xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/node.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Contains the `Node` class."""
18
19import collections
20import copy
21import json
22import numpy as np
23
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.engine import base_layer_utils
28from tensorflow.python.keras.saving.saved_model import json_utils
29from tensorflow.python.keras.utils import tf_utils
30from tensorflow.python.util import nest
31
32_CONSTANT_VALUE = '_CONSTANT_VALUE'
33
34
35class Node:
36  """A `Node` describes the connectivity between two layers.
37
38  Each time a layer is connected to some new input,
39  a node is added to `layer._inbound_nodes`.
40  Each time the output of a layer is used by another layer,
41  a node is added to `layer._outbound_nodes`.
42
43  Args:
44      layer: The Layer for the Layer.__call__ this node represents.
45      call_args: The positional arguments the Layer was called with.
46      call_kwargs: The keyword arguments the Layer was called with.
47      outputs: The outputs of the Layer.__call__
48  """
49
50  def __init__(self,
51               layer,
52               call_args=None,
53               call_kwargs=None,
54               outputs=None):
55    call_args = [] if call_args is None else call_args
56    call_kwargs = {} if call_kwargs is None else call_kwargs
57    outputs = [] if outputs is None else outputs
58
59    self.layer = layer
60    self.is_input = not call_args and not call_kwargs
61
62    # These arguments are user-provided. Copy the structures here so that
63    # future user modifications do not affect the node's metadata.
64    # We copy using map_structure rather than python's shallow or deep copy,
65    # because the args can be data structures (so shallow copy is
66    # insufficient), but individual values might not support copy.copy
67    # or be too expensive to deep copy.
68    call_args = nest.map_structure(lambda t: t, call_args)
69    call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
70    self.outputs = nest.map_structure(lambda t: t, outputs)
71    self.call_args = call_args
72    self.call_kwargs = call_kwargs
73
74    # Cached for performance.
75    self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
76    # Used to avoid expensive `nest` operations in the most common case.
77    self._single_positional_tensor_passed = (not self.call_kwargs and len(
78        self.call_args) == 1 and tensor_util.is_tf_type(self.call_args[0]))
79
80    if not ops.executing_eagerly_outside_functions():
81      # Create TensorFlowOpLayers if needed (in TF1)
82      for obj in self._flat_arguments:
83        if (isinstance(obj, ops.Tensor) and
84            base_layer_utils.needs_keras_history(
85                obj, ignore_call_context=True)):
86          base_layer_utils.create_keras_history(obj)
87
88    self._keras_inputs = []
89    self._keras_inputs_ids_and_indices = []
90    for i, ele in enumerate(self._flat_arguments):
91      if is_keras_tensor(ele):
92        self._keras_inputs.append(ele)
93        kt_id = str(id(ele))
94        kt_index = i
95        self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
96
97    # Wire up Node to Layers.
98    self.layer._inbound_nodes.append(self)
99    for kt in self.keras_inputs:
100      inbound_layer = kt._keras_history.layer
101      if inbound_layer is not None:  # `None` for `Input` tensors.
102        inbound_layer._outbound_nodes.append(self)
103
104    # Set metadata on outputs.
105    node_index = len(self.layer._inbound_nodes) - 1
106    for i, tensor in enumerate(nest.flatten(outputs)):
107      tensor._keras_history = KerasHistory(
108          layer=layer, node_index=node_index, tensor_index=i)
109
110    # Cached for performance.
111    self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
112    self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
113
114  @property
115  def keras_inputs(self):
116    """Tensors input to this node that can be traced back to a `keras.Input`."""
117    return self._keras_inputs
118
119  @property
120  def parent_nodes(self):
121    """Returns all the `Node`s whose output this node immediately depends on."""
122    node_deps = []
123    for kt in self.keras_inputs:
124      layer = kt._keras_history.layer
125      node_index = kt._keras_history.node_index
126      if layer is not None:  # `None` for `Input` tensors.
127        node_deps.append(layer._inbound_nodes[node_index])
128    return node_deps
129
130  def iterate_inbound(self):
131    """Yields tuples representing the data inbound from other nodes.
132
133    Yields:
134      tuples like: (inbound_layer, node_index, tensor_index, tensor).
135    """
136    for kt in self.keras_inputs:
137      keras_history = kt._keras_history
138      layer = keras_history.layer
139      node_index = keras_history.node_index
140      tensor_index = keras_history.tensor_index
141      yield layer, node_index, tensor_index, kt
142
143  def map_arguments(self, tensor_dict):
144    """Maps Keras Tensors to computed Tensors using `tensor_dict`."""
145    if self._single_positional_tensor_passed:
146      # Performance optimization for most common case.
147      kt_id, _ = self._keras_inputs_ids_and_indices[0]
148      return (tensor_dict[kt_id].pop(),), {}
149    else:
150      flat_arguments = copy.copy(self._flat_arguments)
151      for kt_id, kt_index in self._keras_inputs_ids_and_indices:
152        flat_arguments[kt_index] = tensor_dict[kt_id].pop()
153
154      args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs),
155                                           flat_arguments)
156      return args, kwargs
157
158  def serialize(self, make_node_key, node_conversion_map):
159    """Serializes `Node` for Functional API's `get_config`."""
160    # Serialization still special-cases first argument.
161    args, kwargs = self.call_args, self.call_kwargs
162    inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs)
163
164    # Treat everything other than first argument as a kwarg.
165    arguments = dict(zip(self.layer._call_fn_args[1:], args))
166    arguments.update(kwargs)
167    kwargs = arguments
168
169    def _serialize_keras_tensor(t):
170      """Serializes a single Tensor passed to `call`."""
171      if hasattr(t, '_keras_history'):
172        kh = t._keras_history
173        node_index = kh.node_index
174        node_key = make_node_key(kh.layer.name, node_index)
175        new_node_index = node_conversion_map.get(node_key, 0)
176        return [kh.layer.name, new_node_index, kh.tensor_index]
177
178      if isinstance(t, np.ndarray):
179        return t.tolist()
180
181      if isinstance(t, ops.Tensor):
182        return backend.get_value(t).tolist()
183
184      return t
185
186    kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
187    try:
188      json.dumps(kwargs, default=json_utils.get_json_type)
189    except TypeError:
190      kwarg_types = nest.map_structure(type, kwargs)
191      raise TypeError('Layer ' + self.layer.name +
192                      ' was passed non-JSON-serializable arguments. ' +
193                      'Arguments had types: ' +
194                      str(kwarg_types) + '. They cannot be serialized out '
195                      'when saving the model.')
196
197    # `kwargs` is added to each Tensor in the first arg. This should be
198    # changed in a future version of the serialization format.
199    def serialize_first_arg_tensor(t):
200      if is_keras_tensor(t):
201        kh = t._keras_history
202        node_index = kh.node_index
203        node_key = make_node_key(kh.layer.name, node_index)
204        new_node_index = node_conversion_map.get(node_key, 0)
205        data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
206      else:
207        # If an element in the first call argument did not originate as a
208        # keras tensor and is a constant value, we save it using the format
209        # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant]
210        # (potentially including serialized kwargs in an optional 4th argument
211        data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
212      return tf_utils.ListWrapper(data)
213
214    data = nest.map_structure(serialize_first_arg_tensor, inputs)
215    if (not nest.is_nested(data) and
216        not self.layer._preserve_input_structure_in_config):
217      data = [data]
218    data = tf_utils.convert_inner_node_data(data)
219    return data
220
221  #############################################################
222  # Properties for Backwards compatibility.
223  # These only check the first input argument
224  # As nodes are internal, they may be removed in the future.
225  #############################################################
226
227  @property
228  def input_tensors(self):
229    if self.is_input:
230      return [self.outputs]  # Used in `Layer.input`.
231    return self.call_args[0]
232
233  @property
234  def output_tensors(self):
235    if self.is_input:
236      return [self.outputs]  # Used in `Layer.input`.
237    return self.outputs
238
239  @property
240  def input_shapes(self):
241    input_shapes = nest.map_structure(backend.int_shape, self.input_tensors)
242    if len(input_shapes) == 1 and not self.is_input:
243      return input_shapes[0]
244    return input_shapes
245
246  @property
247  def output_shapes(self):
248    return nest.map_structure(backend.int_shape, self.output_tensors)
249
250  @property
251  def outbound_layer(self):
252    return self.layer
253
254  @property
255  def inbound_layers(self):
256    if self.is_input:
257      return []
258    inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
259                                        self.call_args[0])
260    return inbound_layers
261
262
263class KerasHistory(
264    collections.namedtuple('KerasHistory',
265                           ['layer', 'node_index', 'tensor_index'])):
266  """Tracks the Layer call that created a Tensor, for Keras Graph Networks.
267
268  During construction of Keras Graph Networks, this metadata is added to
269  each Tensor produced as the output of a Layer, starting with an
270  `InputLayer`. This allows Keras to track how each Tensor was produced, and
271  this information is later retraced by the `keras.engine.Network` class to
272  reconstruct the Keras Graph Network.
273
274  Attributes:
275    layer: The Layer that produced the Tensor.
276    node_index: The specific call to the Layer that produced this Tensor. Layers
277      can be called multiple times in order to share weights. A new node is
278      created every time a Layer is called.
279    tensor_index: The output index for this Tensor. Always zero if the Layer
280      that produced this Tensor only has one output. Nested structures of
281      Tensors are deterministically assigned an index via `nest.flatten`.
282  """
283  # Added to maintain memory and performance characteristics of `namedtuple`
284  # while subclassing.
285  __slots__ = ()
286
287
288def is_keras_tensor(obj):
289  return hasattr(obj, '_keras_history')
290