1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to layer/model functionality.""" 17 18import functools 19import weakref 20 21import numpy as np 22 23from tensorflow.python.util import nest 24from tensorflow.python.util.tf_export import keras_export 25 26 27@keras_export('keras.utils.get_source_inputs') 28def get_source_inputs(tensor, layer=None, node_index=None): 29 """Returns the list of input tensors necessary to compute `tensor`. 30 31 Output will always be a list of tensors 32 (potentially with 1 element). 33 34 Args: 35 tensor: The tensor to start from. 36 layer: Origin layer of the tensor. Will be 37 determined via tensor._keras_history if not provided. 38 node_index: Origin node index of the tensor. 39 40 Returns: 41 List of input tensors. 42 """ 43 if not hasattr(tensor, '_keras_history'): 44 return tensor 45 46 if layer is None or node_index: 47 layer, node_index, _ = tensor._keras_history 48 if not layer._inbound_nodes: 49 return [tensor] 50 else: 51 node = layer._inbound_nodes[node_index] 52 if node.is_input: 53 # Reached an Input layer, stop recursion. 54 return nest.flatten(node.input_tensors) 55 else: 56 source_tensors = [] 57 for layer, node_index, _, tensor in node.iterate_inbound(): 58 previous_sources = get_source_inputs(tensor, layer, node_index) 59 # Avoid input redundancy. 60 for x in previous_sources: 61 if all(x is not t for t in source_tensors): 62 source_tensors.append(x) 63 return source_tensors 64 65 66def validate_string_arg(input_data, 67 allowable_strings, 68 layer_name, 69 arg_name, 70 allow_none=False, 71 allow_callables=False): 72 """Validates the correctness of a string-based arg.""" 73 if allow_none and input_data is None: 74 return 75 elif allow_callables and callable(input_data): 76 return 77 elif isinstance(input_data, str) and input_data in allowable_strings: 78 return 79 else: 80 allowed_args = '`None`, ' if allow_none else '' 81 allowed_args += 'a `Callable`, ' if allow_callables else '' 82 allowed_args += 'or one of the following values: %s' % (allowable_strings,) 83 raise ValueError(('The %s argument of layer %s received an invalid ' 84 'value %s. Allowed values are: %s.') % 85 (arg_name, layer_name, input_data, allowed_args)) 86 87 88def count_params(weights): 89 """Count the total number of scalars composing the weights. 90 91 Args: 92 weights: An iterable containing the weights on which to compute params 93 94 Returns: 95 The total number of scalars composing the weights 96 """ 97 unique_weights = {id(w): w for w in weights}.values() 98 weight_shapes = [w.shape.as_list() for w in unique_weights] 99 standardized_weight_shapes = [ 100 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes 101 ] 102 return int(sum(np.prod(p) for p in standardized_weight_shapes)) 103 104 105def print_summary(model, line_length=None, positions=None, print_fn=None): 106 """Prints a summary of a model. 107 108 Args: 109 model: Keras model instance. 110 line_length: Total length of printed lines 111 (e.g. set this to adapt the display to different 112 terminal window sizes). 113 positions: Relative or absolute positions of log elements in each line. 114 If not provided, defaults to `[.33, .55, .67, 1.]`. 115 print_fn: Print function to use. 116 It will be called on each line of the summary. 117 You can set it to a custom function 118 in order to capture the string summary. 119 It defaults to `print` (prints to stdout). 120 """ 121 if print_fn is None: 122 print_fn = print 123 124 if model.__class__.__name__ == 'Sequential': 125 sequential_like = True 126 elif not model._is_graph_network: 127 # We treat subclassed models as a simple sequence of layers, for logging 128 # purposes. 129 sequential_like = True 130 else: 131 sequential_like = True 132 nodes_by_depth = model._nodes_by_depth.values() 133 nodes = [] 134 for v in nodes_by_depth: 135 if (len(v) > 1) or (len(v) == 1 and 136 len(nest.flatten(v[0].keras_inputs)) > 1): 137 # if the model has multiple nodes 138 # or if the nodes have multiple inbound_layers 139 # the model is no longer sequential 140 sequential_like = False 141 break 142 nodes += v 143 if sequential_like: 144 # search for shared layers 145 for layer in model.layers: 146 flag = False 147 for node in layer._inbound_nodes: 148 if node in nodes: 149 if flag: 150 sequential_like = False 151 break 152 else: 153 flag = True 154 if not sequential_like: 155 break 156 157 if sequential_like: 158 line_length = line_length or 65 159 positions = positions or [.45, .85, 1.] 160 if positions[-1] <= 1: 161 positions = [int(line_length * p) for p in positions] 162 # header names for the different log elements 163 to_display = ['Layer (type)', 'Output Shape', 'Param #'] 164 else: 165 line_length = line_length or 98 166 positions = positions or [.33, .55, .67, 1.] 167 if positions[-1] <= 1: 168 positions = [int(line_length * p) for p in positions] 169 # header names for the different log elements 170 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 171 relevant_nodes = [] 172 for v in model._nodes_by_depth.values(): 173 relevant_nodes += v 174 175 def print_row(fields, positions): 176 line = '' 177 for i in range(len(fields)): 178 if i > 0: 179 line = line[:-1] + ' ' 180 line += str(fields[i]) 181 line = line[:positions[i]] 182 line += ' ' * (positions[i] - len(line)) 183 print_fn(line) 184 185 print_fn('Model: "{}"'.format(model.name)) 186 print_fn('_' * line_length) 187 print_row(to_display, positions) 188 print_fn('=' * line_length) 189 190 def print_layer_summary(layer): 191 """Prints a summary for a single layer. 192 193 Args: 194 layer: target layer. 195 """ 196 try: 197 output_shape = layer.output_shape 198 except AttributeError: 199 output_shape = 'multiple' 200 except RuntimeError: # output_shape unknown in Eager mode. 201 output_shape = '?' 202 name = layer.name 203 cls_name = layer.__class__.__name__ 204 if not layer.built and not getattr(layer, '_is_graph_network', False): 205 # If a subclassed model has a layer that is not called in Model.call, the 206 # layer will not be built and we cannot call layer.count_params(). 207 params = '0 (unused)' 208 else: 209 params = layer.count_params() 210 fields = [name + ' (' + cls_name + ')', output_shape, params] 211 print_row(fields, positions) 212 213 def print_layer_summary_with_connections(layer): 214 """Prints a summary for a single layer (including topological connections). 215 216 Args: 217 layer: target layer. 218 """ 219 try: 220 output_shape = layer.output_shape 221 except AttributeError: 222 output_shape = 'multiple' 223 connections = [] 224 for node in layer._inbound_nodes: 225 if relevant_nodes and node not in relevant_nodes: 226 # node is not part of the current network 227 continue 228 229 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 230 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index, 231 tensor_index)) 232 233 name = layer.name 234 cls_name = layer.__class__.__name__ 235 if not connections: 236 first_connection = '' 237 else: 238 first_connection = connections[0] 239 fields = [ 240 name + ' (' + cls_name + ')', output_shape, 241 layer.count_params(), first_connection 242 ] 243 print_row(fields, positions) 244 if len(connections) > 1: 245 for i in range(1, len(connections)): 246 fields = ['', '', '', connections[i]] 247 print_row(fields, positions) 248 249 layers = model.layers 250 for i in range(len(layers)): 251 if sequential_like: 252 print_layer_summary(layers[i]) 253 else: 254 print_layer_summary_with_connections(layers[i]) 255 if i == len(layers) - 1: 256 print_fn('=' * line_length) 257 else: 258 print_fn('_' * line_length) 259 260 if hasattr(model, '_collected_trainable_weights'): 261 trainable_count = count_params(model._collected_trainable_weights) 262 else: 263 trainable_count = count_params(model.trainable_weights) 264 265 non_trainable_count = count_params(model.non_trainable_weights) 266 267 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 268 print_fn('Trainable params: {:,}'.format(trainable_count)) 269 print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 270 print_fn('_' * line_length) 271 272 273def convert_dense_weights_data_format(dense, 274 previous_feature_map_shape, 275 target_data_format='channels_first'): 276 """Utility useful when changing a convnet's `data_format`. 277 278 When porting the weights of a convnet from one data format to the other, 279 if the convnet includes a `Flatten` layer 280 (applied to the last convolutional feature map) 281 followed by a `Dense` layer, the weights of that `Dense` layer 282 should be updated to reflect the new dimension ordering. 283 284 Args: 285 dense: The target `Dense` layer. 286 previous_feature_map_shape: A shape tuple of 3 integers, 287 e.g. `(512, 7, 7)`. The shape of the convolutional 288 feature map right before the `Flatten` layer that 289 came before the target `Dense` layer. 290 target_data_format: One of "channels_last", "channels_first". 291 Set it "channels_last" 292 if converting a "channels_first" model to "channels_last", 293 or reciprocally. 294 """ 295 assert target_data_format in {'channels_last', 'channels_first'} 296 kernel, bias = dense.get_weights() 297 for i in range(kernel.shape[1]): 298 if target_data_format == 'channels_first': 299 c, h, w = previous_feature_map_shape 300 original_fm_shape = (h, w, c) 301 ki = kernel[:, i].reshape(original_fm_shape) 302 ki = np.transpose(ki, (2, 0, 1)) # last -> first 303 else: 304 h, w, c = previous_feature_map_shape 305 original_fm_shape = (c, h, w) 306 ki = kernel[:, i].reshape(original_fm_shape) 307 ki = np.transpose(ki, (1, 2, 0)) # first -> last 308 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 309 dense.set_weights([kernel, bias]) 310 311 312def is_builtin_layer(layer): 313 if not getattr(layer, '_keras_api_names', None): 314 return False 315 316 # Subclasses of `Layer` that are not exported inherit the export name 317 # of the base layer class. 318 return (layer._keras_api_names != ('keras.layers.Layer',) and 319 layer._keras_api_names_v1 != ('keras.layers.Layer',)) 320 321 322def cached_per_instance(f): 323 """Lightweight decorator for caching lazily constructed properties. 324 325 When to use: 326 This decorator provides simple caching with minimal overhead. It is designed 327 for properties which are expensive to compute and static over the life of a 328 class instance, and provides no mechanism for cache invalidation. Thus it is 329 best suited for lazily exposing derived properties of other static data. 330 331 For classes with custom getattr / setattr behavior (such as trackable 332 objects), storing cache results as object attributes is not performant. 333 Instead, a specialized cache can significantly reduce property lookup 334 overhead. (While still allowing the decorated property to be lazily computed.) 335 Consider the following class: 336 337 ``` 338 class MyClass(object): 339 def __setattr__(self, key, value): 340 # Some expensive class specific code 341 # ... 342 # ... 343 344 super(MyClass, self).__setattr__(key, value) 345 346 @property 347 def thing(self): 348 # `thing` is expensive to compute (and may not even be requested), so we 349 # want to lazily compute it and then cache it. 350 output = getattr(self, '_thing', None) 351 if output is None: 352 self._thing = output = compute_thing(self) 353 return output 354 ``` 355 356 It's also worth noting that ANY overriding of __setattr__, even something as 357 simple as: 358 ``` 359 def __setattr__(self, key, value): 360 super(MyClass, self).__setattr__(key, value) 361 ``` 362 363 Slows down attribute assignment by nearly 10x. 364 365 By contrast, replacing the definition of `thing` with the following sidesteps 366 the expensive __setattr__ altogether: 367 368 ''' 369 @property 370 @tracking.cached_per_instance 371 def thing(self): 372 # `thing` is expensive to compute (and may not even be requested), so we 373 # want to lazily compute it and then cache it. 374 return compute_thing(self) 375 ''' 376 377 Performance: 378 The overhead for this decorator is ~0.4 us / call. A much lower overhead 379 implementation (~0.085 us / call) can be achieved by using a custom dict type: 380 381 ``` 382 def dict_based_cache(f): 383 class Cache(dict): 384 __slots__ = () 385 def __missing__(self, key): 386 self[key] = output = f(key) 387 return output 388 389 return property(Cache().__getitem__) 390 ``` 391 392 However, that implementation holds class instances as keys, and as a result 393 blocks garbage collection. (And modifying it to use weakref's as keys raises 394 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary 395 implementation below turns out to be more prudent. 396 397 Args: 398 f: The function to cache. 399 400 Returns: 401 f decorated with simple caching behavior. 402 """ 403 404 cache = weakref.WeakKeyDictionary() 405 406 @functools.wraps(f) 407 def wrapped(item): 408 output = cache.get(item) 409 if output is None: 410 cache[item] = output = f(item) 411 return output 412 413 wrapped.cache = cache 414 return wrapped 415 416 417def filter_empty_layer_containers(layer_list): 418 """Filter out empty Layer-like containers and uniquify.""" 419 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 420 existing = set() 421 to_visit = layer_list[::-1] 422 while to_visit: 423 obj = to_visit.pop() 424 if id(obj) in existing: 425 continue 426 existing.add(id(obj)) 427 if hasattr(obj, '_is_layer') and not isinstance(obj, type): 428 yield obj 429 else: 430 sub_layers = getattr(obj, 'layers', None) or [] 431 432 # Trackable data structures will not show up in ".layers" lists, but 433 # the layers they contain will. 434 to_visit.extend(sub_layers[::-1]) 435