# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tf.data service test-cluster with local and remote workers."""

import tempfile

from tensorflow.core.protobuf import data_service_pb2
from tensorflow.core.protobuf import service_config_pb2
from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest

_WORKER_SHUTDOWN_QUIET_PERIOD_MS = 100


# pylint: disable=protected-access
class _RemoteWorkerProcess(multi_process_lib.Process):
  """Runs a worker server in a new process to simulate a remote worker."""

  def __init__(self, dispatcher_address, port, worker_tags, pipe_writer):
    super(_RemoteWorkerProcess, self).__init__()
    self._dispatcher_address = dispatcher_address
    self._port = port
    self._worker_tags = worker_tags
    self._pipe_writer = pipe_writer

  def run(self):
    self.start_worker()

  def start_worker(self):
    self._worker = data_service_test_base.TestWorker(
        self._dispatcher_address,
        _WORKER_SHUTDOWN_QUIET_PERIOD_MS,
        port=self._port,
        worker_tags=self._worker_tags)
    self._worker.start()
    self._pipe_writer.send(self._worker.worker_address())
    self._worker.join()


class MultiProcessCluster:
  """tf.data service cluster with local and remote workers.

  Represents a cluster with a dispatcher, `num_local_workers` local workers, and
  `num_remote_workers` remote workers. Remote workers run in separate processes.
  This is useful to test reading from local in-process workers. For example:

  ```
  cluster = multi_process_cluster.MultiProcessCluster(
      num_local_workers=1, num_remote_workers=3)
  num_elements = 10
  dataset = self.make_distributed_range_dataset(
      num_elements, cluster, target_workers="LOCAL")
  self.assertDatasetProduces(dataset, list(range(num_elements)))
  ```
  """

  def __init__(self,
               num_local_workers,
               num_remote_workers,
               worker_tags=None,
               worker_addresses=None,
               deployment_mode=data_service_pb2.DEPLOYMENT_MODE_COLOCATED):
    self._work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
    self._deployment_mode = deployment_mode
    self._start_dispatcher(worker_addresses)
    self._start_local_workers(num_local_workers, worker_tags)
    self._start_remote_workers(num_remote_workers, worker_tags)

  def _start_dispatcher(self, worker_addresses, port=0):
    if port == 0:
      port = test_util.pick_unused_port()
    self._dispatcher = server_lib.DispatchServer(
        service_config_pb2.DispatcherConfig(
            port=port,
            protocol="grpc",
            work_dir=self._work_dir,
            fault_tolerant_mode=True,
            worker_addresses=worker_addresses,
            deployment_mode=self._deployment_mode),
        start=True)

  def _start_local_workers(self, num_workers, worker_tags=None):
    self._local_workers = []
    for _ in range(num_workers):
      self.start_local_worker(worker_tags)

  def _start_remote_workers(self, num_workers, worker_tags=None):
    # List of (worker address, remote worker process) tuples.
    self._remote_workers = []
    for _ in range(num_workers):
      self.start_remote_worker(worker_tags)

  def start_local_worker(self, worker_tags=None):
    worker = data_service_test_base.TestWorker(
        self.dispatcher_address(),
        _WORKER_SHUTDOWN_QUIET_PERIOD_MS,
        port=test_util.pick_unused_port(),
        worker_tags=worker_tags)
    worker.start()
    self._local_workers.append(worker)

  def start_remote_worker(self, worker_tags=None):
    """Runs a tf.data service worker in a remote process."""

    pipe_reader, pipe_writer = multi_process_lib.multiprocessing.Pipe(
        duplex=False)
    worker_process = _RemoteWorkerProcess(
        self.dispatcher_address(),
        port=test_util.pick_unused_port(),
        worker_tags=worker_tags,
        pipe_writer=pipe_writer)
    worker_process.start()
    worker_address = pipe_reader.recv()
    self._remote_workers.append((worker_address, worker_process))

  def restart_dispatcher(self):
    port = int(self.dispatcher_address().split(":")[1])
    self._dispatcher._stop()
    self._start_dispatcher(
        worker_addresses=(self.local_worker_addresses() +
                          self.remote_worker_addresses()),
        port=port)

  def restart_local_workers(self):
    for worker in self._local_workers:
      worker.restart()

  def dispatcher_address(self):
    return self._dispatcher._address

  def local_worker_addresses(self):
    return [worker.worker_address() for worker in self._local_workers]

  def remote_worker_addresses(self):
    return [worker_address for (worker_address, _) in self._remote_workers]

  def _stop(self):
    for worker in self._local_workers:
      worker.stop()
    for (_, worker_process) in self._remote_workers:
      worker_process.kill()
    self._dispatcher._stop()

  def __del__(self):
    self._stop()


def test_main():
  """Main function to be called within `__main__` of a test file."""
  multi_process_lib.test_main()
