from keras.src import backend
from keras.src.utils.module_utils import tensorflow as tf


def get_tensor_spec(t, dynamic_batch=False, name=None):
    """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
    if isinstance(t, tf.TypeSpec):
        spec = t
    elif isinstance(t, tf.__internal__.CompositeTensor):
        # Check for ExtensionTypes
        spec = t._type_spec
    elif hasattr(t, "shape") and hasattr(t, "dtype"):
        spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
    else:
        return None  # Allow non-Tensors to pass through.

    if not dynamic_batch:
        return spec

    shape = spec.shape
    if shape.rank is None or shape.rank == 0:
        return spec

    shape_list = shape.as_list()
    shape_list[0] = None
    shape = tf.TensorShape(shape_list)
    spec._shape = shape
    return spec


def ensure_tensor(inputs, dtype=None):
    """Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
    if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):
        if backend.backend() == "torch" and backend.is_tensor(inputs):
            # Plain `np.asarray()` conversion fails with PyTorch.
            inputs = backend.convert_to_numpy(inputs)
        inputs = tf.convert_to_tensor(inputs, dtype)
    if dtype is not None and inputs.dtype != dtype:
        inputs = tf.cast(inputs, dtype)
    return inputs


def is_ragged_tensor(x):
    return "ragged_tensor.RaggedTensor" in str(type(x))


def sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None):
    """Apply binary or count encoding to an input and return a sparse tensor."""
    result = tf.sparse.bincount(
        inputs,
        weights=count_weights,
        minlength=depth,
        maxlength=depth,
        axis=-1,
        binary_output=binary_output,
    )
    result = tf.cast(result, dtype)
    if inputs.shape.rank == 1:
        output_shape = (depth,)
    else:
        batch_size = tf.shape(result)[0]
        output_shape = (batch_size, depth)
    result = tf.SparseTensor(
        indices=result.indices, values=result.values, dense_shape=output_shape
    )
    return result


def dense_bincount(inputs, depth, binary_output, dtype, count_weights=None):
    """Apply binary or count encoding to an input."""
    result = tf.math.bincount(
        inputs,
        weights=count_weights,
        minlength=depth,
        maxlength=depth,
        dtype=dtype,
        axis=-1,
        binary_output=binary_output,
    )
    if inputs.shape.rank == 1:
        result.set_shape(tf.TensorShape((depth,)))
    else:
        batch_size = inputs.shape.as_list()[0]
        result.set_shape(tf.TensorShape((batch_size, depth)))
    return result


def expand_dims(inputs, axis):
    """Expand dims on sparse, ragged, or dense tensors."""
    if isinstance(inputs, tf.SparseTensor):
        return tf.sparse.expand_dims(inputs, axis)
    return tf.expand_dims(inputs, axis)


def tf_encode_categorical_inputs(
    inputs,
    output_mode,
    depth,
    dtype="float32",
    sparse=False,
    count_weights=None,
    idf_weights=None,
):
    """Encodes categorical inputs according to output_mode.

    Faster method that relies on bincount.
    """

    if output_mode == "int":
        return tf.identity(tf.cast(inputs, dtype))

    original_shape = inputs.shape
    # In all cases, we should uprank scalar input to a single sample.
    if inputs.shape.rank == 0:
        inputs = expand_dims(inputs, -1)
    # One hot will unprank only if the final output dimension is not already 1.
    if output_mode == "one_hot":
        if inputs.shape[-1] != 1:
            inputs = expand_dims(inputs, -1)

    if inputs.shape.rank > 2:
        raise ValueError(
            "When output_mode is not `'int'`, maximum supported output rank "
            f"is 2. Received output_mode {output_mode} and input shape "
            f"{original_shape}, "
            f"which would result in output rank {inputs.shape.rank}."
        )

    binary_output = output_mode in ("multi_hot", "one_hot")
    if sparse:
        bincounts = sparse_bincount(
            inputs, depth, binary_output, dtype, count_weights
        )
    else:
        bincounts = dense_bincount(
            inputs, depth, binary_output, dtype, count_weights
        )

    bincounts = tf.cast(bincounts, dtype)
    if output_mode != "tf_idf":
        return bincounts

    if idf_weights is None:
        raise ValueError(
            "When output mode is `'tf_idf'`, idf_weights must be provided. "
            f"Received: output_mode={output_mode} and idf_weights={idf_weights}"
        )

    if sparse:
        value_weights = tf.gather(idf_weights, bincounts.indices[:, -1])
        return tf.SparseTensor(
            bincounts.indices,
            value_weights * bincounts.values,
            bincounts.dense_shape,
        )
    else:
        return tf.multiply(bincounts, idf_weights)
