import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export


@keras_export("keras.visualization.draw_segmentation_masks")
def draw_segmentation_masks(
    images,
    segmentation_masks,
    num_classes=None,
    color_mapping=None,
    alpha=0.8,
    blend=True,
    ignore_index=-1,
    data_format=None,
):
    """Draws segmentation masks on images.

    The function overlays segmentation masks on the input images.
    The masks are blended with the images using the specified alpha value.

    Args:
        images: A batch of images as a 4D tensor or NumPy array. Shape
            should be (batch_size, height, width, channels).
        segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor
            or NumPy array.  Shape should be (batch_size, height, width) or
            (batch_size, height, width, 1). The values represent class indices
            starting from 1 up to `num_classes`. Class 0 is reserved for
            the background and will be ignored if `ignore_index` is not 0.
        num_classes: The number of segmentation classes. If `None`, it is
            inferred from the maximum value in `segmentation_masks`.
        color_mapping: A dictionary mapping class indices to RGB colors.
            If `None`, a default color palette is generated. The keys should be
            integers starting from 1 up to `num_classes`.
        alpha: The opacity of the segmentation masks. Must be in the range
            `[0, 1]`.
        blend: Whether to blend the masks with the input image using the
            `alpha` value. If `False`, the masks are drawn directly on the
            images without blending. Defaults to `True`.
        ignore_index: The class index to ignore. Mask pixels with this value
            will not be drawn.  Defaults to -1.
        data_format: Image data format, either `"channels_last"` or
            `"channels_first"`. Defaults to the `image_data_format` value found
            in your Keras config file at `~/.keras/keras.json`. If you never
            set it, then it will be `"channels_last"`.

    Returns:
        A NumPy array of the images with the segmentation masks overlaid.

    Raises:
        ValueError: If the input `images` is not a 4D tensor or NumPy array.
        TypeError: If the input `segmentation_masks` is not an integer type.
    """
    data_format = data_format or backend.image_data_format()
    images_shape = ops.shape(images)
    if len(images_shape) != 4:
        raise ValueError(
            "`images` must be batched 4D tensor. "
            f"Received: images.shape={images_shape}"
        )
    if data_format == "channels_first":
        images = ops.transpose(images, (0, 2, 3, 1))
        segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))
    images = ops.convert_to_tensor(images, dtype="float32")
    segmentation_masks = ops.convert_to_tensor(segmentation_masks)

    if not backend.is_int_dtype(segmentation_masks.dtype):
        dtype = backend.standardize_dtype(segmentation_masks.dtype)
        raise TypeError(
            "`segmentation_masks` must be in integer dtype. "
            f"Received: segmentation_masks.dtype={dtype}"
        )

    # Infer num_classes
    if num_classes is None:
        num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))
    if color_mapping is None:
        colors = _generate_color_palette(num_classes)
    else:
        colors = [color_mapping[i] for i in range(num_classes)]
    valid_masks = ops.not_equal(segmentation_masks, ignore_index)
    valid_masks = ops.squeeze(valid_masks, axis=-1)
    segmentation_masks = ops.one_hot(segmentation_masks, num_classes)
    segmentation_masks = segmentation_masks[..., 0, :]
    segmentation_masks = ops.convert_to_numpy(segmentation_masks)

    # Replace class with color
    masks = segmentation_masks
    masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool")
    images_to_draw = ops.convert_to_numpy(images).copy()
    for mask, color in zip(masks, colors):
        color = np.array(color, dtype=images_to_draw.dtype)
        images_to_draw[mask, ...] = color[None, :]
    images_to_draw = ops.convert_to_tensor(images_to_draw)
    outputs = ops.cast(images_to_draw, dtype="float32")

    if blend:
        outputs = images * (1 - alpha) + outputs * alpha
        outputs = ops.where(valid_masks[..., None], outputs, images)
        outputs = ops.cast(outputs, dtype="uint8")
        outputs = ops.convert_to_numpy(outputs)
    return outputs


def _generate_color_palette(num_classes):
    palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
    return [((i * palette) % 255).tolist() for i in range(num_classes)]
