from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.attention.attention import Attention


@keras_export("keras.layers.AdditiveAttention")
class AdditiveAttention(Attention):
    """Additive attention layer, a.k.a. Bahdanau-style attention.

    Inputs are a list with 2 or 3 elements:
    1. A `query` tensor of shape `(batch_size, Tq, dim)`.
    2. A `value` tensor of shape `(batch_size, Tv, dim)`.
    3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none
        supplied, `value` will be used as `key`.

    The calculation follows the steps:
    1. Calculate attention scores using `query` and `key` with shape
        `(batch_size, Tq, Tv)` as a non-linear sum
        `scores = reduce_sum(tanh(query + key), axis=-1)`.
    2. Use scores to calculate a softmax distribution with shape
        `(batch_size, Tq, Tv)`.
    3. Use the softmax distribution to create a linear combination of `value`
        with shape `(batch_size, Tq, dim)`.

    Args:
        use_scale: If `True`, will create a scalar variable to scale the
            attention scores.
        dropout: Float between 0 and 1. Fraction of the units to drop for the
            attention scores. Defaults to `0.0`.

    Call arguments:
        inputs: List of the following tensors:
            - `query`: Query tensor of shape `(batch_size, Tq, dim)`.
            - `value`: Value tensor of shape `(batch_size, Tv, dim)`.
            - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If
                not given, will use `value` for both `key` and `value`, which is
                the most common case.
        mask: List of the following tensors:
            - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.
                If given, the output will be zero at the positions where
                `mask==False`.
            - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.
                If given, will apply the mask such that values at positions
                 where `mask==False` do not contribute to the result.
        return_attention_scores: bool, it `True`, returns the attention scores
            (after masking and softmax) as an additional output argument.
        training: Python boolean indicating whether the layer should behave in
            training mode (adding dropout) or in inference mode (no dropout).
        use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds
            a mask such that position `i` cannot attend to positions `j > i`.
            This prevents the flow of information from the future towards the
            past. Defaults to `False`.

    Output:
        Attention outputs of shape `(batch_size, Tq, dim)`.
        (Optional) Attention scores after masking and softmax with shape
            `(batch_size, Tq, Tv)`.
    """

    def __init__(
        self,
        use_scale=True,
        dropout=0.0,
        **kwargs,
    ):
        super().__init__(use_scale=use_scale, dropout=dropout, **kwargs)

    def build(self, input_shape):
        self._validate_inputs(input_shape)
        dim = input_shape[0][-1]
        self.scale = None
        if self.use_scale:
            self.scale = self.add_weight(
                name="scale",
                shape=[dim],
                initializer="glorot_uniform",
                dtype=self.dtype,
                trainable=True,
            )

    def _calculate_scores(self, query, key):
        """Calculates attention scores as a nonlinear sum of query and key.

        Args:
            query: Query tensor of shape `(batch_size, Tq, dim)`.
            key: Key tensor of shape `(batch_size, Tv, dim)`.

        Returns:
            Tensor of shape `(batch_size, Tq, Tv)`.
        """
        # Reshape tensors to enable broadcasting.
        # Reshape into [batch_size, Tq, 1, dim].
        q_reshaped = ops.expand_dims(query, axis=-2)
        # Reshape into [batch_size, 1, Tv, dim].
        k_reshaped = ops.expand_dims(key, axis=-3)
        scale = self.scale if self.use_scale else 1.0
        return ops.sum(scale * ops.tanh(q_reshaped + k_reshaped), axis=-1)

    def get_config(self):
        base_config = super().get_config()
        del base_config["score_mode"]
        return base_config
