from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
from keras.src.saving import serialization_lib


@keras_export("keras.layers.Rescaling")
class Rescaling(TFDataLayer):
    """A preprocessing layer which rescales input values to a new range.

    This layer rescales every value of an input (often an image) by multiplying
    by `scale` and adding `offset`.

    For instance:

    1. To rescale an input in the `[0, 255]` range
    to be in the `[0, 1]` range, you would pass `scale=1./255`.

    2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
    you would pass `scale=1./127.5, offset=-1`.

    The rescaling is applied both during training and inference. Inputs can be
    of integer or floating point dtype, and by default the layer will output
    floats.

    **Note:** This layer is safe to use inside a `tf.data` pipeline
    (independently of which backend you're using).

    Args:
        scale: Float, the scale to apply to the inputs.
        offset: Float, the offset to apply to the inputs.
        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.
    """

    def __init__(self, scale, offset=0.0, **kwargs):
        super().__init__(**kwargs)
        self.scale = scale
        self.offset = offset
        self.supports_masking = True

    def call(self, inputs):
        dtype = self.compute_dtype
        scale = self.backend.cast(self.scale, dtype)
        offset = self.backend.cast(self.offset, dtype)
        scale_shape = self.backend.core.shape(scale)
        if (
            len(scale_shape) > 0
            and backend.image_data_format() == "channels_first"
        ):
            scale = self.backend.numpy.reshape(
                scale, scale_shape + (1,) * (3 - len(scale_shape))
            )
        return self.backend.cast(inputs, dtype) * scale + offset

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                # `scale` and `offset` might be numpy array.
                "scale": serialization_lib.serialize_keras_object(self.scale),
                "offset": serialization_lib.serialize_keras_object(self.offset),
            }
        )
        return config

    @classmethod
    def from_config(cls, config, custom_objects=None):
        config = config.copy()
        config["scale"] = serialization_lib.deserialize_keras_object(
            config["scale"], custom_objects=custom_objects
        )
        config["offset"] = serialization_lib.deserialize_keras_object(
            config["offset"], custom_objects=custom_objects
        )
        return cls(**config)
