1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Modules encapsulate building stateful components.""" 16 17import re 18 19from tensorflow.python import tf2 20from tensorflow.python.framework import composite_tensor 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import variables 23from tensorflow.python.trackable import autotrackable 24from tensorflow.python.util import nest 25from tensorflow.python.util import tf_decorator 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("Module") 30class Module(autotrackable.AutoTrackable): 31 """Base neural network module class. 32 33 A module is a named container for `tf.Variable`s, other `tf.Module`s and 34 functions which apply to user input. For example a dense layer in a neural 35 network might be implemented as a `tf.Module`: 36 37 >>> class Dense(tf.Module): 38 ... def __init__(self, input_dim, output_size, name=None): 39 ... super().__init__(name=name) 40 ... self.w = tf.Variable( 41 ... tf.random.normal([input_dim, output_size]), name='w') 42 ... self.b = tf.Variable(tf.zeros([output_size]), name='b') 43 ... def __call__(self, x): 44 ... y = tf.matmul(x, self.w) + self.b 45 ... return tf.nn.relu(y) 46 47 You can use the Dense layer as you would expect: 48 49 >>> d = Dense(input_dim=3, output_size=2) 50 >>> d(tf.ones([1, 3])) 51 <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)> 52 53 54 By subclassing `tf.Module` instead of `object` any `tf.Variable` or 55 `tf.Module` instances assigned to object properties can be collected using 56 the `variables`, `trainable_variables` or `submodules` property: 57 58 >>> d.variables 59 (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=..., 60 dtype=float32)>, 61 <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>) 62 63 64 Subclasses of `tf.Module` can also take advantage of the `_flatten` method 65 which can be used to implement tracking of any other types. 66 67 All `tf.Module` classes have an associated `tf.name_scope` which can be used 68 to group operations in TensorBoard and create hierarchies for variable names 69 which can help with debugging. We suggest using the name scope when creating 70 nested submodules/parameters or for forward methods whose graph you might want 71 to inspect in TensorBoard. You can enter the name scope explicitly using 72 `with self.name_scope:` or you can annotate methods (apart from `__init__`) 73 with `@tf.Module.with_name_scope`. 74 75 >>> class MLP(tf.Module): 76 ... def __init__(self, input_size, sizes, name=None): 77 ... super().__init__(name=name) 78 ... self.layers = [] 79 ... with self.name_scope: 80 ... for size in sizes: 81 ... self.layers.append(Dense(input_dim=input_size, output_size=size)) 82 ... input_size = size 83 ... @tf.Module.with_name_scope 84 ... def __call__(self, x): 85 ... for layer in self.layers: 86 ... x = layer(x) 87 ... return x 88 89 >>> module = MLP(input_size=5, sizes=[5, 5]) 90 >>> module.variables 91 (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 92 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 93 dtype=float32)>, 94 <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 95 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 96 dtype=float32)>) 97 """ 98 99 # AutoTrackable adds object attributes that users will not expect us to 100 # include when flattening (these reference dependencies reachable via other 101 # object attributes). 102 _TF_MODULE_IGNORED_PROPERTIES = frozenset(( 103 "_self_unconditional_checkpoint_dependencies", 104 "_self_unconditional_dependency_names" 105 )) 106 107 def __init__(self, name=None): 108 if name is None: 109 name = camel_to_snake(type(self).__name__) 110 else: 111 if not valid_identifier(name): 112 raise ValueError( 113 "%r is not a valid module name. Module names must be valid Python " 114 "identifiers (e.g. a valid class name)." % name) 115 116 self._name = name 117 if tf2.enabled(): 118 with ops.name_scope_v2(name) as scope_name: 119 self._name_scope = ops.name_scope_v2(scope_name) 120 else: 121 with ops.name_scope(name, skip_on_eager=False) as scope_name: 122 self._scope_name = scope_name 123 124 @property 125 def name(self): 126 """Returns the name of this module as passed or determined in the ctor. 127 128 NOTE: This is not the same as the `self.name_scope.name` which includes 129 parent module names. 130 """ 131 return self._name 132 133 @property 134 def name_scope(self): 135 """Returns a `tf.name_scope` instance for this class.""" 136 if tf2.enabled(): 137 return self._name_scope 138 else: 139 # In TF1 name_scope is not re-entrant in eager so we cannot memoize it. 140 return ops.name_scope(self._scope_name, skip_on_eager=False) 141 142 @property 143 def variables(self): 144 """Sequence of variables owned by this module and its submodules. 145 146 Note: this method uses reflection to find variables on the current instance 147 and submodules. For performance reasons you may wish to cache the result 148 of calling this method if you don't expect the return value to change. 149 150 Returns: 151 A sequence of variables for the current module (sorted by attribute 152 name) followed by variables from all submodules recursively (breadth 153 first). 154 """ 155 return tuple(self._flatten(predicate=_is_variable, expand_composites=True)) 156 157 @property 158 def trainable_variables(self): 159 """Sequence of trainable variables owned by this module and its submodules. 160 161 Note: this method uses reflection to find variables on the current instance 162 and submodules. For performance reasons you may wish to cache the result 163 of calling this method if you don't expect the return value to change. 164 165 Returns: 166 A sequence of variables for the current module (sorted by attribute 167 name) followed by variables from all submodules recursively (breadth 168 first). 169 """ 170 return tuple( 171 self._flatten(predicate=_is_trainable_variable, expand_composites=True)) 172 173 @property 174 def non_trainable_variables(self): 175 """Sequence of non-trainable variables owned by this module and its submodules. 176 177 Note: this method uses reflection to find variables on the current instance 178 and submodules. For performance reasons you may wish to cache the result 179 of calling this method if you don't expect the return value to change. 180 181 Returns: 182 A sequence of variables for the current module (sorted by attribute 183 name) followed by variables from all submodules recursively (breadth 184 first). 185 """ 186 return tuple(self._flatten( 187 predicate=_is_non_trainable_variable, expand_composites=True)) 188 189 @property 190 def submodules(self): 191 """Sequence of all sub-modules. 192 193 Submodules are modules which are properties of this module, or found as 194 properties of modules which are properties of this module (and so on). 195 196 >>> a = tf.Module() 197 >>> b = tf.Module() 198 >>> c = tf.Module() 199 >>> a.b = b 200 >>> b.c = c 201 >>> list(a.submodules) == [b, c] 202 True 203 >>> list(b.submodules) == [c] 204 True 205 >>> list(c.submodules) == [] 206 True 207 208 Returns: 209 A sequence of all submodules. 210 """ 211 return tuple(self._flatten(predicate=_is_module)) 212 213 def _flatten(self, 214 recursive=True, 215 predicate=None, 216 attribute_traversal_key=None, 217 with_path=False, 218 expand_composites=False): 219 """Flattened attribute values in sorted order by attribute name. 220 221 Modules are flattened by first walking their attributes in name order. 222 Each attribute value is then flattened to find leaf values. If flatten is 223 applied `recursive`ly and if the leaf is a `Module` it will also be 224 flattened to find leaves. Finally every leaf value is optionally tested 225 against the given `predicate` and finally yielded. 226 227 ``` 228 class Foo(tf.Module): 229 def __init__(self): 230 super().__init__() 231 self.x = [tf.constant('a'), tf.constant('b')] 232 self.y = {'i': tf.constant('c'), 'j': tf.constant('d')} 233 self.z = tf.constant('e') 234 235 @property 236 def tensors(self): 237 return tuple(self._flatten(predicate=is_tensor, with_path=True)) 238 239 foo = Foo() 240 foo.tensors 241 # ==> ((('x', 0), <tf.Tensor: ...'a'>), 242 # (('x', 1), <tf.Tensor: ...'b'>), 243 # (('y', 'i'), <tf.Tensor: ...'c'>), 244 # (('y', 'j'), <tf.Tensor: ...'d'>), 245 # (('z',), <tf.Tensor: ...'e'>)) 246 ``` 247 248 `attribute_traversal_key` controls the order object properties are visited. 249 If not set objects are visited in ascending order by name. 250 251 Args: 252 recursive: Whether to recurse into child modules or not. 253 predicate: (Optional) If set then only values matching predicate are 254 yielded. A value of `None` (the default) means no items will be 255 filtered. 256 attribute_traversal_key: (Optional) Method to rekey object attributes 257 before they are sorted. Contract is the same as `key` argument to 258 builtin `sorted` and only applies to object properties. 259 with_path: (Optional) Whether to include the path to the object as well 260 as the object itself. If `with_path` is `True` then leaves will not be 261 de-duplicated (e.g. if the same leaf instance is reachable via multiple 262 modules then it will be yielded multiple times with different paths). 263 expand_composites: If true, then composite tensors are expanded into their 264 component tensors. 265 266 Returns: 267 Flat generator for leaves of the current module and optionally all 268 submodules. 269 """ 270 if predicate is None: 271 predicate = lambda _: True 272 273 return _flatten_module( 274 self, 275 recursive=recursive, 276 predicate=predicate, 277 attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES, 278 attribute_traversal_key=attribute_traversal_key, 279 with_path=with_path, 280 expand_composites=expand_composites) 281 282 @classmethod 283 def with_name_scope(cls, method): 284 """Decorator to automatically enter the module name scope. 285 286 >>> class MyModule(tf.Module): 287 ... @tf.Module.with_name_scope 288 ... def __call__(self, x): 289 ... if not hasattr(self, 'w'): 290 ... self.w = tf.Variable(tf.random.normal([x.shape[1], 3])) 291 ... return tf.matmul(x, self.w) 292 293 Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose 294 names included the module name: 295 296 >>> mod = MyModule() 297 >>> mod(tf.ones([1, 2])) 298 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)> 299 >>> mod.w 300 <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32, 301 numpy=..., dtype=float32)> 302 303 Args: 304 method: The method to wrap. 305 306 Returns: 307 The original method wrapped such that it enters the module's name scope. 308 """ 309 def method_with_name_scope(self, *args, **kwargs): 310 with self.name_scope: 311 return method(self, *args, **kwargs) 312 313 return tf_decorator.make_decorator(method, method_with_name_scope) 314 315 316def _is_variable(obj): 317 return isinstance(obj, variables.Variable) 318 319 320def _is_trainable_variable(obj): 321 return _is_variable(obj) and getattr(obj, "trainable", False) 322 323 324def _is_non_trainable_variable(obj): 325 return _is_variable(obj) and not getattr(obj, "trainable", False) 326 327 328def _is_module(obj): 329 return isinstance(obj, Module) 330 331_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") 332_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$") 333 334 335def valid_identifier(name): 336 return bool(_VALID_IDENTIFIER.match(name)) 337 338 339def camel_to_snake(value): 340 return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower() 341 342 343def _flatten_non_variable_composites_with_tuple_path(structure, path_prefix=()): 344 """Flattens composite tensors with tuple path expect variables.""" 345 for path, child in nest.flatten_with_tuple_paths(structure): 346 if (isinstance(child, composite_tensor.CompositeTensor) and 347 not _is_variable(child)): 348 # pylint: disable=protected-access 349 spec = child._type_spec 350 yield from _flatten_non_variable_composites_with_tuple_path( 351 spec._to_components(child), 352 path_prefix + path + (spec.value_type.__name__,)) 353 # pylint: enable=protected-access 354 else: 355 yield path_prefix + path, child 356 357 358def _flatten_module(module, 359 recursive, 360 predicate, 361 attribute_traversal_key, 362 attributes_to_ignore, 363 with_path, 364 expand_composites, 365 module_path=(), 366 seen=None, 367 recursion_stack=None): 368 """Implementation of `flatten`. 369 370 Args: 371 module: Current module to process. 372 recursive: Whether to recurse into child modules or not. 373 predicate: (Optional) If set then only values matching predicate are 374 yielded. A value of `None` (the default) means no items will be 375 filtered. 376 attribute_traversal_key: (Optional) Method to rekey object attributes 377 before they are sorted. Contract is the same as `key` argument to 378 builtin `sorted` and only applies to object properties. 379 attributes_to_ignore: object attributes to ignored. 380 with_path: (Optional) Whether to include the path to the object as well 381 as the object itself. If `with_path` is `True` then leaves will not be 382 de-duplicated (e.g. if the same leaf instance is reachable via multiple 383 modules then it will be yielded multiple times with different paths). 384 expand_composites: If true, then composite tensors are expanded into their 385 component tensors. 386 module_path: The path to the current module as a tuple. 387 seen: A set containing all leaf IDs seen so far. 388 recursion_stack: A list containing all module IDs associated with the 389 current call stack. 390 391 Yields: 392 Matched leaves with the optional corresponding paths of the current module 393 and optionally all its submodules. 394 """ 395 module_id = id(module) 396 if seen is None: 397 seen = set([module_id]) 398 399 module_dict = vars(module) 400 submodules = [] 401 402 if recursion_stack is None: 403 recursion_stack = [] 404 405 # When calling `_flatten_module` with `with_path=False`, the global lookup 406 # table `seen` guarantees the uniqueness of the matched objects. 407 # In the case of `with_path=True`, there might be multiple paths associated 408 # with the same predicate, so we don't stop traversing according to `seen` 409 # to make sure all these paths are returned. 410 # When there are cycles connecting submodules, we break cycles by avoiding 411 # following back edges (links pointing to a node in `recursion_stack`). 412 if module_id in recursion_stack: 413 recursive = False 414 415 for key in sorted(module_dict, key=attribute_traversal_key): 416 if key in attributes_to_ignore: 417 continue 418 419 prop = module_dict[key] 420 try: 421 if expand_composites: 422 leaves = list(_flatten_non_variable_composites_with_tuple_path(prop)) 423 else: 424 leaves = nest.flatten_with_tuple_paths(prop) 425 except Exception as cause: # pylint: disable=broad-except 426 raise ValueError("Error processing property {!r} of {!r}".format( 427 key, prop)) from cause 428 429 for leaf_path, leaf in leaves: 430 leaf_path = (key,) + leaf_path 431 432 if not with_path: 433 leaf_id = id(leaf) 434 if leaf_id in seen: 435 continue 436 seen.add(leaf_id) 437 438 if predicate(leaf): 439 if with_path: 440 yield module_path + leaf_path, leaf 441 else: 442 yield leaf 443 444 if recursive and _is_module(leaf): 445 # Walk direct properties first then recurse. 446 submodules.append((module_path + leaf_path, leaf)) 447 448 recursion_stack.append(module_id) 449 450 for submodule_path, submodule in submodules: 451 subvalues = _flatten_module( 452 submodule, 453 recursive=recursive, 454 predicate=predicate, 455 attribute_traversal_key=attribute_traversal_key, 456 attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access 457 with_path=with_path, 458 expand_composites=expand_composites, 459 module_path=submodule_path, 460 seen=seen, 461 recursion_stack=recursion_stack) 462 463 for subvalue in subvalues: 464 # Predicate is already tested for these values. 465 yield subvalue 466 467 recursion_stack.pop() 468