# mypy: allow-untyped-defs
from abc import ABC, abstractmethod
from typing import Any, Callable

import torch
import torch.ao.nn.qat as nnqat
import torch.nn as nn
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.fx._equalize import (
    default_equalization_qconfig,
    EqualizationQConfig,
)
from torch.ao.quantization.fx._model_report.model_report_observer import (
    ModelReportObserver,
)
from torch.ao.quantization.fx.graph_module import GraphModule
from torch.ao.quantization.observer import (
    _is_activation_post_process,
    default_dynamic_quant_observer,
    default_observer,
    default_per_channel_weight_observer,
    default_weight_observer,
    ObserverBase,
)
from torch.ao.quantization.qconfig import (
    _assert_valid_qconfig,
    default_qconfig,
    QConfig,
)


# Names for observer insert keys
DETECTOR_TARGET_NODE_KEY = "target_node"
DETECTOR_OBS_TO_INSERT_KEY = "observer_to_insert"
DETECTOR_IS_POST_OBS_KEY = "is_post_observer"
DETECTOR_OBS_ARGS_KEY = "observer_args"


# Mapping related code
class DetectorQConfigInfo:
    r"""
    This class contains the QConfig information for a single module.
    The list of variables / values this contains can grow depending on the
    extensibility of the qconfig mapping feature set but this currently includes:
    - if activation observer is dynamic
    - if weight observer is per channel


    Args:
        module_fqn (str): The fully qualified name (fqn) of the module that this
            information contains info relevant to qconfig for
    """

    def __init__(self, module_fqn: str):
        super().__init__()
        self.module_fqn = module_fqn

        # populate this section with all the variables we might find important
        # change from none if your detector is actually using this
        self.is_activation_dynamic = False
        self.is_weight_per_channel = False

        # equalization related options
        self.is_equalization_recommended = False

    def generate_quantization_qconfig(self, module: torch.nn.Module) -> QConfig:
        r"""
        Args:
            module (torch.nn.Module) The module we are generating
            the qconfig for

        Returns the generated quantization QConfig according to what a valid configuration is
        """
        # Apply suggestions to new qconfig
        module_qconfig = default_qconfig

        # keep track of dynamic and per_channel recommendations
        recommendations_list = []
        # append as if a list of combinations
        recommendations_list.append(
            (self.is_activation_dynamic, self.is_weight_per_channel)
        )
        recommendations_list.append(
            (self.is_activation_dynamic, False)
        )  # only trying dynamic rec
        recommendations_list.append(
            (False, self.is_weight_per_channel)
        )  # only trying dynamic

        # now we try each of the combinations
        for rec in recommendations_list:
            # rec[0] -> dynamic recommended
            # rec[1] -> per channel recommended
            activation = default_dynamic_quant_observer if rec[0] else default_observer
            weight = (
                default_per_channel_weight_observer
                if rec[1]
                else default_weight_observer
            )
            test_config = QConfig(activation, weight)
            try:
                _assert_valid_qconfig(test_config, module)
                module_qconfig = test_config
                break
            except AssertionError:
                # if not a valid configuration, we move on to the next one in priority
                continue

        # return the QConfig chosen
        return module_qconfig

    def generate_equalization_qconfig(self) -> EqualizationQConfig:
        r"""
        This returns the equalization configuration for a module.

        For now, it just returns the default, but as more equalization options become
        possible, this method can get more fleshed out with more nuanced granularity.


        Returns the generated equalization QConfig according to what a valid configuration is
        """
        # in this case, we just return default equalization config
        # we know this is valid because only valid modules would even
        # have this option
        return default_equalization_qconfig


# Adding base class for detectors
class DetectorBase(ABC):
    r"""Base Detector Module
    Any detector class should derive from this class.

    Concrete detectors should follow the same general API, which includes:
    - A method to calculate and return observer insertion points
        - Should return both the fqns and the Observer class to insert
    - A method to return a report based on the detector
        - Should return a str-based report and dict info in Tuple[str,Dict] format
    """

    def __init__(self) -> None:
        super().__init__()
        self.detector_config_info = None

    @abstractmethod
    def determine_observer_insert_points(self, model) -> dict:
        r"""
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict.
            This dict maps string keys to detector specific information
        """

    @abstractmethod
    def get_detector_name(self) -> str:
        r"""Returns the name of the current detector"""

    @abstractmethod
    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
        """

    def _get_targeting_node(
        self, prepared_fx_model: GraphModule, target_fqn: str
    ) -> torch.fx.node.Node:
        r"""
        Takes in a GraphModule and the target_fqn and finds the node whose target is this fqn.

        If it's not found, it means it is most likely inside a fused layer
            We just go one layer up in terms of the fqn we are searching for until we find parent node
            If we get to empty string, then we know that it doesn't exist

        The reason for the recursion is that if the model that we are looking for got fused,
        we will have module fqn as e.g. x.linear.0 but the graph will only have a node for the fused module,
        which would have fqn as x.linear so they will not match.
        To handle this, if we don't match, we then take off the last bit of the fqn e.g. x.linear.0 -> x.linear,
        or more generally foo.bar.baz -> foo.bar and search again, this will allow us to locate the correct module
        even in cases with fusion

        Args:
            prepared_fx_model (GraphModule):  The prepared Fx GraphModule
            target_fqn (str): The fqn of the layer we are trying to target

        Returns the node object we are trying to add observers around
        """
        for node in prepared_fx_model.graph.nodes:
            # if the node's target is our target, return it
            if node.target == target_fqn:
                return node

        # getting here means node not found
        # if no "." we are already at base and failed
        parent_fqn_sep_index = target_fqn.rfind(".")
        if parent_fqn_sep_index == -1:
            raise ValueError("passed in target_fqn not found in graph's targets.")
        else:
            # recursively call it with parent fqn
            return self._get_targeting_node(
                prepared_fx_model, target_fqn[:parent_fqn_sep_index]
            )

    @abstractmethod
    def generate_detector_report(self, model) -> tuple[str, dict[str, Any]]:
        r"""
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Tuple of two elements:
            Str: string report of the suggested improvements
            Dict: contains useful data collected by the observer pertinent to this report
        """


class PerChannelDetector(DetectorBase):
    r"""This class is used to detect if any Linear or Conv layers in a model utilize per_channel quantization.
    Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.

    per_channel quantization can lead to major benefits in the form of accuracy.
    Therefore, if the backend used by the user supports it, it is recommended to use

    Args:
        backend (str, optional): the backend the user wishes to use in production
            Default value is current torch.backends.quantized.engine
    """

    # Keys for return dictionary
    BACKEND_KEY = "backend"
    PER_CHAN_SUPPORTED_KEY = "per_channel_quantization_supported"
    PER_CHAN_USED_KEY = "per_channel_quantization_used"

    # Default map for representing supported per channel quantization modules for different backends
    DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES: dict[str, set[Any]] = {
        "fbgemm": {
            nn.Linear,
            nn.Conv1d,
            nn.Conv2d,
            nn.Conv3d,
            nnqat.Linear,
            nnqat.Conv1d,
            nnqat.Conv2d,
            nnqat.Conv3d,
        },
        "qnnpack": {
            nn.Linear,
            nn.Conv1d,
            nn.Conv2d,
            nn.Conv3d,
            nnqat.Linear,
            nnqat.Conv1d,
            nnqat.Conv2d,
            nnqat.Conv3d,
        },
        "onednn": {
            nn.Linear,
            nn.Conv1d,
            nn.Conv2d,
            nn.Conv3d,
            nnqat.Linear,
            nnqat.Conv1d,
            nnqat.Conv2d,
            nnqat.Conv3d,
        },
        "x86": {
            nn.Linear,
            nn.Conv1d,
            nn.Conv2d,
            nn.Conv3d,
            nnqat.Linear,
            nnqat.Conv1d,
            nnqat.Conv2d,
            nnqat.Conv3d,
        },
    }

    def __init__(self, backend: str = torch.backends.quantized.engine):
        super().__init__()

        # store the backend information
        self.backend_chosen = backend
        self.supported_modules = set()
        if self.backend_chosen in self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES:
            self.supported_modules = self.DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES[
                self.backend_chosen
            ]
        else:
            raise ValueError(
                f"Not configured to work with {self.backend_chosen}. Try a different default backend"
            )

    def get_detector_name(self) -> str:
        r"""returns the string name of this detector"""
        return "per_channel_detector"

    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
        """
        # run the helper function to populate the dictionary
        per_channel_info = self._detect_per_channel_helper(model)

        # we actually have a qconfig info object we are populating
        module_fqn_to_detector_qconfig_info = {}

        for module_fqn in per_channel_info:
            # create a detector info instance
            detector_qconfig_info = DetectorQConfigInfo(module_fqn)

            # see if per channel quantization is supported
            per_chan_supported: bool = per_channel_info[module_fqn][
                self.PER_CHAN_SUPPORTED_KEY
            ]
            detector_qconfig_info.is_weight_per_channel = per_chan_supported
            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info

        return module_fqn_to_detector_qconfig_info

    def determine_observer_insert_points(self, model: nn.Module) -> dict:
        r"""
        There is no observers inserted for the PerChannelDetector.

        Returns an empty dictionary since no observers are added or needed
        """
        return {}

    def _detect_per_channel_helper(self, model: nn.Module):
        r"""
        determines if per_channel quantization is supported in modules and submodules.

        Returns a dictionary in the higher level _detect_per_channel function.
        Each entry maps the fully-qualified-name to information on whether per_channel quantization.

        Args:
            model: The current module that is being checked to see if it is per_channel quantizable

        Returns dictionary mapping fqns to if per_channel quantization is possible
        """
        # create dict we will return
        per_channel_info: dict = {}

        # get the fully qualified name and check if in list of modules to include and list of modules to ignore
        for fqn, module in model.named_modules():
            is_in_include_list = any(
                isinstance(module, x) for x in self.supported_modules
            )

            # check if the module per_channel is supported
            # based on backend
            per_channel_supported = False

            if is_in_include_list:
                per_channel_supported = True

                # assert statement for MyPy
                q_config_file = module.qconfig
                assert isinstance(q_config_file, QConfig)

                # this object should either be fake quant or observer
                q_or_s_obj = module.qconfig.weight.p.func()
                assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))

                per_channel_used = False  # will be true if found in qconfig

                if hasattr(
                    q_or_s_obj, "ch_axis"
                ):  # then we know that per_channel quantization used
                    # all fake quants have channel axis so need to check is_per_channel
                    if isinstance(q_or_s_obj, FakeQuantize):
                        if (
                            hasattr(q_or_s_obj, "is_per_channel")
                            and q_or_s_obj.is_per_channel
                        ):
                            per_channel_used = True
                    elif isinstance(q_or_s_obj, ObserverBase):
                        # should be an observer otherwise
                        per_channel_used = True
                    else:
                        raise ValueError("Should be either observer or fake quant")

                per_channel_info[fqn] = {
                    self.PER_CHAN_SUPPORTED_KEY: per_channel_supported,
                    self.PER_CHAN_USED_KEY: per_channel_used,
                    self.BACKEND_KEY: self.backend_chosen,
                }

        return per_channel_info

    def generate_detector_report(self, model: nn.Module) -> tuple[str, dict[str, Any]]:
        r"""Checks if any Linear or Conv layers in the model utilize per_channel quantization.
        Only Linear and Conv layers can use per_channel as of now so only these two are currently checked.

        Looks at q_config format and backend to determine if per_channel can be utilized.
        Uses the DEFAULT_BACKEND_PER_CHANNEL_SUPPORTED_MODULES structure to determine support

        Args:
            model: The prepared and calibrated model we want to check if using per_channel

        Returns a tuple with two elements:
            String report of potential actions to improve model (if per_channel quantization is available in backend)
            Dictionary mapping per_channel quantizable elements to:
                whether per_channel quantization is supported by the backend
                if it is being utilized in the current model
        """

        # run the helper function to populate the dictionary
        per_channel_info = self._detect_per_channel_helper(model)

        # String to let the user know of further optimizations
        further_optims_str = (
            f"Further Optimizations for backend {self.backend_chosen}: \n"
        )

        optimizations_possible = False
        for fqn in per_channel_info:
            fqn_dict = per_channel_info[fqn]
            if (
                fqn_dict[self.PER_CHAN_SUPPORTED_KEY]
                and not fqn_dict[self.PER_CHAN_USED_KEY]
            ):
                optimizations_possible = True
                further_optims_str += (
                    f"Module {fqn} can be configured to use per_channel quantization.\n"
                )

        if optimizations_possible:
            further_optims_str += "To use per_channel quantization, make sure the qconfig has a per_channel weight observer."
        else:
            further_optims_str += "No further per_channel optimizations possible."

        # return the string and the dictionary form of same information
        return (further_optims_str, per_channel_info)


class DynamicStaticDetector(DetectorBase):
    r"""
    Determines whether dynamic or static quantization is more appropriate for a given module.

    Takes advantage of the ModelReportObserver that records range information.
    Stationary distribution of data are strictly above tolerance level for the comparison statistic:

        S = average_batch_activation_range/epoch_activation_range

    Nonstationary distributions are below or at the tolerance level for this metric.

    If the distribution of data right after the module is non-stationary, recommend dynamic quantization
        Otherwise recommend static quantization

    Args:
        tolerance (float, optional): The threshold where S metric is stationary above and non-stationary otherwise. Default: 0.5
    """
    # names for the pre and post observers that are inserted
    DEFAULT_PRE_OBSERVER_NAME = "model_report_pre_observer"
    DEFAULT_POST_OBSERVER_NAME = "model_report_post_observer"

    # naming conventions for stationary vs non-stationary data
    STATIONARY_STR = "stationary"
    NON_STATIONARY_STR = "non-stationary"

    # naming for activation
    INPUT_ACTIVATION_PREFIX = "input_activation_"
    OUTPUT_ACTIVATION_PREFIX = "output_activation_"

    # naming conventions for the keys of the return module info
    TOLERANCE_KEY = "dynamic_static_tolerance"
    DEFAULT_DYNAMIC_REC_KEY = "dynamic_recommended"
    PRE_OBS_COMP_STAT_KEY = INPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
    POST_OBS_COMP_STAT_KEY = OUTPUT_ACTIVATION_PREFIX + "dynamic_static_comp_stat"
    PRE_OBS_DATA_DIST_KEY = (
        INPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
    )
    POST_OBS_DATA_DIST_KEY = (
        OUTPUT_ACTIVATION_PREFIX + "dynamic_static_data_classification"
    )
    IS_CURRENTLY_SUPPORTED_KEY = "is_dynamic_supported"

    # modules that are supported both dynamic and static for this report function
    DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED = {nn.Linear}

    # modules that will be supported soon for both
    DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED = {nn.Conv1d, nn.Conv2d, nn.Conv3d}

    def __init__(self, tolerance=0.5):
        super().__init__()

        # set tolerance level and initialize a set to keep track of useful fqn locations
        self.tolerance = tolerance
        self.useful_observer_fqns: set[str] = set()

    def determine_observer_insert_points(
        self, prepared_fx_model: GraphModule
    ) -> dict[str, dict[str, Any]]:
        r"""
        Determines where observers need to be inserted for the Dynamic vs Static detector.
        For this detector, we want to place observers on either side of linear layers in the model.

        Currently inserts observers for:
            linear layers

        Args:
            prepared_fx_model (GraphModule):  The prepared Fx GraphModule

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
            key "observer_args" -> The arguments that are meant to be passed into the observer
        """

        # observer for this detector is ModelReportObserver
        obs_ctr = ModelReportObserver

        # return dict
        obs_fqn_to_info: dict[str, dict[str, Any]] = {}

        for fqn, module in prepared_fx_model.named_modules():
            # make sure module is supported
            if self._is_supported(module, insert=True):
                # if it's a supported type, we want to get node and add observer insert locations
                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)

                # add entry for pre-observer
                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME

                obs_fqn_to_info[pre_obs_fqn] = {
                    DETECTOR_TARGET_NODE_KEY: targeted_node,
                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
                    DETECTOR_IS_POST_OBS_KEY: False,
                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
                }

                # add entry for post-observer
                post_obs_fqn = fqn + "." + self.DEFAULT_POST_OBSERVER_NAME

                obs_fqn_to_info[post_obs_fqn] = {
                    DETECTOR_TARGET_NODE_KEY: targeted_node,
                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(),
                    DETECTOR_IS_POST_OBS_KEY: True,
                    DETECTOR_OBS_ARGS_KEY: (targeted_node,),
                }

        return obs_fqn_to_info

    def get_detector_name(self) -> str:
        r"""returns the string name of this detector"""
        return "dynamic_vs_static_detector"

    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
        """
        # run the helper function to populate the dictionary
        dynamic_static_info = self._generate_dict_info(model)

        # we actually have a qconfig info object we are populating
        module_fqn_to_detector_qconfig_info = {}

        for module_fqn in dynamic_static_info:
            # create a detector info instance
            detector_qconfig_info = DetectorQConfigInfo(module_fqn)

            # see if per channel quantization is supported
            dynamic_static_recommended: bool = dynamic_static_info[module_fqn][
                self.DEFAULT_DYNAMIC_REC_KEY
            ]
            detector_qconfig_info.is_activation_dynamic = dynamic_static_recommended
            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info

        return module_fqn_to_detector_qconfig_info

    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
        r"""Returns whether the given module is supported for observers

        Args
            module: The module to check and ensure is supported
            insert: True if this is check for observer insertion, false if for report gen

        Returns True if the module is supported by observer, False otherwise
        """
        # check to see if module is of a supported type
        is_supported_type = any(
            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
        )

        # check if it will be supported
        future_supported_type = any(
            isinstance(module, x) for x in self.DEFAULT_DYNAMIC_STATIC_FUTURE_SUPPORTED
        )

        # supported
        supported = is_supported_type or future_supported_type

        # this is check for observer insertion
        if insert:
            return supported
        else:
            # this is for report gen and we also need to check if it contains observers
            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME) and hasattr(
                module, self.DEFAULT_POST_OBSERVER_NAME
            )
            return supported and has_obs

    def _generate_dict_info(self, model: GraphModule) -> dict[str, Any]:
        r"""
        Helper function for generate_detector_report that does the generation of the dictionary.
        This process is done as specified in generate_detector_report documentation

        Args:
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a Dictionary mapping modules with ModelReportObservers around them to:
                whether dynamic quantization is recommended
                their S metric of input to module
                whether input to module is stationary or non-stationary
                their S metric of output of module
                whether output of module is stationary or non-stationary
                the tolerance level to decided whether input/output is stationary or non-stationary
                whether it is currently supported or planned for the future
        """
        # store modules dynamic vs static information
        module_dynamic_static_info = {}

        # This for loop goes through the modules, and extracts all relevant information into module_dynamic_static_info
        #   This information primary includes whether the data distributions around a supported module is stationary or not
        #   Based on this, it is recorded whether dynamic or static quantization is recommended

        # loop through all submodules included nested ones
        for fqn, module in model.named_modules():
            # if module is Linear has the ModelReportObserver attached to it
            if self._is_supported(module):
                # get pre and post observers for the module
                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
                post_obs = getattr(module, self.DEFAULT_POST_OBSERVER_NAME)

                # get the statistics for each module
                pre_stat = pre_obs.get_batch_to_epoch_ratio()
                post_stat = post_obs.get_batch_to_epoch_ratio()

                # record module, pre and post stat, and whether to do dynamic or static based off it
                # true if post observer data distribution is non-stationary, false if it's stationary
                dynamic_recommended = post_stat <= self.tolerance

                # specify the classifications for whether data distributions considered stationary or non-stationary
                pre_obs_dist_classif = (
                    self.STATIONARY_STR
                    if pre_stat > self.tolerance
                    else self.NON_STATIONARY_STR
                )
                post_obs_dist_classif = (
                    self.STATIONARY_STR
                    if post_stat > self.tolerance
                    else self.NON_STATIONARY_STR
                )

                # check if current support or future support
                is_supported_type = any(
                    isinstance(module, x)
                    for x in self.DEFAULT_DYNAMIC_STATIC_CHECK_SUPPORTED
                )

                # store the set of important information for this module
                module_info = {
                    self.TOLERANCE_KEY: self.tolerance,
                    self.DEFAULT_DYNAMIC_REC_KEY: dynamic_recommended,
                    self.PRE_OBS_COMP_STAT_KEY: pre_stat,
                    self.PRE_OBS_DATA_DIST_KEY: pre_obs_dist_classif,
                    self.POST_OBS_COMP_STAT_KEY: post_stat,
                    self.POST_OBS_DATA_DIST_KEY: post_obs_dist_classif,
                    self.IS_CURRENTLY_SUPPORTED_KEY: is_supported_type,
                }

                module_dynamic_static_info[fqn] = module_info

        return module_dynamic_static_info

    def generate_detector_report(
        self, model: GraphModule
    ) -> tuple[str, dict[str, Any]]:
        r"""
        Determines whether dynamic or static quantization is more appropriate for a given module.

        Takes advantage of the ModelReportObserver that records range information.
        Stationary distribution of data are strictly above tolerance level for the comparison statistic:

            S = average_batch_activation_range/epoch_activation_range

        Nonstationary distributions are below or at the tolerance level for this metric.

        If the distribution of data right after the module is non-stationary, recommend dynamic quantization
            Otherwise recommend static quantization

        This will then generate suggestions for dynamic vs static quantization focused around Linear.

        Args:
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a tuple with two elements:
            String report of of whether dynamic or static quantization is recommended for certain modules
            Dictionary mapping modules with ModelReportObservers around them to:
                whether dynamic quantization is recommended
                their S metric of input to module
                whether input to module is stationary or non-stationary
                their S metric of output of module
                whether output of module is stationary or non-stationary
                the tolerance level to decided whether input/output is stationary or non-stationary
                whether it is currently supported or planned for the future
        """

        # get the dictionary of the information to format the string report
        module_dynamic_static_info = self._generate_dict_info(model)

        dynamic_vs_static_string = "Dynamic vs. Static Quantization suggestions: \n"

        modules_added: bool = False  # check to make sure at least 1 module added.

        dynamic_benefit = (
            " You will get more accurate results if you use dynamic quantization"
        )
        static_benefit = (
            " You can increase model efficiency if you use static quantization"
        )
        future_support_str = (
            ". This layer is not yet supported for dynamic quantization"
        )
        # This for loop goes through the information collected in module_dynamic_static_info and:
        #   Populates the string based report with the information from module_dynamic_static_info
        #   Compiles the complete report by appending relevant formatted strings

        for module_fqn in module_dynamic_static_info.keys():
            # there is at least 1 module for suggestion
            modules_added = True
            module_info = module_dynamic_static_info[module_fqn]
            suggestion_string_template = (
                "For module {} it is suggested to use {} quantization because {}.\n"
            )

            # decide what string formatting values will be
            quantization_type = ""
            quantization_reasoning = "the distribution of data before {} is {} and the distribution after is {}."

            benefit_str = ""

            # strings for if dynamic quantized per tensor is needed
            recommend_per_tensor = (
                ". We recommend to add a {} before this module if it is static."
            )
            rec_lay_to_add = "dynamic quantize per tensor layer"
            dynamic_per_tensor_string = recommend_per_tensor.format(rec_lay_to_add)
            dynamic_per_tensor_reasoning_string = " This is because the input to this module has a non-stationary distribution"

            # start composing explanation
            if module_info[self.DEFAULT_DYNAMIC_REC_KEY]:
                quantization_type = "dynamic"
                # check if currently supported or future supported
                benefit_str = dynamic_benefit
                if not module_info[self.IS_CURRENTLY_SUPPORTED_KEY]:
                    benefit_str += future_support_str
            else:
                quantization_type = "static"
                benefit_str = static_benefit

            # now set the quantization explanation string
            quantization_reasoning = (
                quantization_reasoning.format(
                    module_fqn,
                    module_info[self.PRE_OBS_DATA_DIST_KEY],
                    module_info[self.POST_OBS_DATA_DIST_KEY],
                )
                + benefit_str
            )

            # if we have a non-stationary input -> linear -> stationary we suggested static
            # however, we want to also recommend they add a dynamic quantize per tensor right if this change is made
            if (
                module_info[self.PRE_OBS_DATA_DIST_KEY] == self.NON_STATIONARY_STR
                and module_info[self.POST_OBS_DATA_DIST_KEY] == self.STATIONARY_STR
            ):
                quantization_reasoning = (
                    quantization_reasoning
                    + dynamic_per_tensor_string
                    + dynamic_per_tensor_reasoning_string
                )

            # format the overall suggestion string with the specific inputs
            module_suggestion_string = suggestion_string_template.format(
                module_fqn, quantization_type, quantization_reasoning
            )

            # append to overall suggestion
            dynamic_vs_static_string += module_suggestion_string

        if not modules_added:
            dynamic_vs_static_string += "No applicable layers for suggestions. Only linear and conv are valid.\n"

        # return the string as well as the dictionary of information
        return (dynamic_vs_static_string, module_dynamic_static_info)


class InputWeightEqualizationDetector(DetectorBase):
    r"""
    Determines whether input-weight equalization can help improve quantization for certain modules.

    Specifically, this list of modules includes:
        linear
        conv

    Determines whether input-weight equalization is recommended based on the comp stat:
        s_c = sqrt(w_c/W)/sqrt(i_c/I)
        where:
            w_c is range of weight for channel c, W is range of weight over all channels
            i_c is range of input for channel c, I is range of input over all channels

        if s_c >= threshold or <= 1 / threshold, recommends input-weight equalization

    Args:
        ratio_threshold (float): The threshold for s_c to determine if input-weight equalization is suggested
            Should be between 0 and 1 (both non-inclusive)
        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
            Default: 1

    * :attr:`ratio_threshold`: The threshold for s_c to determine if input-weight equalization is suggested
        Should be between 0 and 1

    * :attr:`ch_axis`: The channel axis being observed to determine input weight equalization

    * :attr:`SUPPORTED_MODULES`: This specifies the modules that are supported for input-weight equalization

    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
    """

    SUPPORTED_MODULES: set[Callable] = {
        nn.Linear,
        nn.Conv1d,
        nn.Conv2d,
        nn.Conv3d,
        nnqat.Linear,
        nnqat.Conv1d,
        nnqat.Conv2d,
        nnqat.Conv3d,
    }

    # names for the pre and post observers that are inserted
    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"

    # weight / activation prefix for each of the below info
    WEIGHT_PREFIX = "weight_"
    ACTIVATION_PREFIX = "input_activation_"

    # string names for keys of info dictionaries
    PER_CHANNEL_MAX_KEY = "per_channel_max"
    PER_CHANNEL_MIN_KEY = "per_channel_min"
    GLOBAL_MAX_KEY = "global_max"
    GLOBAL_MIN_KEY = "global_min"

    # keys for return dict of recommendations
    RECOMMENDED_KEY = "input_weight_equalization_recommended"
    COMP_METRIC_KEY = "input_weight_channel_comparison_metrics"
    THRESHOLD_KEY = "input_weight_threshold"
    CHANNEL_KEY = "input_weight_channel_axis"

    # default weight and info strings
    WEIGHT_STR = "weight"
    INPUT_STR = "input"

    # default for what ratio we recommend input weight
    DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO = 0.4

    def __init__(self, ratio_threshold: float, ch_axis: int = 1):
        # ensure passed in inputs are valid
        if ratio_threshold <= 0 or ratio_threshold >= 1:
            raise ValueError("Make sure threshold is > 0 and < 1")

        # initialize attributes based on args
        self.ratio_threshold: float = ratio_threshold
        self.ch_axis: int = ch_axis

    def _is_supported(self, module: nn.Module, insert: bool = False) -> bool:
        r"""Returns whether the given module is supported for observers

        Args
            module: The module to check and ensure is supported
            insert: True if this is check for observer insertion, false if for report gen

        Returns True if the module is supported by observer, False otherwise
        """
        # check to see if module is of a supported type
        is_supported_type = any(type(module) is x for x in self.SUPPORTED_MODULES)

        # this is check for observer insertion
        if insert:
            return is_supported_type
        else:
            # this is for report gen and we also need to check if it contains observers
            has_obs = hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)
            return is_supported_type and has_obs

    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
        """
        # run the helper function to populate the dictionary
        # find the range of inputs
        input_values: dict[str, dict] = self._extract_input_info(model)

        # find the range of weights
        weight_values: dict[str, dict] = self._extract_weight_info(model)

        # calculate per_channel comparison statistic s_c
        comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values(
            input_values, weight_values
        )

        # generate the return dictionary
        input_weight_equalization_info: dict[str, dict] = self._generate_dict_info(
            input_values, weight_values, comp_stats
        )

        # we actually have a qconfig info object we are populating
        module_fqn_to_detector_qconfig_info = {}

        for module_fqn in input_weight_equalization_info:
            # create a detector info instance
            detector_qconfig_info = DetectorQConfigInfo(module_fqn)

            # see if per channel quantization is supported
            input_weight_recommended: bool = input_weight_equalization_info[module_fqn][
                self.RECOMMENDED_KEY
            ]
            detector_qconfig_info.is_equalization_recommended = input_weight_recommended
            module_fqn_to_detector_qconfig_info[module_fqn] = detector_qconfig_info

        return module_fqn_to_detector_qconfig_info

    def determine_observer_insert_points(
        self, prepared_fx_model: GraphModule
    ) -> dict[str, dict[str, Any]]:
        r"""Determines where observers need to be inserted for the Input Weight Equalization Detector.
        For this detector, we want to place observers in front of supported layers.

        Currently inserts observers for:
            linear layers
            conv layers

        Args:
            prepared_fx_model (GraphModule):  The prepared Fx GraphModule

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
            key "observer_args" -> The arguments that are meant to be passed into the observer
        """

        # observer for this detector is ModelReportObserver
        obs_ctr = ModelReportObserver

        # return dict
        obs_fqn_to_info: dict[str, dict[str, Any]] = {}

        for fqn, module in prepared_fx_model.named_modules():
            # check to see if module is of a supported type
            if self._is_supported(module, insert=True):
                # if it's a supported type, we want to get node and add observer insert locations
                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)

                # add entry for pre-observer
                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME

                obs_fqn_to_info[pre_obs_fqn] = {
                    DETECTOR_TARGET_NODE_KEY: targeted_node,
                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(ch_axis=self.ch_axis),
                    DETECTOR_IS_POST_OBS_KEY: False,
                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
                }

        return obs_fqn_to_info

    def get_detector_name(self) -> str:
        r"""Returns the name of this detector"""
        return "input_weight_equalization_detector"

    def _extract_input_info(self, model: GraphModule) -> dict[str, dict]:
        r"""
        Takes in a calibrated GraphModule and then finds the relevant observers.
        It then extracts the input information for each observer returns it

        Args
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a dict mapping relevant module fqns (str) to a dict with keys:
            "input_activation_per_channel_max" : maps to the per_channel max values
            "input_activation_per_channel_min" : maps to the per_channel min values
            "input_activation_global_max" : maps to the global max recorded
            "input_activation_global_min" : maps to the global min recorded
        """

        # return dictionary mapping observer fqns to desired info
        input_info: dict[str, dict] = {}

        for fqn, module in model.named_modules():
            # if module is supported and it has a pre-observer
            if self._is_supported(module):
                # get pre observer for the module
                pre_obs = getattr(module, self.DEFAULT_PRE_OBSERVER_NAME)

                input_info[fqn] = {
                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MAX_KEY: pre_obs.max_val,
                    self.ACTIVATION_PREFIX + self.PER_CHANNEL_MIN_KEY: pre_obs.min_val,
                    self.ACTIVATION_PREFIX + self.GLOBAL_MAX_KEY: max(pre_obs.max_val),
                    self.ACTIVATION_PREFIX + self.GLOBAL_MIN_KEY: min(pre_obs.min_val),
                }

        return input_info

    def _extract_weight_info(self, model: GraphModule) -> dict[str, dict]:
        r"""
        Takes in a calibrated GraphModule and then finds the relevant observers.
        It then extracts the weight information for each layer an observer is attached to.

        Args
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a dict mapping module fqns (str) to a dict with keys:
            "per_channel_max" : maps to the per_channel max values
            "per_channel_min" : maps to the per_channel min values
            "global_max" : maps to the global max recorded
            "global_min" : maps to the global min recorded
        """
        # return dictionary mapping observer fqns to desired info
        weight_info: dict[str, dict] = {}

        for fqn, module in model.named_modules():
            # if module is supported and it has a pre-observer
            if self._is_supported(module):
                # we don't need actual observer, just the module weights
                # calculate min and max vals
                device = module.weight.device
                min_val: torch.Tensor = torch.tensor([float("inf")], device=device)
                max_val: torch.Tensor = torch.tensor([float("-inf")], device=device)
                x_copy = module.weight
                x_dim = x_copy.size()

                new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
                new_axis_list[self.ch_axis] = 0
                new_axis_list[0] = self.ch_axis
                y = x_copy.permute(new_axis_list)

                # Need to match dtype of min/max because the updates to buffers
                # are done in place and types need to match for comparisons
                y = y.to(min_val.dtype)
                y = torch.flatten(y, start_dim=1)
                if min_val.numel() == 0 or max_val.numel() == 0:
                    min_val, max_val = torch.aminmax(y, dim=1)
                else:
                    min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
                    min_val = torch.min(min_val_cur, min_val)
                    max_val = torch.max(max_val_cur, max_val)

                weight_info[fqn] = {
                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MAX_KEY: max_val,
                    self.WEIGHT_PREFIX + self.PER_CHANNEL_MIN_KEY: min_val,
                    self.WEIGHT_PREFIX + self.GLOBAL_MAX_KEY: max(max_val),
                    self.WEIGHT_PREFIX + self.GLOBAL_MIN_KEY: min(min_val),
                }

        return weight_info

    def _calculate_range_ratio(
        self, info_dict: dict, info_str: str, module_fqn: str
    ) -> torch.Tensor:
        r"""
        Takes in an info dict and calculates the s_c matrix.

        Args:
            info_dict (dict): A dictionary of either input or weight range info
            info_str (str): A str describing whether currently looking at weight or input info
                Either "weight" or "input"
            module_fqn (str): The fqn of the module we are looking at

        Returns a tensor of values, where each value is the s_c stat for a different channel
        """
        # calculate the ratios of the info
        # get the prefix str
        prefix_str = (
            self.ACTIVATION_PREFIX if info_str == self.INPUT_STR else self.WEIGHT_PREFIX
        )

        per_channel_range = (
            info_dict[prefix_str + self.PER_CHANNEL_MAX_KEY]
            - info_dict[prefix_str + self.PER_CHANNEL_MIN_KEY]
        )
        global_range = (
            info_dict[prefix_str + self.GLOBAL_MAX_KEY]
            - info_dict[prefix_str + self.GLOBAL_MIN_KEY]
        )

        if global_range == 0:
            range_zero_explanation = "We recommend removing this channel as it doesn't provide any useful information."
            raise ValueError(
                f"The range of the {info_str} data for module {module_fqn} is 0, "
                f"which means you have a constant value channel. {range_zero_explanation}"
            )

        ratio = per_channel_range / global_range

        return ratio

    def _generate_comparison_values(
        self, input_info: dict, weight_info: dict
    ) -> dict[str, torch.Tensor]:
        r"""
        Takes in the information on the min and max values of the inputs and weights and:
            Calculates the comp stat for each channel: s_c = sqrt(w_c/W)/sqrt(i_c/I)

        Args:
            input_info (dict): A dict mapping each observer to input range information
            weight_info (dict): A dict mapping each observer to weight range information

        Returns a dict mapping relevant observer fqns (str) to a 1-D tensor.
            Each value is a different s_c value for a different channel
        """
        # create return dictionary for each observer
        module_fqn_to_channel: dict[str, torch.Tensor] = {}

        # for each module (both passed in dicts should have same keys)
        for module_fqn in input_info:
            # raise error if not in weight info
            if module_fqn not in weight_info:
                raise KeyError(
                    f"Unable to find weight range stats for module {module_fqn}"
                )

            # calculate the ratios of the weight info and input info
            weight_ratio = self._calculate_range_ratio(
                weight_info[module_fqn], self.WEIGHT_STR, module_fqn
            )
            input_ratio = self._calculate_range_ratio(
                input_info[module_fqn], self.INPUT_STR, module_fqn
            )

            # if mismatched size, because of grouping, we want to replicate weight enough times
            weight_channels = len(weight_ratio)
            input_channels = len(input_ratio)
            if weight_channels != input_channels:
                # we try to replicate
                assert (
                    input_channels % weight_channels == 0
                ), "input channels should be divisible by weight channels."
                # get replication factor
                rep_factor: int = input_channels // weight_channels

                # weight ratio is (n,), input ratio is (k,), we just repeat weight ratio k // n
                weight_ratio = weight_ratio.repeat(rep_factor)

            # calculate the s metric per channel
            s = torch.sqrt(weight_ratio) / torch.sqrt(input_ratio)
            module_fqn_to_channel[module_fqn] = s

        # return compiled observer ratios
        return module_fqn_to_channel

    def _generate_dict_info(
        self, input_info: dict, weight_info: dict, comp_stats: dict
    ) -> dict[str, dict]:
        r"""
        Helper function for generate_detector_report that does the generation of the dictionary.
        This process is done as specified in generate_detector_report documentation

        Args:
            input_info (dict): A dict mapping each module to input range information
            weight_info (dict): A dict mapping each module to weight range information
            comp_stats (dict): A dict mapping each module to its corresponding comp stat

        Returns a dictionary mapping each module with relevant ModelReportObservers around them to:
            whether input weight equalization is recommended
            their s_c metric compared to the threshold
            the threshold used to make the recommendation
            the channel used for recording data
            the input channel range info
            the weight channel range info
        """
        # store modules input weight equalization info
        input_weight_equalization_info: dict[str, dict] = {}

        # for each module we add separate set of suggestions
        for module_fqn in input_info:
            # get relevant info for this module
            mod_input_info: dict = input_info[module_fqn]
            mod_weight_info: dict = weight_info[module_fqn]
            mod_comp_stat: dict = comp_stats[module_fqn]

            # decide if each channel should have input weight equalization or not
            channel_rec_vals: list = []

            for val in mod_comp_stat:
                float_rep: float = val.item()

                # decide if recommending input weight equalization
                recommended: bool = (
                    float_rep >= self.ratio_threshold
                    and float_rep <= 1 / self.ratio_threshold
                )
                channel_rec_vals.append(recommended)

            # build the return dict input
            # also unpack input and weight dicts into it
            input_weight_equalization_info[module_fqn] = {
                self.RECOMMENDED_KEY: channel_rec_vals,
                self.COMP_METRIC_KEY: mod_comp_stat,
                self.THRESHOLD_KEY: self.ratio_threshold,
                self.CHANNEL_KEY: self.ch_axis,
                **mod_input_info,
                **mod_weight_info,
            }

        # return our compiled info for each module
        return input_weight_equalization_info

    def generate_detector_report(
        self, model: GraphModule
    ) -> tuple[str, dict[str, Any]]:
        r"""
        Determines whether input weight equalization is appropriate for a given module.

        Takes advantage of the ModelReport Observer which records per channel information of input range
        It then uses the passed in weight info inconjunction to compute the desired ratio
        Finally, it gives suggestions based on this information for each module of interest

        Args:
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a tuple with two elements:
            String report of of whether input weight equalization is recommended for certain modules
            Dictionary mapping modules of interest to:
                whether input weight equalization is recommended
                their s_c metric compared to the threshold
                the threshold used to make the recommendation
                the channel used for recording data
                the input channel range info
                the weight channel range info
        """

        # find the range of inputs
        input_values: dict[str, dict] = self._extract_input_info(model)

        # find the range of weights
        weight_values: dict[str, dict] = self._extract_weight_info(model)

        # calculate per_channel comparison statistic s_c
        comp_stats: dict[str, torch.Tensor] = self._generate_comparison_values(
            input_values, weight_values
        )

        # generate the return dictionary
        input_weight_equalization_info: dict[str, dict] = self._generate_dict_info(
            input_values, weight_values, comp_stats
        )

        # now we can generate report based on this information
        input_weight_string = "Input-Weight Equalization suggestions: \n"

        # some strings to be formatted depending on module we are adding
        module_suggestion_str = "For Module {} looked at with axis {}: \n"
        channel_suggestion_str = (
            "\tWe suggest {} input weight equalization because {}\n"
        )
        use_str = "to use"
        no_use_str = "to not use"
        input_weight_benefit_str = "{}/{} channels would benefit and we expect significant reduction in quantization error."
        input_weight_non_benefit_reasoning = (
            "{}/{} channels benefitting from input-weight equalization being applied."
        )
        input_weight_non_benefit_str = "we don't expect much improvement from input-weight equalization based on {}"

        # added module check
        added_module: bool = False

        # compile the suggestion string
        for module_fqn in input_weight_equalization_info:
            # we added at least 1 module
            added_module = True
            # add the module level description
            input_weight_string += module_suggestion_str.format(
                module_fqn, self.ch_axis
            )

            mod_info: dict[str, Any] = input_weight_equalization_info[module_fqn]

            # gather info on how many channels would benefit from input weight and
            recommendation_per_channel: torch.Tensor = mod_info[self.RECOMMENDED_KEY]
            num_recs = sum(recommendation_per_channel)

            if (
                num_recs / len(recommendation_per_channel)
                >= self.DEFAULT_RECOMMEND_INPUT_WEIGHT_CHANNEL_RATIO
            ):
                input_benefit_formatted = input_weight_benefit_str.format(
                    num_recs, len(recommendation_per_channel)
                )
                channel_str = channel_suggestion_str.format(
                    use_str, input_benefit_formatted
                )
                input_weight_string += channel_str
            else:
                non_benefit_reason_formatted = (
                    input_weight_non_benefit_reasoning.format(
                        num_recs, len(recommendation_per_channel)
                    )
                )
                non_benefit_str = input_weight_non_benefit_str.format(
                    non_benefit_reason_formatted
                )
                channel_str = channel_suggestion_str.format(no_use_str, non_benefit_str)
                input_weight_string += channel_str

        # if no modules looked at, amend return string
        if not added_module:
            input_weight_string += (
                "No applicable layers for suggestions. Only linear and conv valid.\n"
            )

        # return a tuple with the string explanation and the compiled dict info
        return (input_weight_string, input_weight_equalization_info)


class OutlierDetector(DetectorBase):
    r"""
    Determines whether there are significant outliers in activation data around a certain layer.

    This is ideally used in conjunction with information on stationary vs. non-stationary distribution:
        If the data is stationary, and there are significant outliers, then we want to flag them
        We want to do this on a per channel basis for detecting outliers

    Determines whether activation data is flagged as outlier based on if data is stationary and:
        p_r = avg(100th percentile / "reference_percentile"th percentile)
        where:
            p_r is average percentile ratio across all batches in the epoch
            reference_percentile is a percentile values between 0 and 100 exclusive

        if p_r is above some threshold, then we consider the activations to have significant outliers

    Args:
        ratio_threshold (float, optional): The threshold for p_r to determine if there are outliers in activations
            Should be >= 1
            Default: 3.5
        reference_percentile (float, optional): The denominator to find the relative scale of the 100th percentile
            Should be between 0 and 1
            Default: 0.975
        fraction_batches_used_threshold (float, optional): Threshold of fraction of batches per channel to determine outlier
            If fraction is below this, we deem number of samples used to calculate outliers as insignificant and alert user
            regardless of whether we detected outliers or not in channel to take a closer look at channel results
            Should be between 0 and 1
            Default: 0.95
        ch_axis (int, optional): The channel axis being observed to determine input weight equalization
            Default: 1

    * :attr:`ratio_threshold`: The threshold for p_r to determine if there are outliers in activations
        The p_r value (average ratio of 100th percentile/reference_percentile) is compared to ratio_threshold
        If it is significantly greater, then we consider it an outlier
        This threshold was calculated based on the ratio of the percentiles in a normal distribution
        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing

    * :attr:`reference_percentile`: The denominator of the top fraction to find the relative scale of the 100th percentile
        Should be between 0 and 1
        The calculations behind value choice: https://drive.google.com/file/d/1N2wdtXWI-kOH8S7HH4-PYB_NmqzZil4p/view?usp=sharing

    * :attr:`fraction_batches_used_threshold`: The fraction of batches to determine outliers for each channel should be above this
        Some batches may not be used because of 0-based errors, so this is to ensure a good amount of the total batches are used
        Should be between 0 and 1

    * :attr:`ch_axis`: The channel axis being observed to determine outliers

    * :attr:`DEFAULT_PRE_OBSERVER_NAME`: The name of the pre-observer to be inserted for this detector
    """

    # names for the pre observers that are inserted
    DEFAULT_PRE_OBSERVER_NAME: str = "model_report_pre_observer"

    # pre activation prefix
    INPUT_ACTIVATION_PREFIX = "input_activation_"

    # names for dict keys
    OUTLIER_KEY = "outliers_detected"
    NUM_BATCHES_KEY = "outlier_detection_batches_used"
    IS_SUFFICIENT_BATCHES_KEY = "outlier_detection_is_sufficient_batches"
    COMP_METRIC_KEY = "outlier_detection_percentile_ratios"
    RATIO_THRES_KEY = "outlier_detection_ratio_threshold"
    REF_PERCENTILE_KEY = "outlier_detection_reference_percentile"
    CHANNEL_AXIS_KEY = "outlier_detection_channel_axis"
    MAX_VALS_KEY = INPUT_ACTIVATION_PREFIX + "per_channel_max"
    CONSTANT_COUNTS_KEY = "constant_batch_counts"

    def __init__(
        self,
        ratio_threshold: float = 3.5,
        reference_percentile: float = 0.975,
        fraction_batches_used_threshold: float = 0.95,
        ch_axis: int = 1,
    ):
        # initialize the variables of interest
        self.ratio_threshold = ratio_threshold

        # make sure passed in percentile is valid
        assert reference_percentile >= 0 and reference_percentile <= 1
        assert (
            fraction_batches_used_threshold >= 0
            and fraction_batches_used_threshold <= 1
        )
        self.reference_percentile = reference_percentile
        self.fraction_batches_used_threshold = fraction_batches_used_threshold
        self.ch_axis = ch_axis

    def get_detector_name(self) -> str:
        r"""Returns the name of this detector"""
        return "outlier_detector"

    def _supports_insertion(self, module: nn.Module) -> bool:
        r"""Returns whether the given module is supported for observers insertion

        Any module that doesn't have children and isn't an observer itself is supported

        Args
            module: The module to check and ensure is supported

        Returns True if the module is supported by observer, False otherwise
        """
        # case for insertion of module
        # check if the module has any children and isn't observer
        num_children = len(list(module.children()))
        return num_children == 0 and not _is_activation_post_process(module)

    def get_qconfig_info(self, model) -> dict[str, DetectorQConfigInfo]:
        r"""Returns the DetectorQConfigInfo for each module_fqn relevant
        Args
            model (nn.Module or subclass): model to find observer insertion points

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to:
            A DetectorQConfigInfo with the information to generate a QConfig for a specific module
        """
        # currently doesn't do anything for outlier detector
        return {}

    def _supports_report_gen(self, module: nn.Module) -> bool:
        r"""Returns whether the given module is supported for report generation

        Any module that has a model report pre-observer is supported

        Args
            module: The module to check and ensure is supported

        Returns True if the module is supported by observer, False otherwise
        """
        return hasattr(module, self.DEFAULT_PRE_OBSERVER_NAME)

    def determine_observer_insert_points(
        self, prepared_fx_model: GraphModule
    ) -> dict[str, dict[str, Any]]:
        r"""Determines where observers need to be inserted for the Outlier Detector.

        For this detector, we want to place observers in front of supported layers.

        Currently inserts observers for:
            all layers that do not have children (leaf level layers)

        Args:
            prepared_fx_model (GraphModule):  The prepared Fx GraphModule

        Returns a Dict mapping from unique observer fqns (where we want to insert them) to a Dict with:
            key "target_node" -> the node we are trying to observe with this observer (torch.fx.node.Node)
            key "observer_to_insert" -> the observer we wish to insert (ObserverBase)
            key "is_post_observer" -> True if this is meant to be a post-observer for target_node, False if pre-observer
            key "observer_args" -> The arguments that are meant to be passed into the observer
        """
        # observer for this detector is ModelReportObserver
        obs_ctr = ModelReportObserver

        # return dict
        obs_fqn_to_info: dict[str, dict[str, Any]] = {}

        for fqn, module in prepared_fx_model.named_modules():
            # check to see if module is of a supported type
            if self._supports_insertion(module):
                # if it's a supported type, we want to get node and add observer insert locations
                targeted_node = self._get_targeting_node(prepared_fx_model, fqn)

                # add entry for pre-observer
                pre_obs_fqn = fqn + "." + self.DEFAULT_PRE_OBSERVER_NAME

                obs_fqn_to_info[pre_obs_fqn] = {
                    DETECTOR_TARGET_NODE_KEY: targeted_node,
                    DETECTOR_OBS_TO_INSERT_KEY: obs_ctr(
                        ch_axis=self.ch_axis, comp_percentile=self.reference_percentile
                    ),
                    DETECTOR_IS_POST_OBS_KEY: False,
                    DETECTOR_OBS_ARGS_KEY: targeted_node.args,
                }

        return obs_fqn_to_info

    def _calculate_outlier_info(
        self,
        percentile_ratios: torch.Tensor,
        counted_batches: torch.Tensor,
        total_batches: int,
    ) -> dict[str, list[bool]]:
        r"""
        Gives info on whether the percentile ratios calculated would be considered outliers
        Also gives information on whether the collected data is statistically significant to make this claim

        Args:
            percentile_ratios (torch.Tensor): The average percentile_ratios per channel calculated by the observer
            counted_batches (torch.Tensor): The number of batches used for average calculation per tensor
            total_batches (int): The total number of batches that passed through observer in this epoch

        Returns a dictionary mapping:
            "outliers_detected" : list of bools per channel that are true if it is considered an outlier
            "is_sufficient_batches": if o_r was >= fraction_batches_used_threshold:
                where o_r = counted_batches / total_batches
        """
        outlier_dict: dict[str, list[bool]] = {
            self.OUTLIER_KEY: [],
            self.IS_SUFFICIENT_BATCHES_KEY: [],
        }

        # get both as flattened lists for easy mapping
        ratios_list: list = percentile_ratios.tolist()
        num_batches_list: list = counted_batches.tolist()

        # calculate whether channels were statistically significant
        significant_size = [
            batch_size / total_batches >= self.fraction_batches_used_threshold
            for batch_size in num_batches_list
        ]
        outlier_dict[self.IS_SUFFICIENT_BATCHES_KEY] = significant_size

        # calculate for each channel whether it's an outlier or not based on ratio
        outlier_detected = [ratio > self.ratio_threshold for ratio in ratios_list]
        outlier_dict[self.OUTLIER_KEY] = outlier_detected

        # return the dictionary with the two lists
        return outlier_dict

    def _generate_info_dict(self, model: GraphModule) -> dict[str, dict]:
        r"""
        Helper function for generate_detector_report that does the generation of the dictionary.
        This process is done as specified in generate_detector_report documentation

        Args:
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a dict mapping relevant module fqns to:
            whether there were outliers found in activation before
            the number of batches used for each channel
            whether fraction of applicable batches used is above fraction_batches_used_threshold
            their p_r metric compared to the threshold
            the threshold used to make the recommendation
            the reference_percentile used to make the recommendation
            the channel axis used to determine individual channels
            the constant batch counts per channel
            the per channel max values
        """
        # return dictionary mapping observer fqns to desired info
        info_dict: dict[str, dict] = {}

        for fqn, module in model.named_modules():
            # if module is supported and it has a pre-observer
            if self._supports_report_gen(module):
                # get pre observer for the module
                pre_obs: ModelReportObserver = getattr(
                    module, self.DEFAULT_PRE_OBSERVER_NAME
                )

                # get the number of batches and calculated ratio thresholds
                num_batches: torch.Tensor = pre_obs.percentile_batches_tracked
                average_ratios: torch.Tensor = pre_obs.average_percentile_ratio
                channel_batch_cnts: torch.Tensor = pre_obs.constant_channels
                total_batches: int = pre_obs.num_batches_tracked

                # also get the max values
                max_vals: torch.Tensor = pre_obs.max_val

                # we have to specifically modify how we are recording negative ratio for pre-relu layers
                for index, ratio_val in enumerate(average_ratios):
                    # check if we have a negative ratio
                    # a ratio might be negative if we have a situation where the 100th percentile is
                    # > 0 while the nth percentile is < 0, in which case this would not be detected
                    # as an outlier. Since we care more about magnitude, we make it positive.
                    if ratio_val.item() < 0:
                        # first make it positive
                        average_ratios[index] = -ratio_val

                    if ratio_val.item() < 1:
                        # if it's less than 1 we have the flip it as well
                        average_ratios[index] = 1 / ratio_val

                outlier_calcs = self._calculate_outlier_info(
                    average_ratios, num_batches, total_batches
                )

                # calculate whether ratios were outliers
                info_dict[fqn] = {
                    self.CHANNEL_AXIS_KEY: self.ch_axis,
                    self.REF_PERCENTILE_KEY: self.reference_percentile,
                    self.RATIO_THRES_KEY: self.ratio_threshold,
                    self.COMP_METRIC_KEY: average_ratios,
                    self.NUM_BATCHES_KEY: num_batches,
                    self.OUTLIER_KEY: outlier_calcs[self.OUTLIER_KEY],
                    self.IS_SUFFICIENT_BATCHES_KEY: outlier_calcs[
                        self.IS_SUFFICIENT_BATCHES_KEY
                    ],
                    self.CONSTANT_COUNTS_KEY: channel_batch_cnts,
                    self.MAX_VALS_KEY: max_vals,
                }

        return info_dict

    def generate_detector_report(
        self, model: GraphModule
    ) -> tuple[str, dict[str, Any]]:
        r"""
        Determines whether input weight equalization is appropriate for a given module.

        Takes advantage of the ModelReport Observer which records the relevant percentile information

        Args:
            model (GraphModule): The prepared and calibrated GraphModule with inserted ModelReportObservers

        Returns a tuple with two elements:
            String report of of whether there are outliers in the activations around certain modules
            Dictionary mapping modules of interest to:
                whether there were outliers found in activation before
                the number of batches used for each channel
                whether fraction of applicable batches used is above fraction_batches_used_threshold
                their p_r metric compared to the threshold
                the threshold used to make the recommendation
                the reference_percentile used to make the recommendation
                the channel axis used to determine individual channels
                the constant batch counts per channel
                the per channel max values
        """
        # generate the information dictionary of outlier information
        info_dict = self._generate_info_dict(model)

        # now we can generate report based on this information
        outlier_string = "Outlier detection report: \n"

        # added module check
        added_module: bool = False

        # some strings to be formatted depending on module we are adding
        module_suggestion_str = "For Module {} looked at with axis {}: \n"
        channel_suggestion_str = "\tFor channel {}, we found outliers in the preceding activation data with {}.\n"
        channel_max_value_str = "a max value across all batches of {}"
        note_string = "Note: outlier detection is only reliable for {}. We recommend {} to ensure the most accurate results."
        note_distribution = "stationary distributions"
        note_rec = "running the static vs. dynamic detector to ensure activation data before modules above is stationary"

        # suggestion for constant batch check since that can make it no outliers
        constant_str = "\tFor channel {}, we found {} constant value batches. {}\n"
        constant_suggestion = "We recommend taking a look at the dict and data to see how frequent this occurred and why."

        # compile the suggestion string
        for module_fqn in info_dict:
            # get module specific info
            mod_info: dict[str, Any] = info_dict[module_fqn]
            # check to see if we already added high level model desc
            added_model_desc = False
            # look at each individual channel and add a suggestion
            for index, outlier_detected in enumerate(mod_info[self.OUTLIER_KEY]):
                if outlier_detected:
                    # we found at least 1 outlier
                    if not added_model_desc:
                        # add the module level description
                        outlier_string += module_suggestion_str.format(
                            module_fqn, self.ch_axis
                        )
                        added_model_desc = True

                    # we mark that we found at least one outlier
                    added_module = True
                    max_value_found_str = channel_max_value_str.format(
                        mod_info[self.MAX_VALS_KEY][index]
                    )
                    channel_str = channel_suggestion_str.format(
                        index, max_value_found_str
                    )
                    outlier_string += channel_str

                # also check if we found constant batch
                if mod_info[self.CONSTANT_COUNTS_KEY][index] != 0:
                    # make sure we add a module level highlight.
                    if not added_model_desc:
                        # add the module level description
                        outlier_string += module_suggestion_str.format(
                            module_fqn, self.ch_axis
                        )
                        added_model_desc = True

                    constant_values_for_channel = mod_info[self.CONSTANT_COUNTS_KEY][
                        index
                    ]
                    formatted_str = constant_str.format(
                        index, constant_values_for_channel, constant_suggestion
                    )
                    outlier_string += formatted_str
                    # we also added at least one thing to description
                    added_module = True

        # if found outlier, give suggestion, else give default response
        if added_module:
            # compose the note string
            note_composed = note_string.format(note_distribution, note_rec)
            outlier_string += note_composed
        else:
            outlier_string += "There were no outliers found in the activations.\n"

        return (outlier_string, info_dict)
