# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module to expose RPC APIs in tensorflow."""

from typing import Optional, Sequence, Union

import tensorflow.distribute.experimental.rpc.kernels.gen_rpc_ops as gen_rpc_ops
from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as rpc_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as tf_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import none_tensor
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.types import core as core_tf_types
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


def get_output_specs_from_function(func: tf_function.ConcreteFunction):
  output_specs = nest.map_structure(type_spec.type_spec_from_value,
                                    func.structured_outputs)
  output_specs_proto = nested_structure_coder.encode_structure(output_specs)
  return output_specs_proto.SerializeToString()


def get_input_specs_from_function(func: tf_function.ConcreteFunction):
  arg_specs, _ = func.structured_input_signature
  arg_specs_proto = nested_structure_coder.encode_structure(arg_specs)
  return arg_specs_proto.SerializeToString()


@tf_export("distribute.experimental.rpc.Server", v1=[])
class Server(object):
  """A Server base class for accepting RPCs for registered tf.functions.

    Functions can be registered on the server and are exposed via RPCs.
  """

  @staticmethod
  def create(rpc_layer, address):
    """Create TF RPC server at given address.

    Args:
      rpc_layer: Communication layer between client and server. Only "grpc" rpc
        layer is supported at the moment.
      address: Address where RPC server is hosted.

    Returns:
      An instance of `tf.distribute.experimental.rpc.Server` class.

    Raises:
        A ValueError if rpc_layer other than "grpc" is used. Only GRPC
        is supported at the moment.

    Example usage:

      >>> import portpicker
      >>> @tf.function(input_signature=[
      ...      tf.TensorSpec([], tf.int32),
      ...      tf.TensorSpec([], tf.int32)])
      ... def remote_fn(a, b):
      ...   return tf.add(a, b)

      >>> port = portpicker.pick_unused_port()
      >>> address = "localhost:{}".format(port)
      >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
      >>> server.register("addition", remote_fn)
      >>> server.start()

    """
    if rpc_layer != "grpc":
      raise ValueError("Only GRPC backend is supported at the moment.")
    return GrpcServer(address=address)

  def register(self, method_name: str,
               func: Union[def_function.Function,
                           tf_function.ConcreteFunction]):
    """Method for registering tf.function on server.

    Registered methods can be invoked remotely from clients.

    Args:
      method_name: Name of the tf.function. Clients use this method_name to make
        RPCs.
      func: A `tf.function` or ConcreteFunction to register.
    """
    raise NotImplementedError("Please use create_server method to create a"
                              "concrete subclass of Server.")

  def start(self):
    """Starts the RPC server on provided address.

     Server listens for new requests from client, once it is started.
    """
    raise NotImplementedError("Please use create_server method to create a"
                              "concrete subclass of Server.")


@tf_export("distribute.experimental.rpc.Client", v1=[])
class Client(object):
  """Client class for invoking RPCs to the server."""

  @staticmethod
  def create(rpc_layer, address, name="", timeout_in_ms=0):
    """Create TF RPC client to connect to the given address.

    Args:
      rpc_layer: Communication layer between client and server. Only "grpc" rpc
        layer is supported at the moment.
      address: Address of the server to connect the RPC client to.
      name: Name of the RPC Client. You can create multiple clients connecting
        to same server and distinguish them using different names.
      timeout_in_ms: The default timeout to use for outgoing RPCs from client. 0
        indicates no timeout. Exceeding timeout during RPC will raise
        DeadlineExceeded error.

    Returns:
      An instance of `tf.distribute.experimental.rpc.Client` with the following
      dynamically added methods for eagerly created clients:
        * `Registered methods` e.g. multiply(**args):
            If Client is created when executing eagerly, client will request the
            list of registered methods from server during client creation.
            The convenience methods for RPCs will be dynamically added to the
            created Client instance.

            For example, when a server has method "multiply" registered, the
            client object created in eager mode will have 'multiply' method
            available. Users can use client.multiply(..) to make RPC, instead of
            client.call("multiply", ...)

            Both "call" and "multiply" methods are non-blocking i.e. they return
            a StatusOrResult object which should be used to wait for getting
            value or error.

            Along with the above, blocking versions of the registered
            methods are also dynamically added to client instance.
            e.g. multiply_blocking(**args). These methods block till the RPC is
            finished and return response for successful RPC. Otherwise raise
            exception.

            These methods are not available when Client is created inside a
            tf.function.

    Raises:
        A ValueError if rpc_layer other than "grpc" is used. Only GRPC
          is supported at the moment.
        A DeadlineExceeded exception in eager mode if timeout exceeds while
          creating and listing client methods.

    Example usage:
      >>> # Have server already started.
      >>> import portpicker
      >>> @tf.function(input_signature=[
      ...      tf.TensorSpec([], tf.int32),
      ...      tf.TensorSpec([], tf.int32)])
      ... def remote_fn(a, b):
      ...   return tf.add(a, b)

      >>> port = portpicker.pick_unused_port()
      >>> address = "localhost:{}".format(port)
      >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
      >>> server.register("addition", remote_fn)
      >>> server.start()

      >>> # Start client
      >>> client = tf.distribute.experimental.rpc.Client.create("grpc",
      ...      address=address, name="test_client")

      >>> a = tf.constant(2, dtype=tf.int32)
      >>> b = tf.constant(3, dtype=tf.int32)

      >>> result = client.call(
      ...    args=[a, b],
      ...    method_name="addition",
      ...    output_specs=tf.TensorSpec((), tf.int32))

      >>> if result.is_ok():
      ...   result.get_value()

      >>> result = client.addition(a, b)

      >>> if result.is_ok():
      ...   result.get_value()

      >>> value = client.addition_blocking(a, b)
    """
    if rpc_layer != "grpc":
      raise ValueError("Only GRPC backend is supported at the moment.")
    if context.executing_eagerly():
      list_registered_methods = True
    else:
      list_registered_methods = False
    return GrpcClient(
        address=address,
        name=name,
        list_registered_methods=list_registered_methods,
        timeout_in_ms=timeout_in_ms)

  def call(self,
           method_name: str,
           args: Optional[Sequence[core_tf_types.Tensor]] = None,
           output_specs=None,
           timeout_in_ms=0):
    """Method for making RPC calls to remote server.

    This invokes RPC to the server, executing the registered method_name
    remotely.
    Args:
      method_name: Remote registered method to invoke
      args: List of arguments for the registered method.
      output_specs: Output specs for the output from method.
         For example, if tf.function is: @tf.function(input_signature=[
           tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int32) ])
          def multiply_fn(a, b): return tf.math.multiply(a, b)
        output_spec is: tf.TensorSpec((), tf.int32)  If you have access to TF
          Function, the output specs can be generated
       from tf.function by calling: output_specs =
         tf.nest.map_structure(tf.type_spec_from_value,
         tf_function.get_concrete_function().structured_outputs  If output_specs
         are not provided, flattened list of tensors will be returned in
         response.
      timeout_in_ms: Timeout for this call. If 0, default client timeout will be
        used.

    Returns:
      An instance of `StatusOrResult` class with the following available
      methods.
        * `is_ok()`:
            Returns True of RPC was successful.
        * `get_error()`:
            Returns TF error_code and error message for the RPC.
        * `get_value()`:
            Returns the returned value from remote TF function execution
            when RPC is successful.

      Calling any of the above methods will block till RPC is completed and
      result is available.
    """
    raise NotImplementedError("Must be implemented in inherited classes.")


class GrpcServer(Server):
  """GrpcServer object encapsulates a resource with GRPC server.

    Functions can be registered locally and are exposed via RPCs.
    Example:
    ```
    server = rpc_ops.GrpcServer("host:port")
    @tf.function
    def add(a, b):
      return a + b

    server.register("add", add)
    server.start()
    ```
  """

  def __init__(self, address: str):
    self._server_handle = gen_rpc_ops.rpc_server(address)
    if context.executing_eagerly():
      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._server_handle, handle_device=self._server_handle.device)
    else:
      raise NotImplementedError("Please create the server outside tf.function.")

  def register(self, method_name: str,
               func: Union[def_function.Function,
                           tf_function.ConcreteFunction]):
    """Method for registering functions."""

    if isinstance(func, def_function.Function):
      if func.function_spec.arg_names:
        if func.input_signature is None:
          raise ValueError("Input signature not specified for the function.")
      concrete_fn = func.get_concrete_function()
      gen_rpc_ops.rpc_server_register(
          self._server_handle,
          method_name=method_name,
          captured_inputs=concrete_fn.captured_inputs,
          input_specs=get_input_specs_from_function(concrete_fn),
          output_specs=get_output_specs_from_function(concrete_fn),
          f=concrete_fn)
    elif isinstance(func, tf_function.ConcreteFunction):
      gen_rpc_ops.rpc_server_register(
          self._server_handle,
          method_name=method_name,
          captured_inputs=func.captured_inputs,
          input_specs=get_input_specs_from_function(func),
          output_specs=get_output_specs_from_function(func),
          f=func)
    else:
      # Python functions
      # TODO(b/186762191): Add an implementation to support python functions.
      raise ValueError("Only TF functions are supported with Register method")

  def start(self):
    """Starts GRPC server."""
    gen_rpc_ops.rpc_server_start(self._server_handle)


class GrpcClient(Client):
  """Client wrapper to connect to remote RPC server using GRPC.

  If Client is created with (list_registered_methods=True):
  1. Input and output specs for the methods till this point will be fetched from
  Server.
  2. convenience methods are added to invoke registered methods directly from
  client.
  For example:
    For call a server method `add`
    client.add(a, b) or client.add_async(a, b) can be used instead of
    client.call(args=[a,b], output_specs=[..])

  Prerequisite for using list_registered_methods=True:
   1. Server should be already started with the registered methods.
   2. Client must be created in Eager mode.
  """

  def __init__(self,
               address: str,
               name: str = "",
               list_registered_methods=False,
               timeout_in_ms=0):
    self._client_handle, methods = gen_rpc_ops.rpc_client(
        shared_name=name,
        server_address=address,
        list_registered_methods=list_registered_methods,
        timeout_in_ms=timeout_in_ms)
    if context.executing_eagerly():
      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
          handle=self._client_handle, handle_device=self._client_handle.device)
    else:
      raise NotImplementedError(
          "Client creation is supported only in eager mode.")
    self._server_address = address
    self._method_registry = {}
    for method in methods.numpy():
      m = rpc_pb2.RegisteredMethod()
      m.ParseFromString(method)
      output_specs = nested_structure_coder.decode_proto(m.output_specs)
      input_specs = nested_structure_coder.decode_proto(m.input_specs)
      self._method_registry[m.method] = output_specs
      # TODO(ishark): Perhaps doc string can also be taken as input during
      # function registration.
      doc_string = "RPC Call for " + m.method + " method to server " + address
      self._add_method(m.method, output_specs, input_specs, self._client_handle,
                       doc_string)

  def _add_method(self, method_name, output_specs, input_specs, client_handle,
                  doc_string):
    """Method to add RPC methods to the client object."""

    def validate_and_get_flat_inputs(*args):
      if args is None:
        args = []
      if input_specs:
        nest.assert_same_structure(args, input_specs)
      flat_inputs = nest.flatten(args)
      return flat_inputs

    def call_wrapper(*args, timeout_in_ms=0):
      status_or, deleter = gen_rpc_ops.rpc_call(
          client_handle,
          args=validate_and_get_flat_inputs(*args),
          method_name=method_name,
          timeout_in_ms=timeout_in_ms)
      return StatusOrResult(status_or, deleter, output_specs)

    def call_blocking_wrapper(*args, timeout_in_ms=0):
      status_or, deleter = gen_rpc_ops.rpc_call(
          client_handle,
          args=validate_and_get_flat_inputs(*args),
          method_name=method_name,
          timeout_in_ms=timeout_in_ms)
      status_or = StatusOrResult(status_or, deleter, output_specs)
      if status_or.is_ok():
        return status_or.get_value()
      else:
        error_code, error_msg = status_or.get_error()
        raise errors.exception_type_from_error_code(error_code.numpy())(
            None, None, error_msg.numpy())

    setattr(self, method_name, call_wrapper)
    call_wrapper.__doc__ = doc_string

    blocking_method_name = method_name + "_blocking"
    setattr(self, blocking_method_name, call_blocking_wrapper)
    call_blocking_wrapper.__doc__ = doc_string

  def call(self,
           method_name: str,
           args: Optional[Sequence[core_tf_types.Tensor]] = None,
           output_specs=None,
           timeout_in_ms=0):
    """Method to invoke remote registered functions on the connected server.

    Server should be started before making an RPC Call.

    Args:
      method_name: Registered method to invoke on Server.
      args: Input arguments for the method.
      output_specs: Output specs for the output from method.
      timeout_in_ms: Timeout for this call. If 0, default client timeout will be
       used.

    Returns:
      StatusOrResult object. This function issues the RPC call to server, it
      does not block for the duration of RPC. Please call is_ok, get_error or
      get_value methods on the returned object to blocked till RPC finishes.
    """
    if args is None:
      args = []
    status_or, deleter = gen_rpc_ops.rpc_call(
        self._client_handle,
        args=nest.flatten(args),
        method_name=method_name,
        timeout_in_ms=timeout_in_ms)
    return StatusOrResult(status_or, deleter, output_specs)


class StatusOrResult(object):
  """Class representing result and status from RPC Call."""

  def __init__(self, status_or, deleter, output_specs=None):
    self._status_or = status_or
    self._output_specs = output_specs
    self._deleter = deleter
    self._error_code: dtypes.int64 = None
    self._error_message: dtypes.string = None

  def _check_status(self):
    if self._error_code is None:
      self._error_code, self._error_message = gen_rpc_ops.rpc_check_status(
          self._status_or)

  def __del__(self):
    # Make sure the resource is deleted in the same mode as it was created in.
    if context.executing_eagerly():
      with context.eager_mode():
        gen_rpc_ops.delete_rpc_future_resource(
            handle=self._status_or, deleter=self._deleter)
    else:
      with context.graph_mode():
        gen_rpc_ops.delete_rpc_future_resource(
            handle=self._status_or, deleter=self._deleter)

  def is_ok(self):
    """Returns True if RPC is successful, otherwise returns False.

    This call will block for RPC result.
    """
    self._check_status()
    return math_ops.equal(self._error_code,
                          constant_op.constant(0, dtype=dtypes.int64))

  def get_error(self):
    """Returns (TF Error Code, Error Message) from RPC Response.

    This call will block for RPC result.
    """
    self._check_status()
    return self._error_code, self._error_message

  def get_value(self):
    """Returns the returned response value from RPC Call when RPC is successful.

      The returned value is tensors in the output_specs format as returned from
      the RPC call


    This call will block for RPC result.
    """

    self._check_status()
    if self._output_specs is None or isinstance(self._output_specs,
                                                none_tensor.NoneTensorSpec):
      flat_output_dtypes = []
      return_none = True
    else:
      return_none = False
      flat_output_dtypes = [s.dtype for s in nest.flatten(self._output_specs)]

    result = gen_rpc_ops.rpc_get_value(self._status_or, Tout=flat_output_dtypes)
    if return_none:
      return None
    else:
      return nest.pack_sequence_as(self._output_specs, result)
