xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/multiple_results_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for JIT compilation of functions with multiple results."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28jitrt = tf_jitrt.TfJitRtExecutor()
29
30
31class MultipleResultsTest(test.TestCase):
32
33  def test_two_results(self):
34    for specialize in specializations:
35      mlir_function = """
36        func.func @test(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
37          %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> }
38               : () -> tensor<f32>
39          %1 = "tf.AddV2"(%arg0, %0)
40               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
41          %2 = "tf.AddV2"(%1, %0)
42               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
43          func.return %1, %2 : tensor<?xf32>, tensor<?xf32>
44        }"""
45
46      compiled = jitrt.compile(mlir_function, 'test', specialize)
47
48      d0 = np.random.randint(1, 10)
49      arg0 = np.zeros(d0, np.float32)
50
51      [res0, res1] = jitrt.execute(compiled, [arg0])
52      np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0)
53      np.testing.assert_allclose(res1, arg0 + 2.0, atol=0.0)
54
55  def test_three_results(self):
56    for specialize in specializations:
57      mlir_function = """
58        func.func @test(%arg0: tensor<?xf32>) ->
59            (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
60          %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> }
61               : () -> tensor<f32>
62          %1 = "tf.AddV2"(%arg0, %0)
63               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
64          %2 = "tf.AddV2"(%1, %0)
65               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
66          %3 = "tf.AddV2"(%2, %0)
67               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
68          func.return %1, %2, %3 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
69        }"""
70
71      compiled = jitrt.compile(mlir_function, 'test', specialize)
72
73      d0 = np.random.randint(1, 10)
74      arg0 = np.zeros(d0, np.float32)
75
76      [res0, res1, res2] = jitrt.execute(compiled, [arg0])
77      np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0)
78      np.testing.assert_allclose(res1, arg0 + 2.0, atol=0.0)
79      np.testing.assert_allclose(res2, arg0 + 3.0, atol=0.0)
80
81  def test_same_tensor_returned_twice(self):
82    for specialize in specializations:
83      mlir_function = """
84        func.func @test(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
85          %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> }
86               : () -> tensor<f32>
87          %1 = "tf.AddV2"(%arg0, %0)
88               : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
89          func.return %1, %1 : tensor<?xf32>, tensor<?xf32>
90        }"""
91
92      compiled = jitrt.compile(mlir_function, 'test', specialize)
93
94      d0 = np.random.randint(1, 10)
95      arg0 = np.zeros(d0, np.float32)
96
97      [res0, res1] = jitrt.execute(compiled, [arg0])
98      np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0)
99      np.testing.assert_allclose(res1, arg0 + 1.0, atol=0.0)
100
101
102if __name__ == '__main__':
103  np.random.seed(0)
104  test.main()
105