import torch
import torch._dynamo as dynamo
import torch.nn.functional as tnn

from keras.src.backend.config import floatx
from keras.src.backend.torch.core import convert_to_tensor
from keras.src.backend.torch.core import get_device
from keras.src.backend.torch.core import to_torch_dtype
from keras.src.random.seed_generator import SeedGenerator
from keras.src.random.seed_generator import draw_seed
from keras.src.random.seed_generator import make_default_seed


# torch.Generator not supported with dynamo
# see: https://github.com/pytorch/pytorch/issues/88576
@dynamo.disable()
def torch_seed_generator(seed):
    first_seed, second_seed = draw_seed(seed)
    device = get_device()
    if device == "meta":
        # Generator is not supported by the meta device.
        return None
    generator = torch.Generator(device=get_device())
    generator.manual_seed(int(first_seed + second_seed))
    return generator


def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
    dtype = dtype or floatx()
    dtype = to_torch_dtype(dtype)
    # Do not use generator during symbolic execution.
    if get_device() == "meta":
        return torch.normal(
            mean, stddev, size=shape, dtype=dtype, device=get_device()
        )
    generator = torch_seed_generator(seed)
    return torch.normal(
        mean,
        stddev,
        size=shape,
        generator=generator,
        dtype=dtype,
        device=get_device(),
    )


def categorical(logits, num_samples, dtype="int32", seed=None):
    logits = convert_to_tensor(logits)
    dtype = to_torch_dtype(dtype)
    probs = torch.softmax(logits, dim=-1)
    # Do not use generator during symbolic execution.
    if get_device() == "meta":
        return torch.multinomial(
            probs,
            num_samples,
            replacement=True,
        ).type(dtype)
    generator = torch_seed_generator(seed)
    return torch.multinomial(
        probs,
        num_samples,
        replacement=True,
        generator=generator,
    ).type(dtype)


def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
    dtype = dtype or floatx()
    dtype = to_torch_dtype(dtype)
    requested_shape = shape
    if len(requested_shape) == 0:
        shape = (1,)
    # Do not use generator during symbolic execution.
    if get_device() == "meta":
        rand_tensor = torch.rand(size=shape, dtype=dtype, device=get_device())
    else:
        generator = torch_seed_generator(seed)
        rand_tensor = torch.rand(
            size=shape, generator=generator, dtype=dtype, device=get_device()
        )

    output = (maxval - minval) * rand_tensor + minval

    if len(requested_shape) == 0:
        return output[0]
    return output


def randint(shape, minval, maxval, dtype="int32", seed=None):
    dtype = to_torch_dtype(dtype)
    # Do not use generator during symbolic execution.
    if get_device() == "meta":
        return torch.randint(
            low=minval,
            high=maxval,
            size=shape,
            dtype=dtype,
            device=get_device(),
        )
    generator = torch_seed_generator(seed)
    return torch.randint(
        low=minval,
        high=maxval,
        size=shape,
        generator=generator,
        dtype=dtype,
        device=get_device(),
    )


def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
    dtype = to_torch_dtype(dtype)
    # Take a larger standard normal dist, discard values outside 2 * stddev
    # Offset by mean and stddev
    x = normal(tuple(shape) + (4,), mean=0, stddev=1, dtype=dtype, seed=seed)
    valid = (x > -2) & (x < 2)
    indexes = valid.max(-1, keepdim=True)[1]
    trunc_x = torch.empty(shape, dtype=dtype, device=get_device())
    trunc_x.data.copy_(x.gather(-1, indexes).squeeze(-1))
    trunc_x.data.mul_(stddev).add_(mean)
    return trunc_x


def _get_concrete_noise_shape(inputs, noise_shape):
    if noise_shape is None:
        return inputs.shape

    concrete_inputs_shape = inputs.shape
    concrete_noise_shape = []
    for i, value in enumerate(noise_shape):
        concrete_noise_shape.append(
            concrete_inputs_shape[i] if value is None else value
        )
    return concrete_noise_shape


def dropout(inputs, rate, noise_shape=None, seed=None):
    if (
        seed is not None
        and not (isinstance(seed, SeedGenerator) and seed._initial_seed is None)
        or noise_shape is not None
    ):
        keep_prob = 1.0 - rate
        noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
        keep_prob_matrix = torch.full(
            noise_shape, keep_prob, device=get_device()
        )
        generator = torch_seed_generator(seed)

        # Do not use generator during symbolic execution.
        if get_device() == "meta":
            mask = torch.bernoulli(keep_prob_matrix)
        else:
            mask = torch.bernoulli(keep_prob_matrix, generator=generator)

        mask = mask.bool()
        mask = torch.broadcast_to(mask, inputs.shape)
        return torch.where(
            mask,
            inputs / keep_prob,
            torch.zeros_like(inputs, dtype=inputs.dtype),
        )
    # Fast path, unseeded (since torch doesn't support seeding dropout!!!!)
    # Using the above implementation is possible, but much slower.
    return torch.nn.functional.dropout(
        inputs, p=rate, training=True, inplace=False
    )


def shuffle(x, axis=0, seed=None):
    # Ref: https://github.com/pytorch/pytorch/issues/71409
    x = convert_to_tensor(x)

    # Get permutation indices
    # Do not use generator during symbolic execution.
    if get_device() == "meta":
        row_perm = torch.rand(x.shape[: axis + 1], device=get_device()).argsort(
            axis
        )
    else:
        generator = torch_seed_generator(seed)
        row_perm = torch.rand(
            x.shape[: axis + 1], generator=generator, device=get_device()
        ).argsort(axis)
    for _ in range(x.ndim - axis - 1):
        row_perm.unsqueeze_(-1)

    # Reformat this for the gather operation
    row_perm = row_perm.repeat(
        *[1 for _ in range(axis + 1)], *(x.shape[axis + 1 :])
    )
    return x.gather(axis, row_perm)


def gamma(shape, alpha, dtype=None, seed=None):
    dtype = dtype or floatx()
    dtype = to_torch_dtype(dtype)
    alpha = torch.broadcast_to(convert_to_tensor(alpha), shape)
    beta = torch.ones(shape, device=get_device())
    prev_rng_state = torch.random.get_rng_state()
    # Do not draw seed during symbolic execution
    if not get_device() == "meta":
        first_seed, second_seed = draw_seed(seed)
        torch.manual_seed(first_seed + second_seed)
    gamma_distribution = torch.distributions.gamma.Gamma(alpha, beta)
    sample = gamma_distribution.sample().type(dtype)
    torch.random.set_rng_state(prev_rng_state)
    return sample


def binomial(shape, counts, probabilities, dtype=None, seed=None):
    dtype = dtype or floatx()
    dtype = to_torch_dtype(dtype)
    counts = torch.broadcast_to(convert_to_tensor(counts), shape)
    probabilities = torch.broadcast_to(convert_to_tensor(probabilities), shape)
    prev_rng_state = torch.random.get_rng_state()
    # Do not draw seed during symbolic execution
    if not get_device() == "meta":
        first_seed, second_seed = draw_seed(seed)
        torch.manual_seed(first_seed + second_seed)
    binomial_distribution = torch.distributions.binomial.Binomial(
        total_count=counts, probs=probabilities
    )
    sample = binomial_distribution.sample().type(dtype)
    torch.random.set_rng_state(prev_rng_state)
    return sample


def beta(shape, alpha, beta, dtype=None, seed=None):
    dtype = dtype or floatx()
    dtype = to_torch_dtype(dtype)
    alpha = torch.broadcast_to(convert_to_tensor(alpha), shape)
    beta = torch.broadcast_to(convert_to_tensor(beta), shape)
    prev_rng_state = torch.random.get_rng_state()
    # Do not draw seed during symbolic execution
    if not get_device() == "meta":
        first_seed, second_seed = draw_seed(seed)
        torch.manual_seed(first_seed + second_seed)
    beta_distribution = torch.distributions.beta.Beta(
        concentration1=alpha, concentration0=beta
    )
    sample = beta_distribution.sample().type(dtype)
    torch.random.set_rng_state(prev_rng_state)
    return sample
