xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/constant_op.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations that generate constants.
16
17See the [constants guide](https://tensorflow.org/api_guides/python/constant_op).
18"""
19
20# Must be separate from array_ops to avoid a cyclic dependency.
21
22from tensorflow.core.framework import attr_value_pb2
23from tensorflow.core.framework import types_pb2
24from tensorflow.python.eager import context
25from tensorflow.python.eager import execute
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import op_callbacks
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.profiler import trace
32from tensorflow.python.util.tf_export import tf_export
33
34
35def _eager_reshape(tensor, shape, ctx):
36  """Eager-only version of Reshape op; requires tensor is an eager Tensor."""
37  attr_t = tensor._datatype_enum()  # pylint: disable=protected-access
38  attr_tshape, (shape,) = execute.args_to_matching_eager(
39      [shape], ctx, [dtypes.int32, dtypes.int64], dtypes.int32)
40  inputs_flat = [tensor, shape]
41  attrs = ("T", attr_t, "Tshape", attr_tshape)
42  result, = execute.execute(
43      b"Reshape", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
44  return result
45
46
47def _eager_fill(dims, value, ctx):
48  """Eager-only version of Fill op; requires value is an eager Tensor."""
49  attr_t = value.dtype.as_datatype_enum
50  dims = convert_to_eager_tensor(dims, ctx, dtypes.int32)
51  inputs_flat = [dims, value]
52  attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32)
53  result, = execute.execute(
54      b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx)
55  return result
56
57
58def _eager_identity(tensor, ctx):
59  """Eager-only version of Identity op; requires tensor is an eager Tensor."""
60  attrs = ("T", tensor.dtype.as_datatype_enum)
61  result, = execute.execute(
62      b"Identity", 1, inputs=[tensor], attrs=attrs, ctx=ctx)
63  return result
64
65
66def _eager_const(tensor, ctx):
67  """Copy a constant to the current device."""
68  attrs = ("T", tensor.dtype.as_datatype_enum)
69  result, = execute.execute(
70      b"_EagerConst", 1, inputs=[tensor], attrs=attrs, ctx=ctx)
71  return result
72
73
74def convert_to_eager_tensor(value, ctx, dtype=None):
75  """Converts the given `value` to an `EagerTensor`.
76
77  Note that this function could return cached copies of created constants for
78  performance reasons.
79
80  Args:
81    value: value to convert to EagerTensor.
82    ctx: value of context.context().
83    dtype: optional desired dtype of the converted EagerTensor.
84
85  Returns:
86    EagerTensor created from value.
87
88  Raises:
89    TypeError: if `dtype` is not compatible with the type of t.
90  """
91  if isinstance(value, ops.EagerTensor):
92    if dtype is not None and value.dtype != dtype:
93      raise TypeError(f"Expected tensor {value} with dtype {dtype!r}, but got "
94                      f"dtype {value.dtype!r}.")
95    return value
96  if dtype is not None:
97    try:
98      dtype = dtype.as_datatype_enum
99    except AttributeError:
100      dtype = dtypes.as_dtype(dtype).as_datatype_enum
101  ctx.ensure_initialized()
102  return ops.EagerTensor(value, ctx.device_name, dtype)
103
104
105@tf_export(v1=["constant"])
106def constant_v1(
107    value, dtype=None, shape=None, name="Const", verify_shape=False):
108  """Creates a constant tensor.
109
110  The resulting tensor is populated with values of type `dtype`, as
111  specified by arguments `value` and (optionally) `shape` (see examples
112  below).
113
114  The argument `value` can be a constant value, or a list of values of type
115  `dtype`. If `value` is a list, then the length of the list must be less
116  than or equal to the number of elements implied by the `shape` argument (if
117  specified). In the case where the list length is less than the number of
118  elements specified by `shape`, the last element in the list will be used
119  to fill the remaining entries.
120
121  The argument `shape` is optional. If present, it specifies the dimensions of
122  the resulting tensor. If not present, the shape of `value` is used.
123
124  If the argument `dtype` is not specified, then the type is inferred from
125  the type of `value`.
126
127  For example:
128
129  ```python
130  # Constant 1-D Tensor populated with value list.
131  tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7]
132
133  # Constant 2-D tensor populated with scalar value -1.
134  tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
135                                               [-1. -1. -1.]]
136  ```
137
138  `tf.constant` differs from `tf.fill` in a few ways:
139
140  *   `tf.constant` supports arbitrary constants, not just uniform scalar
141      Tensors like `tf.fill`.
142  *   `tf.constant` creates a `Const` node in the computation graph with the
143      exact value at graph construction time. On the other hand, `tf.fill`
144      creates an Op in the graph that is expanded at runtime.
145  *   Because `tf.constant` only embeds constant values in the graph, it does
146      not support dynamic shapes based on other runtime Tensors, whereas
147      `tf.fill` does.
148
149  Args:
150    value:          A constant value (or list) of output type `dtype`.
151
152    dtype:          The type of the elements of the resulting tensor.
153
154    shape:          Optional dimensions of resulting tensor.
155
156    name:           Optional name for the tensor.
157
158    verify_shape:   Boolean that enables verification of a shape of values.
159
160  Returns:
161    A Constant Tensor.
162
163  Raises:
164    TypeError: if shape is incorrectly specified or unsupported.
165  """
166  return _constant_impl(value, dtype, shape, name, verify_shape=verify_shape,
167                        allow_broadcast=False)
168
169
170@tf_export("constant", v1=[])
171def constant(value, dtype=None, shape=None, name="Const"):
172  """Creates a constant tensor from a tensor-like object.
173
174  Note: All eager `tf.Tensor` values are immutable (in contrast to
175  `tf.Variable`). There is nothing especially _constant_ about the value
176  returned from `tf.constant`. This function is not fundamentally different from
177  `tf.convert_to_tensor`. The name `tf.constant` comes from the `value` being
178  embedded in a `Const` node in the `tf.Graph`. `tf.constant` is useful
179  for asserting that the value can be embedded that way.
180
181  If the argument `dtype` is not specified, then the type is inferred from
182  the type of `value`.
183
184  >>> # Constant 1-D Tensor from a python list.
185  >>> tf.constant([1, 2, 3, 4, 5, 6])
186  <tf.Tensor: shape=(6,), dtype=int32,
187      numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
188  >>> # Or a numpy array
189  >>> a = np.array([[1, 2, 3], [4, 5, 6]])
190  >>> tf.constant(a)
191  <tf.Tensor: shape=(2, 3), dtype=int64, numpy=
192    array([[1, 2, 3],
193           [4, 5, 6]])>
194
195  If `dtype` is specified, the resulting tensor values are cast to the requested
196  `dtype`.
197
198  >>> tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.float64)
199  <tf.Tensor: shape=(6,), dtype=float64,
200      numpy=array([1., 2., 3., 4., 5., 6.])>
201
202  If `shape` is set, the `value` is reshaped to match. Scalars are expanded to
203  fill the `shape`:
204
205  >>> tf.constant(0, shape=(2, 3))
206    <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
207    array([[0, 0, 0],
208           [0, 0, 0]], dtype=int32)>
209  >>> tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])
210  <tf.Tensor: shape=(2, 3), dtype=int32, numpy=
211    array([[1, 2, 3],
212           [4, 5, 6]], dtype=int32)>
213
214  `tf.constant` has no effect if an eager Tensor is passed as the `value`, it
215  even transmits gradients:
216
217  >>> v = tf.Variable([0.0])
218  >>> with tf.GradientTape() as g:
219  ...     loss = tf.constant(v + v)
220  >>> g.gradient(loss, v).numpy()
221  array([2.], dtype=float32)
222
223  But, since `tf.constant` embeds the value in the `tf.Graph` this fails for
224  symbolic tensors:
225
226  >>> with tf.compat.v1.Graph().as_default():
227  ...   i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
228  ...   t = tf.constant(i)
229  Traceback (most recent call last):
230  ...
231  TypeError: ...
232
233  `tf.constant` will create tensors on the current device. Inputs which are
234  already tensors maintain their placements unchanged.
235
236  Related Ops:
237
238  * `tf.convert_to_tensor` is similar but:
239    * It has no `shape` argument.
240    * Symbolic tensors are allowed to pass through.
241
242    >>> with tf.compat.v1.Graph().as_default():
243    ...   i = tf.compat.v1.placeholder(shape=[None, None], dtype=tf.float32)
244    ...   t = tf.convert_to_tensor(i)
245
246  * `tf.fill`: differs in a few ways:
247    *   `tf.constant` supports arbitrary constants, not just uniform scalar
248        Tensors like `tf.fill`.
249    *   `tf.fill` creates an Op in the graph that is expanded at runtime, so it
250        can efficiently represent large tensors.
251    *   Since `tf.fill` does not embed the value, it can produce dynamically
252        sized outputs.
253
254  Args:
255    value: A constant value (or list) of output type `dtype`.
256    dtype: The type of the elements of the resulting tensor.
257    shape: Optional dimensions of resulting tensor.
258    name: Optional name for the tensor.
259
260  Returns:
261    A Constant Tensor.
262
263  Raises:
264    TypeError: if shape is incorrectly specified or unsupported.
265    ValueError: if called on a symbolic tensor.
266  """
267  return _constant_impl(value, dtype, shape, name, verify_shape=False,
268                        allow_broadcast=True)
269
270
271def _constant_impl(
272    value, dtype, shape, name, verify_shape, allow_broadcast):
273  """Implementation of constant."""
274  ctx = context.context()
275  if ctx.executing_eagerly():
276    if trace.enabled:
277      with trace.Trace("tf.constant"):
278        return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
279    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
280
281  g = ops.get_default_graph()
282  tensor_value = attr_value_pb2.AttrValue()
283  tensor_value.tensor.CopyFrom(
284      tensor_util.make_tensor_proto(
285          value, dtype=dtype, shape=shape, verify_shape=verify_shape,
286          allow_broadcast=allow_broadcast))
287  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
288  attrs = {"value": tensor_value, "dtype": dtype_value}
289  const_tensor = g._create_op_internal(  # pylint: disable=protected-access
290      "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]
291
292  if op_callbacks.should_invoke_op_callbacks():
293    # TODO(b/147670703): Once the special-op creation code paths
294    # are unified. Remove this `if` block.
295    callback_outputs = op_callbacks.invoke_op_callbacks(
296        "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
297    if callback_outputs is not None:
298      const_tensor, = callback_outputs
299  return const_tensor
300
301
302def _constant_eager_impl(ctx, value, dtype, shape, verify_shape):
303  """Creates a constant on the current device."""
304  t = convert_to_eager_tensor(value, ctx, dtype)
305  if shape is None:
306    return t
307  shape = tensor_shape.as_shape(shape)
308  if shape == t.shape:
309    return t
310  if verify_shape:
311    raise TypeError(f"Expected Tensor {t} (converted from {value}) with shape "
312                    f"{tuple(shape)}, but got shape {tuple(t.shape)}.")
313  num_t = t.shape.num_elements()
314  # TODO(josh11b): Implement shape -> eager tensor conversion.
315  if num_t == shape.num_elements():
316    return _eager_reshape(t, shape.as_list(), ctx)
317  if num_t == 1:
318    if t.dtype == dtypes.bool:
319      # We don't have a Fill kernel for bool dtype on GPU. So we first run
320      # Fill on CPU and then copy to GPU if needed.
321      with ops.device("/device:CPU:0"):
322        x = _eager_fill(shape.as_list(), _eager_identity(t, ctx), ctx)
323      return _eager_identity(x, ctx)
324    else:
325      return _eager_fill(shape.as_list(), t, ctx)
326  raise TypeError("Eager execution of tf.constant with unsupported shape. "
327                  f"Tensor {t} (converted from {value}) has {num_t:d} "
328                  f"elements, but got `shape` {shape} with "
329                  f"{shape.num_elements()} elements).")
330
331
332def is_constant(tensor_or_op):
333  if isinstance(tensor_or_op, ops.Tensor):
334    op = tensor_or_op.op
335  else:
336    op = tensor_or_op
337  return op.type == "Const"
338
339
340def _constant_tensor_conversion_function(v, dtype=None, name=None,
341                                         as_ref=False):
342  _ = as_ref
343  return constant(v, dtype=dtype, name=name)
344
345
346ops.register_tensor_conversion_function(
347    (list, tuple), _constant_tensor_conversion_function, 100)
348ops.register_tensor_conversion_function(
349    object, _constant_tensor_conversion_function, 200)
350
351
352def _tensor_shape_tensor_conversion_function(s,
353                                             dtype=None,
354                                             name=None,
355                                             as_ref=False):
356  """Function to convert TensorShape to Tensor."""
357  _ = as_ref
358  if not s.is_fully_defined():
359    raise ValueError(
360        f"Cannot convert a partially known TensorShape {s} to a Tensor.")
361  s_list = s.as_list()
362  int64_value = 0
363  for dim in s_list:
364    if dim >= 2**31:
365      int64_value = dim
366      break
367
368  if dtype is not None:
369    if dtype not in (dtypes.int32, dtypes.int64):
370      raise TypeError(f"Cannot convert TensorShape {s} to dtype {dtype}. "
371                      "Allowed dtypes are tf.int32 and tf.int64.")
372    if dtype == dtypes.int32 and int64_value:
373      raise ValueError(f"Cannot convert TensorShape {s} to dtype int32; "
374                       f"a dimension is too large. Consider using tf.int64.")
375  else:
376    dtype = dtypes.int64 if int64_value else dtypes.int32
377  if name is None:
378    name = "shape_as_tensor"
379  return constant(s_list, dtype=dtype, name=name)
380
381
382ops.register_tensor_conversion_function(
383    tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100)
384
385
386def _dimension_tensor_conversion_function(d,
387                                          dtype=None,
388                                          name=None,
389                                          as_ref=False):
390  """Function to convert Dimension to Tensor."""
391  _ = as_ref
392  if d.value is None:
393    raise ValueError(f"Cannot convert unknown Dimension {d} to a Tensor.")
394  if dtype is not None:
395    if dtype not in (dtypes.int32, dtypes.int64):
396      raise TypeError(f"Cannot convert Dimension {d} to dtype {dtype}. "
397                      "Allowed dtypes are tf.int32 and tf.int64.")
398  else:
399    dtype = dtypes.int32
400  if name is None:
401    name = "shape_as_tensor"
402  return constant(d.value, dtype=dtype, name=name)
403
404
405ops.register_tensor_conversion_function(
406    tensor_shape.Dimension, _dimension_tensor_conversion_function, 100)
407