1# Owner(s): ["module: onnx"] 2import onnxruntime 3import pytorch_test_common 4from pytorch_test_common import skipIfNoCuda 5 6import torch 7from torch.onnx import verification 8from torch.onnx._globals import GLOBALS 9from torch.testing._internal import common_utils 10 11 12def _jit_graph_to_onnx_model(graph, operator_export_type, opset_version): 13 r""" 14 This function exports torch::jit::Graph object 15 to serialized ONNX ModelProto. 16 This function is for testing purpose. 17 It only keeps the essential parts for IR graph conversions. 18 It also does not interact with actual PyTorch modules nor 19 PyTorch tensor inputs. 20 """ 21 22 GLOBALS.export_onnx_opset_version = opset_version 23 graph = torch.onnx.utils._optimize_graph( 24 graph, operator_export_type, params_dict={} 25 ) 26 proto, _, _, _ = graph._export_onnx( 27 {}, 28 opset_version, 29 {}, 30 False, 31 operator_export_type, 32 False, 33 False, 34 {}, 35 True, 36 "", 37 {}, 38 ) 39 return proto 40 41 42class _TestJITIRToONNX: 43 """Abstract base class for test cases. 44 45 Intentionally not a sub-class of unittest.TestCase so that unittest / pytest 46 don't run it directly. unitest.TestCase is mixed in as another base class when 47 creating concrete sub-types. See MakeTestCase(). 48 """ 49 50 opset_version = -1 # Sub-classes must override 51 ort_providers = ["CPUExecutionProvider"] 52 check_shape = True 53 check_dtype = True 54 ignore_none = True # True for tracing, and Flase for scripting 55 56 def run_test(self, graph_ir, example_inputs, parse_tensor_constants=False): 57 graph = torch._C.parse_ir(graph_ir, parse_tensor_constants) 58 jit_outs = torch._C._jit_interpret_graph(graph, example_inputs) 59 60 onnx_proto = _jit_graph_to_onnx_model( 61 graph, torch.onnx.OperatorExportTypes.ONNX, self.opset_version 62 ) 63 ort_sess = onnxruntime.InferenceSession( 64 onnx_proto, providers=self.ort_providers 65 ) 66 ort_outs = verification._run_onnx(ort_sess, example_inputs) 67 68 options = verification.VerificationOptions( 69 rtol=1e-3, 70 atol=1e-7, 71 check_shape=self.check_shape, 72 check_dtype=self.check_dtype, 73 ignore_none=self.ignore_none, 74 acceptable_error_percentage=None, 75 ) 76 verification._compare_onnx_pytorch_outputs( 77 ort_outs, 78 jit_outs, 79 options, 80 ) 81 82 def test_example_ir(self): 83 graph_ir = """ 84 graph(%1 : Float(2, 3), 85 %2 : Float(2, 3)): 86 %3 : int = prim::Constant[value=1]() 87 %4 : Float(2, 3) = aten::add(%1, %2, %3) 88 return (%4) 89 """ 90 a = torch.randn(2, 3) 91 b = torch.randn(2, 3) 92 self.run_test(graph_ir, (a, b)) 93 94 def test_where_constants(self): 95 graph_ir = """ 96 graph(%0 : Bool(8, device=cpu), 97 %1 : Float(8, device=cpu)): 98 %3 : Double(device=cpu) = prim::Constant[value={0.}]() 99 %4 : Float(8) = aten::where(%0, %1, %3) 100 return (%4) 101 """ 102 a = torch.zeros(8, dtype=bool) 103 b = torch.zeros(8) 104 self.run_test(graph_ir, (a, b), parse_tensor_constants=True) 105 106 def test_add_sub_with_graph_inputs(self): 107 for op in ["add", "sub", "rsub"]: 108 graph_ir = f""" 109 graph(%1 : Float(2, 3), 110 %2 : Float(2, 3), 111 %3 : int): 112 %4 : Float(2, 3) = aten::{op}(%1, %2, %3) 113 return (%4) 114 """ 115 a = torch.randn(2, 3) 116 b = torch.randn(2, 3) 117 self.run_test(graph_ir, (a, b, 2)) 118 119 def test_native_layer_norm(self): 120 graph_ir = """ 121 graph(%x : Float(2, 3, 2), 122 %w : Float(3, 2), 123 %b : Float(3, 2)): 124 %5 : int = prim::Constant[value=3]() 125 %6 : int = prim::Constant[value=2]() 126 %7 : int[] = prim::ListConstruct(%5, %6) 127 %10 : float = prim::Constant[value=1.0000000000000001e-05]() 128 %11 : Float(2, 3, 2), %12 : Float(2, 1, 1), %13 : Float(2, 1, 1) = aten::native_layer_norm(%x, %7, %w, %b, %10) 129 return (%11, %12, %13) 130 """ 131 x = torch.randn(2, 3, 2) 132 w = torch.randn(3, 2) 133 b = torch.randn(3, 2) 134 self.run_test(graph_ir, (x, w, b)) 135 136 def test_convolution(self): 137 graph_ir = """ 138 graph(%1 : Tensor, 139 %2 : Tensor): 140 %3 : NoneType = prim::Constant() 141 %4 : int[] = prim::Constant[value=[1, 1]]() 142 %5 : int[] = prim::Constant[value=[0, 0]]() 143 %6 : bool = prim::Constant[value=0]() 144 %7 : int = prim::Constant[value=1]() 145 %8 : Tensor = aten::convolution(%1, %2, %3, %4, %5, %4, %6, %5, %7) 146 return (%8) 147 """ 148 x = torch.randn(8, 1, 5, 5) 149 w = torch.randn(4, 1, 3, 3) 150 self.run_test(graph_ir, (x, w)) 151 152 def test_log_softmax(self): 153 graph_ir = """ 154 graph(%x: Tensor): 155 %half_to_float: bool = prim::Constant[value=0]() 156 %dim: int = prim::Constant[value=1]() 157 %y = aten::_log_softmax(%x, %dim, %half_to_float) 158 return (%y) 159 """ 160 x = torch.randn(5, 2) 161 self.run_test(graph_ir, (x,)) 162 163 @skipIfNoCuda 164 def test_log_softmax_half_to_float(self): 165 graph_ir = """ 166 graph(%x: Tensor): 167 %half_to_float: bool = prim::Constant[value=1]() 168 %dim: int = prim::Constant[value=1]() 169 %y = aten::_log_softmax(%x, %dim, %half_to_float) 170 return (%y) 171 """ 172 x = torch.randn(5, 2).half().to("cuda") 173 self.run_test(graph_ir, (x,)) 174 175 def test_native_dropout(self): 176 graph_ir = """ 177 graph(%1 : Float(2, 3)): 178 %2 : float = prim::Constant[value=0.0]() 179 %training : bool = prim::Constant[value=1]() 180 %3 : Tensor, %4 : Tensor = aten::native_dropout(%1, %2, %training) 181 return (%3, %4) 182 """ 183 a = torch.randn(2, 3) 184 self.run_test(graph_ir, (a,)) 185 186 187def MakeTestCase(opset_version: int) -> type: 188 name = f"TestJITIRToONNX_opset{opset_version}" 189 return type( 190 str(name), 191 (pytorch_test_common.ExportTestCase,), 192 dict(_TestJITIRToONNX.__dict__, opset_version=opset_version), 193 ) 194 195 196TestJITIRToONNX_opset14 = MakeTestCase(14) 197 198if __name__ == "__main__": 199 common_utils.run_tests() 200