1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing TPU distributed values. 16 17Note that the tests are in values_test.py . 18 19""" 20 21from tensorflow.python.distribute import packed_distributed_variable as packed 22from tensorflow.python.distribute import tpu_replicated_variable 23from tensorflow.python.distribute import tpu_util 24from tensorflow.python.distribute import values 25from tensorflow.python.distribute import values_util 26from tensorflow.python.eager import context 27from tensorflow.python.eager import tape 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import gen_resource_variable_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variable_scope 32 33 34_scatter_error_msg = ("{op_name} is only supported for distributed " 35 "variable (variable created within certain " 36 "`tf.distribute.Strategy` scope) with NONE " 37 " aggregation, got: {aggregation}.") 38 39 40class TPUVariableMixin(object): 41 """Mixin for TPU variables.""" 42 43 def __init__(self, *args, **kwargs): 44 super(TPUVariableMixin, self).__init__(*args, **kwargs) 45 46 # Handle ID is needed for `get_replicated_var_handle` to cache the variables 47 # correctly since in eager mode different variables can have the same name. 48 if ops.executing_eagerly_outside_functions(): 49 self._handle_id = self._common_name + "_" + str(id(self._primary)) 50 else: 51 self._handle_id = self._common_name 52 53 def __getattr__(self, name): 54 if tpu_util.enclosing_tpu_context() is None: 55 return super(TPUVariableMixin, self).__getattr__(name) 56 else: 57 raise AttributeError( 58 f"`TPUVariableMixin.{name}` not accessible within a TPU context.") 59 60 def get(self): 61 if tpu_util.enclosing_tpu_context() is None: 62 return super(TPUVariableMixin, self).get() 63 else: 64 raise NotImplementedError( 65 "`TPUVariableMixin.get()` is not supported within a TPU context.") 66 67 def _get_as_operand(self): 68 return self.read_value() 69 70 def _is_mirrored(self): 71 raise NotImplementedError( 72 "`TPUVariableMixin._is_mirrored()` must be implemented by subclasses.") 73 74 @property 75 def handle(self): 76 """The handle by which this variable can be accessed.""" 77 # If we're in a tpu.rewrite(), return the replicated handle. 78 tpu_context = tpu_util.enclosing_tpu_context() 79 if tpu_context is None or context.executing_eagerly(): 80 var = self._get_on_device_or_primary() 81 if isinstance(var, packed.PackedVarAndDevice): 82 return var.on_device_handle() 83 else: 84 return var.handle 85 else: 86 is_packed = self._packed_var is not None 87 val = self._values 88 if is_packed: 89 val = [self._packed_var] 90 91 return tpu_context.get_replicated_var_handle(self._common_name, 92 self._handle_id, val, 93 self._is_mirrored(), 94 is_packed) 95 96 @property 97 def device(self): 98 return self.handle.device 99 100 def _read_variable_op(self): 101 """Reads the value of this variable.""" 102 if self.trainable: 103 tape.variable_accessed(self) 104 105 handle = self.handle 106 if getattr(handle, "is_packed", False): 107 # Add a device scope for a packed variable handle. 108 with ops.device(self._get_on_device_or_primary().device): 109 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 110 else: 111 return gen_resource_variable_ops.read_variable_op(handle, self.dtype) 112 113 def read_value(self): 114 if tpu_util.enclosing_tpu_context() is None: 115 return super(TPUVariableMixin, self).read_value() 116 else: 117 return self._read_variable_op() 118 119 def value(self): 120 if tpu_util.enclosing_tpu_context() is None: 121 return super(TPUVariableMixin, self).value() 122 else: 123 return self._read_variable_op() 124 125 def _as_graph_element(self): 126 if tpu_util.enclosing_tpu_context() is None: 127 return super(TPUVariableMixin, self)._as_graph_element() # pylint: disable=protected-access 128 else: 129 return None 130 131 @property 132 def op(self): 133 if values_util.is_saving_non_distributed(): 134 return self._primary.op 135 return values.DistributedVarOp(self._primary.op.name, 136 self._primary.op.graph, 137 self._primary.op.traceback, 138 self._primary.op.type) 139 140 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 141 """Converts a variable to a tensor.""" 142 # pylint: disable=protected-access 143 if tpu_util.enclosing_tpu_context() is None: 144 return super(TPUVariableMixin, self)._dense_var_to_tensor( 145 dtype=dtype, name=name, as_ref=as_ref) 146 # pylint: enable=protected-access 147 elif dtype is not None and dtype != self.dtype: 148 return math_ops.cast(self.read_value(), dtype) 149 else: 150 return self.handle if as_ref else self.read_value() 151 152 153class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable): 154 """DistributedVariable subclass for TPUStrategy.""" 155 156 def _is_mirrored(self): 157 return self._policy._is_mirrored() # pylint: disable=protected-access 158 159 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 160 if values_util.is_saving_non_distributed(): 161 return self._primary.assign_sub(value, use_locking, name, read_value) 162 return self._policy.assign_sub( 163 self, value, use_locking=use_locking, name=name, read_value=read_value) 164 165 def assign_add(self, value, use_locking=False, name=None, read_value=True): 166 if values_util.is_saving_non_distributed(): 167 return self._primary.assign_add(value, use_locking, name, read_value) 168 return self._policy.assign_add( 169 self, value, use_locking=use_locking, name=name, read_value=read_value) 170 171 def assign(self, value, use_locking=False, name=None, read_value=True): 172 if values_util.is_saving_non_distributed(): 173 return self._primary.assign(value, use_locking, name, read_value) 174 return self._policy.assign( 175 self, value, use_locking=use_locking, name=name, read_value=read_value) 176 177 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 178 if values_util.is_saving_non_distributed(): 179 return self._primary.scatter_sub(sparse_delta, use_locking, name) 180 return self._policy.scatter_sub( 181 self, sparse_delta, use_locking=use_locking, name=name) 182 183 def scatter_add(self, sparse_delta, use_locking=False, name=None): 184 if values_util.is_saving_non_distributed(): 185 return self._primary.scatter_add(sparse_delta, use_locking, name) 186 return self._policy.scatter_add( 187 self, sparse_delta, use_locking=use_locking, name=name) 188 189 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 190 if values_util.is_saving_non_distributed(): 191 return self._primary.scatter_mul(sparse_delta, use_locking, name) 192 return self._policy.scatter_mul( 193 self, sparse_delta, use_locking=use_locking, name=name) 194 195 def scatter_div(self, sparse_delta, use_locking=False, name=None): 196 if values_util.is_saving_non_distributed(): 197 return self._primary.scatter_div(sparse_delta, use_locking, name) 198 return self._policy.scatter_div( 199 self, sparse_delta, use_locking=use_locking, name=name) 200 201 def scatter_min(self, sparse_delta, use_locking=False, name=None): 202 if values_util.is_saving_non_distributed(): 203 return self._primary.scatter_min(sparse_delta, use_locking, name) 204 return self._policy.scatter_min( 205 self, sparse_delta, use_locking=use_locking, name=name) 206 207 def scatter_max(self, sparse_delta, use_locking=False, name=None): 208 if values_util.is_saving_non_distributed(): 209 return self._primary.scatter_max(sparse_delta, use_locking, name) 210 return self._policy.scatter_max( 211 self, sparse_delta, use_locking=use_locking, name=name) 212 213 def scatter_update(self, sparse_delta, use_locking=False, name=None): 214 if values_util.is_saving_non_distributed(): 215 return self._primary.scatter_update(sparse_delta, use_locking, name) 216 return self._policy.scatter_update( 217 self, sparse_delta, use_locking=use_locking, name=name) 218 219 220class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable): 221 """Holds a map from replica to TPU variables whose values are kept in sync.""" 222 223 def _is_replicated_or_sharded_to_logical_cores(self): 224 """Returns whether each of the underlying variables is replicated or sharded to logical cores. 225 226 If True, the handles of the underlying variables are not available outside a 227 TPU context. 228 """ 229 return isinstance(self._primary, 230 tpu_replicated_variable.TPUReplicatedVariable) 231 232 @property 233 def device(self): 234 if (self._is_replicated_or_sharded_to_logical_cores() and 235 tpu_util.enclosing_tpu_context() is None): 236 return self._primary.device 237 return super(TPUMirroredVariable, self).device 238 239 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 240 tpu_context = tpu_util.enclosing_tpu_context() 241 if (self._is_replicated_or_sharded_to_logical_cores() and 242 tpu_context is None): 243 assign_sub_fn = lambda v, *a, **ka: v.assign_sub(*a, **ka) 244 return self._update( 245 update_fn=assign_sub_fn, 246 value=value, 247 use_locking=use_locking, 248 name=name, 249 read_value=read_value) 250 251 if (tpu_context and 252 self.aggregation == variable_scope.VariableAggregation.NONE): 253 return tpu_util.make_raw_assign_fn( 254 gen_resource_variable_ops.assign_sub_variable_op)( 255 self, 256 value=value, 257 use_locking=use_locking, 258 name=name, 259 read_value=read_value) 260 return assign_sub( 261 self, value, use_locking=use_locking, name=name, read_value=read_value) 262 263 def assign_add(self, value, use_locking=False, name=None, read_value=True): 264 tpu_context = tpu_util.enclosing_tpu_context() 265 if (self._is_replicated_or_sharded_to_logical_cores() and 266 tpu_context is None): 267 assign_add_fn = lambda v, *a, **ka: v.assign_add(*a, **ka) 268 return self._update( 269 update_fn=assign_add_fn, 270 value=value, 271 use_locking=use_locking, 272 name=name, 273 read_value=read_value) 274 275 if (tpu_context and 276 self.aggregation == variable_scope.VariableAggregation.NONE): 277 return tpu_util.make_raw_assign_fn( 278 gen_resource_variable_ops.assign_add_variable_op)( 279 self, 280 value=value, 281 use_locking=use_locking, 282 name=name, 283 read_value=read_value) 284 return assign_add( 285 self, value, use_locking=use_locking, name=name, read_value=read_value) 286 287 def assign(self, value, use_locking=False, name=None, read_value=True): 288 tpu_context = tpu_util.enclosing_tpu_context() 289 if (self._is_replicated_or_sharded_to_logical_cores() and 290 tpu_context is None): 291 assign_fn = lambda v, *a, **ka: v.assign(*a, **ka) 292 return self._update( 293 update_fn=assign_fn, 294 value=value, 295 use_locking=use_locking, 296 name=name, 297 read_value=read_value) 298 299 if (tpu_util.enclosing_tpu_context() and 300 self.aggregation == variable_scope.VariableAggregation.NONE): 301 return tpu_util.make_raw_assign_fn( 302 gen_resource_variable_ops.assign_variable_op)( 303 self, 304 value=value, 305 use_locking=use_locking, 306 name=name, 307 read_value=read_value) 308 return assign( 309 self, value, use_locking=use_locking, name=name, read_value=read_value) 310 311 def scatter_sub(self, *args, **kwargs): 312 if values_util.is_saving_non_distributed(): 313 return self._primary.scatter_sub(*args, **kwargs) 314 raise NotImplementedError 315 316 def scatter_add(self, *args, **kwargs): 317 if values_util.is_saving_non_distributed(): 318 return self._primary.scatter_add(*args, **kwargs) 319 raise NotImplementedError 320 321 def scatter_max(self, *args, **kwargs): 322 if values_util.is_saving_non_distributed(): 323 return self._primary.scatter_max(*args, **kwargs) 324 raise NotImplementedError 325 326 def scatter_min(self, *args, **kwargs): 327 if values_util.is_saving_non_distributed(): 328 return self._primary.scatter_min(*args, **kwargs) 329 raise NotImplementedError 330 331 def scatter_mul(self, *args, **kwargs): 332 if values_util.is_saving_non_distributed(): 333 return self._primary.scatter_mul(*args, **kwargs) 334 raise NotImplementedError 335 336 def scatter_div(self, *args, **kwargs): 337 if values_util.is_saving_non_distributed(): 338 return self._primary.scatter_div(*args, **kwargs) 339 raise NotImplementedError 340 341 def scatter_update(self, *args, **kwargs): 342 if values_util.is_saving_non_distributed(): 343 return self._primary.scatter_update(*args, **kwargs) 344 raise NotImplementedError 345 346 def _is_mirrored(self): 347 return True 348 349 350class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable): 351 """Holds a map from replica to variables whose values are reduced on save.""" 352 353 def assign_sub(self, *args, **kwargs): 354 if tpu_util.enclosing_tpu_context() is None: 355 return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs) 356 else: 357 return tpu_util.make_raw_assign_fn( 358 gen_resource_variable_ops.assign_sub_variable_op)(self, *args, 359 **kwargs) 360 361 def assign_add(self, *args, **kwargs): 362 if tpu_util.enclosing_tpu_context() is None: 363 return values.SyncOnReadVariable.assign_add(self, *args, **kwargs) 364 else: 365 return tpu_util.make_raw_assign_fn( 366 gen_resource_variable_ops.assign_add_variable_op)(self, *args, 367 **kwargs) 368 369 def assign(self, *args, **kwargs): 370 if tpu_util.enclosing_tpu_context() is None: 371 return values.SyncOnReadVariable.assign(self, *args, **kwargs) 372 else: 373 return tpu_util.make_raw_assign_fn( 374 gen_resource_variable_ops.assign_variable_op)(self, *args, **kwargs) 375 376 def _is_mirrored(self): 377 return False 378 379 380# Common method between OnWrite and Mirrored variables. 381def assign_sub(var, value, use_locking=False, name=None, read_value=True): 382 assign_sub_fn = tpu_util.make_raw_assign_fn( 383 gen_resource_variable_ops.assign_sub_variable_op) 384 return var._update( # pylint: disable=protected-access 385 update_fn=assign_sub_fn, 386 value=value, 387 use_locking=use_locking, 388 name=name, 389 read_value=read_value) 390 391 392def assign_add(var, value, use_locking=False, name=None, read_value=True): 393 assign_add_fn = tpu_util.make_raw_assign_fn( 394 gen_resource_variable_ops.assign_add_variable_op) 395 return var._update( # pylint: disable=protected-access 396 update_fn=assign_add_fn, 397 value=value, 398 use_locking=use_locking, 399 name=name, 400 read_value=read_value) 401 402 403def assign(var, value, use_locking=False, name=None, read_value=True): 404 assign_fn = tpu_util.make_raw_assign_fn( 405 gen_resource_variable_ops.assign_variable_op) 406 return var._update( # pylint: disable=protected-access 407 update_fn=assign_fn, 408 value=value, 409 use_locking=use_locking, 410 name=name, 411 read_value=read_value) 412 413 414class TPUOnWritePolicy(values.OnWritePolicy): 415 """Policy defined for `tf.VariableSynchronization.ON_WRITE` synchronization. 416 417 This policy is created when `synchronization` is set to 418 `tf.VariableSynchronization.AUTO` or `tf.VariableSynchronization.ON_WRITE`. 419 """ 420 421 def assign_sub(self, 422 var, 423 value, 424 use_locking=False, 425 name=None, 426 read_value=True): 427 if (tpu_util.enclosing_tpu_context() and 428 var.aggregation == variable_scope.VariableAggregation.NONE): 429 return tpu_util.make_raw_assign_fn( 430 gen_resource_variable_ops.assign_sub_variable_op)( 431 var, 432 value=value, 433 use_locking=use_locking, 434 name=name, 435 read_value=read_value) 436 return assign_sub( 437 var, value, use_locking=use_locking, name=name, read_value=read_value) 438 439 def assign_add(self, 440 var, 441 value, 442 use_locking=False, 443 name=None, 444 read_value=True): 445 if (tpu_util.enclosing_tpu_context() and 446 var.aggregation == variable_scope.VariableAggregation.NONE): 447 return tpu_util.make_raw_assign_fn( 448 gen_resource_variable_ops.assign_add_variable_op)( 449 var, 450 value=value, 451 use_locking=use_locking, 452 name=name, 453 read_value=read_value) 454 return assign_add( 455 var, value, use_locking=use_locking, name=name, read_value=read_value) 456 457 def assign(self, var, value, use_locking=False, name=None, read_value=True): 458 if (tpu_util.enclosing_tpu_context() and 459 var.aggregation == variable_scope.VariableAggregation.NONE): 460 return tpu_util.make_raw_assign_fn( 461 gen_resource_variable_ops.assign_variable_op)( 462 var, 463 value=value, 464 use_locking=use_locking, 465 name=name, 466 read_value=read_value) 467 return assign( 468 var, value, use_locking=use_locking, name=name, read_value=read_value) 469 470 def _scatter_xxx(self, 471 raw_scater_xxx_fn, 472 op_name, 473 var, 474 sparse_delta, 475 use_locking=False, 476 name=None): 477 scater_xxx_fn = tpu_util.make_raw_scatter_xxx_fn(raw_scater_xxx_fn) 478 if tpu_util.enclosing_tpu_context(): 479 if self._aggregation != variable_scope.VariableAggregation.NONE: 480 raise NotImplementedError( 481 _scatter_error_msg.format( 482 op_name=op_name, aggregation=self._aggregation)) 483 return scater_xxx_fn( 484 var, sparse_delta=sparse_delta, use_locking=use_locking, name=name) 485 else: 486 return var._update( # pylint: disable=protected-access 487 update_fn=scater_xxx_fn, 488 value=sparse_delta, 489 use_locking=use_locking, 490 name=name) 491 492 def scatter_sub(self, var, sparse_delta, use_locking=False, name=None): 493 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_sub, 494 "scatter_sub", var, sparse_delta, use_locking, 495 name) 496 497 def scatter_add(self, var, sparse_delta, use_locking=False, name=None): 498 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_add, 499 "scatter_add", var, sparse_delta, use_locking, 500 name) 501 502 def scatter_max(self, var, sparse_delta, use_locking=False, name=None): 503 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_max, 504 "scatter_max", var, sparse_delta, use_locking, 505 name) 506 507 def scatter_min(self, var, sparse_delta, use_locking=False, name=None): 508 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_min, 509 "scatter_min", var, sparse_delta, use_locking, 510 name) 511 512 def scatter_mul(self, var, sparse_delta, use_locking=False, name=None): 513 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_mul, 514 "scatter_mul", var, sparse_delta, use_locking, 515 name) 516 517 def scatter_div(self, var, sparse_delta, use_locking=False, name=None): 518 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_div, 519 "scatter_div", var, sparse_delta, use_locking, 520 name) 521 522 def scatter_update(self, var, sparse_delta, use_locking=False, name=None): 523 return self._scatter_xxx(gen_resource_variable_ops.resource_scatter_update, 524 "scatter_update", var, sparse_delta, use_locking, 525 name) 526 527 def _is_mirrored(self): 528 return True 529 530 531class TPUOnReadPolicy(values.OnReadPolicy): 532 """Policy defined for `tf.VariableSynchronization.ON_READ` synchronization. 533 534 This policy is created when `synchronization` is set to 535 `tf.VariableSynchronization.ON_READ` and `aggregation` is set to any of the 536 values allowed by the `tf.VariableAggregation` enum such as `NONE`, `SUM`, 537 `MEAN` or `ONLY_FIRST_REPLICA`when creating a `tf.Variable` in `tf.distribute` 538 scope. 539 """ 540 541 def assign_sub(self, var, *args, **kwargs): 542 if tpu_util.enclosing_tpu_context() is None: 543 return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs) 544 else: 545 return tpu_util.make_raw_assign_fn( 546 gen_resource_variable_ops.assign_sub_variable_op)(var, *args, 547 **kwargs) 548 549 def assign_add(self, var, *args, **kwargs): 550 if tpu_util.enclosing_tpu_context() is None: 551 return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs) 552 else: 553 return tpu_util.make_raw_assign_fn( 554 gen_resource_variable_ops.assign_add_variable_op)(var, *args, 555 **kwargs) 556 557 def assign(self, var, *args, **kwargs): 558 if tpu_util.enclosing_tpu_context() is None: 559 return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs) 560 else: 561 return tpu_util.make_raw_assign_fn( 562 gen_resource_variable_ops.assign_variable_op)(var, *args, **kwargs) 563 564 def _is_mirrored(self): 565 return False 566 567 def scatter_sub(self, *args, **kwargs): 568 raise NotImplementedError 569 570 def scatter_add(self, *args, **kwargs): 571 raise NotImplementedError 572 573 def scatter_max(self, *args, **kwargs): 574 raise NotImplementedError 575 576 def scatter_min(self, *args, **kwargs): 577 raise NotImplementedError 578 579 def scatter_mul(self, *args, **kwargs): 580 raise NotImplementedError 581 582 def scatter_div(self, *args, **kwargs): 583 raise NotImplementedError 584 585 def scatter_update(self, *args, **kwargs): 586 raise NotImplementedError 587