# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for multi-process testing."""

import multiprocessing
import os
import platform
import sys
import unittest
from absl import app
from absl import logging

from tensorflow.python.eager import test


def is_oss():
  """Returns whether the test is run under OSS."""
  return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]


def _is_enabled():
  # Note that flags may not be parsed at this point and simply importing the
  # flags module causes a variety of unusual errors.
  tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')]
  if is_oss() and tpu_args:
    return False
  if sys.version_info == (3, 8) and platform.system() == 'Linux':
    return False  # TODO(b/171242147)
  return sys.platform != 'win32'


class _AbslProcess:
  """A process that runs using absl.app.run."""

  def __init__(self, *args, **kwargs):
    super(_AbslProcess, self).__init__(*args, **kwargs)
    # Monkey-patch that is carried over into the spawned process by pickle.
    self._run_impl = getattr(self, 'run')
    self.run = self._run_with_absl

  def _run_with_absl(self):
    app.run(lambda _: self._run_impl())


if _is_enabled():

  class AbslForkServerProcess(_AbslProcess,
                              multiprocessing.context.ForkServerProcess):
    """An absl-compatible Forkserver process.

    Note: Forkserver is not available in windows.
    """

  class AbslForkServerContext(multiprocessing.context.ForkServerContext):
    _name = 'absl_forkserver'
    Process = AbslForkServerProcess  # pylint: disable=invalid-name

  multiprocessing = AbslForkServerContext()
  Process = multiprocessing.Process

else:

  class Process(object):
    """A process that skips test (until windows is supported)."""

    def __init__(self, *args, **kwargs):
      del args, kwargs
      raise unittest.SkipTest(
          'TODO(b/150264776): Windows is not supported in MultiProcessRunner.')


_test_main_called = False


def _set_spawn_exe_path():
  """Set the path to the executable for spawned processes.

  This utility searches for the binary the parent process is using, and sets
  the executable of multiprocessing's context accordingly.

  Raises:
    RuntimeError: If the binary path cannot be determined.
  """
  # TODO(b/150264776): This does not work with Windows. Find a solution.
  if sys.argv[0].endswith('.py'):
    def guess_path(package_root):
      # If all we have is a python module path, we'll need to make a guess for
      # the actual executable path.
      if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
        # Guess the binary path under bazel. For target
        # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
        # argv[0] is in the form of
        # /.../tensorflow/python/distribute/input_lib_test.py
        # and the binary is
        # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
        package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
        binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
        possible_path = os.path.join(package_root_base, package_root,
                                     binary)
        logging.info('Guessed test binary path: %s', possible_path)
        if os.access(possible_path, os.X_OK):
          return possible_path
        return None
    path = guess_path('org_tensorflow')
    if not path:
      path = guess_path('org_keras')
    if path is None:
      logging.error(
          'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
          sys.argv[0], os.environ)
      raise RuntimeError('Cannot determine binary path')
    sys.argv[0] = path
  # Note that this sets the executable for *all* contexts.
  multiprocessing.get_context().set_executable(sys.argv[0])


def _if_spawn_run_and_exit():
  """If spawned process, run requested spawn task and exit. Else a no-op."""

  # `multiprocessing` module passes a script "from multiprocessing.x import y"
  # to subprocess, followed by a main function call. We use this to tell if
  # the process is spawned. Examples of x are "forkserver" or
  # "semaphore_tracker".
  is_spawned = ('-c' in sys.argv[1:] and
                sys.argv[sys.argv.index('-c') +
                         1].startswith('from multiprocessing.'))

  if not is_spawned:
    return
  cmd = sys.argv[sys.argv.index('-c') + 1]
  # As a subprocess, we disregarding all other interpreter command line
  # arguments.
  sys.argv = sys.argv[0:1]

  # Run the specified command - this is expected to be one of:
  # 1. Spawn the process for semaphore tracker.
  # 2. Spawn the initial process for forkserver.
  # 3. Spawn any process as requested by the "spawn" method.
  exec(cmd)  # pylint: disable=exec-used
  sys.exit(0)  # Semaphore tracker doesn't explicitly sys.exit.


def test_main():
  """Main function to be called within `__main__` of a test file."""
  global _test_main_called
  _test_main_called = True

  os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

  if _is_enabled():
    _set_spawn_exe_path()
    _if_spawn_run_and_exit()

  # Only runs test.main() if not spawned process.
  test.main()


def initialized():
  """Returns whether the module is initialized."""
  return _test_main_called
