xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/input_spec.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=g-classes-have-attributes
17"""Contains the InputSpec class."""
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_spec
22from tensorflow.python.keras import backend
23from tensorflow.python.util import nest
24from tensorflow.python.util.tf_export import keras_export
25from tensorflow.python.util.tf_export import tf_export
26
27
28@keras_export('keras.layers.InputSpec')
29@tf_export(v1=['layers.InputSpec'])
30class InputSpec(object):
31  """Specifies the rank, dtype and shape of every input to a layer.
32
33  Layers can expose (if appropriate) an `input_spec` attribute:
34  an instance of `InputSpec`, or a nested structure of `InputSpec` instances
35  (one per input tensor). These objects enable the layer to run input
36  compatibility checks for input structure, input rank, input shape, and
37  input dtype.
38
39  A None entry in a shape is compatible with any dimension,
40  a None shape is compatible with any shape.
41
42  Args:
43    dtype: Expected DataType of the input.
44    shape: Shape tuple, expected shape of the input
45      (may include None for unchecked axes). Includes the batch size.
46    ndim: Integer, expected rank of the input.
47    max_ndim: Integer, maximum rank of the input.
48    min_ndim: Integer, minimum rank of the input.
49    axes: Dictionary mapping integer axes to
50      a specific dimension value.
51    allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
52      as the last axis of the input is 1, as well as inputs of rank N-1
53      as long as the last axis of the spec is 1.
54    name: Expected key corresponding to this input when passing data as
55      a dictionary.
56
57  Example:
58
59  ```python
60  class MyLayer(Layer):
61      def __init__(self):
62          super(MyLayer, self).__init__()
63          # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
64          # and raise an appropriate error message otherwise.
65          self.input_spec = InputSpec(
66              shape=(None, 28, 28, 1),
67              allow_last_axis_squeeze=True)
68  ```
69  """
70
71  def __init__(self,
72               dtype=None,
73               shape=None,
74               ndim=None,
75               max_ndim=None,
76               min_ndim=None,
77               axes=None,
78               allow_last_axis_squeeze=False,
79               name=None):
80    self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
81    shape = tensor_shape.TensorShape(shape)
82    if shape.rank is None:
83      shape = None
84    else:
85      shape = tuple(shape.as_list())
86    if shape is not None:
87      self.ndim = len(shape)
88      self.shape = shape
89    else:
90      self.ndim = ndim
91      self.shape = None
92    self.max_ndim = max_ndim
93    self.min_ndim = min_ndim
94    self.name = name
95    self.allow_last_axis_squeeze = allow_last_axis_squeeze
96    try:
97      axes = axes or {}
98      self.axes = {int(k): axes[k] for k in axes}
99    except (ValueError, TypeError):
100      raise TypeError('The keys in axes must be integers.')
101
102    if self.axes and (self.ndim is not None or self.max_ndim is not None):
103      max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
104      max_axis = max(self.axes)
105      if max_axis > max_dim:
106        raise ValueError('Axis {} is greater than the maximum allowed value: {}'
107                         .format(max_axis, max_dim))
108
109  def __repr__(self):
110    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
111            ('shape=' + str(self.shape)) if self.shape else '',
112            ('ndim=' + str(self.ndim)) if self.ndim else '',
113            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
114            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
115            ('axes=' + str(self.axes)) if self.axes else '']
116    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
117
118  def get_config(self):
119    return {
120        'dtype': self.dtype,
121        'shape': self.shape,
122        'ndim': self.ndim,
123        'max_ndim': self.max_ndim,
124        'min_ndim': self.min_ndim,
125        'axes': self.axes}
126
127  @classmethod
128  def from_config(cls, config):
129    return cls(**config)
130
131
132def to_tensor_shape(spec):
133  """Returns a tf.TensorShape object that matches the shape specifications.
134
135  If the InputSpec's shape or ndim is defined, this method will return a fully
136  or partially-known shape. Otherwise, the returned TensorShape is None.
137
138  Args:
139    spec: an InputSpec object.
140
141  Returns:
142    a tf.TensorShape object
143  """
144  if spec.ndim is None and spec.shape is None:
145    return tensor_shape.TensorShape(None)
146  elif spec.shape is not None:
147    return tensor_shape.TensorShape(spec.shape)
148  else:
149    shape = [None] * spec.ndim
150    for a in spec.axes:
151      shape[a] = spec.axes[a]  # Assume that axes is defined
152    return tensor_shape.TensorShape(shape)
153
154
155def assert_input_compatibility(input_spec, inputs, layer_name):
156  """Checks compatibility between the layer and provided inputs.
157
158  This checks that the tensor(s) `inputs` verify the input assumptions
159  of a layer (if any). If not, a clear and actional exception gets raised.
160
161  Args:
162      input_spec: An InputSpec instance, list of InputSpec instances, a nested
163          structure of InputSpec instances, or None.
164      inputs: Input tensor, list of input tensors, or a nested structure of
165          input tensors.
166      layer_name: String, name of the layer (for error message formatting).
167
168  Raises:
169      ValueError: in case of mismatch between
170          the provided inputs and the expectations of the layer.
171  """
172  if not input_spec:
173    return
174
175  input_spec = nest.flatten(input_spec)
176  if isinstance(inputs, dict):
177    # Flatten `inputs` by reference order if input spec names are provided
178    names = [spec.name for spec in input_spec]
179    if all(names):
180      list_inputs = []
181      for name in names:
182        if name not in inputs:
183          raise ValueError('Missing data for input "%s". '
184                           'You passed a data dictionary with keys %s. '
185                           'Expected the following keys: %s' %
186                           (name, list(inputs.keys()), names))
187        list_inputs.append(inputs[name])
188      inputs = list_inputs
189
190  inputs = nest.flatten(inputs)
191  for x in inputs:
192    # Having a shape/dtype is the only commonality of the various tensor-like
193    # objects that may be passed. The most common kind of invalid type we are
194    # guarding for is a Layer instance (Functional API), which does not
195    # have a `shape` attribute.
196    if not hasattr(x, 'shape'):
197      raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
198
199  if len(inputs) != len(input_spec):
200    raise ValueError('Layer ' + layer_name + ' expects ' +
201                     str(len(input_spec)) + ' input(s), '
202                     'but it received ' + str(len(inputs)) +
203                     ' input tensors. Inputs received: ' + str(inputs))
204  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
205    if spec is None:
206      continue
207
208    shape = tensor_shape.TensorShape(x.shape)
209    if shape.rank is None:
210      return
211    # Check ndim.
212    if spec.ndim is not None and not spec.allow_last_axis_squeeze:
213      ndim = shape.rank
214      if ndim != spec.ndim:
215        raise ValueError('Input ' + str(input_index) + ' of layer ' +
216                         layer_name + ' is incompatible with the layer: '
217                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
218                         str(ndim) + '. Full shape received: ' +
219                         str(tuple(shape)))
220    if spec.max_ndim is not None:
221      ndim = x.shape.rank
222      if ndim is not None and ndim > spec.max_ndim:
223        raise ValueError('Input ' + str(input_index) + ' of layer ' +
224                         layer_name + ' is incompatible with the layer: '
225                         'expected max_ndim=' + str(spec.max_ndim) +
226                         ', found ndim=' + str(ndim))
227    if spec.min_ndim is not None:
228      ndim = x.shape.rank
229      if ndim is not None and ndim < spec.min_ndim:
230        raise ValueError('Input ' + str(input_index) + ' of layer ' +
231                         layer_name + ' is incompatible with the layer: '
232                         ': expected min_ndim=' + str(spec.min_ndim) +
233                         ', found ndim=' + str(ndim) +
234                         '. Full shape received: ' +
235                         str(tuple(shape)))
236    # Check dtype.
237    if spec.dtype is not None:
238      if x.dtype.name != spec.dtype:
239        raise ValueError('Input ' + str(input_index) + ' of layer ' +
240                         layer_name + ' is incompatible with the layer: '
241                         'expected dtype=' + str(spec.dtype) +
242                         ', found dtype=' + str(x.dtype))
243
244    # Check specific shape axes.
245    shape_as_list = shape.as_list()
246    if spec.axes:
247      for axis, value in spec.axes.items():
248        if hasattr(value, 'value'):
249          value = value.value
250        if value is not None and shape_as_list[int(axis)] not in {value, None}:
251          raise ValueError(
252              'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
253              ' incompatible with the layer: expected axis ' + str(axis) +
254              ' of input shape to have value ' + str(value) +
255              ' but received input with shape ' + display_shape(x.shape))
256    # Check shape.
257    if spec.shape is not None and shape.rank is not None:
258      spec_shape = spec.shape
259      if spec.allow_last_axis_squeeze:
260        if shape_as_list and shape_as_list[-1] == 1:
261          shape_as_list = shape_as_list[:-1]
262        if spec_shape and spec_shape[-1] == 1:
263          spec_shape = spec_shape[:-1]
264      for spec_dim, dim in zip(spec_shape, shape_as_list):
265        if spec_dim is not None and dim is not None:
266          if spec_dim != dim:
267            raise ValueError('Input ' + str(input_index) +
268                             ' is incompatible with layer ' + layer_name +
269                             ': expected shape=' + str(spec.shape) +
270                             ', found shape=' + display_shape(x.shape))
271
272
273def display_shape(shape):
274  return str(tuple(shape.as_list()))
275
276
277def to_tensor_spec(input_spec, default_dtype=None):
278  """Converts a Keras InputSpec object to a TensorSpec."""
279  default_dtype = default_dtype or backend.floatx()
280  if isinstance(input_spec, InputSpec):
281    dtype = input_spec.dtype or default_dtype
282    return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
283  return tensor_spec.TensorSpec(None, default_dtype)
284