# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reaching definition analysis.

This analysis attaches a set of a Definition objects to each symbol, one
for each distinct definition that may reach it. The Definition objects are
mutable and may be used by subsequent analyses to further annotate data like
static type and value information.
The analysis also attaches the set of the symbols defined at the entry of
control flow statements.

Requires activity analysis.
"""

import weakref

import gast

from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import transformer


class Definition(object):
  """Definition objects describe a unique definition of a variable.

  Subclasses of this may be used by passing an appropriate factory function to
  resolve.

  Attributes:
    param_of: Optional[ast.AST]
    directives: Dict, optional definition annotations
  """

  def __init__(self):
    self.param_of = None
    self.directives = {}

  def __repr__(self):
    return '%s[%d]' % (self.__class__.__name__, id(self))


class _NodeState(object):
  """Abstraction for the state of the CFG walk for reaching definition analysis.

  This is a value type. Only implements the strictly necessary operators.

  Attributes:
    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
        their possible definitions
  """

  def __init__(self, init_from=None):
    if init_from:
      if isinstance(init_from, _NodeState):
        self.value = {
            s: set(other_infos) for s, other_infos in init_from.value.items()
        }
      elif isinstance(init_from, dict):
        self.value = {s: set((init_from[s],)) for s in init_from}
      else:
        assert False, init_from
    else:
      self.value = {}

  def __eq__(self, other):
    if frozenset(self.value.keys()) != frozenset(other.value.keys()):
      return False
    ret = all(self.value[s] == other.value[s] for s in self.value)
    return ret

  def __ne__(self, other):
    return not self.__eq__(other)

  def __or__(self, other):
    assert isinstance(other, _NodeState)
    result = _NodeState(self)
    for s, other_infos in other.value.items():
      if s in result.value:
        result.value[s].update(other_infos)
      else:
        result.value[s] = set(other_infos)
    return result

  def __sub__(self, other):
    assert isinstance(other, set)
    result = _NodeState(self)
    for s in other:
      result.value.pop(s, None)
    return result

  def __repr__(self):
    return 'NodeState[%s]=%s' % (id(self), repr(self.value))


class Analyzer(cfg.GraphVisitor):
  """CFG visitor that determines reaching definitions at statement level."""

  def __init__(self, graph, definition_factory):
    self._definition_factory = definition_factory
    super(Analyzer, self).__init__(graph)
    self.gen_map = {}

  def init_state(self, _):
    return _NodeState()

  def visit_node(self, node):
    prev_defs_out = self.out[node]

    defs_in = _NodeState()
    for n in node.prev:
      defs_in |= self.out[n]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
      # The definition objects created by each node must be singletons because
      # their ids are used in equality checks.
      if node not in self.gen_map:
        node_symbols = {}
        # Every binding operation (assign, nonlocal, global, etc.) counts as a
        # definition, with the exception of del, which only deletes without
        # creating a new variable.
        newly_defined = ((node_scope.bound | node_scope.globals) -
                         node_scope.deleted)
        for s in newly_defined:
          def_ = self._definition_factory()
          node_symbols[s] = def_
        # Every param receives a definition. Params are not necessarily
        # considered as "modified".
        for s, p in node_scope.params.items():
          def_ = self._definition_factory()
          def_.param_of = weakref.ref(p)
          node_symbols[s] = def_
        self.gen_map[node] = _NodeState(node_symbols)

      gen = self.gen_map[node]
      kill = node_scope.modified | node_scope.deleted
      defs_out = gen | (defs_in - kill)

      gen = self.gen_map[node]
      defs_out = gen | (defs_in - kill)

    else:
      assert self.can_ignore(node), (node.ast_node, node)
      defs_out = defs_in

    self.in_[node] = defs_in
    self.out[node] = defs_out

    return prev_defs_out != defs_out


class TreeAnnotator(transformer.Base):
  """AST visitor that annotates each symbol name with its reaching definitions.

  Simultaneously, the visitor runs the dataflow analysis on each function node,
  accounting for the effect of closures. For example:

    def foo():
      bar = 1
      def baz():
        # bar = 1 reaches here
  """

  def __init__(self, source_info, graphs, definition_factory):
    super(TreeAnnotator, self).__init__(source_info)
    self.allow_skips = False
    self.definition_factory = definition_factory
    self.graphs = graphs
    self.current_analyzer = None
    self.current_cfg_node = None

  def visit_FunctionDef(self, node):
    parent_analyzer = self.current_analyzer
    subgraph = self.graphs[node]

    analyzer = Analyzer(subgraph, self.definition_factory)
    analyzer.visit_forward()

    # Recursively process any remaining subfunctions.
    self.current_analyzer = analyzer
    node.args = self.visit(node.args)
    node.body = self.visit_block(node.body)
    self.current_analyzer = parent_analyzer

    return node

  def visit_Name(self, node):
    if self.current_analyzer is None:
      # Names may appear outside function defs - for example in class
      # definitions.
      return node

    analyzer = self.current_analyzer
    cfg_node = self.current_cfg_node

    assert cfg_node is not None, ('name node, %s, outside of any statement?'
                                  % node.id)

    qn = anno.getanno(node, anno.Basic.QN)
    if isinstance(node.ctx, gast.Load):
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.in_[cfg_node].value.get(qn, ())))
    else:
      anno.setanno(node, anno.Static.DEFINITIONS,
                   tuple(analyzer.out[cfg_node].value.get(qn, ())))

    return node

  def _aggregate_predecessors_defined_in(self, node):
    preds = self.current_analyzer.graph.stmt_prev[node]
    node_defined_in = set()
    for p in preds:
      node_defined_in |= set(self.current_analyzer.out[p].value.keys())
    anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))

  def visit_If(self, node):
    self._aggregate_predecessors_defined_in(node)
    return self.generic_visit(node)

  def visit_For(self, node):
    self._aggregate_predecessors_defined_in(node)

    # Manually accounting for the shortcoming described in
    # cfg.AstToCfg.visit_For.
    parent = self.current_cfg_node
    self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
    node.target = self.visit(node.target)
    self.current_cfg_node = parent

    node.iter = self.visit(node.iter)
    node.body = self.visit_block(node.body)
    node.orelse = self.visit_block(node.orelse)

    return node

  def visit_While(self, node):
    self._aggregate_predecessors_defined_in(node)
    return self.generic_visit(node)

  def visit_Try(self, node):
    self._aggregate_predecessors_defined_in(node)
    return self.generic_visit(node)

  def visit_ExceptHandler(self, node):
    self._aggregate_predecessors_defined_in(node)
    # TODO(mdan): Also track the exception type / name symbols.
    node.body = self.visit_block(node.body)
    return node

  def visit(self, node):
    parent = self.current_cfg_node

    if (self.current_analyzer is not None and
        node in self.current_analyzer.graph.index):
      self.current_cfg_node = self.current_analyzer.graph.index[node]
    node = super(TreeAnnotator, self).visit(node)

    self.current_cfg_node = parent
    return node


def resolve(node, source_info, graphs, definition_factory=Definition):
  """Resolves reaching definitions for each symbol.

  Args:
    node: ast.AST
    source_info: transformer.SourceInfo
    graphs: Dict[ast.FunctionDef, cfg.Graph]
    definition_factory: Callable[[], Definition]
  Returns:
    ast.AST
  """
  visitor = TreeAnnotator(source_info, graphs, definition_factory)
  node = visitor.visit(node)
  return node
