import jax
import numpy as np
from jax import lax

from keras.src import backend
from keras.src.backend.common.backend_utils import (
    compute_conv_transpose_padding_args_for_jax,
)
from keras.src.backend.numpy.core import cast
from keras.src.backend.numpy.core import convert_to_tensor
from keras.src.backend.numpy.core import is_tensor
from keras.src.utils.module_utils import scipy


def relu(x):
    x = convert_to_tensor(x)
    return np.maximum(x, np.array(0.0, x.dtype))


def relu6(x):
    x = convert_to_tensor(x)
    # np.clip incorrectly promote bfloat16 to float32, so we replace it with
    # np.minimum and np.maximum here
    return np.minimum(
        np.maximum(x, np.array(0.0, x.dtype)), np.array(6.0, x.dtype)
    )


def sigmoid(x):
    x = convert_to_tensor(x)
    return np.array(1.0, x.dtype) / (np.array(1.0, x.dtype) + np.exp(-x))


def sparse_sigmoid(x):
    x = convert_to_tensor(x)
    return np.where(
        x <= -1,
        np.array(0.0, x.dtype),
        np.where(
            x >= 1, np.array(1.0, x.dtype), np.array(0.5 * (x + 1), x.dtype)
        ),
    )


def tanh(x):
    return np.tanh(x)


def tanh_shrink(x):
    x = convert_to_tensor(x)
    return x - np.tanh(x)


def softplus(x):
    x = convert_to_tensor(x)
    return np.logaddexp(x, np.array(0.0, x.dtype))


def softsign(x):
    x = convert_to_tensor(x)
    return x / (np.array(1.0, x.dtype) + np.abs(x))


def soft_shrink(x, threshold=0.5):
    return np.where(
        x > threshold,
        np.array(x - threshold, dtype=x.dtype),
        np.where(
            x < -threshold,
            np.array(x + threshold, dtype=x.dtype),
            np.array(0.0, dtype=x.dtype),
        ),
    )


def sparse_plus(x):
    return np.where(
        x <= -1,
        np.zeros_like(x, dtype=x.dtype),
        np.where(x < 1, np.array((1 / 4) * (x + 1) ** 2, dtype=x.dtype), x),
    )


def silu(x):
    x = convert_to_tensor(x)
    return x * sigmoid(x)


def squareplus(x, b=4):
    x = convert_to_tensor(x)
    b = convert_to_tensor(b, dtype=x.dtype)
    y = x + np.sqrt(x**2 + b)
    return y / 2


def log_sigmoid(x):
    x = convert_to_tensor(x)
    return -softplus(-x)


def leaky_relu(x, negative_slope=0.2):
    x = convert_to_tensor(x)
    return np.maximum(x, np.array(negative_slope, x.dtype) * x)


def hard_sigmoid(x):
    # python numbers will be promoted to float64 by np, so it's necessary to
    # first convert the python numbers to np scalars
    x = x / np.array(6.0, x.dtype) + np.array(0.5, x.dtype)
    return np.where(
        x <= 0.0,
        np.array(0.0, x.dtype),
        np.where(x >= 1.0, np.array(1.0, x.dtype), x),
    )


def hard_silu(x):
    return x * hard_sigmoid(x)


def elu(x, alpha=1.0):
    x = convert_to_tensor(x)
    return np.where(
        x >= np.array(0.0, x.dtype), x, np.array(alpha, x.dtype) * np.expm1(x)
    )


def selu(
    x,
    alpha=1.6732632423543772848170429916717,
    scale=1.0507009873554804934193349852946,
):
    x = convert_to_tensor(x)
    return np.array(scale, x.dtype) * elu(x, alpha)


def gelu(x, approximate=True):
    x = convert_to_tensor(x)
    # followed by JAX's implementation
    if approximate:
        sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
        cdf = np.array(0.5, x.dtype) * (
            np.array(1.0, x.dtype)
            + np.tanh(
                sqrt_2_over_pi
                * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype))
            )
        )
        return x * cdf
    else:
        sqrt_2 = np.sqrt(2).astype(x.dtype)
        return (
            x
            * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype)
            / np.array(2, x.dtype)
        )


def celu(x, alpha=1.0):
    x = convert_to_tensor(x)
    alpha = np.array(alpha, x.dtype)
    return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1(
        np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha
    )


def glu(x, axis=-1):
    x = convert_to_tensor(x)
    if x.shape[axis] % 2 != 0:
        raise ValueError(
            "axis size must be divisible by 2. "
            f"Received: x.shape={x.shape} with axis={axis}"
        )
    x1, x2 = np.split(x, 2, axis)
    return x1 * (1 / (1 + np.exp(-x2)))


def hard_tanh(x):
    x = convert_to_tensor(x)
    min_val = np.asarray(-1.0, x.dtype)
    max_val = np.asarray(1.0, x.dtype)
    return np.array(np.clip(x, min_val, max_val), dtype=x.dtype)


def hard_shrink(x, threshold=0.5):
    x = convert_to_tensor(x)
    threshold = np.asarray(threshold, x.dtype)
    return np.array(
        np.where(np.abs(x) > threshold, x, np.array(0.0, dtype=x.dtype)),
        dtype=x.dtype,
    )


def threshold(x, threshold, default_value):
    x = convert_to_tensor(x)
    return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype))


def softmax(x, axis=None):
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)


def log_softmax(x, axis=None):
    max_x = np.max(x, axis=axis, keepdims=True)
    logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))
    return x - max_x - logsumexp


def sparsemax(logits, axis=-1):
    # Sort logits along the specified axis in descending order
    logits = convert_to_tensor(logits)
    logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)
    logits_cumsum = np.cumsum(logits_sorted, axis=axis)
    r = np.arange(1, logits.shape[axis] + 1)
    r_shape = [1] * logits.ndim
    r_shape[axis] = -1  # Broadcast to match the target axis
    r = r.reshape(r_shape)
    support = logits_sorted - (logits_cumsum - 1) / r > 0
    # Find the threshold
    k = np.sum(support, axis=axis, keepdims=True)
    logits_cumsum_safe = np.where(support, logits_cumsum, 0.0)
    tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
    output = np.maximum(logits - tau, 0.0)
    return output


def _convert_to_spatial_operand(
    x,
    num_spatial_dims,
    data_format="channels_last",
    include_batch_and_channels=True,
):
    # Helper function that converts an operand to a spatial operand.
    x = (x,) * num_spatial_dims if isinstance(x, int) else x
    if not include_batch_and_channels:
        return x
    if data_format == "channels_last":
        x = (1,) + x + (1,)
    else:
        x = (1,) + (1,) + x
    return x


def _pool(
    inputs,
    initial_value,
    reduce_fn,
    pool_size,
    strides=None,
    padding="valid",
):
    """Helper function to define pooling functions.

    Args:
        inputs: input data of shape `N+2`.
        initial_value: the initial value for the reduction.
        reduce_fn: a reduce function of the form `(T, T) -> T`.
        pool_size: a sequence of `N` integers, representing the window size to
            reduce over.
        strides: a sequence of `N` integers, representing the inter-window
            strides (default: `(1, ..., 1)`).
        padding: either the string `same` or `valid`.

    Returns:
        The output of the reduction for each window slice.
    """
    if padding not in ("same", "valid"):
        raise ValueError(
            f"Invalid padding '{padding}', must be 'same' or 'valid'."
        )
    padding = padding.upper()
    return np.array(
        lax.reduce_window(
            inputs,
            initial_value,
            reduce_fn,
            pool_size,
            strides,
            padding,
        )
    )


def max_pool(
    inputs,
    pool_size,
    strides=None,
    padding="valid",
    data_format=None,
):
    data_format = backend.standardize_data_format(data_format)
    num_spatial_dims = inputs.ndim - 2
    pool_size = _convert_to_spatial_operand(
        pool_size, num_spatial_dims, data_format
    )
    strides = pool_size if strides is None else strides
    strides = _convert_to_spatial_operand(
        strides, num_spatial_dims, data_format
    )
    return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding)


def average_pool(
    inputs,
    pool_size,
    strides,
    padding,
    data_format=None,
):
    data_format = backend.standardize_data_format(data_format)
    num_spatial_dims = inputs.ndim - 2
    pool_size = _convert_to_spatial_operand(
        pool_size, num_spatial_dims, data_format
    )
    strides = pool_size if strides is None else strides
    strides = _convert_to_spatial_operand(
        strides, num_spatial_dims, data_format
    )

    pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)
    if padding == "valid":
        # Avoid the extra reduce_window.
        return pooled / np.prod(pool_size)
    else:
        # Count the number of valid entries at each input point, then use that
        # for computing average. Assumes that any two arrays of same shape will
        # be padded the same. Avoid broadcasting on axis where pooling is
        # skipped.
        shape = [
            (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)
        ]
        window_counts = _pool(
            np.ones(shape, inputs.dtype),
            0.0,
            lax.add,
            pool_size,
            strides,
            padding,
        )
        return pooled / window_counts


def _convert_to_lax_conv_dimension_numbers(
    num_spatial_dims,
    data_format="channels_last",
    transpose=False,
):
    """Create a `lax.ConvDimensionNumbers` for the given inputs."""
    num_dims = num_spatial_dims + 2

    if data_format == "channels_last":
        spatial_dims = tuple(range(1, num_dims - 1))
        inputs_dn = (0, num_dims - 1) + spatial_dims
    else:
        spatial_dims = tuple(range(2, num_dims))
        inputs_dn = (0, 1) + spatial_dims

    if transpose:
        kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))
    else:
        kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))

    return lax.ConvDimensionNumbers(
        lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn
    )


def conv(
    inputs,
    kernel,
    strides=1,
    padding="valid",
    data_format=None,
    dilation_rate=1,
):
    data_format = backend.standardize_data_format(data_format)
    num_spatial_dims = inputs.ndim - 2
    dimension_numbers = _convert_to_lax_conv_dimension_numbers(
        num_spatial_dims,
        data_format,
        transpose=False,
    )
    strides = _convert_to_spatial_operand(
        strides,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )
    dilation_rate = _convert_to_spatial_operand(
        dilation_rate,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )
    if data_format == "channels_last":
        channels = inputs.shape[-1]
    else:
        channels = inputs.shape[1]
    kernel_in_channels = kernel.shape[-2]
    if channels % kernel_in_channels > 0:
        raise ValueError(
            "The number of input channels must be evenly divisible by "
            f"kernel's in_channels. Received input channels {channels} and "
            f"kernel in_channels {kernel_in_channels}. "
        )
    feature_group_count = channels // kernel_in_channels
    return np.array(
        jax.lax.conv_general_dilated(
            inputs,
            kernel if is_tensor(kernel) else kernel.numpy(),
            strides,
            padding,
            rhs_dilation=dilation_rate,
            dimension_numbers=dimension_numbers,
            feature_group_count=feature_group_count,
        )
    )


def depthwise_conv(
    inputs,
    kernel,
    strides=1,
    padding="valid",
    data_format=None,
    dilation_rate=1,
):
    data_format = backend.standardize_data_format(data_format)
    num_spatial_dims = inputs.ndim - 2
    dimension_numbers = _convert_to_lax_conv_dimension_numbers(
        num_spatial_dims,
        data_format,
        transpose=False,
    )
    strides = _convert_to_spatial_operand(
        strides,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )
    dilation_rate = _convert_to_spatial_operand(
        dilation_rate,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )
    feature_group_count = (
        inputs.shape[-1] if data_format == "channels_last" else inputs.shape[1]
    )
    kernel = np.reshape(
        kernel if is_tensor(kernel) else kernel.numpy(),
        kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),
    )
    return np.array(
        jax.lax.conv_general_dilated(
            inputs,
            kernel,
            strides,
            padding,
            rhs_dilation=dilation_rate,
            dimension_numbers=dimension_numbers,
            feature_group_count=feature_group_count,
        )
    )


def separable_conv(
    inputs,
    depthwise_kernel,
    pointwise_kernel,
    strides=1,
    padding="valid",
    data_format=None,
    dilation_rate=1,
):
    data_format = backend.standardize_data_format(data_format)
    depthwise_conv_output = depthwise_conv(
        inputs,
        depthwise_kernel,
        strides,
        padding,
        data_format,
        dilation_rate,
    )
    return conv(
        depthwise_conv_output,
        pointwise_kernel,
        strides=1,
        padding="valid",
        data_format=data_format,
        dilation_rate=dilation_rate,
    )


def conv_transpose(
    inputs,
    kernel,
    strides=1,
    padding="valid",
    output_padding=None,
    data_format=None,
    dilation_rate=1,
):
    data_format = backend.standardize_data_format(data_format)
    num_spatial_dims = inputs.ndim - 2
    padding_values = compute_conv_transpose_padding_args_for_jax(
        input_shape=inputs.shape,
        kernel_shape=kernel.shape,
        strides=strides,
        padding=padding,
        output_padding=output_padding,
        dilation_rate=dilation_rate,
    )
    dimension_numbers = _convert_to_lax_conv_dimension_numbers(
        num_spatial_dims,
        data_format,
        transpose=False,
    )
    strides = _convert_to_spatial_operand(
        strides,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )
    dilation_rate = _convert_to_spatial_operand(
        dilation_rate,
        num_spatial_dims,
        data_format,
        include_batch_and_channels=False,
    )

    return np.array(
        jax.lax.conv_transpose(
            inputs,
            kernel if is_tensor(kernel) else kernel.numpy(),
            strides,
            padding=padding_values,
            rhs_dilation=dilation_rate,
            dimension_numbers=dimension_numbers,
            transpose_kernel=True,
        )
    )


def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
    if sparse:
        raise ValueError("Unsupported value `sparse=True` with numpy backend")
    x = convert_to_tensor(x)
    input_shape = x.shape

    x = x.reshape(-1)
    if not num_classes:
        num_classes = np.max(x) + 1

    batch_size = x.shape[0]
    categorical = np.zeros((batch_size, num_classes), dtype=dtype)
    valid_indices = x >= 0
    categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1

    # First, reshape the array with the extra dimension at the end
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)

    # Then, move this new dimension to the right place (according to axis)
    if axis != -1:
        categorical = np.moveaxis(categorical, -1, axis)

    return categorical


def multi_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
    if sparse:
        raise ValueError("Unsupported value `sparse=True` with numpy backend")
    x = convert_to_tensor(x)
    reduction_axis = 1 if len(x.shape) > 1 else 0
    outputs = np.max(
        one_hot(cast(x, "int32"), num_classes, axis=axis, dtype=dtype),
        axis=reduction_axis,
    )
    return outputs


def categorical_crossentropy(target, output, from_logits=False, axis=-1):
    target = np.array(target)
    output = np.array(output)

    if target.shape != output.shape:
        raise ValueError(
            "Arguments `target` and `output` must have the same shape. "
            "Received: "
            f"target.shape={target.shape}, output.shape={output.shape}"
        )
    if len(target.shape) < 1:
        raise ValueError(
            "Arguments `target` and `output` must be at least rank 1. "
            "Received: "
            f"target.shape={target.shape}, output.shape={output.shape}"
        )

    if from_logits:
        log_prob = log_softmax(output, axis=axis)
    else:
        output = output / np.sum(output, axis, keepdims=True)
        output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
        log_prob = np.log(output)
    return -np.sum(target * log_prob, axis=axis)


def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
    target = np.array(target, dtype="int32")
    output = np.array(output)
    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:
        target = np.squeeze(target, axis=-1)

    if len(output.shape) < 1:
        raise ValueError(
            "Argument `output` must be at least rank 1. "
            "Received: "
            f"output.shape={output.shape}"
        )
    if target.shape != output.shape[:-1]:
        raise ValueError(
            "Arguments `target` and `output` must have the same shape "
            "up until the last dimension: "
            f"target.shape={target.shape}, output.shape={output.shape}"
        )
    if from_logits:
        log_prob = log_softmax(output, axis=axis)
    else:
        output = output / np.sum(output, axis, keepdims=True)
        output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
        log_prob = np.log(output)
    target = one_hot(target, output.shape[axis], axis=axis)
    return -np.sum(target * log_prob, axis=axis)


def binary_crossentropy(target, output, from_logits=False):
    target = np.array(target)
    output = np.array(output)

    if target.shape != output.shape:
        raise ValueError(
            "Arguments `target` and `output` must have the same shape. "
            "Received: "
            f"target.shape={target.shape}, output.shape={output.shape}"
        )

    if from_logits:
        output = sigmoid(output)

    output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
    bce = target * np.log(output)
    bce += (1.0 - target) * np.log(1.0 - output)
    return -bce


def moments(x, axes, keepdims=False, synchronized=False):
    if synchronized:
        raise NotImplementedError(
            "Argument synchronized=True is not supported with NumPy."
        )
    axes = tuple(axes) if isinstance(axes, list) else axes
    # The dynamic range of float16 is too limited for statistics. As a
    # workaround, we simply perform the operations on float32 and convert back
    # to float16
    need_cast = False
    ori_dtype = backend.standardize_dtype(x.dtype)
    if ori_dtype == "float16":
        need_cast = True
        x = cast(x, "float32")

    mean = np.mean(x, axes, keepdims=True)

    # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
    # but less numerically stable.
    variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean)

    if not keepdims:
        mean = np.squeeze(mean, axes)
        variance = np.squeeze(variance, axes)
    if need_cast:
        # avoid overflow and underflow when casting from float16 to float32
        mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max)
        variance = np.clip(
            variance, np.finfo(np.float16).min, np.finfo(np.float16).max
        )
        mean = cast(mean, ori_dtype)
        variance = cast(variance, ori_dtype)
    return mean, variance


def batch_normalization(
    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
):
    shape = [1] * len(x.shape)
    shape[axis] = mean.shape[0]
    mean = np.reshape(mean, shape)
    variance = np.reshape(variance, shape)

    inv = 1.0 / np.sqrt(variance + epsilon)
    if scale is not None:
        scale = np.reshape(scale, shape)
        inv = inv * scale

    res = -mean * inv
    if offset is not None:
        offset = np.reshape(offset, shape)
        res = res + offset

    return x * inv + res


def ctc_loss(target, output, target_length, output_length, mask_index=0):
    # Ref: https://github.com/google-deepmind/optax
    # optax.ctc_loss_with_forward_probs
    target = convert_to_tensor(target, dtype="int32")
    output = convert_to_tensor(output)
    target_length = convert_to_tensor(target_length, "int32")
    output_length = convert_to_tensor(output_length, "int32")
    batch_size, max_input_length, num_classes = output.shape
    batch_size, max_label_length = target.shape
    log_epsilon = -1e5

    # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss`
    dtype = backend.result_type(output.dtype, "float32")
    output = output.astype(dtype)

    def _lengths_to_paddings(lengths, max_length):
        indices = np.arange(max_length).reshape(
            (1,) * lengths.ndim + (max_length,)
        )
        lengths = np.expand_dims(lengths, axis=-1)
        elem_valid = indices < lengths
        return np.logical_not(elem_valid)

    target_paddings = _lengths_to_paddings(target_length, max_label_length)
    output_paddings = _lengths_to_paddings(output_length, max_input_length)
    target_paddings = target_paddings.astype(output.dtype)
    output_paddings = output_paddings.astype(output.dtype)

    logprobs = log_softmax(output, axis=-1)
    label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype(
        np.int32
    )

    # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].
    repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32)
    repeat = np.pad(repeat, ((0, 0), (0, 1)))

    logprobs_phi = logprobs[:, :, mask_index : mask_index + 1]  # [B, T, 1]
    logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2))  # [T, B, 1]

    _one_hot = one_hot(target, num_classes=num_classes)  # [B, N, K]
    logprobs_emit = np.einsum("btk,bnk->btn", logprobs, _one_hot)
    logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2))  # [T, B, N]

    # [B, N]
    logalpha_phi_init = (
        np.ones((batch_size, max_label_length + 1), dtype=output.dtype)
        * log_epsilon
    )
    logalpha_phi_init[:, 0] = 0.0
    logalpha_emit_init = (
        np.ones((batch_size, max_label_length), dtype=output.dtype)
        * log_epsilon
    )

    def update_phi_score(phi, added_score):
        # Update `phi[:, 1:]`` with adding `added_score` in log space.
        return np.concatenate(
            [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1
        )

    def loop_body(prev, x):
        prev_phi, prev_emit = prev
        # emit-to-phi epsilon transition, except if the next label is repetition
        prev_phi_orig = prev_phi
        prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat)

        logprob_emit, logprob_phi, pad = x

        # phi-to-emit transition
        next_emit = np.logaddexp(
            prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit
        )
        # self-loop transition
        next_phi = prev_phi + logprob_phi
        # emit-to-phi blank transition only when the next label is repetition
        next_phi = update_phi_score(
            next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)
        )

        pad = pad.reshape((batch_size, 1))
        next_emit = pad * prev_emit + (1.0 - pad) * next_emit
        next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi

        return (next_phi, next_emit), (next_phi, next_emit)

    def np_scan(f, init, xs):
        carry = init
        ys = []
        for x in zip(*xs):
            carry, y = f(carry, x)
            ys.append(y)
        result = []
        for i in range(len(ys[0])):
            result.append(np.stack([y[i] for y in ys]))
        return carry, result

    xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0)))
    _, (logalpha_phi, logalpha_emit) = np_scan(
        loop_body, (logalpha_phi_init, logalpha_emit_init), xs
    )

    # last row needs to be updated with the last epsilon transition
    logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1])
    logalpha_phi[-1] = logalpha_phi_last

    # extract per_seq_loss
    # [B, N+1]
    _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1)
    per_seq_loss = -np.einsum("bn,bn->b", logalpha_phi_last, _one_hot)
    return per_seq_loss


def _ctc_greedy_decode(
    inputs,
    sequence_lengths,
    merge_repeated=True,
    mask_index=None,
):
    inputs = convert_to_tensor(inputs)
    sequence_lengths = convert_to_tensor(sequence_lengths, dtype="int32")
    batch_size, max_length, num_classes = inputs.shape

    if mask_index is None:
        mask_index = num_classes - 1

    indices = np.argmax(inputs, axis=-1).astype("int32")
    scores = np.max(inputs, axis=-1)

    seqlen_mask = np.arange(max_length)[None, :]
    seqlen_mask = seqlen_mask >= sequence_lengths[:, None]

    indices = np.where(seqlen_mask, mask_index, indices)
    scores = np.where(seqlen_mask, 0.0, scores)

    if merge_repeated:
        repeat_mask = indices[:, 1:] == indices[:, :-1]
        repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0)))
        indices = np.where(repeat_mask, mask_index, indices)

    # We set to -1 for blank labels
    invalid_mask = indices == mask_index
    indices = np.where(invalid_mask, -1, indices)

    # We rearrange the indices by moving `mask_index` to the end of the array
    order = np.expand_dims(np.arange(max_length), axis=0)  # [1, N]
    order = np.tile(order, (batch_size, 1))  # [B, N]
    order = np.where(invalid_mask, max_length, order)
    order = np.argsort(order, axis=-1)
    indices = np.take_along_axis(indices, order, axis=-1)

    scores = -np.sum(scores, axis=1)[:, None]
    indices = np.expand_dims(indices, axis=0)
    return indices, scores


def _ctc_beam_search_decode(
    inputs,
    sequence_lengths,
    beam_width=100,
    top_paths=1,
    mask_index=None,
):
    inputs = convert_to_tensor(inputs)
    sequence_lengths = convert_to_tensor(sequence_lengths)

    batch_size, max_seq_len, num_classes = inputs.shape
    inputs = log_softmax(inputs, axis=-1)
    seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]

    if mask_index is None:
        mask_index = num_classes - 1

    # This is a workaround for the fact that np.argsort does not support
    # the order parameter which is used to break ties when scores are equal.
    # For compatibility with the tensorflow implementation, we flip the inputs
    # and the mask_index, and then flip the classes back to the correct indices
    inputs = np.flip(inputs, axis=2)
    mask_index = num_classes - mask_index - 1

    _pad = -1

    init_paths = np.full(
        (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32
    )

    num_init_paths = np.min(np.array([num_classes, beam_width]))
    max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:]
    init_classes = np.where(max_classes == mask_index, _pad, max_classes)
    init_paths[:, :num_init_paths, 0] = init_classes

    init_scores = np.full(
        (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype
    )
    init_scores[:, :num_init_paths] = np.take_along_axis(
        inputs[:, 0], max_classes, axis=1
    )
    init_masked = init_paths[:, :, 0] == _pad

    def _extend_paths(paths, scores, masked, x):
        paths = np.repeat(paths, num_classes, axis=0)
        scores = np.repeat(scores, num_classes)
        masked = np.repeat(masked, num_classes)

        path_tail_index = np.argmax(paths == _pad, axis=1)
        paths_arange = np.arange(2 * beam_width * num_classes)
        path_tails = paths[paths_arange, path_tail_index - 1]
        path_tails = np.where(path_tail_index == 0, _pad, path_tails)

        classes = np.arange(num_classes)
        classes[mask_index] = _pad
        classes = np.tile(classes, 2 * beam_width)

        prev_masked = masked
        masked = classes == _pad

        masked_repeat = ~prev_masked & (path_tails == classes)
        classes = np.where(masked_repeat, _pad, classes)
        paths[paths_arange, path_tail_index] = classes

        x = np.tile(x, 2 * beam_width)
        scores = scores + x

        return paths, scores, masked

    def _merge_scores(unique_inverse, scores):
        scores_max = np.max(scores)
        scores_exp = np.exp(scores - scores_max)
        scores = np.zeros_like(scores)
        for i, u in enumerate(unique_inverse):
            scores[u] += scores_exp[i]
        scores = np.log(scores) + scores_max
        return scores

    def _prune_paths(paths, scores, masked):
        paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0)
        pad_size = (2 * num_classes * beam_width) - len(paths)
        if pad_size > 0:
            paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad)
        paths = paths[: 2 * num_classes * beam_width]
        if len(unique_inverse.shape) >= 2:
            unique_inverse = np.squeeze(unique_inverse, axis=1)

        emit_scores = np.where(masked, -np.inf, scores)
        mask_scores = np.where(masked, scores, -np.inf)

        emit_scores = _merge_scores(unique_inverse, emit_scores)
        mask_scores = _merge_scores(unique_inverse, mask_scores)

        total_scores = np.logaddexp(emit_scores, mask_scores)
        top_indices = np.argsort(total_scores, kind="stable")[-beam_width:]

        paths = paths[top_indices]
        emit_scores = emit_scores[top_indices]
        mask_scores = mask_scores[top_indices]

        paths = np.tile(paths, (2, 1))
        scores = np.concatenate([emit_scores, mask_scores])
        masked = np.concatenate(
            [np.zeros(beam_width, bool), np.ones(beam_width, bool)]
        )

        return paths, scores, masked

    def _decode_step(paths, scores, masked, x):
        paths, scores, masked = _extend_paths(paths, scores, masked, x)
        paths, scores, masked = _prune_paths(paths, scores, masked)
        return paths, scores, masked

    def _step(prev, x):
        paths, scores, masked = prev
        x, seqlen_mask = x
        if not seqlen_mask:
            paths, scores, masked = _decode_step(paths, scores, masked, x)
        return (paths, scores, masked), None

    def _decode_batch(
        init_paths, init_scores, init_masked, inputs, seqlen_mask
    ):
        def np_scan_only_carry(f, init, xs):
            carry = init
            for x in zip(*xs):
                carry, y = f(carry, x)
            return carry, None

        (paths, scores, masked), _ = np_scan_only_carry(
            _step,
            (init_paths, init_scores, init_masked),
            (inputs[1:], seqlen_mask[1:]),
        )

        paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0)
        pad_size = (2 * num_classes * beam_width) - len(paths)
        if pad_size > 0:
            paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad)
        paths = paths[: 2 * num_classes * beam_width]
        if len(unique_inverse.shape) >= 2:
            unique_inverse = np.squeeze(unique_inverse, axis=1)
        scores = _merge_scores(unique_inverse, scores)

        top_indices = np.argsort(scores)[-top_paths:][::-1]
        paths = paths[top_indices]
        scores = scores[top_indices]

        return paths, scores

    results = [
        _decode_batch(p, s, m, i, sm)
        for p, s, m, i, sm in zip(
            init_paths, init_scores, init_masked, inputs, seqlen_mask
        )
    ]
    paths = np.stack([r[0] for r in results])
    scores = np.stack([r[1] for r in results])

    # convert classes back to the correct indices
    paths = np.where(paths == _pad, _pad, num_classes - paths - 1)
    paths = np.transpose(paths, [1, 0, 2])
    return paths, scores


def ctc_decode(
    inputs,
    sequence_lengths,
    strategy="greedy",
    beam_width=100,
    top_paths=1,
    merge_repeated=True,
    mask_index=0,
):
    inputs = convert_to_tensor(inputs)
    dtype = backend.result_type(inputs.dtype, "float32")
    inputs = cast(inputs, dtype)

    if strategy == "greedy":
        return _ctc_greedy_decode(
            inputs,
            sequence_lengths,
            merge_repeated=merge_repeated,
            mask_index=mask_index,
        )
    elif strategy == "beam_search":
        return _ctc_beam_search_decode(
            inputs,
            sequence_lengths,
            beam_width=beam_width,
            top_paths=top_paths,
            mask_index=mask_index,
        )
    else:
        raise ValueError(
            f"Invalid strategy {strategy}. Supported values are "
            "'greedy' and 'beam_search'."
        )


def psnr(x1, x2, max_val):
    if x1.shape != x2.shape:
        raise ValueError(
            f"Input shapes {x1.shape} and {x2.shape} must "
            "match for PSNR calculation. "
        )

    max_val = convert_to_tensor(max_val, dtype=x2.dtype)
    mse = np.mean(np.square(x1 - x2))
    psnr = 20 * np.log10(max_val) - 10 * np.log10(mse)
    return psnr


def _get_large_negative(dtype):
    dtype = backend.standardize_dtype(dtype)
    val = 65500.0 if dtype == "float16" else 3.38953e38
    return np.asarray(val * -0.7, dtype=dtype)


def _apply_masks(logits, mask, is_causal):
    if mask is None and not is_causal:
        return logits

    combined_mask = np.ones_like(logits, dtype=np.bool_)
    if mask is not None:
        combined_mask = np.logical_and(combined_mask, mask)

    if is_causal:
        T, S = logits.shape[2], logits.shape[3]
        mask = np.tril(np.ones((T, S), dtype=np.bool_))
        mask = mask[None, None, :, :]
        combined_mask = np.logical_and(combined_mask, mask)

    padded_logits = np.where(
        combined_mask, logits, _get_large_negative(logits.dtype)
    )
    return padded_logits


def _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):
    original_dtype = key.dtype
    logits_dtype = np.promote_types(query.dtype, np.float32)
    if backend.standardize_dtype(key.dtype) == "bfloat16":
        # `np.einsum` doesn't support bfloat16
        key = key.astype("float32")
        value = value.astype("float32")
    logits = np.einsum("BTNH,BSNH->BNTS", query, key)
    logits = logits.astype(logits_dtype)
    logits *= np.array(scale, dtype=logits.dtype)

    if bias is not None:
        logits = (logits + bias).astype(logits.dtype)

    padded_logits = _apply_masks(logits, mask, is_causal)

    # Softmax and it is always carried out in fp32.
    padded_logits = padded_logits.astype(np.float32)
    probs = softmax(padded_logits, axis=-1).astype(original_dtype)
    encoded_dtype = probs.dtype
    if backend.standardize_dtype(probs.dtype) == "bfloat16":
        # `np.einsum` doesn't support bfloat16
        probs = probs.astype("float32")
        value = value.astype("float32")
    encoded = np.einsum("BNTS,BSNH->BTNH", probs, value)
    encoded = encoded.astype(encoded_dtype)
    return encoded


def dot_product_attention(
    query,
    key,
    value,
    bias=None,
    mask=None,
    scale=None,
    is_causal=False,
    flash_attention=None,
    attn_logits_soft_cap=None,
):
    if flash_attention is None:
        flash_attention = False
    if flash_attention:
        raise ValueError("Flash attention is not supported in numpy backend.")

    # Ref: jax.nn.dot_product_attention
    # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828
    # Not support `query_seq_lengths` and `key_value_seq_lengths` args
    query = convert_to_tensor(query)
    key = convert_to_tensor(key)
    value = convert_to_tensor(value)
    if len(query.shape) != 4:
        raise ValueError(
            "`dot_product_attention` only supports 4D inputs. "
            f"Received: query.shape={query.shape}, key.shape={key.shape}, "
            f"value.shape={value.shape}."
        )
    _, _, _, H = key.shape
    scale = (1.0 / np.sqrt(H)) if scale is None else scale
    return _dot_product_attention_xla(
        query, key, value, bias, mask, is_causal, scale
    )
