import copy
import inspect
import typing
import warnings

from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.backend.common import global_state
from keras.src.layers.core.input_layer import Input
from keras.src.layers.core.input_layer import InputLayer
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer
from keras.src.legacy.saving import saving_utils
from keras.src.legacy.saving import serialization as legacy_serialization
from keras.src.models.model import Model
from keras.src.ops.function import Function
from keras.src.ops.function import _build_map
from keras.src.ops.function import make_node_key
from keras.src.ops.node import KerasHistory
from keras.src.ops.node import Node
from keras.src.ops.operation import Operation
from keras.src.saving import serialization_lib
from keras.src.utils import tracking


class Functional(Function, Model):
    """A `Functional` model is a `Model` defined as a directed graph of layers.

    Three types of `Model` exist: subclassed `Model`, `Functional` model,
    and `Sequential` (a special case of `Functional`).

    A `Functional` model can be instantiated by passing two arguments to
    `__init__()`. The first argument is the `keras.Input` objects
    that represent the inputs to the model.
    The second argument specifies the output tensors that represent
    the outputs of this model. Both arguments can be a nested structure
    of tensors.

    Example:

    ```
    inputs = {'x1': keras.Input(shape=(10,), name='x1'),
              'x2': keras.Input(shape=(1,), name='x2')}
    t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
    outputs = keras.layers.Add()([t, inputs['x2']])
    model = keras.Model(inputs, outputs)
    ```

    A `Functional` model constructed using the Functional API can also
    include raw Keras 3 ops.

    Example:

    ```python
    inputs = keras.Input(shape=(10,))
    x = keras.layers.Dense(1)(inputs)
    outputs = ops.nn.relu(x)
    model = keras.Model(inputs, outputs)
    ```

    A new `Functional` model can also be created by using the
    intermediate tensors. This enables you to quickly extract sub-components
    of the model.

    Example:

    ```python
    inputs = keras.Input(shape=(None, None, 3))
    processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
    conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
    pooling = keras.layers.GlobalAveragePooling2D()(conv)
    feature = keras.layers.Dense(10)(pooling)

    full_model = keras.Model(inputs, feature)
    backbone = keras.Model(processed, conv)
    activations = keras.Model(conv, feature)
    ```

    Note that the `backbone` and `activations` models are not
    created with `keras.Input` objects, but with the tensors
    that are originated from `keras.Input` objects.
    Under the hood, the layers and weights will
    be shared across these models, so that user can train the `full_model`, and
    use `backbone` or `activations` to do feature extraction.
    The inputs and outputs of the model can be nested structures of tensors as
    well, and the created models are standard `Functional` model that support
    all the existing API.

    Args:
        inputs: List of input tensors (must be created via `keras.Input()`
            or originated from `keras.Input()`).
        outputs: List of output tensors.
        name: String, optional. Name of the model.
        trainable: Boolean, optional. If the model's variables should be
            trainable.
    """

    def __new__(cls, *args, **kwargs):
        return typing.cast(cls, super().__new__(cls))

    @tracking.no_automatic_dependency_tracking
    def __init__(self, inputs, outputs, name=None, **kwargs):
        if isinstance(inputs, dict):
            for k, v in inputs.items():
                if isinstance(v, backend.KerasTensor) and k != v.name:
                    warnings.warn(
                        "When providing `inputs` as a dict, all keys in the "
                        "dict must match the names of the corresponding "
                        f"tensors. Received key '{k}' mapping to value {v} "
                        f"which has name '{v.name}'. Change the tensor name to "
                        f"'{k}' (via `Input(..., name='{k}')`)"
                    )

        trainable = kwargs.pop("trainable", None)
        flat_inputs = tree.flatten(inputs)
        flat_outputs = tree.flatten(outputs)
        for x in flat_inputs:
            if not isinstance(x, backend.KerasTensor):
                raise ValueError(
                    "All `inputs` values must be KerasTensors. Received: "
                    f"inputs={inputs} including invalid value {x} of "
                    f"type {type(x)}"
                )
        for x in flat_outputs:
            if not isinstance(x, backend.KerasTensor):
                raise ValueError(
                    "All `outputs` values must be KerasTensors. Received: "
                    f"outputs={outputs} including invalid value {x} of "
                    f"type {type(x)}"
                )

        if not all(is_input_keras_tensor(t) for t in flat_inputs):
            inputs, outputs = clone_graph_nodes(inputs, outputs)

        Function.__init__(self, inputs, outputs, name=name)

        if trainable is not None:
            self.trainable = trainable

        self._layers = self.layers
        self.build(None)
        # We will convert directly (to the correct dtype per input).
        self._convert_input_args = False
        self._allow_non_tensor_positional_args = True
        output_layers = [x._keras_history[0] for x in self.outputs]
        self.output_names = [x.name for x in output_layers]

    def _lock_state(self):
        # Unlike other layers, we allow Functional state to be mutable after
        # build. E.g. to attach a layer to a model that is not part of the
        # functional DAG.
        pass

    def _obj_type(self):
        return "Functional"

    @property
    def layers(self):
        layers = []
        for operation in self._operations:
            if isinstance(operation, Layer):
                layers.append(operation)
        return layers

    @layers.setter
    def layers(self, _):
        raise AttributeError(
            "`Model.layers` attribute is reserved and should not be used. "
            "Please use another name."
        )

    def call(self, inputs, training=None, mask=None, **kwargs):
        # Add support for training, masking
        inputs = self._standardize_inputs(inputs)
        if mask is None:
            masks = [None] * len(inputs)
        else:
            masks = tree.flatten(mask)
            for x, mask in zip(inputs, masks):
                if mask is not None:
                    backend.set_keras_mask(x, mask)
        outputs = self._run_through_graph(
            inputs,
            operation_fn=lambda op: operation_fn(
                op, training=training, **kwargs
            ),
        )
        return unpack_singleton(outputs)

    def compute_output_spec(self, inputs, training=None, mask=None):
        # From Function
        return super().compute_output_spec(inputs)

    def compute_output_shape(self, input_shape):
        # From Function
        return super().compute_output_shape(input_shape)

    def build(self, input_shape):
        self.built = True

    @property
    def input_shape(self):
        input_shapes = tree.map_structure(lambda x: x.shape, self.inputs)
        if isinstance(input_shapes, list) and len(input_shapes) == 1:
            return input_shapes[0]
        return input_shapes

    @property
    def output_shape(self):
        output_shapes = tree.map_structure(lambda x: x.shape, self.outputs)
        if isinstance(output_shapes, list) and len(output_shapes) == 1:
            return output_shapes[0]
        return output_shapes

    def _assert_input_compatibility(self, *args):
        return super(Model, self)._assert_input_compatibility(*args)

    def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False):
        try:
            # We first normalize to tuples before performing the check to
            # suppress warnings when encountering mismatched tuples and lists.
            tree.assert_same_structure(
                tree.lists_to_tuples(inputs),
                tree.lists_to_tuples(self._inputs_struct),
            )
        except:
            model_inputs_struct = tree.map_structure(
                lambda x: x.name, self._inputs_struct
            )
            inputs_struct = tree.map_structure(
                lambda x: f"Tensor(shape={x.shape})", inputs
            )
            msg = (
                "The structure of `inputs` doesn't match the expected "
                f"structure.\nExpected: {model_inputs_struct}\n"
                f"Received: inputs={inputs_struct}"
            )
            if raise_exception:
                raise ValueError(msg)
            warnings.warn(msg)

    def _convert_inputs_to_tensors(self, flat_inputs):
        converted = []
        for x, input in zip(flat_inputs, self._inputs):
            if x is None:  # TODO: check if optional
                converted.append(x)
            else:
                converted.append(
                    ops.convert_to_tensor(
                        x, dtype=input.dtype, sparse=input.sparse
                    )
                )
        return converted

    def _adjust_input_rank(self, flat_inputs):
        flat_ref_shapes = [x.shape for x in self._inputs]
        adjusted = []
        for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
            if x is None:
                adjusted.append(x)
                continue
            x_rank = len(x.shape)
            ref_rank = len(ref_shape)
            if x_rank == ref_rank:
                adjusted.append(x)
                continue
            if x_rank == ref_rank + 1:
                if x.shape[-1] == 1:
                    adjusted.append(ops.squeeze(x, axis=-1))
                    continue
            if x_rank == ref_rank - 1:
                if ref_shape[-1] == 1:
                    adjusted.append(ops.expand_dims(x, axis=-1))
                    continue
            raise ValueError(
                f"Invalid input shape for input {x}. Expected shape "
                f"{ref_shape}, but input has incompatible shape {x.shape}"
            )
        # Add back metadata.
        for i in range(len(flat_inputs)):
            if hasattr(flat_inputs[i], "_keras_history"):
                adjusted[i]._keras_history = flat_inputs[i]._keras_history
            mask = backend.get_keras_mask(flat_inputs[i])
            if mask is not None:
                backend.set_keras_mask(adjusted[i], mask)
        return adjusted

    def _standardize_inputs(self, inputs):
        raise_exception = False
        if (
            isinstance(self._inputs_struct, list)
            and len(self._inputs_struct) == 1
            and ops.is_tensor(inputs)
        ):
            inputs = [inputs]
        elif isinstance(inputs, dict) and not isinstance(
            self._inputs_struct, dict
        ):
            # This is to avoid warning
            # when we have reconciable dict/list structs
            if hasattr(self._inputs_struct, "__len__") and all(
                isinstance(i, backend.KerasTensor) for i in self._inputs_struct
            ):
                expected_keys = set(i.name for i in self._inputs_struct)
                keys = set(inputs.keys())
                if expected_keys.issubset(keys):
                    inputs = [inputs[i.name] for i in self._inputs_struct]
                else:
                    raise_exception = True
            elif isinstance(self._inputs_struct, backend.KerasTensor):
                if self._inputs_struct.name in inputs:
                    inputs = [inputs[self._inputs_struct.name]]
                else:
                    raise_exception = True
            else:
                raise_exception = True
        if (
            isinstance(self._inputs_struct, dict)
            and not isinstance(inputs, dict)
            and list(self._inputs_struct.keys())
            != sorted(self._inputs_struct.keys())
        ):
            raise_exception = True
        self._maybe_warn_inputs_struct_mismatch(
            inputs, raise_exception=raise_exception
        )

        flat_inputs = tree.flatten(inputs)
        flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
        return self._adjust_input_rank(flat_inputs)

    @property
    def input(self):
        # For backwards compatibility,
        # override `input` to retrieve the used-provided
        # constructor inputs
        return self._inputs_struct

    @property
    def output(self):
        return self._outputs_struct

    def add_loss(self, loss):
        # Symbolic only. TODO
        raise NotImplementedError

    @property
    def input_spec(self):
        if hasattr(self, "_manual_input_spec"):
            return self._manual_input_spec

        def shape_with_no_batch_size(x):
            x = list(x)
            if x:
                x[0] = None
            return tuple(x)

        def make_spec_for_tensor(x, name=None):
            optional = False
            if isinstance(x._keras_history[0], InputLayer):
                if x._keras_history[0].optional:
                    optional = True
            return InputSpec(
                shape=shape_with_no_batch_size(x.shape),
                allow_last_axis_squeeze=True,
                name=x._keras_history[0].name if name is None else name,
                optional=optional,
            )

        if isinstance(self._inputs_struct, dict):
            if all(
                isinstance(x, backend.KerasTensor)
                for x in self._inputs_struct.values()
            ):
                # Case where `_nested_inputs` is a plain dict of Inputs.
                names = sorted(self._inputs_struct.keys())
                return [
                    make_spec_for_tensor(self._inputs_struct[name], name=name)
                    for name in names
                ]
            return None  # Deeply nested dict: skip checks.
        return [make_spec_for_tensor(x) for x in self.inputs]

    @input_spec.setter
    def input_spec(self, value):
        self._manual_input_spec = value

    def get_config(self):
        if not functional_like_constructor(self.__class__):
            # Subclassed networks are not serializable
            # (unless serialization is implemented by
            # the author of the subclassed network).
            return Model.get_config(self)

        config = {
            "name": self.name,
            "trainable": self.trainable,
        }
        # Build a map from a layer unique name (make_node_key)
        # to the index of the nodes that are saved in the config.
        # Only nodes in network_nodes are saved.
        node_reindexing_map = {}
        for operation in self.operations:
            if issubclass(operation.__class__, Functional):
                # Functional models start with a pre-existing node
                # linking their input to output.
                kept_nodes = 1
            else:
                kept_nodes = 0
            for original_node_index, node in enumerate(
                operation._inbound_nodes
            ):
                node_key = make_node_key(operation, original_node_index)
                if node_key in self._nodes:
                    # i.e. we mark it to be saved
                    node_reindexing_map[node_key] = kept_nodes
                    kept_nodes += 1

        # serialize and save the layers in layer_configs
        layer_configs = []
        for operation in self.operations:  # From the earliest layers on.
            filtered_inbound_nodes = []
            for original_node_index, node in enumerate(
                operation._inbound_nodes
            ):
                node_key = make_node_key(operation, original_node_index)
                if node_key in self._nodes:
                    # The node is relevant to the model:
                    # add to filtered_inbound_nodes.
                    node_data = serialize_node(node, own_nodes=self._nodes)
                    if node_data is not None:
                        filtered_inbound_nodes.append(node_data)

            serialize_obj_fn = serialization_lib.serialize_keras_object
            if global_state.get_global_attribute("use_legacy_config", False):
                # Legacy format serialization used for H5 and SavedModel
                serialize_obj_fn = legacy_serialization.serialize_keras_object
            layer_config = serialize_obj_fn(operation)
            layer_config["name"] = operation.name
            layer_config["inbound_nodes"] = filtered_inbound_nodes
            layer_configs.append(layer_config)
        config["layers"] = layer_configs

        # Gather info about inputs and outputs.
        def get_tensor_config(tensor):
            operation = tensor._keras_history[0]
            node_index = tensor._keras_history[1]
            tensor_index = tensor._keras_history[2]
            node_key = make_node_key(operation, node_index)
            assert node_key in self._nodes
            new_node_index = node_reindexing_map[node_key]
            return [operation.name, new_node_index, tensor_index]

        def map_tensors(tensors):
            if isinstance(tensors, backend.KerasTensor):
                return [get_tensor_config(tensors)]
            return tree.map_structure(get_tensor_config, tensors)

        config["input_layers"] = map_tensors(self._inputs_struct)
        config["output_layers"] = map_tensors(self._outputs_struct)
        return copy.deepcopy(config)


def functional_from_config(cls, config, custom_objects=None):
    """Instantiates a Functional model from its config (from `get_config()`).

    Args:
        cls: Class of the model, e.g. a custom subclass of `Model`.
        config: Output of `get_config()` for the original model instance.
        custom_objects: Optional dict of custom objects.

    Returns:
        An instance of `cls`.
    """
    # Layer instances created during
    # the graph reconstruction process
    created_layers = {}

    # Dictionary mapping layer instances to
    # node data that specifies a layer call.
    # It acts as a queue that maintains any unprocessed
    # layer call until it becomes possible to process it
    # (i.e. until the input tensors to the call all exist).
    unprocessed_nodes = {}

    def add_unprocessed_node(layer, node_data):
        """Add node to layer list

        Arg:
            layer: layer object
            node_data: Node data specifying layer call
        """
        if layer not in unprocessed_nodes:
            unprocessed_nodes[layer] = [node_data]
        else:
            unprocessed_nodes[layer].append(node_data)

    def process_node(layer, node_data):
        """Reconstruct node by linking to inbound layers

        Args:
            layer: Layer to process
            node_data: List of layer configs
        """
        args, kwargs = deserialize_node(node_data, created_layers)
        # Call layer on its inputs, thus creating the node
        # and building the layer if needed.
        layer(*args, **kwargs)

    def process_layer(layer_data):
        """Deserializes a layer and index its inbound nodes.

        Args:
            layer_data: layer config dict.
        """
        layer_name = layer_data["name"]

        # Instantiate layer.
        if "module" not in layer_data:
            # Legacy format deserialization (no "module" key)
            # used for H5 and SavedModel formats
            layer = saving_utils.model_from_config(
                layer_data, custom_objects=custom_objects
            )
        else:
            layer = serialization_lib.deserialize_keras_object(
                layer_data, custom_objects=custom_objects
            )
        if not isinstance(layer, Operation):
            raise ValueError(
                "Unexpected object from deserialization, expected a layer or "
                f"operation, got a {type(layer)}"
            )
        created_layers[layer_name] = layer

        # Gather layer inputs.
        inbound_nodes_data = layer_data["inbound_nodes"]
        for node_data in inbound_nodes_data:
            # We don't process nodes (i.e. make layer calls)
            # on the fly because the inbound node may not yet exist,
            # in case of layer shared at different topological depths
            # (e.g. a model such as A(B(A(B(x)))))
            add_unprocessed_node(layer, node_data)

    # Extract config used to instantiate Functional model from the config. The
    # remaining config will be passed as keyword arguments to the Model
    # constructor.
    functional_config = {}
    for key in ["layers", "input_layers", "output_layers"]:
        functional_config[key] = config.pop(key)
    for key in ["name", "trainable"]:
        if key in config:
            functional_config[key] = config.pop(key)
        else:
            functional_config[key] = None

    # First, we create all layers and enqueue nodes to be processed
    for layer_data in functional_config["layers"]:
        process_layer(layer_data)

    # Then we process nodes in order of layer depth.
    # Nodes that cannot yet be processed (if the inbound node
    # does not yet exist) are re-enqueued, and the process
    # is repeated until all nodes are processed.
    while unprocessed_nodes:
        for layer_data in functional_config["layers"]:
            layer = created_layers[layer_data["name"]]

            # Process all nodes in layer, if not yet processed
            if layer in unprocessed_nodes:
                node_data_list = unprocessed_nodes[layer]

                # Process nodes in order
                node_index = 0
                while node_index < len(node_data_list):
                    node_data = node_data_list[node_index]
                    try:
                        process_node(layer, node_data)

                    # If the node does not have all inbound layers
                    # available, stop processing and continue later
                    except IndexError:
                        break

                    node_index += 1

                # If not all nodes processed then store unprocessed nodes
                if node_index < len(node_data_list):
                    unprocessed_nodes[layer] = node_data_list[node_index:]
                # If all nodes processed remove the layer
                else:
                    del unprocessed_nodes[layer]

    # Create list of input and output tensors and return new class
    name = functional_config["name"]
    trainable = functional_config["trainable"]

    def get_tensor(layer_name, node_index, tensor_index):
        assert layer_name in created_layers
        layer = created_layers[layer_name]
        if isinstance(layer, Functional):
            # Functional models start out with a built-in node.
            node_index -= 1
        layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
        return layer_output_tensors[tensor_index]

    def map_tensors(tensors):
        if (
            isinstance(tensors, list)
            and len(tensors) == 3
            and isinstance(tensors[0], str)
        ):
            # Leaf
            return get_tensor(*tensors)
        if isinstance(tensors, dict):
            return {k: map_tensors(v) for k, v in tensors.items()}
        if isinstance(tensors, tuple):
            return tuple([map_tensors(v) for v in tensors])
        return [map_tensors(v) for v in tensors]

    input_tensors = map_tensors(functional_config["input_layers"])
    output_tensors = map_tensors(functional_config["output_layers"])
    if isinstance(output_tensors, list) and len(output_tensors) == 1:
        output_tensors = output_tensors[0]

    return cls(
        inputs=input_tensors,
        outputs=output_tensors,
        name=name,
        trainable=trainable,
        **config,
    )


def operation_fn(operation, **call_context_args):
    """Wraps each op to inject the call-context args."""

    def call(*args, **kwargs):
        # Propagate all registered call-context args
        for name, value in call_context_args.items():
            if (
                name in getattr(operation, "_call_context_args", {})
                and value is not None
            ):
                kwargs[name] = value

        return operation(*args, **kwargs)

    return call


def functional_like_constructor(cls):
    init_args = inspect.getfullargspec(cls.__init__).args[1:]
    functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:]
    if init_args == functional_init_args:
        return True
    return False


def unpack_singleton(x):
    if isinstance(x, (list, tuple)) and len(x) == 1:
        return x[0]
    return x


def serialize_node(node, own_nodes=()):
    if not node.input_tensors:
        # Does not need to be serialized.
        return

    def serialize_keras_tensor(x):
        # Serialize KerasTensor while converting
        # node indices to only include nodes relevant to `own_nodes`.
        if isinstance(x, backend.KerasTensor):
            operation, node_index, tensor_index = x._keras_history
            irrelevant_node_count = 0
            for i, node in enumerate(operation._inbound_nodes[:node_index]):
                node_key = make_node_key(operation, i)
                if node_key not in own_nodes:
                    irrelevant_node_count += 1
            x._keras_history = KerasHistory(
                operation, node_index - irrelevant_node_count, tensor_index
            )
            serialized = serialization_lib.serialize_keras_object(x)
            x._keras_history = KerasHistory(operation, node_index, tensor_index)
            return serialized
        return x

    args = node.arguments.args
    kwargs = node.arguments.kwargs

    args = tree.map_structure(serialize_keras_tensor, args)
    kwargs = tree.map_structure(serialize_keras_tensor, kwargs)
    return {
        "args": serialization_lib.serialize_keras_object(args),
        "kwargs": serialization_lib.serialize_keras_object(kwargs),
    }


def deserialize_node(node_data, created_layers):
    """Return (args, kwargs) for calling the node layer."""
    if not node_data:
        return [], {}

    if isinstance(node_data, list):
        # Legacy case.
        input_tensors = []
        for input_data in node_data:
            inbound_layer_name = input_data[0]
            inbound_node_index = input_data[1]
            inbound_tensor_index = input_data[2]
            if len(input_data) == 3:
                kwargs = {}
            elif len(input_data) == 4:
                kwargs = input_data[3]
            else:
                raise ValueError(
                    "Cannot deserialize the model (invalid config data?)"
                )
            inbound_layer = created_layers[inbound_layer_name]

            # Raise an error if the corresponding layer node
            # has not yet been created
            if len(inbound_layer._inbound_nodes) <= inbound_node_index:
                raise IndexError(
                    "Layer node index out of bounds.\n"
                    f"inbound_layer = {inbound_layer}\n"
                    "inbound_layer._inbound_nodes = "
                    f"{inbound_layer._inbound_nodes}\n"
                    f"inbound_node_index = {inbound_node_index}"
                )
            inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
            input_tensors.append(
                inbound_node.output_tensors[inbound_tensor_index]
            )
        return [unpack_singleton(input_tensors)], kwargs

    args = serialization_lib.deserialize_keras_object(node_data["args"])
    kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])

    def convert_revived_tensor(x):
        if isinstance(x, backend.KerasTensor):
            history = x._pre_serialization_keras_history
            if history is None:
                return x
            layer = created_layers.get(history[0], None)
            if layer is None:
                raise ValueError(f"Unknown layer: {history[0]}")
            inbound_node_index = history[1]
            inbound_tensor_index = history[2]
            if len(layer._inbound_nodes) <= inbound_node_index:
                raise IndexError(
                    "Layer node index out of bounds.\n"
                    f"inbound_layer = {layer}\n"
                    f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
                    f"inbound_node_index = {inbound_node_index}"
                )
            inbound_node = layer._inbound_nodes[inbound_node_index]
            return inbound_node.output_tensors[inbound_tensor_index]
        return x

    args = tree.map_structure(convert_revived_tensor, args)
    kwargs = tree.map_structure(convert_revived_tensor, kwargs)
    return args, kwargs


def is_input_keras_tensor(x):
    (
        operation,
        node_index,
        _,
    ) = x._keras_history
    node = operation._inbound_nodes[node_index]
    return node.is_input


def clone_single_keras_tensor(x):
    return backend.KerasTensor(
        shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=x.name + "_clone"
    )


def clone_keras_tensors(tensors, kt_id_mapping):
    def swap(x):
        if not isinstance(x, backend.KerasTensor):
            return x
        if id(x) in kt_id_mapping:
            return kt_id_mapping[id(x)]
        new_x = clone_single_keras_tensor(x)
        kt_id_mapping[id(x)] = new_x
        return new_x

    return tree.map_structure(swap, tensors)


def find_nodes_by_inputs_and_outputs(inputs, outputs):
    nodes, _ = _build_map(inputs, outputs)
    return nodes


def clone_graph_nodes(inputs, outputs):
    """Clone the `Node` between the inputs and output tensors.

    This function is used to create a new functional model from any intermediate
    Keras tensors. The clone of the nodes mimic the behavior of reconstructing
    the functional graph network by re-executing all the `__call__()` methods.
    The cloned nodes will be appended to the layers.

    Note that a new `keras.Input` will be created for any items in the
    `inputs`

    Args:
    inputs: A nested structure of `KerasTensor` instances.
    outputs: A nested structure of `KerasTensor` instances.

    Returns:
        A pair of inputs and outputs, with cloned `KerasTensor` instances.
        They can be used to create a new functional model.
    """
    nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
    cloned_inputs = []
    cloned_outputs = []
    # We not only need to create copies of Nodes (mimic the calls), also need to
    # clone Keras tensors to avoid the override of _keras_history attached on
    # the Keras tensor. The following dict is used to track any keras tensor we
    # cloned The key is the string ID of the original keras tensor, and value is
    # the cloned Keras tensor instance.
    kt_id_mapping = {}
    op_id_mapping = {}

    for kt_input in tree.flatten(inputs):
        if is_input_keras_tensor(kt_input):
            # For any existing Keras tensor from keras.Input, leave them as is.
            cloned_inputs.append(kt_input)
            kt_id_mapping[id(kt_input)] = kt_input
        else:
            # We need to create a new Keras tensor for any intermediate tensor
            cloned_input = Input(
                batch_shape=kt_input.shape,
                dtype=kt_input.dtype,
                sparse=kt_input.sparse,
                name=kt_input.name + "CLONE",
            )
            cloned_inputs.append(cloned_input)
            kt_id_mapping[id(kt_input)] = cloned_input
            op_id_mapping[id(kt_input._keras_history[0])] = (
                cloned_input._keras_history[0]
            )
    cloned_inputs = tree.pack_sequence_as(inputs, cloned_inputs)

    for kt_output in tree.flatten(outputs):
        cpy = clone_single_keras_tensor(kt_output)
        # We reuse the _keras_history here, which contains the old information.
        cpy._keras_history = kt_output._keras_history
        cloned_outputs.append(cpy)
        kt_id_mapping[id(kt_output)] = cpy
    cloned_outputs = tree.pack_sequence_as(outputs, cloned_outputs)

    for node in nodes_to_clone:
        if id(node.operation) in op_id_mapping:
            operation = op_id_mapping[id(node.operation)]
        else:
            operation = node.operation
        # Clone any Keras tensor to avoid override of _keras_history
        # Or reuse an existing Keras tensor if it has already been cloned.
        output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
        if not isinstance(operation, InputLayer):
            call_args_copy = clone_keras_tensors(
                node.arguments.args, kt_id_mapping
            )
            call_kwargs_copy = clone_keras_tensors(
                node.arguments.kwargs, kt_id_mapping
            )
        else:
            call_args_copy = ()
            call_kwargs_copy = {}
        # Creating new nodes based on the existing node information.  Node wires
        # itself to inbound and outbound layers.  The Node constructor actually
        # updates this layer's self._inbound_nodes, sets _keras_history on the
        # outputs, and adds itself to the `_outbound_nodes` of the layers that
        # produced the inputs to this layer call.
        Node(
            operation,
            call_args=call_args_copy,
            call_kwargs=call_kwargs_copy,
            outputs=output_copy,
        )
    return cloned_inputs, cloned_outputs
