1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DLPack functions.""" 16from absl.testing import parameterized 17import numpy as np 18 19 20from tensorflow.python.dlpack import dlpack 21from tensorflow.python.eager import context 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.platform import test 27from tensorflow.python.ops import array_ops 28 29int_dtypes = [ 30 np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, 31 np.uint64 32] 33float_dtypes = [np.float16, np.float32, np.float64] 34complex_dtypes = [np.complex64, np.complex128] 35dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] + complex_dtypes 36 37testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)] 38 39 40def FormatShapeAndDtype(shape, dtype): 41 return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) 42 43 44def GetNamedTestParameters(): 45 result = [] 46 for dtype in dlpack_dtypes: 47 for shape in testcase_shapes: 48 result.append({ 49 "testcase_name": FormatShapeAndDtype(shape, dtype), 50 "dtype": dtype, 51 "shape": shape 52 }) 53 return result 54 55 56class DLPackTest(parameterized.TestCase, test.TestCase): 57 58 @parameterized.named_parameters(GetNamedTestParameters()) 59 def testRoundTrip(self, dtype, shape): 60 np.random.seed(42) 61 np_array = np.random.randint(0, 10, shape) 62 # copy to gpu if available 63 tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype)) 64 tf_tensor_device = tf_tensor.device 65 tf_tensor_dtype = tf_tensor.dtype 66 dlcapsule = dlpack.to_dlpack(tf_tensor) 67 del tf_tensor # should still work 68 tf_tensor2 = dlpack.from_dlpack(dlcapsule) 69 self.assertAllClose(np_array, tf_tensor2) 70 if tf_tensor_dtype == dtypes.int32: 71 # int32 tensor is always on cpu for now 72 self.assertEqual(tf_tensor2.device, 73 "/job:localhost/replica:0/task:0/device:CPU:0") 74 else: 75 self.assertEqual(tf_tensor_device, tf_tensor2.device) 76 77 def testTensorsCanBeConsumedOnceOnly(self): 78 np.random.seed(42) 79 np_array = np.random.randint(0, 10, (2, 3, 4)) 80 tf_tensor = constant_op.constant(np_array, dtype=np.float32) 81 dlcapsule = dlpack.to_dlpack(tf_tensor) 82 del tf_tensor # should still work 83 _ = dlpack.from_dlpack(dlcapsule) 84 85 def ConsumeDLPackTensor(): 86 dlpack.from_dlpack(dlcapsule) # Should can be consumed only once 87 88 self.assertRaisesRegex(Exception, 89 ".*a DLPack tensor may be consumed at most once.*", 90 ConsumeDLPackTensor) 91 92 def testDLPackFromWithoutContextInitialization(self): 93 tf_tensor = constant_op.constant(1) 94 dlcapsule = dlpack.to_dlpack(tf_tensor) 95 # Resetting the context doesn't cause an error. 96 context._reset_context() 97 _ = dlpack.from_dlpack(dlcapsule) 98 99 def testUnsupportedTypeToDLPack(self): 100 101 def UnsupportedQint16(): 102 tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16) 103 _ = dlpack.to_dlpack(tf_tensor) 104 105 self.assertRaisesRegex(Exception, ".* is not supported by dlpack", 106 UnsupportedQint16) 107 108 def testMustPassTensorArgumentToDLPack(self): 109 with self.assertRaisesRegex( 110 errors.InvalidArgumentError, 111 "The argument to `to_dlpack` must be a TF tensor, not Python object"): 112 dlpack.to_dlpack([1]) 113 114 115if __name__ == "__main__": 116 ops.enable_eager_execution() 117 test.main() 118