xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/service/server_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Python interface for creating dataset servers."""
16
17import collections
18
19# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
20from tensorflow.core.protobuf import service_config_pb2
21from tensorflow.python import pywrap_tensorflow
22from tensorflow.python.data.experimental.service import _pywrap_server_lib
23from tensorflow.python.data.experimental.service import _pywrap_utils
24from tensorflow.python.util.tf_export import tf_export
25
26
27def _get_time_or_placeholder(value):
28  """Modifies time-based config values to account for special behaviors."""
29
30  # Servers interpret time values of 0 to mean "choose a reasonable
31  # default". However, the Python API uses `None` for this, and allows 0 as a
32  # normal value. To account for this, if a user explicitly configures the
33  # interval/timeout to 0, we interpret it to mean "a very small number", and
34  # replace it with 1.
35  if value == 0:
36    return 1
37  # `None` indicates that the user wants to leave the behavior to the runtime.
38  if value is None:
39    return 0
40  return value
41
42
43@tf_export("data.experimental.service.DispatcherConfig")
44class DispatcherConfig(
45    collections.namedtuple("DispatcherConfig", [
46        "port", "protocol", "work_dir", "fault_tolerant_mode",
47        "worker_addresses", "job_gc_check_interval_ms", "job_gc_timeout_ms"
48    ])):
49  """Configuration class for tf.data service dispatchers.
50
51  Fields:
52    port: Specifies the port to bind to. A value of 0 indicates that the server
53      may bind to any available port.
54    protocol: The protocol to use for communicating with the tf.data service,
55      e.g. "grpc".
56    work_dir: A directory to store dispatcher state in. This
57      argument is required for the dispatcher to be able to recover from
58      restarts.
59    fault_tolerant_mode: Whether the dispatcher should write its state to a
60      journal so that it can recover from restarts. Dispatcher state, including
61      registered datasets and created jobs, is synchronously written to the
62      journal before responding to RPCs. If `True`, `work_dir` must also be
63      specified.
64    worker_addresses: If the job uses auto-sharding, it needs to specify a fixed
65      list of worker addresses that will register with the dispatcher. The
66      worker addresses should be in the format `"host"` or `"host:port"`, where
67      `"port"` is an integer, named port, or `%port%` to match any port.
68    job_gc_check_interval_ms: How often the dispatcher should scan through to
69      delete old and unused jobs, in milliseconds. If not set, the runtime will
70      select a reasonable default. A higher value will reduce load on the
71      dispatcher, while a lower value will reduce the time it takes for the
72      dispatcher to garbage collect expired jobs.
73    job_gc_timeout_ms: How long a job needs to be unused before it becomes a
74      candidate for garbage collection, in milliseconds. A value of -1 indicates
75      that jobs should never be garbage collected. If not set, the runtime will
76      select a reasonable default. A higher value will cause jobs to stay around
77      longer with no consumers. This is useful if there is a large gap in
78      time between when consumers read from the job. A lower value will reduce
79      the time it takes to reclaim the resources from expired jobs.
80  """
81
82  def __new__(cls,
83              port=0,
84              protocol=None,
85              work_dir=None,
86              fault_tolerant_mode=False,
87              worker_addresses=None,
88              job_gc_check_interval_ms=None,
89              job_gc_timeout_ms=None):
90    if protocol is None:
91      protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
92    job_gc_check_interval_ms = _get_time_or_placeholder(
93        job_gc_check_interval_ms)
94    job_gc_timeout_ms = _get_time_or_placeholder(job_gc_timeout_ms)
95    return super(DispatcherConfig,
96                 cls).__new__(cls, port, protocol, work_dir,
97                              fault_tolerant_mode, worker_addresses,
98                              job_gc_check_interval_ms, job_gc_timeout_ms)
99
100
101@tf_export("data.experimental.service.DispatchServer", v1=[])
102class DispatchServer:
103  """An in-process tf.data service dispatch server.
104
105  A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
106  `tf.data.experimental.service.WorkerServer`s. When the workers start, they
107  register themselves with the dispatcher.
108
109  >>> dispatcher = tf.data.experimental.service.DispatchServer()
110  >>> dispatcher_address = dispatcher.target.split("://")[1]
111  >>> worker = tf.data.experimental.service.WorkerServer(
112  ...     tf.data.experimental.service.WorkerConfig(
113  ...     dispatcher_address=dispatcher_address))
114  >>> dataset = tf.data.Dataset.range(10)
115  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
116  ...     processing_mode="parallel_epochs", service=dispatcher.target))
117  >>> print(list(dataset.as_numpy_iterator()))
118  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
119
120  When starting a dedicated tf.data dispatch process, use join() to block
121  indefinitely after starting up the server.
122
123  ```
124  dispatcher = tf.data.experimental.service.DispatchServer(
125      tf.data.experimental.service.DispatcherConfig(port=5050))
126  dispatcher.join()
127  ```
128
129  To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
130  `fault_tolerant_mode` like below:
131
132  ```
133  dispatcher = tf.data.experimental.service.DispatchServer(
134      tf.data.experimental.service.DispatcherConfig(
135          port=5050,
136          work_dir="gs://my-bucket/dispatcher/work_dir",
137          fault_tolerant_mode=True))
138  ```
139  """
140
141  def __init__(self, config=None, start=True):
142    """Creates a new dispatch server.
143
144    Args:
145      config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
146        configration. If `None`, the dispatcher will use default
147        configuration values.
148      start: (Optional.) Boolean, indicating whether to start the server after
149        creating it. Defaults to True.
150    """
151    config = config or DispatcherConfig()
152    if config.fault_tolerant_mode and not config.work_dir:
153      raise ValueError(
154          "Cannot enable fault tolerant mode without configuring a work dir. "
155          "Make sure to set `work_dir` in the `config` object passed to "
156          "`DispatcherServer`.")
157    self._config = config
158    if isinstance(config, service_config_pb2.DispatcherConfig):
159      config_proto = config
160    else:
161      config_proto = service_config_pb2.DispatcherConfig(
162          port=config.port,
163          protocol=config.protocol,
164          work_dir=config.work_dir,
165          fault_tolerant_mode=config.fault_tolerant_mode,
166          worker_addresses=config.worker_addresses,
167          job_gc_check_interval_ms=config.job_gc_check_interval_ms,
168          job_gc_timeout_ms=config.job_gc_timeout_ms)
169    self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
170        config_proto.SerializeToString())
171    if start:
172      self._server.start()
173
174  def start(self):
175    """Starts this server.
176
177    >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
178    >>> dispatcher.start()
179
180    Raises:
181      tf.errors.OpError: Or one of its subclasses if an error occurs while
182        starting the server.
183    """
184    self._server.start()
185
186  def join(self):
187    """Blocks until the server has shut down.
188
189    This is useful when starting a dedicated dispatch process.
190
191    ```
192    dispatcher = tf.data.experimental.service.DispatchServer(
193        tf.data.experimental.service.DispatcherConfig(port=5050))
194    dispatcher.join()
195    ```
196
197    Raises:
198      tf.errors.OpError: Or one of its subclasses if an error occurs while
199        joining the server.
200    """
201    self._server.join()
202
203  @property
204  def target(self):
205    """Returns a target that can be used to connect to the server.
206
207    >>> dispatcher = tf.data.experimental.service.DispatchServer()
208    >>> dataset = tf.data.Dataset.range(10)
209    >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
210    ...     processing_mode="parallel_epochs", service=dispatcher.target))
211
212    The returned string will be in the form protocol://address, e.g.
213    "grpc://localhost:5050".
214    """
215    return "{0}://localhost:{1}".format(self._config.protocol,
216                                        self._server.bound_port())
217
218  def _stop(self):
219    """Stops the server.
220
221    Raises:
222      tf.errors.OpError: Or one of its subclasses if an error occurs while
223        stopping the server.
224    """
225    self._server.stop()
226
227  def __del__(self):
228    self._stop()
229
230  @property
231  def _address(self):
232    """Returns the address of the server.
233
234    The returned string will be in the form address:port, e.g. "localhost:1000".
235    """
236    return "localhost:{0}".format(self._server.bound_port())
237
238  def _num_workers(self):
239    """Returns the number of workers registered with the dispatcher."""
240    return self._server.num_workers()
241
242
243@tf_export("data.experimental.service.WorkerConfig")
244class WorkerConfig(
245    collections.namedtuple("WorkerConfig", [
246        "dispatcher_address", "worker_address", "port", "protocol",
247        "heartbeat_interval_ms", "dispatcher_timeout_ms",
248        "data_transfer_protocol"
249    ])):
250  """Configuration class for tf.data service dispatchers.
251
252  Fields:
253    dispatcher_address: Specifies the address of the dispatcher.
254    worker_address: Specifies the address of the worker server. This address is
255      passed to the dispatcher so that the dispatcher can tell clients how to
256      connect to this worker.
257    port: Specifies the port to bind to. A value of 0 indicates that the worker
258      can bind to any available port.
259    protocol: A string indicating the protocol to be used by the worker to
260      connect to the dispatcher. E.g. "grpc".
261    heartbeat_interval_ms: How often the worker should heartbeat to the
262      dispatcher, in milliseconds. If not set, the runtime will select a
263      reasonable default. A higher value will reduce the load on the dispatcher,
264      while a lower value will reduce the time it takes to reclaim resources
265      from finished jobs.
266    dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
267      dispatcher before giving up and reporting an error. Defaults to 1 hour.
268    data_transfer_protocol: A string indicating the protocol to be used by the
269      worker to transfer data to the client. E.g. "grpc".
270  """
271
272  def __new__(cls,
273              dispatcher_address,
274              worker_address=None,
275              port=0,
276              protocol=None,
277              heartbeat_interval_ms=None,
278              dispatcher_timeout_ms=None,
279              data_transfer_protocol=None):
280    if worker_address is None:
281      worker_address = "localhost:%port%"
282    if protocol is None:
283      protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
284    if data_transfer_protocol is None:
285      data_transfer_protocol = (
286          _pywrap_utils.TF_DATA_DefaultDataTransferProtocol())
287    heartbeat_interval_ms = _get_time_or_placeholder(heartbeat_interval_ms)
288    dispatcher_timeout_ms = _get_time_or_placeholder(dispatcher_timeout_ms)
289
290    return super(WorkerConfig,
291                 cls).__new__(cls, dispatcher_address, worker_address, port,
292                              protocol, heartbeat_interval_ms,
293                              dispatcher_timeout_ms, data_transfer_protocol)
294
295
296@tf_export("data.experimental.service.WorkerServer", v1=[])
297class WorkerServer:
298  """An in-process tf.data service worker server.
299
300  A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
301  processing for user-defined datasets, and provides the resulting elements over
302  RPC. A worker is associated with a single
303  `tf.data.experimental.service.DispatchServer`.
304
305  >>> dispatcher = tf.data.experimental.service.DispatchServer()
306  >>> dispatcher_address = dispatcher.target.split("://")[1]
307  >>> worker = tf.data.experimental.service.WorkerServer(
308  ...     tf.data.experimental.service.WorkerConfig(
309  ...         dispatcher_address=dispatcher_address))
310  >>> dataset = tf.data.Dataset.range(10)
311  >>> dataset = dataset.apply(tf.data.experimental.service.distribute(
312  ...     processing_mode="parallel_epochs", service=dispatcher.target))
313  >>> print(list(dataset.as_numpy_iterator()))
314  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
315
316  When starting a dedicated tf.data worker process, use join() to block
317  indefinitely after starting up the server.
318
319  ```
320  worker = tf.data.experimental.service.WorkerServer(
321      port=5051, dispatcher_address="localhost:5050")
322  worker.join()
323  ```
324  """
325
326  def __init__(self, config, start=True):
327    """Creates a new worker server.
328
329    Args:
330      config: A `tf.data.experimental.service.WorkerConfig` configration.
331      start: (Optional.) Boolean, indicating whether to start the server after
332        creating it. Defaults to True.
333    """
334    if config.dispatcher_address is None:
335      raise ValueError(
336          "Must specify a `dispatcher_address` in the `config` passed "
337          "to `WorkerServer`.")
338    if isinstance(config, service_config_pb2.WorkerConfig):
339      config_proto = config
340    else:
341      config_proto = service_config_pb2.WorkerConfig(
342          dispatcher_address=config.dispatcher_address,
343          worker_address=config.worker_address,
344          port=config.port,
345          protocol=config.protocol,
346          heartbeat_interval_ms=config.heartbeat_interval_ms,
347          dispatcher_timeout_ms=config.dispatcher_timeout_ms,
348          data_transfer_protocol=None)
349    self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
350        config_proto.SerializeToString())
351    if start:
352      self._server.start()
353
354  def start(self):
355    """Starts this server.
356
357    Raises:
358      tf.errors.OpError: Or one of its subclasses if an error occurs while
359        starting the server.
360    """
361    self._server.start()
362
363  def join(self):
364    """Blocks until the server has shut down.
365
366    This is useful when starting a dedicated worker process.
367
368    ```
369    worker_server = tf.data.experimental.service.WorkerServer(
370        port=5051, dispatcher_address="localhost:5050")
371    worker_server.join()
372    ```
373
374    This method currently blocks forever.
375
376    Raises:
377      tf.errors.OpError: Or one of its subclasses if an error occurs while
378        joining the server.
379    """
380    self._server.join()
381
382  def _stop(self):
383    """Stops the server.
384
385    Raises:
386      tf.errors.OpError: Or one of its subclasses if an error occurs while
387        stopping the server.
388    """
389    self._server.stop()
390
391  def __del__(self):
392    self._stop()
393
394  @property
395  def _address(self):
396    """Returns the address of the server.
397
398    The returned string will be in the form address:port, e.g. "localhost:1000".
399    """
400    return "localhost:{0}".format(self._server.bound_port())
401
402  def _num_tasks(self):
403    """Returns the number of tasks currently being executed on the worker."""
404    return self._server.num_tasks()
405