1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Python interface for creating dataset servers.""" 16 17import collections 18 19# pylint: disable=invalid-import-order,g-bad-import-order, unused-import 20from tensorflow.core.protobuf import service_config_pb2 21from tensorflow.python import pywrap_tensorflow 22from tensorflow.python.data.experimental.service import _pywrap_server_lib 23from tensorflow.python.data.experimental.service import _pywrap_utils 24from tensorflow.python.util.tf_export import tf_export 25 26 27def _get_time_or_placeholder(value): 28 """Modifies time-based config values to account for special behaviors.""" 29 30 # Servers interpret time values of 0 to mean "choose a reasonable 31 # default". However, the Python API uses `None` for this, and allows 0 as a 32 # normal value. To account for this, if a user explicitly configures the 33 # interval/timeout to 0, we interpret it to mean "a very small number", and 34 # replace it with 1. 35 if value == 0: 36 return 1 37 # `None` indicates that the user wants to leave the behavior to the runtime. 38 if value is None: 39 return 0 40 return value 41 42 43@tf_export("data.experimental.service.DispatcherConfig") 44class DispatcherConfig( 45 collections.namedtuple("DispatcherConfig", [ 46 "port", "protocol", "work_dir", "fault_tolerant_mode", 47 "worker_addresses", "job_gc_check_interval_ms", "job_gc_timeout_ms" 48 ])): 49 """Configuration class for tf.data service dispatchers. 50 51 Fields: 52 port: Specifies the port to bind to. A value of 0 indicates that the server 53 may bind to any available port. 54 protocol: The protocol to use for communicating with the tf.data service, 55 e.g. "grpc". 56 work_dir: A directory to store dispatcher state in. This 57 argument is required for the dispatcher to be able to recover from 58 restarts. 59 fault_tolerant_mode: Whether the dispatcher should write its state to a 60 journal so that it can recover from restarts. Dispatcher state, including 61 registered datasets and created jobs, is synchronously written to the 62 journal before responding to RPCs. If `True`, `work_dir` must also be 63 specified. 64 worker_addresses: If the job uses auto-sharding, it needs to specify a fixed 65 list of worker addresses that will register with the dispatcher. The 66 worker addresses should be in the format `"host"` or `"host:port"`, where 67 `"port"` is an integer, named port, or `%port%` to match any port. 68 job_gc_check_interval_ms: How often the dispatcher should scan through to 69 delete old and unused jobs, in milliseconds. If not set, the runtime will 70 select a reasonable default. A higher value will reduce load on the 71 dispatcher, while a lower value will reduce the time it takes for the 72 dispatcher to garbage collect expired jobs. 73 job_gc_timeout_ms: How long a job needs to be unused before it becomes a 74 candidate for garbage collection, in milliseconds. A value of -1 indicates 75 that jobs should never be garbage collected. If not set, the runtime will 76 select a reasonable default. A higher value will cause jobs to stay around 77 longer with no consumers. This is useful if there is a large gap in 78 time between when consumers read from the job. A lower value will reduce 79 the time it takes to reclaim the resources from expired jobs. 80 """ 81 82 def __new__(cls, 83 port=0, 84 protocol=None, 85 work_dir=None, 86 fault_tolerant_mode=False, 87 worker_addresses=None, 88 job_gc_check_interval_ms=None, 89 job_gc_timeout_ms=None): 90 if protocol is None: 91 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 92 job_gc_check_interval_ms = _get_time_or_placeholder( 93 job_gc_check_interval_ms) 94 job_gc_timeout_ms = _get_time_or_placeholder(job_gc_timeout_ms) 95 return super(DispatcherConfig, 96 cls).__new__(cls, port, protocol, work_dir, 97 fault_tolerant_mode, worker_addresses, 98 job_gc_check_interval_ms, job_gc_timeout_ms) 99 100 101@tf_export("data.experimental.service.DispatchServer", v1=[]) 102class DispatchServer: 103 """An in-process tf.data service dispatch server. 104 105 A `tf.data.experimental.service.DispatchServer` coordinates a cluster of 106 `tf.data.experimental.service.WorkerServer`s. When the workers start, they 107 register themselves with the dispatcher. 108 109 >>> dispatcher = tf.data.experimental.service.DispatchServer() 110 >>> dispatcher_address = dispatcher.target.split("://")[1] 111 >>> worker = tf.data.experimental.service.WorkerServer( 112 ... tf.data.experimental.service.WorkerConfig( 113 ... dispatcher_address=dispatcher_address)) 114 >>> dataset = tf.data.Dataset.range(10) 115 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 116 ... processing_mode="parallel_epochs", service=dispatcher.target)) 117 >>> print(list(dataset.as_numpy_iterator())) 118 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 119 120 When starting a dedicated tf.data dispatch process, use join() to block 121 indefinitely after starting up the server. 122 123 ``` 124 dispatcher = tf.data.experimental.service.DispatchServer( 125 tf.data.experimental.service.DispatcherConfig(port=5050)) 126 dispatcher.join() 127 ``` 128 129 To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and 130 `fault_tolerant_mode` like below: 131 132 ``` 133 dispatcher = tf.data.experimental.service.DispatchServer( 134 tf.data.experimental.service.DispatcherConfig( 135 port=5050, 136 work_dir="gs://my-bucket/dispatcher/work_dir", 137 fault_tolerant_mode=True)) 138 ``` 139 """ 140 141 def __init__(self, config=None, start=True): 142 """Creates a new dispatch server. 143 144 Args: 145 config: (Optional.) A `tf.data.experimental.service.DispatcherConfig` 146 configration. If `None`, the dispatcher will use default 147 configuration values. 148 start: (Optional.) Boolean, indicating whether to start the server after 149 creating it. Defaults to True. 150 """ 151 config = config or DispatcherConfig() 152 if config.fault_tolerant_mode and not config.work_dir: 153 raise ValueError( 154 "Cannot enable fault tolerant mode without configuring a work dir. " 155 "Make sure to set `work_dir` in the `config` object passed to " 156 "`DispatcherServer`.") 157 self._config = config 158 if isinstance(config, service_config_pb2.DispatcherConfig): 159 config_proto = config 160 else: 161 config_proto = service_config_pb2.DispatcherConfig( 162 port=config.port, 163 protocol=config.protocol, 164 work_dir=config.work_dir, 165 fault_tolerant_mode=config.fault_tolerant_mode, 166 worker_addresses=config.worker_addresses, 167 job_gc_check_interval_ms=config.job_gc_check_interval_ms, 168 job_gc_timeout_ms=config.job_gc_timeout_ms) 169 self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer( 170 config_proto.SerializeToString()) 171 if start: 172 self._server.start() 173 174 def start(self): 175 """Starts this server. 176 177 >>> dispatcher = tf.data.experimental.service.DispatchServer(start=False) 178 >>> dispatcher.start() 179 180 Raises: 181 tf.errors.OpError: Or one of its subclasses if an error occurs while 182 starting the server. 183 """ 184 self._server.start() 185 186 def join(self): 187 """Blocks until the server has shut down. 188 189 This is useful when starting a dedicated dispatch process. 190 191 ``` 192 dispatcher = tf.data.experimental.service.DispatchServer( 193 tf.data.experimental.service.DispatcherConfig(port=5050)) 194 dispatcher.join() 195 ``` 196 197 Raises: 198 tf.errors.OpError: Or one of its subclasses if an error occurs while 199 joining the server. 200 """ 201 self._server.join() 202 203 @property 204 def target(self): 205 """Returns a target that can be used to connect to the server. 206 207 >>> dispatcher = tf.data.experimental.service.DispatchServer() 208 >>> dataset = tf.data.Dataset.range(10) 209 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 210 ... processing_mode="parallel_epochs", service=dispatcher.target)) 211 212 The returned string will be in the form protocol://address, e.g. 213 "grpc://localhost:5050". 214 """ 215 return "{0}://localhost:{1}".format(self._config.protocol, 216 self._server.bound_port()) 217 218 def _stop(self): 219 """Stops the server. 220 221 Raises: 222 tf.errors.OpError: Or one of its subclasses if an error occurs while 223 stopping the server. 224 """ 225 self._server.stop() 226 227 def __del__(self): 228 self._stop() 229 230 @property 231 def _address(self): 232 """Returns the address of the server. 233 234 The returned string will be in the form address:port, e.g. "localhost:1000". 235 """ 236 return "localhost:{0}".format(self._server.bound_port()) 237 238 def _num_workers(self): 239 """Returns the number of workers registered with the dispatcher.""" 240 return self._server.num_workers() 241 242 243@tf_export("data.experimental.service.WorkerConfig") 244class WorkerConfig( 245 collections.namedtuple("WorkerConfig", [ 246 "dispatcher_address", "worker_address", "port", "protocol", 247 "heartbeat_interval_ms", "dispatcher_timeout_ms", 248 "data_transfer_protocol" 249 ])): 250 """Configuration class for tf.data service dispatchers. 251 252 Fields: 253 dispatcher_address: Specifies the address of the dispatcher. 254 worker_address: Specifies the address of the worker server. This address is 255 passed to the dispatcher so that the dispatcher can tell clients how to 256 connect to this worker. 257 port: Specifies the port to bind to. A value of 0 indicates that the worker 258 can bind to any available port. 259 protocol: A string indicating the protocol to be used by the worker to 260 connect to the dispatcher. E.g. "grpc". 261 heartbeat_interval_ms: How often the worker should heartbeat to the 262 dispatcher, in milliseconds. If not set, the runtime will select a 263 reasonable default. A higher value will reduce the load on the dispatcher, 264 while a lower value will reduce the time it takes to reclaim resources 265 from finished jobs. 266 dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the 267 dispatcher before giving up and reporting an error. Defaults to 1 hour. 268 data_transfer_protocol: A string indicating the protocol to be used by the 269 worker to transfer data to the client. E.g. "grpc". 270 """ 271 272 def __new__(cls, 273 dispatcher_address, 274 worker_address=None, 275 port=0, 276 protocol=None, 277 heartbeat_interval_ms=None, 278 dispatcher_timeout_ms=None, 279 data_transfer_protocol=None): 280 if worker_address is None: 281 worker_address = "localhost:%port%" 282 if protocol is None: 283 protocol = _pywrap_utils.TF_DATA_DefaultProtocol() 284 if data_transfer_protocol is None: 285 data_transfer_protocol = ( 286 _pywrap_utils.TF_DATA_DefaultDataTransferProtocol()) 287 heartbeat_interval_ms = _get_time_or_placeholder(heartbeat_interval_ms) 288 dispatcher_timeout_ms = _get_time_or_placeholder(dispatcher_timeout_ms) 289 290 return super(WorkerConfig, 291 cls).__new__(cls, dispatcher_address, worker_address, port, 292 protocol, heartbeat_interval_ms, 293 dispatcher_timeout_ms, data_transfer_protocol) 294 295 296@tf_export("data.experimental.service.WorkerServer", v1=[]) 297class WorkerServer: 298 """An in-process tf.data service worker server. 299 300 A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset` 301 processing for user-defined datasets, and provides the resulting elements over 302 RPC. A worker is associated with a single 303 `tf.data.experimental.service.DispatchServer`. 304 305 >>> dispatcher = tf.data.experimental.service.DispatchServer() 306 >>> dispatcher_address = dispatcher.target.split("://")[1] 307 >>> worker = tf.data.experimental.service.WorkerServer( 308 ... tf.data.experimental.service.WorkerConfig( 309 ... dispatcher_address=dispatcher_address)) 310 >>> dataset = tf.data.Dataset.range(10) 311 >>> dataset = dataset.apply(tf.data.experimental.service.distribute( 312 ... processing_mode="parallel_epochs", service=dispatcher.target)) 313 >>> print(list(dataset.as_numpy_iterator())) 314 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 315 316 When starting a dedicated tf.data worker process, use join() to block 317 indefinitely after starting up the server. 318 319 ``` 320 worker = tf.data.experimental.service.WorkerServer( 321 port=5051, dispatcher_address="localhost:5050") 322 worker.join() 323 ``` 324 """ 325 326 def __init__(self, config, start=True): 327 """Creates a new worker server. 328 329 Args: 330 config: A `tf.data.experimental.service.WorkerConfig` configration. 331 start: (Optional.) Boolean, indicating whether to start the server after 332 creating it. Defaults to True. 333 """ 334 if config.dispatcher_address is None: 335 raise ValueError( 336 "Must specify a `dispatcher_address` in the `config` passed " 337 "to `WorkerServer`.") 338 if isinstance(config, service_config_pb2.WorkerConfig): 339 config_proto = config 340 else: 341 config_proto = service_config_pb2.WorkerConfig( 342 dispatcher_address=config.dispatcher_address, 343 worker_address=config.worker_address, 344 port=config.port, 345 protocol=config.protocol, 346 heartbeat_interval_ms=config.heartbeat_interval_ms, 347 dispatcher_timeout_ms=config.dispatcher_timeout_ms, 348 data_transfer_protocol=None) 349 self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer( 350 config_proto.SerializeToString()) 351 if start: 352 self._server.start() 353 354 def start(self): 355 """Starts this server. 356 357 Raises: 358 tf.errors.OpError: Or one of its subclasses if an error occurs while 359 starting the server. 360 """ 361 self._server.start() 362 363 def join(self): 364 """Blocks until the server has shut down. 365 366 This is useful when starting a dedicated worker process. 367 368 ``` 369 worker_server = tf.data.experimental.service.WorkerServer( 370 port=5051, dispatcher_address="localhost:5050") 371 worker_server.join() 372 ``` 373 374 This method currently blocks forever. 375 376 Raises: 377 tf.errors.OpError: Or one of its subclasses if an error occurs while 378 joining the server. 379 """ 380 self._server.join() 381 382 def _stop(self): 383 """Stops the server. 384 385 Raises: 386 tf.errors.OpError: Or one of its subclasses if an error occurs while 387 stopping the server. 388 """ 389 self._server.stop() 390 391 def __del__(self): 392 self._stop() 393 394 @property 395 def _address(self): 396 """Returns the address of the server. 397 398 The returned string will be in the form address:port, e.g. "localhost:1000". 399 """ 400 return "localhost:{0}".format(self._server.bound_port()) 401 402 def _num_tasks(self): 403 """Returns the number of tasks currently being executed on the worker.""" 404 return self._server.num_tasks() 405