1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to connect to remote servers.""" 16 17import copy 18 19from absl import logging 20 21from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef 22from tensorflow.python import pywrap_tfe 23from tensorflow.python.distribute import device_util 24from tensorflow.python.distribute.cluster_resolver import cluster_resolver 25from tensorflow.python.eager import context 26from tensorflow.python.framework import ops 27from tensorflow.python.platform import remote_utils 28from tensorflow.python.training import server_lib 29from tensorflow.python.util import nest 30from tensorflow.python.util.tf_export import tf_export 31 32 33_GRPC_PREFIX = "grpc://" 34_LOCAL_MASTERS = ("", "local") 35 36 37@tf_export("config.experimental_connect_to_host") 38def connect_to_remote_host(remote_host=None, job_name="worker"): 39 """Connects to a single machine to enable remote execution on it. 40 41 Will make devices on the remote host available to use. Note that calling this 42 more than once will work, but will invalidate any tensor handles on the old 43 remote devices. 44 45 Using the default job_name of worker, you can schedule ops to run remotely as 46 follows: 47 ```python 48 # When eager execution is enabled, connect to the remote host. 49 tf.config.experimental_connect_to_host("exampleaddr.com:9876") 50 51 with ops.device("job:worker/replica:0/task:1/device:CPU:0"): 52 # The following tensors should be resident on the remote device, and the op 53 # will also execute remotely. 54 x1 = array_ops.ones([2, 2]) 55 x2 = array_ops.ones([2, 2]) 56 y = math_ops.matmul(x1, x2) 57 ``` 58 59 Args: 60 remote_host: a single or a list the remote server addr in host-port format. 61 job_name: The job name under which the new server will be accessible. 62 63 Raises: 64 ValueError: if remote_host is None. 65 """ 66 if not remote_host: 67 raise ValueError("Must provide at least one remote_host") 68 69 remote_hosts = nest.flatten(remote_host) 70 cluster_spec = server_lib.ClusterSpec( 71 {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]}) 72 73 connect_to_cluster(cluster_spec) 74 75 76@tf_export("config.experimental_connect_to_cluster") 77def connect_to_cluster(cluster_spec_or_resolver, 78 job_name="localhost", 79 task_index=0, 80 protocol=None, 81 make_master_device_default=True, 82 cluster_device_filters=None): 83 """Connects to the given cluster. 84 85 Will make devices on the cluster available to use. Note that calling this more 86 than once will work, but will invalidate any tensor handles on the old remote 87 devices. 88 89 If the given local job name is not present in the cluster specification, it 90 will be automatically added, using an unused port on the localhost. 91 92 Device filters can be specified to isolate groups of remote tasks to avoid 93 undesired accesses between workers. Workers accessing resources or launching 94 ops / functions on filtered remote devices will result in errors (unknown 95 devices). For any remote task, if no device filter is present, all cluster 96 devices will be visible; if any device filter is specified, it can only 97 see devices matching at least one filter. Devices on the task itself are 98 always visible. Device filters can be particially specified. 99 100 For example, for a cluster set up for parameter server training, the following 101 device filters might be specified: 102 103 ```python 104 cdf = tf.config.experimental.ClusterDeviceFilters() 105 # For any worker, only the devices on PS nodes and itself are visible 106 for i in range(num_workers): 107 cdf.set_device_filters('worker', i, ['/job:ps']) 108 # Similarly for any ps, only the devices on workers and itself are visible 109 for i in range(num_ps): 110 cdf.set_device_filters('ps', i, ['/job:worker']) 111 112 tf.config.experimental_connect_to_cluster(cluster_def, 113 cluster_device_filters=cdf) 114 ``` 115 116 Args: 117 cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing 118 the cluster. 119 job_name: The name of the local job. 120 task_index: The local task index. 121 protocol: The communication protocol, such as `"grpc"`. If unspecified, will 122 use the default from `python/platform/remote_utils.py`. 123 make_master_device_default: If True and a cluster resolver is passed, will 124 automatically enter the master task device scope, which indicates the 125 master becomes the default device to run ops. It won't do anything if 126 a cluster spec is passed. Will throw an error if the caller is currently 127 already in some device scope. 128 cluster_device_filters: an instance of 129 `tf.train.experimental/ClusterDeviceFilters` that specify device filters 130 to the remote tasks in cluster. 131 """ 132 if not context.executing_eagerly(): 133 raise ValueError( 134 "`tf.config.experimental_connect_to_cluster` can only be called in " 135 "eager mode." 136 ) 137 protocol = protocol or remote_utils.get_default_communication_protocol() 138 if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec): 139 cluster_spec = cluster_spec_or_resolver 140 elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver): 141 if cluster_spec_or_resolver.master() in _LOCAL_MASTERS: 142 # Do nothing if the master is local. 143 return 144 cluster_spec = cluster_spec_or_resolver.cluster_spec() 145 else: 146 raise ValueError( 147 "`cluster_spec_or_resolver` must be a `ClusterSpec` or a " 148 "`ClusterResolver`.") 149 150 cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) 151 if cluster_device_filters: 152 if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): 153 cluster_device_filters = copy.deepcopy( 154 cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access 155 else: 156 raise ValueError("`cluster_device_filters` must be an instance of " 157 "`tf.train.experimental.ClusterDeviceFilters`.") 158 159 # Check whether the server def has changed. We need to do the check before the 160 # local job is added to the cluster. 161 is_server_def_changed = False 162 current_server_def = context.get_server_def() 163 if current_server_def and job_name not in cluster_spec.jobs: 164 for i, job in enumerate(current_server_def.cluster.job): 165 if job.name == job_name: 166 del current_server_def.cluster.job[i] 167 if (current_server_def is None or current_server_def.cluster != cluster_def or 168 current_server_def.job_name != job_name or 169 current_server_def.task_index != task_index): 170 is_server_def_changed = True 171 172 # Automatically add local job, if not part of the cluster spec. 173 if job_name not in cluster_spec.jobs: 174 local_port = pywrap_tfe.TF_PickUnusedPortOrDie() 175 job_def = cluster_def.job.add() 176 job_def.name = job_name 177 # TODO(fishx): Update this to make sure remote worker has valid ip address 178 # to connect with local. 179 job_def.tasks[0] = "localhost:{}".format(local_port) 180 181 if context.context().coordination_service is None: 182 # Maybe enable coordination service for the communication protocol 183 coordination_service = remote_utils.coordination_service_type(protocol) 184 if coordination_service: 185 context.context().configure_coordination_service(coordination_service) 186 187 server_def = ServerDef( 188 cluster=cluster_def, 189 job_name=job_name, 190 task_index=task_index, 191 protocol=protocol, 192 default_session_config=context.context().config, 193 cluster_device_filters=cluster_device_filters) 194 195 if is_server_def_changed: 196 context.set_server_def(server_def) 197 else: 198 context.update_server_def(server_def) 199 200 if make_master_device_default and isinstance( 201 cluster_spec_or_resolver, 202 cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master(): 203 master = cluster_spec_or_resolver.master() 204 master_job_name = None 205 master_task_id = None 206 for job_name in cluster_spec.jobs: 207 for task_id in cluster_spec.task_indices(job_name): 208 task_address = cluster_spec.task_address(job_name, task_id) 209 if master in task_address or task_address in master: 210 master_job_name = job_name 211 master_task_id = task_id 212 break 213 214 if not master_job_name: 215 raise ValueError( 216 "`make_master_device_default` is set to True but cannot find " 217 "master %s in the cluster" % master) 218 219 master_device = "/job:{}/replica:0/task:{}".format(master_job_name, 220 master_task_id) 221 master_device = device_util.canonicalize(master_device) 222 current_device = device_util.current() 223 if current_device: 224 current_device = device_util.canonicalize(current_device) 225 if current_device and current_device != master_device: 226 raise ValueError("`connect_to_cluster` is called inside existing device " 227 "scope %s, which is different from the master device " 228 "scope %s to enter. This is not allowed." % 229 (current_device, master_device)) 230 # TODO(b/138389076): Think of the entering device scope behavior in the 231 # failure recovery case when dealing with preemptions. 232 if not current_device: 233 logging.info("Entering into master device scope: %s", master_device) 234 ops.device(master_device).__enter__() 235 236 237def _strip_prefix(s, prefix): 238 return s[len(prefix):] if s.startswith(prefix) else s 239