# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An analysis that determines the reach of a function definition.

A function definition is said to reach a statement if that function may exist
(and therefore may be called) when that statement executes.
"""

import gast

from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import transformer


class Definition(object):
  """Definition objects describe a unique definition of a function."""

  def __init__(self, def_node):
    self.def_node = def_node


class _NodeState(object):
  """Abstraction for the state of the CFG walk for reaching definition analysis.

  This is a value type. Only implements the strictly necessary operators.

  Attributes:
    value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
        their possible definitions
  """

  def __init__(self, init_from=None):
    if init_from:
      self.value = set(init_from)
    else:
      self.value = set()

  def __eq__(self, other):
    return self.value == other.value

  def __ne__(self, other):
    return self.value != other.value

  def __or__(self, other):
    assert isinstance(other, _NodeState)
    result = _NodeState(self.value)
    result.value.update(other.value)
    return result

  def __add__(self, value):
    result = _NodeState(self.value)
    result.value.add(value)
    return result

  def __repr__(self):
    return 'NodeState[%s]=%s' % (id(self), repr(self.value))


class Analyzer(cfg.GraphVisitor):
  """CFG visitor that determines reaching definitions at statement level."""

  def __init__(self, graph, external_defs):
    super(Analyzer, self).__init__(graph)
    # This allows communicating that nodes have extra reaching definitions,
    # e.g. those that a function closes over.
    self.external_defs = external_defs

  def init_state(self, _):
    return _NodeState()

  def visit_node(self, node):
    prev_defs_out = self.out[node]

    if node is self.graph.entry:
      defs_in = _NodeState(self.external_defs)
    else:
      defs_in = prev_defs_out

    for n in node.prev:
      defs_in |= self.out[n]

    defs_out = defs_in
    if isinstance(node.ast_node, (gast.Lambda, gast.FunctionDef)):
      defs_out += node.ast_node

    self.in_[node] = defs_in
    self.out[node] = defs_out

    return prev_defs_out != defs_out


class TreeAnnotator(transformer.Base):
  """AST visitor that annotates each symbol name with its reaching definitions.

  Simultaneously, the visitor runs the dataflow analysis on each function node,
  accounting for the effect of closures. For example:

    def foo():
      def f():
        pass
      def g():
        # `def f` reaches here
  """

  def __init__(self, source_info, graphs):
    super(TreeAnnotator, self).__init__(source_info)
    self.graphs = graphs
    self.allow_skips = False
    self.current_analyzer = None

  def _proces_function(self, node):
    parent_analyzer = self.current_analyzer
    subgraph = self.graphs[node]

    if (self.current_analyzer is not None
        and node in self.current_analyzer.graph.index):
      cfg_node = self.current_analyzer.graph.index[node]
      defined_in = self.current_analyzer.in_[cfg_node].value
    else:
      defined_in = ()

    analyzer = Analyzer(subgraph, defined_in)
    analyzer.visit_forward()

    self.current_analyzer = analyzer
    node = self.generic_visit(node)
    self.current_analyzer = parent_analyzer
    return node

  def visit_FunctionDef(self, node):
    return self._proces_function(node)

  def visit_Lambda(self, node):
    return self._proces_function(node)

  def visit(self, node):
    # This can happen before entering the top level function
    if (self.current_analyzer is not None
        and node in self.current_analyzer.graph.index):
      cfg_node = self.current_analyzer.graph.index[node]
      anno.setanno(node, anno.Static.DEFINED_FNS_IN,
                   self.current_analyzer.in_[cfg_node].value)

    extra_node = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST, default=None)
    if extra_node is not None:
      cfg_node = self.current_analyzer.graph.index[extra_node]
      anno.setanno(extra_node, anno.Static.DEFINED_FNS_IN,
                   self.current_analyzer.in_[cfg_node].value)

    return super(TreeAnnotator, self).visit(node)


def resolve(node, source_info, graphs):
  """Resolves reaching definitions for each symbol.

  Args:
    node: ast.AST
    source_info: transformer.SourceInfo
    graphs: Dict[ast.FunctionDef, cfg.Graph]
  Returns:
    ast.AST
  """
  visitor = TreeAnnotator(source_info, graphs)
  node = visitor.visit(node)
  return node
