xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/tf_logical_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for logical operations JIT compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28
29def logical_op_1d(op_name):
30  return f"""
31  func.func @test(%arg0: tensor<?xi1>, %arg1: tensor<?xi1>) -> tensor<?xi1> {{
32    %0 = "tf.{op_name}"(%arg0, %arg1)
33        : (tensor<?xi1>, tensor<?xi1>) -> tensor<?xi1>
34    func.return %0 : tensor<?xi1>
35  }}"""
36
37
38jitrt = tf_jitrt.TfJitRtExecutor()
39
40
41def test_logical_op(mlir_blob, reference_fn, rank):
42  for specialize in specializations:
43    compiled = jitrt.compile(mlir_blob, "test", specialize)
44
45    for _ in range(100):
46      shape = np.random.randint(0, 100, size=(rank))
47      arg0 = np.random.choice([True, False], size=shape)
48      arg1 = np.random.choice([True, False], size=shape)
49
50      [res] = jitrt.execute(compiled, [arg0, arg1])
51      np.testing.assert_equal(res, reference_fn(arg0, arg1))
52
53
54class TfLogicalOpsTest(test.TestCase):
55
56  def test_logical_and_1d(self):
57    test_logical_op(logical_op_1d("LogicalAnd"), np.logical_and, 1)
58
59  def test_logical_or_1d(self):
60    test_logical_op(logical_op_1d("LogicalOr"), np.logical_or, 1)
61
62
63if __name__ == "__main__":
64  np.random.seed(0)
65  test.main()
66