xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/tpu_values.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various classes representing TPU distributed values.
16
17Note that the tests are in values_test.py .
18
19"""
20
21from tensorflow.python.distribute import packed_distributed_variable as packed
22from tensorflow.python.distribute import tpu_replicated_variable
23from tensorflow.python.distribute import tpu_util
24from tensorflow.python.distribute import values
25from tensorflow.python.distribute import values_util
26from tensorflow.python.eager import context
27from tensorflow.python.eager import tape
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import gen_resource_variable_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import variable_scope
32
33
34_scatter_error_msg = ("{op_name} is only supported for distributed "
35                      "variable (variable created within certain "
36                      "`tf.distribute.Strategy` scope) with NONE "
37                      " aggregation, got: {aggregation}.")
38
39
40class TPUVariableMixin(object):
41  """Mixin for TPU variables."""
42
43  def __init__(self, *args, **kwargs):
44    super(TPUVariableMixin, self).__init__(*args, **kwargs)
45
46    # Handle ID is needed for `get_replicated_var_handle` to cache the variables
47    # correctly since in eager mode different variables can have the same name.
48    if ops.executing_eagerly_outside_functions():
49      self._handle_id = self._common_name + "_" + str(id(self._primary))
50    else:
51      self._handle_id = self._common_name
52
53  def __getattr__(self, name):
54    if tpu_util.enclosing_tpu_context() is None:
55      return super(TPUVariableMixin, self).__getattr__(name)
56    else:
57      raise AttributeError(
58          f"`TPUVariableMixin.{name}` not accessible within a TPU context.")
59
60  def get(self):
61    if tpu_util.enclosing_tpu_context() is None:
62      return super(TPUVariableMixin, self).get()
63    else:
64      raise NotImplementedError(
65          "`TPUVariableMixin.get()` is not supported within a TPU context.")
66
67  def _get_as_operand(self):
68    return self.read_value()
69
70  def _is_mirrored(self):
71    raise NotImplementedError(
72        "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.")
73
74  @property
75  def handle(self):
76    """The handle by which this variable can be accessed."""
77    # If we're in a tpu.rewrite(), return the replicated handle.
78    tpu_context = tpu_util.enclosing_tpu_context()
79    if tpu_context is None or context.executing_eagerly():
80      var = self._get_on_device_or_primary()
81      if isinstance(var, packed.PackedVarAndDevice):
82        return var.on_device_handle()
83      else:
84        return var.handle
85    else:
86      is_packed = self._packed_var is not None
87      val = self._values
88      if is_packed:
89        val = [self._packed_var]
90
91      return tpu_context.get_replicated_var_handle(self._common_name,
92                                                   self._handle_id, val,
93                                                   self._is_mirrored(),
94                                                   is_packed)
95
96  @property
97  def device(self):
98    return self.handle.device
99
100  def _read_variable_op(self):
101    """Reads the value of this variable."""
102    if self.trainable:
103      tape.variable_accessed(self)
104
105    handle = self.handle
106    if getattr(handle, "is_packed", False):
107      # Add a device scope for a packed variable handle.
108      with ops.device(self._get_on_device_or_primary().device):
109        return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
110    else:
111      return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
112
113  def read_value(self):
114    if tpu_util.enclosing_tpu_context() is None:
115      return super(TPUVariableMixin, self).read_value()
116    else:
117      return self._read_variable_op()
118
119  def value(self):
120    if tpu_util.enclosing_tpu_context() is None:
121      return super(TPUVariableMixin, self).value()
122    else:
123      return self._read_variable_op()
124
125  def _as_graph_element(self):
126    if tpu_util.enclosing_tpu_context() is None:
127      return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
128    else:
129      return None
130
131  @property
132  def op(self):
133    if values_util.is_saving_non_distributed():
134      return self._primary.op
135    return values.DistributedVarOp(self._primary.op.name,
136                                   self._primary.op.graph,
137                                   self._primary.op.traceback,
138                                   self._primary.op.type)
139
140  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
141    """Converts a variable to a tensor."""
142    # pylint: disable=protected-access
143    if tpu_util.enclosing_tpu_context() is None:
144      return super(TPUVariableMixin, self)._dense_var_to_tensor(
145          dtype=dtype, name=name, as_ref=as_ref)
146    # pylint: enable=protected-access
147    elif dtype is not None and dtype != self.dtype:
148      return math_ops.cast(self.read_value(), dtype)
149    else:
150      return self.handle if as_ref else self.read_value()
151
152
153class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
154  """DistributedVariable subclass for TPUStrategy."""
155
156  def _is_mirrored(self):
157    return self._policy._is_mirrored()  # pylint: disable=protected-access
158
159  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
160    if values_util.is_saving_non_distributed():
161      return self._primary.assign_sub(value, use_locking, name, read_value)
162    return self._policy.assign_sub(
163        self, value, use_locking=use_locking, name=name, read_value=read_value)
164
165  def assign_add(self, value, use_locking=False, name=None, read_value=True):
166    if values_util.is_saving_non_distributed():
167      return self._primary.assign_add(value, use_locking, name, read_value)
168    return self._policy.assign_add(
169        self, value, use_locking=use_locking, name=name, read_value=read_value)
170
171  def assign(self, value, use_locking=False, name=None, read_value=True):
172    if values_util.is_saving_non_distributed():
173      return self._primary.assign(value, use_locking, name, read_value)
174    return self._policy.assign(
175        self, value, use_locking=use_locking, name=name, read_value=read_value)
176
177  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
178    if values_util.is_saving_non_distributed():
179      return self._primary.scatter_sub(sparse_delta, use_locking, name)
180    return self._policy.scatter_sub(
181        self, sparse_delta, use_locking=use_locking, name=name)
182
183  def scatter_add(self, sparse_delta, use_locking=False, name=None):
184    if values_util.is_saving_non_distributed():
185      return self._primary.scatter_add(sparse_delta, use_locking, name)
186    return self._policy.scatter_add(
187        self, sparse_delta, use_locking=use_locking, name=name)
188
189  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
190    if values_util.is_saving_non_distributed():
191      return self._primary.scatter_mul(sparse_delta, use_locking, name)
192    return self._policy.scatter_mul(
193        self, sparse_delta, use_locking=use_locking, name=name)
194
195  def scatter_div(self, sparse_delta, use_locking=False, name=None):
196    if values_util.is_saving_non_distributed():
197      return self._primary.scatter_div(sparse_delta, use_locking, name)
198    return self._policy.scatter_div(
199        self, sparse_delta, use_locking=use_locking, name=name)
200
201  def scatter_min(self, sparse_delta, use_locking=False, name=None):
202    if values_util.is_saving_non_distributed():
203      return self._primary.scatter_min(sparse_delta, use_locking, name)
204    return self._policy.scatter_min(
205        self, sparse_delta, use_locking=use_locking, name=name)
206
207  def scatter_max(self, sparse_delta, use_locking=False, name=None):
208    if values_util.is_saving_non_distributed():
209      return self._primary.scatter_max(sparse_delta, use_locking, name)
210    return self._policy.scatter_max(
211        self, sparse_delta, use_locking=use_locking, name=name)
212
213  def scatter_update(self, sparse_delta, use_locking=False, name=None):
214    if values_util.is_saving_non_distributed():
215      return self._primary.scatter_update(sparse_delta, use_locking, name)
216    return self._policy.scatter_update(
217        self, sparse_delta, use_locking=use_locking, name=name)
218
219
220class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
221  """Holds a map from replica to TPU variables whose values are kept in sync."""
222
223  def _is_replicated_or_sharded_to_logical_cores(self):
224    """Returns whether each of the underlying variables is replicated or sharded to logical cores.
225
226    If True, the handles of the underlying variables are not available outside a
227    TPU context.
228    """
229    return isinstance(self._primary,
230                      tpu_replicated_variable.TPUReplicatedVariable)
231
232  @property
233  def device(self):
234    if (self._is_replicated_or_sharded_to_logical_cores() and
235        tpu_util.enclosing_tpu_context() is None):
236      return self._primary.device
237    return super(TPUMirroredVariable, self).device
238
239  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
240    tpu_context = tpu_util.enclosing_tpu_context()
241    if (self._is_replicated_or_sharded_to_logical_cores() and
242        tpu_context is None):
243      assign_sub_fn = lambda v, *a, **ka: v.assign_sub(*a, **ka)
244      return self._update(
245          update_fn=assign_sub_fn,
246          value=value,
247          use_locking=use_locking,
248          name=name,
249          read_value=read_value)
250
251    if (tpu_context and
252        self.aggregation == variable_scope.VariableAggregation.NONE):
253      return tpu_util.make_raw_assign_fn(
254          gen_resource_variable_ops.assign_sub_variable_op)(
255              self,
256              value=value,
257              use_locking=use_locking,
258              name=name,
259              read_value=read_value)
260    return assign_sub(
261        self, value, use_locking=use_locking, name=name, read_value=read_value)
262
263  def assign_add(self, value, use_locking=False, name=None, read_value=True):
264    tpu_context = tpu_util.enclosing_tpu_context()
265    if (self._is_replicated_or_sharded_to_logical_cores() and
266        tpu_context is None):
267      assign_add_fn = lambda v, *a, **ka: v.assign_add(*a, **ka)
268      return self._update(
269          update_fn=assign_add_fn,
270          value=value,
271          use_locking=use_locking,
272          name=name,
273          read_value=read_value)
274
275    if (tpu_context and
276        self.aggregation == variable_scope.VariableAggregation.NONE):
277      return tpu_util.make_raw_assign_fn(
278          gen_resource_variable_ops.assign_add_variable_op)(
279              self,
280              value=value,
281              use_locking=use_locking,
282              name=name,
283              read_value=read_value)
284    return assign_add(
285        self, value, use_locking=use_locking, name=name, read_value=read_value)
286
287  def assign(self, value, use_locking=False, name=None, read_value=True):
288    tpu_context = tpu_util.enclosing_tpu_context()
289    if (self._is_replicated_or_sharded_to_logical_cores() and
290        tpu_context is None):
291      assign_fn = lambda v, *a, **ka: v.assign(*a, **ka)
292      return self._update(
293          update_fn=assign_fn,
294          value=value,
295          use_locking=use_locking,
296          name=name,
297          read_value=read_value)
298
299    if (tpu_util.enclosing_tpu_context() and
300        self.aggregation == variable_scope.VariableAggregation.NONE):
301      return tpu_util.make_raw_assign_fn(
302          gen_resource_variable_ops.assign_variable_op)(
303              self,
304              value=value,
305              use_locking=use_locking,
306              name=name,
307              read_value=read_value)
308    return assign(
309        self, value, use_locking=use_locking, name=name, read_value=read_value)
310
311  def scatter_sub(self, *args, **kwargs):
312    if values_util.is_saving_non_distributed():
313      return self._primary.scatter_sub(*args, **kwargs)
314    raise NotImplementedError
315
316  def scatter_add(self, *args, **kwargs):
317    if values_util.is_saving_non_distributed():
318      return self._primary.scatter_add(*args, **kwargs)
319    raise NotImplementedError
320
321  def scatter_max(self, *args, **kwargs):
322    if values_util.is_saving_non_distributed():
323      return self._primary.scatter_max(*args, **kwargs)
324    raise NotImplementedError
325
326  def scatter_min(self, *args, **kwargs):
327    if values_util.is_saving_non_distributed():
328      return self._primary.scatter_min(*args, **kwargs)
329    raise NotImplementedError
330
331  def scatter_mul(self, *args, **kwargs):
332    if values_util.is_saving_non_distributed():
333      return self._primary.scatter_mul(*args, **kwargs)
334    raise NotImplementedError
335
336  def scatter_div(self, *args, **kwargs):
337    if values_util.is_saving_non_distributed():
338      return self._primary.scatter_div(*args, **kwargs)
339    raise NotImplementedError
340
341  def scatter_update(self, *args, **kwargs):
342    if values_util.is_saving_non_distributed():
343      return self._primary.scatter_update(*args, **kwargs)
344    raise NotImplementedError
345
346  def _is_mirrored(self):
347    return True
348
349
350class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
351  """Holds a map from replica to variables whose values are reduced on save."""
352
353  def assign_sub(self, *args, **kwargs):
354    if tpu_util.enclosing_tpu_context() is None:
355      return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
356    else:
357      return tpu_util.make_raw_assign_fn(
358          gen_resource_variable_ops.assign_sub_variable_op)(self, *args,
359                                                            **kwargs)
360
361  def assign_add(self, *args, **kwargs):
362    if tpu_util.enclosing_tpu_context() is None:
363      return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
364    else:
365      return tpu_util.make_raw_assign_fn(
366          gen_resource_variable_ops.assign_add_variable_op)(self, *args,
367                                                            **kwargs)
368
369  def assign(self, *args, **kwargs):
370    if tpu_util.enclosing_tpu_context() is None:
371      return values.SyncOnReadVariable.assign(self, *args, **kwargs)
372    else:
373      return tpu_util.make_raw_assign_fn(
374          gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs)
375
376  def _is_mirrored(self):
377    return False
378
379
380# Common method between OnWrite and Mirrored variables.
381def assign_sub(var, value, use_locking=False, name=None, read_value=True):
382  assign_sub_fn = tpu_util.make_raw_assign_fn(
383      gen_resource_variable_ops.assign_sub_variable_op)
384  return var._update(  # pylint: disable=protected-access
385      update_fn=assign_sub_fn,
386      value=value,
387      use_locking=use_locking,
388      name=name,
389      read_value=read_value)
390
391
392def assign_add(var, value, use_locking=False, name=None, read_value=True):
393  assign_add_fn = tpu_util.make_raw_assign_fn(
394      gen_resource_variable_ops.assign_add_variable_op)
395  return var._update(  # pylint: disable=protected-access
396      update_fn=assign_add_fn,
397      value=value,
398      use_locking=use_locking,
399      name=name,
400      read_value=read_value)
401
402
403def assign(var, value, use_locking=False, name=None, read_value=True):
404  assign_fn = tpu_util.make_raw_assign_fn(
405      gen_resource_variable_ops.assign_variable_op)
406  return var._update(  # pylint: disable=protected-access
407      update_fn=assign_fn,
408      value=value,
409      use_locking=use_locking,
410      name=name,
411      read_value=read_value)
412
413
414class TPUOnWritePolicy(values.OnWritePolicy):
415  """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization.
416
417  This policy is created when `synchronization` is set to
418  `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`.
419  """
420
421  def assign_sub(self,
422                 var,
423                 value,
424                 use_locking=False,
425                 name=None,
426                 read_value=True):
427    if (tpu_util.enclosing_tpu_context() and
428        var.aggregation == variable_scope.VariableAggregation.NONE):
429      return tpu_util.make_raw_assign_fn(
430          gen_resource_variable_ops.assign_sub_variable_op)(
431              var,
432              value=value,
433              use_locking=use_locking,
434              name=name,
435              read_value=read_value)
436    return assign_sub(
437        var, value, use_locking=use_locking, name=name, read_value=read_value)
438
439  def assign_add(self,
440                 var,
441                 value,
442                 use_locking=False,
443                 name=None,
444                 read_value=True):
445    if (tpu_util.enclosing_tpu_context() and
446        var.aggregation == variable_scope.VariableAggregation.NONE):
447      return tpu_util.make_raw_assign_fn(
448          gen_resource_variable_ops.assign_add_variable_op)(
449              var,
450              value=value,
451              use_locking=use_locking,
452              name=name,
453              read_value=read_value)
454    return assign_add(
455        var, value, use_locking=use_locking, name=name, read_value=read_value)
456
457  def assign(self, var, value, use_locking=False, name=None, read_value=True):
458    if (tpu_util.enclosing_tpu_context() and
459        var.aggregation == variable_scope.VariableAggregation.NONE):
460      return tpu_util.make_raw_assign_fn(
461          gen_resource_variable_ops.assign_variable_op)(
462              var,
463              value=value,
464              use_locking=use_locking,
465              name=name,
466              read_value=read_value)
467    return assign(
468        var, value, use_locking=use_locking, name=name, read_value=read_value)
469
470  def _scatter_xxx(self,
471                   raw_scater_xxx_fn,
472                   op_name,
473                   var,
474                   sparse_delta,
475                   use_locking=False,
476                   name=None):
477    scater_xxx_fn = tpu_util.make_raw_scatter_xxx_fn(raw_scater_xxx_fn)
478    if tpu_util.enclosing_tpu_context():
479      if self._aggregation != variable_scope.VariableAggregation.NONE:
480        raise NotImplementedError(
481            _scatter_error_msg.format(
482                op_name=op_name, aggregation=self._aggregation))
483      return scater_xxx_fn(
484          var, sparse_delta=sparse_delta, use_locking=use_locking, name=name)
485    else:
486      return var._update(  # pylint: disable=protected-access
487          update_fn=scater_xxx_fn,
488          value=sparse_delta,
489          use_locking=use_locking,
490          name=name)
491
492  def scatter_sub(self, var, sparse_delta, use_locking=False, name=None):
493    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub,
494                             "scatter_sub", var, sparse_delta, use_locking,
495                             name)
496
497  def scatter_add(self, var, sparse_delta, use_locking=False, name=None):
498    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add,
499                             "scatter_add", var, sparse_delta, use_locking,
500                             name)
501
502  def scatter_max(self, var, sparse_delta, use_locking=False, name=None):
503    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max,
504                             "scatter_max", var, sparse_delta, use_locking,
505                             name)
506
507  def scatter_min(self, var, sparse_delta, use_locking=False, name=None):
508    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min,
509                             "scatter_min", var, sparse_delta, use_locking,
510                             name)
511
512  def scatter_mul(self, var, sparse_delta, use_locking=False, name=None):
513    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul,
514                             "scatter_mul", var, sparse_delta, use_locking,
515                             name)
516
517  def scatter_div(self, var, sparse_delta, use_locking=False, name=None):
518    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div,
519                             "scatter_div", var, sparse_delta, use_locking,
520                             name)
521
522  def scatter_update(self, var, sparse_delta, use_locking=False, name=None):
523    return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update,
524                             "scatter_update", var, sparse_delta, use_locking,
525                             name)
526
527  def _is_mirrored(self):
528    return True
529
530
531class TPUOnReadPolicy(values.OnReadPolicy):
532  """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization.
533
534  This policy is created when `synchronization` is set to
535  `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the
536  values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`,
537  `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute`
538  scope.
539  """
540
541  def assign_sub(self, var, *args, **kwargs):
542    if tpu_util.enclosing_tpu_context() is None:
543      return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
544    else:
545      return tpu_util.make_raw_assign_fn(
546          gen_resource_variable_ops.assign_sub_variable_op)(var, *args,
547                                                            **kwargs)
548
549  def assign_add(self, var, *args, **kwargs):
550    if tpu_util.enclosing_tpu_context() is None:
551      return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
552    else:
553      return tpu_util.make_raw_assign_fn(
554          gen_resource_variable_ops.assign_add_variable_op)(var, *args,
555                                                            **kwargs)
556
557  def assign(self, var, *args, **kwargs):
558    if tpu_util.enclosing_tpu_context() is None:
559      return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
560    else:
561      return tpu_util.make_raw_assign_fn(
562          gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs)
563
564  def _is_mirrored(self):
565    return False
566
567  def scatter_sub(self, *args, **kwargs):
568    raise NotImplementedError
569
570  def scatter_add(self, *args, **kwargs):
571    raise NotImplementedError
572
573  def scatter_max(self, *args, **kwargs):
574    raise NotImplementedError
575
576  def scatter_min(self, *args, **kwargs):
577    raise NotImplementedError
578
579  def scatter_mul(self, *args, **kwargs):
580    raise NotImplementedError
581
582  def scatter_div(self, *args, **kwargs):
583    raise NotImplementedError
584
585  def scatter_update(self, *args, **kwargs):
586    raise NotImplementedError
587