xref: /aosp_15_r20/external/tensorflow/tensorflow/python/dlpack/dlpack_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DLPack functions."""
16from absl.testing import parameterized
17import numpy as np
18
19
20from tensorflow.python.dlpack import dlpack
21from tensorflow.python.eager import context
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.platform import test
27from tensorflow.python.ops import array_ops
28
29int_dtypes = [
30    np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
31    np.uint64
32]
33float_dtypes = [np.float16, np.float32, np.float64]
34complex_dtypes = [np.complex64, np.complex128]
35dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] + complex_dtypes
36
37testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
38
39
40def FormatShapeAndDtype(shape, dtype):
41  return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
42
43
44def GetNamedTestParameters():
45  result = []
46  for dtype in dlpack_dtypes:
47    for shape in testcase_shapes:
48      result.append({
49          "testcase_name": FormatShapeAndDtype(shape, dtype),
50          "dtype": dtype,
51          "shape": shape
52      })
53  return result
54
55
56class DLPackTest(parameterized.TestCase, test.TestCase):
57
58  @parameterized.named_parameters(GetNamedTestParameters())
59  def testRoundTrip(self, dtype, shape):
60    np.random.seed(42)
61    np_array = np.random.randint(0, 10, shape)
62    # copy to gpu if available
63    tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype))
64    tf_tensor_device = tf_tensor.device
65    tf_tensor_dtype = tf_tensor.dtype
66    dlcapsule = dlpack.to_dlpack(tf_tensor)
67    del tf_tensor  # should still work
68    tf_tensor2 = dlpack.from_dlpack(dlcapsule)
69    self.assertAllClose(np_array, tf_tensor2)
70    if tf_tensor_dtype == dtypes.int32:
71      # int32 tensor is always on cpu for now
72      self.assertEqual(tf_tensor2.device,
73                       "/job:localhost/replica:0/task:0/device:CPU:0")
74    else:
75      self.assertEqual(tf_tensor_device, tf_tensor2.device)
76
77  def testTensorsCanBeConsumedOnceOnly(self):
78    np.random.seed(42)
79    np_array = np.random.randint(0, 10, (2, 3, 4))
80    tf_tensor = constant_op.constant(np_array, dtype=np.float32)
81    dlcapsule = dlpack.to_dlpack(tf_tensor)
82    del tf_tensor  # should still work
83    _ = dlpack.from_dlpack(dlcapsule)
84
85    def ConsumeDLPackTensor():
86      dlpack.from_dlpack(dlcapsule)  # Should can be consumed only once
87
88    self.assertRaisesRegex(Exception,
89                           ".*a DLPack tensor may be consumed at most once.*",
90                           ConsumeDLPackTensor)
91
92  def testDLPackFromWithoutContextInitialization(self):
93    tf_tensor = constant_op.constant(1)
94    dlcapsule = dlpack.to_dlpack(tf_tensor)
95    # Resetting the context doesn't cause an error.
96    context._reset_context()
97    _ = dlpack.from_dlpack(dlcapsule)
98
99  def testUnsupportedTypeToDLPack(self):
100
101    def UnsupportedQint16():
102      tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
103      _ = dlpack.to_dlpack(tf_tensor)
104
105    self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
106                           UnsupportedQint16)
107
108  def testMustPassTensorArgumentToDLPack(self):
109    with self.assertRaisesRegex(
110        errors.InvalidArgumentError,
111        "The argument to `to_dlpack` must be a TF tensor, not Python object"):
112      dlpack.to_dlpack([1])
113
114
115if __name__ == "__main__":
116  ops.enable_eager_execution()
117  test.main()
118