# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A variable which packs a list of variables distributed across devices."""

from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops


class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
  """A variable which packs multiple variables distributed across devices.

  It's only supported when eager execution is enabled.
  For op-by-op execution, use an unpacked handle on the current device; for
  function execution, use the packed handle to reduce the overhead of function
  calls.
  """

  def __init__(self, distributed_variables=None, name=None, **unused_kwargs):
    """Packs a list of variables which are distributed across devices.

    Args:
      distributed_variables: A list of distributed Variables to pack.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
    """
    if not ops.executing_eagerly_outside_functions():
      raise ValueError(
          "PackedDistributedVariable should be created in eager mode.")
    if not distributed_variables:
      raise ValueError("Expect a non-empty list of variables to pack.")
    for i, var in enumerate(distributed_variables):
      if not resource_variable_ops.is_resource_variable(var):
        raise ValueError("Expect a list of ResourceVariables to pack, "
                         "but the %d-th variable is %s" % (i, type(var)))

    self._distributed_variables = distributed_variables
    self._devices = [v.device for v in distributed_variables]
    with ops.init_scope():
      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
        handle = ops.pack_eager_tensors(
            [var.handle for var in distributed_variables])
        handle_name = ops.name_from_scope_name(name)
        unique_id = "%s_%d" % (handle_name, ops.uid())
        super(PackedDistributedVariable, self).__init__(
            trainable=distributed_variables[0].trainable,
            shape=distributed_variables[0].shape,
            dtype=distributed_variables[0].dtype,
            handle=handle,
            synchronization=distributed_variables[0].synchronization,
            constraint=distributed_variables[0].constraint,
            aggregation=distributed_variables[0].aggregation,
            distribute_strategy=distributed_variables[0]._distribute_strategy,  # pylint: disable=protected-access
            name=name,
            unique_id=unique_id,
            handle_name=handle_name,
            graph_element=None,
            initial_value=None,
            initializer_op=None,
            is_initialized_op=None,
            cached_value=None,
            caching_device=None,
            is_distributed_variables=True)

  @property
  def devices(self):
    return self._devices

  def on_device(self, device):
    return PackedVarAndDevice(self, device)

  def get_var_on_device(self, device):
    for i, d in enumerate(self._devices):
      if d == device:
        return self._distributed_variables[i]
    # Device didn't match, check if device on same host as a fallback.
    for i, d in enumerate(self._devices):
      if device_util.get_host_for_device(d) == device:
        return self._distributed_variables[i]
    raise ValueError("Device %s is not found" % device)

  def get_var_on_current_device(self):
    current_device = device_util.canonicalize(device_util.current())
    return self.get_var_on_device(current_device)

  def initial_value(self, device):
    """Returns the Tensor used as the initial value for the variable."""
    return self.get_var_on_device(device).initial_value

  @property
  def handle(self):
    if context.executing_eagerly():
      return self.get_var_on_current_device().handle
    else:
      return self._handle

  @property
  def packed_handle(self):
    return self._handle

  def _read_variable_op(self, no_copy=False):
    if context.executing_eagerly():
      return self.get_var_on_current_device().value()
    else:
      return super()._read_variable_op(no_copy)

  def value(self):
    return self._read_variable_op()

  def is_initialized(self, name=None):
    if context.executing_eagerly():
      result = self._distributed_variables[0].is_initialized()
      for v in self._distributed_variables[1:-1]:
        result = math_ops.logical_and(result, v.is_initialized())
      result = math_ops.logical_and(
          result, self._distributed_variables[-1].is_initialized(), name=name)
    else:
      with ops.device(self._devices[0]):
        result = super(PackedDistributedVariable, self).is_initialized(name)
      for d in self._devices[1:-1]:
        with ops.device(d):
          initialized = super(PackedDistributedVariable,
                              self).is_initialized(name)
        result = math_ops.logical_and(result, initialized)
      with ops.device(self._devices[-1]):
        initialized = super(PackedDistributedVariable,
                            self).is_initialized(name)
      result = math_ops.logical_and(result, initialized, name=name)
    return result

  def _update(self, update_fn, value, **kwargs):
    if context.executing_eagerly():
      return update_fn(self.get_var_on_current_device(), value, **kwargs)
    else:
      return update_fn(super(PackedDistributedVariable, self), value, **kwargs)

  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
    return self._update(
        update_fn=assign_sub_fn,
        value=delta,
        use_locking=use_locking,
        name=name,
        read_value=read_value)

  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
    return self._update(
        update_fn=assign_add_fn,
        value=delta,
        use_locking=use_locking,
        name=name,
        read_value=read_value)

  def assign(self, value, use_locking=None, name=None, read_value=True):
    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
    return self._update(
        update_fn=assign_fn,
        value=value,
        use_locking=use_locking,
        name=name,
        read_value=read_value)

  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
    scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
    return self._update(
        update_fn=scatter_sub_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_add(self, sparse_delta, use_locking=False, name=None):
    scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
    return self._update(
        update_fn=scatter_add_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
    scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
    return self._update(
        update_fn=scatter_mul_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_div(self, sparse_delta, use_locking=False, name=None):
    scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
    return self._update(
        update_fn=scatter_div_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_min(self, sparse_delta, use_locking=False, name=None):
    scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
    return self._update(
        update_fn=scatter_min_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_max(self, sparse_delta, use_locking=False, name=None):
    scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
    return self._update(
        update_fn=scatter_max_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def scatter_update(self, sparse_delta, use_locking=False, name=None):
    scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
    return self._update(
        update_fn=scatter_update_fn,
        value=sparse_delta,
        use_locking=use_locking,
        name=name)

  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
    if context.executing_eagerly():
      return self.get_var_on_current_device()._dense_var_to_tensor(  # pylint: disable=protected-access
          dtype=dtype,
          name=name,
          as_ref=as_ref)
    else:
      return super(PackedDistributedVariable, self)._dense_var_to_tensor(  # pylint: disable=protected-access
          dtype=dtype,
          name=name,
          as_ref=as_ref)


class PackedVarAndDevice(object):
  """Holds a packed distributed variable and a device."""

  def __init__(self, var, device):
    self._var = var
    self._device = device

  def __getattr__(self, name):
    # Exceptions raised inside the contextmanager can cause a reference
    # cycle.[1] The cycle involves the current frame, which holds the reference
    # to the outer frame. Tensorflow, e.g. iterators, relies on object
    # finalizers to clean up resources. Such references prevents the resource
    # from being deleted and can cause leaks and errors. One corner the case is
    # that iterators are kept alive and the garbage collector happens to run
    # after auto control dependencies; this causes the deletion to lose the
    # control dependencies to operations that uses such resources.
    #
    # Catch and re-raise the exception seems to workaround the issue.
    #
    # [1] https://bugs.python.org/issue43533
    try:
      with ops.device(self._device):
        return getattr(self._var, name)
    except:  # pylint: disable=try-except-raise
      raise

  def var(self):
    return self._var

  def value(self):
    with ops.device(self._device):
      return self._var.value()

  def read_value(self):
    with ops.device(self._device):
      return self._var.read_value()

  @property
  def initial_value(self):
    return self._var.initial_value(self._device)

  def initialized_value(self):
    with ops.device(self._device):
      return self._var.initialized_value()

  @property
  def device(self):
    return self._device

  @property
  def handle(self):
    with ops.device(self._device):
      return self._var.handle

  def on_device_handle(self):
    with ops.device(self._device):
      return self._var.get_var_on_current_device().handle

  @property
  def op(self) -> ops.Operation:
    with ops.device(self._device):
      return self._var.op

  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
    with ops.device(self._device):
      return self._var.assign_sub(delta, use_locking, name, read_value)

  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
    with ops.device(self._device):
      return self._var.assign_add(delta, use_locking, name, read_value)

  def assign(self, value, use_locking=None, name=None, read_value=True):
    with ops.device(self._device):
      return self._var.assign(value, use_locking, name, read_value)

  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_sub(sparse_delta, use_locking, name)

  def scatter_add(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_add(sparse_delta, use_locking, name)

  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_mul(sparse_delta, use_locking, name)

  def scatter_div(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_div(sparse_delta, use_locking, name)

  def scatter_min(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_min(sparse_delta, use_locking, name)

  def scatter_max(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_max(sparse_delta, use_locking, name)

  def scatter_update(self, sparse_delta, use_locking=False, name=None):
    with ops.device(self._device):
      return self._var.scatter_update(sparse_delta, use_locking, name)

  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
    with ops.device(self._device):
      return self._var._dense_var_to_tensor(  # pylint: disable=protected-access
          dtype=dtype,
          name=name,
          as_ref=as_ref)

  def _as_graph_element(self):
    return self._var._as_graph_element()  # pylint: disable=protected-access


def _tensor_conversion_packed_var_and_device(var,
                                             dtype=None,
                                             name=None,
                                             as_ref=False):
  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access


tensor_conversion_registry.register_tensor_conversion_function(
    PackedVarAndDevice, _tensor_conversion_packed_var_and_device)
