xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/layer_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to layer/model functionality."""
17
18import functools
19import weakref
20
21import numpy as np
22
23from tensorflow.python.util import nest
24from tensorflow.python.util.tf_export import keras_export
25
26
27@keras_export('keras.utils.get_source_inputs')
28def get_source_inputs(tensor, layer=None, node_index=None):
29  """Returns the list of input tensors necessary to compute `tensor`.
30
31  Output will always be a list of tensors
32  (potentially with 1 element).
33
34  Args:
35      tensor: The tensor to start from.
36      layer: Origin layer of the tensor. Will be
37          determined via tensor._keras_history if not provided.
38      node_index: Origin node index of the tensor.
39
40  Returns:
41      List of input tensors.
42  """
43  if not hasattr(tensor, '_keras_history'):
44    return tensor
45
46  if layer is None or node_index:
47    layer, node_index, _ = tensor._keras_history
48  if not layer._inbound_nodes:
49    return [tensor]
50  else:
51    node = layer._inbound_nodes[node_index]
52    if node.is_input:
53      # Reached an Input layer, stop recursion.
54      return nest.flatten(node.input_tensors)
55    else:
56      source_tensors = []
57      for layer, node_index, _, tensor in node.iterate_inbound():
58        previous_sources = get_source_inputs(tensor, layer, node_index)
59        # Avoid input redundancy.
60        for x in previous_sources:
61          if all(x is not t for t in source_tensors):
62            source_tensors.append(x)
63      return source_tensors
64
65
66def validate_string_arg(input_data,
67                        allowable_strings,
68                        layer_name,
69                        arg_name,
70                        allow_none=False,
71                        allow_callables=False):
72  """Validates the correctness of a string-based arg."""
73  if allow_none and input_data is None:
74    return
75  elif allow_callables and callable(input_data):
76    return
77  elif isinstance(input_data, str) and input_data in allowable_strings:
78    return
79  else:
80    allowed_args = '`None`, ' if allow_none else ''
81    allowed_args += 'a `Callable`, ' if allow_callables else ''
82    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
83    raise ValueError(('The %s argument of layer %s received an invalid '
84                      'value %s. Allowed values are: %s.') %
85                     (arg_name, layer_name, input_data, allowed_args))
86
87
88def count_params(weights):
89  """Count the total number of scalars composing the weights.
90
91  Args:
92      weights: An iterable containing the weights on which to compute params
93
94  Returns:
95      The total number of scalars composing the weights
96  """
97  unique_weights = {id(w): w for w in weights}.values()
98  weight_shapes = [w.shape.as_list() for w in unique_weights]
99  standardized_weight_shapes = [
100      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
101  ]
102  return int(sum(np.prod(p) for p in standardized_weight_shapes))
103
104
105def print_summary(model, line_length=None, positions=None, print_fn=None):
106  """Prints a summary of a model.
107
108  Args:
109      model: Keras model instance.
110      line_length: Total length of printed lines
111          (e.g. set this to adapt the display to different
112          terminal window sizes).
113      positions: Relative or absolute positions of log elements in each line.
114          If not provided, defaults to `[.33, .55, .67, 1.]`.
115      print_fn: Print function to use.
116          It will be called on each line of the summary.
117          You can set it to a custom function
118          in order to capture the string summary.
119          It defaults to `print` (prints to stdout).
120  """
121  if print_fn is None:
122    print_fn = print
123
124  if model.__class__.__name__ == 'Sequential':
125    sequential_like = True
126  elif not model._is_graph_network:
127    # We treat subclassed models as a simple sequence of layers, for logging
128    # purposes.
129    sequential_like = True
130  else:
131    sequential_like = True
132    nodes_by_depth = model._nodes_by_depth.values()
133    nodes = []
134    for v in nodes_by_depth:
135      if (len(v) > 1) or (len(v) == 1 and
136                          len(nest.flatten(v[0].keras_inputs)) > 1):
137        # if the model has multiple nodes
138        # or if the nodes have multiple inbound_layers
139        # the model is no longer sequential
140        sequential_like = False
141        break
142      nodes += v
143    if sequential_like:
144      # search for shared layers
145      for layer in model.layers:
146        flag = False
147        for node in layer._inbound_nodes:
148          if node in nodes:
149            if flag:
150              sequential_like = False
151              break
152            else:
153              flag = True
154        if not sequential_like:
155          break
156
157  if sequential_like:
158    line_length = line_length or 65
159    positions = positions or [.45, .85, 1.]
160    if positions[-1] <= 1:
161      positions = [int(line_length * p) for p in positions]
162    # header names for the different log elements
163    to_display = ['Layer (type)', 'Output Shape', 'Param #']
164  else:
165    line_length = line_length or 98
166    positions = positions or [.33, .55, .67, 1.]
167    if positions[-1] <= 1:
168      positions = [int(line_length * p) for p in positions]
169    # header names for the different log elements
170    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
171    relevant_nodes = []
172    for v in model._nodes_by_depth.values():
173      relevant_nodes += v
174
175  def print_row(fields, positions):
176    line = ''
177    for i in range(len(fields)):
178      if i > 0:
179        line = line[:-1] + ' '
180      line += str(fields[i])
181      line = line[:positions[i]]
182      line += ' ' * (positions[i] - len(line))
183    print_fn(line)
184
185  print_fn('Model: "{}"'.format(model.name))
186  print_fn('_' * line_length)
187  print_row(to_display, positions)
188  print_fn('=' * line_length)
189
190  def print_layer_summary(layer):
191    """Prints a summary for a single layer.
192
193    Args:
194        layer: target layer.
195    """
196    try:
197      output_shape = layer.output_shape
198    except AttributeError:
199      output_shape = 'multiple'
200    except RuntimeError:  # output_shape unknown in Eager mode.
201      output_shape = '?'
202    name = layer.name
203    cls_name = layer.__class__.__name__
204    if not layer.built and not getattr(layer, '_is_graph_network', False):
205      # If a subclassed model has a layer that is not called in Model.call, the
206      # layer will not be built and we cannot call layer.count_params().
207      params = '0 (unused)'
208    else:
209      params = layer.count_params()
210    fields = [name + ' (' + cls_name + ')', output_shape, params]
211    print_row(fields, positions)
212
213  def print_layer_summary_with_connections(layer):
214    """Prints a summary for a single layer (including topological connections).
215
216    Args:
217        layer: target layer.
218    """
219    try:
220      output_shape = layer.output_shape
221    except AttributeError:
222      output_shape = 'multiple'
223    connections = []
224    for node in layer._inbound_nodes:
225      if relevant_nodes and node not in relevant_nodes:
226        # node is not part of the current network
227        continue
228
229      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
230        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
231                                               tensor_index))
232
233    name = layer.name
234    cls_name = layer.__class__.__name__
235    if not connections:
236      first_connection = ''
237    else:
238      first_connection = connections[0]
239    fields = [
240        name + ' (' + cls_name + ')', output_shape,
241        layer.count_params(), first_connection
242    ]
243    print_row(fields, positions)
244    if len(connections) > 1:
245      for i in range(1, len(connections)):
246        fields = ['', '', '', connections[i]]
247        print_row(fields, positions)
248
249  layers = model.layers
250  for i in range(len(layers)):
251    if sequential_like:
252      print_layer_summary(layers[i])
253    else:
254      print_layer_summary_with_connections(layers[i])
255    if i == len(layers) - 1:
256      print_fn('=' * line_length)
257    else:
258      print_fn('_' * line_length)
259
260  if hasattr(model, '_collected_trainable_weights'):
261    trainable_count = count_params(model._collected_trainable_weights)
262  else:
263    trainable_count = count_params(model.trainable_weights)
264
265  non_trainable_count = count_params(model.non_trainable_weights)
266
267  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
268  print_fn('Trainable params: {:,}'.format(trainable_count))
269  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
270  print_fn('_' * line_length)
271
272
273def convert_dense_weights_data_format(dense,
274                                      previous_feature_map_shape,
275                                      target_data_format='channels_first'):
276  """Utility useful when changing a convnet's `data_format`.
277
278  When porting the weights of a convnet from one data format to the other,
279  if the convnet includes a `Flatten` layer
280  (applied to the last convolutional feature map)
281  followed by a `Dense` layer, the weights of that `Dense` layer
282  should be updated to reflect the new dimension ordering.
283
284  Args:
285      dense: The target `Dense` layer.
286      previous_feature_map_shape: A shape tuple of 3 integers,
287          e.g. `(512, 7, 7)`. The shape of the convolutional
288          feature map right before the `Flatten` layer that
289          came before the target `Dense` layer.
290      target_data_format: One of "channels_last", "channels_first".
291          Set it "channels_last"
292          if converting a "channels_first" model to "channels_last",
293          or reciprocally.
294  """
295  assert target_data_format in {'channels_last', 'channels_first'}
296  kernel, bias = dense.get_weights()
297  for i in range(kernel.shape[1]):
298    if target_data_format == 'channels_first':
299      c, h, w = previous_feature_map_shape
300      original_fm_shape = (h, w, c)
301      ki = kernel[:, i].reshape(original_fm_shape)
302      ki = np.transpose(ki, (2, 0, 1))  # last -> first
303    else:
304      h, w, c = previous_feature_map_shape
305      original_fm_shape = (c, h, w)
306      ki = kernel[:, i].reshape(original_fm_shape)
307      ki = np.transpose(ki, (1, 2, 0))  # first -> last
308    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
309  dense.set_weights([kernel, bias])
310
311
312def is_builtin_layer(layer):
313  if not getattr(layer, '_keras_api_names', None):
314    return False
315
316  # Subclasses of `Layer` that are not exported inherit the export name
317  # of the base layer class.
318  return (layer._keras_api_names != ('keras.layers.Layer',) and
319          layer._keras_api_names_v1 != ('keras.layers.Layer',))
320
321
322def cached_per_instance(f):
323  """Lightweight decorator for caching lazily constructed properties.
324
325  When to use:
326  This decorator provides simple caching with minimal overhead. It is designed
327  for properties which are expensive to compute and static over the life of a
328  class instance, and provides no mechanism for cache invalidation. Thus it is
329  best suited for lazily exposing derived properties of other static data.
330
331  For classes with custom getattr / setattr behavior (such as trackable
332  objects), storing cache results as object attributes is not performant.
333  Instead, a specialized cache can significantly reduce property lookup
334  overhead. (While still allowing the decorated property to be lazily computed.)
335  Consider the following class:
336
337  ```
338  class MyClass(object):
339    def __setattr__(self, key, value):
340      # Some expensive class specific code
341      # ...
342      # ...
343
344      super(MyClass, self).__setattr__(key, value)
345
346    @property
347    def thing(self):
348      # `thing` is expensive to compute (and may not even be requested), so we
349      # want to lazily compute it and then cache it.
350      output = getattr(self, '_thing', None)
351      if output is None:
352        self._thing = output = compute_thing(self)
353      return output
354  ```
355
356  It's also worth noting that ANY overriding of __setattr__, even something as
357  simple as:
358  ```
359    def __setattr__(self, key, value):
360      super(MyClass, self).__setattr__(key, value)
361  ```
362
363  Slows down attribute assignment by nearly 10x.
364
365  By contrast, replacing the definition of `thing` with the following sidesteps
366  the expensive __setattr__ altogether:
367
368  '''
369  @property
370  @tracking.cached_per_instance
371  def thing(self):
372    # `thing` is expensive to compute (and may not even be requested), so we
373    # want to lazily compute it and then cache it.
374    return compute_thing(self)
375  '''
376
377  Performance:
378  The overhead for this decorator is ~0.4 us / call. A much lower overhead
379  implementation (~0.085 us / call) can be achieved by using a custom dict type:
380
381  ```
382  def dict_based_cache(f):
383    class Cache(dict):
384      __slots__ = ()
385      def __missing__(self, key):
386        self[key] = output = f(key)
387        return output
388
389    return property(Cache().__getitem__)
390  ```
391
392  However, that implementation holds class instances as keys, and as a result
393  blocks garbage collection. (And modifying it to use weakref's as keys raises
394  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
395  implementation below turns out to be more prudent.
396
397  Args:
398    f: The function to cache.
399
400  Returns:
401    f decorated with simple caching behavior.
402  """
403
404  cache = weakref.WeakKeyDictionary()
405
406  @functools.wraps(f)
407  def wrapped(item):
408    output = cache.get(item)
409    if output is None:
410      cache[item] = output = f(item)
411    return output
412
413  wrapped.cache = cache
414  return wrapped
415
416
417def filter_empty_layer_containers(layer_list):
418  """Filter out empty Layer-like containers and uniquify."""
419  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
420  existing = set()
421  to_visit = layer_list[::-1]
422  while to_visit:
423    obj = to_visit.pop()
424    if id(obj) in existing:
425      continue
426    existing.add(id(obj))
427    if hasattr(obj, '_is_layer') and not isinstance(obj, type):
428      yield obj
429    else:
430      sub_layers = getattr(obj, 'layers', None) or []
431
432      # Trackable data structures will not show up in ".layers" lists, but
433      # the layers they contain will.
434      to_visit.extend(sub_layers[::-1])
435