xref: /aosp_15_r20/external/pytorch/test/jit/test_aten_pow.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.testing._internal.common_utils import TestCase
5
6
7class TestAtenPow(TestCase):
8    def test_aten_pow_zero_negative_exponent(self):
9        """
10        1. Testing a = int, b = int
11        """
12
13        @torch.jit.script
14        def fn_int_int(a: int, b: int):
15            return a**b
16
17        # Existing correct behaviors of aten::pow
18        self.assertEqual(fn_int_int(2, 1), 2**1)
19        self.assertEqual(fn_int_int(2, 0), 2**0)
20        self.assertEqual(fn_int_int(2, -2), 2 ** (-2))
21        self.assertEqual(fn_int_int(-2, 2), (-2) ** 2)
22        self.assertEqual(fn_int_int(-2, 0), (-2) ** 0)
23        self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2))
24        self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1))
25        self.assertEqual(fn_int_int(0, 2), 0**1)
26        self.assertEqual(fn_int_int(0, 0), 0**0)
27        # zero base and negative exponent case that should trigger RunTimeError
28        self.assertRaises(RuntimeError, fn_int_int, 0, -2)
29
30        """
31        2. Testing a = int, b = float
32        """
33
34        @torch.jit.script
35        def fn_int_float(a: int, b: float):
36            return a**b
37
38        # Existing correct behaviors of aten::pow
39        self.assertEqual(fn_int_float(2, 2.5), 2**2.5)
40        self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5))
41        self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0))
42        self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0))
43        self.assertEqual(fn_int_float(-2, 2.0), (-2) ** 2.0)
44        self.assertEqual(fn_int_float(-2, -2.0), (-2) ** (-2.0))
45        self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0))
46        self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0))
47        self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0))
48        self.assertEqual(fn_int_float(0, 2.0), 0**2.0)
49        self.assertEqual(fn_int_float(0, 0.5), 0**0.5)
50        self.assertEqual(fn_int_float(0, 0.0), 0**0.0)
51        self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0))
52        # zero base and negative exponent case that should trigger RunTimeError
53        self.assertRaises(RuntimeError, fn_int_float, 0, -2.5)
54
55        """
56        3. Testing a = float, b = int
57        """
58
59        @torch.jit.script
60        def fn_float_int(a: float, b: int):
61            return a**b
62
63        # Existing correct behaviors of aten::pow
64        self.assertEqual(fn_float_int(2.5, 2), 2.5**2)
65        self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2))
66        self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0))
67        self.assertEqual(fn_float_int(2.5, 0), 2.5**0)
68        self.assertEqual(fn_float_int(-2.5, 2), 2.5**2)
69        self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2))
70        self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3))
71        self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0))
72        self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0)
73        self.assertEqual(fn_float_int(0.0, 2), 0**2)
74        self.assertEqual(fn_float_int(0.0, 0), 0**0)
75        self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0))
76        # zero base and negative exponent case that should trigger RunTimeError
77        self.assertRaises(RuntimeError, fn_float_int, 0.0, -2)
78
79        """
80        4. Testing a = float, b = float
81        """
82
83        @torch.jit.script
84        def fn_float_float(a: float, b: float):
85            return a**b
86
87        # Existing correct behaviors of aten::pow
88        self.assertEqual(fn_float_float(2.5, 2.0), 2.5**2.0)
89        self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0))
90        self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0))
91        self.assertEqual(fn_float_float(2.5, 0.0), 2.5**0.0)
92        self.assertEqual(fn_float_float(-2.5, 2.0), 2.5**2.0)
93        self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0))
94        self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0))
95        self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0))
96        self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0)
97        self.assertEqual(fn_float_float(0.0, 2.0), 0.0**2.0)
98        self.assertEqual(fn_float_float(0.0, 0.0), 0.0**0.0)
99        self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0))
100        # zero base and negative exponent case that should trigger RunTimeError
101        self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0)
102