# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cond function for Control Flow Operations."""

from tensorflow.python.eager import context
from tensorflow.python.eager.polymorphic_function import eager_function_run
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


# pylint: disable=redefined-outer-name
# pylint: disable=g-doc-args
@tf_export(v1=["cond"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(
    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
    "fn1", "fn2")
def cond(pred,
         true_fn=None,
         false_fn=None,
         strict=False,
         name=None,
         fn1=None,
         fn2=None):
  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.

  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
  `false_fn` must have the same non-zero number and type of outputs.

  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
  `false_fn` will be executed regardless of which branch is selected at runtime.

  Although this behavior is consistent with the dataflow model of TensorFlow,
  it has frequently surprised users who expected a lazier semantics.
  Consider the following simple program:

  ```python
  z = tf.multiply(a, b)
  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
  ```

  If `x < y`, the `tf.add` operation will be executed and `tf.square`
  operation will not be executed. Since `z` is needed for at least one
  branch of the `cond`, the `tf.multiply` operation is always executed,
  unconditionally.

  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
  call to `cond`, and not at all during `Session.run()`). `cond`
  stitches together the graph fragments created during the `true_fn` and
  `false_fn` calls with some additional graph nodes to ensure that the right
  branch gets executed depending on the value of `pred`.

  `tf.cond` supports nested structures as implemented in
  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
  same (possibly nested) value structure of lists, tuples, and/or named tuples.
  Singleton lists and tuples form the only exceptions to this: when returned by
  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
  This behavior is disabled by passing `strict=True`.

  Args:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    strict: A boolean that enables/disables 'strict' mode; see above.
    name: Optional name prefix for the returned tensors.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`. If the
    callables return a singleton list, the element is extracted from the list.

  Raises:
    TypeError: if `true_fn` or `false_fn` is not callable.
    ValueError: if `true_fn` and `false_fn` do not return the same number of
      tensors, or return tensors of different types.

  Example:

  ```python
  x = tf.constant(2)
  y = tf.constant(5)
  def f1(): return tf.multiply(x, 17)
  def f2(): return tf.add(y, 23)
  r = tf.cond(tf.less(x, y), f1, f2)
  # r is set to f1().
  # Operations in f2 (e.g., tf.add) are not executed.
  ```

  """
  # We needed to make true_fn/false_fn keyword arguments for
  # backwards-compatibility. This check exists so that we can convert back to
  # having them be positional arguments.
  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
  # `fn1` and `fn2` are deleted.
  if fn1 is not None:
    if true_fn is not None:
      raise TypeError(
          "cond(): 'true_fn' and 'fn1' may not be set simultaneously.")
    true_fn = fn1
  elif true_fn is None:
    raise TypeError("cond(): 'true_fn' argument required")
  if fn2 is not None:
    if false_fn is not None:
      raise TypeError(
          "cond(): 'false_fn' and 'fn2' may not be set simultaneously.")
    false_fn = fn2
  elif false_fn is None:
    raise TypeError("cond(): 'false_fn' argument required")

  if not callable(true_fn):
    raise TypeError("'true_fn' must be callable.")
  if not callable(false_fn):
    raise TypeError("'false_fn' must be callable.")

  if context.executing_eagerly():
    return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)

  # Always enable control flow v2 if building a function, regardless of toggle.
  if util.EnableControlFlowV2(ops.get_default_graph()):
    return cond_v2.cond_v2(pred, true_fn, false_fn, name)

  with ops.name_scope(name, "cond", [pred]):
    # Add the Switch to the graph.
    if isinstance(pred, bool):
      raise TypeError("'pred' must not be a Python bool.")
    p_2, p_1 = control_flow_ops.switch(pred, pred)
    pivot_1 = array_ops.identity(p_1, name="switch_t")
    pivot_2 = array_ops.identity(p_2, name="switch_f")
    pred = array_ops.identity(pred, name="pred_id")
    # Disable the fetching of tensors that are only on one branch of cond.
    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
      tensor.op.graph.prevent_fetching(tensor.op)

    # Build the graph for the true branch in a new context.
    context_t = control_flow_ops.CondContext(pred, pivot_1, branch=1)
    try:
      context_t.Enter()
      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
      if orig_res_t is None:
        raise ValueError("'true_fn' must have a return value.")
      context_t.ExitResult(res_t)
    finally:
      context_t.Exit()

    # Build the graph for the false branch in a new context.
    context_f = control_flow_ops.CondContext(pred, pivot_2, branch=0)
    try:
      context_f.Enter()
      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
      if orig_res_f is None:
        raise ValueError("'false_fn' must have a return value.")
      context_f.ExitResult(res_f)
    finally:
      context_f.Exit()

    if not strict:
      orig_res_t = _UnpackIfSingleton(orig_res_t)
      orig_res_f = _UnpackIfSingleton(orig_res_f)

    # Check that the return values of the two branches have the same structure.
    try:
      nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
    except (TypeError, ValueError):
      nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
      nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
      try:
        nest.assert_same_structure(orig_res_t, orig_res_f,
                                   expand_composites=True)
      except TypeError as e:
        raise TypeError(
            f"Incompatible return types of 'true_fn' and 'false_fn': {e}")
      except ValueError as e:
        raise ValueError(
            f"Incompatible return values of 'true_fn' and 'false_fn': {e}")

    # Add the final merge to the graph.
    if not res_t:
      raise ValueError(
          "'true_fn' and 'false_fn' must return at least one result.")

    res_t_flat = nest.flatten(res_t, expand_composites=True)
    res_f_flat = nest.flatten(res_f, expand_composites=True)

    for (x, y) in zip(res_t_flat, res_f_flat):
      assert (
          isinstance(x, tensor_lib.Tensor)
          and isinstance(y, tensor_lib.Tensor))
      if x.dtype.base_dtype != y.dtype.base_dtype:
        raise ValueError(
            "Outputs of 'true_fn' and 'false_fn' must have the same type(s). "
            f"Received {x.dtype.name} from 'true_fn' "
            f"and {y.dtype.name} from 'false_fn'.")

    merges = [
        control_flow_ops.merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
    merges = nest.map_structure(
        control_flow_ops._convert_flow_to_tensorarray,  # pylint: disable=protected-access
        nest.flatten(orig_res_t, expand_composites=True),
        merges)

    # Only add non-nested conds to the collection. Any nested control flow will
    # be encapsulated in the root context.
    assert context_t.outer_context == context_f.outer_context
    if context_t.outer_context is None:
      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)

    merges = nest.pack_sequence_as(
        structure=orig_res_t, flat_sequence=merges, expand_composites=True)

    # Singleton lists and tuples are automatically unpacked if strict == False.
    if not strict:
      merges = _UnpackIfSingleton(merges)
    return merges


@tf_export("cond", v1=[])
@dispatch.add_dispatch_support
def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.

  Note: This op is automatically used in a `tf.function` to convert Python
  if-statements when the predicate is a `tf.Tensor`, unless `autograph=False` is
  explicitly specified in `tf.function` args. For example, the following are
  equivalent:

  >>> @tf.function
  ... def fun1(x,y):
  ...   if x > 0:  # AutoGraph converts if-statement to tf.cond().
  ...     z = y+1
  ...   else:
  ...     z = y-1
  ...   return z
  >>> print(fun1(tf.constant(7), tf.constant(3)).numpy())
  4

  >>> @tf.function
  ... def fun2(x,y):
  ...   pred = x > 0
  ...   true_fn =  lambda: y+1
  ...   false_fn = lambda: y-1
  ...   return tf.cond(pred, true_fn, false_fn)  # Use tf.cond() explicitly.
  >>> print(fun1(tf.constant(7), tf.constant(3)).numpy())
  4

  For more information, see [tf.function and AutoGraph guide](
  https://www.tensorflow.org/guide/function#autograph_transformations).

  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
  `false_fn` must have the same non-zero number and type of outputs.

  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
  `false_fn` will be executed regardless of which branch is selected at runtime.

  Although this behavior is consistent with the dataflow model of TensorFlow,
  it has frequently surprised users who expected a lazier semantics.
  Consider the following simple program:

  >>> x, y = tf.constant(2, dtype=tf.int32), tf.constant(4, dtype=tf.int32)
  >>> z = tf.multiply(x, y)
  >>> r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
  >>> print(r.numpy())
  10

  If `x < y`, the `tf.add` operation will be executed and `tf.square`
  operation will not be executed. Since `z` is needed for at least one
  branch of the `cond`, the `tf.multiply` operation is always executed,
  unconditionally.

  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
  call to `cond`, and not at all during `Session.run()`). `cond`
  stitches together the graph fragments created during the `true_fn` and
  `false_fn` calls with some additional graph nodes to ensure that the right
  branch gets executed depending on the value of `pred`.

  `tf.cond` supports nested structures as implemented in
  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
  same (possibly nested) value structure of lists, tuples, and/or named tuples.
  Singleton lists and tuples form the only exceptions to this: when returned by
  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.

  Note: It is illegal to "directly" use tensors created inside a cond branch
  outside it, e.g. by storing a reference to a branch tensor in the python
  state. If you need to use a tensor created in a branch function you should
  return it as an output of the branch function and use the output from
  `tf.cond` instead.

  Args:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix for the returned tensors.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`. If the
    callables return a singleton list, the element is extracted from the list.

  Raises:
    TypeError: if `true_fn` or `false_fn` is not callable.
    ValueError: if `true_fn` and `false_fn` do not return the same number of
      tensors, or return tensors of different types.

  Example:

  >>> x = tf.constant(2)
  >>> y = tf.constant(5)
  >>> def f1(): return tf.multiply(x, 7)
  >>> def f2(): return tf.add(y, 3)
  >>> r = tf.cond(tf.less(x, y), f1, f2)
  >>> # r is set to f1().
  >>> # Operations in f2 (e.g., tf.add) are not executed.
  >>> print(r.numpy())
  14

  """
  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)


def _UnpackIfSingleton(res):
  if isinstance(res, (list, tuple)) and len(res) == 1:
    return res[0]
  else:
    return res


def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
  """Special cases for `cond` when executing eagerly."""
  pred = ops.convert_to_tensor(pred)
  pred_constant_value = tensor_util.constant_value(pred)
  if pred_constant_value is None:
    # Eager tensors from a parallel device may not have a constant
    # value. Running the cond op itself would work, but we don't have logic to
    # build cond ops without wrapping in a function first.
    if (not isinstance(true_fn, core.PolymorphicFunction)
        or not isinstance(false_fn, core.PolymorphicFunction)):
      raise TypeError("When running tf.cond on a parallel device, 'true_fn' "
                      "and 'false_fn' must be decorated with `tf.function`.")
    functions_run_eagerly = eager_function_run.functions_run_eagerly()
    if functions_run_eagerly:
      # We need to use tf.function to deal with variable creation inside the
      # cond, and skipping it because of run_functions_eagerly would just
      # crash immediately.
      logging.warning(
          "It looks like tf.function behavior was disabled, perhaps using "
          "tf.config.run_functions_eagerly. Parallelized tf.cond requires "
          "tf.function to work. This primitive will override the disable.")
    eager_function_run.run_functions_eagerly(False)
    try:
      return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    finally:
      if functions_run_eagerly is not None:
        eager_function_run.run_functions_eagerly(functions_run_eagerly)
  else:
    # For conditions which are eager tensors with a constant value (most of
    # them), we only call the relevant branch function and execute it eagerly.
    with ops.name_scope(name, "cond", [pred]):
      if pred_constant_value:
        result = true_fn()
      else:
        result = false_fn()
      if not strict:
        result = _UnpackIfSingleton(result)
      return result


def _cast_indexed_slice_indices(a, b):
  """Cast IndexedSlice.indices from int32 to int64 where necessary.

  If `a` and `b` are both IndexedSlices, and their indices have different
  dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
  in-place).  Otherwise, does nothing.

  Args:
    a: A value, which may be an IndexedSlices.
    b: A value, which may be an IndexedSlices.
  """
  if (isinstance(a, indexed_slices.IndexedSlices) and
      isinstance(b, indexed_slices.IndexedSlices) and
      a.indices.dtype != b.indices.dtype):
    # pylint: disable=protected-access
    a._indices = math_ops.cast(a.indices, dtypes.int64)
    b._indices = math_ops.cast(b.indices, dtypes.int64)
