import warnings

from keras.src import backend
from keras.src import layers
from keras.src.api_export import keras_export
from keras.src.applications import imagenet_utils
from keras.src.models import Functional
from keras.src.ops import operation_utils
from keras.src.utils import file_utils

BASE_WEIGHT_PATH = (
    "https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/"
)
WEIGHTS_HASHES = {
    "large_224_0.75_float": (
        "765b44a33ad4005b3ac83185abf1d0eb",
        "40af19a13ebea4e2ee0c676887f69a2e",
    ),
    "large_224_1.0_float": (
        "59e551e166be033d707958cf9e29a6a7",
        "07fb09a5933dd0c8eaafa16978110389",
    ),
    "large_minimalistic_224_1.0_float": (
        "675e7b876c45c57e9e63e6d90a36599c",
        "ec5221f64a2f6d1ef965a614bdae7973",
    ),
    "small_224_0.75_float": (
        "cb65d4e5be93758266aa0a7f2c6708b7",
        "ebdb5cc8e0b497cd13a7c275d475c819",
    ),
    "small_224_1.0_float": (
        "8768d4c2e7dee89b9d02b2d03d65d862",
        "d3e8ec802a04aa4fc771ee12a9a9b836",
    ),
    "small_minimalistic_224_1.0_float": (
        "99cd97fb2fcdad2bf028eb838de69e37",
        "cde8136e733e811080d9fcd8a252f7e4",
    ),
}


BASE_DOCSTRING = """Instantiates the {name} architecture.

Reference:
- [Searching for MobileNetV3](
    https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019)

The following table describes the performance of MobileNets v3:
------------------------------------------------------------------------
MACs stands for Multiply Adds

|Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)|
|---|---|---|---|---|
| mobilenet_v3_large_1.0_224              | 217 | 5.4 |   75.6   |   51.2  |
| mobilenet_v3_large_0.75_224             | 155 | 4.0 |   73.3   |   39.8  |
| mobilenet_v3_large_minimalistic_1.0_224 | 209 | 3.9 |   72.3   |   44.1  |
| mobilenet_v3_small_1.0_224              | 66  | 2.9 |   68.1   |   15.8  |
| mobilenet_v3_small_0.75_224             | 44  | 2.4 |   65.4   |   12.8  |
| mobilenet_v3_small_minimalistic_1.0_224 | 65  | 2.0 |   61.9   |   12.2  |

For image classification use cases, see
[this page for detailed examples](
https://keras.io/api/applications/#usage-examples-for-image-classification-models).

For transfer learning use cases, make sure to read the
[guide to transfer learning & fine-tuning](
https://keras.io/guides/transfer_learning/).

Note: each Keras Application expects a specific kind of input preprocessing.
For MobileNetV3, by default input preprocessing is included as a part of the
model (as a `Rescaling` layer), and thus
`keras.applications.mobilenet_v3.preprocess_input` is actually a
pass-through function. In this use case, MobileNetV3 models expect their
inputs to be float tensors of pixels with values in the `[0-255]` range.
At the same time, preprocessing as a part of the model (i.e. `Rescaling`
layer) can be disabled by setting `include_preprocessing` argument to `False`.
With preprocessing disabled MobileNetV3 models expect their inputs to be float
tensors of pixels with values in the `[-1, 1]` range.

Args:
    input_shape: Optional shape tuple, to be specified if you would
        like to use a model with an input image resolution that is not
        `(224, 224, 3)`.
        It should have exactly 3 inputs channels.
        You can also omit this option if you would like
        to infer input_shape from an input_tensor.
        If you choose to include both input_tensor and input_shape then
        input_shape will be used if they match, if the shapes
        do not match then we will throw an error.
        E.g. `(160, 160, 3)` would be one valid value.
    alpha: controls the width of the network. This is known as the
        depth multiplier in the MobileNetV3 paper, but the name is kept for
        consistency with MobileNetV1 in Keras.
        - If `alpha < 1.0`, proportionally decreases the number
            of filters in each layer.
        - If `alpha > 1.0`, proportionally increases the number
            of filters in each layer.
        - If `alpha == 1`, default number of filters from the paper
            are used at each layer.
    minimalistic: In addition to large and small models this module also
        contains so-called minimalistic models, these models have the same
        per-layer dimensions characteristic as MobilenetV3 however, they don't
        utilize any of the advanced blocks (squeeze-and-excite units,
        hard-swish, and 5x5 convolutions).
        While these models are less efficient on CPU, they
        are much more performant on GPU/DSP.
    include_top: Boolean, whether to include the fully-connected
        layer at the top of the network. Defaults to `True`.
    weights: String, one of `None` (random initialization),
        `"imagenet"` (pre-training on ImageNet),
        or the path to the weights file to be loaded.
    input_tensor: Optional Keras tensor (i.e. output of
        `layers.Input()`)
        to use as image input for the model.
    pooling: String, optional pooling mode for feature extraction
        when `include_top` is `False`.
        - `None` means that the output of the model
            will be the 4D tensor output of the
            last convolutional block.
        - `avg` means that global average pooling
            will be applied to the output of the
            last convolutional block, and thus
            the output of the model will be a
            2D tensor.
        - `max` means that global max pooling will
            be applied.
    classes: Integer, optional number of classes to classify images
        into, only to be specified if `include_top` is `True`, and
        if no `weights` argument is specified.
    dropout_rate: fraction of the input units to drop on the last layer.
    classifier_activation: A `str` or callable. The activation function to use
        on the "top" layer. Ignored unless `include_top=True`. Set
        `classifier_activation=None` to return the logits of the "top" layer.
        When loading pretrained weights, `classifier_activation` can only
        be `None` or `"softmax"`.
    include_preprocessing: Boolean, whether to include the preprocessing
        layer (`Rescaling`) at the bottom of the network. Defaults to `True`.
    name: String, the name of the model.

Call arguments:
    inputs: A floating point `numpy.array` or backend-native tensor,
        4D with 3 color channels, with values in the range `[0, 255]`
        if `include_preprocessing` is `True` and in the range `[-1, 1]`
        otherwise.

Returns:
    A model instance.
"""


def MobileNetV3(
    stack_fn,
    last_point_ch,
    input_shape=None,
    alpha=1.0,
    model_type="large",
    minimalistic=False,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    classes=1000,
    pooling=None,
    dropout_rate=0.2,
    classifier_activation="softmax",
    include_preprocessing=True,
    name=None,
):
    if not (weights in {"imagenet", None} or file_utils.exists(weights)):
        raise ValueError(
            "The `weights` argument should be either "
            "`None` (random initialization), `imagenet` "
            "(pre-training on ImageNet), "
            "or the path to the weights file to be loaded.  "
            f"Received weights={weights}"
        )

    if weights == "imagenet" and include_top and classes != 1000:
        raise ValueError(
            'If using `weights="imagenet"` with `include_top` '
            "as true, `classes` should be 1000.  "
            f"Received classes={classes}"
        )

    # Determine proper input shape and default size.
    # If both input_shape and input_tensor are used, they should match
    if input_shape is not None and input_tensor is not None:
        try:
            is_input_t_tensor = backend.is_keras_tensor(input_tensor)
        except ValueError:
            try:
                is_input_t_tensor = backend.is_keras_tensor(
                    operation_utils.get_source_inputs(input_tensor)
                )
            except ValueError:
                raise ValueError(
                    "input_tensor: ",
                    input_tensor,
                    "is not type input_tensor.  "
                    f"Received type(input_tensor)={type(input_tensor)}",
                )
        if is_input_t_tensor:
            if backend.image_data_format() == "channels_first":
                if input_tensor.shape[1] != input_shape[1]:
                    raise ValueError(
                        "When backend.image_data_format()=channels_first, "
                        "input_shape[1] must equal "
                        "input_tensor.shape[1].  Received "
                        f"input_shape={input_shape}, "
                        "input_tensor.shape="
                        f"{input_tensor.shape}"
                    )
            else:
                if input_tensor.shape[2] != input_shape[1]:
                    raise ValueError(
                        "input_shape[1] must equal "
                        "input_tensor.shape[2].  Received "
                        f"input_shape={input_shape}, "
                        "input_tensor.shape="
                        f"{input_tensor.shape}"
                    )
        else:
            raise ValueError(
                "input_tensor specified: ",
                input_tensor,
                "is not a keras tensor",
            )

    # If input_shape is None, infer shape from input_tensor
    if input_shape is None and input_tensor is not None:
        try:
            backend.is_keras_tensor(input_tensor)
        except ValueError:
            raise ValueError(
                "input_tensor: ",
                input_tensor,
                "is type: ",
                type(input_tensor),
                "which is not a valid type",
            )

        if backend.is_keras_tensor(input_tensor):
            if backend.image_data_format() == "channels_first":
                rows = input_tensor.shape[2]
                cols = input_tensor.shape[3]
                input_shape = (3, cols, rows)
            else:
                rows = input_tensor.shape[1]
                cols = input_tensor.shape[2]
                input_shape = (cols, rows, 3)
    # If input_shape is None and input_tensor is None using standard shape
    if input_shape is None and input_tensor is None:
        if backend.image_data_format() == "channels_last":
            input_shape = (None, None, 3)
        else:
            input_shape = (3, None, None)

    if backend.image_data_format() == "channels_last":
        row_axis, col_axis = (0, 1)
    else:
        row_axis, col_axis = (1, 2)
    rows = input_shape[row_axis]
    cols = input_shape[col_axis]
    if rows and cols and (rows < 32 or cols < 32):
        raise ValueError(
            "Input size must be at least 32x32; Received `input_shape="
            f"{input_shape}`"
        )
    if weights == "imagenet":
        if (
            not minimalistic
            and alpha not in [0.75, 1.0]
            or minimalistic
            and alpha != 1.0
        ):
            raise ValueError(
                "If imagenet weights are being loaded, "
                "alpha can be one of `0.75`, `1.0` for non minimalistic "
                "or `1.0` for minimalistic only."
            )

        if rows != cols or rows != 224:
            warnings.warn(
                "`input_shape` is undefined or non-square, "
                "or `rows` is not 224. "
                "Weights for input shape (224, 224) will be "
                "loaded as the default.",
                stacklevel=2,
            )

    if input_tensor is None:
        img_input = layers.Input(shape=input_shape)
    else:
        if not backend.is_keras_tensor(input_tensor):
            img_input = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    channel_axis = 1 if backend.image_data_format() == "channels_first" else -1

    if minimalistic:
        kernel = 3
        activation = relu
        se_ratio = None
    else:
        kernel = 5
        activation = hard_swish
        se_ratio = 0.25

    x = img_input
    if include_preprocessing:
        x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(x)
    x = layers.Conv2D(
        16,
        kernel_size=3,
        strides=(2, 2),
        padding="same",
        use_bias=False,
        name="conv",
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_bn"
    )(x)
    x = activation(x)

    x = stack_fn(x, kernel, activation, se_ratio)

    last_conv_ch = _depth(x.shape[channel_axis] * 6)

    # if the width multiplier is greater than 1 we
    # increase the number of output channels
    if alpha > 1.0:
        last_point_ch = _depth(last_point_ch * alpha)
    x = layers.Conv2D(
        last_conv_ch,
        kernel_size=1,
        padding="same",
        use_bias=False,
        name="conv_1",
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis, epsilon=1e-3, momentum=0.999, name="conv_1_bn"
    )(x)
    x = activation(x)
    if include_top:
        x = layers.GlobalAveragePooling2D(keepdims=True)(x)
        x = layers.Conv2D(
            last_point_ch,
            kernel_size=1,
            padding="same",
            use_bias=True,
            name="conv_2",
        )(x)
        x = activation(x)

        if dropout_rate > 0:
            x = layers.Dropout(dropout_rate)(x)
        x = layers.Conv2D(
            classes, kernel_size=1, padding="same", name="logits"
        )(x)
        x = layers.Flatten()(x)
        imagenet_utils.validate_activation(classifier_activation, weights)
        x = layers.Activation(
            activation=classifier_activation, name="predictions"
        )(x)
    else:
        if pooling == "avg":
            x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
        elif pooling == "max":
            x = layers.GlobalMaxPooling2D(name="max_pool")(x)
    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = operation_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input

    # Create model.
    model = Functional(inputs, x, name=name)

    # Load weights.
    if weights == "imagenet":
        model_name = "{}{}_224_{}_float".format(
            model_type, "_minimalistic" if minimalistic else "", str(alpha)
        )
        if include_top:
            file_name = "weights_mobilenet_v3_" + model_name + ".h5"
            file_hash = WEIGHTS_HASHES[model_name][0]
        else:
            file_name = "weights_mobilenet_v3_" + model_name + "_no_top_v2.h5"
            file_hash = WEIGHTS_HASHES[model_name][1]
        weights_path = file_utils.get_file(
            file_name,
            BASE_WEIGHT_PATH + file_name,
            cache_subdir="models",
            file_hash=file_hash,
        )
        model.load_weights(weights_path)
    elif weights is not None:
        model.load_weights(weights)

    return model


@keras_export("keras.applications.MobileNetV3Small")
def MobileNetV3Small(
    input_shape=None,
    alpha=1.0,
    minimalistic=False,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    classes=1000,
    pooling=None,
    dropout_rate=0.2,
    classifier_activation="softmax",
    include_preprocessing=True,
    name="MobileNetV3Small",
):
    def stack_fn(x, kernel, activation, se_ratio):
        def depth(d):
            return _depth(d * alpha)

        x = _inverted_res_block(x, 1, depth(16), 3, 2, se_ratio, relu, 0)
        x = _inverted_res_block(x, 72.0 / 16, depth(24), 3, 2, None, relu, 1)
        x = _inverted_res_block(x, 88.0 / 24, depth(24), 3, 1, None, relu, 2)
        x = _inverted_res_block(
            x, 4, depth(40), kernel, 2, se_ratio, activation, 3
        )
        x = _inverted_res_block(
            x, 6, depth(40), kernel, 1, se_ratio, activation, 4
        )
        x = _inverted_res_block(
            x, 6, depth(40), kernel, 1, se_ratio, activation, 5
        )
        x = _inverted_res_block(
            x, 3, depth(48), kernel, 1, se_ratio, activation, 6
        )
        x = _inverted_res_block(
            x, 3, depth(48), kernel, 1, se_ratio, activation, 7
        )
        x = _inverted_res_block(
            x, 6, depth(96), kernel, 2, se_ratio, activation, 8
        )
        x = _inverted_res_block(
            x, 6, depth(96), kernel, 1, se_ratio, activation, 9
        )
        x = _inverted_res_block(
            x, 6, depth(96), kernel, 1, se_ratio, activation, 10
        )
        return x

    return MobileNetV3(
        stack_fn,
        1024,
        input_shape,
        alpha,
        "small",
        minimalistic,
        include_top,
        weights,
        input_tensor,
        classes,
        pooling,
        dropout_rate,
        classifier_activation,
        include_preprocessing,
        name=name,
    )


@keras_export("keras.applications.MobileNetV3Large")
def MobileNetV3Large(
    input_shape=None,
    alpha=1.0,
    minimalistic=False,
    include_top=True,
    weights="imagenet",
    input_tensor=None,
    classes=1000,
    pooling=None,
    dropout_rate=0.2,
    classifier_activation="softmax",
    include_preprocessing=True,
    name="MobileNetV3Large",
):
    def stack_fn(x, kernel, activation, se_ratio):
        def depth(d):
            return _depth(d * alpha)

        x = _inverted_res_block(x, 1, depth(16), 3, 1, None, relu, 0)
        x = _inverted_res_block(x, 4, depth(24), 3, 2, None, relu, 1)
        x = _inverted_res_block(x, 3, depth(24), 3, 1, None, relu, 2)
        x = _inverted_res_block(x, 3, depth(40), kernel, 2, se_ratio, relu, 3)
        x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 4)
        x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 5)
        x = _inverted_res_block(x, 6, depth(80), 3, 2, None, activation, 6)
        x = _inverted_res_block(x, 2.5, depth(80), 3, 1, None, activation, 7)
        x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 8)
        x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 9)
        x = _inverted_res_block(
            x, 6, depth(112), 3, 1, se_ratio, activation, 10
        )
        x = _inverted_res_block(
            x, 6, depth(112), 3, 1, se_ratio, activation, 11
        )
        x = _inverted_res_block(
            x, 6, depth(160), kernel, 2, se_ratio, activation, 12
        )
        x = _inverted_res_block(
            x, 6, depth(160), kernel, 1, se_ratio, activation, 13
        )
        x = _inverted_res_block(
            x, 6, depth(160), kernel, 1, se_ratio, activation, 14
        )
        return x

    return MobileNetV3(
        stack_fn,
        1280,
        input_shape,
        alpha,
        "large",
        minimalistic,
        include_top,
        weights,
        input_tensor,
        classes,
        pooling,
        dropout_rate,
        classifier_activation,
        include_preprocessing,
        name=name,
    )


MobileNetV3Small.__doc__ = BASE_DOCSTRING.format(name="MobileNetV3Small")
MobileNetV3Large.__doc__ = BASE_DOCSTRING.format(name="MobileNetV3Large")


def relu(x):
    return layers.ReLU()(x)


def hard_sigmoid(x):
    return layers.ReLU(6.0)(x + 3.0) * (1.0 / 6.0)


def hard_swish(x):
    return layers.Activation("hard_swish")(x)


# This function is taken from the original tf repo.
# It ensures that all layers have a channel number that is divisible by 8
# It can be seen here:
# https://github.com/tensorflow/models/blob/master/research/
# slim/nets/mobilenet/mobilenet.py


def _depth(v, divisor=8, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


def _se_block(inputs, filters, se_ratio, prefix):
    x = layers.GlobalAveragePooling2D(
        keepdims=True, name=prefix + "squeeze_excite_avg_pool"
    )(inputs)
    x = layers.Conv2D(
        _depth(filters * se_ratio),
        kernel_size=1,
        padding="same",
        name=prefix + "squeeze_excite_conv",
    )(x)
    x = layers.ReLU(name=prefix + "squeeze_excite_relu")(x)
    x = layers.Conv2D(
        filters,
        kernel_size=1,
        padding="same",
        name=prefix + "squeeze_excite_conv_1",
    )(x)
    x = hard_sigmoid(x)
    x = layers.Multiply(name=prefix + "squeeze_excite_mul")([inputs, x])
    return x


def _inverted_res_block(
    x, expansion, filters, kernel_size, stride, se_ratio, activation, block_id
):
    channel_axis = 1 if backend.image_data_format() == "channels_first" else -1
    shortcut = x
    prefix = "expanded_conv_"
    infilters = x.shape[channel_axis]
    if block_id:
        # Expand
        prefix = f"expanded_conv_{block_id}_"
        x = layers.Conv2D(
            _depth(infilters * expansion),
            kernel_size=1,
            padding="same",
            use_bias=False,
            name=prefix + "expand",
        )(x)
        x = layers.BatchNormalization(
            axis=channel_axis,
            epsilon=1e-3,
            momentum=0.999,
            name=prefix + "expand_bn",
        )(x)
        x = activation(x)

    if stride == 2:
        x = layers.ZeroPadding2D(
            padding=imagenet_utils.correct_pad(x, kernel_size),
            name=prefix + "depthwise_pad",
        )(x)
    x = layers.DepthwiseConv2D(
        kernel_size,
        strides=stride,
        padding="same" if stride == 1 else "valid",
        use_bias=False,
        name=prefix + "depthwise",
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis,
        epsilon=1e-3,
        momentum=0.999,
        name=prefix + "depthwise_bn",
    )(x)
    x = activation(x)

    if se_ratio:
        x = _se_block(x, _depth(infilters * expansion), se_ratio, prefix)

    x = layers.Conv2D(
        filters,
        kernel_size=1,
        padding="same",
        use_bias=False,
        name=prefix + "project",
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis,
        epsilon=1e-3,
        momentum=0.999,
        name=prefix + "project_bn",
    )(x)

    if stride == 1 and infilters == filters:
        x = layers.Add(name=prefix + "add")([shortcut, x])
    return x


@keras_export("keras.applications.mobilenet_v3.preprocess_input")
def preprocess_input(x, data_format=None):
    """A placeholder method for backward compatibility.

    The preprocessing logic has been included in the mobilenet_v3 model
    implementation. Users are no longer required to call this method to
    normalize the input data. This method does nothing and only kept as a
    placeholder to align the API surface between old and new version of model.

    Args:
        x: A floating point `numpy.array` or a tensor.
        data_format: Optional data format of the image tensor/array.
            `None` means the global setting
            `keras.config.image_data_format()` is used
            (unless you changed it, it uses `"channels_last"`).
            Defaults to `None`.

    Returns:
        Unchanged `numpy.array` or tensor.
    """
    return x


@keras_export("keras.applications.mobilenet_v3.decode_predictions")
def decode_predictions(preds, top=5):
    return imagenet_utils.decode_predictions(preds, top=top)


decode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__
