1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for JIT compilation of functions with multiple results.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28jitrt = tf_jitrt.TfJitRtExecutor() 29 30 31class MultipleResultsTest(test.TestCase): 32 33 def test_two_results(self): 34 for specialize in specializations: 35 mlir_function = """ 36 func.func @test(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { 37 %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } 38 : () -> tensor<f32> 39 %1 = "tf.AddV2"(%arg0, %0) 40 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 41 %2 = "tf.AddV2"(%1, %0) 42 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 43 func.return %1, %2 : tensor<?xf32>, tensor<?xf32> 44 }""" 45 46 compiled = jitrt.compile(mlir_function, 'test', specialize) 47 48 d0 = np.random.randint(1, 10) 49 arg0 = np.zeros(d0, np.float32) 50 51 [res0, res1] = jitrt.execute(compiled, [arg0]) 52 np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0) 53 np.testing.assert_allclose(res1, arg0 + 2.0, atol=0.0) 54 55 def test_three_results(self): 56 for specialize in specializations: 57 mlir_function = """ 58 func.func @test(%arg0: tensor<?xf32>) -> 59 (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) { 60 %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } 61 : () -> tensor<f32> 62 %1 = "tf.AddV2"(%arg0, %0) 63 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 64 %2 = "tf.AddV2"(%1, %0) 65 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 66 %3 = "tf.AddV2"(%2, %0) 67 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 68 func.return %1, %2, %3 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32> 69 }""" 70 71 compiled = jitrt.compile(mlir_function, 'test', specialize) 72 73 d0 = np.random.randint(1, 10) 74 arg0 = np.zeros(d0, np.float32) 75 76 [res0, res1, res2] = jitrt.execute(compiled, [arg0]) 77 np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0) 78 np.testing.assert_allclose(res1, arg0 + 2.0, atol=0.0) 79 np.testing.assert_allclose(res2, arg0 + 3.0, atol=0.0) 80 81 def test_same_tensor_returned_twice(self): 82 for specialize in specializations: 83 mlir_function = """ 84 func.func @test(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) { 85 %0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } 86 : () -> tensor<f32> 87 %1 = "tf.AddV2"(%arg0, %0) 88 : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32> 89 func.return %1, %1 : tensor<?xf32>, tensor<?xf32> 90 }""" 91 92 compiled = jitrt.compile(mlir_function, 'test', specialize) 93 94 d0 = np.random.randint(1, 10) 95 arg0 = np.zeros(d0, np.float32) 96 97 [res0, res1] = jitrt.execute(compiled, [arg0]) 98 np.testing.assert_allclose(res0, arg0 + 1.0, atol=0.0) 99 np.testing.assert_allclose(res1, arg0 + 1.0, atol=0.0) 100 101 102if __name__ == '__main__': 103 np.random.seed(0) 104 test.main() 105