from torch.fx import Graph, GraphModule, Node
from torch.fx._compatibility import compatibility

from .matcher_utils import InternalMatch, SubgraphMatcher


__all__ = ["SubgraphMatcherWithNameNodeMap"]


def _split_to_graph_and_name_node_map(
    gm: GraphModule,
) -> tuple[GraphModule, dict[str, Node]]:
    from torch.fx.graph import _PyTreeInfo
    from torch.utils._pytree import tree_flatten, tree_unflatten

    name_node_map = {}
    for n in gm.graph.nodes:
        if n.op == "output":
            assert gm._out_spec is not None
            output = tree_unflatten(n.args[0], gm._out_spec)
            assert isinstance(
                output, tuple
            ), "Expecting the pattern graph to return a tuple"
            assert (
                len(output) >= 2
            ), "Expecting the pattern graph to have at least two outputs"
            *out, name_node_map = output
            flattened, out_spec = tree_flatten(out)
            assert isinstance(
                name_node_map, dict
            ), "Expecting the input graph to have a dict output as the last element"
            n.args = (flattened,)
            orig_pytree_info = gm._graph._codegen.pytree_info  # type: ignore[attr-defined]
            gm._graph._codegen.pytree_info = _PyTreeInfo(  # type: ignore[attr-defined]
                orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
            )
    gm.recompile()
    return gm, name_node_map


@compatibility(is_backward_compatible=False)
class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
    """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name,
    this requires pattern to have specific format (returning and additional dictionary at the output,
    that has node name as key, and the node in the pattern graph as value, see Example for more details)

    Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during
    initialization since we need to modify the graph (which requires `recompile` the GraphModule)

    Example::
        def pattern(x, weight):
            conv = F.conv2d(x, weight)
            relu = F.relu(conv)
            return relu, {"conv": conv, "relu": relu}

        def target_graph(x, weight):
            conv = F.conv2d(x, weight)
            relu = F.relu(conv)
            relu *= 2
            return relu

        pattern_gm = export_for_training(pattern, example_inputs).module()
        target_gm = export_for_training(target_graph, example_inputs).module()
        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
        matches = matcher.match(target_gm)
        for match in matches:
            match.name_node_map["conv"].meta["annotation"] = ...

    """

    def __init__(
        self,
        pattern_gm: GraphModule,
        match_output: bool = False,
        match_placeholder: bool = False,
        remove_overlapping_matches: bool = True,
        ignore_literals: bool = False,
    ) -> None:
        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
        self.name_node_map = name_node_map
        super().__init__(
            pattern_gm.graph,
            match_output,
            match_placeholder,
            remove_overlapping_matches,
            ignore_literals,
        )

    def match(self, graph: Graph) -> list[InternalMatch]:
        """The returned InternalMatch will have name_node_map populated with a map
        from node name (str) to the target node, e.g.
        {"conv": target_conv_ndoe, "relu": target_relu_node}

        this requires the pattern graph returns an additional
        output of node name to node, e.g. instead of:
        ```
        def pattern(...):
            ...
            return relu
        ```
        we should do:
        ```
        def pattern(...):
            ...
            return relu, {"conv": conv, "relu": relu}
        ``` instead
        """
        internal_matches = super().match(graph)
        for internal_match in internal_matches:
            for k, n in self.name_node_map.items():
                internal_match.name_node_map[k] = internal_match.nodes_map[n]
        return internal_matches
