1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing import FileCheck 5from torch.testing._internal.jit_utils import JitTestCase 6 7 8if __name__ == "__main__": 9 raise RuntimeError( 10 "This test file is not meant to be run directly, use:\n\n" 11 "\tpython test/test_jit.py TESTNAME\n\n" 12 "instead." 13 ) 14 15 16class TestOpDecompositions(JitTestCase): 17 def test_op_decomposition(self): 18 def foo(x): 19 return torch.var(x, unbiased=True) 20 21 # TODO: more robust testing 22 foo_s = torch.jit.script(foo) 23 FileCheck().check("aten::var").run(foo_s.graph) 24 torch._C._jit_pass_run_decompositions(foo_s.graph) 25 inp = torch.rand([10, 10]) 26 self.assertEqual(foo(inp), foo_s(inp)) 27 FileCheck().check_not("aten::var").run(foo_s.graph) 28 29 def test_registered_decomposition(self): 30 @torch.jit.script 31 def foo(x): 32 return torch.square(x) 33 34 @torch.jit.script 35 def square_decomp(x): 36 return torch.pow(x, 2) 37 38 torch.jit._register_decomposition( 39 torch.ops.aten.square.default, square_decomp.graph 40 ) 41 torch._C._jit_pass_run_decompositions(foo.graph) 42 FileCheck().check_not("aten::square").check("aten::pow").run(foo.graph) 43 x = torch.rand([4]) 44 self.assertEqual(foo(x), torch.square(x)) 45