import collections
import math

import numpy as np

from keras.src import backend
from keras.src import tree
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.utils.module_utils import tensorflow as tf

try:
    import pandas
except ImportError:
    pandas = None


# Leave jax, tf, and torch arrays off this list. Instead we will use
# `__array__` to detect these types. Doing so allows us to avoid importing a
# backend framework we are not currently using just to do type-checking.
ARRAY_TYPES = (np.ndarray,)
if pandas:
    ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)


class Sliceable:
    """`Sliceable` wrapping a tensor.

    A `Sliceable` implements the subscript operator to slice or index against
    the first dimension of the array. It also has conversion methods for each
    one of the backends.

    Args:
        array: the native array or tensor to wrap.

    Attributes:
        shape: the shape of the full dense native array.
    """

    def __init__(self, array):
        self.array = array

    def __getitem__(self, indices):
        """Select elements in the 0th dimension.

        Args:
            indices: the indices to select. Only needs to support one dimension,
                the 0th dimension. Should support a `slice` or a list, tuple,
                `np.array` or 1D tensor.
        Returns: A slice of `self.array`.
        """
        return self.array[indices]

    @classmethod
    def cast(cls, x, dtype):
        """Cast a tensor to a different dtype.

        Only called on a full array as provided by the user.

        Args:
            x: the tensor to cast.
        Returns: the cast tensor.
        """
        return x.astype(dtype)

    @classmethod
    def convert_to_numpy(cls, x):
        """Convert a tensor to a NumPy array.

        Only called after slicing using `__getitem__`.

        Args:
            x: the tensor to convert.
        Returns: the converted tensor.
        """
        return x

    @classmethod
    def convert_to_tf_dataset_compatible(cls, x):
        """Convert a tensor to something compatible with `tf.data.Dataset`.

        This can be a NumPy array, `tf.Tensor` or any other type of tensor that
        `tf.data.Dataset.from_tensors` can consume.
        Only called on a full array as provided by the user.

        Args:
            x: the tensor to convert.
        Returns: converted version tensor.
        """
        return x

    @classmethod
    def convert_to_jax_compatible(cls, x):
        """Convert a tensor to something that the JAX backend can consume.

        This can be a `JAX` array, `JAXSparse` or a NumPy array.
        Only called after slicing using `__getitem__`.
        Used to convert sparse tensors and densify ragged tensors.

        Args:
            x: the tensor to convert.
        Returns: the converted tensor.
        """
        return x

    @classmethod
    def convert_to_torch_compatible(cls, x):
        """Convert a tensor to something that the Torch backend can consume.

        This can be a Torch tensor, NumPy array or any other type of tensor that
        `keras.backend.torch.core.convert_to_tensor()` can consume.
        Only called after slicing using `__getitem__`.
        Used to densify sparse tensors and ragged tensors.

        Args:
            x: the tensor to convert.
        Returns: the converted tensor.
        """
        return x


class NumpySliceable(Sliceable):
    pass


class TensorflowSliceable(Sliceable):
    def __getitem__(self, indices):
        from keras.src.utils.module_utils import tensorflow as tf

        if isinstance(indices, slice):
            return self.array[indices]
        else:
            return tf.gather(self.array, indices, axis=0)

    @classmethod
    def cast(cls, x, dtype):
        from keras.src.backend.tensorflow.core import cast

        return cast(x, dtype)

    @classmethod
    def convert_to_numpy(cls, x):
        from keras.src.backend.tensorflow.core import convert_to_numpy

        return convert_to_numpy(x)


class TensorflowRaggedSliceable(TensorflowSliceable):
    @classmethod
    def convert_to_jax_compatible(cls, x):
        return cls.convert_to_numpy(x)

    @classmethod
    def convert_to_torch_compatible(cls, x):
        return x.to_tensor()


class TensorflowSparseSliceable(TensorflowSliceable):
    def __init__(self, array):
        super().__init__(to_tensorflow_sparse_wrapper(array))

    @property
    def shape(self):
        return self.array.sparse.shape

    def __getitem__(self, indices):
        return slice_tensorflow_sparse_wrapper(self.array, indices)

    @classmethod
    def convert_to_tf_dataset_compatible(cls, x):
        return to_tensorflow_sparse_wrapper(x)

    @classmethod
    def convert_to_jax_compatible(cls, x):
        return data_adapter_utils.tf_sparse_to_jax_sparse(x)

    @classmethod
    def convert_to_torch_compatible(cls, x):
        from keras.src.backend.tensorflow import sparse as tf_sparse

        return tf_sparse.sparse_to_dense(x)


class JaxSparseSliceable(Sliceable):
    def __getitem__(self, indices):
        return self.array[indices, ...]

    @classmethod
    def convert_to_numpy(cls, x):
        from keras.src.backend.jax.core import convert_to_numpy

        return convert_to_numpy(x)

    @classmethod
    def convert_to_tf_dataset_compatible(cls, array):
        return to_tensorflow_sparse_wrapper(
            data_adapter_utils.jax_sparse_to_tf_sparse(array)
        )

    @classmethod
    def convert_to_torch_compatible(cls, x):
        return x.todense()


class TorchSliceable(Sliceable):
    @classmethod
    def cast(cls, x, dtype):
        from keras.src.backend.torch.core import cast

        return cast(x, dtype)

    @classmethod
    def convert_to_numpy(cls, x):
        from keras.src.backend.torch.core import convert_to_numpy

        return convert_to_numpy(x)


class PandasSliceable(Sliceable):
    def __getitem__(self, indices):
        return self.array.iloc[indices]

    @classmethod
    def convert_to_numpy(cls, x):
        return x.to_numpy()

    @classmethod
    def convert_to_tf_dataset_compatible(cls, x):
        return cls.convert_to_numpy(x)

    @classmethod
    def convert_to_jax_compatible(cls, x):
        return cls.convert_to_numpy(x)

    @classmethod
    def convert_to_torch_compatible(cls, x):
        return cls.convert_to_numpy(x)


class PandasDataFrameSliceable(PandasSliceable):
    pass


class PandasSeriesSliceable(PandasSliceable):
    @classmethod
    def convert_to_numpy(cls, x):
        return np.expand_dims(x.to_numpy(), axis=-1)


class ScipySparseSliceable(Sliceable):
    def __init__(self, array):
        # The COO representation is not indexable / sliceable and does not lend
        # itself to it. Use the CSR representation instead, which is sliceable.
        super().__init__(array.tocsr())

    @classmethod
    def convert_to_numpy(cls, x):
        return x.todense()

    @classmethod
    def convert_to_tf_dataset_compatible(cls, x):
        return to_tensorflow_sparse_wrapper(
            data_adapter_utils.scipy_sparse_to_tf_sparse(x)
        )

    @classmethod
    def convert_to_jax_compatible(cls, x):
        return data_adapter_utils.scipy_sparse_to_jax_sparse(x)

    @classmethod
    def convert_to_torch_compatible(cls, x):
        return x.todense()


# `tf.SparseTensor` does not support indexing or `tf.gather`. The COO
# representation it uses does not lend itself to indexing. We add some
# intermediary tensors to ease the indexing and slicing. We put both indices and
# values in `RaggedTensor`s where each row corresponds to a row in the sparse
# tensor. This is because the number of values per row is not fixed.
# `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only.
# We then reconstruct a `SparseTensor` from extracted rows. In theory, there is
# no duplication of data for the indices and values, only the addition of row
# splits for the ragged representation.
# `TensorflowSparseWrapper` is a named tuple which combines the original
# `SparseTensor` (used for the shape) and the ragged representations of indices
# and values for indexing / slicing. We use a named tuple and not a `Sliceable`
# to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it.

TensorflowSparseWrapper = collections.namedtuple(
    "TensorflowSparseWrapper", ["sparse", "ragged_indices", "ragged_values"]
)


def to_tensorflow_sparse_wrapper(sparse):
    from keras.src.utils.module_utils import tensorflow as tf

    row_ids = sparse.indices[:, 0]
    row_splits = tf.experimental.RowPartition.from_value_rowids(
        row_ids
    ).row_splits()

    ragged_indices = tf.cast(
        tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64
    )
    ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits)
    return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values)


def slice_tensorflow_sparse_wrapper(sparse_wrapper, indices):
    from keras.src.utils.module_utils import tensorflow as tf

    if isinstance(indices, slice):
        sparse_indices = sparse_wrapper.ragged_indices[indices]
        sparse_values = sparse_wrapper.ragged_values[indices]
        batch_dim = indices.stop - indices.start
    else:
        sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices)
        sparse_values = tf.gather(sparse_wrapper.ragged_values, indices)
        if isinstance(indices, list):
            batch_dim = len(indices)
        else:
            batch_dim = indices.shape[0]
            if batch_dim is None:
                batch_dim = tf.shape(indices)[0]

    row_ids = sparse_indices.value_rowids()
    sparse_indices = sparse_indices.flat_values[:, 1:]  # remove first value
    sparse_indices = tf.concat(
        [tf.expand_dims(row_ids, -1), sparse_indices], axis=1
    )

    sparse_values = sparse_values.flat_values
    sparse_shape = (batch_dim,) + tuple(
        sparse_wrapper.sparse.shape.as_list()[1:]
    )
    return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape)


def can_slice_array(x):
    return (
        x is None
        or isinstance(x, ARRAY_TYPES)
        or data_adapter_utils.is_tensorflow_tensor(x)
        or data_adapter_utils.is_jax_array(x)
        or data_adapter_utils.is_torch_tensor(x)
        or data_adapter_utils.is_scipy_sparse(x)
        or hasattr(x, "__array__")
    )


def convert_to_sliceable(arrays, target_backend=None):
    """Convert a structure of arrays into `Sliceable` instances

    Args:
        arrays: the arrays to convert.
        target_backend: the target backend for the output:
            - `None` indicates that `arrays` will be wrapped into `Sliceable`s
              as-is without using a different representation. This is used by
              `train_validation_split()`.
            - `tensorflow` indicates that
              `Sliceable.convert_to_tf_dataset_compatible` will be called. The
              returned structure therefore contains arrays, not `Sliceable`s.
            - `numpy`, `jax` or `torch` indices that the arrays will eventually
              be converted to this backend type after slicing. In this case,
              the intermediary `Sliceable`s may use a different representation
              from the input `arrays` for better performance.
    Returns: the same structure with `Sliceable` instances or arrays.
    """

    def convert_single_array(x):
        if x is None:
            return x

        # Special case: handle np "object" arrays containing strings
        if (
            isinstance(x, np.ndarray)
            and str(x.dtype) == "object"
            and backend.backend() == "tensorflow"
            and all(isinstance(e, str) for e in x)
        ):
            x = tf.convert_to_tensor(x, dtype="string")

        # Step 1. Determine which Sliceable class to use.
        if isinstance(x, np.ndarray):
            sliceable_class = NumpySliceable
        elif data_adapter_utils.is_tensorflow_tensor(x):
            if data_adapter_utils.is_tensorflow_ragged(x):
                sliceable_class = TensorflowRaggedSliceable
            elif data_adapter_utils.is_tensorflow_sparse(x):
                sliceable_class = TensorflowSparseSliceable
            else:
                sliceable_class = TensorflowSliceable
        elif data_adapter_utils.is_jax_array(x):
            if data_adapter_utils.is_jax_sparse(x):
                sliceable_class = JaxSparseSliceable
            else:
                x = np.asarray(x)
                sliceable_class = NumpySliceable
        elif data_adapter_utils.is_torch_tensor(x):
            sliceable_class = TorchSliceable
        elif pandas is not None and isinstance(x, pandas.DataFrame):
            sliceable_class = PandasDataFrameSliceable
        elif pandas is not None and isinstance(x, pandas.Series):
            sliceable_class = PandasSeriesSliceable
        elif data_adapter_utils.is_scipy_sparse(x):
            sliceable_class = ScipySparseSliceable
        elif hasattr(x, "__array__"):
            x = np.asarray(x)
            sliceable_class = NumpySliceable
        else:
            raise ValueError(
                "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
                "tf.SparseTensor, jax.np.ndarray, "
                "jax.experimental.sparse.JAXSparse, torch.Tensor, "
                "Pandas Dataframe, or Pandas Series. Received invalid input: "
                f"{x} (of type {type(x)})"
            )

        # Step 2. Normalize floats to floatx.
        def is_non_floatx_float(dtype):
            return (
                dtype is not object
                and backend.is_float_dtype(dtype)
                and not backend.standardize_dtype(dtype) == backend.floatx()
            )

        cast_dtype = None
        if pandas is not None and isinstance(x, pandas.DataFrame):
            if any(is_non_floatx_float(d) for d in x.dtypes.values):
                cast_dtype = backend.floatx()
        else:
            if is_non_floatx_float(x.dtype):
                cast_dtype = backend.floatx()

        if cast_dtype is not None:
            x = sliceable_class.cast(x, cast_dtype)

        # Step 3. Apply target backend specific logic and optimizations.
        if target_backend is None:
            return sliceable_class(x)

        if target_backend == "tensorflow":
            return sliceable_class.convert_to_tf_dataset_compatible(x)

        # With dense arrays and JAX as output, it is faster to use NumPy as an
        # intermediary representation, so wrap input array in a NumPy array,
        # which should not use extra memory.
        # See https://github.com/google/jax/issues/1276 for an explanation of
        # why slicing a NumPy array is faster than slicing a JAX array.
        if target_backend == "jax" and sliceable_class in (
            TensorflowSliceable,
            TorchSliceable,
        ):
            x = np.asarray(x)
            sliceable_class = NumpySliceable

        return sliceable_class(x)

    return tree.map_structure(convert_single_array, arrays)


def train_validation_split(arrays, validation_split):
    """Split arrays into train and validation subsets in deterministic order.

    The last part of data will become validation data.

    Args:
        arrays: Tensors to split. Allowed inputs are arbitrarily nested
            structures of Tensors and NumPy arrays.
        validation_split: Float between 0 and 1. The proportion of the dataset
            to include in the validation split. The rest of the dataset will be
            included in the training split.

    Returns:
        `(train_arrays, validation_arrays)`
    """

    flat_arrays = tree.flatten(arrays)
    unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)]
    if unsplitable:
        raise ValueError(
            "Argument `validation_split` is only supported "
            "for tensors or NumPy arrays."
            f"Found incompatible type in the input: {unsplitable}"
        )

    if all(t is None for t in flat_arrays):
        return arrays, arrays

    first_non_none = None
    for t in flat_arrays:
        if t is not None:
            first_non_none = t
            break

    # Assumes all arrays have the same batch shape or are `None`.
    batch_dim = int(first_non_none.shape[0])
    split_at = int(math.floor(batch_dim * (1.0 - validation_split)))

    if split_at == 0 or split_at == batch_dim:
        raise ValueError(
            f"Training data contains {batch_dim} samples, which is not "
            "sufficient to split it into a validation and training set as "
            f"specified by `validation_split={validation_split}`. Either "
            "provide more data, or a different value for the "
            "`validation_split` argument."
        )

    def _split(t, start, end):
        if t is None:
            return t
        return t[start:end]

    sliceables = convert_to_sliceable(arrays)
    train_arrays = tree.map_structure(
        lambda x: _split(x, start=0, end=split_at), sliceables
    )
    val_arrays = tree.map_structure(
        lambda x: _split(x, start=split_at, end=batch_dim), sliceables
    )
    return train_arrays, val_arrays
