xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/packed_distributed_variable.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A variable which packs a list of variables distributed across devices."""
16
17from tensorflow.python.distribute import device_util
18from tensorflow.python.eager import context
19from tensorflow.python.framework import ops
20from tensorflow.python.ops import math_ops
21from tensorflow.python.ops import resource_variable_ops
22
23
24class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
25  """A variable which packs multiple variables distributed across devices.
26
27  It's only supported when eager execution is enabled.
28  For op-by-op execution, use an unpacked handle on the current device; for
29  function execution, use the packed handle to reduce the overhead of function
30  calls.
31  """
32
33  def __init__(self, distributed_variables=None, name=None, **unused_kwargs):
34    """Packs a list of variables which are distributed across devices.
35
36    Args:
37      distributed_variables: A list of distributed Variables to pack.
38      name: Optional name for the variable. Defaults to `'Variable'` and gets
39        uniquified automatically.
40    """
41    if not ops.executing_eagerly_outside_functions():
42      raise ValueError(
43          "PackedDistributedVariable should be created in eager mode.")
44    if not distributed_variables:
45      raise ValueError("Expect a non-empty list of variables to pack.")
46    for i, var in enumerate(distributed_variables):
47      if not resource_variable_ops.is_resource_variable(var):
48        raise ValueError("Expect a list of ResourceVariables to pack, "
49                         "but the %d-th variable is %s" % (i, type(var)))
50
51    self._distributed_variables = distributed_variables
52    self._devices = [v.device for v in distributed_variables]
53    with ops.init_scope():
54      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
55        handle = ops.pack_eager_tensors(
56            [var.handle for var in distributed_variables])
57        handle_name = ops.name_from_scope_name(name)
58        unique_id = "%s_%d" % (handle_name, ops.uid())
59        super(PackedDistributedVariable, self).__init__(
60            trainable=distributed_variables[0].trainable,
61            shape=distributed_variables[0].shape,
62            dtype=distributed_variables[0].dtype,
63            handle=handle,
64            synchronization=distributed_variables[0].synchronization,
65            constraint=distributed_variables[0].constraint,
66            aggregation=distributed_variables[0].aggregation,
67            distribute_strategy=distributed_variables[0]._distribute_strategy,  # pylint: disable=protected-access
68            name=name,
69            unique_id=unique_id,
70            handle_name=handle_name,
71            graph_element=None,
72            initial_value=None,
73            initializer_op=None,
74            is_initialized_op=None,
75            cached_value=None,
76            caching_device=None,
77            is_distributed_variables=True)
78
79  @property
80  def devices(self):
81    return self._devices
82
83  def on_device(self, device):
84    return PackedVarAndDevice(self, device)
85
86  def get_var_on_device(self, device):
87    for i, d in enumerate(self._devices):
88      if d == device:
89        return self._distributed_variables[i]
90    raise ValueError("Device %s is not found" % device)
91
92  def get_var_on_current_device(self):
93    current_device = device_util.canonicalize(device_util.current())
94    return self.get_var_on_device(current_device)
95
96  def initial_value(self, device):
97    """Returns the Tensor used as the initial value for the variable."""
98    return self.get_var_on_device(device).initial_value
99
100  @property
101  def handle(self):
102    if context.executing_eagerly():
103      return self.get_var_on_current_device().handle
104    else:
105      return self._handle
106
107  @property
108  def packed_handle(self):
109    return self._handle
110
111  def _read_variable_op(self):
112    if context.executing_eagerly():
113      return self.get_var_on_current_device().value()
114    else:
115      return super(PackedDistributedVariable, self)._read_variable_op()
116
117  def value(self):
118    return self._read_variable_op()
119
120  def is_initialized(self, name=None):
121    if context.executing_eagerly():
122      result = self._distributed_variables[0].is_initialized()
123      for v in self._distributed_variables[1:-1]:
124        result = math_ops.logical_and(result, v.is_initialized())
125      result = math_ops.logical_and(
126          result, self._distributed_variables[-1].is_initialized(), name=name)
127    else:
128      with ops.device(self._devices[0]):
129        result = super(PackedDistributedVariable, self).is_initialized(name)
130      for d in self._devices[1:-1]:
131        with ops.device(d):
132          initialized = super(PackedDistributedVariable,
133                              self).is_initialized(name)
134        result = math_ops.logical_and(result, initialized)
135      with ops.device(self._devices[-1]):
136        initialized = super(PackedDistributedVariable,
137                            self).is_initialized(name)
138      result = math_ops.logical_and(result, initialized, name=name)
139    return result
140
141  def _update(self, update_fn, value, **kwargs):
142    if context.executing_eagerly():
143      return update_fn(self.get_var_on_current_device(), value, **kwargs)
144    else:
145      return update_fn(super(PackedDistributedVariable, self), value, **kwargs)
146
147  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
148    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
149    return self._update(
150        update_fn=assign_sub_fn,
151        value=delta,
152        use_locking=use_locking,
153        name=name,
154        read_value=read_value)
155
156  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
157    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
158    return self._update(
159        update_fn=assign_add_fn,
160        value=delta,
161        use_locking=use_locking,
162        name=name,
163        read_value=read_value)
164
165  def assign(self, value, use_locking=None, name=None, read_value=True):
166    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
167    return self._update(
168        update_fn=assign_fn,
169        value=value,
170        use_locking=use_locking,
171        name=name,
172        read_value=read_value)
173
174  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
175    scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
176    return self._update(
177        update_fn=scatter_sub_fn,
178        value=sparse_delta,
179        use_locking=use_locking,
180        name=name)
181
182  def scatter_add(self, sparse_delta, use_locking=False, name=None):
183    scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
184    return self._update(
185        update_fn=scatter_add_fn,
186        value=sparse_delta,
187        use_locking=use_locking,
188        name=name)
189
190  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
191    scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
192    return self._update(
193        update_fn=scatter_mul_fn,
194        value=sparse_delta,
195        use_locking=use_locking,
196        name=name)
197
198  def scatter_div(self, sparse_delta, use_locking=False, name=None):
199    scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
200    return self._update(
201        update_fn=scatter_div_fn,
202        value=sparse_delta,
203        use_locking=use_locking,
204        name=name)
205
206  def scatter_min(self, sparse_delta, use_locking=False, name=None):
207    scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
208    return self._update(
209        update_fn=scatter_min_fn,
210        value=sparse_delta,
211        use_locking=use_locking,
212        name=name)
213
214  def scatter_max(self, sparse_delta, use_locking=False, name=None):
215    scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
216    return self._update(
217        update_fn=scatter_max_fn,
218        value=sparse_delta,
219        use_locking=use_locking,
220        name=name)
221
222  def scatter_update(self, sparse_delta, use_locking=False, name=None):
223    scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
224    return self._update(
225        update_fn=scatter_update_fn,
226        value=sparse_delta,
227        use_locking=use_locking,
228        name=name)
229
230  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
231    if context.executing_eagerly():
232      return self.get_var_on_current_device()._dense_var_to_tensor(  # pylint: disable=protected-access
233          dtype=dtype,
234          name=name,
235          as_ref=as_ref)
236    else:
237      return super(PackedDistributedVariable, self)._dense_var_to_tensor(  # pylint: disable=protected-access
238          dtype=dtype,
239          name=name,
240          as_ref=as_ref)
241
242
243class PackedVarAndDevice(object):
244  """Holds a packed distributed variable and a device."""
245
246  def __init__(self, var, device):
247    self._var = var
248    self._device = device
249
250  def __getattr__(self, name):
251    # Exceptions raised inside the contextmanager can cause a reference
252    # cycle.[1] The cycle involves the current frame, which holds the reference
253    # to the outer frame. Tensorflow, e.g. iterators, relies on object
254    # finalizers to clean up resources. Such references prevents the resource
255    # from being deleted and can cause leaks and errors. One corner the case is
256    # that iterators are kept alive and the garbage collector happens to run
257    # after auto control dependencies; this causes the deletion to lose the
258    # control dependencies to operations that uses such resources.
259    #
260    # Catch and re-raise the exception seems to workaround the issue.
261    #
262    # [1] https://bugs.python.org/issue43533
263    try:
264      with ops.device(self._device):
265        return getattr(self._var, name)
266    except:  # pylint: disable=try-except-raise
267      raise
268
269  def var(self):
270    return self._var
271
272  def value(self):
273    with ops.device(self._device):
274      return self._var.value()
275
276  def read_value(self):
277    with ops.device(self._device):
278      return self._var.read_value()
279
280  @property
281  def initial_value(self):
282    return self._var.initial_value(self._device)
283
284  def initialized_value(self):
285    with ops.device(self._device):
286      return self._var.initialized_value()
287
288  @property
289  def device(self):
290    return self._device
291
292  @property
293  def handle(self):
294    with ops.device(self._device):
295      return self._var.handle
296
297  def on_device_handle(self):
298    with ops.device(self._device):
299      return self._var.get_var_on_current_device().handle
300
301  @property
302  def op(self):
303    with ops.device(self._device):
304      return self._var.op
305
306  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
307    with ops.device(self._device):
308      return self._var.assign_sub(delta, use_locking, name, read_value)
309
310  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
311    with ops.device(self._device):
312      return self._var.assign_add(delta, use_locking, name, read_value)
313
314  def assign(self, value, use_locking=None, name=None, read_value=True):
315    with ops.device(self._device):
316      return self._var.assign(value, use_locking, name, read_value)
317
318  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
319    with ops.device(self._device):
320      return self._var.scatter_sub(sparse_delta, use_locking, name)
321
322  def scatter_add(self, sparse_delta, use_locking=False, name=None):
323    with ops.device(self._device):
324      return self._var.scatter_add(sparse_delta, use_locking, name)
325
326  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
327    with ops.device(self._device):
328      return self._var.scatter_mul(sparse_delta, use_locking, name)
329
330  def scatter_div(self, sparse_delta, use_locking=False, name=None):
331    with ops.device(self._device):
332      return self._var.scatter_div(sparse_delta, use_locking, name)
333
334  def scatter_min(self, sparse_delta, use_locking=False, name=None):
335    with ops.device(self._device):
336      return self._var.scatter_min(sparse_delta, use_locking, name)
337
338  def scatter_max(self, sparse_delta, use_locking=False, name=None):
339    with ops.device(self._device):
340      return self._var.scatter_max(sparse_delta, use_locking, name)
341
342  def scatter_update(self, sparse_delta, use_locking=False, name=None):
343    with ops.device(self._device):
344      return self._var.scatter_update(sparse_delta, use_locking, name)
345
346  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
347    with ops.device(self._device):
348      return self._var._dense_var_to_tensor(  # pylint: disable=protected-access
349          dtype=dtype,
350          name=name,
351          as_ref=as_ref)
352
353  def _as_graph_element(self):
354    return self._var._as_graph_element()  # pylint: disable=protected-access
355
356
357def _tensor_conversion_packed_var_and_device(var,
358                                             dtype=None,
359                                             name=None,
360                                             as_ref=False):
361  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
362
363
364ops.register_tensor_conversion_function(
365    PackedVarAndDevice, _tensor_conversion_packed_var_and_device)
366