import collections

from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend.config import backend
from keras.src.ops.operation import Operation


@keras_export("keras.Function")
class Function(Operation):
    """Class that encapsulates a computation graph of Keras operations.

    You can use a `Function` to capture the computation graph linking
    some input tensors to some output tensors, and reapply the same
    computation on new inputs.

    A `Function` is similar to a Functional Model, with the difference
    that it is stateless (it does not track state variables)
    and does not implement the `Layer` API.

    Example:

    ```python
    input_1 = keras.KerasTensor(shape=(None, 2, 3))
    input_2 = keras.KerasTensor(shape=(None, 2, 3))
    x = input_1 + input_2
    output = keras.ops.sigmoid(x)
    fn = keras.Function(inputs=[input_1, input_2], outputs=output)

    input_1_val = np.random.random((4, 2, 3))
    input_2_val = np.random.random((4, 2, 3))
    output_val = fn([input_1_val, input_2_val])
    ```

    Args:
        inputs: `KerasTensor` instance or nested structured of
            `KerasTensor` instances.
        outputs: `KerasTensor` instance or nested structured of
            `KerasTensor` instances. They should be computable
            given only the values of `inputs`.
        name: String. The name of the function.
    """

    def __init__(self, inputs, outputs, name=None):
        super().__init__(name=name)

        if backend() == "tensorflow":
            # Temporary work around for
            # https://github.com/keras-team/keras/issues/931
            # This stop tensorflow from wrapping tf.function output in a
            # _DictWrapper object.
            _self_setattr_tracking = getattr(
                self, "_self_setattr_tracking", True
            )
            self._self_setattr_tracking = False
        self._inputs_struct = tree.map_structure(lambda x: x, inputs)
        self._outputs_struct = tree.map_structure(lambda x: x, outputs)
        self._inputs = tree.flatten(inputs)
        self._outputs = tree.flatten(outputs)
        if not self._inputs:
            raise ValueError(
                "`inputs` argument cannot be empty. Received:\n"
                f"inputs={inputs}\n"
                f"outputs={outputs}"
            )
        if not self._outputs:
            raise ValueError(
                "`outputs` argument cannot be empty. Received:\n"
                f"inputs={inputs}\n"
                f"outputs={outputs}"
            )

        if backend() == "tensorflow":
            self._self_setattr_tracking = _self_setattr_tracking

        (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
            self._inputs, self._outputs
        )
        self._nodes = nodes
        self._nodes_by_depth = nodes_by_depth
        self._operations = operations
        self._operations_by_depth = operations_by_depth
        for input in self._inputs:
            if (
                input._keras_history.operation
                and not input._keras_history.operation._outbound_nodes
            ):
                raise ValueError("`inputs` not connected to `outputs`")

    @property
    def operations(self):
        return self._operations[:]

    @property
    def inputs(self):
        """Flat list of the symbolic inputs of the Function."""
        return self._inputs

    @property
    def outputs(self):
        """Flat list of the symbolic outputs of the Function."""
        return self._outputs

    def compute_output_spec(self, inputs):
        self._assert_input_compatibility(inputs)
        # Check if input shapes are identical to ref input shapes,
        # if so take a shortcut.
        shortcut = True
        for x, x_ref in zip(tree.flatten(inputs), self._inputs):
            if x.shape != x_ref.shape:
                shortcut = False
                break
        if shortcut:
            return tree.map_structure(
                lambda x: KerasTensor(shape=x.shape, dtype=x.dtype),
                self._outputs_struct,
            )
        # No luck; take the long road through the graph.
        # Original Keras used a cache to avoid recomputing all this
        # when known input shapes where seen again. Perhaps a good
        # idea to bring that back.
        return self._run_through_graph(
            inputs, operation_fn=lambda op: op.compute_output_spec
        )

    def compute_output_shape(self, input_shape):
        # Wrap `input_shape` into the structure of KerasTensor to utilize
        # `compute_output_spec`.
        input_shape_struct = tree.map_shape_structure(
            lambda x: KerasTensor(shape=x), input_shape
        )
        # Ensure that dtype and sparse settings are the same as self._inputs,
        # because we only care about the shape in this function.
        for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs):
            x._dtype = x_ref.dtype
            x._sparse = x_ref.sparse
        output_spec = self.compute_output_spec(input_shape_struct)
        return tree.map_structure(lambda x: x.shape, output_spec)

    def call(self, inputs):
        """Computes output tensors for new inputs."""
        self._assert_input_compatibility(inputs)
        return self._run_through_graph(inputs, operation_fn=lambda op: op)

    def _run_through_graph(self, inputs, operation_fn, call_fn=None):
        """Execute the graph.

        At each node we compute outputs via
        `operation_fn(node.operation)(*args, **kwargs)`.
        """
        inputs = tree.flatten(inputs)

        # Dictionary mapping reference tensors to computed tensors.
        tensor_dict = {}
        for x, y in zip(self.inputs, inputs):
            tensor_dict[id(x)] = y

        nodes_by_depth = self._nodes_by_depth
        depth_keys = list(nodes_by_depth.keys())
        depth_keys.sort(reverse=True)

        for depth in depth_keys:
            nodes = nodes_by_depth[depth]
            for node in nodes:
                if not node.operation or node.is_input:
                    continue  # Input tensors already exist.

                if any(id(x) not in tensor_dict for x in node.input_tensors):
                    continue  # Node is not computable, try skipping.

                args, kwargs = node.arguments.fill_in(tensor_dict)
                op = operation_fn(node.operation)
                if call_fn is not None:
                    outputs = call_fn(op, *args, **kwargs)
                else:
                    outputs = op(*args, **kwargs)

                # Update tensor_dict.
                for x, y in zip(node.outputs, tree.flatten(outputs)):
                    tensor_dict[id(x)] = y

        output_tensors = []
        for x in self.outputs:
            output_tensors.append(tensor_dict[id(x)])

        return tree.pack_sequence_as(self._outputs_struct, output_tensors)

    def _assert_input_compatibility(self, inputs):
        try:
            tree.assert_same_structure(inputs, self._inputs_struct)
        except ValueError:
            raise ValueError(
                "Function was called with an invalid input structure. "
                f"Expected input structure: {self._inputs_struct}\n"
                f"Received input structure: {inputs}"
            )
        for x, x_ref in zip(tree.flatten(inputs), self._inputs):
            if len(x.shape) != len(x_ref.shape):
                raise ValueError(
                    f"{self.__class__.__name__} was passed "
                    f"incompatible inputs. For input '{x_ref.name}', "
                    f"expected shape {x_ref.shape}, but received "
                    f"instead a tensor with shape {x.shape}."
                )
            for dim, ref_dim in zip(x.shape, x_ref.shape):
                if ref_dim is not None and dim is not None:
                    if dim != ref_dim:
                        raise ValueError(
                            f"{self.__class__.__name__} was passed "
                            f"incompatible inputs. For input '{x_ref.name}', "
                            f"expected shape {x_ref.shape}, but received "
                            f"instead a tensor with shape {x.shape}."
                        )


def make_node_key(op, node_index):
    return str(id(op)) + "_ib-" + str(node_index)


def map_graph(inputs, outputs):
    """Validates a graph's topology and gather its operations and nodes.

    Args:
        inputs: List of input tensors.
        outputs: List of outputs tensors.

    Returns:
        A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`.
        - nodes: set of Node instances
        - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
        - operations: list of Operation instances.
        - operations_by_depth: dict mapping ints (depth) to lists of Operation
            instances.
    """
    # "depth" is number of operations between output Node and the Node.
    # Nodes are ordered from inputs -> outputs.
    nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs)
    network_nodes = {
        make_node_key(node.operation, node.operation._inbound_nodes.index(node))
        for node in nodes_in_decreasing_depth
    }

    nodes_depths = {}  # dict {node: depth value}
    operations_depths = {}  # dict {operation: depth value}

    for node in reversed(nodes_in_decreasing_depth):
        # If the depth is not set, the node has no outbound nodes (depth 0).
        depth = nodes_depths.setdefault(node, 0)

        # Update the depth of the corresponding operation
        previous_depth = operations_depths.get(node.operation, 0)
        # If we've seen this operation before at a higher depth,
        # we should use that depth instead of the node depth.
        # This is necessary for shared operations that have inputs at different
        # depth levels in the graph.
        depth = max(depth, previous_depth)
        operations_depths[node.operation] = depth
        nodes_depths[node] = depth

        # Update the depth of inbound nodes.
        # The "depth" of a node is the max of the depths
        # of all nodes it is connected to + 1.
        for node_dep in node.parent_nodes:
            previous_depth = nodes_depths.get(node_dep, 0)
            nodes_depths[node_dep] = max(depth + 1, previous_depth)

    # Handle inputs that are not connected to outputs.
    # We do not error out here because the inputs may be used to compute losses
    # and metrics.
    for input_t in inputs:
        input_operation = input_t._keras_history[0]
        if input_operation and input_operation not in operations_depths:
            operations_depths[input_operation] = 0
            operation_indices[input_operation] = -1
            nodes_depths[input_operation._inbound_nodes[0]] = 0
            network_nodes.add(make_node_key(input_operation, 0))

    # Build a dict {depth: list of nodes with this depth}
    nodes_by_depth = collections.defaultdict(list)
    for node, depth in nodes_depths.items():
        nodes_by_depth[depth].append(node)

    # Build a dict {depth: list of operations with this depth}
    operations_by_depth = collections.defaultdict(list)
    for operation, depth in operations_depths.items():
        operations_by_depth[depth].append(operation)

    # Get sorted list of operation depths.
    depth_keys = list(operations_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Set self.operations ordered by depth.
    operations = []
    for depth in depth_keys:
        operations_for_depth = operations_by_depth[depth]
        # Network.operations needs to have a deterministic order:
        # here we order them by traversal order.
        operations_for_depth.sort(key=lambda x: operation_indices[x])
        operations.extend(operations_for_depth)

    # Get sorted list of node depths.
    depth_keys = list(nodes_by_depth.keys())
    depth_keys.sort(reverse=True)

    # Check that all tensors required are computable.
    # computable_tensors: all tensors in the graph
    # that can be computed from the inputs provided.
    computable_tensors = set()
    for x in inputs:
        computable_tensors.add(x)

    operations_with_complete_input = []  # To provide a better error msg.
    for depth in depth_keys:
        for node in nodes_by_depth[depth]:
            for x in tree.flatten(node.input_tensors):
                if x not in computable_tensors:
                    operation = node.operation
                    raise ValueError(
                        "Graph disconnected: cannot find parent for "
                        f"tensor {x} at operation '{operation}'. "
                        "The following previous operations were accessed "
                        f"without issue: {operations_with_complete_input}"
                    )
                operations_with_complete_input.append(node.operation.name)

            for x in tree.flatten(node.outputs):
                computable_tensors.add(x)

    # Ensure name unicity, which will be crucial for serialization
    # (since serialized nodes refer to operations by their name).
    all_names = [operation.name for operation in operations]
    for name in all_names:
        if all_names.count(name) != 1:
            raise ValueError(
                f'The name "{name}" is used {all_names.count(name)} '
                "times in the model. All operation names should be unique."
            )
    return network_nodes, nodes_by_depth, operations, operations_by_depth


def _build_map(inputs, outputs):
    """Topologically sort nodes in order from inputs to outputs.

    It uses a depth-first search to topologically sort nodes that appear in the
    _keras_history connectivity metadata of `outputs`.

    Args:
        outputs: the output tensors whose _keras_history metadata should be
                walked. This may be an arbitrary nested structure.

    Returns:
        A tuple like (ordered_nodes, operation_to_first_traversal_index)
        ordered_nodes: list of nodes appearing in the keras history,
            topologically sorted from original inputs to the `outputs`.
            (If outputs have different sets of ancestors, the inputs to one
            output may appear after a different output).
        operation_to_first_traversal_index:
            A dict mapping operation to the traversal index in the DFS where it
            is seen. Note: if a operation is shared by several nodes, the dict
            will onlystore the index corresponding to the *first* time the
            operation seen.
    """
    finished_nodes = set()
    nodes_in_progress = set()
    nodes_in_decreasing_depth = []  # nodes from inputs -> outputs.
    operation_indices = {}  # operation -> in traversal order.
    for output in tree.flatten(outputs):
        _build_map_helper(
            inputs,
            output,
            finished_nodes,
            nodes_in_progress,
            nodes_in_decreasing_depth,
            operation_indices,
        )
    return nodes_in_decreasing_depth, operation_indices


def _build_map_helper(
    inputs,
    tensor,
    finished_nodes,
    nodes_in_progress,
    nodes_in_decreasing_depth,
    operation_indices,
):
    """Recursive helper for `_build_map`."""
    (
        operation,
        node_index,
        _,
    ) = tensor._keras_history
    if not operation:
        return

    node = operation._inbound_nodes[node_index]

    # Don't repeat work for shared subgraphs
    if node in finished_nodes:
        return

    # Prevent cycles.
    if node in nodes_in_progress:
        raise ValueError(
            f"Tensor {tensor} from operation '{operation.name}' is part of a "
            "cycle."
        )

    # Store the traversal order for operation sorting.
    if operation not in operation_indices:
        operation_indices[operation] = len(operation_indices)

    # Propagate to all previous tensors connected to this node.
    nodes_in_progress.add(node)
    if not node.is_input and tensor not in tree.flatten(inputs):
        for tensor in node.input_tensors:
            _build_map_helper(
                inputs,
                tensor,
                finished_nodes,
                nodes_in_progress,
                nodes_in_decreasing_depth,
                operation_indices,
            )

    finished_nodes.add(node)
    nodes_in_progress.remove(node)
    nodes_in_decreasing_depth.append(node)
