xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/remote.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to connect to remote servers."""
16
17import copy
18
19from absl import logging
20
21from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.distribute import device_util
24from tensorflow.python.distribute.cluster_resolver import cluster_resolver
25from tensorflow.python.eager import context
26from tensorflow.python.framework import ops
27from tensorflow.python.platform import remote_utils
28from tensorflow.python.training import server_lib
29from tensorflow.python.util import nest
30from tensorflow.python.util.tf_export import tf_export
31
32
33_GRPC_PREFIX = "grpc://"
34_LOCAL_MASTERS = ("", "local")
35
36
37@tf_export("config.experimental_connect_to_host")
38def connect_to_remote_host(remote_host=None, job_name="worker"):
39  """Connects to a single machine to enable remote execution on it.
40
41  Will make devices on the remote host available to use. Note that calling this
42  more than once will work, but will invalidate any tensor handles on the old
43  remote devices.
44
45  Using the default job_name of worker, you can schedule ops to run remotely as
46  follows:
47  ```python
48  # When eager execution is enabled, connect to the remote host.
49  tf.config.experimental_connect_to_host("exampleaddr.com:9876")
50
51  with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
52    # The following tensors should be resident on the remote device, and the op
53    # will also execute remotely.
54    x1 = array_ops.ones([2, 2])
55    x2 = array_ops.ones([2, 2])
56    y = math_ops.matmul(x1, x2)
57  ```
58
59  Args:
60    remote_host: a single or a list the remote server addr in host-port format.
61    job_name: The job name under which the new server will be accessible.
62
63  Raises:
64    ValueError: if remote_host is None.
65  """
66  if not remote_host:
67    raise ValueError("Must provide at least one remote_host")
68
69  remote_hosts = nest.flatten(remote_host)
70  cluster_spec = server_lib.ClusterSpec(
71      {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]})
72
73  connect_to_cluster(cluster_spec)
74
75
76@tf_export("config.experimental_connect_to_cluster")
77def connect_to_cluster(cluster_spec_or_resolver,
78                       job_name="localhost",
79                       task_index=0,
80                       protocol=None,
81                       make_master_device_default=True,
82                       cluster_device_filters=None):
83  """Connects to the given cluster.
84
85  Will make devices on the cluster available to use. Note that calling this more
86  than once will work, but will invalidate any tensor handles on the old remote
87  devices.
88
89  If the given local job name is not present in the cluster specification, it
90  will be automatically added, using an unused port on the localhost.
91
92  Device filters can be specified to isolate groups of remote tasks to avoid
93  undesired accesses between workers. Workers accessing resources or launching
94  ops / functions on filtered remote devices will result in errors (unknown
95  devices). For any remote task, if no device filter is present, all cluster
96  devices will be visible; if any device filter is specified, it can only
97  see devices matching at least one filter. Devices on the task itself are
98  always visible. Device filters can be particially specified.
99
100  For example, for a cluster set up for parameter server training, the following
101  device filters might be specified:
102
103  ```python
104  cdf = tf.config.experimental.ClusterDeviceFilters()
105  # For any worker, only the devices on PS nodes and itself are visible
106  for i in range(num_workers):
107    cdf.set_device_filters('worker', i, ['/job:ps'])
108  # Similarly for any ps, only the devices on workers and itself are visible
109  for i in range(num_ps):
110    cdf.set_device_filters('ps', i, ['/job:worker'])
111
112  tf.config.experimental_connect_to_cluster(cluster_def,
113                                            cluster_device_filters=cdf)
114  ```
115
116  Args:
117    cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing
118      the cluster.
119    job_name: The name of the local job.
120    task_index: The local task index.
121    protocol: The communication protocol, such as `"grpc"`. If unspecified, will
122      use the default from `python/platform/remote_utils.py`.
123    make_master_device_default: If True and a cluster resolver is passed, will
124      automatically enter the master task device scope, which indicates the
125      master becomes the default device to run ops. It won't do anything if
126      a cluster spec is passed. Will throw an error if the caller is currently
127      already in some device scope.
128    cluster_device_filters: an instance of
129      `tf.train.experimental/ClusterDeviceFilters` that specify device filters
130      to the remote tasks in cluster.
131  """
132  if not context.executing_eagerly():
133    raise ValueError(
134        "`tf.config.experimental_connect_to_cluster` can only be called in "
135        "eager mode."
136    )
137  protocol = protocol or remote_utils.get_default_communication_protocol()
138  if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec):
139    cluster_spec = cluster_spec_or_resolver
140  elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver):
141    if cluster_spec_or_resolver.master() in _LOCAL_MASTERS:
142      # Do nothing if the master is local.
143      return
144    cluster_spec = cluster_spec_or_resolver.cluster_spec()
145  else:
146    raise ValueError(
147        "`cluster_spec_or_resolver` must be a `ClusterSpec` or a "
148        "`ClusterResolver`.")
149
150  cluster_def = copy.deepcopy(cluster_spec.as_cluster_def())
151  if cluster_device_filters:
152    if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters):
153      cluster_device_filters = copy.deepcopy(
154          cluster_device_filters._as_cluster_device_filters())  # pylint: disable=protected-access
155    else:
156      raise ValueError("`cluster_device_filters` must be an instance of "
157                       "`tf.train.experimental.ClusterDeviceFilters`.")
158
159  # Check whether the server def has changed. We need to do the check before the
160  # local job is added to the cluster.
161  is_server_def_changed = False
162  current_server_def = context.get_server_def()
163  if current_server_def and job_name not in cluster_spec.jobs:
164    for i, job in enumerate(current_server_def.cluster.job):
165      if job.name == job_name:
166        del current_server_def.cluster.job[i]
167  if (current_server_def is None or current_server_def.cluster != cluster_def or
168      current_server_def.job_name != job_name or
169      current_server_def.task_index != task_index):
170    is_server_def_changed = True
171
172  # Automatically add local job, if not part of the cluster spec.
173  if job_name not in cluster_spec.jobs:
174    local_port = pywrap_tfe.TF_PickUnusedPortOrDie()
175    job_def = cluster_def.job.add()
176    job_def.name = job_name
177    # TODO(fishx): Update this to make sure remote worker has valid ip address
178    # to connect with local.
179    job_def.tasks[0] = "localhost:{}".format(local_port)
180
181  if context.context().coordination_service is None:
182    # Maybe enable coordination service for the communication protocol
183    coordination_service = remote_utils.coordination_service_type(protocol)
184    if coordination_service:
185      context.context().configure_coordination_service(coordination_service)
186
187  server_def = ServerDef(
188      cluster=cluster_def,
189      job_name=job_name,
190      task_index=task_index,
191      protocol=protocol,
192      default_session_config=context.context().config,
193      cluster_device_filters=cluster_device_filters)
194
195  if is_server_def_changed:
196    context.set_server_def(server_def)
197  else:
198    context.update_server_def(server_def)
199
200  if make_master_device_default and isinstance(
201      cluster_spec_or_resolver,
202      cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master():
203    master = cluster_spec_or_resolver.master()
204    master_job_name = None
205    master_task_id = None
206    for job_name in cluster_spec.jobs:
207      for task_id in cluster_spec.task_indices(job_name):
208        task_address = cluster_spec.task_address(job_name, task_id)
209        if master in task_address or task_address in master:
210          master_job_name = job_name
211          master_task_id = task_id
212          break
213
214    if not master_job_name:
215      raise ValueError(
216          "`make_master_device_default` is set to True but cannot find "
217          "master %s in the cluster" % master)
218
219    master_device = "/job:{}/replica:0/task:{}".format(master_job_name,
220                                                       master_task_id)
221    master_device = device_util.canonicalize(master_device)
222    current_device = device_util.current()
223    if current_device:
224      current_device = device_util.canonicalize(current_device)
225    if current_device and current_device != master_device:
226      raise ValueError("`connect_to_cluster` is called inside existing device "
227                       "scope %s, which is different from the master device "
228                       "scope %s to enter. This is not allowed." %
229                       (current_device, master_device))
230    # TODO(b/138389076): Think of the entering device scope behavior in the
231    # failure recovery case when dealing with preemptions.
232    if not current_device:
233      logging.info("Entering into master device scope: %s", master_device)
234      ops.device(master_device).__enter__()
235
236
237def _strip_prefix(s, prefix):
238  return s[len(prefix):] if s.startswith(prefix) else s
239