xref: /aosp_15_r20/external/pytorch/test/onnx/test_custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2
3import onnx_test_common
4import pytorch_test_common
5
6import torch
7import torch.utils.cpp_extension
8from torch.onnx import symbolic_helper
9from torch.testing._internal import common_utils
10
11
12class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
13    opset_version = 9
14    keep_initializers_as_inputs = False
15    onnx_shape_inference = True
16
17    def test_symbolic(self):
18        class MyClip(torch.autograd.Function):
19            @staticmethod
20            def forward(ctx, input, scalar):
21                ctx.save_for_backward(input)
22                return input.clamp(min=scalar)
23
24            @staticmethod
25            def symbolic(g, input, scalar):
26                return g.op("Clip", input, min_f=scalar)
27
28        class MyModule(torch.nn.Module):
29            def __init__(self) -> None:
30                super().__init__()
31                self.clip = MyClip.apply
32
33            def forward(self, x):
34                h = self.clip(x, 2)
35                return h
36
37        x = torch.randn(2, 3, 4, requires_grad=True)
38        model = MyModule()
39        onnx_test_common.run_model_test(self, model, input_args=(x,))
40
41    def test_register_op(self):
42        class MyClip(torch.autograd.Function):
43            @staticmethod
44            def forward(ctx, input, scalar):
45                ctx.save_for_backward(input)
46                return input.clamp(min=scalar)
47
48        class MyRelu(torch.autograd.Function):
49            @staticmethod
50            def forward(ctx, input):
51                ctx.save_for_backward(input)
52                return input.clamp(min=0)
53
54        class MyModule(torch.nn.Module):
55            def __init__(self) -> None:
56                super().__init__()
57                self.clip = MyClip.apply
58                self.relu = MyRelu.apply
59
60            def forward(self, x):
61                h = self.clip(x, 2)
62                h = self.relu(h)
63                return h
64
65        def symbolic_pythonop(g, *args, **kwargs):
66            name = kwargs["name"]
67            if name == "MyClip":
68                return g.op("Clip", args[0], min_f=args[1])
69            elif name == "MyRelu":
70                return g.op("Relu", args[0])
71            else:
72                return symbolic_helper._unimplemented(
73                    "prim::PythonOp", "unknown node kind: " + name
74                )
75
76        from torch.onnx import register_custom_op_symbolic
77
78        register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
79
80        x = torch.randn(2, 3, 4, requires_grad=True)
81        model = MyModule()
82        onnx_test_common.run_model_test(self, model, input_args=(x,))
83
84
85class TestExportAsContribOps(pytorch_test_common.ExportTestCase):
86    opset_version = 14
87    keep_initializers_as_inputs = False
88    onnx_shape_inference = True
89
90    def test_contrib_op_with_loop(self):
91        class M(torch.nn.Module):
92            def __init__(self) -> None:
93                super().__init__()
94                self.gelu = torch.nn.GELU(approximate="none")
95
96            def forward(self, x):
97                res = []
98                res2 = []
99                for i in range(x.size(0)):
100                    if len(res) > 0:
101                        res2.append(res[0])
102                    else:
103                        res2.append(self.gelu(x[0]))
104                    res.append(x[0])
105                return torch.stack(res), torch.stack(res2)
106
107        def symbolic_custom_gelu(g, input, approximate):
108            return g.op("com.microsoft::Gelu", input).setType(input.type())
109
110        from torch.onnx import register_custom_op_symbolic
111
112        register_custom_op_symbolic("::gelu", symbolic_custom_gelu, 1)
113
114        x = torch.randn(3, 3, 4, requires_grad=True)
115        model = torch.jit.script(M())
116        onnx_test_common.run_model_test(self, model, input_args=(x,))
117
118
119if __name__ == "__main__":
120    common_utils.run_tests()
121