1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Tensorflow -> jitrt compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 # TODO(ezhulenev): Fix memrefCopy msan warnings to enable these tests. 24 # tf_jitrt.Specialization.ENABLED, 25 # tf_jitrt.Specialization.DISABLED, 26 tf_jitrt.Specialization.ALWAYS, 27] 28 29vectorization = [False, True] 30 31jitrt = tf_jitrt.TfJitRtExecutor() 32 33 34class TfFunction(test.TestCase): 35 36 def test_func_0(self): 37 for specialize in specializations: 38 for vectorize in vectorization: 39 mlir_function = """ 40 func.func @test(%arg0: tensor<1x?xf32>, 41 %arg1: tensor<1x?xf32>, 42 %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> { 43 %c = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} 44 : () -> tensor<f32> 45 %0 = "tf.Tanh"(%arg0) 46 : (tensor<1x?xf32>) -> tensor<1x?xf32> 47 %1 = "tf.Mul"(%arg1, %arg2) 48 : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32> 49 %2 = "tf.Sub"(%c, %arg2) 50 : (tensor<f32>, tensor<1x?xf32>) -> tensor<1x?xf32> 51 %3 = "tf.Mul"(%0, %2) 52 : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32> 53 %4 = "tf.AddV2"(%1, %3) 54 : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32> 55 return %4 : tensor<1x?xf32> 56 }""" 57 58 compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize) 59 60 d0 = np.random.randint(128, 256) 61 arg0 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32) 62 arg1 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32) 63 arg2 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32) 64 65 [res] = jitrt.execute(compiled, [arg0, arg1, arg2]) 66 67 # Function under test spelled in NumPy 68 v0 = np.tanh(arg0) 69 v1 = arg1 * arg2 70 v2 = 1.0 - arg2 71 v3 = v0 * v2 72 v4 = v1 + v3 73 74 np.testing.assert_allclose(res, v4, atol=1e-06) 75 76 def test_func_1(self): 77 for vectorize in vectorization: 78 mlir_function = """ 79 func.func @test(%arg0: tensor<*xf32> {rt.constraint = "rank"}) 80 -> (tensor<*xf32>, tensor<*xf32>) { 81 %c = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} 82 : () -> tensor<f32> 83 %0 = "tf.Sub"(%c, %arg0) 84 : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32> 85 %1 = "tf.Sub"(%c, %0) 86 : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32> 87 return %0, %1 : tensor<*xf32>, tensor<*xf32> 88 }""" 89 90 compiled = jitrt.compile(mlir_function, 'test', 91 tf_jitrt.Specialization.ALWAYS, vectorize) 92 93 d0 = np.random.randint(128, 256) 94 arg0 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32) 95 96 [res0, res1] = jitrt.execute(compiled, [arg0]) 97 98 # Function under test spelled in NumPy 99 v0 = 1.0 - arg0 100 v1 = 1.0 - v0 101 102 np.testing.assert_allclose(res0, v0, atol=0.0) 103 np.testing.assert_allclose(res1, v1, atol=0.0) 104 105 def test_func_2(self): 106 for vectorize in vectorization: 107 mlir_function = """ 108 func.func @test(%arg0: tensor<*xf32> {rt.constraint = "rank"}, 109 %arg1: tensor<?x?xf32>, 110 %arg2: tensor<?x?xf32>, 111 %arg3: tensor<?x?xf32>) -> tensor<*xf32> { 112 %0 = "tf.Mul"(%arg0, %arg1) 113 : (tensor<*xf32>, tensor<?x?xf32>) -> tensor<*xf32> 114 %1 = "tf.Mul"(%arg2, %arg3) 115 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> 116 %2 = "tf.AddV2"(%0, %1) 117 : (tensor<*xf32>, tensor<?x?xf32>) -> tensor<*xf32> 118 return %2 : tensor<*xf32> 119 }""" 120 121 compiled = jitrt.compile(mlir_function, 'test', 122 tf_jitrt.Specialization.ALWAYS, vectorize) 123 124 d0 = np.random.randint(4, 8) 125 d1 = np.random.randint(4, 8) 126 127 arg1 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32) 128 arg2 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32) 129 arg3 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32) 130 131 for shape in [(), (d1), (d0, d1)]: 132 arg0 = np.random.uniform(1.0, 10.0, size=shape).astype(np.float32) 133 [res] = jitrt.execute(compiled, [arg0, arg1, arg2, arg3]) 134 135 # Function under test spelled in NumPy 136 v0 = arg0 * arg1 137 v1 = arg2 * arg3 138 v3 = v0 + v1 139 140 np.testing.assert_allclose(res, v3, atol=0.0) 141 142 def test_func_3(self): 143 for vectorize in vectorization: 144 mlir_function = """ 145 func.func @test(%arg0: tensor<i32>, %arg1: tensor<i32>) 146 -> (tensor<i32>, tensor<i32>) { 147 %c = "tf.Const"() {value = dense<1> : tensor<i32>} 148 : () -> tensor<i32> 149 %0 = "tf.Maximum"(%c, %arg0) 150 : (tensor<i32>, tensor<i32>) -> tensor<i32> 151 %1 = "tf.Minimum"(%arg1, %0) 152 : (tensor<i32>, tensor<i32>) -> tensor<i32> 153 return %0, %1 : tensor<i32>, tensor<i32> 154 }""" 155 156 compiled = jitrt.compile(mlir_function, 'test', 157 tf_jitrt.Specialization.ALWAYS, vectorize) 158 159 arg0 = np.random.uniform(-100, 100, size=()).astype(np.int32) 160 arg1 = np.random.uniform(-100, 100, size=()).astype(np.int32) 161 162 [res0, res1] = jitrt.execute(compiled, [arg0, arg1]) 163 164 # Function under test spelled in NumPy 165 v0 = np.maximum(1, arg0) 166 v1 = np.minimum(arg1, v0) 167 168 np.testing.assert_allclose(res0, v0, atol=0.0) 169 np.testing.assert_allclose(res1, v1, atol=0.0) 170 171 172if __name__ == '__main__': 173 test.main() 174