import functools
import math

import numpy as np

from keras.src import tree
from keras.src.trainers.data_adapters import array_slicing
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.trainers.data_adapters.data_adapter import DataAdapter


class ArrayDataAdapter(DataAdapter):
    """Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays."""

    def __init__(
        self,
        x,
        y=None,
        sample_weight=None,
        batch_size=None,
        steps=None,
        shuffle=False,
        class_weight=None,
    ):
        if not can_convert_arrays((x, y, sample_weight)):
            raise ValueError(
                "Expected all elements of `x` to be array-like. "
                f"Received invalid types: x={x}"
            )

        if sample_weight is not None:
            if class_weight is not None:
                raise ValueError(
                    "You cannot `class_weight` and `sample_weight` "
                    "at the same time."
                )
            if tree.is_nested(y):
                if isinstance(sample_weight, (list, tuple, dict)):
                    try:
                        tree.assert_same_structure(y, sample_weight)
                    except ValueError:
                        raise ValueError(
                            "You should provide one `sample_weight` array per "
                            "output in `y`. The two structures did not match:\n"
                            f"- y: {y}\n"
                            f"- sample_weight: {sample_weight}\n"
                        )
                else:
                    is_samplewise = len(sample_weight.shape) == 1 or (
                        len(sample_weight.shape) == 2
                        and sample_weight.shape[1] == 1
                    )
                    if not is_samplewise:
                        raise ValueError(
                            "For a model with multiple outputs, when providing "
                            "a single `sample_weight` array, it should only "
                            "have one scalar score per sample "
                            "(i.e. shape `(num_samples,)`). If you want to use "
                            "non-scalar sample weights, pass a `sample_weight` "
                            "argument with one array per model output."
                        )
                    # Replicate the same sample_weight array on all outputs.
                    sample_weight = tree.map_structure(
                        lambda _: sample_weight, y
                    )
        if class_weight is not None:
            if tree.is_nested(y):
                raise ValueError(
                    "`class_weight` is only supported for Models with a single "
                    "output."
                )
            sample_weight = data_adapter_utils.class_weight_to_sample_weights(
                y, class_weight
            )

        inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)

        data_adapter_utils.check_data_cardinality(inputs)
        num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
        self._num_samples = num_samples
        self._inputs = inputs

        # If batch_size is not passed but steps is, calculate from the input
        # data.  Defaults to `32` for backwards compatibility.
        if not batch_size:
            batch_size = int(math.ceil(num_samples / steps)) if steps else 32

        self._size = int(math.ceil(num_samples / batch_size))
        self._batch_size = batch_size
        self._partial_batch_size = num_samples % batch_size
        self._shuffle = shuffle

    def get_numpy_iterator(self):
        inputs = array_slicing.convert_to_sliceable(
            self._inputs, target_backend="numpy"
        )

        def slice_and_convert_to_numpy(sliceable, indices=None):
            x = sliceable[indices]
            x = sliceable.convert_to_numpy(x)
            return x

        return self._get_iterator(slice_and_convert_to_numpy, inputs)

    def get_tf_dataset(self):
        from keras.src.utils.module_utils import tensorflow as tf

        shuffle = self._shuffle
        batch_size = self._batch_size
        num_samples = self._num_samples
        num_full_batches = int(self._num_samples // batch_size)

        # Vectorized version of shuffle.
        # This is a performance improvement over using `from_tensor_slices`.
        # The indices of the data are shuffled and batched, and these indices
        # are then zipped with the data and used to extract a batch of the data
        # at each step. The performance improvements here come from:
        # 1. vectorized batch using gather
        # 2. parallelized map
        # 3. pipelined permutation generation
        # 4. optimized permutation batching
        # 5. disabled static optimizations

        indices_dataset = tf.data.Dataset.range(1)

        def permutation(_):
            # It turns out to be more performant to make a new set of indices
            # rather than reusing the same range Tensor. (presumably because of
            # buffer forwarding.)
            indices = tf.range(num_samples, dtype=tf.int64)
            if shuffle and shuffle != "batch":
                indices = tf.random.shuffle(indices)
            return indices

        # We prefetch a single element. Computing large permutations can take
        # quite a while so we don't want to wait for prefetching over an epoch
        # boundary to trigger the next permutation. On the other hand, too many
        # simultaneous shuffles can contend on a hardware level and degrade all
        # performance.
        indices_dataset = indices_dataset.map(permutation).prefetch(1)

        def slice_batch_indices(indices):
            """Convert a Tensor of indices into a dataset of batched indices.

            This step can be accomplished in several ways. The most natural is
            to slice the Tensor in a Dataset map. (With a condition on the upper
            index to handle the partial batch.) However it turns out that
            coercing the Tensor into a shape which is divisible by the batch
            size (and handling the last partial batch separately) allows for a
            much more favorable memory access pattern and improved performance.

            Args:
                indices: Tensor which determines the data order for an entire
                    epoch.

            Returns:
                A Dataset of batched indices.
            """
            num_in_full_batch = num_full_batches * batch_size
            first_k_indices = tf.slice(indices, [0], [num_in_full_batch])
            first_k_indices = tf.reshape(
                first_k_indices, [num_full_batches, batch_size]
            )

            flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)
            if self._partial_batch_size:
                index_remainder = tf.data.Dataset.from_tensors(
                    tf.slice(
                        indices, [num_in_full_batch], [self._partial_batch_size]
                    )
                )
                flat_dataset = flat_dataset.concatenate(index_remainder)

            return flat_dataset

        def slice_inputs(indices_dataset, inputs):
            """Slice inputs into a Dataset of batches.

            Given a Dataset of batch indices and the unsliced inputs,
            this step slices the inputs in a parallelized fashion
            and produces a dataset of input batches.

            Args:
                indices_dataset: A Dataset of batched indices.
                inputs: A python data structure that contains the inputs,
                    targets, and possibly sample weights.

            Returns:
                A Dataset of input batches matching the batch indices.
            """
            inputs = array_slicing.convert_to_sliceable(
                self._inputs, target_backend="tensorflow"
            )
            inputs = tree.lists_to_tuples(inputs)

            dataset = tf.data.Dataset.zip(
                (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())
            )

            def grab_batch(i, data):
                def grab_one(x):
                    if isinstance(x, array_slicing.TensorflowSparseWrapper):
                        return array_slicing.slice_tensorflow_sparse_wrapper(
                            x, i
                        )
                    if isinstance(x, (list, tuple, dict)):
                        return None
                    if tf.is_tensor(x):
                        return tf.gather(x, i, axis=0)
                    return x

                return tree.traverse(grab_one, data)

            dataset = dataset.map(
                grab_batch, num_parallel_calls=tf.data.AUTOTUNE
            )

            # Default optimizations are disabled to avoid the overhead of
            # (unnecessary) input pipeline graph serialization & deserialization
            options = tf.data.Options()
            options.experimental_optimization.apply_default_optimizations = (
                False
            )
            if self._shuffle:
                options.experimental_external_state_policy = (
                    tf.data.experimental.ExternalStatePolicy.IGNORE
                )
            dataset = dataset.with_options(options)
            return dataset

        indices_dataset = indices_dataset.flat_map(slice_batch_indices)
        if shuffle == "batch":
            indices_dataset = indices_dataset.map(tf.random.shuffle)

        dataset = slice_inputs(indices_dataset, self._inputs)

        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = (
            tf.data.experimental.AutoShardPolicy.DATA
        )
        dataset = dataset.with_options(options)
        return dataset.prefetch(tf.data.AUTOTUNE)

    def get_jax_iterator(self):
        inputs = array_slicing.convert_to_sliceable(
            self._inputs, target_backend="jax"
        )

        def slice_and_convert_to_jax(sliceable, indices=None):
            x = sliceable[indices]
            x = sliceable.convert_to_jax_compatible(x)
            return x

        return self._get_iterator(slice_and_convert_to_jax, inputs)

    def get_torch_dataloader(self):
        import torch

        from keras.src.backend.torch.core import convert_to_tensor

        class ArrayDataset(torch.utils.data.Dataset):
            def __init__(self, array):
                self.array = array

            def __getitems__(self, indices):
                def slice_and_convert(sliceable):
                    x = sliceable[indices]
                    x = sliceable.convert_to_torch_compatible(x)
                    x = convert_to_tensor(x)
                    return x

                return tree.map_structure(slice_and_convert, self.array)

            def __len__(self):
                return len(self.array[0])

        class RandomBatchSampler(torch.utils.data.Sampler):
            def __init__(self, sampler):
                self.sampler = sampler

            def __iter__(self):
                for batch in self.sampler:
                    yield [batch[i] for i in torch.randperm(len(batch))]

            def __len__(self):
                return len(self.sampler)

        if self._shuffle == "batch":
            batch_sampler = RandomBatchSampler(
                torch.utils.data.BatchSampler(
                    range(self._num_samples),
                    batch_size=self._batch_size,
                    drop_last=False,
                )
            )
        elif self._shuffle:
            batch_sampler = torch.utils.data.BatchSampler(
                torch.utils.data.RandomSampler(range(self._num_samples)),
                batch_size=self._batch_size,
                drop_last=False,
            )
        else:
            batch_sampler = torch.utils.data.BatchSampler(
                torch.utils.data.SequentialSampler(range(self._num_samples)),
                batch_size=self._batch_size,
                drop_last=False,
            )

        # Because ArrayDataset.__getitems__ returns full batches organized in
        # the expected structure, there is nothing to collate.
        def no_op_collate(batch):
            return batch

        inputs = array_slicing.convert_to_sliceable(
            self._inputs, target_backend="torch"
        )
        dataset = ArrayDataset(inputs)
        return torch.utils.data.DataLoader(
            dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate
        )

    def _get_iterator(self, slice_and_convert_fn, inputs):
        global_permutation = None
        if self._shuffle and self._shuffle != "batch":
            global_permutation = np.random.permutation(self._num_samples)

        for i in range(self._size):
            start = i * self._batch_size
            stop = min((i + 1) * self._batch_size, self._num_samples)
            if self._shuffle == "batch":
                indices = np.random.permutation(stop - start) + start
            elif self._shuffle:
                indices = global_permutation[start:stop]
            else:
                indices = slice(start, stop)

            slice_indices_and_convert_fn = functools.partial(
                slice_and_convert_fn, indices=indices
            )
            yield tree.map_structure(slice_indices_and_convert_fn, inputs)

    @property
    def num_batches(self):
        return self._size

    @property
    def batch_size(self):
        return self._batch_size

    @property
    def has_partial_batch(self):
        return self._partial_batch_size > 0

    @property
    def partial_batch_size(self):
        return self._partial_batch_size or None


def can_convert_arrays(arrays):
    """Check if array like-inputs can be handled by `ArrayDataAdapter`

    Args:
        inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.

    Returns:
        `True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
        otherwise.
    """

    return all(
        tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays))
    )
