1# Owner(s): ["module: onnx"] 2 3import onnx_test_common 4import pytorch_test_common 5 6import torch 7import torch.utils.cpp_extension 8from torch.onnx import symbolic_helper 9from torch.testing._internal import common_utils 10 11 12class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase): 13 opset_version = 9 14 keep_initializers_as_inputs = False 15 onnx_shape_inference = True 16 17 def test_symbolic(self): 18 class MyClip(torch.autograd.Function): 19 @staticmethod 20 def forward(ctx, input, scalar): 21 ctx.save_for_backward(input) 22 return input.clamp(min=scalar) 23 24 @staticmethod 25 def symbolic(g, input, scalar): 26 return g.op("Clip", input, min_f=scalar) 27 28 class MyModule(torch.nn.Module): 29 def __init__(self) -> None: 30 super().__init__() 31 self.clip = MyClip.apply 32 33 def forward(self, x): 34 h = self.clip(x, 2) 35 return h 36 37 x = torch.randn(2, 3, 4, requires_grad=True) 38 model = MyModule() 39 onnx_test_common.run_model_test(self, model, input_args=(x,)) 40 41 def test_register_op(self): 42 class MyClip(torch.autograd.Function): 43 @staticmethod 44 def forward(ctx, input, scalar): 45 ctx.save_for_backward(input) 46 return input.clamp(min=scalar) 47 48 class MyRelu(torch.autograd.Function): 49 @staticmethod 50 def forward(ctx, input): 51 ctx.save_for_backward(input) 52 return input.clamp(min=0) 53 54 class MyModule(torch.nn.Module): 55 def __init__(self) -> None: 56 super().__init__() 57 self.clip = MyClip.apply 58 self.relu = MyRelu.apply 59 60 def forward(self, x): 61 h = self.clip(x, 2) 62 h = self.relu(h) 63 return h 64 65 def symbolic_pythonop(g, *args, **kwargs): 66 name = kwargs["name"] 67 if name == "MyClip": 68 return g.op("Clip", args[0], min_f=args[1]) 69 elif name == "MyRelu": 70 return g.op("Relu", args[0]) 71 else: 72 return symbolic_helper._unimplemented( 73 "prim::PythonOp", "unknown node kind: " + name 74 ) 75 76 from torch.onnx import register_custom_op_symbolic 77 78 register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1) 79 80 x = torch.randn(2, 3, 4, requires_grad=True) 81 model = MyModule() 82 onnx_test_common.run_model_test(self, model, input_args=(x,)) 83 84 85class TestExportAsContribOps(pytorch_test_common.ExportTestCase): 86 opset_version = 14 87 keep_initializers_as_inputs = False 88 onnx_shape_inference = True 89 90 def test_contrib_op_with_loop(self): 91 class M(torch.nn.Module): 92 def __init__(self) -> None: 93 super().__init__() 94 self.gelu = torch.nn.GELU(approximate="none") 95 96 def forward(self, x): 97 res = [] 98 res2 = [] 99 for i in range(x.size(0)): 100 if len(res) > 0: 101 res2.append(res[0]) 102 else: 103 res2.append(self.gelu(x[0])) 104 res.append(x[0]) 105 return torch.stack(res), torch.stack(res2) 106 107 def symbolic_custom_gelu(g, input, approximate): 108 return g.op("com.microsoft::Gelu", input).setType(input.type()) 109 110 from torch.onnx import register_custom_op_symbolic 111 112 register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1) 113 114 x = torch.randn(3, 3, 4, requires_grad=True) 115 model = torch.jit.script(M()) 116 onnx_test_common.run_model_test(self, model, input_args=(x,)) 117 118 119if __name__ == "__main__": 120 common_utils.run_tests() 121