import tensorflow as tf

from keras.src.backend import config
from keras.src.backend import standardize_dtype
from keras.src.backend.common import dtypes
from keras.src.backend.tensorflow.core import cast
from keras.src.backend.tensorflow.core import convert_to_tensor


def segment_sum(data, segment_ids, num_segments=None, sorted=False):
    if sorted:
        if num_segments is not None:
            raise ValueError(
                "Argument `num_segments` cannot be set when sorted is True "
                "when using the tensorflow backend."
                f"Received: num_segments={num_segments}, sorted={sorted}."
            )
        return tf.math.segment_sum(data, segment_ids)
    else:
        if num_segments is None:
            unique_segment_ids, _ = tf.unique(segment_ids)
            num_segments = tf.shape(unique_segment_ids)[0]
        return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)


def segment_max(data, segment_ids, num_segments=None, sorted=False):
    if sorted:
        if num_segments is not None:
            raise ValueError(
                "Argument `num_segments` cannot be set when sorted is True "
                "when using the tensorflow backend."
                f"Received: num_segments={num_segments}, sorted={sorted}."
            )
        return tf.math.segment_max(data, segment_ids)
    else:
        if num_segments is None:
            unique_segment_ids, _ = tf.unique(segment_ids)
            num_segments = tf.shape(unique_segment_ids)[0]
        return tf.math.unsorted_segment_max(data, segment_ids, num_segments)


def top_k(x, k, sorted=True):
    return tf.math.top_k(x, k, sorted=sorted)


def in_top_k(targets, predictions, k):
    return tf.math.in_top_k(targets, predictions, k)


def logsumexp(x, axis=None, keepdims=False):
    return tf.math.reduce_logsumexp(x, axis=axis, keepdims=keepdims)


def qr(x, mode="reduced"):
    if mode not in {"reduced", "complete"}:
        raise ValueError(
            "`mode` argument value not supported. "
            "Expected one of {'reduced', 'complete'}. "
            f"Received: mode={mode}"
        )
    if mode == "reduced":
        return tf.linalg.qr(x)
    return tf.linalg.qr(x, full_matrices=True)


def extract_sequences(x, sequence_length, sequence_stride):
    return tf.signal.frame(
        x,
        frame_length=sequence_length,
        frame_step=sequence_stride,
        axis=-1,
        pad_end=False,
    )


def _get_complex_tensor_from_tuple(x):
    if not isinstance(x, (tuple, list)) or len(x) != 2:
        raise ValueError(
            "Input `x` should be a tuple of two tensors - real and imaginary."
            f"Received: x={x}"
        )
    # `convert_to_tensor` does not support passing complex tensors. We separate
    # the input out into real and imaginary and convert them separately.
    real, imag = x
    real = convert_to_tensor(real)
    imag = convert_to_tensor(imag)
    # Check shapes.
    if real.shape != imag.shape:
        raise ValueError(
            "Input `x` should be a tuple of two tensors - real and imaginary."
            "Both the real and imaginary parts should have the same shape. "
            f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}"
        )
    # Ensure dtype is float.
    if not real.dtype.is_floating or not imag.dtype.is_floating:
        raise ValueError(
            "At least one tensor in input `x` is not of type float."
            f"Received: x={x}."
        )
    complex_input = tf.dtypes.complex(real, imag)
    return complex_input


def fft(x):
    complex_input = _get_complex_tensor_from_tuple(x)
    complex_output = tf.signal.fft(complex_input)
    return tf.math.real(complex_output), tf.math.imag(complex_output)


def fft2(x):
    complex_input = _get_complex_tensor_from_tuple(x)
    complex_output = tf.signal.fft2d(complex_input)
    return tf.math.real(complex_output), tf.math.imag(complex_output)


def ifft2(x):
    real, imag = x
    h = cast(tf.shape(real)[-2], real.dtype)
    w = cast(tf.shape(real)[-1], real.dtype)
    real_conj, imag_conj = real, -imag
    fft_real, fft_imag = fft2((real_conj, imag_conj))
    return fft_real / (h * w), -fft_imag / (h * w)


def rfft(x, fft_length=None):
    if fft_length is not None:
        fft_length = [fft_length]
    complex_output = tf.signal.rfft(x, fft_length=fft_length)
    return tf.math.real(complex_output), tf.math.imag(complex_output)


def irfft(x, fft_length=None):
    complex_input = _get_complex_tensor_from_tuple(x)
    if fft_length is not None:
        fft_length = [fft_length]
    return tf.signal.irfft(complex_input, fft_length)


def stft(
    x, sequence_length, sequence_stride, fft_length, window="hann", center=True
):
    if standardize_dtype(x.dtype) not in {"float32", "float64"}:
        raise TypeError(
            "Invalid input type. Expected `float32` or `float64`. "
            f"Received: input type={x.dtype}"
        )
    if fft_length < sequence_length:
        raise ValueError(
            "`fft_length` must equal or larger than `sequence_length`. "
            f"Received: sequence_length={sequence_length}, "
            f"fft_length={fft_length}"
        )
    if isinstance(window, str):
        if window not in {"hann", "hamming"}:
            raise ValueError(
                "If a string is passed to `window`, it must be one of "
                f'`"hann"`, `"hamming"`. Received: window={window}'
            )
    x = convert_to_tensor(x)

    if center:
        pad_width = [(0, 0) for _ in range(len(x.shape))]
        pad_width[-1] = (fft_length // 2, fft_length // 2)
        x = tf.pad(x, pad_width, mode="reflect")

    l_pad = (fft_length - sequence_length) // 2
    r_pad = fft_length - sequence_length - l_pad

    if window is not None:
        if isinstance(window, str):
            if window == "hann":
                win_array = tf.signal.hann_window(
                    sequence_length, periodic=True, dtype=x.dtype
                )
            else:
                win_array = tf.signal.hamming_window(
                    sequence_length, periodic=True, dtype=x.dtype
                )
        else:
            win_array = convert_to_tensor(window, dtype=x.dtype)
        if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length:
            raise ValueError(
                "The shape of `window` must be equal to [sequence_length]."
                f"Received: window shape={win_array.shape}"
            )
        win_array = tf.pad(win_array, [[l_pad, r_pad]])

        def win(frame_step, dtype):
            return win_array

    else:
        win = None

    result = tf.signal.stft(
        x,
        frame_length=(sequence_length + l_pad + r_pad),
        frame_step=sequence_stride,
        fft_length=fft_length,
        window_fn=win,
    )
    return tf.math.real(result), tf.math.imag(result)


def istft(
    x,
    sequence_length,
    sequence_stride,
    fft_length,
    length=None,
    window="hann",
    center=True,
):
    complex_input = _get_complex_tensor_from_tuple(x)
    dtype = tf.math.real(complex_input).dtype

    expected_output_len = fft_length + sequence_stride * (
        tf.shape(complex_input)[-2] - 1
    )
    l_pad = (fft_length - sequence_length) // 2
    r_pad = fft_length - sequence_length - l_pad

    if window is not None:
        if isinstance(window, str):
            if window == "hann":
                win_array = tf.signal.hann_window(
                    sequence_length, periodic=True, dtype=dtype
                )
            else:
                win_array = tf.signal.hamming_window(
                    sequence_length, periodic=True, dtype=dtype
                )
        else:
            win_array = convert_to_tensor(window, dtype=dtype)
        if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length:
            raise ValueError(
                "The shape of `window` must be equal to [sequence_length]."
                f"Received: window shape={win_array.shape}"
            )
        win_array = tf.pad(win_array, [[l_pad, r_pad]])
        win = tf.signal.inverse_stft_window_fn(
            sequence_stride, lambda frame_step, dtype: win_array
        )
    else:
        win = None

    x = tf.signal.inverse_stft(
        complex_input,
        frame_length=(sequence_length + l_pad + r_pad),
        frame_step=sequence_stride,
        fft_length=fft_length,
        window_fn=win,
    )

    start = 0 if center is False else fft_length // 2
    if length is not None:
        end = start + length
    elif center is True:
        end = -(fft_length // 2)
    else:
        end = expected_output_len
    return x[..., start:end]


def rsqrt(x):
    return tf.math.rsqrt(x)


def erf(x):
    return tf.math.erf(x)


def erfinv(x):
    return tf.math.erfinv(x)


def solve(a, b):
    a = convert_to_tensor(a)
    b = convert_to_tensor(b)
    return tf.linalg.solve(a, b)


def norm(x, ord=None, axis=None, keepdims=False):
    from keras.src.backend.tensorflow.numpy import moveaxis

    x = convert_to_tensor(x)
    x_shape = x.shape
    ndim = x_shape.rank

    if axis is None:
        axis = tuple(range(ndim))
    elif isinstance(axis, int):
        axis = (axis,)

    axis = axis[0] if len(axis) == 1 else axis
    num_axes = 1 if isinstance(axis, int) else len(axis)

    if num_axes == 1 and ord is None:
        ord = "euclidean"
    elif num_axes == 2 and ord is None:
        ord = "fro"

    if standardize_dtype(x.dtype) == "int64":
        dtype = config.floatx()
    else:
        dtype = dtypes.result_type(x.dtype, float)
    x = cast(x, dtype)

    # Fast path to utilize `tf.linalg.norm`
    if (num_axes == 1 and ord in ("euclidean", 1, 2, float("inf"))) or (
        num_axes == 2 and ord in ("euclidean", "fro", 1, 2, float("inf"))
    ):
        return tf.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)

    # Ref: jax.numpy.linalg.norm
    if num_axes == 1 and ord not in ("fro", "nuc"):
        if ord == float("-inf"):
            return tf.math.reduce_min(
                tf.math.abs(x), axis=axis, keepdims=keepdims
            )
        elif ord == 0:
            return tf.math.reduce_sum(
                tf.cast(tf.not_equal(x, 0), dtype=x.dtype),
                axis=axis,
                keepdims=keepdims,
            )
        else:
            ord = convert_to_tensor(ord, dtype=x.dtype)
            out = tf.math.reduce_sum(
                tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims
            )
            return tf.pow(out, 1.0 / ord)
    elif num_axes == 2 and ord in ("nuc", float("-inf"), -2, -1):
        row_axis, col_axis = axis[0], axis[1]
        row_axis = row_axis + ndim if row_axis < 0 else row_axis
        col_axis = col_axis + ndim if col_axis < 0 else col_axis
        if ord == float("-inf"):
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            x = tf.math.reduce_min(
                tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
                axis=row_axis,
                keepdims=keepdims,
            )
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            x = tf.math.reduce_min(
                tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),
                axis=col_axis,
                keepdims=keepdims,
            )
        else:
            x = moveaxis(x, axis, (-2, -1))
            if ord == -2:
                x = tf.math.reduce_min(
                    tf.linalg.svd(x, compute_uv=False), axis=-1
                )
            else:
                x = tf.math.reduce_sum(
                    tf.linalg.svd(x, compute_uv=False), axis=-1
                )
            if keepdims:
                x = tf.expand_dims(x, axis[0])
                x = tf.expand_dims(x, axis[1])
        return x

    if num_axes == 1:
        raise ValueError(
            f"Invalid `ord` argument for vector norm. Received: ord={ord}"
        )
    elif num_axes == 2:
        raise ValueError(
            f"Invalid `ord` argument for matrix norm. Received: ord={ord}"
        )
    else:
        raise ValueError(f"Invalid axis values. Received: axis={axis}")


def logdet(x):
    x = convert_to_tensor(x)
    return tf.linalg.logdet(x)
