# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of ClusterResolvers for GCE instance groups."""

from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.util.tf_export import tf_export


_GOOGLE_API_CLIENT_INSTALLED = True
try:
  from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
  from oauth2client.client import GoogleCredentials  # pylint: disable=g-import-not-at-top
except ImportError:
  _GOOGLE_API_CLIENT_INSTALLED = False


@tf_export('distribute.cluster_resolver.GCEClusterResolver')
class GCEClusterResolver(ClusterResolver):
  """ClusterResolver for Google Compute Engine.

  This is an implementation of cluster resolvers for the Google Compute Engine
  instance group platform. By specifying a project, zone, and instance group,
  this will retrieve the IP address of all the instances within the instance
  group and return a ClusterResolver object suitable for use for distributed
  TensorFlow.

  Note: this cluster resolver cannot retrieve `task_type`, `task_id` or
  `rpc_layer`. To use it with some distribution strategies like
  `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to
  specify `task_type` and `task_id` in the constructor.

  Usage example with tf.distribute.Strategy:

    ```Python
    # On worker 0
    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
                                          "my-instance-group",
                                          task_type="worker", task_id=0)
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
        cluster_resolver=cluster_resolver)

    # On worker 1
    cluster_resolver = GCEClusterResolver("my-project", "us-west1",
                                          "my-instance-group",
                                          task_type="worker", task_id=1)
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
        cluster_resolver=cluster_resolver)
    ```
  """

  def __init__(self,
               project,
               zone,
               instance_group,
               port,
               task_type='worker',
               task_id=0,
               rpc_layer='grpc',
               credentials='default',
               service=None):
    """Creates a new GCEClusterResolver object.

    This takes in a few parameters and creates a GCEClusterResolver project. It
    will then use these parameters to query the GCE API for the IP addresses of
    each instance in the instance group.

    Args:
      project: Name of the GCE project.
      zone: Zone of the GCE instance group.
      instance_group: Name of the GCE instance group.
      port: Port of the listening TensorFlow server (default: 8470)
      task_type: Name of the TensorFlow job this GCE instance group of VM
        instances belong to.
      task_id: The task index for this particular VM, within the GCE
        instance group. In particular, every single instance should be assigned
        a unique ordinal index within an instance group manually so that they
        can be distinguished from each other.
      rpc_layer: The RPC layer TensorFlow should use to communicate across
        instances.
      credentials: GCE Credentials. If nothing is specified, this defaults to
        GoogleCredentials.get_application_default().
      service: The GCE API object returned by the googleapiclient.discovery
        function. (Default: discovery.build('compute', 'v1')). If you specify a
        custom service object, then the credentials parameter will be ignored.

    Raises:
      ImportError: If the googleapiclient is not installed.
    """
    self._project = project
    self._zone = zone
    self._instance_group = instance_group
    self._task_type = task_type
    self._task_id = task_id
    self._rpc_layer = rpc_layer
    self._port = port
    self._credentials = credentials

    if credentials == 'default':
      if _GOOGLE_API_CLIENT_INSTALLED:
        self._credentials = GoogleCredentials.get_application_default()

    if service is None:
      if not _GOOGLE_API_CLIENT_INSTALLED:
        raise ImportError('googleapiclient must be installed before using the '
                          'GCE cluster resolver')
      self._service = discovery.build(
          'compute', 'v1',
          credentials=self._credentials)
    else:
      self._service = service

  def cluster_spec(self):
    """Returns a ClusterSpec object based on the latest instance group info.

    This returns a ClusterSpec object for use based on information from the
    specified instance group. We will retrieve the information from the GCE APIs
    every time this method is called.

    Returns:
      A ClusterSpec containing host information retrieved from GCE.
    """
    request_body = {'instanceState': 'RUNNING'}
    request = self._service.instanceGroups().listInstances(
        project=self._project,
        zone=self._zone,
        instanceGroups=self._instance_group,
        body=request_body,
        orderBy='name')

    worker_list = []

    while request is not None:
      response = request.execute()

      items = response['items']
      for instance in items:
        instance_name = instance['instance'].split('/')[-1]

        instance_request = self._service.instances().get(
            project=self._project,
            zone=self._zone,
            instance=instance_name)

        if instance_request is not None:
          instance_details = instance_request.execute()
          ip_address = instance_details['networkInterfaces'][0]['networkIP']
          instance_url = '%s:%s' % (ip_address, self._port)
          worker_list.append(instance_url)

      request = self._service.instanceGroups().listInstances_next(
          previous_request=request,
          previous_response=response)

    worker_list.sort()
    return ClusterSpec({self._task_type: worker_list})

  def master(self, task_type=None, task_id=None, rpc_layer=None):
    task_type = task_type if task_type is not None else self._task_type
    task_id = task_id if task_id is not None else self._task_id

    if task_type is not None and task_id is not None:
      master = self.cluster_spec().task_address(task_type, task_id)
      if rpc_layer or self._rpc_layer:
        return '%s://%s' % (rpc_layer or self._rpc_layer, master)
      else:
        return master

    return ''

  @property
  def task_type(self):
    return self._task_type

  @property
  def task_id(self):
    return self._task_id

  @task_type.setter
  def task_type(self, task_type):
    raise RuntimeError(
        'You cannot reset the task_type of the GCEClusterResolver after it has '
        'been created.')

  @task_id.setter
  def task_id(self, task_id):
    self._task_id = task_id

  @property
  def rpc_layer(self):
    return self._rpc_layer

  @rpc_layer.setter
  def rpc_layer(self, rpc_layer):
    self._rpc_layer = rpc_layer
