1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16# pylint: disable=g-classes-have-attributes 17"""Contains the `Node` class.""" 18 19import collections 20import copy 21import json 22import numpy as np 23 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.keras import backend 27from tensorflow.python.keras.engine import base_layer_utils 28from tensorflow.python.keras.saving.saved_model import json_utils 29from tensorflow.python.keras.utils import tf_utils 30from tensorflow.python.util import nest 31 32_CONSTANT_VALUE = '_CONSTANT_VALUE' 33 34 35class Node: 36 """A `Node` describes the connectivity between two layers. 37 38 Each time a layer is connected to some new input, 39 a node is added to `layer._inbound_nodes`. 40 Each time the output of a layer is used by another layer, 41 a node is added to `layer._outbound_nodes`. 42 43 Args: 44 layer: The Layer for the Layer.__call__ this node represents. 45 call_args: The positional arguments the Layer was called with. 46 call_kwargs: The keyword arguments the Layer was called with. 47 outputs: The outputs of the Layer.__call__ 48 """ 49 50 def __init__(self, 51 layer, 52 call_args=None, 53 call_kwargs=None, 54 outputs=None): 55 call_args = [] if call_args is None else call_args 56 call_kwargs = {} if call_kwargs is None else call_kwargs 57 outputs = [] if outputs is None else outputs 58 59 self.layer = layer 60 self.is_input = not call_args and not call_kwargs 61 62 # These arguments are user-provided. Copy the structures here so that 63 # future user modifications do not affect the node's metadata. 64 # We copy using map_structure rather than python's shallow or deep copy, 65 # because the args can be data structures (so shallow copy is 66 # insufficient), but individual values might not support copy.copy 67 # or be too expensive to deep copy. 68 call_args = nest.map_structure(lambda t: t, call_args) 69 call_kwargs = nest.map_structure(lambda t: t, call_kwargs) 70 self.outputs = nest.map_structure(lambda t: t, outputs) 71 self.call_args = call_args 72 self.call_kwargs = call_kwargs 73 74 # Cached for performance. 75 self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) 76 # Used to avoid expensive `nest` operations in the most common case. 77 self._single_positional_tensor_passed = (not self.call_kwargs and len( 78 self.call_args) == 1 and tensor_util.is_tf_type(self.call_args[0])) 79 80 if not ops.executing_eagerly_outside_functions(): 81 # Create TensorFlowOpLayers if needed (in TF1) 82 for obj in self._flat_arguments: 83 if (isinstance(obj, ops.Tensor) and 84 base_layer_utils.needs_keras_history( 85 obj, ignore_call_context=True)): 86 base_layer_utils.create_keras_history(obj) 87 88 self._keras_inputs = [] 89 self._keras_inputs_ids_and_indices = [] 90 for i, ele in enumerate(self._flat_arguments): 91 if is_keras_tensor(ele): 92 self._keras_inputs.append(ele) 93 kt_id = str(id(ele)) 94 kt_index = i 95 self._keras_inputs_ids_and_indices.append((kt_id, kt_index)) 96 97 # Wire up Node to Layers. 98 self.layer._inbound_nodes.append(self) 99 for kt in self.keras_inputs: 100 inbound_layer = kt._keras_history.layer 101 if inbound_layer is not None: # `None` for `Input` tensors. 102 inbound_layer._outbound_nodes.append(self) 103 104 # Set metadata on outputs. 105 node_index = len(self.layer._inbound_nodes) - 1 106 for i, tensor in enumerate(nest.flatten(outputs)): 107 tensor._keras_history = KerasHistory( 108 layer=layer, node_index=node_index, tensor_index=i) 109 110 # Cached for performance. 111 self.flat_input_ids = [str(id(t)) for t in self._keras_inputs] 112 self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)] 113 114 @property 115 def keras_inputs(self): 116 """Tensors input to this node that can be traced back to a `keras.Input`.""" 117 return self._keras_inputs 118 119 @property 120 def parent_nodes(self): 121 """Returns all the `Node`s whose output this node immediately depends on.""" 122 node_deps = [] 123 for kt in self.keras_inputs: 124 layer = kt._keras_history.layer 125 node_index = kt._keras_history.node_index 126 if layer is not None: # `None` for `Input` tensors. 127 node_deps.append(layer._inbound_nodes[node_index]) 128 return node_deps 129 130 def iterate_inbound(self): 131 """Yields tuples representing the data inbound from other nodes. 132 133 Yields: 134 tuples like: (inbound_layer, node_index, tensor_index, tensor). 135 """ 136 for kt in self.keras_inputs: 137 keras_history = kt._keras_history 138 layer = keras_history.layer 139 node_index = keras_history.node_index 140 tensor_index = keras_history.tensor_index 141 yield layer, node_index, tensor_index, kt 142 143 def map_arguments(self, tensor_dict): 144 """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" 145 if self._single_positional_tensor_passed: 146 # Performance optimization for most common case. 147 kt_id, _ = self._keras_inputs_ids_and_indices[0] 148 return (tensor_dict[kt_id].pop(),), {} 149 else: 150 flat_arguments = copy.copy(self._flat_arguments) 151 for kt_id, kt_index in self._keras_inputs_ids_and_indices: 152 flat_arguments[kt_index] = tensor_dict[kt_id].pop() 153 154 args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs), 155 flat_arguments) 156 return args, kwargs 157 158 def serialize(self, make_node_key, node_conversion_map): 159 """Serializes `Node` for Functional API's `get_config`.""" 160 # Serialization still special-cases first argument. 161 args, kwargs = self.call_args, self.call_kwargs 162 inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) 163 164 # Treat everything other than first argument as a kwarg. 165 arguments = dict(zip(self.layer._call_fn_args[1:], args)) 166 arguments.update(kwargs) 167 kwargs = arguments 168 169 def _serialize_keras_tensor(t): 170 """Serializes a single Tensor passed to `call`.""" 171 if hasattr(t, '_keras_history'): 172 kh = t._keras_history 173 node_index = kh.node_index 174 node_key = make_node_key(kh.layer.name, node_index) 175 new_node_index = node_conversion_map.get(node_key, 0) 176 return [kh.layer.name, new_node_index, kh.tensor_index] 177 178 if isinstance(t, np.ndarray): 179 return t.tolist() 180 181 if isinstance(t, ops.Tensor): 182 return backend.get_value(t).tolist() 183 184 return t 185 186 kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) 187 try: 188 json.dumps(kwargs, default=json_utils.get_json_type) 189 except TypeError: 190 kwarg_types = nest.map_structure(type, kwargs) 191 raise TypeError('Layer ' + self.layer.name + 192 ' was passed non-JSON-serializable arguments. ' + 193 'Arguments had types: ' + 194 str(kwarg_types) + '. They cannot be serialized out ' 195 'when saving the model.') 196 197 # `kwargs` is added to each Tensor in the first arg. This should be 198 # changed in a future version of the serialization format. 199 def serialize_first_arg_tensor(t): 200 if is_keras_tensor(t): 201 kh = t._keras_history 202 node_index = kh.node_index 203 node_key = make_node_key(kh.layer.name, node_index) 204 new_node_index = node_conversion_map.get(node_key, 0) 205 data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] 206 else: 207 # If an element in the first call argument did not originate as a 208 # keras tensor and is a constant value, we save it using the format 209 # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant] 210 # (potentially including serialized kwargs in an optional 4th argument 211 data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs] 212 return tf_utils.ListWrapper(data) 213 214 data = nest.map_structure(serialize_first_arg_tensor, inputs) 215 if (not nest.is_nested(data) and 216 not self.layer._preserve_input_structure_in_config): 217 data = [data] 218 data = tf_utils.convert_inner_node_data(data) 219 return data 220 221 ############################################################# 222 # Properties for Backwards compatibility. 223 # These only check the first input argument 224 # As nodes are internal, they may be removed in the future. 225 ############################################################# 226 227 @property 228 def input_tensors(self): 229 if self.is_input: 230 return [self.outputs] # Used in `Layer.input`. 231 return self.call_args[0] 232 233 @property 234 def output_tensors(self): 235 if self.is_input: 236 return [self.outputs] # Used in `Layer.input`. 237 return self.outputs 238 239 @property 240 def input_shapes(self): 241 input_shapes = nest.map_structure(backend.int_shape, self.input_tensors) 242 if len(input_shapes) == 1 and not self.is_input: 243 return input_shapes[0] 244 return input_shapes 245 246 @property 247 def output_shapes(self): 248 return nest.map_structure(backend.int_shape, self.output_tensors) 249 250 @property 251 def outbound_layer(self): 252 return self.layer 253 254 @property 255 def inbound_layers(self): 256 if self.is_input: 257 return [] 258 inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, 259 self.call_args[0]) 260 return inbound_layers 261 262 263class KerasHistory( 264 collections.namedtuple('KerasHistory', 265 ['layer', 'node_index', 'tensor_index'])): 266 """Tracks the Layer call that created a Tensor, for Keras Graph Networks. 267 268 During construction of Keras Graph Networks, this metadata is added to 269 each Tensor produced as the output of a Layer, starting with an 270 `InputLayer`. This allows Keras to track how each Tensor was produced, and 271 this information is later retraced by the `keras.engine.Network` class to 272 reconstruct the Keras Graph Network. 273 274 Attributes: 275 layer: The Layer that produced the Tensor. 276 node_index: The specific call to the Layer that produced this Tensor. Layers 277 can be called multiple times in order to share weights. A new node is 278 created every time a Layer is called. 279 tensor_index: The output index for this Tensor. Always zero if the Layer 280 that produced this Tensor only has one output. Nested structures of 281 Tensors are deterministically assigned an index via `nest.flatten`. 282 """ 283 # Added to maintain memory and performance characteristics of `namedtuple` 284 # while subclassing. 285 __slots__ = () 286 287 288def is_keras_tensor(obj): 289 return hasattr(obj, '_keras_history') 290