# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Conversion to A-normal form.

The general idea of A-normal form is that every intermediate value is
explicitly named with a variable.  For more, see
https://en.wikipedia.org/wiki/A-normal_form.

The specific converters used here are based on Python AST semantics as
documented at https://greentreesnakes.readthedocs.io/en/latest/.
"""

import collections

import gast

from tensorflow.python.autograph.pyct import gast_util
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer


# TODO(mdan): Replace with naming.Namer.
class DummyGensym:
  """A dumb gensym that suffixes a stem by sequential numbers from 1000."""

  def __init__(self):
    # A proper implementation needs to account for:
    #   * ctx.info.namespace
    #   * all the symbols defined in the AST
    #   * the symbols generated so far
    self._idx = 0

  def new_name(self, stem='tmp'):
    self._idx += 1
    return stem + '_' + str(1000 + self._idx)


REPLACE = lambda _1, _2, _3: True
LEAVE = lambda _1, _2, _3: False
ANY = object()


class ASTEdgePattern(collections.namedtuple(
    'ASTEdgePattern', ['parent', 'field', 'child'])):
  """A pattern defining a type of AST edge.

  This consists of three components:
  - The type of the parent node, checked with isinstance,
  - The name of the field, checked with string equality, and
  - The type of the child node, also checked with isinstance.
  If all three match, the whole pattern is considered to match.

  In all three slots, the special value `anf.ANY` is treated as "match
  anything".  The internal nodes are produced from the `gast` library rather
  than the standard `ast` module, which may affect `isinstance` checks.
  """
  __slots__ = ()

  def matches(self, parent, field, child):
    """Computes whether this pattern matches the given edge."""
    if self.parent is ANY or isinstance(parent, self.parent):
      pass  # OK
    else:
      return False
    if self.field is ANY or field == self.field:
      pass  # OK
    else:
      return False
    return self.child is ANY or isinstance(child, self.child)


class AnfTransformer(transformer.Base):
  """Performs the conversion to A-normal form (ANF)."""

  # The algorithm is a postorder recursive tree walk.  Any given node A may, in
  # general, require creation of a series B of Assign statements, which compute
  # and explicitly name the intermediate values needed to compute the value of
  # A.  If A was already a statement, it can be replaced with the sequence B +
  # [A].  If A was an expression, B needs to be propagated up the tree until a
  # statement is encountered.  Since the `ast.NodeTransformer` framework makes
  # no provision for subtraversals returning side information, this class
  # accumulates the sequence B in an instance variable.

  # The only other subtlety is that some Python statements (like `if`) have both
  # expression fields (`test`) and statement list fields (`body` and `orelse`).
  # Any additional assignments needed to name all the intermediate values in the
  # `test` can be prepended to the `if` node, but assignments produced by
  # processing the `body` and the `orelse` need to be kept together with them,
  # and not accidentally lifted out of the `if`.

  def __init__(self, ctx, config):
    """Creates an ANF transformer.

    Args:
      ctx: transformer.Context
      config: Configuration
    """
    super(AnfTransformer, self).__init__(ctx)
    if config is None:
      # These could be pulled out, but are generally considered to already be in
      # A-normal form.  Thus they are left in by default, but could be pulled
      # out if the configuration calls for it.
      if gast_util.GAST2:
        literal_node_types = (
            gast.Num, gast.Str, gast.Bytes, gast.NameConstant,
            gast.Name  # Name is here to cover True, False, and None in Python 2
        )
      elif gast_util.GAST3:
        literal_node_types = (
            gast.Constant,
            gast.Name  # Name is here to cover True, False, and None in Python 2
        )
      else:
        assert False

      self._overrides = [
          (ASTEdgePattern(ANY, ANY, literal_node_types), LEAVE),
          (ASTEdgePattern(ANY, ANY, gast.expr), REPLACE)]
    else:
      self._overrides = config
    self._gensym = DummyGensym()
    self._pending_statements = []

  def _consume_pending_statements(self):
    ans = self._pending_statements
    self._pending_statements = []
    return ans

  def _add_pending_statement(self, stmt):
    self._pending_statements.append(stmt)

  def _match(self, pattern, parent, field, child):
    if pattern is ANY:
      return True
    else:
      return pattern.matches(parent, field, child)

  def _should_transform(self, parent, field, child):
    for pat, result in self._overrides:
      if self._match(pat, parent, field, child):
        return result(parent, field, child)
    # Fell off the end of the pattern list: do not transform
    return False

  def _do_transform_node(self, node):
    temp_name = self._gensym.new_name()
    temp_assign = templates.replace(
        'temp_name = expr', temp_name=temp_name, expr=node)[0]
    self._add_pending_statement(temp_assign)
    answer = templates.replace('temp_name', temp_name=temp_name)[0]
    return answer

  def _ensure_node_in_anf(self, parent, field, node):
    """Puts `node` in A-normal form, by replacing it with a variable if needed.

    The exact definition of A-normal form is given by the configuration.  The
    parent and the incoming field name are only needed because the configuration
    may be context-dependent.

    Args:
      parent: An AST node, the parent of `node`.
      field: The field name under which `node` is the child of `parent`.
      node: An AST node, potentially to be replaced with a variable reference.

    Returns:
      node: An AST node; the argument if transformation was not necessary,
        or the new variable reference if it was.
    """
    if node is None:
      return node
    if _is_trivial(node):
      return node
    if isinstance(node, list):
      # If something's field was actually a list, e.g., variadic arguments.
      return [self._ensure_node_in_anf(parent, field, n) for n in node]
    if isinstance(node, gast.keyword):
      node.value = self._ensure_node_in_anf(parent, field, node.value)
      return node
    if isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
      # These nodes aren't really extractable in their own right, but their
      # subnodes might be.  Propagate the parent and field name to the child
      # nodes, instead of querying the configuration for children of, e.g.,
      # gast.Starred.
      return self._ensure_fields_in_anf(node, parent, field)
    if self._should_transform(parent, field, node):
      return self._do_transform_node(node)
    else:
      return node

  def _ensure_fields_in_anf(self, node, parent=None, super_field=None):
    for field in node._fields:
      if field.startswith('__'):
        continue
      parent_supplied = node if parent is None else parent
      field_supplied = field if super_field is None else super_field
      setattr(node, field, self._ensure_node_in_anf(
          parent_supplied, field_supplied, getattr(node, field)))
    return node

  def _visit_strict_statement(self, node, children_ok_to_transform=True):
    assert not self._pending_statements
    node = self.generic_visit(node)
    if children_ok_to_transform:
      self._ensure_fields_in_anf(node)
    results = self._consume_pending_statements()
    results.append(node)
    return results

  def _visit_trivial_only_statement(self, node, msg):
    assert not self._pending_statements
    node = self.generic_visit(node)
    self._ensure_fields_in_anf(node)
    if self._pending_statements:
      raise ValueError(msg)
    else:
      return node

  def _visit_strict_expression(self, node):
    node = self.generic_visit(node)
    self._ensure_fields_in_anf(node)
    return node

  def _visit_trivial_only_expression(self, node, msg):
    k = len(self._pending_statements)
    node = self.generic_visit(node)
    self._ensure_fields_in_anf(node)
    # This check relies on there being no opportunities to consume pending
    # statements while traversing children of an expression.
    if len(self._pending_statements) != k:
      raise ValueError(msg)
    else:
      return node

  # Note on code order: These are listed in the same order as the grammar
  # elements on https://github.com/serge-sans-paille/gast

  # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.

  def visit_Return(self, node):
    return self._visit_strict_statement(node)

  def visit_Delete(self, node):
    return self._visit_strict_statement(node, children_ok_to_transform=False)

  def visit_Assign(self, node):
    return self._visit_strict_statement(node, children_ok_to_transform=False)

  def visit_AugAssign(self, node):
    return self._visit_strict_statement(node, children_ok_to_transform=False)

  def visit_Print(self, node):
    return self._visit_strict_statement(node)

  def visit_For(self, node):
    assert not self._pending_statements
    # It's important to visit node.iter first, because any statements created
    # thereby need to live outside the body.
    self.visit(node.iter)
    node.iter = self._ensure_node_in_anf(node, 'iter', node.iter)
    iter_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.iter, but that is correct because by
    # this point the node.iter link has been checked.  It may be somewhat
    # expensive if the configuration didn't call for transforming node.iter, as
    # then it may be large and will be uselessly transformed again.  This
    # behavior is what causes the documented effect that configuration callables
    # may be invoked more than once of the same links; if the code is rewritten
    # not to do that (anywhere), the docstring of `transform` should be updated.
    node = self.generic_visit(node)
    assert not self._pending_statements
    iter_stmts.append(node)
    return iter_stmts

  def visit_AsyncFor(self, node):
    msg = ('Nontrivial AsyncFor nodes not supported yet '
           '(need to think through the semantics).')
    return self._visit_trivial_only_statement(node, msg)

  def visit_While(self, node):
    assert not self._pending_statements
    self.visit(node.test)
    node.test = self._ensure_node_in_anf(node, 'test', node.test)
    if self._pending_statements:
      msg = ('While with nontrivial test not supported yet '
             '(need to avoid precomputing the test).')
      raise ValueError(msg)
    # If traversing node.test yielded no statements extracted, the generic visit
    # will do the right thing.
    return self.generic_visit(node)

  def visit_If(self, node):
    assert not self._pending_statements
    # It's important to visit node.test first, because any statements created
    # thereby need to live outside the body.
    self.visit(node.test)
    node.test = self._ensure_node_in_anf(node, 'test', node.test)
    condition_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.test, but that is correct because by
    # this point the node.test link has been checked.  It may be somewhat
    # expensive if the configuration didn't call for transforming node.test, as
    # then it may be large and will be uselessly transformed again.  This
    # happens in several places.
    node = self.generic_visit(node)
    assert not self._pending_statements
    condition_stmts.append(node)
    return condition_stmts

  def visit_With(self, node):
    assert not self._pending_statements
    # It's important to visit node.items first, because any statements created
    # thereby need to live outside the body.
    for item in node.items:
      self.visit(item)
    node.items = [self._ensure_node_in_anf(node, 'items', n)
                  for n in node.items]
    contexts_stmts = self._consume_pending_statements()
    # This generic_visit will revisit node.items, but that is correct because by
    # this point the node.items link has been checked.  It may be somewhat
    # expensive if the configuration didn't call for transforming node.items, as
    # then it may be large and will be uselessly transformed again.  This
    # happens in several places.
    node = self.generic_visit(node)
    assert not self._pending_statements
    contexts_stmts.append(node)
    return contexts_stmts

  def visit_AsyncWith(self, node):
    msg = ('Nontrivial AsyncWith nodes not supported yet '
           '(need to think through the semantics).')
    return self._visit_trivial_only_statement(node, msg)

  def visit_Raise(self, node):
    return self._visit_strict_statement(node)

  # Try should be correct by default.

  def visit_Assert(self, node):
    msg = ('Nontrivial Assert nodes not supported yet '
           '(need to avoid computing the test when assertions are off, and '
           'avoid computing the irritant when the assertion does not fire).')
    return self._visit_trivial_only_statement(node, msg)

  # Import and ImportFrom should be correct by default.

  def visit_Exec(self, node):
    return self._visit_strict_statement(node)

  # Global and Nonlocal should be correct by default.

  def visit_Expr(self, node):
    return self._visit_strict_statement(node, children_ok_to_transform=False)

  # Pass, Break, and Continue should be correct by default.

  def visit_BoolOp(self, node):
    msg = ('Nontrivial BoolOp nodes not supported yet '
           '(need to preserve short-circuiting semantics).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_BinOp(self, node):
    return self._visit_strict_expression(node)

  def visit_UnaryOp(self, node):
    return self._visit_strict_expression(node)

  def visit_Lambda(self, node):
    msg = ('Nontrivial Lambda nodes not supported '
           '(cannot insert statements into lambda bodies).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_IfExp(self, node):
    msg = ('Nontrivial IfExp nodes not supported yet '
           '(need to convert to If statement, to evaluate branches lazily '
           'and insert statements into them).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_Dict(self, node):
    return self._visit_strict_expression(node)

  def visit_Set(self, node):
    return self._visit_strict_expression(node)

  def visit_ListComp(self, node):
    msg = ('ListComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_SetComp(self, node):
    msg = ('SetComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_DictComp(self, node):
    msg = ('DictComp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_GeneratorExp(self, node):
    msg = ('GeneratorExp nodes not supported '
           '(need to convert to a form that tolerates '
           'assignment statements in clause bodies).')
    raise ValueError(msg)

  def visit_Await(self, node):
    msg = ('Nontrivial Await nodes not supported yet '
           '(need to think through the semantics).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_Yield(self, node):
    return self._visit_strict_expression(node)

  def visit_YieldFrom(self, node):
    msg = ('Nontrivial YieldFrom nodes not supported yet '
           '(need to unit-test them in Python 2).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_Compare(self, node):
    if len(node.ops) > 1:
      msg = ('Multi-ary compare nodes not supported yet '
             '(need to preserve short-circuiting semantics).')
      raise ValueError(msg)
    return self._visit_strict_expression(node)

  def visit_Call(self, node):
    return self._visit_strict_expression(node)

  def visit_Repr(self, node):
    msg = ('Nontrivial Repr nodes not supported yet '
           '(need to research their syntax and semantics).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_FormattedValue(self, node):
    msg = ('Nontrivial FormattedValue nodes not supported yet '
           '(need to unit-test them in Python 2).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_JoinedStr(self, node):
    msg = ('Nontrivial JoinedStr nodes not supported yet '
           '(need to unit-test them in Python 2).')
    return self._visit_trivial_only_expression(node, msg)

  def visit_Attribute(self, node):
    return self._visit_strict_expression(node)

  def visit_Subscript(self, node):
    return self._visit_strict_expression(node)

  # Starred and Name are correct by default, because the right thing to do is to
  # just recur.

  def visit_List(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.ctx, gast.Store):
      self._ensure_fields_in_anf(node)
    return node

  def visit_Tuple(self, node):
    node = self.generic_visit(node)
    if not isinstance(node.ctx, gast.Store):
      self._ensure_fields_in_anf(node)
    return node


def _is_py2_name_constant(node):
  return isinstance(node, gast.Name) and node.id in ['True', 'False', 'None']


def _is_trivial(node):
  """Returns whether to consider the given node 'trivial'.

  The definition of 'trivial' is a node that can't meaningfully be pulled out
  into its own assignment statement.

  This is surprisingly difficult to do robustly across versions of Python and
  gast, as the parsing of constants has changed, if I may, constantly.

  Args:
    node: An AST node to check for triviality

  Returns:
    trivial: A Python `bool` indicating whether the node is trivial.
  """
  trivial_node_types = (
      # Variable names
      gast.Name,
      # Non-nodes that show up as AST fields
      bool,
      str,
      # Binary operators
      gast.Add,
      gast.Sub,
      gast.Mult,
      gast.Div,
      gast.Mod,
      gast.Pow,
      gast.LShift,
      gast.RShift,
      gast.BitOr,
      gast.BitXor,
      gast.BitAnd,
      gast.FloorDiv,
      # Unary operators
      gast.Invert,
      gast.Not,
      gast.UAdd,
      gast.USub,
      # Comparison operators
      gast.Eq,
      gast.NotEq,
      gast.Lt,
      gast.LtE,
      gast.Gt,
      gast.GtE,
      gast.Is,
      gast.IsNot,
      gast.In,
      gast.NotIn,
      # Other leaf nodes that don't make sense standalone.
      gast.expr_context,
  )
  if isinstance(node, trivial_node_types) and not _is_py2_name_constant(node):
    return True
  if gast_util.is_ellipsis(node):
    return True

  return False


def transform(node, ctx, config=None):
  """Converts the given node to A-normal form (ANF).

  The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form

  The specific converters used here are based on Python AST semantics as
  documented at https://greentreesnakes.readthedocs.io/en/latest/.

  What exactly should be considered A-normal form for any given programming
  language is not completely obvious.  The transformation defined here is
  therefore configurable as to which syntax to replace with a fresh variable and
  which to leave be.  The configuration is intentionally flexible enough to
  define very precise variable insertion transformations, should that be
  desired.

  The configuration is a list of syntax rules, each of which is a 2-tuple:
  - An `ASTEdgePattern` (which see) defining a type of AST edge, and
  - Whether to transform children of such edges.
  The special object `anf.ANY` may be used as a pattern that matches all edges.

  Each replacement directive is one of three possible things:
  - The object `anf.REPLACE`, meaning "Replace this child node with a variable",
  - The object `anf.LEAVE`, meaning "Do not replace this child node with a
    variable", or
  - A Python callable.  If a callable, it is called with the parent node, the
    field name, and the child node, and must compute a boolean indicating
    whether to transform the child node or not.  The callable is free to use
    whatever context information it chooses.  The callable may be invoked more
    than once on the same link, and must produce the same answer each time.

  The syntax rules are tested in order, and the first match governs.  If no rule
  matches, the node is not transformed.

  The above rules notwithstanding,
  - Variable references are never replaced with (fresh) variables, as that would
    accomplish nothing.
  - The left-hand children of Assign and AugAssign nodes, and the children of
    Del nodes, are never replaced with variables, as that would break their
    semantics.
  - The right-hand children of Assign nodes are never replaced with variables,
    as the original assignment would still have to be present in the result
    to define the new variable.  (That is, there's no point in transforming
    `x = sin(y)` into `tmp = sin(y); x = tmp`.)
  - The right-hand children of AugAssign nodes are never replaced with variables
    either, but only because the difference from Assign was considered a
    potential source of confusion (and it would have been slightly awkward in
    the code to treat the RHS differently than the LHS).
  - Various special-purpose AST nodes are not exposed to the configuration, lest
    the transform produce invalid syntax like, e.g., `tmp = +; x = 1 tmp 2`.

  For example, the configuration
  ```python
  [(anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.expr), anf.REPLACE)]
  ```
  gives explicit fresh names to all expressions regardless of context (except as
  outlined above), whereas
  ```python
  [(anf.ASTEdgePattern(gast.If, "test", anf.ANY), anf.REPLACE)]
  ```
  only transforms the conditionals of `if` statements (but not, e.g., `while`).

  If no configuration is supplied, the default behavior is to transform all
  expressions except literal constants, which is defined as a configuration as
  ```python
  # For Python 3, and gast library versions before 0.3
  literals = (gast.Num, gast.Str, gast.Bytes, gast.NameConstant)
  [(anf.ASTEdgePattern(anf.ANY, anf.ANY, literals), anf.LEAVE),
   (anf.ASTEdgePattern(anf.ANY, anf.ANY, gast.expr), anf.REPLACE)]
  ```

  Args:
    node: The node to transform.
    ctx: transformer.EntityInfo.  TODO(mdan): What information does this
      argument provide?
    config: Optional ANF configuration.  If omitted, ANF replaces all expression
      expect literal constants.
  """
  return AnfTransformer(ctx, config).visit(node)
