1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing._internal.common_utils import TestCase 5 6 7class TestAtenPow(TestCase): 8 def test_aten_pow_zero_negative_exponent(self): 9 """ 10 1. Testing a = int, b = int 11 """ 12 13 @torch.jit.script 14 def fn_int_int(a: int, b: int): 15 return a**b 16 17 # Existing correct behaviors of aten::pow 18 self.assertEqual(fn_int_int(2, 1), 2**1) 19 self.assertEqual(fn_int_int(2, 0), 2**0) 20 self.assertEqual(fn_int_int(2, -2), 2 ** (-2)) 21 self.assertEqual(fn_int_int(-2, 2), (-2) ** 2) 22 self.assertEqual(fn_int_int(-2, 0), (-2) ** 0) 23 self.assertEqual(fn_int_int(-2, -2), (-2) ** (-2)) 24 self.assertEqual(fn_int_int(-2, -1), (-2) ** (-1)) 25 self.assertEqual(fn_int_int(0, 2), 0**1) 26 self.assertEqual(fn_int_int(0, 0), 0**0) 27 # zero base and negative exponent case that should trigger RunTimeError 28 self.assertRaises(RuntimeError, fn_int_int, 0, -2) 29 30 """ 31 2. Testing a = int, b = float 32 """ 33 34 @torch.jit.script 35 def fn_int_float(a: int, b: float): 36 return a**b 37 38 # Existing correct behaviors of aten::pow 39 self.assertEqual(fn_int_float(2, 2.5), 2**2.5) 40 self.assertEqual(fn_int_float(2, -2.5), 2 ** (-2.5)) 41 self.assertEqual(fn_int_float(2, -0.0), 2 ** (-0.0)) 42 self.assertEqual(fn_int_float(2, 0.0), 2 ** (0.0)) 43 self.assertEqual(fn_int_float(-2, 2.0), (-2) ** 2.0) 44 self.assertEqual(fn_int_float(-2, -2.0), (-2) ** (-2.0)) 45 self.assertEqual(fn_int_float(-2, -3.0), (-2) ** (-3.0)) 46 self.assertEqual(fn_int_float(-2, -0.0), (-2) ** (-0.0)) 47 self.assertEqual(fn_int_float(-2, 0.0), (-2) ** (0.0)) 48 self.assertEqual(fn_int_float(0, 2.0), 0**2.0) 49 self.assertEqual(fn_int_float(0, 0.5), 0**0.5) 50 self.assertEqual(fn_int_float(0, 0.0), 0**0.0) 51 self.assertEqual(fn_int_float(0, -0.0), 0 ** (-0.0)) 52 # zero base and negative exponent case that should trigger RunTimeError 53 self.assertRaises(RuntimeError, fn_int_float, 0, -2.5) 54 55 """ 56 3. Testing a = float, b = int 57 """ 58 59 @torch.jit.script 60 def fn_float_int(a: float, b: int): 61 return a**b 62 63 # Existing correct behaviors of aten::pow 64 self.assertEqual(fn_float_int(2.5, 2), 2.5**2) 65 self.assertEqual(fn_float_int(2.5, -2), 2.5 ** (-2)) 66 self.assertEqual(fn_float_int(2.5, -0), 2.5 ** (-0)) 67 self.assertEqual(fn_float_int(2.5, 0), 2.5**0) 68 self.assertEqual(fn_float_int(-2.5, 2), 2.5**2) 69 self.assertEqual(fn_float_int(-2.5, -2), (-2.5) ** (-2)) 70 self.assertEqual(fn_float_int(-2.5, -3), (-2.5) ** (-3)) 71 self.assertEqual(fn_float_int(-2.5, -0), (-2.5) ** (-0)) 72 self.assertEqual(fn_float_int(-2.5, 0), (-2.5) ** 0) 73 self.assertEqual(fn_float_int(0.0, 2), 0**2) 74 self.assertEqual(fn_float_int(0.0, 0), 0**0) 75 self.assertEqual(fn_float_int(0.0, -0), 0 ** (-0)) 76 # zero base and negative exponent case that should trigger RunTimeError 77 self.assertRaises(RuntimeError, fn_float_int, 0.0, -2) 78 79 """ 80 4. Testing a = float, b = float 81 """ 82 83 @torch.jit.script 84 def fn_float_float(a: float, b: float): 85 return a**b 86 87 # Existing correct behaviors of aten::pow 88 self.assertEqual(fn_float_float(2.5, 2.0), 2.5**2.0) 89 self.assertEqual(fn_float_float(2.5, -2.0), 2.5 ** (-2.0)) 90 self.assertEqual(fn_float_float(2.5, -0.0), 2.5 ** (-0.0)) 91 self.assertEqual(fn_float_float(2.5, 0.0), 2.5**0.0) 92 self.assertEqual(fn_float_float(-2.5, 2.0), 2.5**2.0) 93 self.assertEqual(fn_float_float(-2.5, -2.0), (-2.5) ** (-2.0)) 94 self.assertEqual(fn_float_float(-2.5, -3.0), (-2.5) ** (-3.0)) 95 self.assertEqual(fn_float_float(-2.5, -0.0), (-2.5) ** (-0.0)) 96 self.assertEqual(fn_float_float(-2.5, 0.0), (-2.5) ** 0.0) 97 self.assertEqual(fn_float_float(0.0, 2.0), 0.0**2.0) 98 self.assertEqual(fn_float_float(0.0, 0.0), 0.0**0.0) 99 self.assertEqual(fn_float_float(0.0, -0.0), 0.0 ** (-0.0)) 100 # zero base and negative exponent case that should trigger RunTimeError 101 self.assertRaises(RuntimeError, fn_float_float, 0.0, -2.0) 102