import gc
import threading

from keras.src import backend
from keras.src.api_export import keras_export

GLOBAL_STATE_TRACKER = threading.local()
GLOBAL_SETTINGS_TRACKER = threading.local()


def set_global_attribute(name, value):
    setattr(GLOBAL_STATE_TRACKER, name, value)


def get_global_attribute(name, default=None, set_to_default=False):
    attr = getattr(GLOBAL_STATE_TRACKER, name, None)
    if attr is None and default is not None:
        attr = default
        if set_to_default:
            set_global_attribute(name, attr)
    return attr


@keras_export(["keras.utils.clear_session", "keras.backend.clear_session"])
def clear_session(free_memory=True):
    """Resets all state generated by Keras.

    Keras manages a global state, which it uses to implement the Functional
    model-building API and to uniquify autogenerated layer names.

    If you are creating many models in a loop, this global state will consume
    an increasing amount of memory over time, and you may want to clear it.
    Calling `clear_session()` releases the global state: this helps avoid
    clutter from old models and layers, especially when memory is limited.

    Args:
        free_memory: Whether to call Python garbage collection.
            It's usually a good practice to call it to make sure
            memory used by deleted objects is immediately freed.
            However, it may take a few seconds to execute, so
            when using `clear_session()` in a short loop,
            you may want to skip it.

    Example 1: calling `clear_session()` when creating models in a loop

    ```python
    for _ in range(100):
      # Without `clear_session()`, each iteration of this loop will
      # slightly increase the size of the global state managed by Keras
      model = keras.Sequential([
          keras.layers.Dense(10) for _ in range(10)])

    for _ in range(100):
      # With `clear_session()` called at the beginning,
      # Keras starts with a blank state at each iteration
      # and memory consumption is constant over time.
      keras.backend.clear_session()
      model = keras.Sequential([
          keras.layers.Dense(10) for _ in range(10)])
    ```

    Example 2: resetting the layer name generation counter

    >>> layers = [keras.layers.Dense(10) for _ in range(10)]
    >>> new_layer = keras.layers.Dense(10)
    >>> print(new_layer.name)
    dense_10
    >>> keras.backend.clear_session()
    >>> new_layer = keras.layers.Dense(10)
    >>> print(new_layer.name)
    dense
    """
    global GLOBAL_STATE_TRACKER
    global GLOBAL_SETTINGS_TRACKER

    GLOBAL_STATE_TRACKER = threading.local()
    GLOBAL_SETTINGS_TRACKER = threading.local()

    if backend.backend() == "tensorflow":
        from keras.src.utils.module_utils import tensorflow as tf

        tf.compat.v1.reset_default_graph()
        if tf.executing_eagerly():
            # Clear pending nodes in eager executors, kernel caches and
            # step_containers.
            from tensorflow.python.eager import context

            context.context().clear_kernel_cache()
    elif backend.backend() == "torch":
        import torch._dynamo as dynamo

        # reset's torchdynamo's cache so that  cached guards, compiled fn, etc
        # do not persist between clear_session() calls
        dynamo.reset()

    if free_memory:
        # Manually trigger garbage collection.
        gc.collect()
