xref: /aosp_15_r20/external/tensorflow/tensorflow/python/module/module.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Modules encapsulate building stateful components."""
16
17import re
18
19from tensorflow.python import tf2
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import variables
23from tensorflow.python.trackable import autotrackable
24from tensorflow.python.util import nest
25from tensorflow.python.util import tf_decorator
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("Module")
30class Module(autotrackable.AutoTrackable):
31  """Base neural network module class.
32
33  A module is a named container for `tf.Variable`s, other `tf.Module`s and
34  functions which apply to user input. For example a dense layer in a neural
35  network might be implemented as a `tf.Module`:
36
37  >>> class Dense(tf.Module):
38  ...   def __init__(self, input_dim, output_size, name=None):
39  ...     super().__init__(name=name)
40  ...     self.w = tf.Variable(
41  ...       tf.random.normal([input_dim, output_size]), name='w')
42  ...     self.b = tf.Variable(tf.zeros([output_size]), name='b')
43  ...   def __call__(self, x):
44  ...     y = tf.matmul(x, self.w) + self.b
45  ...     return tf.nn.relu(y)
46
47  You can use the Dense layer as you would expect:
48
49  >>> d = Dense(input_dim=3, output_size=2)
50  >>> d(tf.ones([1, 3]))
51  <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>
52
53
54  By subclassing `tf.Module` instead of `object` any `tf.Variable` or
55  `tf.Module` instances assigned to object properties can be collected using
56  the `variables`, `trainable_variables` or `submodules` property:
57
58  >>> d.variables
59      (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
60      dtype=float32)>,
61      <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)
62
63
64  Subclasses of `tf.Module` can also take advantage of the `_flatten` method
65  which can be used to implement tracking of any other types.
66
67  All `tf.Module` classes have an associated `tf.name_scope` which can be used
68  to group operations in TensorBoard and create hierarchies for variable names
69  which can help with debugging. We suggest using the name scope when creating
70  nested submodules/parameters or for forward methods whose graph you might want
71  to inspect in TensorBoard. You can enter the name scope explicitly using
72  `with self.name_scope:` or you can annotate methods (apart from `__init__`)
73  with `@tf.Module.with_name_scope`.
74
75  >>> class MLP(tf.Module):
76  ...   def __init__(self, input_size, sizes, name=None):
77  ...     super().__init__(name=name)
78  ...     self.layers = []
79  ...     with self.name_scope:
80  ...       for size in sizes:
81  ...         self.layers.append(Dense(input_dim=input_size, output_size=size))
82  ...         input_size = size
83  ...   @tf.Module.with_name_scope
84  ...   def __call__(self, x):
85  ...     for layer in self.layers:
86  ...       x = layer(x)
87  ...     return x
88
89  >>> module = MLP(input_size=5, sizes=[5, 5])
90  >>> module.variables
91  (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
92  <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
93     dtype=float32)>,
94  <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
95  <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
96     dtype=float32)>)
97  """
98
99  # AutoTrackable adds object attributes that users will not expect us to
100  # include when flattening (these reference dependencies reachable via other
101  # object attributes).
102  _TF_MODULE_IGNORED_PROPERTIES = frozenset((
103      "_self_unconditional_checkpoint_dependencies",
104      "_self_unconditional_dependency_names"
105  ))
106
107  def __init__(self, name=None):
108    if name is None:
109      name = camel_to_snake(type(self).__name__)
110    else:
111      if not valid_identifier(name):
112        raise ValueError(
113            "%r is not a valid module name. Module names must be valid Python "
114            "identifiers (e.g. a valid class name)." % name)
115
116    self._name = name
117    if tf2.enabled():
118      with ops.name_scope_v2(name) as scope_name:
119        self._name_scope = ops.name_scope_v2(scope_name)
120    else:
121      with ops.name_scope(name, skip_on_eager=False) as scope_name:
122        self._scope_name = scope_name
123
124  @property
125  def name(self):
126    """Returns the name of this module as passed or determined in the ctor.
127
128    NOTE: This is not the same as the `self.name_scope.name` which includes
129    parent module names.
130    """
131    return self._name
132
133  @property
134  def name_scope(self):
135    """Returns a `tf.name_scope` instance for this class."""
136    if tf2.enabled():
137      return self._name_scope
138    else:
139      # In TF1 name_scope is not re-entrant in eager so we cannot memoize it.
140      return ops.name_scope(self._scope_name, skip_on_eager=False)
141
142  @property
143  def variables(self):
144    """Sequence of variables owned by this module and its submodules.
145
146    Note: this method uses reflection to find variables on the current instance
147    and submodules. For performance reasons you may wish to cache the result
148    of calling this method if you don't expect the return value to change.
149
150    Returns:
151      A sequence of variables for the current module (sorted by attribute
152      name) followed by variables from all submodules recursively (breadth
153      first).
154    """
155    return tuple(self._flatten(predicate=_is_variable, expand_composites=True))
156
157  @property
158  def trainable_variables(self):
159    """Sequence of trainable variables owned by this module and its submodules.
160
161    Note: this method uses reflection to find variables on the current instance
162    and submodules. For performance reasons you may wish to cache the result
163    of calling this method if you don't expect the return value to change.
164
165    Returns:
166      A sequence of variables for the current module (sorted by attribute
167      name) followed by variables from all submodules recursively (breadth
168      first).
169    """
170    return tuple(
171        self._flatten(predicate=_is_trainable_variable, expand_composites=True))
172
173  @property
174  def non_trainable_variables(self):
175    """Sequence of non-trainable variables owned by this module and its submodules.
176
177    Note: this method uses reflection to find variables on the current instance
178    and submodules. For performance reasons you may wish to cache the result
179    of calling this method if you don't expect the return value to change.
180
181    Returns:
182      A sequence of variables for the current module (sorted by attribute
183      name) followed by variables from all submodules recursively (breadth
184      first).
185    """
186    return tuple(self._flatten(
187        predicate=_is_non_trainable_variable, expand_composites=True))
188
189  @property
190  def submodules(self):
191    """Sequence of all sub-modules.
192
193    Submodules are modules which are properties of this module, or found as
194    properties of modules which are properties of this module (and so on).
195
196    >>> a = tf.Module()
197    >>> b = tf.Module()
198    >>> c = tf.Module()
199    >>> a.b = b
200    >>> b.c = c
201    >>> list(a.submodules) == [b, c]
202    True
203    >>> list(b.submodules) == [c]
204    True
205    >>> list(c.submodules) == []
206    True
207
208    Returns:
209      A sequence of all submodules.
210    """
211    return tuple(self._flatten(predicate=_is_module))
212
213  def _flatten(self,
214               recursive=True,
215               predicate=None,
216               attribute_traversal_key=None,
217               with_path=False,
218               expand_composites=False):
219    """Flattened attribute values in sorted order by attribute name.
220
221    Modules are flattened by first walking their attributes in name order.
222    Each attribute value is then flattened to find leaf values. If flatten is
223    applied `recursive`ly and if the leaf is a `Module` it will also be
224    flattened to find leaves. Finally every leaf value is optionally tested
225    against the given `predicate` and finally yielded.
226
227    ```
228    class Foo(tf.Module):
229      def __init__(self):
230        super().__init__()
231        self.x = [tf.constant('a'), tf.constant('b')]
232        self.y = {'i': tf.constant('c'), 'j': tf.constant('d')}
233        self.z = tf.constant('e')
234
235      @property
236      def tensors(self):
237        return tuple(self._flatten(predicate=is_tensor, with_path=True))
238
239    foo = Foo()
240    foo.tensors
241    # ==> ((('x', 0),   <tf.Tensor: ...'a'>),
242    #     (('x', 1),   <tf.Tensor: ...'b'>),
243    #     (('y', 'i'), <tf.Tensor: ...'c'>),
244    #     (('y', 'j'), <tf.Tensor: ...'d'>),
245    #     (('z',),     <tf.Tensor: ...'e'>))
246    ```
247
248    `attribute_traversal_key` controls the order object properties are visited.
249    If not set objects are visited in ascending order by name.
250
251    Args:
252      recursive: Whether to recurse into child modules or not.
253      predicate: (Optional) If set then only values matching predicate are
254        yielded. A value of `None` (the default) means no items will be
255        filtered.
256      attribute_traversal_key: (Optional) Method to rekey object attributes
257        before they are sorted. Contract is the same as `key` argument to
258        builtin `sorted` and only applies to object properties.
259      with_path: (Optional) Whether to include the path to the object as well
260        as the object itself. If `with_path` is `True` then leaves will not be
261        de-duplicated (e.g. if the same leaf instance is reachable via multiple
262        modules then it will be yielded multiple times with different paths).
263      expand_composites: If true, then composite tensors are expanded into their
264        component tensors.
265
266    Returns:
267      Flat generator for leaves of the current module and optionally all
268      submodules.
269    """
270    if predicate is None:
271      predicate = lambda _: True
272
273    return _flatten_module(
274        self,
275        recursive=recursive,
276        predicate=predicate,
277        attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
278        attribute_traversal_key=attribute_traversal_key,
279        with_path=with_path,
280        expand_composites=expand_composites)
281
282  @classmethod
283  def with_name_scope(cls, method):
284    """Decorator to automatically enter the module name scope.
285
286    >>> class MyModule(tf.Module):
287    ...   @tf.Module.with_name_scope
288    ...   def __call__(self, x):
289    ...     if not hasattr(self, 'w'):
290    ...       self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
291    ...     return tf.matmul(x, self.w)
292
293    Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
294    names included the module name:
295
296    >>> mod = MyModule()
297    >>> mod(tf.ones([1, 2]))
298    <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
299    >>> mod.w
300    <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
301    numpy=..., dtype=float32)>
302
303    Args:
304      method: The method to wrap.
305
306    Returns:
307      The original method wrapped such that it enters the module's name scope.
308    """
309    def method_with_name_scope(self, *args, **kwargs):
310      with self.name_scope:
311        return method(self, *args, **kwargs)
312
313    return tf_decorator.make_decorator(method, method_with_name_scope)
314
315
316def _is_variable(obj):
317  return isinstance(obj, variables.Variable)
318
319
320def _is_trainable_variable(obj):
321  return _is_variable(obj) and getattr(obj, "trainable", False)
322
323
324def _is_non_trainable_variable(obj):
325  return _is_variable(obj) and not getattr(obj, "trainable", False)
326
327
328def _is_module(obj):
329  return isinstance(obj, Module)
330
331_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
332_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$")
333
334
335def valid_identifier(name):
336  return bool(_VALID_IDENTIFIER.match(name))
337
338
339def camel_to_snake(value):
340  return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
341
342
343def _flatten_non_variable_composites_with_tuple_path(structure, path_prefix=()):
344  """Flattens composite tensors with tuple path expect variables."""
345  for path, child in nest.flatten_with_tuple_paths(structure):
346    if (isinstance(child, composite_tensor.CompositeTensor) and
347        not _is_variable(child)):
348      # pylint: disable=protected-access
349      spec = child._type_spec
350      yield from _flatten_non_variable_composites_with_tuple_path(
351          spec._to_components(child),
352          path_prefix + path + (spec.value_type.__name__,))
353      # pylint: enable=protected-access
354    else:
355      yield path_prefix + path, child
356
357
358def _flatten_module(module,
359                    recursive,
360                    predicate,
361                    attribute_traversal_key,
362                    attributes_to_ignore,
363                    with_path,
364                    expand_composites,
365                    module_path=(),
366                    seen=None,
367                    recursion_stack=None):
368  """Implementation of `flatten`.
369
370  Args:
371    module: Current module to process.
372    recursive: Whether to recurse into child modules or not.
373    predicate: (Optional) If set then only values matching predicate are
374      yielded. A value of `None` (the default) means no items will be
375      filtered.
376    attribute_traversal_key: (Optional) Method to rekey object attributes
377      before they are sorted. Contract is the same as `key` argument to
378      builtin `sorted` and only applies to object properties.
379    attributes_to_ignore: object attributes to ignored.
380    with_path: (Optional) Whether to include the path to the object as well
381      as the object itself. If `with_path` is `True` then leaves will not be
382      de-duplicated (e.g. if the same leaf instance is reachable via multiple
383      modules then it will be yielded multiple times with different paths).
384    expand_composites: If true, then composite tensors are expanded into their
385      component tensors.
386    module_path: The path to the current module as a tuple.
387    seen: A set containing all leaf IDs seen so far.
388    recursion_stack: A list containing all module IDs associated with the
389      current call stack.
390
391  Yields:
392    Matched leaves with the optional corresponding paths of the current module
393    and optionally all its submodules.
394  """
395  module_id = id(module)
396  if seen is None:
397    seen = set([module_id])
398
399  module_dict = vars(module)
400  submodules = []
401
402  if recursion_stack is None:
403    recursion_stack = []
404
405  # When calling `_flatten_module` with `with_path=False`, the global lookup
406  # table `seen` guarantees the uniqueness of the matched objects.
407  # In the case of `with_path=True`, there might be multiple paths associated
408  # with the same predicate, so we don't stop traversing according to `seen`
409  # to make sure all these paths are returned.
410  # When there are cycles connecting submodules, we break cycles by avoiding
411  # following back edges (links pointing to a node in `recursion_stack`).
412  if module_id in recursion_stack:
413    recursive = False
414
415  for key in sorted(module_dict, key=attribute_traversal_key):
416    if key in attributes_to_ignore:
417      continue
418
419    prop = module_dict[key]
420    try:
421      if expand_composites:
422        leaves = list(_flatten_non_variable_composites_with_tuple_path(prop))
423      else:
424        leaves = nest.flatten_with_tuple_paths(prop)
425    except Exception as cause:  # pylint: disable=broad-except
426      raise ValueError("Error processing property {!r} of {!r}".format(
427          key, prop)) from cause
428
429    for leaf_path, leaf in leaves:
430      leaf_path = (key,) + leaf_path
431
432      if not with_path:
433        leaf_id = id(leaf)
434        if leaf_id in seen:
435          continue
436        seen.add(leaf_id)
437
438      if predicate(leaf):
439        if with_path:
440          yield module_path + leaf_path, leaf
441        else:
442          yield leaf
443
444      if recursive and _is_module(leaf):
445        # Walk direct properties first then recurse.
446        submodules.append((module_path + leaf_path, leaf))
447
448  recursion_stack.append(module_id)
449
450  for submodule_path, submodule in submodules:
451    subvalues = _flatten_module(
452        submodule,
453        recursive=recursive,
454        predicate=predicate,
455        attribute_traversal_key=attribute_traversal_key,
456        attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES,  # pylint: disable=protected-access
457        with_path=with_path,
458        expand_composites=expand_composites,
459        module_path=submodule_path,
460        seen=seen,
461        recursion_stack=recursion_stack)
462
463    for subvalue in subvalues:
464      # Predicate is already tested for these values.
465      yield subvalue
466
467  recursion_stack.pop()
468