xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/tf_function_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Tensorflow -> jitrt compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    # TODO(ezhulenev): Fix memrefCopy msan warnings to enable these tests.
24    # tf_jitrt.Specialization.ENABLED,
25    # tf_jitrt.Specialization.DISABLED,
26    tf_jitrt.Specialization.ALWAYS,
27]
28
29vectorization = [False, True]
30
31jitrt = tf_jitrt.TfJitRtExecutor()
32
33
34class TfFunction(test.TestCase):
35
36  def test_func_0(self):
37    for specialize in specializations:
38      for vectorize in vectorization:
39        mlir_function = """
40        func.func @test(%arg0: tensor<1x?xf32>,
41                       %arg1: tensor<1x?xf32>,
42                       %arg2: tensor<1x?xf32>) -> tensor<1x?xf32> {
43          %c = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
44               : () -> tensor<f32>
45          %0 = "tf.Tanh"(%arg0)
46               : (tensor<1x?xf32>) -> tensor<1x?xf32>
47          %1 = "tf.Mul"(%arg1, %arg2)
48               : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32>
49          %2 = "tf.Sub"(%c, %arg2)
50               : (tensor<f32>, tensor<1x?xf32>) -> tensor<1x?xf32>
51          %3 = "tf.Mul"(%0, %2)
52               : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32>
53          %4 = "tf.AddV2"(%1, %3)
54               : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<1x?xf32>
55          return %4 : tensor<1x?xf32>
56        }"""
57
58        compiled = jitrt.compile(mlir_function, 'test', specialize, vectorize)
59
60        d0 = np.random.randint(128, 256)
61        arg0 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32)
62        arg1 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32)
63        arg2 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32)
64
65        [res] = jitrt.execute(compiled, [arg0, arg1, arg2])
66
67        # Function under test spelled in NumPy
68        v0 = np.tanh(arg0)
69        v1 = arg1 * arg2
70        v2 = 1.0 - arg2
71        v3 = v0 * v2
72        v4 = v1 + v3
73
74        np.testing.assert_allclose(res, v4, atol=1e-06)
75
76  def test_func_1(self):
77    for vectorize in vectorization:
78      mlir_function = """
79        func.func @test(%arg0: tensor<*xf32> {rt.constraint = "rank"})
80            -> (tensor<*xf32>, tensor<*xf32>) {
81          %c = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
82               : () -> tensor<f32>
83          %0 = "tf.Sub"(%c, %arg0)
84               : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
85          %1 = "tf.Sub"(%c, %0)
86               : (tensor<f32>, tensor<*xf32>) -> tensor<*xf32>
87          return %0, %1 : tensor<*xf32>, tensor<*xf32>
88        }"""
89
90      compiled = jitrt.compile(mlir_function, 'test',
91                               tf_jitrt.Specialization.ALWAYS, vectorize)
92
93      d0 = np.random.randint(128, 256)
94      arg0 = np.random.uniform(1.0, 10.0, size=(1, d0)).astype(np.float32)
95
96      [res0, res1] = jitrt.execute(compiled, [arg0])
97
98      # Function under test spelled in NumPy
99      v0 = 1.0 - arg0
100      v1 = 1.0 - v0
101
102      np.testing.assert_allclose(res0, v0, atol=0.0)
103      np.testing.assert_allclose(res1, v1, atol=0.0)
104
105  def test_func_2(self):
106    for vectorize in vectorization:
107      mlir_function = """
108        func.func @test(%arg0: tensor<*xf32> {rt.constraint = "rank"},
109                   %arg1: tensor<?x?xf32>,
110                   %arg2: tensor<?x?xf32>,
111                   %arg3: tensor<?x?xf32>) -> tensor<*xf32> {
112          %0 = "tf.Mul"(%arg0, %arg1)
113               : (tensor<*xf32>, tensor<?x?xf32>) -> tensor<*xf32>
114          %1 = "tf.Mul"(%arg2, %arg3)
115               : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
116          %2 = "tf.AddV2"(%0, %1)
117               : (tensor<*xf32>, tensor<?x?xf32>) -> tensor<*xf32>
118          return %2 : tensor<*xf32>
119        }"""
120
121      compiled = jitrt.compile(mlir_function, 'test',
122                               tf_jitrt.Specialization.ALWAYS, vectorize)
123
124      d0 = np.random.randint(4, 8)
125      d1 = np.random.randint(4, 8)
126
127      arg1 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32)
128      arg2 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32)
129      arg3 = np.random.uniform(1.0, 10.0, size=(d0, d1)).astype(np.float32)
130
131      for shape in [(), (d1), (d0, d1)]:
132        arg0 = np.random.uniform(1.0, 10.0, size=shape).astype(np.float32)
133        [res] = jitrt.execute(compiled, [arg0, arg1, arg2, arg3])
134
135        # Function under test spelled in NumPy
136        v0 = arg0 * arg1
137        v1 = arg2 * arg3
138        v3 = v0 + v1
139
140        np.testing.assert_allclose(res, v3, atol=0.0)
141
142  def test_func_3(self):
143    for vectorize in vectorization:
144      mlir_function = """
145        func.func @test(%arg0: tensor<i32>, %arg1: tensor<i32>)
146            -> (tensor<i32>, tensor<i32>) {
147          %c = "tf.Const"() {value = dense<1> : tensor<i32>}
148               : () -> tensor<i32>
149          %0 = "tf.Maximum"(%c, %arg0)
150               : (tensor<i32>, tensor<i32>) -> tensor<i32>
151          %1 = "tf.Minimum"(%arg1, %0)
152               : (tensor<i32>, tensor<i32>) -> tensor<i32>
153        return %0, %1 : tensor<i32>, tensor<i32>
154      }"""
155
156      compiled = jitrt.compile(mlir_function, 'test',
157                               tf_jitrt.Specialization.ALWAYS, vectorize)
158
159      arg0 = np.random.uniform(-100, 100, size=()).astype(np.int32)
160      arg1 = np.random.uniform(-100, 100, size=()).astype(np.int32)
161
162      [res0, res1] = jitrt.execute(compiled, [arg0, arg1])
163
164      # Function under test spelled in NumPy
165      v0 = np.maximum(1, arg0)
166      v1 = np.minimum(arg1, v0)
167
168      np.testing.assert_allclose(res0, v0, atol=0.0)
169      np.testing.assert_allclose(res1, v1, atol=0.0)
170
171
172if __name__ == '__main__':
173  test.main()
174