xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/handle_data_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from tensorflow.python.client import pywrap_tf_session
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20
21
22get_resource_handle_data = ops.get_resource_handle_data
23
24
25def copy_handle_data(source_t, target_t):
26  """Copies HandleData for variant and resource type tensors if available.
27
28  The CppShapeInferenceResult::HandleData proto contains information about the
29  shapes and types of the element tensors of resource/variant type tensors.
30  We need to copy this across function boundaries, i.e., when capturing a
31  placeholder or when returning a function tensor as output. If we don't do this
32  the element tensors will have unknown shapes, e.g., if a TensorList variant
33  tensor is captured as a placeholder, elements popped from that list would have
34  unknown shape.
35
36  Args:
37    source_t: The tensor to copy HandleData from.
38    target_t: The tensor to copy HandleData to.
39  """
40  if (target_t.dtype == dtypes.resource or
41      target_t.dtype == dtypes.variant):
42    if isinstance(source_t, ops.EagerTensor):
43      handle_data = source_t._handle_data  # pylint: disable=protected-access
44    else:
45      handle_data = get_resource_handle_data(source_t)
46    if (handle_data is not None
47        and handle_data.is_set
48        and handle_data.shape_and_type):
49      set_handle_data(target_t, handle_data)
50
51
52def set_handle_data(target_t, handle_data):
53  # pylint: disable=protected-access
54  if isinstance(target_t, ops.EagerTensor):
55    target_t._handle_data = handle_data
56    return
57  with target_t.graph._c_graph.get() as c_graph:
58    pywrap_tf_session.SetHandleShapeAndType(c_graph, target_t._as_tf_output(),
59                                            handle_data.SerializeToString())
60  # pylint: enable=protected-access
61