import math

import numpy as np

from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend.common.backend_utils import canonicalize_axis
from keras.src.backend.common.backend_utils import to_tuple_or_list


def broadcast_shapes(shape1, shape2):
    """Broadcast input shapes to a unified shape.

    Convert to list for mutability.

    Args:
        shape1: A tuple or list of integers.
        shape2: A tuple or list of integers.

    Returns:
        output_shape (list of integers or `None`): The broadcasted shape.

    Example:
    >>> broadcast_shapes((5, 3), (1, 3))
    [5, 3]
    """
    shape1 = list(shape1)
    shape2 = list(shape2)
    origin_shape1 = shape1
    origin_shape2 = shape2

    if len(shape1) > len(shape2):
        shape2 = [1] * (len(shape1) - len(shape2)) + shape2
    if len(shape1) < len(shape2):
        shape1 = [1] * (len(shape2) - len(shape1)) + shape1
    output_shape = list(shape1)
    for i in range(len(shape1)):
        if shape1[i] == 1:
            output_shape[i] = shape2[i]
        elif shape1[i] is None:
            output_shape[i] = None if shape2[i] == 1 else shape2[i]
        else:
            if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:
                output_shape[i] = shape1[i]
            else:
                raise ValueError(
                    "Cannot broadcast shape, the failure dim has value "
                    f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. "
                    f"Input shapes are: {origin_shape1} and {origin_shape2}."
                )

    return output_shape


def compute_expand_dims_output_shape(input_shape, axis):
    """Compute the output shape for the `expand_dims` operation.

    Args:
        input_shape: Input shape.
        axis: int or sequence of ints for the axis to expand.

    Returns:
        Tuple of ints: The output shape after the `expand_dims` operation.
    """
    input_shape = list(input_shape)
    if axis is None:
        axis = len(input_shape)
    axis = to_tuple_or_list(axis)
    out_ndim = len(axis) + len(input_shape)
    axis = [canonicalize_axis(a, out_ndim) for a in axis]
    shape_iter = iter(input_shape)
    new_shape = [
        1 if ax in axis else next(shape_iter) for ax in range(out_ndim)
    ]
    return tuple(new_shape)


def compute_pooling_output_shape(
    input_shape,
    pool_size,
    strides,
    padding="valid",
    data_format="channels_last",
):
    """Computes the output shape of pooling operations.

    Args:
        input_shape: Input shape. Must be a tuple of integers.
        pool_size: Size of the pooling operation. Must be a tuple of integers.
        strides: Stride of the pooling operation. Must be a tuple of integers.
            Defaults to `pool_size`.
        padding: Padding method. Available methods are `"valid"` or `"same"`.
            Defaults to `"valid"`.
        data_format: String, either `"channels_last"` or `"channels_first"`.
            The ordering of the dimensions in the inputs. `"channels_last"`
            corresponds to inputs with shape `(batch, height, width, channels)`
            while `"channels_first"` corresponds to inputs with shape
            `(batch, channels, height, weight)`. Defaults to `"channels_last"`.

    Returns:
        Tuple of ints: The output shape of the pooling operation.

    Examples:

    # Basic usage with square pooling on a single image
    >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2))
    (1, 2, 2, 1)

    # Strided pooling on a single image with strides different from pool_size
    >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1))
    (1, 3, 3, 1)

    # Pooling on a batch of images
    >>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2))
    (32, 2, 2, 3)
    """
    strides = pool_size if strides is None else strides
    input_shape_origin = list(input_shape)
    input_shape = np.array(input_shape)
    if data_format == "channels_last":
        spatial_shape = input_shape[1:-1]
    else:
        spatial_shape = input_shape[2:]
    none_dims = []
    for i in range(len(spatial_shape)):
        if spatial_shape[i] is None:
            # Set `None` shape to a manual value so that we can run numpy
            # computation on `spatial_shape`.
            spatial_shape[i] = -1
            none_dims.append(i)
    pool_size = np.array(pool_size)
    if padding == "valid":
        output_spatial_shape = (
            np.floor((spatial_shape - pool_size) / strides) + 1
        )
        for i in range(len(output_spatial_shape)):
            if i not in none_dims and output_spatial_shape[i] < 0:
                raise ValueError(
                    "Computed output size would be negative. Received: "
                    f"`inputs.shape={input_shape}` and `pool_size={pool_size}`."
                )
    elif padding == "same":
        output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
    else:
        raise ValueError(
            "Argument `padding` must be either 'valid' or 'same'. Received: "
            f"padding={padding}"
        )
    output_spatial_shape = [int(i) for i in output_spatial_shape]
    for i in none_dims:
        output_spatial_shape[i] = None
    output_spatial_shape = tuple(output_spatial_shape)
    if data_format == "channels_last":
        output_shape = (
            (input_shape_origin[0],)
            + output_spatial_shape
            + (input_shape_origin[-1],)
        )
    else:
        output_shape = (
            input_shape_origin[0],
            input_shape_origin[1],
        ) + output_spatial_shape
    return output_shape


def compute_conv_output_shape(
    input_shape,
    filters,
    kernel_size,
    strides=1,
    padding="valid",
    data_format="channels_last",
    dilation_rate=1,
):
    """Compute the output shape of conv ops."""
    if data_format == "channels_last":
        spatial_shape = input_shape[1:-1]
        kernel_shape = kernel_size + (input_shape[-1], filters)
    else:
        spatial_shape = input_shape[2:]
        kernel_shape = kernel_size + (input_shape[1], filters)
    if len(kernel_shape) != len(input_shape):
        raise ValueError(
            "Kernel shape must have the same length as input, but received "
            f"kernel of shape {kernel_shape} and "
            f"input of shape {input_shape}."
        )
    if isinstance(dilation_rate, int):
        dilation_rate = (dilation_rate,) * len(spatial_shape)
    if isinstance(strides, int):
        strides = (strides,) * len(spatial_shape)
    if len(dilation_rate) != len(spatial_shape):
        raise ValueError(
            "Dilation must be None, scalar or tuple/list of length of "
            "inputs' spatial shape, but received "
            f"`dilation_rate={dilation_rate}` and "
            f"input of shape {input_shape}."
        )
    none_dims = []
    spatial_shape = np.array(spatial_shape)
    for i in range(len(spatial_shape)):
        if spatial_shape[i] is None:
            # Set `None` shape to a manual value so that we can run numpy
            # computation on `spatial_shape`.
            spatial_shape[i] = -1
            none_dims.append(i)

    kernel_spatial_shape = np.array(kernel_shape[:-2])
    dilation_rate = np.array(dilation_rate)
    if padding == "valid":
        output_spatial_shape = (
            np.floor(
                (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)
                / strides
            )
            + 1
        )
        for i in range(len(output_spatial_shape)):
            if i not in none_dims and output_spatial_shape[i] < 0:
                raise ValueError(
                    "Computed output size would be negative. Received "
                    f"`inputs shape={input_shape}`, "
                    f"`kernel shape={kernel_shape}`, "
                    f"`dilation_rate={dilation_rate}`."
                )
    elif padding == "same" or padding == "causal":
        output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1
    else:
        raise ValueError(
            "`padding` must be either `'valid'` or `'same'`. Received "
            f"{padding}."
        )
    output_spatial_shape = [int(i) for i in output_spatial_shape]
    for i in none_dims:
        output_spatial_shape[i] = None
    output_spatial_shape = tuple(output_spatial_shape)
    if data_format == "channels_last":
        output_shape = (
            (input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)
        )
    else:
        output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape
    return output_shape


def compute_matmul_output_shape(shape1, shape2):
    """Compute the output shape of a `matmul` operation.

    Args:
        shape1: Shape of the left operand.
        shape2: Shape of the right operand.

    Returns:
        Tuple of ints: The output shape for the `matmul` operation.
    """
    if len(shape1) == 1:
        shape1 = (1, shape1[0])
    if len(shape2) == 1:
        shape2 = (shape2[0], 1)
    if (
        shape1[-1] is not None
        and shape2[-2] is not None
        and shape1[-1] != shape2[-2]
    ):
        raise ValueError(
            "Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be "
            f"equal, but received `x1.shape={shape1}` and "
            f"`x2.shape={shape2}`."
        )

    leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2])
    last_2_dims_shape = [shape1[-2], shape2[-1]]
    output_shape = leading_shape + last_2_dims_shape
    if len(shape1) == 1:
        del output_shape[-2]
    if len(shape2) == 1:
        del output_shape[-1]
    return tuple(output_shape)


def compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):
    """Converts `-1` in `newshape` to either an actual dimension or `None`.

    This utility does not special case the 0th dimension (batch size).
    """
    unknown_dim_count = newshape.count(-1)
    if unknown_dim_count > 1:
        raise ValueError(
            "There must be at most one unknown dimension (-1) in "
            f"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}."
        )

    # If there is a None in input_shape, we can't infer what the -1 is
    if None in input_shape:
        return tuple(dim if dim != -1 else None for dim in newshape)

    input_size = math.prod(input_shape)
    # If the `newshape` is fully defined, return it
    if unknown_dim_count == 0:
        if input_size != math.prod(newshape):
            raise ValueError(
                "The total size of the tensor must be unchanged. Received: "
                f"input_shape={input_shape}, {newshape_arg_name}={newshape}"
            )
        return newshape

    # We have one -1 in `newshape`, compute the actual value
    known_output_size = 1
    unknown_dim_index = None
    for index, dim in enumerate(newshape):
        if dim == -1:
            unknown_dim_index = index
        else:
            known_output_size *= dim

    if known_output_size == 0 or input_size % known_output_size != 0:
        raise ValueError(
            "The total size of the tensor must be unchanged, however, the "
            "input size cannot by divided by the specified dimensions in "
            f"{newshape_arg_name}. Received: input_shape={input_shape}, "
            f"{newshape_arg_name}={newshape}"
        )

    output_shape = list(newshape)
    output_shape[unknown_dim_index] = input_size // known_output_size
    return tuple(output_shape)


def compute_transpose_output_shape(input_shape, axes):
    """Compute the output shape for the `transpose` operation.

    Args:
        input_shape: Input shape.
        axes: Permutation of the dimensions for the `transpose` operation.

    Returns:
        Tuple of ints: The output shape after the `transpose` operation.
    """
    input_shape = list(input_shape)
    if axes is None:
        return tuple(input_shape[::-1])

    if len(axes) != len(input_shape):
        raise ValueError(
            "axis must be a list of the same length as the input shape, "
            f"expected {len(input_shape)}, but received {len(axes)}."
        )
    return tuple(input_shape[ax] for ax in axes)


def compute_take_along_axis_output_shape(input_shape, indices_shape, axis):
    input_shape = list(input_shape)
    indices_shape = list(indices_shape)
    if axis is None:
        input_shape = (
            [None] if None in input_shape else [int(np.prod(input_shape))]
        )

    if len(input_shape) != len(indices_shape):
        raise ValueError(
            "`x` and `indices` must have the same number of dimensions, "
            f"but receive shape {input_shape} and {indices_shape}."
        )

    input_shape[axis] = indices_shape[axis]
    output_shape = broadcast_shapes(input_shape, indices_shape)
    return output_shape


def reduce_shape(shape, axis=None, keepdims=False):
    shape = list(shape)
    if axis is None:
        if keepdims:
            return tuple([1 for _ in shape])
        else:
            return tuple([])

    if keepdims:
        for ax in axis:
            shape[ax] = 1
        return tuple(shape)
    else:
        for ax in sorted(axis, reverse=True):
            del shape[ax]
        return tuple(shape)


@keras_export("keras.utils.get_source_inputs")
def get_source_inputs(tensor):
    """Returns the list of input tensors necessary to compute `tensor`.

    Output will always be a list of tensors
    (potentially with 1 element).

    Args:
        tensor: The tensor to start from.

    Returns:
        List of input tensors.
    """
    if not hasattr(tensor, "_keras_history"):
        return tensor

    operation, node_index, _ = tensor._keras_history
    if not operation or not operation._inbound_nodes:
        return [tensor]
    else:
        node = operation._inbound_nodes[node_index]
        if node.is_input:
            # Reached input node, stop recursion.
            return tree.flatten(node.output_tensors)
        else:
            source_tensors = []
            for tensor in node.input_tensors:
                previous_sources = get_source_inputs(tensor)
                # Avoid input redundancy.
                for x in previous_sources:
                    if all(x is not t for t in source_tensors):
                        source_tensors.append(x)
            return source_tensors
