xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/tf_log1p_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.Log1p JIT compilation."""
16
17import numpy as np
18
19from tensorflow.compiler.mlir.tfrt.jit.python_binding import tf_jitrt
20from tensorflow.python.platform import test
21
22specializations = [
23    tf_jitrt.Specialization.ENABLED,
24    tf_jitrt.Specialization.DISABLED,
25    tf_jitrt.Specialization.ALWAYS,
26]
27
28
29def log1p_1d():
30  return """
31  func.func @log1p(%arg0: tensor<?xf32>) -> tensor<?xf32> {
32    %0 = "tf.Log1p"(%arg0): (tensor<?xf32>) -> tensor<?xf32>
33    func.return %0 : tensor<?xf32>
34  }"""
35
36
37def log1p_2d():
38  return """
39  func.func @log1p(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
40    %0 = "tf.Log1p"(%arg0): (tensor<?x?xf32>) -> tensor<?x?xf32>
41    func.return %0 : tensor<?x?xf32>
42  }"""
43
44
45jitrt = tf_jitrt.TfJitRtExecutor()
46
47
48def test_log1p(fn, rank):
49  for specialize in specializations:
50    compiled = jitrt.compile(fn(), "log1p", specialize)
51
52    for _ in range(100):
53      shape = np.random.randint(0, 10, size=(rank))
54      arg = np.random.uniform(0, 10.0, size=shape).astype(np.float32)
55
56      [res] = jitrt.execute(compiled, [arg])
57      np.testing.assert_allclose(res, np.log1p(arg), atol=1e-06)
58
59
60class TfLog1PTest(test.TestCase):
61
62  def test_1d(self):
63    test_log1p(log1p_1d, 1)
64
65  def test_2d(self):
66    test_log1p(log1p_2d, 2)
67
68
69if __name__ == "__main__":
70  np.random.seed(0)
71  test.main()
72