1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Standard functions for creating slots. 17 18A slot is a `Variable` created with the same first m-dimension as a primary 19variable or `Tensor`. A slot is always scoped in the namespace of the primary 20object and typically has the same device and type. 21 22Slots are typically used as accumulators to track values associated with 23the primary object: 24 25```python 26# Optimizers can create a slot for each variable to track accumulators 27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs} 28for var in vs: 29 apply_momentum(var, accumulators[var], lr, grad, momentum_tensor) 30 31# Slots can also be used for moving averages 32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg") 33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay)) 34``` 35""" 36# pylint: disable=g-bad-name 37 38from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 39from tensorflow.python.distribute import distribution_strategy_context 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import init_ops 42from tensorflow.python.ops import resource_variable_ops 43from tensorflow.python.ops import variable_scope 44from tensorflow.python.ops import variables 45 46 47def _create_slot_var(primary, 48 val, 49 scope, 50 validate_shape, 51 shape, 52 dtype, 53 *, 54 copy_xla_sharding=False): 55 """Helper function for creating a slot variable.""" 56 57 # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current 58 # scope. 59 current_partitioner = variable_scope.get_variable_scope().partitioner 60 variable_scope.get_variable_scope().set_partitioner(None) 61 # When init from val instead of callable initializer, the shape is expected to 62 # be None, not <unknown> or any fully defined shape. 63 shape = shape if callable(val) else None 64 if resource_variable_ops.is_resource_variable(primary): 65 use_resource = True 66 elif isinstance(primary, variables.RefVariable): 67 use_resource = False 68 else: 69 use_resource = None 70 slot = variable_scope.get_variable( 71 scope, 72 initializer=val, 73 trainable=False, 74 use_resource=use_resource, 75 shape=shape, 76 dtype=dtype, 77 validate_shape=validate_shape) 78 variable_scope.get_variable_scope().set_partitioner(current_partitioner) 79 80 # pylint: disable=protected-access 81 if isinstance(primary, variables.Variable) and primary._save_slice_info: 82 # Primary is a partitioned variable, so we need to also indicate that 83 # the slot is a partitioned variable. Slots have the same partitioning 84 # as their primaries. 85 # For examples when using AdamOptimizer in linear model, slot.name 86 # here can be "linear//weights/Adam:0", while primary.op.name is 87 # "linear//weight". We want to get 'Adam' as real_slot_name, so we 88 # remove "'linear//weight' + '/'" and ':0'. 89 real_slot_name = slot.name[len(primary.op.name + "/"):-2] 90 slice_info = primary._save_slice_info 91 # support slot's shape not same as primary's shape 92 # example: primary's shape = [10, 20, 30], slot's shape = 93 # None, [], [10], [10, 20] or [10, 20, 30] is allowed 94 # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary 95 # slot's shape = [], don't set slot's slice_info 96 # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims 97 n = slot.shape.ndims 98 if n is None or n > 0: 99 slot._set_save_slice_info( 100 variables.Variable.SaveSliceInfo( 101 slice_info.full_name + "/" + real_slot_name, 102 slice_info.full_shape[:n], slice_info.var_offset[:n], 103 slice_info.var_shape[:n])) 104 # pylint: enable=protected-access 105 106 # Copy XLA sharding attributes from the primary if the slot variable has the 107 # same rank as the primary. 108 def _has_same_rank(primary_shape, slot_shape): 109 return (primary_shape.rank is not None and slot_shape.rank is not None and 110 primary_shape.rank == slot_shape.rank) 111 112 if copy_xla_sharding and _has_same_rank(primary.shape, slot.shape): 113 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False) 114 return slot 115 116 117def create_slot(primary, 118 val, 119 name, 120 colocate_with_primary=True, 121 *, 122 copy_xla_sharding=False): 123 """Create a slot initialized to the given value. 124 125 The type of the slot is determined by the given value. 126 127 Args: 128 primary: The primary `Variable` or `Tensor`. 129 val: A `Tensor` specifying the initial value of the slot. 130 name: Name to use for the slot variable. 131 colocate_with_primary: Boolean. If True the slot is located 132 on the same device as `primary`. 133 copy_xla_sharding: Boolean. If True also copies XLA sharding 134 from primary. 135 136 Returns: 137 A `Variable` object. 138 """ 139 # Scope the slot name in the namespace of the primary variable. 140 # Set primary's name + '/' + name as default name, so the scope name of 141 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 142 # and the same name has been previously used, the scope name will add '_N' 143 # as suffix for unique identifications. 144 validate_shape = val.get_shape().is_fully_defined() 145 if isinstance(primary, variables.Variable): 146 prefix = primary._shared_name # pylint: disable=protected-access 147 else: 148 prefix = primary.op.name 149 with variable_scope.variable_scope(None, prefix + "/" + name): 150 if colocate_with_primary: 151 distribution_strategy = distribution_strategy_context.get_strategy() 152 with distribution_strategy.extended.colocate_vars_with(primary): 153 return _create_slot_var( 154 primary, 155 val, 156 "", 157 validate_shape, 158 None, 159 None, 160 copy_xla_sharding=copy_xla_sharding) 161 else: 162 return _create_slot_var( 163 primary, 164 val, 165 "", 166 validate_shape, 167 None, 168 None, 169 copy_xla_sharding=copy_xla_sharding) 170 171 172def create_slot_with_initializer(primary, 173 initializer, 174 shape, 175 dtype, 176 name, 177 colocate_with_primary=True, 178 *, 179 copy_xla_sharding=False): 180 """Creates a slot initialized using an `Initializer`. 181 182 The type of the slot is determined by the given value. 183 184 Args: 185 primary: The primary `Variable` or `Tensor`. 186 initializer: An `Initializer`. The initial value of the slot. 187 shape: Shape of the initial value of the slot. 188 dtype: Type of the value of the slot. 189 name: Name to use for the slot variable. 190 colocate_with_primary: Boolean. If True the slot is located 191 on the same device as `primary`. 192 copy_xla_sharding: Boolean. If True also copies XLA sharding 193 from primary. 194 195 Returns: 196 A `Variable` object. 197 """ 198 # Scope the slot name in the namespace of the primary variable. 199 # Set "primary.op.name + '/' + name" as default name, so the scope name of 200 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 201 # and the same name has been previously used, the scope name will add '_N' 202 # as suffix for unique identifications. 203 validate_shape = shape.is_fully_defined() 204 if isinstance(primary, variables.Variable): 205 prefix = primary._shared_name # pylint: disable=protected-access 206 else: 207 prefix = primary.op.name 208 with variable_scope.variable_scope(None, prefix + "/" + name): 209 if colocate_with_primary: 210 distribution_strategy = distribution_strategy_context.get_strategy() 211 with distribution_strategy.extended.colocate_vars_with(primary): 212 return _create_slot_var( 213 primary, 214 initializer, 215 "", 216 validate_shape, 217 shape, 218 dtype, 219 copy_xla_sharding=copy_xla_sharding) 220 else: 221 return _create_slot_var( 222 primary, 223 initializer, 224 "", 225 validate_shape, 226 shape, 227 dtype, 228 copy_xla_sharding=copy_xla_sharding) 229 230 231def create_zeros_slot(primary, 232 name, 233 dtype=None, 234 colocate_with_primary=True, 235 *, 236 copy_xla_sharding=False): 237 """Create a slot initialized to 0 with same shape as the primary object. 238 239 Args: 240 primary: The primary `Variable` or `Tensor`. 241 name: Name to use for the slot variable. 242 dtype: Type of the slot variable. Defaults to the type of `primary`. 243 colocate_with_primary: Boolean. If True the slot is located 244 on the same device as `primary`. 245 copy_xla_sharding: Boolean. If True also copies XLA sharding 246 from primary. 247 248 Returns: 249 A `Variable` object. 250 """ 251 if dtype is None: 252 dtype = primary.dtype 253 slot_shape = primary.get_shape() 254 if slot_shape.is_fully_defined(): 255 initializer = init_ops.zeros_initializer() 256 return create_slot_with_initializer( 257 primary, 258 initializer, 259 slot_shape, 260 dtype, 261 name, 262 colocate_with_primary=colocate_with_primary, 263 copy_xla_sharding=copy_xla_sharding) 264 else: 265 if isinstance(primary, variables.Variable): 266 slot_shape = array_ops.shape(primary.initialized_value()) 267 else: 268 slot_shape = array_ops.shape(primary) 269 val = array_ops.zeros(slot_shape, dtype=dtype) 270 return create_slot( 271 primary, 272 val, 273 name, 274 colocate_with_primary=colocate_with_primary, 275 copy_xla_sharding=copy_xla_sharding) 276