import copy
import warnings

import torch

from keras.src import tree
from keras.src.export.export_utils import convert_spec_to_tensor
from keras.src.utils.module_utils import tensorflow as tf
from keras.src.utils.module_utils import torch_xla


class TorchExportArchive:
    def _track_layer(self, layer):
        raise NotImplementedError(
            "`track` is not supported for `Layer`s and `Model`s in the torch "
            "backend. Use `track_and_add_endpoint` instead."
        )

    def add_endpoint(self, name, fn, input_signature, **kwargs):
        raise NotImplementedError(
            "`add_endpoint` is not supported for `Layer`s and `Model`s in the "
            "torch backend. Use `track_and_add_endpoint` instead."
        )

    def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
        # Disable false alarms related to lifting parameters.
        warnings.filterwarnings("ignore", message=".*created when tracing.*")
        warnings.filterwarnings(
            "ignore", message=".*Unable to find the path of the module.*"
        )

        if not isinstance(resource, torch.nn.Module):
            raise TypeError(
                "`resource` must be an instance of `torch.nn.Module`. "
                f"Received: resource={resource} (of type {type(resource)})"
            )

        sample_inputs = tree.map_structure(
            lambda x: convert_spec_to_tensor(x, replace_none_number=1),
            input_signature,
        )
        sample_inputs = tuple(sample_inputs)

        # Ref: torch_xla.tf_saved_model_integration
        # TODO: Utilize `dynamic_shapes`
        exported = torch.export.export(
            resource, sample_inputs, dynamic_shapes=None, strict=False
        )
        options = torch_xla.stablehlo.StableHLOExportOptions(
            override_tracing_arguments=sample_inputs
        )
        stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo(
            exported, options
        )
        state_dict_keys = list(stablehlo_model._bundle.state_dict.keys())

        # Remove unused variables.
        for k in state_dict_keys:
            if "lifted" not in k:
                stablehlo_model._bundle.state_dict.pop(k)

        bundle = copy.deepcopy(stablehlo_model._bundle)
        bundle.state_dict = {
            k: tf.Variable(v, trainable=False, name=k)
            for k, v in bundle.state_dict.items()
        }
        bundle.additional_constants = [
            tf.Variable(v, trainable=False) for v in bundle.additional_constants
        ]

        # Track variables in `bundle` for `write_out`.
        self._tf_trackable.variables += (
            list(bundle.state_dict.values()) + bundle.additional_constants
        )

        # Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf
        def make_tf_function(func, bundle):
            from tensorflow.compiler.tf2xla.python import xla as tfxla

            def _get_shape_with_dynamic(signature):
                shape = copy.copy(signature.shape)
                for i in signature.dynamic_dims:
                    shape[i] = None
                return shape

            def _extract_call_parameters(args, meta, bundle):
                call_args = []
                if meta.input_pytree_spec is not None:
                    args = tree.flatten(args)
                for loc in meta.input_locations:
                    if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER:
                        call_args.append(bundle.state_dict[loc.name])
                    elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT:
                        call_args.append(
                            bundle.additional_constants[loc.position]
                        )
                    else:
                        call_args.append(args[loc.position])
                return call_args

            def inner(*args):
                Touts = [sig.dtype for sig in func.meta.output_signature]
                Souts = [
                    _get_shape_with_dynamic(sig)
                    for sig in func.meta.output_signature
                ]
                call_args = _extract_call_parameters(args, func.meta, bundle)
                results = tfxla.call_module(
                    tuple(call_args),
                    version=5,
                    Tout=Touts,  # dtype information
                    Sout=Souts,  # Shape information
                    function_list=[],
                    module=func.bytecode,
                )
                if len(Souts) == 1:
                    results = results[0]
                return results

            return inner

        decorated_fn = tf.function(
            make_tf_function(
                stablehlo_model._bundle.stablehlo_funcs[0], bundle
            ),
            input_signature=input_signature,
        )
        return decorated_fn
