# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Live variable analysis.

See https://en.wikipedia.org/wiki/Live_variable_analysis for a definition of
the following idioms: live variable, live in, live out, which are used
throughout this file.

This analysis attaches the following:
 * symbols that are live at the exit of control flow statements
 * symbols that are live at the entry of control flow statements

Requires activity analysis.
"""

import gast

from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import annos


class Analyzer(cfg.GraphVisitor):
  """CFG visitor that performs liveness analysis at statement level."""

  def __init__(self, graph, include_annotations):
    super(Analyzer, self).__init__(graph)
    self.include_annotations = include_annotations

  def init_state(self, _):
    return set()

  def lamba_check(self, fn_ast_node):
    if isinstance(fn_ast_node, gast.Lambda):
      # Exception: lambda functions are assumed to be used only in the
      # place where they are defined, and not later.
      return True
    return False

  def visit_node(self, node):
    prev_live_in = self.in_[node]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

      gen = node_scope.read
      if not self.include_annotations:
        gen -= node_scope.annotations
      # TODO(mdan): verify whether composites' parents need to be added.
      # E.g. whether x needs to be added if x.y is live. Theoretically the
      # activity analysis should have both so that wouldn't be needed.
      kill = node_scope.modified | node_scope.deleted

      live_out = set()
      for n in node.next:
        live_out |= self.in_[n]
      live_in = gen | (live_out - kill)

      reaching_functions = anno.getanno(
          node.ast_node, anno.Static.DEFINED_FNS_IN)
      for fn_ast_node in reaching_functions:
        if self.lamba_check(fn_ast_node):
          continue
        fn_scope = anno.getanno(fn_ast_node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
        # Any closure of a reaching function definition is conservatively
        # considered live.
        live_in |= (fn_scope.read - fn_scope.bound)

    else:
      assert self.can_ignore(node), (node.ast_node, node)

      live_out = set()
      for n in node.next:
        live_out |= self.in_[n]
      live_in = live_out

    self.in_[node] = live_in
    self.out[node] = live_out

    # TODO(mdan): Move this to the superclass?
    return prev_live_in != live_in


class TreeAnnotator(transformer.Base):
  """Runs liveness analysis on each of the functions defined in the AST.

  If a function defined other local functions, those will have separate CFGs.
  However, dataflow analysis needs to tie up these CFGs to properly emulate the
  effect of closures. In the case of liveness, the parent function's live
  variables must account for the variables that are live at the entry of each
  subfunction. For example:

    def foo():
      # baz is live from here on
      def bar():
        print(baz)

  This analyzer runs liveness analysis on each individual function, accounting
  for the effect above.
  """

  def __init__(self, source_info, graphs, include_annotations):
    super(TreeAnnotator, self).__init__(source_info)
    self.include_annotations = include_annotations
    self.allow_skips = False
    self.graphs = graphs
    self.current_analyzer = None

  def visit(self, node):
    node = super(TreeAnnotator, self).visit(node)
    if (self.current_analyzer is not None and
        isinstance(node, gast.stmt) and
        node in self.current_analyzer.graph.index):
      cfg_node = self.current_analyzer.graph.index[node]
      anno.setanno(node, anno.Static.LIVE_VARS_IN,
                   frozenset(self.current_analyzer.in_[cfg_node]))
    return node

  def _analyze_function(self, node, is_lambda):
    parent_analyzer = self.current_analyzer

    analyzer = Analyzer(self.graphs[node], self.include_annotations)
    analyzer.visit_reverse()
    self.current_analyzer = analyzer
    node = self.generic_visit(node)

    self.current_analyzer = parent_analyzer
    return node

  def visit_Lambda(self, node):
    return self._analyze_function(node, is_lambda=True)

  def visit_FunctionDef(self, node):
    return self._analyze_function(node, is_lambda=False)

  def _block_statement_live_out(self, node):
    successors = self.current_analyzer.graph.stmt_next[node]
    stmt_live_out = set()
    for s in successors:
      stmt_live_out.update(self.current_analyzer.in_[s])
    anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(stmt_live_out))
    return node

  def _block_statement_live_in(self, node, entry_node):
    if entry_node in self.current_analyzer.graph.index:
      cfg_node = self.current_analyzer.graph.index[entry_node]
      stmt_live_in = frozenset(self.current_analyzer.in_[cfg_node])
    else:
      assert anno.hasanno(entry_node, anno.Static.LIVE_VARS_IN), (
          'If not matching a CFG node, must be a block statement:'
          ' {}'.format(entry_node))
      stmt_live_in = anno.getanno(entry_node, anno.Static.LIVE_VARS_IN)
    anno.setanno(node, anno.Static.LIVE_VARS_IN, stmt_live_in)
    return node

  def visit_If(self, node):
    node = self.generic_visit(node)
    node = self._block_statement_live_out(node)
    return self._block_statement_live_in(node, node.test)

  def visit_For(self, node):
    node = self.generic_visit(node)
    node = self._block_statement_live_out(node)
    return self._block_statement_live_in(node, node.iter)

  def visit_While(self, node):
    node = self.generic_visit(node)
    node = self._block_statement_live_out(node)
    return self._block_statement_live_in(node, node.test)

  def visit_Try(self, node):
    node = self.generic_visit(node)
    node = self._block_statement_live_out(node)
    return self._block_statement_live_in(node, node.body[0])

  def visit_ExceptHandler(self, node):
    node = self.generic_visit(node)
    node = self._block_statement_live_out(node)
    return self._block_statement_live_in(node, node.body[0])

  def visit_With(self, node):
    node = self.generic_visit(node)
    return self._block_statement_live_in(node, node.items[0])

  def visit_Expr(self, node):
    node = self.generic_visit(node)
    cfg_node = self.current_analyzer.graph.index[node]
    anno.setanno(node, anno.Static.LIVE_VARS_OUT,
                 frozenset(self.current_analyzer.out[cfg_node]))
    return node


# TODO(mdan): Investigate the possibility of removing include_annotations.
def resolve(node, source_info, graphs, include_annotations=True):
  """Resolves the live symbols at the exit of control flow statements.

  Args:
    node: ast.AST
    source_info: transformer.SourceInfo
    graphs: Dict[ast.FunctionDef, cfg.Graph]
    include_annotations: Bool, whether type annotations should be included in
      the analysis.
  Returns:
    ast.AST
  """
  node = TreeAnnotator(source_info, graphs, include_annotations).visit(node)
  return node
