"""JSON utilities for legacy saving formats (h5 and SavedModel)"""

import collections
import enum
import functools
import json

import numpy as np

from keras.src.legacy.saving import serialization
from keras.src.saving import serialization_lib
from keras.src.utils.module_utils import tensorflow as tf

_EXTENSION_TYPE_SPEC = "_EXTENSION_TYPE_SPEC"


class Encoder(json.JSONEncoder):
    """JSON encoder and decoder that handles TensorShapes and tuples."""

    def default(self, obj):
        """Encodes objects for types that aren't handled by the default
        encoder."""
        if tf.available and isinstance(obj, tf.TensorShape):
            items = obj.as_list() if obj.rank is not None else None
            return {"class_name": "TensorShape", "items": items}
        return get_json_type(obj)

    def encode(self, obj):
        return super().encode(_encode_tuple(obj))


def _encode_tuple(x):
    if isinstance(x, tuple):
        return {
            "class_name": "__tuple__",
            "items": tuple(_encode_tuple(i) for i in x),
        }
    elif isinstance(x, list):
        return [_encode_tuple(i) for i in x]
    elif isinstance(x, dict):
        return {key: _encode_tuple(value) for key, value in x.items()}
    else:
        return x


def decode(json_string):
    return json.loads(json_string, object_hook=_decode_helper)


def decode_and_deserialize(
    json_string, module_objects=None, custom_objects=None
):
    """Decodes the JSON and deserializes any Keras objects found in the dict."""
    return json.loads(
        json_string,
        object_hook=functools.partial(
            _decode_helper,
            deserialize=True,
            module_objects=module_objects,
            custom_objects=custom_objects,
        ),
    )


def _decode_helper(
    obj, deserialize=False, module_objects=None, custom_objects=None
):
    """A decoding helper that is TF-object aware.

    Args:
      obj: A decoded dictionary that may represent an object.
      deserialize: Boolean. When True, deserializes any Keras
        objects found in `obj`. Defaults to `False`.
      module_objects: A dictionary of built-in objects to look the name up in.
        Generally, `module_objects` is provided by midlevel library
        implementers.
      custom_objects: A dictionary of custom objects to look the name up in.
        Generally, `custom_objects` is provided by the end user.

    Returns:
      The decoded object.
    """
    if isinstance(obj, dict) and "class_name" in obj:
        if tf.available:
            if obj["class_name"] == "TensorShape":
                return tf.TensorShape(obj["items"])
            elif obj["class_name"] == "TypeSpec":
                from tensorflow.python.framework import type_spec_registry

                return type_spec_registry.lookup(obj["type_spec"])._deserialize(
                    _decode_helper(obj["serialized"])
                )
            elif obj["class_name"] == "CompositeTensor":
                spec = obj["spec"]
                tensors = []
                for dtype, tensor in obj["tensors"]:
                    tensors.append(
                        tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))
                    )
                return tf.nest.pack_sequence_as(
                    _decode_helper(spec), tensors, expand_composites=True
                )

        if obj["class_name"] == "__tuple__":
            return tuple(_decode_helper(i) for i in obj["items"])
        elif obj["class_name"] == "__ellipsis__":
            return Ellipsis
        elif deserialize and "__passive_serialization__" in obj:
            # __passive_serialization__ is added by the JSON encoder when
            # encoding an object that has a `get_config()` method.
            try:
                if (
                    "module" not in obj
                ):  # TODO(nkovela): Add TF SavedModel scope
                    return serialization.deserialize_keras_object(
                        obj,
                        module_objects=module_objects,
                        custom_objects=custom_objects,
                    )
                else:
                    return serialization_lib.deserialize_keras_object(
                        obj,
                        module_objects=module_objects,
                        custom_objects=custom_objects,
                    )
            except ValueError:
                pass
        elif obj["class_name"] == "__bytes__":
            return obj["value"].encode("utf-8")
    return obj


def get_json_type(obj):
    """Serializes any object to a JSON-serializable structure.

    Args:
        obj: the object to serialize

    Returns:
        JSON-serializable structure representing `obj`.

    Raises:
        TypeError: if `obj` cannot be serialized.
    """
    # if obj is a serializable Keras class instance
    # e.g. optimizer, layer
    if hasattr(obj, "get_config"):
        # TODO(nkovela): Replace with legacy serialization
        serialized = serialization.serialize_keras_object(obj)
        serialized["__passive_serialization__"] = True
        return serialized

    # if obj is any numpy type
    if type(obj).__module__ == np.__name__:
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return obj.item()

    # misc functions (e.g. loss function)
    if callable(obj):
        return obj.__name__

    # if obj is a python 'type'
    if type(obj).__name__ == type.__name__:
        return obj.__name__

    if tf.available and isinstance(obj, tf.compat.v1.Dimension):
        return obj.value

    if tf.available and isinstance(obj, tf.TensorShape):
        return obj.as_list()

    if tf.available and isinstance(obj, tf.DType):
        return obj.name

    if isinstance(obj, collections.abc.Mapping):
        return dict(obj)

    if obj is Ellipsis:
        return {"class_name": "__ellipsis__"}

    # if isinstance(obj, wrapt.ObjectProxy):
    #     return obj.__wrapped__

    if tf.available and isinstance(obj, tf.TypeSpec):
        from tensorflow.python.framework import type_spec_registry

        try:
            type_spec_name = type_spec_registry.get_name(type(obj))
            return {
                "class_name": "TypeSpec",
                "type_spec": type_spec_name,
                "serialized": obj._serialize(),
            }
        except ValueError:
            raise ValueError(
                f"Unable to serialize {obj} to JSON, because the TypeSpec "
                f"class {type(obj)} has not been registered."
            )
    if tf.available and isinstance(obj, tf.__internal__.CompositeTensor):
        spec = tf.type_spec_from_value(obj)
        tensors = []
        for tensor in tf.nest.flatten(obj, expand_composites=True):
            tensors.append((tensor.dtype.name, tensor.numpy().tolist()))
        return {
            "class_name": "CompositeTensor",
            "spec": get_json_type(spec),
            "tensors": tensors,
        }

    if isinstance(obj, enum.Enum):
        return obj.value

    if isinstance(obj, bytes):
        return {"class_name": "__bytes__", "value": obj.decode("utf-8")}

    raise TypeError(
        f"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}."
    )
