1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed values.""" 16 17import copy 18import weakref 19 20from tensorflow.python.distribute import device_util 21from tensorflow.python.distribute import distribute_lib 22from tensorflow.python.distribute import distribution_strategy_context as ds_context 23from tensorflow.python.distribute import packed_distributed_variable as packed 24from tensorflow.python.distribute import reduce_util 25from tensorflow.python.distribute import values_util 26from tensorflow.python.eager import context 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.framework import type_spec 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variable_scope as vs 36from tensorflow.python.ops import variables as variables_lib 37from tensorflow.python.trackable import base as trackable 38from tensorflow.python.training.saving import saveable_object 39from tensorflow.python.types import core 40from tensorflow.python.types import distribute as ds_types 41from tensorflow.python.types import trace 42 43 44def _on_write_update_replica(var, update_fn, value, **kwargs): 45 """Updates variables with ON_WRITE synchronization in replica context.""" 46 if var.aggregation == vs.VariableAggregation.NONE: 47 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 48 49 if not ds_context.get_strategy().extended._use_merge_call(): # pylint: disable=protected-access 50 # Don't allow MEAN with non float dtype, since it may cause unexpected 51 # precision loss. Python3 and NumPy automatically upcast integers to 52 # float in division, but we should always preserve the type. 53 if var.aggregation == vs.VariableAggregation.MEAN and ( 54 not var.dtype.is_floating) and tensor_util.is_tf_type(value): 55 raise ValueError( 56 "Cannot update non-float variables with " 57 "tf.VariableAggregation.MEAN aggregation in replica context. " 58 "Either change the variable dtype to float or update it in " 59 "cross-replica context.") 60 61 aggregated_value = apply_aggregation_replica_context( 62 value, var.aggregation, var) 63 values_util.mark_as_unsaveable() 64 65 return ds_context.get_replica_context()._update( # pylint: disable=protected-access 66 var, 67 update_fn, 68 args=(aggregated_value,), 69 kwargs=kwargs, 70 group=True) 71 72 else: 73 74 def merge_fn(strategy, value, **kwargs): 75 """Aggregate values and update all variables in cross replica context.""" 76 # Don't allow MEAN with non float dtype, since it may cause unexpected 77 # precision loss. Python3 and NumPy automatically upcast integers to 78 # float in division, but we should always preserve the type. 79 # 80 # Note that to be backward compatible we allow the case when the value 81 # is *always* the same on each replica. I.E. value is not a 82 # PerReplica. Refer to regroup() to see how values are grouped. 83 if var.aggregation == vs.VariableAggregation.MEAN and ( 84 not var.dtype.is_floating) and isinstance(value, PerReplica): 85 raise ValueError( 86 "Cannot update non-float variables with " 87 "tf.VariableAggregation.MEAN aggregation in replica context. " 88 "Either change the variable dtype to float or update it in " 89 "cross-replica context.") 90 91 assert strategy == var.distribute_strategy 92 v = values_util.apply_aggregation(strategy, value, var.aggregation, var) 93 return var._update_cross_replica(update_fn, v, **kwargs) # pylint: disable=protected-access 94 95 return ds_context.get_replica_context().merge_call( 96 merge_fn, args=(value,), kwargs=kwargs) 97 98 99def apply_aggregation_replica_context(value, aggregation, destinations): 100 """Aggregate `value` to `destinations` as specified by `aggregation`.""" 101 # if it is a python literal, return without aggregation 102 if isinstance(value, DistributedValues): 103 raise TypeError( 104 "Cannot use DistributedValues to update variables in replica context.") 105 if not tensor_util.is_tf_type(value): 106 return value 107 108 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 109 # Switch to cross-replica context to broadcast 110 def merge_fn(strategy, value): 111 return strategy.extended.broadcast_to( 112 strategy.experimental_local_results(value)[0], 113 destinations=destinations) 114 115 return ds_context.get_replica_context().merge_call(merge_fn, args=(value,)) 116 117 else: 118 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 119 aggregated_value = ds_context.get_strategy( # pylint: disable=protected-access 120 ).extended._replica_ctx_all_reduce(reduce_op, value) 121 return aggregated_value 122 123 124class DistributedValues(ds_types.DistributedValues): 125 """Base class for representing distributed values.""" 126 127 def __init__(self, values): 128 """Should only be called by subclass __init__.""" 129 self._values = tuple(values) 130 131 def _get(self): 132 """Returns the value for the current device or raises a ValueError.""" 133 replica_id = values_util.get_current_replica_id_as_int() 134 if replica_id is None: 135 return self._get_cross_replica() 136 else: 137 return self._values[replica_id] 138 139 def _get_cross_replica(self): 140 raise NotImplementedError( 141 "DistributedValues._get_cross_replica should be implemented by " 142 "sub-classes which support cross-replica accesses.") 143 144 def _get_on_device_or_primary(self): 145 """Returns value in same replica or device if possible, else the _primary.""" 146 replica_id = values_util.get_current_replica_id_as_int() 147 if replica_id is None: 148 # Try to find a value on the current device. 149 current_device = device_util.canonicalize(device_util.current()) 150 for value in self._values: 151 if device_util.canonicalize(value.device) == current_device: 152 return value 153 return self._primary 154 else: 155 return self._values[replica_id] 156 157 @property 158 def _primary(self): 159 """Returns a representative component.""" 160 return self._values[0] 161 162 @property 163 def _devices(self): 164 return tuple(v.device for v in self._values) 165 166 def __str__(self): 167 debug_str = ",\n".join( 168 " %d: %s" % (i, v) for i, v in enumerate(self._values)) 169 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_str) 170 171 def __repr__(self): 172 debug_repr = ",\n".join( 173 " %d: %r" % (i, v) for i, v in enumerate(self._values)) 174 return "%s:{\n%s\n}" % (self.__class__.__name__, debug_repr) 175 176 177# NOTE(josh11b,apassos): It would be great if we could inspect the values this was 178# initialized with and use that to generate the overloaded operators here. 179# Unfortunately, Python's rules for special methods don't allow this, see 180# https://docs.python.org/3/reference/datamodel.html#special-method-names 181# "if a class defines a method named __getitem__(), and x is an instance of 182# this class, then x[i] is roughly equivalent to type(x).__getitem__(x, i)." 183# In particular, these special methods don't go through __getattr__, and 184# it will only use those methods if they are defined in the class, not the 185# object. 186class DistributedDelegate(DistributedValues): 187 """A map from device to values; acts as the same type as the values.""" 188 189 def __getattr__(self, name): 190 # The '_use_resource_variables' and the attrs starts with '_self' are used 191 # for restoring the saved_model proto, and '_attribute_sentinel' is used for 192 # Layer tracking. At the point these attrs are queried, the variable has not 193 # been initialized. Thus it should not query those of the underlying 194 # components. 195 if name.startswith("_self_") or name in ("_use_resource_variables", 196 "_attribute_sentinel", 197 "_distributed_container"): 198 return super(DistributedDelegate, self).__getattr__(name) 199 200 # This allows copy.copy(DistributedDelegate). When copying an object, 201 # copy.copy doesn't invoke its __init__ method, instead it makes a new 202 # empty object, then copies the attributes over. copy.copy looks for 203 # attributes like "__getstate__" in case the object implements its custom 204 # copying. Since DistributedDelegate doesn't have those attributes defined, 205 # __getattr__ will be invoked, which tries to access "_values" attributes, 206 # but that doesn't exist either because this is an empty object, and again 207 # __getattr__ is invoked, leading to an infinite recursion. 208 if name == "_values": 209 raise AttributeError() 210 211 # TODO(priyag): This needs to be made robust against pitfalls from mix use 212 # __getattr__ and @property. See b/120402273. 213 return getattr(self._get(), name) 214 215 @property 216 def values(self): 217 """Returns the per replica values.""" 218 return self._values 219 220 def _get_as_operand(self): 221 """Returns the value for operations for the current device. 222 223 Some implementations, e.g. `TPUMirroredVariable`, are not able to return the 224 value type within a replica context. They can, however, return a value that 225 can be used by the operations below. 226 """ 227 return self._get() 228 229 # pylint: disable=multiple-statements 230 def __add__(self, o): 231 return self._get_as_operand() + o 232 233 def __radd__(self, o): 234 return o + self._get_as_operand() 235 236 def __sub__(self, o): 237 return self._get_as_operand() - o 238 239 def __rsub__(self, o): 240 return o - self._get_as_operand() 241 242 def __mul__(self, o): 243 return self._get_as_operand() * o 244 245 def __rmul__(self, o): 246 return o * self._get_as_operand() 247 248 def __truediv__(self, o): 249 return self._get_as_operand() / o 250 251 def __rtruediv__(self, o): 252 return o / self._get_as_operand() 253 254 def __floordiv__(self, o): 255 return self._get_as_operand() // o 256 257 def __rfloordiv__(self, o): 258 return o // self._get_as_operand() 259 260 def __mod__(self, o): 261 return self._get_as_operand() % o 262 263 def __rmod__(self, o): 264 return o % self._get_as_operand() 265 266 def __lt__(self, o): 267 return self._get_as_operand() < o 268 269 def __le__(self, o): 270 return self._get_as_operand() <= o 271 272 def __gt__(self, o): 273 return self._get_as_operand() > o 274 275 def __ge__(self, o): 276 return self._get_as_operand() >= o 277 278 def __and__(self, o): 279 return self._get_as_operand() & o 280 281 def __rand__(self, o): 282 return o & self._get_as_operand() 283 284 def __or__(self, o): 285 return self._get_as_operand() | o 286 287 def __ror__(self, o): 288 return o | self._get_as_operand() 289 290 def __xor__(self, o): 291 return self._get_as_operand() ^ o 292 293 def __rxor__(self, o): 294 return o ^ self._get_as_operand() 295 296 def __getitem__(self, o): 297 return self._get_as_operand()[o] 298 299 def __pow__(self, o, modulo=None): 300 return pow(self._get_as_operand(), o, modulo) 301 302 def __rpow__(self, o): 303 return pow(o, self._get_as_operand()) 304 305 def __invert__(self): 306 return ~self._get_as_operand() 307 308 def __neg__(self): 309 return -self._get_as_operand() 310 311 def __abs__(self): 312 return abs(self._get_as_operand()) 313 314 def __div__(self, o): 315 try: 316 return self._get_as_operand().__div__(o) 317 except AttributeError: 318 # See https://docs.python.org/3/library/constants.html#NotImplemented 319 return NotImplemented 320 321 def __rdiv__(self, o): 322 try: 323 return self._get_as_operand().__rdiv__(o) 324 except AttributeError: 325 # See https://docs.python.org/3/library/constants.html#NotImplemented 326 return NotImplemented 327 328 def __matmul__(self, o): 329 try: 330 return self._get_as_operand().__matmul__(o) 331 except AttributeError: 332 # See https://docs.python.org/3/library/constants.html#NotImplemented 333 return NotImplemented 334 335 def __rmatmul__(self, o): 336 try: 337 return self._get_as_operand().__rmatmul__(o) 338 except AttributeError: 339 # See https://docs.python.org/3/library/constants.html#NotImplemented 340 return NotImplemented 341 342 # TODO(josh11b): Even more operator overloads. 343 344 345class PerReplica(DistributedValues, composite_tensor.CompositeTensor, 346 ds_types.PerReplica): 347 """Holds a map from replica to unsynchronized values.""" 348 349 @property 350 def _type_spec(self): 351 return PerReplicaSpec( 352 *(type_spec.type_spec_from_value(v) for v in self._values)) 353 354 @property 355 def values(self): 356 """Returns the per replica values.""" 357 return self._values 358 359 360def _per_replica_to_tensor(var, dtype=None, name=None, as_ref=False): 361 """Converts a `PerReplica` to a `Tensor`.""" 362 del name 363 if dtype is not None and not dtype.is_compatible_with(var.dtype): 364 raise ValueError( 365 "Incompatible type conversion requested to type {!r} for variable " 366 "of type {!r}".format(dtype.name, var.dtype.name)) 367 if as_ref: 368 raise NotImplementedError( 369 "PerReplica doesn't support being used as a reference.") 370 if ds_context.in_cross_replica_context() or not ds_context.has_strategy(): 371 raise ValueError("It looks like you are using a PerReplica object while " 372 "not inside a replica context, which is not supported. " 373 "Try running your op or function inside a replica context " 374 "by using `strategy.run`") 375 else: 376 replica_id = values_util.get_current_replica_id_as_int() 377 return var.values[replica_id] 378 379# Register a conversion function to provide a useful error message when users 380# try to use PerReplica values in the wrong contexts 381ops.register_tensor_conversion_function(PerReplica, _per_replica_to_tensor) 382 383 384class PerReplicaSpec(type_spec.TypeSpec): 385 """Type specification for a `PerReplica`.""" 386 387 __slots__ = ["_value_specs"] 388 389 value_type = property(lambda self: PerReplica) 390 391 def __init__(self, *value_specs): 392 self._value_specs = tuple(value_specs) 393 394 def _serialize(self): 395 return self._value_specs 396 397 @property 398 def _component_specs(self): 399 return self._value_specs 400 401 def _to_components(self, value): 402 replica_context = ds_context.get_replica_context() 403 if replica_context is not None and replica_context.num_replicas_in_sync > 1: 404 raise ValueError( 405 "Flattening a PerReplica to components is not supported in replica " 406 "context.") 407 return value._values # pylint: disable=protected-access 408 409 def _from_components(self, tensor_list): 410 return PerReplica(tensor_list) 411 412 413# Note that unlike PerReplica, Mirrored values inherit from 414# DistributedDelegate and so can be used directly in cross-replica mode. 415# TODO(tomhennigan) Should this extend CompositeTensor? 416class Mirrored(DistributedDelegate, ds_types.Mirrored): 417 """Holds a map from replica to values which are kept in sync.""" 418 419 def _get_cross_replica(self): 420 return self._get_on_device_or_primary() 421 422 def _as_graph_element(self): 423 obj = self._get() 424 conv_fn = getattr(obj, "_as_graph_element", None) 425 if conv_fn and callable(conv_fn): 426 return conv_fn() 427 return obj 428 429 430class DistributedVarOp(object): 431 """A class that looks like `tf.Operation`.""" 432 433 def __init__(self, name, graph, traceback, typ): 434 self.name = name 435 self.graph = graph 436 self.traceback = traceback 437 self.type = typ 438 439 def __eq__(self, o): 440 if not isinstance(o, self.__class__): 441 raise NotImplementedError 442 return (self.name == o.name and self.graph == o.graph and 443 self.traceback == o.traceback and self.type == o.type) 444 445 def __hash__(self): 446 return hash((self.name, self.graph, tuple(self.traceback), self.type)) 447 448 449# TODO(b/209081027): Remove this once Variable is a CompositeTensor. 450class DistributedVariableTraceType(trace.TraceType): 451 """TraceType of DistributedVariable objects.""" 452 453 def __init__(self, distributed_variable): 454 self.distributed_variable = distributed_variable 455 self.components = (tuple(distributed_variable.shape.as_list()), 456 distributed_variable.dtype) 457 458 def is_subtype_of(self, other): 459 return self == other 460 461 def most_specific_common_supertype(self, others): 462 return self if all(self == other for other in others) else None 463 464 def _placeholder_value(self): 465 return self.distributed_variable 466 467 def __hash__(self) -> int: 468 return hash(self.components) 469 470 def __eq__(self, other) -> bool: 471 if not isinstance(other, DistributedVariableTraceType): 472 return False 473 474 return self.components == other.components 475 476 477class DistributedVariable(DistributedDelegate, variables_lib.Variable, 478 core.Tensor): 479 """Holds a map from replica to variables.""" 480 481 def __init__(self, strategy, values, aggregation, var_policy=None): 482 if (aggregation == variables_lib.VariableAggregation.MEAN and 483 not values[0].dtype.is_floating): 484 raise ValueError( 485 "creating distributed tf.Variable with aggregation=MEAN and a " 486 "non-floating dtype is not supported, please use a different " 487 "aggregation or dtype") 488 self._distribute_strategy = strategy 489 self._aggregation = aggregation 490 super(DistributedVariable, self).__init__(values) 491 self._common_name = self._primary.name.split(":")[0] 492 # Use a weakref to make it easy to map from the contained values 493 # to the container without introducing a reference cycle. 494 for v in values: 495 v._distributed_container = weakref.ref(self) # pylint: disable=protected-access 496 497 # Packed variable is used to reduce the overhead of function execution. 498 # For a DistributedVariable, only one variable handle is captured into a 499 # function graph. It's only supported in eager mode. 500 if ops.executing_eagerly_outside_functions() and getattr( 501 strategy, "_enable_packed_variable_in_eager_mode", False): 502 name = "%s/packed/" % self._common_name 503 self._packed_var = packed.PackedDistributedVariable(values, name=name) 504 else: 505 self._packed_var = None 506 507 # tf.keras keeps track of variables initialized using this attribute. When 508 # tf.keras gets the default session, it initializes all uninitialized vars. 509 # We need to make _keras_initialized a member of DistributedVariable because 510 # without this it will use `__getattr__` which will delegate to a component 511 # variable. 512 self._keras_initialized = False 513 # Typically, a `DistributedVariable`'s initializer is composed of the 514 # initializers of the components variables. However, in some cases, such as 515 # when restoring from a checkpoint, we may set the _initializer_op 516 # property on the entire `DistributedVariable`. 517 self._initializer_op = None 518 # Set a VariablePolicy which decides how we replicate/aggregate the given 519 # variable. 520 self._policy = var_policy 521 522 def __deepcopy__(self, memo): 523 """Perform a deepcopy of the `DistributedVariable`. 524 525 Unlike the deepcopy of a regular tf.Variable, this keeps the original 526 strategy and devices of the `DistributedVariable`. To avoid confusion 527 with the behavior of deepcopy on a regular `Variable` (which does 528 copy into new devices), we only allow a deepcopy of a `DistributedVariable` 529 within its originating strategy scope. 530 531 Args: 532 memo: The memoization object for `deepcopy`. 533 534 Returns: 535 A deep copy of the current `DistributedVariable`. 536 537 Raises: 538 RuntimeError: If trying to deepcopy into a different strategy. 539 """ 540 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 541 new_values = [] 542 543 for value in self._values: 544 with ops.device(value.device): 545 new_values.append(copy.deepcopy(value, memo)) 546 547 copied_variable = type(self)( 548 strategy=self._distribute_strategy, 549 values=new_values, 550 aggregation=self._aggregation, 551 var_policy=copy.deepcopy(self._policy, memo)) 552 553 memo[id(self)] = copied_variable 554 555 return copied_variable 556 557 def _use_packed_variable(self): 558 # Don't use packed variable when under a SaveContext to avoid explicit 559 # device placement on variable consuming ops. 560 return self._packed_var is not None and ( 561 not values_util.is_saving_non_distributed()) 562 563 def is_initialized(self, name=None): 564 """Identifies if all the component variables are initialized. 565 566 Args: 567 name: Name of the final `logical_and` op. 568 569 Returns: 570 The op that evaluates to True or False depending on if all the 571 component variables are initialized. 572 """ 573 if values_util.is_saving_non_distributed(): 574 return self._primary.is_initialized() 575 if self._use_packed_variable(): 576 return self._packed_var.is_initialized() 577 result = self._primary.is_initialized() 578 # We iterate through the list of values except the last one to allow us to 579 # name the final `logical_and` op the same name that is passed by the user 580 # to the `is_initialized` op. For distributed variables, the 581 # `is_initialized` op is a `logical_and` op. 582 for v in self._values[1:-1]: 583 result = math_ops.logical_and(result, v.is_initialized()) 584 result = math_ops.logical_and( 585 result, self._values[-1].is_initialized(), name=name) 586 return result 587 588 @property 589 def initializer(self): 590 if values_util.is_saving_non_distributed(): 591 return self._primary.initializer 592 if self._initializer_op: 593 init_op = self._initializer_op 594 else: 595 # return grouped ops of all the var initializations of component values of 596 # the mirrored variable 597 init_op = control_flow_ops.group( 598 tuple(v.initializer for v in self._values)) 599 return init_op 600 601 def initialized_value(self): 602 return self._get_on_device_or_primary().initialized_value() 603 604 @property 605 def initial_value(self): 606 return self._get_on_device_or_primary().initial_value 607 608 @property 609 def constraint(self): 610 return self._primary.constraint 611 612 @property 613 def graph(self): 614 return self._primary.graph 615 616 @property 617 def _shared_name(self): 618 return self._common_name 619 620 @property 621 def _unique_id(self): 622 return self._primary._unique_id # pylint: disable=protected-access 623 624 @property 625 def _graph_key(self): 626 """Lets Optimizers know which graph this variable is from.""" 627 return self._primary._graph_key # pylint: disable=protected-access 628 629 @property 630 def name(self): 631 return self._primary.name 632 633 @property 634 def dtype(self): 635 return self._primary.dtype 636 637 @property 638 def shape(self): 639 return self._primary.shape 640 641 @property 642 def synchronization(self): 643 return self._primary.synchronization 644 645 @property 646 def aggregation(self): 647 return self._aggregation 648 649 @property 650 def _packed_variable(self): 651 if self._use_packed_variable(): 652 return self._packed_var 653 return None 654 655 @property 656 def handle(self): 657 if values_util.is_saving_non_distributed(): 658 return self._primary.handle 659 replica_id = values_util.get_current_replica_id_as_int() 660 if replica_id is None: 661 raise ValueError( 662 "DistributedVariable.handle is not available outside the replica " 663 "context or a `tf.distribute.Strategy.update()` call.") 664 else: 665 if self._use_packed_variable(): 666 return self._packed_var.handle 667 return self._values[replica_id].handle 668 669 def eval(self, session=None): 670 return self._get_on_device_or_primary().eval(session) 671 672 @property 673 def _save_slice_info(self): 674 return self._primary._save_slice_info # pylint: disable=protected-access 675 676 def _get_save_slice_info(self): 677 return self._primary._get_save_slice_info() # pylint: disable=protected-access 678 679 def _set_save_slice_info(self, save_slice_info): 680 for v in self._values: 681 v._set_save_slice_info(save_slice_info) # pylint: disable=protected-access 682 683 @property 684 def device(self): 685 return self._get_on_device_or_primary().device 686 687 @property 688 def trainable(self): 689 return self._primary.trainable 690 691 @property 692 def distribute_strategy(self): 693 return self._distribute_strategy 694 695 def get_shape(self): 696 return self._primary.get_shape() 697 698 def to_proto(self, export_scope=None): 699 return self._primary.to_proto(export_scope=export_scope) 700 701 @property 702 def op(self): 703 if values_util.is_saving_non_distributed(): 704 return self._primary.op 705 # We want cross-replica code that does some var.op.X calls 706 # to work (even if the current device isn't in self._devices), but 707 # other uses of var.op in a cross-replica context to fail. 708 if ds_context.in_cross_replica_context(): 709 return DistributedVarOp(self._primary.op.name, self._primary.op.graph, 710 self._primary.op.traceback, self._primary.op.type) 711 return self._get().op 712 713 @property 714 def _in_graph_mode(self): 715 return self._primary._in_graph_mode # pylint: disable=protected-access 716 717 def _get_replica(self, replica_id): 718 """Returns the value on a device with the given replica_id.""" 719 if self._use_packed_variable(): 720 return self._packed_var.on_device(self._devices[replica_id]) 721 return self._values[replica_id] 722 723 def _get(self): 724 """Returns the value for the current device or raises a ValueError.""" 725 if values_util.is_saving_non_distributed(): 726 return self._primary 727 replica_id = values_util.get_current_replica_id_as_int() 728 if replica_id is None: 729 return self._get_cross_replica() 730 else: 731 return self._get_replica(replica_id) 732 733 def _get_on_device_or_primary(self): 734 """Returns value in same replica or device if possible, else the _primary.""" 735 if values_util.is_saving_non_distributed(): 736 return self._primary 737 replica_id = values_util.get_current_replica_id_as_int() 738 if replica_id is None: 739 # Try to find a value on the current device. 740 current_device = device_util.canonicalize(device_util.current()) 741 for i, value in enumerate(self._values): 742 if device_util.canonicalize(value.device) == current_device: 743 return self._get_replica(i) 744 return self._get_replica(0) 745 else: 746 return self._get_replica(replica_id) 747 748 def read_value(self): 749 if values_util.is_saving_non_distributed(): 750 return self._primary.read_value() 751 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 752 return array_ops.identity(self._get()) 753 754 def value(self): 755 if values_util.is_saving_non_distributed(): 756 return self._primary.value() 757 if self._policy: 758 return self._policy.value(self) 759 return self._get_on_device_or_primary().value() 760 761 def numpy(self): 762 if context.executing_eagerly(): 763 return self.read_value().numpy() 764 else: 765 raise NotImplementedError("DistributedVariable.numpy() is only available " 766 "when eager execution is enabled.") 767 768 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 769 if values_util.is_saving_non_distributed(): 770 return self._primary.assign_sub(value, use_locking, name, read_value) 771 if self._policy: 772 return self._policy.assign_sub( 773 self, 774 value, 775 use_locking=use_locking, 776 name=name, 777 read_value=read_value) 778 return values_util.on_write_assign_sub( 779 self, value, use_locking=use_locking, name=name, read_value=read_value) 780 781 def assign_add(self, value, use_locking=False, name=None, read_value=True): 782 if values_util.is_saving_non_distributed(): 783 return self._primary.assign_add(value, use_locking, name, read_value) 784 if self._policy: 785 return self._policy.assign_add( 786 self, 787 value, 788 use_locking=use_locking, 789 name=name, 790 read_value=read_value) 791 return values_util.on_write_assign_add( 792 self, value, use_locking=use_locking, name=name, read_value=read_value) 793 794 def assign(self, value, use_locking=False, name=None, read_value=True): 795 if values_util.is_saving_non_distributed(): 796 return self._primary.assign(value, use_locking, name, read_value) 797 if self._policy: 798 return self._policy.assign( 799 self, 800 value, 801 use_locking=use_locking, 802 name=name, 803 read_value=read_value) 804 return values_util.on_write_assign( 805 self, value, use_locking=use_locking, name=name, read_value=read_value) 806 807 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 808 if values_util.is_saving_non_distributed(): 809 return self._primary.scatter_sub(sparse_delta, use_locking, name) 810 if self._policy: 811 return self._policy.scatter_sub( 812 self, sparse_delta, use_locking=use_locking, name=name) 813 return values_util.scatter_sub( 814 self, sparse_delta, use_locking=use_locking, name=name) 815 816 def scatter_add(self, sparse_delta, use_locking=False, name=None): 817 if values_util.is_saving_non_distributed(): 818 return self._primary.scatter_add(sparse_delta, use_locking, name) 819 if self._policy: 820 return self._policy.scatter_add( 821 self, sparse_delta, use_locking=use_locking, name=name) 822 return values_util.scatter_add( 823 self, sparse_delta, use_locking=use_locking, name=name) 824 825 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 826 if values_util.is_saving_non_distributed(): 827 return self._primary.scatter_mul(sparse_delta, use_locking, name) 828 if self._policy: 829 return self._policy.scatter_mul( 830 self, sparse_delta, use_locking=use_locking, name=name) 831 return values_util.scatter_mul( 832 self, sparse_delta, use_locking=use_locking, name=name) 833 834 def scatter_div(self, sparse_delta, use_locking=False, name=None): 835 if values_util.is_saving_non_distributed(): 836 return self._primary.scatter_div(sparse_delta, use_locking, name) 837 if self._policy: 838 return self._policy.scatter_div( 839 self, sparse_delta, use_locking=use_locking, name=name) 840 return values_util.scatter_div( 841 self, sparse_delta, use_locking=use_locking, name=name) 842 843 def scatter_min(self, sparse_delta, use_locking=False, name=None): 844 if values_util.is_saving_non_distributed(): 845 return self._primary.scatter_min(sparse_delta, use_locking, name) 846 if self._policy: 847 return self._policy.scatter_min( 848 self, sparse_delta, use_locking=use_locking, name=name) 849 return values_util.scatter_min( 850 self, sparse_delta, use_locking=use_locking, name=name) 851 852 def scatter_max(self, sparse_delta, use_locking=False, name=None): 853 if values_util.is_saving_non_distributed(): 854 return self._primary.scatter_max(sparse_delta, use_locking, name) 855 if self._policy: 856 return self._policy.scatter_max( 857 self, sparse_delta, use_locking=use_locking, name=name) 858 return values_util.scatter_max( 859 self, sparse_delta, use_locking=use_locking, name=name) 860 861 def scatter_update(self, sparse_delta, use_locking=False, name=None): 862 if values_util.is_saving_non_distributed(): 863 return self._primary.scatter_update(sparse_delta, use_locking, name) 864 if self._policy: 865 return self._policy.scatter_update( 866 self, sparse_delta, use_locking=use_locking, name=name) 867 return values_util.scatter_update( 868 self, sparse_delta, use_locking=use_locking, name=name) 869 870 def __tf_tracing_type__(self, _): 871 return DistributedVariableTraceType(self) 872 873 def _gather_saveables_for_checkpoint(self): 874 """Overrides Trackable method. 875 876 This allows both name-based and object-based save and restore of 877 DistributedVariables. 878 879 Returns: 880 A dictionary mapping attribute names to `SaveableObject` factories. 881 """ 882 883 def _saveable_factory(name=self._common_name): 884 return _DistributedVariableSaveable(self, self._primary, name) 885 886 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 887 888 def _as_graph_element(self): 889 if values_util.is_saving_non_distributed(): 890 return self._primary._as_graph_element() # pylint: disable=protected-access 891 if self._policy: 892 return self._policy._as_graph_element(self) # pylint: disable=protected-access 893 894 raise NotImplementedError( 895 "DistributedVariable._as_graph_element requires a valid " 896 "VariablePolicy. Please set the policy via the `var_policy` argument " 897 "in the constructor, or override this method in sub-classes which " 898 "support cross-replica accesses.") 899 900 def _get_cross_replica(self): 901 if values_util.is_saving_non_distributed(): 902 return self._primary 903 if self._policy: 904 return self._policy._get_cross_replica(self) # pylint: disable=protected-access 905 906 raise NotImplementedError( 907 "DistributedVariable._get_cross_replica requires a valid " 908 "VariablePolicy. Please set the policy via the `var_policy` argument " 909 "in the constructor, or override this method in sub-classes which " 910 "support cross-replica accesses.") 911 912 def _update_cross_replica(self, update_fn, value, **kwargs): 913 """Applies updates across replicas. 914 915 Args: 916 update_fn: A callable to pass to `strategy.extended.update` to update the 917 variable. It should has the same signature as `Variable.assign()`. 918 value: value to be passed to `update_fn`. 919 **kwargs: remaining arguments to `update_fn`. 920 921 Returns: 922 Updated variable or `tf.Operation`. 923 """ 924 values_util.mark_as_unsaveable() 925 return self.distribute_strategy.extended.update( 926 self, update_fn, args=(value,), kwargs=kwargs, group=True) 927 928 def _update_replica(self, update_fn, value, **kwargs): 929 """Applies updates in one replica. 930 931 Args: 932 update_fn: A callable to update the variable. It should has the same 933 signature as `Variable.assign()`. 934 value: value to be passed to `update_fn`. 935 **kwargs: remaining arguments to `update_fn`. 936 937 Returns: 938 Updated variable or `tf.Operation`. 939 """ 940 if self._policy: 941 return self._policy._update_replica(self, update_fn, value, **kwargs) # pylint: disable=protected-access 942 raise NotImplementedError( 943 "DistributedVariable._update_replica requires a valid VariablePolicy. " 944 "Please set the policy via the `var_policy` argument in the " 945 "constructor, or override this method in sub-classes which support " 946 "cross-replica accesses.") 947 948 def _update(self, update_fn, value, **kwargs): 949 """Applies updates depending on the context. 950 951 The method calls `_update_replica` in replica context, 952 `_update_cross_replica` in cross replica context, and `update_fn` in update 953 context. 954 955 If `read_value` is True, the method returns the updated Variable. If 956 `read_value` is False, the method returns the update `tf.Operation`. 957 958 Args: 959 update_fn: A callable to pass to `strategy.extended.update` to update the 960 variable. It should have the same signature as `Variable.assign()`. 961 value: value to be passed to `update_fn`. 962 **kwargs: keyword arguments to `update_fn`. 963 964 Returns: 965 Updated variable or `tf.Operation`. 966 967 """ 968 if values_util.is_saving_non_distributed(): 969 return update_fn(self._primary, value, **kwargs) 970 with ds_context.enter_or_assert_strategy(self.distribute_strategy): 971 if ds_context.in_cross_replica_context(): 972 update_replica_id = distribute_lib.get_update_replica_id() 973 if update_replica_id is not None: 974 replica_value = self._get_replica(update_replica_id) 975 return update_fn(replica_value, value, **kwargs) 976 return self._update_cross_replica(update_fn, value, **kwargs) 977 else: 978 values_util.assert_replica_context(self.distribute_strategy) 979 return self._update_replica(update_fn, value, **kwargs) 980 981 def _should_act_as_resource_variable(self): 982 """Pass resource_variable_ops.is_resource_variable check.""" 983 pass 984 985 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 986 """Converts a variable to a tensor.""" 987 if values_util.is_saving_non_distributed(): 988 return ops.convert_to_tensor( 989 self._primary, dtype=dtype, name=name, as_ref=as_ref) 990 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 991 return ops.convert_to_tensor( 992 self._get(), dtype=dtype, name=name, as_ref=as_ref) 993 994 def _map_resources(self, save_options): 995 """For implementing `Trackable`.""" 996 # Initialize for self._primary first, so that obj_map[self._primary] and 997 # resource_map[self._primary.handle] contain mapped values. 998 obj_map, resource_map = self._primary._map_resources(save_options) # pylint:disable=protected-access 999 for v in [v for v in self._values if v != self._primary]: 1000 1001 if (save_options.experimental_variable_policy # pylint:disable=protected-access 1002 ._expand_distributed_variables()): 1003 v_obj_map, v_resource_map = v._map_resources(save_options) # pylint:disable=protected-access 1004 obj_map.update(v_obj_map) 1005 resource_map.update(v_resource_map) 1006 else: 1007 obj_map[v] = obj_map[self._primary] 1008 resource_map[v.handle] = resource_map[self._primary.handle] 1009 obj_map[self] = obj_map[self._primary] 1010 resource_map[self] = resource_map[self._primary.handle] 1011 if self._packed_var is not None: 1012 resource_map[self._packed_var.packed_handle] = resource_map[ 1013 self._primary.handle] 1014 return obj_map, resource_map 1015 1016 def _write_object_proto(self, proto, options): 1017 """Update a SavedObject proto for the caller. 1018 1019 If a DistributedVariable object supports this method, it will be called when 1020 saving with a pre-built `SavedObject` proto representing the object, plus an 1021 instance of `SaveOptions`. This method is then free to modify that proto 1022 instance. 1023 1024 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1025 write out information about their components to the 1026 `experimental_distributed_variable_components` field of a 1027 `SavedVariable` (depending on the `SaveOptions` variable policy). 1028 1029 Args: 1030 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1031 will be a `SavedVariable` instance. 1032 options: A `SaveOptions` instance. 1033 """ 1034 resource_variable_ops.write_object_proto_for_resource_variable( 1035 self, proto, options) 1036 if self._policy: 1037 if self._policy._is_mirrored(): # pylint: disable=protected-access 1038 self._policy._write_object_proto(self, proto, options) # pylint: disable=protected-access 1039 1040 @property 1041 def is_distributed_variable(self): 1042 return True 1043 1044 def __tf_experimental_restore_capture__( 1045 self, concrete_function, internal_capture): 1046 concrete_function.graph.capture_distributed_variable(self, internal_capture) 1047 return self 1048 1049 1050# We extend from `saveable_object.SaveableObject` instead of 1051# `saveable_object_util.ResourceVariableSaveable` since we need to read the 1052# value of ONREAD variables when saving. `SaveableObject` provides a way to 1053# specify the function to run to get the value of the variable or tensor at 1054# saving time. We can use this for both ON_READ and ON_WRITE variables. 1055# TODO(b/164586507): Consolidate ON_WRITE and ON_READ saving/restoring logic 1056# if possible. 1057class _DistributedVariableSaveable(saveable_object.SaveableObject): 1058 """Class for defining how to restore a DistributedVariable.""" 1059 1060 def __init__(self, distributed_variable, primary_variable, name): 1061 self._distributed_variable = distributed_variable 1062 if not self._distributed_variable._policy: 1063 raise ValueError( 1064 "The VariablePolicy of the argument `distributed_variable` must be " 1065 "set to create a _DistributedVariableSaveable. Please set it via " 1066 "the `var_policy` argument in the constructor of DistributedVariable." 1067 ) 1068 tensor, spec = distributed_variable._policy.get_saveable( 1069 distributed_variable, primary_variable, name) 1070 super(_DistributedVariableSaveable, self).__init__(tensor, spec, name) 1071 1072 def restore(self, restored_tensors, restored_shapes): 1073 """Restore the same value into all variables.""" 1074 tensor, = restored_tensors 1075 return self._distributed_variable._policy.get_restore_ops( # pylint: disable=protected-access 1076 self._distributed_variable, tensor) 1077 1078 1079class _MirroredSaveable(saveable_object.SaveableObject): 1080 """Class for defining how to restore a MirroredVariable.""" 1081 1082 def __init__(self, mirrored_variable, primary_variable, name): 1083 self._mirrored_variable = mirrored_variable 1084 tensor, spec = values_util.get_on_write_saveable(self._mirrored_variable, 1085 primary_variable, name) 1086 super(_MirroredSaveable, self).__init__(tensor, spec, name) 1087 1088 def restore(self, restored_tensors, restored_shapes): 1089 """Restore the same value into all variables.""" 1090 tensor, = restored_tensors 1091 return values_util.get_on_write_restore_ops(self._mirrored_variable, tensor) 1092 1093 1094class MirroredVariable(DistributedVariable, Mirrored): 1095 """Holds a map from replica to variables whose values are kept in sync.""" 1096 1097 def _update_replica(self, update_fn, value, **kwargs): 1098 return _on_write_update_replica(self, update_fn, value, **kwargs) 1099 1100 def scatter_min(self, *args, **kwargs): 1101 if values_util.is_saving_non_distributed(): 1102 return self._primary.scatter_min(*args, **kwargs) 1103 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1104 self._aggregation != vs.VariableAggregation.NONE): 1105 raise NotImplementedError( 1106 values_util.scatter_error_msg.format( 1107 op_name="scatter_min", aggregation=self._aggregation)) 1108 return super(MirroredVariable, self).scatter_min(*args, **kwargs) 1109 1110 def scatter_max(self, *args, **kwargs): 1111 if values_util.is_saving_non_distributed(): 1112 return self._primary.scatter_max(*args, **kwargs) 1113 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1114 self._aggregation != vs.VariableAggregation.NONE): 1115 raise NotImplementedError( 1116 values_util.scatter_error_msg.format( 1117 op_name="scatter_max", aggregation=self._aggregation)) 1118 return super(MirroredVariable, self).scatter_max(*args, **kwargs) 1119 1120 def scatter_update(self, *args, **kwargs): 1121 if values_util.is_saving_non_distributed(): 1122 return self._primary.scatter_update(*args, **kwargs) 1123 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1124 self._aggregation != vs.VariableAggregation.NONE): 1125 raise NotImplementedError( 1126 values_util.scatter_error_msg.format( 1127 op_name="scatter_update", aggregation=self._aggregation)) 1128 return super(MirroredVariable, self).scatter_update(*args, **kwargs) 1129 1130 def _get_cross_replica(self): 1131 # Return identity, to avoid directly exposing the variable to the user and 1132 # allowing it to be modified by mistake. 1133 return array_ops.identity(Mirrored._get_cross_replica(self)) 1134 1135 def _as_graph_element(self): 1136 return self._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1137 1138 def _gather_saveables_for_checkpoint(self): 1139 """Overrides Trackable method. 1140 1141 This allows both name-based and object-based save and restore of 1142 MirroredVariables. 1143 1144 Returns: 1145 A dictionary mapping attribute names to `SaveableObject` factories. 1146 """ 1147 1148 def _saveable_factory(name=self._common_name): 1149 return _MirroredSaveable(self, self._primary, name) 1150 1151 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1152 1153 def _write_object_proto(self, proto, options): 1154 """Update a SavedObject proto for the caller. 1155 1156 If a DistributedVariable object supports this method, it will be called when 1157 saving with a pre-built `SavedObject` proto representing the object, plus an 1158 instance of `SaveOptions`. This method is then free to modify that proto 1159 instance. 1160 1161 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1162 write out information about their components to the 1163 `experimental_distributed_variable_components` field of a 1164 `SavedVariable` (depending on the `SaveOptions` variable policy). 1165 1166 Args: 1167 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1168 will be a `SavedVariable` instance. 1169 options: A `SaveOptions` instance. 1170 """ 1171 super(MirroredVariable, self)._write_object_proto(proto, options) 1172 values_util.write_object_proto(self, proto, options) 1173 1174 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1175 """Converts a variable to a tensor.""" 1176 # TODO(b/154017756): Make _dense_var_to_tensor consistent between ON_READ 1177 # and ON_WRITE. 1178 # Try to avoid assignments to and other mutations of MirroredVariable 1179 # state except through a DistributionStrategy.extended.update() or any of 1180 # the `assign*` and `scatter*` calls. 1181 if as_ref: 1182 # A TF 1.x case where the variable is a boolean variable and used like: 1183 # tf.cond(v, true_fn, false_fn). 1184 raise ValueError( 1185 "You may be using variable created under distribute strategy in TF " 1186 "1.x control flows. Try explicitly converting the variable to Tensor " 1187 "using variable.read_value(), or switch to TF 2.x.") 1188 return ops.convert_to_tensor( 1189 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1190 1191 1192class _SyncOnReadSaveable(saveable_object.SaveableObject): 1193 """Class for defining how to restore a SyncOnReadVariable.""" 1194 1195 def __init__(self, sync_on_read_variable, name): 1196 self._sync_on_read_variable = sync_on_read_variable 1197 tensor, spec = values_util.get_on_read_saveable( 1198 sync_on_read_variable, sync_on_read_variable._primary, name) 1199 1200 super(_SyncOnReadSaveable, self).__init__(tensor, spec, name) 1201 1202 def restore(self, restored_tensors, restored_shapes): 1203 """Restore the same value into all variables.""" 1204 tensor, = restored_tensors 1205 return values_util.get_on_read_restore_ops( 1206 self._sync_on_read_variable, tensor, 1207 self._sync_on_read_variable.aggregation) 1208 1209 1210class SyncOnReadVariable(DistributedVariable): 1211 """Holds a map from replica to variables whose values are reduced on save.""" 1212 1213 def _update_replica(self, update_fn, value, **kwargs): 1214 return update_fn(self._get_on_device_or_primary(), value, **kwargs) 1215 1216 def _get(self): 1217 """Returns the value of SyncOnReadVariable based on surrounding context. 1218 1219 If called under a non-default replica-context, returns the corresponding 1220 variable on that replica. 1221 If called under default replica-context or cross-replica context, returns 1222 the synced value. 1223 """ 1224 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1225 return super(SyncOnReadVariable, self)._get() 1226 1227 # TODO(b/154017756): Make assign behaivor in cross replica context consistent 1228 # with MirroredVariable. 1229 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 1230 if values_util.is_saving_non_distributed(): 1231 return self._primary.assign_sub(value, use_locking, name, read_value) 1232 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1233 if (ds_context.in_cross_replica_context() and 1234 not values_util.in_replica_update_context()): 1235 values_util.mark_as_unsaveable() 1236 return values_util.on_read_assign_sub_cross_replica( 1237 self, value, read_value=read_value) 1238 else: 1239 return super(SyncOnReadVariable, 1240 self).assign_sub(value, use_locking, name, read_value) 1241 1242 def assign_add(self, value, use_locking=False, name=None, read_value=True): 1243 if values_util.is_saving_non_distributed(): 1244 return self._primary.assign_add(value, use_locking, name, read_value) 1245 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1246 if (ds_context.in_cross_replica_context() and 1247 not values_util.in_replica_update_context()): 1248 values_util.mark_as_unsaveable() 1249 return values_util.on_read_assign_add_cross_replica( 1250 self, value, read_value=read_value) 1251 else: 1252 return super(SyncOnReadVariable, 1253 self).assign_add(value, use_locking, name, read_value) 1254 1255 def assign(self, value, use_locking=False, name=None, read_value=True): 1256 if values_util.is_saving_non_distributed(): 1257 return self._primary.assign(value, use_locking, name, read_value) 1258 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1259 if (ds_context.in_cross_replica_context() and 1260 not values_util.in_replica_update_context()): 1261 values_util.mark_as_unsaveable() 1262 return values_util.on_read_assign_cross_replica( 1263 self, value, read_value=read_value) 1264 else: 1265 return super(SyncOnReadVariable, self).assign(value, use_locking, name, 1266 read_value) 1267 1268 def _scatter_not_implemented(self, method): 1269 raise NotImplementedError( 1270 f"Variables with `synchronization=ON_READ` doesn't support `{method}`") 1271 1272 def scatter_sub(self, *args, **kwargs): 1273 if values_util.is_saving_non_distributed(): 1274 return self._primary.scatter_sub(*args, **kwargs) 1275 self._scatter_not_implemented("scatter_sub") 1276 1277 def scatter_add(self, *args, **kwargs): 1278 if values_util.is_saving_non_distributed(): 1279 return self._primary.scatter_add(*args, **kwargs) 1280 self._scatter_not_implemented("scatter_add") 1281 1282 def scatter_mul(self, *args, **kwargs): 1283 if values_util.is_saving_non_distributed(): 1284 return self._primary.scatter_mul(*args, **kwargs) 1285 self._scatter_not_implemented("scatter_mul") 1286 1287 def scatter_div(self, *args, **kwargs): 1288 if values_util.is_saving_non_distributed(): 1289 return self._primary.scatter_div(*args, **kwargs) 1290 self._scatter_not_implemented("scatter_div") 1291 1292 def scatter_min(self, *args, **kwargs): 1293 if values_util.is_saving_non_distributed(): 1294 return self._primary.scatter_min(*args, **kwargs) 1295 self._scatter_not_implemented("scatter_min") 1296 1297 def scatter_max(self, *args, **kwargs): 1298 if values_util.is_saving_non_distributed(): 1299 return self._primary.scatter_max(*args, **kwargs) 1300 self._scatter_not_implemented("scatter_max") 1301 1302 def scatter_update(self, *args, **kwargs): 1303 if values_util.is_saving_non_distributed(): 1304 return self._primary.scatter_update(*args, **kwargs) 1305 self._scatter_not_implemented("scatter_update") 1306 1307 def value(self): 1308 if ds_context.in_variable_sync_on_read_context(): 1309 raise NotImplementedError( 1310 "call `variable.value()` inside variable_sync_on_read_context is not " 1311 "supported") 1312 if values_util.is_saving_non_distributed(): 1313 return self._primary.value() 1314 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1315 if (ds_context.in_cross_replica_context() and 1316 not values_util.in_replica_update_context()): 1317 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1318 return self._get_replica(0).value() 1319 return self._get_cross_replica() 1320 else: 1321 # _get_on_device_or_primary() returns a Variable. 1322 return self._get_on_device_or_primary().value() 1323 1324 def read_value(self): 1325 if ds_context.in_variable_sync_on_read_context(): 1326 raise NotImplementedError( 1327 "call `variable.read_value()` inside variable_sync_on_read_context is" 1328 " not supported") 1329 return super().read_value() 1330 1331 def _get_cross_replica(self): 1332 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1333 # Consider returning a tensor value here to make the return value of 1334 # _get_cross_replica consistent. 1335 return self._get_replica(0) 1336 if self._aggregation == vs.VariableAggregation.SUM: 1337 values_util.mark_as_unsaveable() 1338 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1339 return self._distribute_strategy.reduce( 1340 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1341 self, 1342 axis=None) 1343 1344 def _as_graph_element(self): 1345 if values_util.is_saving_non_distributed(): 1346 return self._primary._as_graph_element() # pylint: disable=protected-access 1347 # pylint: disable=protected-access 1348 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1349 if ds_context.in_cross_replica_context(): 1350 return ops.convert_to_tensor(self._get_cross_replica()) 1351 return self._get()._as_graph_element() 1352 1353 def _gather_saveables_for_checkpoint(self): 1354 """Overrides Trackable method. 1355 1356 This allows both name-based and object-based save and restore of 1357 `SyncOnReadVariable`s. 1358 1359 Returns: 1360 A dictionary mapping attribute names to `SaveableObject` factories. 1361 """ 1362 1363 def _saveable_factory(name=self._common_name): 1364 return _SyncOnReadSaveable(self, name) 1365 1366 return {trackable.VARIABLE_VALUE_KEY: _saveable_factory} 1367 1368 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 1369 """Converts a SyncOnReadVariable to a tensor.""" 1370 if values_util.is_saving_non_distributed(): 1371 return ops.convert_to_tensor( 1372 self._primary, dtype=dtype, name=name, as_ref=as_ref) 1373 with ds_context.enter_or_assert_strategy(self._distribute_strategy): 1374 replica_context = ds_context.get_replica_context() 1375 if (replica_context is not None and 1376 ds_context.in_variable_sync_on_read_context()): 1377 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1378 return ops.convert_to_tensor( 1379 self._get_replica(0), dtype=dtype, name=name, as_ref=as_ref) 1380 if self._aggregation == vs.VariableAggregation.SUM: 1381 values_util.mark_as_unsaveable() 1382 # pylint: disable=protected-access 1383 reduced = ( 1384 replica_context.strategy.extended._replica_ctx_all_reduce( 1385 reduce_util.ReduceOp.from_variable_aggregation( 1386 self._aggregation), 1387 self._get().read_value())) 1388 return ops.convert_to_tensor( 1389 reduced, dtype=dtype, name=name, as_ref=as_ref) 1390 1391 return ops.convert_to_tensor( 1392 self._get(), dtype=dtype, name=name, as_ref=as_ref) 1393 1394 1395# Register a conversion functions which reads the value of the variable, 1396# allowing instances of the class to be used as tensors. 1397# DistributedVariable 1398def _tensor_conversion_distributed_var(var, 1399 dtype=None, 1400 name=None, 1401 as_ref=False): 1402 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1403 1404 1405ops.register_tensor_conversion_function(DistributedVariable, 1406 _tensor_conversion_distributed_var) 1407 1408 1409# MirroredVariables 1410def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False): 1411 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1412 1413 1414ops.register_tensor_conversion_function(MirroredVariable, 1415 _tensor_conversion_mirrored) 1416 1417 1418# Mirrored Values 1419def _tensor_conversion_mirrored_val(value, dtype=None, name=None, as_ref=False): 1420 return ops.convert_to_tensor( 1421 value._get(), dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1422 1423 1424ops.register_tensor_conversion_function(Mirrored, 1425 _tensor_conversion_mirrored_val) 1426 1427 1428# SyncOnReadVariables 1429def _tensor_conversion_sync_on_read(var, dtype=None, name=None, as_ref=False): 1430 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 1431 1432 1433ops.register_tensor_conversion_function(SyncOnReadVariable, 1434 _tensor_conversion_sync_on_read) 1435 1436 1437class VariablePolicy(object): 1438 """Policy defining synchronization and aggregation of a distributed variable. 1439 1440 Given `synchronization` and `aggregation` parameters set on a `tf.Variable` 1441 during variable creation within `tf.distribute` scope, `tf.distribute` creates 1442 an appropriate policy object and assigns it to the distributed variable. All 1443 variable operations are delegated to the respective policy object. 1444 """ 1445 1446 def __init__(self, aggregation): 1447 self._aggregation = aggregation 1448 1449 def value(self): 1450 raise NotImplementedError( 1451 "VariablePolicy.value should be overriden by sub-classes.") 1452 1453 def _is_mirrored(self): 1454 raise NotImplementedError( 1455 "VariablePolicy._is_mirrored should be overriden by sub-classes.") 1456 1457 def _as_graph_element(self, _): 1458 raise NotImplementedError( 1459 "VariablePolicy._as_graph_element should be overriden by sub-classes.") 1460 1461 def _get_cross_replica(self, var): 1462 raise NotImplementedError( 1463 "VariablePolicy._get_cross_replica should be overriden by sub-classes.") 1464 1465 def _update_replica(self, var, update_fn, value, **kwargs): 1466 raise NotImplementedError( 1467 "VariablePolicy._update_replica should be overriden by sub-classes.") 1468 1469 1470class OnReadPolicy(VariablePolicy): 1471 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 1472 1473 This policy is created when `synchronization` is set to 1474 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 1475 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 1476 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 1477 scope. 1478 """ 1479 1480 def _is_mirrored(self): 1481 return False 1482 1483 def value(self, var): 1484 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1485 if (ds_context.in_cross_replica_context() and 1486 not values_util.in_replica_update_context()): 1487 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1488 return var._get_replica(0).value() # pylint: disable=protected-access 1489 return var._get_cross_replica() # pylint: disable=protected-access 1490 else: 1491 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1492 1493 def _as_graph_element(self, var): 1494 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1495 if ds_context.in_cross_replica_context(): 1496 return ops.convert_to_tensor(var._get_cross_replica()) # pylint: disable=protected-access 1497 return var._get()._as_graph_element() # pylint: disable=protected-access 1498 1499 def _get_cross_replica(self, var): 1500 if self._aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 1501 return var._get_replica(0) # pylint: disable=protected-access 1502 if self._aggregation == vs.VariableAggregation.SUM: 1503 values_util.mark_as_unsaveable() 1504 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1505 return var.distribute_strategy.reduce( 1506 reduce_util.ReduceOp.from_variable_aggregation(self._aggregation), 1507 var, 1508 axis=None) 1509 1510 def _update_replica(self, var, update_fn, value, **kwargs): 1511 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1512 1513 def _scatter_not_implemented(self, method): 1514 raise NotImplementedError(f"ON_READ variables doesn't support `{method}` " 1515 "in cross replica context") 1516 1517 def assign_sub(self, 1518 var, 1519 value, 1520 use_locking=False, 1521 name=None, 1522 read_value=True): 1523 """Subtracts a value from this variable.""" 1524 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1525 if (ds_context.in_cross_replica_context() and 1526 not values_util.in_replica_update_context()): 1527 values_util.mark_as_unsaveable() 1528 return values_util.on_read_assign_sub_cross_replica( 1529 var, value, read_value=read_value) 1530 else: 1531 return values_util.on_write_assign_sub( 1532 var, 1533 value, 1534 use_locking=use_locking, 1535 name=name, 1536 read_value=read_value) 1537 1538 def assign_add(self, 1539 var, 1540 value, 1541 use_locking=False, 1542 name=None, 1543 read_value=True): 1544 """Adds a value to this variable.""" 1545 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1546 if (ds_context.in_cross_replica_context() and 1547 not values_util.in_replica_update_context()): 1548 values_util.mark_as_unsaveable() 1549 return values_util.on_read_assign_add_cross_replica( 1550 var, value, read_value=read_value) 1551 else: 1552 return values_util.on_write_assign_add( 1553 var, 1554 value, 1555 use_locking=use_locking, 1556 name=name, 1557 read_value=read_value) 1558 1559 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1560 with ds_context.enter_or_assert_strategy(var.distribute_strategy): 1561 if (ds_context.in_cross_replica_context() and 1562 not values_util.in_replica_update_context()): 1563 values_util.mark_as_unsaveable() 1564 return values_util.on_read_assign_cross_replica( 1565 var, value, read_value=read_value) 1566 else: 1567 return values_util.on_write_assign( 1568 var, 1569 value, 1570 use_locking=use_locking, 1571 name=name, 1572 read_value=read_value) 1573 1574 def scatter_sub(self, *args, **kwargs): 1575 del args, kwargs 1576 self._scatter_not_implemented("scatter_sub") 1577 1578 def scatter_add(self, *args, **kwargs): 1579 del args, kwargs 1580 self._scatter_not_implemented("scatter_add") 1581 1582 def scatter_mul(self, *args, **kwargs): 1583 del args, kwargs 1584 self._scatter_not_implemented("scatter_mul") 1585 1586 def scatter_div(self, *args, **kwargs): 1587 del args, kwargs 1588 self._scatter_not_implemented("scatter_div") 1589 1590 def scatter_min(self, *args, **kwargs): 1591 del args, kwargs 1592 self._scatter_not_implemented("scatter_min") 1593 1594 def scatter_max(self, *args, **kwargs): 1595 del args, kwargs 1596 self._scatter_not_implemented("scatter_max") 1597 1598 def scatter_update(self, *args, **kwargs): 1599 del args, kwargs 1600 self._scatter_not_implemented("scatter_update") 1601 1602 def get_saveable(self, var, primary_var, name): 1603 """Create a saveable object for the given variable.""" 1604 return values_util.get_on_read_saveable(var, primary_var, name) 1605 1606 def get_restore_ops(self, var, tensor): 1607 """Restore the same value into all variables.""" 1608 return values_util.get_on_read_restore_ops(var, tensor, self._aggregation) 1609 1610 1611class OnWritePolicy(VariablePolicy): 1612 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 1613 1614 This policy is created when the following `synchronization` and `aggregation` 1615 parameters are specified when creating a `tf.Variable` in `tf.distribute` 1616 scope and `synchronization` is equal to `tf.VariableSynchronization.ON_WRITE` 1617 or `tf.VariableSynchronization.AUTO`. 1618 """ 1619 1620 def _is_mirrored(self): 1621 return True 1622 1623 def value(self, var): 1624 return var._get_on_device_or_primary().value() # pylint: disable=protected-access 1625 1626 def _as_graph_element(self, var): 1627 return var._get_on_device_or_primary()._as_graph_element() # pylint: disable=protected-access 1628 1629 def _get_cross_replica(self, var): 1630 # Return identity, to avoid directly exposing the variable to the user and 1631 # allowing it to be modified by mistake. 1632 return array_ops.identity(var._get_on_device_or_primary()) # pylint: disable=protected-access 1633 1634 def _update_replica(self, var, update_fn, value, **kwargs): 1635 if var.aggregation == variables_lib.VariableAggregation.NONE: 1636 return update_fn(var._get_on_device_or_primary(), value, **kwargs) # pylint: disable=protected-access 1637 return _on_write_update_replica(var, update_fn, value, **kwargs) 1638 1639 def assign(self, var, value, use_locking=False, name=None, read_value=True): 1640 return values_util.on_write_assign( 1641 var, value, use_locking=use_locking, name=name, read_value=read_value) 1642 1643 def assign_add(self, 1644 var, 1645 value, 1646 use_locking=False, 1647 name=None, 1648 read_value=True): 1649 return values_util.on_write_assign_add( 1650 var, value, use_locking=use_locking, name=name, read_value=read_value) 1651 1652 def assign_sub(self, 1653 var, 1654 value, 1655 use_locking=False, 1656 name=None, 1657 read_value=True): 1658 return values_util.on_write_assign_sub( 1659 var, value, use_locking=use_locking, name=name, read_value=read_value) 1660 1661 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 1662 return values_util.scatter_sub( 1663 var, sparse_delta, use_locking=use_locking, name=name) 1664 1665 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 1666 return values_util.scatter_add( 1667 var, sparse_delta, use_locking=use_locking, name=name) 1668 1669 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 1670 return values_util.scatter_mul( 1671 var, sparse_delta, use_locking=use_locking, name=name) 1672 1673 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 1674 return values_util.scatter_div( 1675 var, sparse_delta, use_locking=use_locking, name=name) 1676 1677 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 1678 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1679 self._aggregation != vs.VariableAggregation.NONE): 1680 raise NotImplementedError( 1681 values_util.scatter_error_msg.format( 1682 op_name="scatter_min", aggregation=self._aggregation)) 1683 return values_util.scatter_min( 1684 var, sparse_delta, use_locking=use_locking, name=name) 1685 1686 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 1687 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1688 self._aggregation != vs.VariableAggregation.NONE): 1689 raise NotImplementedError( 1690 values_util.scatter_error_msg.format( 1691 op_name="scatter_max", aggregation=self._aggregation)) 1692 return values_util.scatter_max( 1693 var, sparse_delta, use_locking=use_locking, name=name) 1694 1695 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 1696 if (self._aggregation != vs.VariableAggregation.ONLY_FIRST_REPLICA and 1697 self._aggregation != vs.VariableAggregation.NONE): 1698 raise NotImplementedError( 1699 values_util.scatter_error_msg.format( 1700 op_name="scatter_update", aggregation=self._aggregation)) 1701 return values_util.scatter_update( 1702 var, sparse_delta, use_locking=use_locking, name=name) 1703 1704 def get_saveable(self, var, primary_var, name): 1705 """Saveable ops for AUTO variables.""" 1706 return values_util.get_on_write_saveable(var, primary_var, name) 1707 1708 def get_restore_ops(self, var, tensor): 1709 return values_util.get_on_write_restore_ops(var, tensor) 1710 1711 def _write_object_proto(self, var, proto, options): 1712 """Update a SavedObject proto for the caller. 1713 1714 If a DistributedVariable object supports this method, it will be called when 1715 saving with a pre-built `SavedObject` proto representing the object, plus an 1716 instance of `SaveOptions`. This method is then free to modify that proto 1717 instance. 1718 1719 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 1720 write out information about their components to the 1721 `experimental_distributed_variable_components` field of a 1722 `SavedVariable` (depending on the `SaveOptions` variable policy). 1723 1724 Args: 1725 var : A DistributedVariable object 1726 proto: A pre-built `SavedObject` proto for this object. It is assumed this 1727 will be a `SavedVariable` instance. 1728 options: A `SaveOptions` instance. 1729 """ 1730 values_util.write_object_proto(var, proto, options) 1731 1732 1733class PerWorkerResource(): 1734 """A per-worker CapturableResource class for non-ParameterServer strategy. 1735 1736 Resources that populate `host_to_resources` should be instances of classes 1737 subclassing CapturableResource, although currently it's only used and tested 1738 for StaticHashTable with TPUStrategy. 1739 """ 1740 1741 def __init__(self, strategy, host_to_resources): 1742 distribute_lib.distribution_strategy_input_api_counter.get_cell( 1743 "PerWorkerResource", "TPUDistributedLookupTable").increase_by(1) 1744 self._strategy = strategy 1745 self._host_to_resources = host_to_resources 1746 1747 def __getattribute__(self, name): 1748 if name not in ("__init__", "__getattribute__", "_host_to_resources", 1749 "_strategy", "local_resource"): 1750 return getattr(self.local_resource(), name) 1751 return super(PerWorkerResource, self).__getattribute__(name) 1752 1753 def __setattr__(self, name, value): 1754 if name not in ("_strategy", "_host_to_resources"): 1755 return setattr(self.local_resource(), name, value) 1756 return super(PerWorkerResource, self).__setattr__(name, value) 1757 1758 def local_resource(self): 1759 """Returns the resource on the local worker.""" 1760 current_device = device_util.canonicalize(device_util.current()) 1761 host_device = device_util.canonicalize( 1762 device_util.get_host_for_device(current_device)) 1763 return self._host_to_resources.get( 1764 host_device, 1765 self._host_to_resources[next(iter(self._host_to_resources))]) 1766