1# Owner(s): ["module: onnx"] 2from __future__ import annotations 3 4import os 5import sys 6 7import torch 8import torch.onnx 9from torch.testing._internal import common_utils 10from torch.utils import _pytree as torch_pytree 11 12 13sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14import onnx_test_common 15 16 17class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): 18 def _compare_onnx_and_torch_exported_program( 19 self, 20 torch_exported_program, 21 onnx_exported_program, 22 input_args, 23 input_kwargs=None, 24 rtol=1e-03, 25 atol=1e-07, 26 ): 27 # avoid mutable default argument 28 if input_kwargs is None: 29 input_kwargs = {} 30 31 # NOTE: ONNXProgram holds a reference (not copy) to the original ref_model, including its state_dict. 32 # Thus, ONNXProgram() must run before ref_model() to prevent ref_model.forward() from changing the state_dict. 33 # Otherwise, the ref_model can change buffers on state_dict which would be used by ONNXProgram.__call__() 34 onnx_outputs = onnx_exported_program(*input_args, **input_kwargs) 35 if isinstance(torch_exported_program, torch.export.ExportedProgram): 36 torch_outputs = torch_exported_program.module()(*input_args, **input_kwargs) 37 else: 38 torch_outputs = torch_exported_program(*input_args, **input_kwargs) 39 40 if isinstance(torch_outputs, torch.Tensor): 41 torch_outputs = [torch_outputs] 42 43 if len(torch_outputs) != len(onnx_outputs): 44 raise AssertionError( 45 f"Expected {len(torch_outputs)} outputs, got {len(onnx_outputs)}" 46 ) 47 for torch_output, onnx_output in zip(torch_outputs, onnx_outputs): 48 torch.testing.assert_close( 49 torch_output, torch.tensor(onnx_output), rtol=rtol, atol=atol 50 ) 51 52 def test_exported_program_with_dynamic_input(self): 53 class Model(torch.nn.Module): 54 def forward(self, x): 55 return x + 1.0 56 57 x = torch.randn(2, 3, 4, dtype=torch.float) 58 dim0 = torch.export.Dim("dim0") 59 exported_program = torch.export.export( 60 Model(), (x,), dynamic_shapes={"x": {0: dim0}} 61 ) 62 onnx_program = torch.onnx.dynamo_export(exported_program, x) 63 64 # different dim inputs 65 y = torch.randn(3, 3, 4, dtype=torch.float) 66 self._compare_onnx_and_torch_exported_program( 67 exported_program, onnx_program, input_args=(y,) 68 ) 69 70 def test_exported_program_as_input_from_file(self): 71 import tempfile 72 73 class Model(torch.nn.Module): 74 def forward(self, x): 75 return x + 1.0 76 77 x = torch.randn(1, 1, 2, dtype=torch.float) 78 exported_program = torch.export.export(Model(), args=(x,)) 79 onnx_program = torch.onnx.dynamo_export(exported_program, x) 80 81 with tempfile.NamedTemporaryFile(suffix=".pte") as f: 82 torch.export.save(exported_program, f.name) 83 del ( 84 exported_program 85 ) # Delete the exported program to ensure that we are loading from file 86 loaded_exported_program = torch.export.load(f.name) 87 88 self._compare_onnx_and_torch_exported_program( 89 loaded_exported_program, onnx_program, input_args=(x,) 90 ) 91 92 def test_exported_program_with_specialized_input_during_tracing(self): 93 class Foo(torch.nn.Module): 94 def forward(self, x, y): 95 return x + y 96 97 f = Foo() 98 99 tensor_input = torch.ones(7, 5) 100 dim0_x = torch.export.Dim("dim0_x", min=6) 101 dynamic_shapes = {"x": {0: dim0_x}, "y": None} 102 # specialized input y to 5 during tracing 103 exported_program = torch.export.export( 104 f, (tensor_input, 5), dynamic_shapes=dynamic_shapes 105 ) 106 onnx_program = torch.onnx.dynamo_export(exported_program, tensor_input, 5) 107 108 # different dim inputs 109 additional_tensor_input = torch.ones(8, 5) 110 self._compare_onnx_and_torch_exported_program( 111 exported_program, onnx_program, input_args=(additional_tensor_input, 5) 112 ) 113 114 def test_onnx_program_supports_retraced_graph(self): 115 class Bar(torch.nn.Module): 116 def __init__(self) -> None: 117 super().__init__() 118 self.buf = torch.nn.Buffer(torch.ones(1)) 119 120 def forward(self, x): 121 self.buf.add_(1) 122 return x.sum() + self.buf.sum() 123 124 class Foo(torch.nn.Module): 125 def __init__(self) -> None: 126 super().__init__() 127 self.buf = torch.nn.Buffer(torch.zeros(1)) 128 self.bar = Bar() 129 130 def forward(self, x): 131 self.buf.add_(1) 132 bar = self.bar(x) 133 self.bar.buf.add_(2) 134 return bar.sum() + self.buf.sum() 135 136 tensor_input = torch.ones(5, 5) 137 exported_program = torch.export.export(Foo(), (tensor_input,)) 138 139 dim0_x = torch.export.Dim("dim0_x") 140 # NOTE: If input is ExportedProgram, we need to specify dynamic_shapes 141 # as a tuple. 142 reexported_program = torch.export.export( 143 exported_program.module(), (tensor_input,), dynamic_shapes=({0: dim0_x},) 144 ) 145 reexported_onnx_program = torch.onnx.dynamo_export( 146 reexported_program, tensor_input 147 ) 148 149 additional_tensor_input = torch.ones(7, 5) 150 self._compare_onnx_and_torch_exported_program( 151 reexported_program, 152 reexported_onnx_program, 153 input_args=(additional_tensor_input,), 154 ) 155 156 def test_onnx_program_supports_none_arg_name_in_dynamic(self): 157 class Foo(torch.nn.Module): 158 def forward(self, a, b): 159 return a.sum() + b.sum() 160 161 foo = Foo() 162 163 dim = torch.export.Dim("dim") 164 exported_program = torch.export.export( 165 foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim}) 166 ) 167 onnx_program = torch.onnx.dynamo_export( 168 exported_program, torch.randn(4, 4), torch.randn(4, 4) 169 ) 170 171 test_inputs = ( 172 torch.randn(4, 4), 173 torch.randn(7, 4), 174 ) 175 self._compare_onnx_and_torch_exported_program( 176 exported_program, onnx_program, test_inputs 177 ) 178 179 def test_onnx_program_suppors_non_arg_name_with_kwarg(self): 180 class Foo(torch.nn.Module): 181 def forward(self, a, b, kw1, kw2): 182 return a.sum() + b.sum() + kw1.sum() - kw2.sum() 183 184 foo = Foo() 185 186 dim = torch.export.Dim("dim") 187 dim_for_kw1 = torch.export.Dim("dim_for_kw1") 188 exported_program = torch.export.export( 189 foo, 190 (torch.randn(4, 4), torch.randn(4, 4)), 191 {"kw2": torch.ones(4, 4), "kw1": torch.zeros(4, 4)}, 192 # We are specifying dynamism on the first kwarg even though user passed in 193 # different order 194 dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None), 195 ) 196 onnx_program = torch.onnx.dynamo_export( 197 exported_program, 198 torch.randn(4, 4), 199 torch.randn(4, 4), 200 kw2=torch.ones(4, 4), 201 kw1=torch.zeros(4, 4), 202 ) 203 204 test_inputs = (torch.randn(4, 4), torch.randn(7, 4)) 205 test_kwargs = {"kw2": torch.ones(4, 4), "kw1": torch.zeros(9, 4)} 206 # This should work even if the kwarg order are flipped. 207 self._compare_onnx_and_torch_exported_program( 208 exported_program, onnx_program, test_inputs, test_kwargs 209 ) 210 211 def test_exported_program_as_input_lifting_buffers_mutation(self): 212 for persistent in (True, False): 213 214 class CustomModule(torch.nn.Module): 215 def __init__(self) -> None: 216 super().__init__() 217 self.register_buffer( 218 "my_buffer", torch.tensor(4.0), persistent=persistent 219 ) 220 221 def forward(self, x, b): 222 output = x + b 223 ( 224 self.my_buffer.add_(1.0) + 3.0 225 ) # Mutate buffer through in-place addition 226 return output 227 228 input_x = torch.rand((3, 3), dtype=torch.float32) 229 input_b = torch.randn(3, 3) 230 model = CustomModule() 231 232 dim = torch.export.Dim("dim") 233 exported_program = torch.export.export( 234 model, 235 ( 236 input_x, 237 input_b, 238 ), 239 dynamic_shapes=({0: dim}, {0: dim}), 240 ) 241 onnx_program = torch.onnx.dynamo_export(exported_program, input_x, input_b) 242 243 # different dim inputs 244 additional_inputs_x = torch.rand((4, 3), dtype=torch.float32) 245 additional_inputs_b = torch.randn(4, 3) 246 self._compare_onnx_and_torch_exported_program( 247 exported_program, 248 onnx_program, 249 ( 250 additional_inputs_x, 251 additional_inputs_b, 252 ), 253 ) 254 255 def test_onnx_program_supports_non_arg_name_with_container_type(self): 256 class Foo(torch.nn.Module): 257 def forward(self, a, b): 258 return a[0].sum() + a[1].sum() + b.sum() 259 260 foo = Foo() 261 262 inp_a = (torch.randn(4, 4), torch.randn(4, 4)) 263 inp_b = torch.randn(4, 4) 264 inp = (inp_a, inp_b) 265 266 count = 0 267 268 def dynamify_inp(x): 269 # Mark the second input a[1] dynamic 270 nonlocal count 271 if count == 1: 272 dim = torch.export.Dim("dim", min=3) 273 count += 1 274 return {0: dim} 275 count += 1 276 return None 277 278 dynamic_shapes = torch_pytree.tree_map(dynamify_inp, inp) 279 exported_program = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes) 280 onnx_program = torch.onnx.dynamo_export(exported_program, inp_a, inp_b) 281 282 # NOTE: Careful with the input format. The input format should be 283 # consistent with how the model is exported. 284 test_inputs = ((torch.randn(4, 4), torch.randn(6, 4)), torch.randn(4, 4)) 285 self._compare_onnx_and_torch_exported_program( 286 exported_program, onnx_program, test_inputs 287 ) 288 289 def test_onnx_program_supports_lazy_module_kwargs(self): 290 class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): 291 def initialize_parameters(self, *args, **kwargs): 292 pass 293 294 def forward(self, x, y): 295 return x + y 296 297 m = LazyModule() 298 dim = torch.export.Dim("dim") 299 dynamic_shapes = ({0: dim}, {0: dim}) 300 exported_program = torch.export.export( 301 m, 302 (), 303 {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}, 304 dynamic_shapes=dynamic_shapes, 305 ) 306 onnx_program = torch.onnx.dynamo_export( 307 exported_program, x=torch.randn(3, 3), y=torch.randn(3, 3) 308 ) 309 310 # NOTE: A model should be fed with the input formats that 311 # how the model is exported 312 inputs = {"x": torch.randn(6, 3), "y": torch.randn(6, 3)} 313 self._compare_onnx_and_torch_exported_program( 314 exported_program, onnx_program, input_args=(), input_kwargs=inputs 315 ) 316 317 318if __name__ == "__main__": 319 common_utils.run_tests() 320