1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for logical operations JIT compilation.""" 16 17import numpy as np 18 19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt 20from tensorflow.python.platform import test 21 22specializations = [ 23 tf_jitrt.Specialization.ENABLED, 24 tf_jitrt.Specialization.DISABLED, 25 tf_jitrt.Specialization.ALWAYS, 26] 27 28 29def logical_op_1d(op_name): 30 return f""" 31 func.func @test(%arg0: tensor<?xi1>, %arg1: tensor<?xi1>) -> tensor<?xi1> {{ 32 %0 = "tf.{op_name}"(%arg0, %arg1) 33 : (tensor<?xi1>, tensor<?xi1>) -> tensor<?xi1> 34 func.return %0 : tensor<?xi1> 35 }}""" 36 37 38jitrt = tf_jitrt.TfJitRtExecutor() 39 40 41def test_logical_op(mlir_blob, reference_fn, rank): 42 for specialize in specializations: 43 compiled = jitrt.compile(mlir_blob, "test", specialize) 44 45 for _ in range(100): 46 shape = np.random.randint(0, 100, size=(rank)) 47 arg0 = np.random.choice([True, False], size=shape) 48 arg1 = np.random.choice([True, False], size=shape) 49 50 [res] = jitrt.execute(compiled, [arg0, arg1]) 51 np.testing.assert_equal(res, reference_fn(arg0, arg1)) 52 53 54class TfLogicalOpsTest(test.TestCase): 55 56 def test_logical_and_1d(self): 57 test_logical_op(logical_op_1d("LogicalAnd"), np.logical_and, 1) 58 59 def test_logical_or_1d(self): 60 test_logical_op(logical_op_1d("LogicalOr"), np.logical_or, 1) 61 62 63if __name__ == "__main__": 64 np.random.seed(0) 65 test.main() 66