# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""xla is an experimental library that provides XLA support APIs."""

import contextlib


from tensorflow.compiler.jit.ops import xla_ops
from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.distribute import summary_op_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export

_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5

# Operations that indicate some error in the users graph. For example, XLA
# computation should not have any Placeholder op.
_DENYLISTED_OPS = set([
    'Placeholder',
])

# XLA doesn't currently support reading of intermediate tensors, thus some ops
# are not supported.
_UNSUPPORTED_OPS = set([
    'AudioSummary',
    'AudioSummaryV2',
    'HistogramSummary',
    'ImageSummary',
    'MergeSummary',
    'Print',
    'ScalarSummary',
    'TensorSummary',
    'TensorSummaryV2',
])


@tf_export('xla.experimental.compile')
@deprecated(
    None, 'xla.experimental.compile is deprecated. Consider using '
    '`@tf.function(jit_compile=True)`.',
    warn_once=True)
def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
  """Builds an operator that compiles and runs `computation` with XLA.

  NOTE: In eager mode, `computation` will have `@tf.function` semantics.

  Args:
    computation: A Python function that builds a computation to apply to the
      input. If the function takes n inputs, 'inputs' should be a list of n
      `Tensor`s.

      `computation` may return a list of `Tensor`s and `Operation`s.
      `Tensor`s must come before `Operation`s in the returned list.

      All `Operation`s returned from `computation` will be executed when
      evaluating any of the returned output tensors.
    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
      can be a nested structure containing values that can be converted to
      `Tensor`s. Note that passing an N-dimension list of compatible values will
      result in an N-dimension list of scalar `Tensor`s rather than a single
      Rank-N `Tensor`. If you need a different behavior, convert parts of
      `inputs` to `Tensor`s with `tf.convert_to_tensor`.

  Returns:
    List of `Tensor`s corresponding to the `Tensor`s from
      the output of `computation` i.e. the same return value as if
      computation(*inputs) is called directly, with the following exceptions:
      * None output: a NoOp would be returned with a control dependency on
         `computation`.
      * Single value output: a tuple containing the value would be returned.
      * Operation-only outputs: a NoOp would be returned with a control
      dependency on `computation`.
      TODO(b/121383831): Investigate into removing these special cases.

  Raises:
    RuntimeError: When eager execution is enabled.

  Known issues:
    When a tf.random operation is built with XLA, the implementation doesn't
      pass the user provided seed to the XLA compiler. As such, the XLA compiler
      generates a random number and uses it as a seed when compiling the
      operation. This implementation causes a violation of the Tensorflow
      defined semantics in two aspects. First, changing the value of the user
      defined seed doesn't change the numbers generated by the operation.
      Second, when a seed is not specified, running the program multiple times
      will generate the same numbers.
  """
  if context.executing_eagerly():

    @def_function.function
    def xla_compile_wrapper():
      return _compile_internal(computation, inputs)

    return xla_compile_wrapper()

  return _compile_internal(computation, inputs)


class XLACompileContext(control_flow_ops.XLAControlFlowContext):
  """A `ControlFlowContext` for nodes inside an XLA computation cluster.

  THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.

  The primary role of `XLACompileContext` is to mark operators inside a
  xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
  a unique name.

  `ControlFlowContext` is used to perform the annotation since it integrates
  with Tensorflow constructs like ResourceVariables. For example, if a
  `ResourceVariable` is constructed inside a xla.compile() block, the
  `ResourceVariable` implementation can use
  `with ops.control_dependencies(None)` to build the variable's definition
  outside the compiled computation.
  """

  def __init__(self, name, pivot):
    """Builds a new XLACompileContext.

    Args:
      name: a unique name for the context, used to populate the
        `_xla_compile_id` attribute.
      pivot: a pivot node. Nodes in the XLACompileContext that do not have any
        inputs will have a control dependency on the pivot node. This ensures
        that nodes are correctly included in any enclosing control flow
        contexts.
    """
    super(XLACompileContext, self).__init__()
    self._name = name
    self._name_as_bytes = compat.as_bytes(name)
    self._unsupported_ops = []
    self._pivot = pivot

  def report_unsupported_operations(self):
    if self._unsupported_ops:
      op_str = '\n'.join([
          '  %s (%s)' % (op.type, op.name)
          for op in self._unsupported_ops[:_MAX_WARNING_LINES]
      ])
      logging.warning('%d unsupported operations found: \n%s',
                      len(self._unsupported_ops), op_str)
      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
        logging.warning('... and %d more',
                        len(self._unsupported_ops) - _MAX_WARNING_LINES)

  def _RemoveExternalControlEdges(self, op: ops.Operation):
    """Remove any external control dependency on this op."""
    internal_control_inputs = []
    external_control_inputs = []
    for x in op.control_inputs:
      # pylint: disable=protected-access
      is_internal_op = False
      ctxt = x._get_control_flow_context()
      while ctxt is not None:
        if ctxt == self:
          is_internal_op = True
          break
        ctxt = ctxt._outer_context
      if is_internal_op:
        internal_control_inputs.append(x)
      else:
        external_control_inputs.append(x)
      # pylint: enable=protected-access
    # pylint: disable=protected-access
    op._remove_all_control_inputs()
    op._add_control_inputs(internal_control_inputs)
    # pylint: enable=protected-access
    return internal_control_inputs, external_control_inputs

  def AddOp(self, op: ops.Operation):
    """Create op in XLACompileContext and notifies outer context recursively."""
    # pylint: disable=protected-access
    if op.type in _DENYLISTED_OPS:
      logging.error(
          'Operation of type %s (%s) is not supported in XLA. Execution will '
          'fail if this op is used in the graph. ', op.type, op.name)

    # TODO(ycao): Automatically disable summaries instead of reporting them.
    if op.type in _UNSUPPORTED_OPS:
      self._unsupported_ops.append(op)

    if any(x.dtype._is_ref_dtype for x in op.inputs):
      raise NotImplementedError(
          'Non-resource Variables are not supported inside XLA computations '
          '(operator name: %s)' % op.name)

    if _XLA_COMPILE_ATTR in op.node_def.attr:
      raise ValueError('XLA compiled computations cannot be nested, (operator '
                       'name: %s)' % op.name)

    op._set_attr(
        _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))

    op.graph.prevent_feeding(op)
    op.graph.prevent_fetching(op)

    # Remove any control edges from outer control flow contexts. These may cause
    # mismatched frame errors. An example is when one of op's inputs is
    # generated in a different While control flow context.
    (internal_control_inputs,
     external_control_inputs) = self._RemoveExternalControlEdges(op)

    if not op.inputs:
      # Add a control edge from the control pivot to this op.
      if not internal_control_inputs:
        # pylint: disable=protected-access
        op._add_control_input(self._pivot)
        # pylint: enable=protected-access
    else:
      for index in range(len(op.inputs)):
        x = op.inputs[index]
        real_x = self.AddValue(x)
        if real_x is not x:
          op._update_input(index, real_x)  # pylint: disable=protected-access

    if external_control_inputs:
      # Use an identity to pull control inputs as data inputs. Note that we
      # ignore ops which don't have outputs. TODO(phawkins): fix that.
      with ops.control_dependencies(None):
        self.Enter()
        external_control_inputs = [
            array_ops.identity(x.outputs[0]).op
            for x in external_control_inputs
            if x.outputs
        ]
        self.Exit()
      # pylint: disable=protected-access
      op._add_control_inputs(external_control_inputs)
      # pylint: enable=protected-access

    # Mark op's outputs as seen by this context and any outer contexts.
    output_names = [x.name for x in op.outputs]
    context = self
    while context is not None:
      # pylint: disable=protected-access
      context._values.update(output_names)
      context = context._outer_context
      # pylint: enable=protected-access

    if self._outer_context:
      self._outer_context.AddInnerOp(op)

  def AddValue(self, val):
    """Add `val` to the current context and its outer context recursively."""
    if val.name in self._values:
      # Use the real value if it comes from outer context.
      result = self._external_values.get(val.name)
      return val if result is None else result

    result = val
    self._values.add(val.name)
    if self._outer_context:
      result = self._outer_context.AddValue(val)
      self._values.add(result.name)

    self._external_values[val.name] = result

    return result

  def AddInnerOp(self, op: ops.Operation):
    self.AddOp(op)
    if self._outer_context:
      self._outer_context.AddInnerOp(op)

  @property
  def grad_state(self):
    # Define the gradient loop state associated with the XLACompileContext to
    # be None as the XLACompileContext does not get nested nor does the
    # grad_state outside the XLACompileContext affect the graph inside so the
    # grad_state should be as if this is the top-level gradient state.
    return None

  @property
  def back_prop(self):
    """Forwards to the enclosing while context, if any."""
    if self.GetWhileContext():
      return self.GetWhileContext().back_prop
    return False


def _compile_internal(computation, inputs=None):
  """Builds graph operators that compiles and symbolically executes computation.

  Args:
    computation: A Python function that builds the computation to compile and
      execute.
    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
      can be a nested structure containing values that are convertible to
      tensors. Note that passing an N-dimension list of compatible values will
      result in a N-dimension list of scalar tensors rather than a single Rank-N
      tensors. If you need different behavior, convert part of inputs to tensors
      with `tf.convert_to_tensor`.

  Returns:
    Same data structure as if computation(*inputs) is called directly with some
    exceptions for correctness. Exceptions include: 1) None output 2) Single
    value output 3) Operation-only outputs
  Raises:
    ValueError: If any element in computation outputs is neither an operations
      or a value that can be converted to tensor.
    ValueError: If computation outputs is non-flat and contains any Operations.
    TypeError: If `inputs` is not a list or tuple.
  """
  if inputs is None:
    inputs = []

  if not isinstance(inputs, collections_abc.Sequence):
    raise TypeError('inputs must be a list')

  # Flatten inputs.
  flat_inputs = nest.flatten(inputs)
  # Converts inputs to Tensors.
  flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]

  cluster_name = ops.get_default_graph().unique_name('cluster')
  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
  context = XLACompileContext(name=cluster_name, pivot=pivot)
  try:
    context.Enter()

    # Add identity ops so even unused inputs are 'consumed' by the
    # computation.
    flat_inputs = [
        array_ops.identity(x, name='input_{}'.format(i))
        for i, x in enumerate(flat_inputs)
    ]

    # Re-pack flat_inputs in same structure as 'inputs'.
    computation_inputs = nest.pack_sequence_as(
        structure=inputs, flat_sequence=flat_inputs)

    # Only resource variables work inside an XLA computation, so turn on
    # resource variables for the computation.
    vscope = variable_scope.get_variable_scope()
    saved_use_resource = vscope.use_resource
    vscope.set_use_resource(True)

    with _disable_summary_context():
      outputs = computation(*computation_inputs)

    # Restore variable scope after computation.
    vscope.set_use_resource(saved_use_resource)

    outputs_is_flat = is_flat(outputs)
    if outputs_is_flat:
      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
    else:
      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)

    context.ExitResult(output_tensors)
  finally:
    context.report_unsupported_operations()
    context.Exit()

  # When XLA computation returns only operations and no tensors, a NoOp
  # dependent on the operations in outputs is returned. Otherwise final
  # outputs would be empty and there is no way to trigger returned
  # operations.
  if not output_tensors:
    return control_flow_ops.group(control_deps, name='output_0')

  output_tensors = [
      xla_ops.xla_cluster_output(o, name='output{}'.format(i))
      for i, o in enumerate(output_tensors)
  ]

  with ops.control_dependencies(control_deps):
    # Wraps the outputs in identity operators that carries control
    # dependencies.
    output_tensors = [
        array_ops.identity(o, name='output_%d' % i)
        for i, o in enumerate(output_tensors)
    ]

  # If `computation` returned non-flat output structure, pack output tensors
  # back into same structure.
  if not outputs_is_flat:
    output_tensors = nest.pack_sequence_as(
        structure=outputs, flat_sequence=output_tensors)

  return output_tensors


def is_flat(outputs):
  """Checks if outputs is a flat structure.

    Following structures and values are considered flat:
    1) None
    2) A single object
    3) A list or tuple of Tensors/Operations

    The only structures that this function understands are sequences,
    dictionaries and types defined using the attrs library.  E.g. this means
    that if outputs contains a single user-defined Object, it is considered to
    be flat. Errors are raised later on if that Object cannot be converted to a
    Tensor.

  Args:
    outputs: Output from `computation` inside `xla.compile`.

  Returns:
    A boolean indicates whether outputs is flat.
  """
  # If outputs is a list or tuple, check if it has any nested structure. If
  # there is, then outputs is non-flat.
  if isinstance(outputs, collections_abc.Sequence):
    for o in outputs:
      if (isinstance(o, collections_abc.Sequence) or
          isinstance(o, collections_abc.Mapping) or
          hasattr(o.__class__, '__attrs_attrs__')):
        return False

  # If outputs is a dict, it is non-flat.
  if isinstance(outputs, collections_abc.Mapping):
    return False

  # If outputs is from the attrs library, it is non-flat.
  if hasattr(outputs.__class__, '__attrs_attrs__'):
    return False

  # Getting here means either outputs itself is a single non-structured value
  # or it is a flat list of single non-structured values.
  return True


def _postprocess_flat_outputs(outputs):
  """Validates flat outputs and adds back device assignments.

  Args:
    outputs: Output from `computation` inside `xla.compile`.

  Returns:
    Tensors and Operations extracted from outputs.
  """
  # Following code segment is to preserve legacy behavior. Previously we only
  # supported flat outputs and thus for consistency it was nice to convert even
  # single element into a tuple. But now that we support arbitrary output
  # structure, this is no longer necessary.
  # TODO(b/121383831): Migrate all legacy use cases and delete this special
  # case.
  # If the computation returns `None`, make it an empty tuple.
  if outputs is None:
    outputs = tuple()
  # If the computation only returned one value, make it a tuple.
  if not isinstance(outputs, collections_abc.Sequence):
    outputs = (outputs,)

  # Append `no_op` here so that return value of this function always contains
  # at least one op that can trigger XlaLaunch node.
  outputs += (control_flow_ops.no_op(),)
  try:
    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]
  except Exception as e:
    raise ValueError(
        'XLA computation function return values must all either be Operations'
        ' or convertible to Tensors. Got error: "%s"' % str(e))

  # Separates the returned Operations and Tensors.
  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]

  if outputs != output_tensors + output_operations:
    raise ValueError(
        'XLA computation function must return zero or more Tensor values '
        'followed by zero or more Operations.')

  new_output_tensors = []
  for t in output_tensors:
    with ops.device(t.device if t.device else ''):
      new_output_tensors.append(array_ops.identity(t))

  return new_output_tensors, output_operations


def _postprocess_non_flat_outputs(outputs):
  """Validates non-flat outputs and adds back device assignments.

  Args:
    outputs: Output from `computation` inside `xla.compile`.

  Returns:
    Tensors extracted from outputs and an empty list because Operations are not
    allowed in non-flat outputs..
  """
  # Convert all non-Operation outputs to Tensors.
  new_output_tensors = []
  for o in nest.flatten(outputs):
    if isinstance(o, ops.Operation):
      raise ValueError(
          'xla.compile does not support Operation as return value in non-flat '
          'output structure. You can set returned Operations as control '
          'dependencies of returned Tensors so Operations are triggered when '
          'Tensors are evaluated. Operation found: "%s"' % o.name)

    try:
      o = ops.convert_to_tensor(o)
    except Exception as e:
      raise ValueError(
          'XLA computation function return values must all either be '
          'Operations or convertible to Tensors. Got error: "%s"' % str(e))

    # Makes sure even pass-through inputs/outputs are touched in compile
    # context by creating an Identity node inside compile context.
    with ops.device(o.device if o.device else ''):
      new_output_tensors.append(array_ops.identity(o))

  return new_output_tensors, []


@contextlib.contextmanager
def _disable_summary_context():
  """Enters a context where all summary ops are skipped.

  Summaries are not yet supported in xla.compile(). So we provide this context
  manager that can skip creating summary ops. This is a temporary workaround due
  to XLA not supporting summary ops.

  Yields:
    None.
  """
  original_skip_summary_func = summary_op_util.skip_summary
  summary_op_util.skip_summary = lambda: True

  try:
    yield
  finally:
    summary_op_util.skip_summary = original_skip_summary_func


class _CapturedObject(object):
  """A placeholder to capture an object."""

  def __init__(self):
    self._object = None

  def capture(self, o):
    if self._object:
      raise RuntimeError(
          'InternalError: _CapturedObject can capture only once. Please file '
          'bug.')

    self._object = o

  def get(self):
    return self._object


def check_function_argument_count(func, input_arity, infeed_queue):
  """Validate the number of input arguments to an XLA function.

  Args:
    func: the Python function that will be called to generate the body of an XLA
      computation graph.
    input_arity: the number of explicit arguments supplied by the caller.
    infeed_queue: if not None, the infeed queue that will supply
      additional arguments to the function.

  Returns:
    None if function can be called with the supplied number of
      arguments, or an error string if it cannot.
  """
  def format_error(complaint, quantity):
    return '%s %d argument%s' % (complaint, quantity, ''
                                 if quantity == 1 else 's')

  num_args_supplied = input_arity
  if infeed_queue is not None:
    num_args_supplied += infeed_queue.number_of_tuple_elements
  arg_spec = tf_inspect.getargspec(func)
  num_func_args = len(arg_spec.args)
  if arg_spec.defaults is None:
    num_func_defaults = 0
  else:
    num_func_defaults = len(arg_spec.defaults)
  min_func_args = num_func_args - num_func_defaults
  if num_args_supplied < min_func_args:
    # The required number of arguments is not enough to call the function.
    if num_func_defaults == 0 and arg_spec.varargs is None:
      return format_error('exactly', num_func_args)
    else:
      return format_error('at least', min_func_args)
  if arg_spec.varargs is None and num_args_supplied > num_func_args:
    # The required number of arguments is too many to call the function.
    if num_func_defaults == 0:
      return format_error('exactly', num_func_args)
    else:
      return format_error('at most', num_func_args)
  # Reaching here means either
  # 1) There are varargs, func can accept any number of arguments greater than
  # the minimum.
  # 2) Number of supplied arguments falls in range of acceptable argument count
  # of func.
  return None
