xref: /aosp_15_r20/external/pytorch/test/onnx/test_pytorch_jit_onnx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2import onnxruntime
3import pytorch_test_common
4from pytorch_test_common import skipIfNoCuda
5
6import torch
7from torch.onnx import verification
8from torch.onnx._globals import GLOBALS
9from torch.testing._internal import common_utils
10
11
12def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version):
13    r"""
14    This function exports torch::jit::Graph object
15    to serialized ONNX ModelProto.
16    This function is for testing purpose.
17    It only keeps the essential parts for IR graph conversions.
18    It also does not interact with actual PyTorch modules nor
19    PyTorch tensor inputs.
20    """
21
22    GLOBALS.export_onnx_opset_version = opset_version
23    graph = torch.onnx.utils._optimize_graph(
24        graph, operator_export_type, params_dict={}
25    )
26    proto, _, _, _ = graph._export_onnx(
27        {},
28        opset_version,
29        {},
30        False,
31        operator_export_type,
32        False,
33        False,
34        {},
35        True,
36        "",
37        {},
38    )
39    return proto
40
41
42class _TestJITIRToONNX:
43    """Abstract base class for test cases.
44
45    Intentionally not a sub-class of unittest.TestCase so that unittest / pytest
46    don't run it directly. unitest.TestCase is mixed in as another base class when
47    creating concrete sub-types. See MakeTestCase().
48    """
49
50    opset_version = -1  # Sub-classes must override
51    ort_providers = ["CPUExecutionProvider"]
52    check_shape = True
53    check_dtype = True
54    ignore_none = True  # True for tracing, and Flase for scripting
55
56    def run_test(self, graph_ir, example_inputs, parse_tensor_constants=False):
57        graph = torch._C.parse_ir(graph_ir, parse_tensor_constants)
58        jit_outs = torch._C._jit_interpret_graph(graph, example_inputs)
59
60        onnx_proto = _jit_graph_to_onnx_model(
61            graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version
62        )
63        ort_sess = onnxruntime.InferenceSession(
64            onnx_proto, providers=self.ort_providers
65        )
66        ort_outs = verification._run_onnx(ort_sess, example_inputs)
67
68        options = verification.VerificationOptions(
69            rtol=1e-3,
70            atol=1e-7,
71            check_shape=self.check_shape,
72            check_dtype=self.check_dtype,
73            ignore_none=self.ignore_none,
74            acceptable_error_percentage=None,
75        )
76        verification._compare_onnx_pytorch_outputs(
77            ort_outs,
78            jit_outs,
79            options,
80        )
81
82    def test_example_ir(self):
83        graph_ir = """
84        graph(%1 : Float(2, 3),
85              %2 : Float(2, 3)):
86          %3 : int = prim::Constant[value=1]()
87          %4 : Float(2, 3) = aten::add(%1, %2, %3)
88          return (%4)
89        """
90        a = torch.randn(2, 3)
91        b = torch.randn(2, 3)
92        self.run_test(graph_ir, (a, b))
93
94    def test_where_constants(self):
95        graph_ir = """
96        graph(%0 : Bool(8, device=cpu),
97              %1 : Float(8, device=cpu)):
98          %3 : Double(device=cpu) = prim::Constant[value={0.}]()
99          %4 : Float(8) = aten::where(%0, %1, %3)
100          return (%4)
101        """
102        a = torch.zeros(8, dtype=bool)
103        b = torch.zeros(8)
104        self.run_test(graph_ir, (a, b), parse_tensor_constants=True)
105
106    def test_add_sub_with_graph_inputs(self):
107        for op in ["add", "sub", "rsub"]:
108            graph_ir = f"""
109            graph(%1 : Float(2, 3),
110                  %2 : Float(2, 3),
111                  %3 : int):
112              %4 : Float(2, 3) = aten::{op}(%1, %2, %3)
113              return (%4)
114            """
115            a = torch.randn(2, 3)
116            b = torch.randn(2, 3)
117            self.run_test(graph_ir, (a, b, 2))
118
119    def test_native_layer_norm(self):
120        graph_ir = """
121        graph(%x : Float(2, 3, 2),
122              %w : Float(3, 2),
123              %b : Float(3, 2)):
124          %5 : int = prim::Constant[value=3]()
125          %6 : int = prim::Constant[value=2]()
126          %7 : int[] = prim::ListConstruct(%5, %6)
127          %10 : float = prim::Constant[value=1.0000000000000001e-05]()
128          %11 : Float(2, 3, 2), %12 : Float(2, 1, 1), %13 : Float(2, 1, 1) = aten::native_layer_norm(%x, %7, %w, %b, %10)
129          return (%11, %12, %13)
130        """
131        x = torch.randn(2, 3, 2)
132        w = torch.randn(3, 2)
133        b = torch.randn(3, 2)
134        self.run_test(graph_ir, (x, w, b))
135
136    def test_convolution(self):
137        graph_ir = """
138        graph(%1 : Tensor,
139              %2 : Tensor):
140          %3 : NoneType = prim::Constant()
141          %4 : int[] = prim::Constant[value=[1, 1]]()
142          %5 : int[] = prim::Constant[value=[0, 0]]()
143          %6 : bool = prim::Constant[value=0]()
144          %7 : int = prim::Constant[value=1]()
145          %8 : Tensor = aten::convolution(%1, %2, %3, %4, %5, %4, %6, %5, %7)
146          return (%8)
147        """
148        x = torch.randn(8, 1, 5, 5)
149        w = torch.randn(4, 1, 3, 3)
150        self.run_test(graph_ir, (x, w))
151
152    def test_log_softmax(self):
153        graph_ir = """
154        graph(%x: Tensor):
155          %half_to_float: bool = prim::Constant[value=0]()
156          %dim: int = prim::Constant[value=1]()
157          %y = aten::_log_softmax(%x, %dim, %half_to_float)
158          return (%y)
159        """
160        x = torch.randn(5, 2)
161        self.run_test(graph_ir, (x,))
162
163    @skipIfNoCuda
164    def test_log_softmax_half_to_float(self):
165        graph_ir = """
166        graph(%x: Tensor):
167          %half_to_float: bool = prim::Constant[value=1]()
168          %dim: int = prim::Constant[value=1]()
169          %y = aten::_log_softmax(%x, %dim, %half_to_float)
170          return (%y)
171        """
172        x = torch.randn(5, 2).half().to("cuda")
173        self.run_test(graph_ir, (x,))
174
175    def test_native_dropout(self):
176        graph_ir = """
177        graph(%1 : Float(2, 3)):
178          %2 : float = prim::Constant[value=0.0]()
179          %training : bool = prim::Constant[value=1]()
180          %3 : Tensor, %4 : Tensor = aten::native_dropout(%1, %2, %training)
181          return (%3, %4)
182        """
183        a = torch.randn(2, 3)
184        self.run_test(graph_ir, (a,))
185
186
187def MakeTestCase(opset_version: int) -> type:
188    name = f"TestJITIRToONNX_opset{opset_version}"
189    return type(
190        str(name),
191        (pytorch_test_common.ExportTestCase,),
192        dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
193    )
194
195
196TestJITIRToONNX_opset14 = MakeTestCase(14)
197
198if __name__ == "__main__":
199    common_utils.run_tests()
200