xref: /aosp_15_r20/external/pytorch/test/onnx/torch_export/test_torch_export_with_onnxruntime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: onnx"]
2from __future__ import annotations
3
4import os
5import sys
6
7import torch
8import torch.onnx
9from torch.testing._internal import common_utils
10from torch.utils import _pytree as torch_pytree
11
12
13sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14import onnx_test_common
15
16
17class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
18    def _compare_onnx_and_torch_exported_program(
19        self,
20        torch_exported_program,
21        onnx_exported_program,
22        input_args,
23        input_kwargs=None,
24        rtol=1e-03,
25        atol=1e-07,
26    ):
27        # avoid mutable default argument
28        if input_kwargs is None:
29            input_kwargs = {}
30
31        # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict.
32        # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict.
33        # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__()
34        onnx_outputs = onnx_exported_program(*input_args, **input_kwargs)
35        if isinstance(torch_exported_program, torch.export.ExportedProgram):
36            torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs)
37        else:
38            torch_outputs = torch_exported_program(*input_args, **input_kwargs)
39
40        if isinstance(torch_outputs, torch.Tensor):
41            torch_outputs = [torch_outputs]
42
43        if len(torch_outputs) != len(onnx_outputs):
44            raise AssertionError(
45                f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}"
46            )
47        for torch_output, onnx_output in zip(torch_outputs, onnx_outputs):
48            torch.testing.assert_close(
49                torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol
50            )
51
52    def test_exported_program_with_dynamic_input(self):
53        class Model(torch.nn.Module):
54            def forward(self, x):
55                return x + 1.0
56
57        x = torch.randn(2, 3, 4, dtype=torch.float)
58        dim0 = torch.export.Dim("dim0")
59        exported_program = torch.export.export(
60            Model(), (x,), dynamic_shapes={"x": {0: dim0}}
61        )
62        onnx_program = torch.onnx.dynamo_export(exported_program, x)
63
64        # different dim inputs
65        y = torch.randn(3, 3, 4, dtype=torch.float)
66        self._compare_onnx_and_torch_exported_program(
67            exported_program, onnx_program, input_args=(y,)
68        )
69
70    def test_exported_program_as_input_from_file(self):
71        import tempfile
72
73        class Model(torch.nn.Module):
74            def forward(self, x):
75                return x + 1.0
76
77        x = torch.randn(1, 1, 2, dtype=torch.float)
78        exported_program = torch.export.export(Model(), args=(x,))
79        onnx_program = torch.onnx.dynamo_export(exported_program, x)
80
81        with tempfile.NamedTemporaryFile(suffix=".pte") as f:
82            torch.export.save(exported_program, f.name)
83            del (
84                exported_program
85            )  # Delete the exported program to ensure that we are loading from file
86            loaded_exported_program = torch.export.load(f.name)
87
88        self._compare_onnx_and_torch_exported_program(
89            loaded_exported_program, onnx_program, input_args=(x,)
90        )
91
92    def test_exported_program_with_specialized_input_during_tracing(self):
93        class Foo(torch.nn.Module):
94            def forward(self, x, y):
95                return x + y
96
97        f = Foo()
98
99        tensor_input = torch.ones(7, 5)
100        dim0_x = torch.export.Dim("dim0_x", min=6)
101        dynamic_shapes = {"x": {0: dim0_x}, "y": None}
102        # specialized input y to 5 during tracing
103        exported_program = torch.export.export(
104            f, (tensor_input, 5), dynamic_shapes=dynamic_shapes
105        )
106        onnx_program = torch.onnx.dynamo_export(exported_program, tensor_input, 5)
107
108        # different dim inputs
109        additional_tensor_input = torch.ones(8, 5)
110        self._compare_onnx_and_torch_exported_program(
111            exported_program, onnx_program, input_args=(additional_tensor_input, 5)
112        )
113
114    def test_onnx_program_supports_retraced_graph(self):
115        class Bar(torch.nn.Module):
116            def __init__(self) -> None:
117                super().__init__()
118                self.buf = torch.nn.Buffer(torch.ones(1))
119
120            def forward(self, x):
121                self.buf.add_(1)
122                return x.sum() + self.buf.sum()
123
124        class Foo(torch.nn.Module):
125            def __init__(self) -> None:
126                super().__init__()
127                self.buf = torch.nn.Buffer(torch.zeros(1))
128                self.bar = Bar()
129
130            def forward(self, x):
131                self.buf.add_(1)
132                bar = self.bar(x)
133                self.bar.buf.add_(2)
134                return bar.sum() + self.buf.sum()
135
136        tensor_input = torch.ones(5, 5)
137        exported_program = torch.export.export(Foo(), (tensor_input,))
138
139        dim0_x = torch.export.Dim("dim0_x")
140        # NOTE: If input is ExportedProgram, we need to specify dynamic_shapes
141        # as a tuple.
142        reexported_program = torch.export.export(
143            exported_program.module(), (tensor_input,), dynamic_shapes=({0: dim0_x},)
144        )
145        reexported_onnx_program = torch.onnx.dynamo_export(
146            reexported_program, tensor_input
147        )
148
149        additional_tensor_input = torch.ones(7, 5)
150        self._compare_onnx_and_torch_exported_program(
151            reexported_program,
152            reexported_onnx_program,
153            input_args=(additional_tensor_input,),
154        )
155
156    def test_onnx_program_supports_none_arg_name_in_dynamic(self):
157        class Foo(torch.nn.Module):
158            def forward(self, a, b):
159                return a.sum() + b.sum()
160
161        foo = Foo()
162
163        dim = torch.export.Dim("dim")
164        exported_program = torch.export.export(
165            foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim})
166        )
167        onnx_program = torch.onnx.dynamo_export(
168            exported_program, torch.randn(4, 4), torch.randn(4, 4)
169        )
170
171        test_inputs = (
172            torch.randn(4, 4),
173            torch.randn(7, 4),
174        )
175        self._compare_onnx_and_torch_exported_program(
176            exported_program, onnx_program, test_inputs
177        )
178
179    def test_onnx_program_suppors_non_arg_name_with_kwarg(self):
180        class Foo(torch.nn.Module):
181            def forward(self, a, b, kw1, kw2):
182                return a.sum() + b.sum() + kw1.sum() - kw2.sum()
183
184        foo = Foo()
185
186        dim = torch.export.Dim("dim")
187        dim_for_kw1 = torch.export.Dim("dim_for_kw1")
188        exported_program = torch.export.export(
189            foo,
190            (torch.randn(4, 4), torch.randn(4, 4)),
191            {"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)},
192            # We are specifying dynamism on the first kwarg even though user passed in
193            # different order
194            dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None),
195        )
196        onnx_program = torch.onnx.dynamo_export(
197            exported_program,
198            torch.randn(4, 4),
199            torch.randn(4, 4),
200            kw2=torch.ones(4, 4),
201            kw1=torch.zeros(4, 4),
202        )
203
204        test_inputs = (torch.randn(4, 4), torch.randn(7, 4))
205        test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)}
206        # This should work even if the kwarg order are flipped.
207        self._compare_onnx_and_torch_exported_program(
208            exported_program, onnx_program, test_inputs, test_kwargs
209        )
210
211    def test_exported_program_as_input_lifting_buffers_mutation(self):
212        for persistent in (True, False):
213
214            class CustomModule(torch.nn.Module):
215                def __init__(self) -> None:
216                    super().__init__()
217                    self.register_buffer(
218                        "my_buffer", torch.tensor(4.0), persistent=persistent
219                    )
220
221                def forward(self, x, b):
222                    output = x + b
223                    (
224                        self.my_buffer.add_(1.0) + 3.0
225                    )  # Mutate buffer through in-place addition
226                    return output
227
228            input_x = torch.rand((3, 3), dtype=torch.float32)
229            input_b = torch.randn(3, 3)
230            model = CustomModule()
231
232            dim = torch.export.Dim("dim")
233            exported_program = torch.export.export(
234                model,
235                (
236                    input_x,
237                    input_b,
238                ),
239                dynamic_shapes=({0: dim}, {0: dim}),
240            )
241            onnx_program = torch.onnx.dynamo_export(exported_program, input_x, input_b)
242
243            # different dim inputs
244            additional_inputs_x = torch.rand((4, 3), dtype=torch.float32)
245            additional_inputs_b = torch.randn(4, 3)
246            self._compare_onnx_and_torch_exported_program(
247                exported_program,
248                onnx_program,
249                (
250                    additional_inputs_x,
251                    additional_inputs_b,
252                ),
253            )
254
255    def test_onnx_program_supports_non_arg_name_with_container_type(self):
256        class Foo(torch.nn.Module):
257            def forward(self, a, b):
258                return a[0].sum() + a[1].sum() + b.sum()
259
260        foo = Foo()
261
262        inp_a = (torch.randn(4, 4), torch.randn(4, 4))
263        inp_b = torch.randn(4, 4)
264        inp = (inp_a, inp_b)
265
266        count = 0
267
268        def dynamify_inp(x):
269            # Mark the second input a[1] dynamic
270            nonlocal count
271            if count == 1:
272                dim = torch.export.Dim("dim", min=3)
273                count += 1
274                return {0: dim}
275            count += 1
276            return None
277
278        dynamic_shapes = torch_pytree.tree_map(dynamify_inp, inp)
279        exported_program = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
280        onnx_program = torch.onnx.dynamo_export(exported_program, inp_a, inp_b)
281
282        # NOTE: Careful with the input format. The input format should be
283        # consistent with how the model is exported.
284        test_inputs = ((torch.randn(4, 4), torch.randn(6, 4)), torch.randn(4, 4))
285        self._compare_onnx_and_torch_exported_program(
286            exported_program, onnx_program, test_inputs
287        )
288
289    def test_onnx_program_supports_lazy_module_kwargs(self):
290        class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
291            def initialize_parameters(self, *args, **kwargs):
292                pass
293
294            def forward(self, x, y):
295                return x + y
296
297        m = LazyModule()
298        dim = torch.export.Dim("dim")
299        dynamic_shapes = ({0: dim}, {0: dim})
300        exported_program = torch.export.export(
301            m,
302            (),
303            {"x": torch.randn(3, 3), "y": torch.randn(3, 3)},
304            dynamic_shapes=dynamic_shapes,
305        )
306        onnx_program = torch.onnx.dynamo_export(
307            exported_program, x=torch.randn(3, 3), y=torch.randn(3, 3)
308        )
309
310        # NOTE: A model should be fed with the input formats that
311        # how the model is exported
312        inputs = {"x": torch.randn(6, 3), "y": torch.randn(6, 3)}
313        self._compare_onnx_and_torch_exported_program(
314            exported_program, onnx_program, input_args=(), input_kwargs=inputs
315        )
316
317
318if __name__ == "__main__":
319    common_utils.run_tests()
320