# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Tests for this file live in python/kernel_tests/array_ops_test.py
"""Operations to stack and unstack tensors."""

from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export


@tf_export("stack")
@dispatch.add_dispatch_support
def stack(values, axis=0, name="stack"):
  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.

  See also `tf.concat`, `tf.tile`, `tf.repeat`.

  Packs the list of tensors in `values` into a tensor with rank one higher than
  each tensor in `values`, by packing them along the `axis` dimension.
  Given a list of length `N` of tensors of shape `(A, B, C)`;

  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
  Etc.

  For example:

  >>> x = tf.constant([1, 4])
  >>> y = tf.constant([2, 5])
  >>> z = tf.constant([3, 6])
  >>> tf.stack([x, y, z])
  <tf.Tensor: shape=(3, 2), dtype=int32, numpy=
  array([[1, 4],
         [2, 5],
         [3, 6]], dtype=int32)>
  >>> tf.stack([x, y, z], axis=1)
  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
  array([[1, 2, 3],
         [4, 5, 6]], dtype=int32)>

  This is the opposite of unstack.  The numpy equivalent is `np.stack`

  >>> np.array_equal(np.stack([x, y, z]), tf.stack([x, y, z]))
  True

  Args:
    values: A list of `Tensor` objects with the same shape and type.
    axis: An `int`. The axis to stack along. Defaults to the first dimension.
      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
    name: A name for this operation (optional).

  Returns:
    output: A stacked `Tensor` with the same type as `values`.

  Raises:
    ValueError: If `axis` is out of the range [-(R+1), R+1).
  """
  if axis == 0:
    try:
      # If the input is a constant list, it can be converted to a constant op
      return ops.convert_to_tensor(values, name=name)
    except (TypeError, ValueError, NotImplementedError):
      pass  # Input list contains non-constant tensors

  value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple()  # pylint: disable=protected-access
  if value_shape is not None:
    expanded_num_dims = len(value_shape) + 1
    if axis < -expanded_num_dims or axis >= expanded_num_dims:
      raise ValueError(f"Argument `axis` = {axis} not in range "
                       f"[{-expanded_num_dims}, {expanded_num_dims})")

  return gen_array_ops.pack(values, axis=axis, name=name)


@tf_export("unstack")
@dispatch.add_dispatch_support
def unstack(value, num=None, axis=0, name="unstack"):
  """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.

  Unpacks tensors from `value` by chipping it along the `axis` dimension.

  >>> x = tf.reshape(tf.range(12), (3,4))
  >>>
  >>> p, q, r = tf.unstack(x)
  >>> p.shape.as_list()
  [4]

  >>> i, j, k, l = tf.unstack(x, axis=1)
  >>> i.shape.as_list()
  [3]

  This is the opposite of stack.

  >>> x = tf.stack([i, j, k, l], axis=1)

  More generally if you have a tensor of shape `(A, B, C, D)`:

  >>> A, B, C, D = [2, 3, 4, 5]
  >>> t = tf.random.normal(shape=[A, B, C, D])

  The number of tensor returned is equal to the length of the target `axis`:

  >>> axis = 2
  >>> items = tf.unstack(t, axis=axis)
  >>> len(items) == t.shape[axis]
  True

  The shape of each result tensor is equal to the shape of the input tensor,
  with the target `axis` removed.

  >>> items[0].shape.as_list()  # [A, B, D]
  [2, 3, 5]

  The value of each tensor `items[i]` is equal to the slice of `input` across
  `axis` at index `i`:

  >>> for i in range(len(items)):
  ...   slice = t[:,:,i,:]
  ...   assert tf.reduce_all(slice == items[i])

  #### Python iterable unpacking

  With eager execution you _can_ unstack the 0th axis of a tensor using python's
  iterable unpacking:

  >>> t = tf.constant([1,2,3])
  >>> a,b,c = t

  `unstack` is still necessary because Iterable unpacking doesn't work in
  a `@tf.function`: Symbolic tensors are not iterable.

  You need to use `tf.unstack` here:

  >>> @tf.function
  ... def bad(t):
  ...   a,b,c = t
  ...   return a
  >>>
  >>> bad(t)
  Traceback (most recent call last):
  ...
  OperatorNotAllowedInGraphError: ...

  >>> @tf.function
  ... def good(t):
  ...   a,b,c = tf.unstack(t)
  ...   return a
  >>>
  >>> print(good(t).numpy())
  1

  #### Unknown shapes

  Eager tensors have concrete values, so their shape is always known.
  Inside a `tf.function` the symbolic tensors may have unknown shapes.
  If the length of `axis` is unknown `tf.unstack` will fail because it cannot
  handle an unknown number of tensors:

  >>> @tf.function(input_signature=[tf.TensorSpec([None], tf.float32)])
  ... def bad(t):
  ...   tensors = tf.unstack(t)
  ...   return tensors[0]
  >>>
  >>> bad(tf.constant([1.0, 2.0, 3.0]))
  Traceback (most recent call last):
  ...
  ValueError: Cannot infer argument `num` from shape (None,)

  If you know the `axis` length you can pass it as the `num` argument. But this
  must be a constant value.

  If you actually need a variable number of tensors in a single `tf.function`
  trace, you will need to use exlicit loops and a `tf.TensorArray` instead.

  Args:
    value: A rank `R > 0` `Tensor` to be unstacked.
    num: An `int`. The length of the dimension `axis`. Automatically inferred if
      `None` (the default).
    axis: An `int`. The axis to unstack along. Defaults to the first dimension.
      Negative values wrap around, so the valid range is `[-R, R)`.
    name: A name for the operation (optional).

  Returns:
    The list of `Tensor` objects unstacked from `value`.

  Raises:
    ValueError: If `axis` is out of the range `[-R, R)`.
    ValueError: If `num` is unspecified and cannot be inferred.
    InvalidArgumentError: If `num` does not match the shape of `value`.
  """
  if num is None:
    value = ops.convert_to_tensor(value)
    value_shape = value.get_shape()
    if value_shape.ndims is not None:
      if axis < -value_shape.ndims or axis >= value_shape.ndims:
        raise ValueError(f"Argument `axis` = {axis} not in range "
                         f"[{-value_shape.ndims}, {value_shape.ndims})")
      num = value_shape.dims[axis].value
    if num is None:
      raise ValueError(f"Cannot infer argument `num` from shape {value_shape}")
  return gen_array_ops.unpack(value, num=num, axis=axis, name=name)
