# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""

import collections
import math

from tensorflow.python.framework import device as device_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops


def _flatten_tensors(tensors):
  """Check tensors for isomorphism and flatten.

  Args:
    tensors: list of `tf.Tensor` which must all have the same shape.

  Returns:
    tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
    shape: the original shape of each element of input tensors

  Raises:
    ValueError: tensors are empty or non-isomorphic or have unknown shape.
  """
  if not tensors:
    raise ValueError("tensors cannot be empty")
  shape = tensors[0].shape
  for tensor in tensors:
    shape = shape.merge_with(tensor.shape)
  if not shape.is_fully_defined():
    raise ValueError("Tensors must have statically known shape.")
  if len(shape) != 1:
    reshaped = []
    for t in tensors:
      with ops.colocate_with(t):
        reshaped.append(array_ops.reshape(t, [-1]))
    tensors = reshaped
  return tensors, shape


def _reshape_tensors(tensors, shape):
  """Reshape tensors flattened by _flatten_tensors.

  Args:
    tensors: list of `tf.Tensor` of identical length 1D tensors.
    shape: list of integers describing the desired shape.  Product of
      the elements must equal the length of each tensor.

  Returns:
    list of `tf.Tensor` which are the reshaped inputs.
  """
  reshaped = []
  for t in tensors:
    with ops.colocate_with(t):
      reshaped.append(array_ops.reshape(t, shape))
  return reshaped


def _padded_split(tensor, pieces):
  """Like split for 1D tensors but pads-out case where len % pieces != 0.

  Args:
    tensor: `tf.Tensor` that must be 1D.
    pieces: a positive integer specifying the number of pieces into which
      tensor should be split.

  Returns:
    list of `tf.Tensor` of length pieces, which hold the values of
      thin input tensor, in order. The final tensor may
      be zero-padded on the end to make its size equal to those of all
      of the other tensors.

  Raises:
    ValueError: The input tensor is not 1D.
  """
  shape = tensor.shape
  if 1 != len(shape):
    raise ValueError("input tensor must be 1D")
  tensor_len = shape.dims[0].value
  with ops.colocate_with(tensor):
    if tensor_len % pieces != 0:
      # pad to an even length
      chunk_size = 1 + tensor_len // pieces
      if pieces > tensor_len:
        # This is an edge case that should not come up in practice,
        # i.e. a different reduction algorithm would be better,
        # but we'll make it work just for completeness.
        pad_len = pieces - tensor_len
        extended_whole = array_ops.concat(
            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        parts = array_ops.split(extended_whole, pieces)
        return parts, pad_len
      elif (pieces - 1) * chunk_size >= tensor_len:
        # Another edge case of limited real interest.
        pad_len = (pieces * chunk_size) % tensor_len
        extended_whole = array_ops.concat(
            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        parts = array_ops.split(extended_whole, pieces)
        return parts, pad_len
      else:
        last_chunk_size = tensor_len - (pieces - 1) * chunk_size
        pad_len = chunk_size - last_chunk_size
        piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
        parts = array_ops.split(tensor, piece_lens)
        parts[-1] = array_ops.concat(
            [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        return parts, pad_len
    else:
      return array_ops.split(tensor, pieces), 0


def _strip_padding(tensors, pad_len):
  """Strip the suffix padding added by _padded_split.

  Args:
    tensors: list of `tf.Tensor` of identical length 1D tensors.
    pad_len: number of elements to be stripped from the end of each tensor.

  Returns:
    list of `tf.Tensor` which are the stripped inputs.

  Raises:
    ValueError: tensors must be a non-empty list of 1D tensors, and
      each must be longer than pad_len.
  """
  if not tensors:
    raise ValueError("tensors cannot be empty")
  shape = tensors[0].shape
  if len(shape) > 1:
    raise ValueError("tensors must be 1D")
  prefix_len = int(shape[0] - pad_len)
  if prefix_len < 0:
    raise ValueError("pad_len longer than tensor")
  stripped = []
  for t in tensors:
    with ops.colocate_with(t):
      stripped.append(array_ops.slice(t, [0], [prefix_len]))
  return stripped


def _ragged_split(tensor, pieces):
  """Like split for 1D tensors but allows case where len % pieces != 0.

  Args:
    tensor: `tf.Tensor` that must be 1D.
    pieces: a positive integer specifying the number of pieces into which
      tensor should be split.

  Returns:
    list of `tf.Tensor` of length pieces, which hold the values of
      the input tensor, in order. The final tensor may be shorter
      than the others, which will all be of equal length.

  Raises:
    ValueError: input tensor must be 1D.
  """
  shape = tensor.shape
  if 1 != len(shape):
    raise ValueError("input tensor must be 1D")
  tensor_len = shape.dims[0].value
  chunk_size = tensor_len // pieces
  with ops.colocate_with(tensor):
    if tensor_len != (pieces * chunk_size):
      # last piece will be short
      assert pieces > 1
      last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
      assert last_chunk_size > 0
      piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
      return array_ops.split(tensor, piece_lens)
    else:
      return array_ops.split(tensor, pieces)


def _ring_permutations(num_workers, num_subchunks, gpu_perm):
  """"Generate an array of device index arrays, one for each subchunk.

  In the basic ring reduction algorithm there are size(T)/num_devices
  data chunks and each device process one chunk per tick, i.e. sending
  one chunk and receiving one chunk.  The idea of subchunking is that
  each device processes num_subchunks smaller data regions per tick,
  and the ring rank permutation is different for each subchunk index
  so that a device is potentially sending to and receiving from
  num_subchunks different other devices at each tick.  Where multiple
  independent data channels exist between devices, this strategy
  supplies a method of using them in parallel.

  Args:
    num_workers: number of worker tasks
    num_subchunks: number of subchunks into which to divide each per-GPU chunk.
    gpu_perm: an array of integers in [0, num_gpus-1] giving the default
      ring order of GPUs at each worker.  Other permutations will be generated
      by rotating this array and splicing together per-worker instances.

  Raises:
    ValueError: the number of subchunks may not exceed the number of GPUs.

  Returns:
    pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
        preceding device in the permutation for that subchunk.  The
        device index of GPU i at worker j is i + (j * num_gpus).
    rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
       local rank of device d in the permutation for that subchunk.
  """
  num_gpus = len(gpu_perm)
  devices = num_workers * num_gpus
  if devices == 0:
    return [], []
  if num_subchunks > num_gpus:
    raise ValueError(
        "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
  rotation_interval = max(1, int(num_gpus / num_subchunks))
  perms_by_s = []
  for s in range(0, num_subchunks):
    full_order = []
    offset = s * rotation_interval
    for w in range(0, num_workers):
      default_order = [(w * num_gpus) + i for i in gpu_perm]
      dev_order = default_order[offset:] + default_order[:offset]
      full_order += dev_order
    perms_by_s.append(full_order)
  pred_by_s_d = [[-1 for d in range(0, devices)]
                 for s in range(0, num_subchunks)]
  rank_by_s_d = [[-1 for d in range(0, devices)]
                 for s in range(0, num_subchunks)]
  for s in range(0, num_subchunks):
    for d in range(0, devices):
      for t in range(0, devices):
        if d == perms_by_s[s][t]:
          rank_by_s_d[s][d] = t
          pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
          break
  return (pred_by_s_d, rank_by_s_d)


def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
                          gpu_perm, red_op, un_op=None):
  """Construct a subgraph performing a ring-style all-reduce of input_tensors.

  Args:
    input_tensors: a list of `tf.Tensor` objects, which must all
      have the same shape and type.
    num_workers: number of worker tasks spanned by input_tensors.
    num_subchunks: number of subchunks each device should process in one tick.
    gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
      each worker.  All workers must have the same number of
      GPUs with the same rank ordering.  If NVLINK is available, this should
      be a ring order supported by NVLINK edges.
    red_op: a binary operator for elementwise reduction.
    un_op: an optional unary operator to apply to fully reduced values.

  Raises:
    ValueError: empty input_tensors or they don't all have same
    size.

  Returns:
    a list of `tf.Tensor` identical sum-reductions of input_tensors.
  """
  if len(input_tensors) < 2:
    raise ValueError("input_tensors must be length 2 or longer")
  input_tensors, shape = _flatten_tensors(input_tensors)
  devices = [t.device for t in input_tensors]
  (pred_by_s_d, rank_by_s_d) = _ring_permutations(
      num_workers, num_subchunks, gpu_perm)
  chunks_by_dev, pad_len = _build_ring_gather(
      input_tensors, devices,
      num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
  if un_op:
    chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
  output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
                                       chunks_by_dev)
  if pad_len > 0:
    output_tensors = _strip_padding(output_tensors, pad_len)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_ring_gather(input_tensors, devices, num_subchunks,
                       pred_by_s_d, rank_by_s_d, red_op):
  """Construct a subgraph for the first (reduction) pass of ring all-reduce.

  Args:
    input_tensors: a list of `tf.Tensor` 1D input tensors of same
      shape and type.
    devices: array of device name strings
    num_subchunks: number of subchunks each device should process in one tick.
    pred_by_s_d: as produced by _ring_permutations
    rank_by_s_d: as produced by _ring_permutations
    red_op: a binary operator for elementwise reduction

  Raises:
    ValueError: tensors must all be one dimensional.

  Returns:
    list of list of `tf.Tensor` of (partially) reduced values where
    exactly num_subchunks chunks at each device are fully reduced.
  """
  num_devices = len(input_tensors)
  if num_devices == 0:
    return []
  if num_devices == 1:
    return input_tensors
  shape = input_tensors[0].shape
  if 1 != len(shape):
    raise ValueError("input tensors must be 1D")
  num_chunks = num_devices * num_subchunks
  num_ticks = num_devices - 1
  # Initialize chunks_by_dev with splits of the input tensors.
  chunks_by_dev = []
  split_pad_len = 0
  for d in range(0, num_devices):
    with ops.device(devices[d]):
      splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
      chunks_by_dev.append(splits)
  # Reduction phase
  for tick in range(0, num_ticks):
    # One new partial reduction for every chunk
    new_partial_reductions = [None for _ in range(0, num_chunks)]
    # Compute reductions with respect to last tick's values
    for d in range(0, num_devices):
      with ops.device(devices[d]):
        for s in range(0, num_subchunks):
          rank = rank_by_s_d[s][d]
          seg_index = (rank + num_devices - (2 + tick)) % num_devices
          pred_dev = pred_by_s_d[s][d]
          chunk_index = (seg_index * num_subchunks) + s
          new_partial_reductions[chunk_index] = red_op(
              chunks_by_dev[pred_dev][chunk_index],
              chunks_by_dev[d][chunk_index])
    # Update chunks_by_dev with the new values at the end of the tick.
    for d in range(0, num_devices):
      for s in range(0, num_subchunks):
        rank = rank_by_s_d[s][d]
        seg_index = (rank + num_devices - (2 + tick)) % num_devices
        chunk_index = (seg_index * num_subchunks) + s
        chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
  return chunks_by_dev, split_pad_len


def _apply_unary_to_chunks(f, chunks_by_dev):
  """Apply a unary op to each tensor in chunks_by_dev, on same device.

  Args:
    f: a unary function over `tf.Tensor`.
    chunks_by_dev: list of lists of `tf.Tensor`.

  Returns:
    new list of lists of `tf.Tensor` with the same structure as
    chunks_by_dev containing the derived tensors.
  """
  output = []
  for x in chunks_by_dev:
    with ops.colocate_with(x[0]):
      output.append([f(t) for t in x])
  return output


def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
                        chunks_by_dev):
  """Construct subgraph for second (scatter) pass of ring all-reduce.

  Args:
    pred_by_s_d: as produced by _ring_permutations
    rank_by_s_d: as produced by _ring_permutations
    chunks_by_dev: list of list of `tf.Tensor` indexed by ints
      (device, chunk)

  Raises:
    ValueError: chunks_by_dev is not well-formed

  Returns:
    list of `tf.Tensor` which are the fully reduced tensors, one
    at each device corresponding to the outer dimension of chunks_by_dev.
  """
  num_devices = len(chunks_by_dev)
  num_chunks = len(chunks_by_dev[0])
  if 0 != num_chunks % num_devices:
    raise ValueError(
        "Expect number of chunks per device to be divisible by num_devices")
  num_subchunks = int(num_chunks / num_devices)
  num_ticks = num_devices - 1
  for tick in range(0, num_ticks):
    passed_values = [None for _ in range(0, num_chunks)]
    for d in range(0, num_devices):
      with ops.colocate_with(chunks_by_dev[d][0]):
        for s in range(0, num_subchunks):
          rank = rank_by_s_d[s][d]
          seg_index = (rank + num_devices - (1 + tick)) % num_devices
          pred_dev = pred_by_s_d[s][d]
          chunk_index = (seg_index * num_subchunks) + s
          passed_values[chunk_index] = array_ops.identity(
              chunks_by_dev[pred_dev][chunk_index])
    for d in range(0, num_devices):
      for s in range(0, num_subchunks):
        rank = rank_by_s_d[s][d]
        seg_index = (rank + num_devices - (1 + tick)) % num_devices
        chunk_index = (seg_index * num_subchunks) + s
        chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
  # Join chunks at each device.
  output = []
  for x in chunks_by_dev:
    with ops.colocate_with(x[0]):
      output.append(array_ops.concat(x, 0))
  return output


def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
  """Construct a subgraph for recursive halving-doubling all-reduce.

  The recursive halving-doubling algorithm is described in
  (Thakur et al., 2015).

  The concept is to arrange the participating n devices in
  a linear sequence where devices exchange data pairwise
  with one other device in each round.  During the gather
  phase there are lg(n) rounds where devices exchange
  increasingly smaller sub-tensors with another device
  at increasingly greater distances, until at the top
  each device has 1/n of the fully reduced values.  During the
  scatter phase each device exchanges its fully reduced
  sub-tensor (which doubles in length at each round)
  with one other device at increasingly smaller distances
  until each device has all of the fully reduced values.

  Note: this preliminary version requires that len(input_tensors) be a
    power of 2.  TODO(tucker): relax this restriction.  Also, the
    number of elements in each tensor must be divisible by 2^h where h
    is the number of hops in each phase.  This will also be relaxed in
    the future with edge-case specific logic.

  Args:
    input_tensors: list of `tf.Tensor` to be elementwise reduced.
    red_op: a binary elementwise reduction Op.
    un_op: an optional unary elementwise Op to apply to reduced values.

  Returns:
    list of `tf.Tensor` which are the fully reduced tensors, one
    at each device of input_tensors.

  Raises:
    ValueError: num_devices not a power of 2, or tensor len not divisible
    by 2 the proper number of times.

  References:
    Optimization of Collective Communication Operations in MPICH:
      [Thakur et al., 2005]
      (https://journals.sagepub.com/doi/abs/10.1177/1094342005051521)
      ([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf))
  """
  devices = [t.device for t in input_tensors]
  input_tensors, shape = _flatten_tensors(input_tensors)
  reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
  if un_op:
    reduced_shards = [un_op(t) for t in reduced_shards]
  output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_recursive_hd_gather(input_tensors, devices, red_op):
  """Construct the gather phase of recursive halving-doubling all-reduce.

  Args:
    input_tensors: list of `tf.Tensor` to be elementwise reduced.
    devices: a list of strings naming the devices hosting input_tensors,
      which will also be used to host the (partial) reduction values.
    red_op: a binary elementwise reduction Op.

  Returns:
    list of `tf.Tensor` which are the fully reduced tensor shards.

  Raises:
    ValueError: num_devices not a power of 2, or tensor len not divisible
    by 2 the proper number of times.
  """
  num_devices = len(devices)
  num_hops = int(math.log(num_devices, 2))
  if num_devices != (2 ** num_hops):
    raise ValueError("num_devices must be a power of 2")
  chunks = input_tensors
  for h in range(0, num_hops):
    span = 2 ** h
    group_size = span * 2
    new_chunks = [[] for _ in devices]
    for d in range(0, num_devices):
      if (d % group_size) >= (group_size / 2):
        # skip right half of a pair
        continue
      left_dev = devices[d]
      right_dev = devices[d + span]
      left_split = array_ops.split(chunks[d], 2)
      right_split = array_ops.split(chunks[d+span], 2)
      with ops.device(left_dev):
        new_chunks[d] = red_op(left_split[0], right_split[0])
      with ops.device(right_dev):
        new_chunks[d + span] = red_op(left_split[1], right_split[1])
    chunks = new_chunks
  return chunks


def _build_recursive_hd_scatter(input_tensors, devices):
  """Construct the scatter phase of recursive halving-doubling all-reduce.

  Args:
    input_tensors: list of `tf.Tensor` that are fully-reduced shards.
    devices: a list of strings naming the devices on which the reconstituted
      full tensors should be placed.

  Returns:
    list of `tf.Tensor` which are the fully reduced tensors.
  """
  num_devices = len(devices)
  num_hops = int(math.log(num_devices, 2))
  assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
  chunks = input_tensors
  for h in reversed(range(0, num_hops)):
    span = 2 ** h
    group_size = span * 2
    new_chunks = [[] for _ in devices]
    for d in range(0, num_devices):
      if (d % group_size) >= (group_size / 2):
        # skip right half of a pair
        continue
      left_idx = d
      right_idx = d + span
      left_dev = devices[left_idx]
      right_dev = devices[right_idx]
      with ops.device(left_dev):
        new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
                                                 chunks[right_idx]], 0)
      with ops.device(right_dev):
        new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
                                                  chunks[right_idx]], 0)
    chunks = new_chunks
  return chunks


def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
  """Construct a subgraph for shuffle all-reduce.

  Shuffle reduce is essentially the algorithm implemented when using
  parameter servers.  Suppose tensor length is n, there are d devices
  and g gather shards.  Each device sends a n/g length sub-tensor to
  each gather shard.  The gather shards perform a reduction across d
  fragments, then broadcast the result back to each device.  The
  devices then join the g fully reduced fragments they receive from
  the shards.  The gather shards could perform d-1 pairwise
  reductions, or one d-way reduction.  The first is better where
  reduction Op time is low compared to transmission time, the second
  better in the other case.

  Args:
    input_tensors: list of `tf.Tensor` values to be reduced.
    gather_devices: list of names of devices on which reduction shards
      should be placed.
    red_op: an n-array elementwise reduction Op
    un_op: optional elementwise unary Op to be applied to fully-reduced values.

  Returns:
    list of `tf.Tensor` which are the fully reduced tensors.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  dst_devices = [t.device for t in input_tensors]
  reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
                                         red_op, un_op)
  output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
  """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.

  Args:
    input_tensors: list of `tf.Tensor` values to be reduced.
    gather_devices: list of names of devices on which reduction shards
      should be placed.
    red_op: the binary reduction Op
    un_op: optional elementwise unary Op to be applied to fully-reduced values.

  Returns:
    list of `tf.Tensor` which are the fully reduced shards.

  Raises:
    ValueError: inputs not well-formed.
  """
  num_source_devices = len(input_tensors)
  num_gather_devices = len(gather_devices)
  shape = input_tensors[0].shape
  if len(shape) != 1:
    raise ValueError("input_tensors must be 1D")
  shards_by_source = []
  for d in range(0, num_source_devices):
    with ops.colocate_with(input_tensors[d]):
      shards_by_source.append(
          _ragged_split(input_tensors[d], num_gather_devices))
  reduced_shards = []
  for d in range(0, num_gather_devices):
    with ops.device(gather_devices[d]):
      values = [s[d] for s in shards_by_source]
      red_shard = red_op(values)
      if un_op:
        red_shard = un_op(red_shard)
      reduced_shards.append(red_shard)
  return reduced_shards


def _build_shuffle_scatter(reduced_shards, dst_devices):
  """Build the scatter phase of shuffle all-reduce.

  Args:
    reduced_shards:  list of `tf.Tensor` fully reduced shards
    dst_devices: list of names of devices at which the fully-reduced value
      should be reconstituted.

  Returns:
    list of `tf.Tensor` scattered tensors.
  """
  num_devices = len(dst_devices)
  out_tensors = []
  for d in range(0, num_devices):
    with ops.device(dst_devices[d]):
      out_tensors.append(array_ops.concat(reduced_shards, 0))
  return out_tensors


def _split_by_task(devices, values):
  """Partition devices and values by common task.

  Args:
    devices: list of device name strings
    values: list of `tf.Tensor` of same length as devices.

  Returns:
    (per_task_devices, per_task_values) where both values are
    lists of lists with isomorphic structure: the outer list is
    indexed by task, and the inner list has length of the number
    of values belonging to that task.  per_task_devices contains
    the specific devices to which the values are local, and
    per_task_values contains the corresponding values.

  Raises:
    ValueError: devices must be same length as values.
  """
  num_devices = len(devices)
  if num_devices != len(values):
    raise ValueError("len(devices) must equal len(values)")
  per_task_devices = collections.OrderedDict()
  per_task_values = collections.OrderedDict()
  for d in range(num_devices):
    d_spec = device_lib.DeviceSpec.from_string(devices[d])
    if not hasattr(d_spec, "task") or d_spec.task is None:
      assert False, "failed to parse device %s" % devices[d]
    index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
    if index not in per_task_devices:
      per_task_devices[index] = []
      per_task_values[index] = []
    per_task_devices[index].append(devices[d])
    per_task_values[index].append(values[d])

  return (list(per_task_devices.values()), list(per_task_values.values()))


def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
  """Build a subgraph that does one full all-reduce, using NCCL.

  Args:
    input_tensors: list of `tf.Tensor` of same-shape and type values to
      be reduced.
    red_op: binary elementwise reduction operator. Must be one of
      {tf.add}
    un_op: optional unary elementwise Op to apply to fully-reduce values.

  Returns:
    list of `tf.Tensor` of reduced values.

  Raises:
    ValueError: red_op not supported.
  """
  if red_op == math_ops.add:
    output_tensors = nccl_ops.all_sum(input_tensors)
  else:
    raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
  if un_op:
    un_op_wrapped = []
    for t in output_tensors:
      with ops.colocate_with(t):
        un_op_wrapped.append(un_op(t))
    output_tensors = un_op_wrapped
  return output_tensors


def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
  """Construct a subgraph for NCCL hybrid all-reduce.

  Args:
    input_tensors: list of `tf.Tensor` of same-shape and type values to
      be reduced.
    red_op: binary elementwise reduction operator.
    upper_level_f: function for reducing one value per worker, across
      workers.

  Returns:
    list of `tf.Tensor` of reduced values.

  Raises:
    ValueError: inputs not well-formed.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  devices = [t.device for t in input_tensors]
  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
  num_workers = len(per_worker_devices)
  up_values = [None for w in range(0, num_workers)]
  up_devices = up_values[:]
  down_values = up_values[:]
  # First stage: reduce within each worker using NCCL
  for w in range(0, num_workers):
    worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
    # NOTE: these reductions will not run to completion unless
    # every output value is used.  Since we only need one, we
    # need to put control dependencies on the rest.
    with ops.control_dependencies(worker_values):
      with ops.device(worker_values[0].device):
        up_values[w] = array_ops.identity(worker_values[0])
      up_devices[w] = per_worker_devices[w][0]
  # Second stage: Apply upper_level_f to reduce across first device at
  # each worker
  level_2_output = upper_level_f(up_values)
  # Third stage: propagate within each worker using NCCL Broadcast
  for w in range(0, num_workers):
    dst_tensors = []
    with ops.device(per_worker_devices[w][0]):
      broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
    for d in per_worker_devices[w]:
      with ops.device(d):
        dst_tensors.append(array_ops.identity(broadcast_src))
    down_values[w] = dst_tensors
  output_tensors = [v for sublist in down_values for v in sublist]
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _reduce_non_singleton(input_tensors, red_f, un_op):
  """If len(input_tensors) > 1, apply red_f, else apply un_op."""
  if len(input_tensors) > 1:
    return red_f(input_tensors)
  else:
    if not un_op:
      return input_tensors
    output_tensors = []
    for t in input_tensors:
      with ops.colocate_with(t):
        output_tensors.append(un_op(t))
    return output_tensors


def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Ring across workers."""
  def upper_builder(y):
    return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
  def upper_level_f(x):
    return _reduce_non_singleton(x, upper_builder, un_op)
  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)


def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
  upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)


def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
                            shuffle_red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Shuffle across workers."""
  def upper_level_f(x):
    return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op)

  return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)


def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
  """Construct a subgraph for Shuffle hybrid all-reduce.

  Args:
    input_tensors: list of `tf.Tensor` of same-shape and type values to
      be reduced.
    gather_devices: list of device names on which to host gather shards.
    red_op: binary elementwise reduction operator.
    upper_level_f: function for reducing one value per worker, across
      workers.

  Returns:
    list of `tf.Tensor` of reduced values.

  Raises:
    ValueError: inputs not well-formed.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  # First stage, reduce across each worker using gather_devices.
  devices = [t.device for t in input_tensors]
  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
  num_workers = len(per_worker_devices)
  up_values = []
  if len(gather_devices) != num_workers:
    raise ValueError("For shuffle hybrid, gather_devices must contain one "
                     "device per worker. ")
  for w in range(0, num_workers):
    reduced_shards = _build_shuffle_gather(
        per_worker_values[w], [gather_devices[w]], red_op)
    up_values.append(reduced_shards[0])
  # Second stage, apply upper_level_f.
  level_2_output = upper_level_f(up_values)
  # Third stage, apply shuffle scatter at each worker.
  output_tensors = []
  for w in range(0, num_workers):
    output_tensors += _build_shuffle_scatter(
        [level_2_output[w]], per_worker_devices[w])
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
                            red_n_op, red_op, un_op=None):
  """Construct hybrid of Shuffle within workers, Ring across workers."""
  def upper_builder(tensors):
    return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
                                 red_op, un_op)
  def upper_level_f(tensors):
    return _reduce_non_singleton(tensors, upper_builder, un_op)
  return _build_shuffle_hybrid(
      input_tensors, gather_devices, red_n_op, upper_level_f)


def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
                               second_gather_devices, red_op, un_op=None):
  """Construct hybrid of Shuffle within workers, Shuffle across workers."""
  def upper_builder(tensors):
    return build_shuffle_all_reduce(tensors, second_gather_devices,
                                    red_op, un_op)
  def upper_level_f(tensors):
    return _reduce_non_singleton(tensors, upper_builder, un_op)
  return _build_shuffle_hybrid(
      input_tensors, first_gather_devices, red_op, upper_level_f)
