# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===================================================================
"""TPU system metadata and associated tooling."""

import collections

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu
from tensorflow.python.util.tf_export import tf_export

_PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000  # 10 min
_RETRY_TIMES = 12 * 24  # 1 day
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000  # 5 mins

_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')


@tf_export('tpu.experimental.TPUSystemMetadata')
class TPUSystemMetadata(
    collections.namedtuple('TPUSystemMetadata', [
        'num_cores',
        'num_hosts',
        'num_of_cores_per_host',
        'topology',
        'devices',
    ])):
  """Describes some metadata about the TPU system.

  Attributes:
    num_cores: interger. Total number of TPU cores in the TPU system.
    num_hosts: interger. Total number of hosts (TPU workers) in the TPU system.
    num_of_cores_per_host: interger. Number of TPU cores per host (TPU worker).
    topology: an instance of `tf.tpu.experimental.Topology`, which describes the
      physical topology of TPU system.
    devices: a tuple of strings, which describes all the TPU devices in the
      system.
  """

  def __new__(cls, num_cores, num_hosts, num_of_cores_per_host, topology,
              devices):
    return super(TPUSystemMetadata,
                 cls).__new__(cls, num_cores, num_hosts, num_of_cores_per_host,
                              topology, devices)


def _query_tpu_system_metadata(master_address, cluster_def=None,
                               query_topology=False):
  """Automatically detects the TPU system metadata in the system."""
  tpu_core_count = 0
  devices = []
  device_dict = collections.defaultdict(list)

  if context.executing_eagerly():
    logical_devices = config.list_logical_devices()

    # We want the output type to match in both eager and session mode
    devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name),  # pylint: disable=protected-access
                                             d.device_type, 0, 0)
               for d in logical_devices]
  else:
    # TODO(b/120564445): Replace with standard library for retries.
    retry_count = 1
    while True:
      logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
                   master_address)
      try:
        with ops.Graph().as_default():
          with session_lib.Session(
              master_address,
              config=get_session_config_with_timeout(
                  _PINGING_MASTER_TIMEOUT_IN_MS,
                  cluster_def)) as sess:
            devices = sess.list_devices()
            break
      except errors.DeadlineExceededError:
        msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
               'not be ready (still scheduling) or the Tensorflow master '
               'address is incorrect: got (%s).' %
               (master_address))

        # TODO(xiejw): For local or grpc master we might not need retry logic
        # here.
        if retry_count <= _RETRY_TIMES:
          logging.warning('%s', msg)
          logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
          retry_count += 1
        else:
          raise ValueError(msg)

  for device in devices:
    spec = tf_device.DeviceSpec.from_string(device.name)
    if spec.device_type == 'TPU':
      device_dict[spec.task].append(spec.device_index)
      tpu_core_count += 1

  num_of_cores_per_host = 0
  if tpu_core_count:
    num_cores_per_host_set = set(
        [len(core_ids) for core_ids in device_dict.values()])
    if len(num_cores_per_host_set) != 1:
      raise RuntimeError(
          'TPU cores on each host is not same. This should not happen!. '
          'devices: {}'.format(devices))
    num_of_cores_per_host = num_cores_per_host_set.pop()

  topology = None
  if query_topology:
    if not tpu_core_count:
      raise RuntimeError(
          'Cannot find any TPU cores in the system (master address {}). '
          'This usually means the master address is incorrect or the '
          'TPU worker has some problems. Available devices: {}'.format(
              master_address, devices))

    topology = _obtain_topology(master_address, cluster_def)

  # We sort the metadata devices so that downstream users get a sorted list
  # for creating mirrored variables correctly.
  def _sort_key(device):
    spec = tf_device.DeviceSpec.from_string(device.name)
    return (spec.job, spec.replica, spec.task, spec.device_type,
            spec.device_index)
  devices = tuple(sorted(devices, key=_sort_key))

  metadata = TPUSystemMetadata(
      num_cores=tpu_core_count,
      num_hosts=len(device_dict),
      num_of_cores_per_host=num_of_cores_per_host,
      topology=topology,
      devices=devices)

  if tpu_core_count:
    logging.info('Found TPU system:')
    logging.info('*** Num TPU Cores: %d', metadata.num_cores)
    logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
    logging.info('*** Num TPU Cores Per Worker: %d',
                 metadata.num_of_cores_per_host)
    for device in metadata.devices:
      logging.info('*** Available Device: %s', device)
  else:
    logging.info('Failed to find TPU: %s', metadata)
  return metadata


def _obtain_topology(master_address, cluster_def):
  """Obtains TPU fabric topology."""
  try:
    logging.info('Initializing TPU system (master: %s) to fetch topology '
                 'for model parallelism. This might take a while.',
                 master_address)
    with ops.Graph().as_default():
      session_config = get_session_config_with_timeout(
          _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, cluster_def)
      with session_lib.Session(
          master_address, config=session_config) as sess:
        topology = sess.run(tpu.initialize_system())
        return topology
  except errors.DeadlineExceededError:
    raise ValueError(
        'Fail to initialize TPU system with master (%s). '
        'Please double check the TPU system is functional.' % (
            master_address))


def get_session_config_with_timeout(timeout_in_secs, cluster_def):
  """Returns a session given a timeout and a cluster configuration."""
  config_proto = config_pb2.ConfigProto(
      operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
  return config_proto


def master_job(master, cluster_def):
  """Returns the canonical job name to use to place TPU computations on.

  Args:
    master: A `string` representing the TensorFlow master to use.
    cluster_def: A ClusterDef object describing the TPU cluster.

  Returns:
    A string containing the job name, or None if no job should be specified.

  Raises:
    ValueError: If the user needs to specify a tpu_job_name, because we are
      unable to infer the job name automatically, or if the user-specified job
      names are inappropriate.
  """
  # If the user specifies the tpu_job_name, use that.

  if master in _LOCAL_MASTERS:
    return None

  if (not cluster_def or not cluster_def.job):
    return _DEFAULT_JOB_NAME
  job_names = set(job.name for job in cluster_def.job)
  if _DEFAULT_JOB_NAME in job_names:
    # b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
    raise ValueError('Currently, tpu_worker is not an allowed job name.')
  if len(job_names) == 1:
    return cluster_def.job[0].name
  if len(job_names) == 2:
    if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
      job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
      return job_names.pop()
    # TODO(b/67716447): Include more sophisticated heuristics.
  raise ValueError('Could not infer TPU job name.')
