1*da0073e9SAndroid Build Coastguard Workerimport difflib 2*da0073e9SAndroid Build Coastguard Workerimport io 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport onnx 7*da0073e9SAndroid Build Coastguard Workerimport onnx.helper 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.jit 11*da0073e9SAndroid Build Coastguard Workerimport torch.onnx 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerdef colonize(msg, sep=": "): 15*da0073e9SAndroid Build Coastguard Worker if not msg: 16*da0073e9SAndroid Build Coastguard Worker return "" 17*da0073e9SAndroid Build Coastguard Worker else: 18*da0073e9SAndroid Build Coastguard Worker return msg + sep 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass Errors: 22*da0073e9SAndroid Build Coastguard Worker """ 23*da0073e9SAndroid Build Coastguard Worker An error-collecting object which supports error recovery. 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker It is intended to be used like a context manager: 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker >>> with Errors("Top-level error message") as errs: 28*da0073e9SAndroid Build Coastguard Worker >>> ... 29*da0073e9SAndroid Build Coastguard Worker """ 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def __init__(self, msg, rtol=1e-3, atol=1e-5): 32*da0073e9SAndroid Build Coastguard Worker self.msg = msg 33*da0073e9SAndroid Build Coastguard Worker self.errors = [] 34*da0073e9SAndroid Build Coastguard Worker self.context = [] 35*da0073e9SAndroid Build Coastguard Worker self.rtol = rtol 36*da0073e9SAndroid Build Coastguard Worker self.atol = atol 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker # Allocated upon instance creation so that multiple Errors 39*da0073e9SAndroid Build Coastguard Worker # can be used 40*da0073e9SAndroid Build Coastguard Worker class ShortCircuit(Exception): 41*da0073e9SAndroid Build Coastguard Worker pass 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker self.exc_class = ShortCircuit 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def requireAlmostEqual(self, x, y, msg=None): 46*da0073e9SAndroid Build Coastguard Worker """ 47*da0073e9SAndroid Build Coastguard Worker Test that x and y are nearly equal (equal within self.rtol 48*da0073e9SAndroid Build Coastguard Worker precision); aborts execution if they are not. 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker self.almostEqualAndThen(x, y, msg, self.failWith) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def checkAlmostEqual(self, x, y, msg=None): 53*da0073e9SAndroid Build Coastguard Worker """ 54*da0073e9SAndroid Build Coastguard Worker Test that x and y are nearly equal (equal within self.rtol 55*da0073e9SAndroid Build Coastguard Worker precision), but continue execution even if they are not equal. 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker To prevent error cascades, you should remember to call "failIfErrs" 58*da0073e9SAndroid Build Coastguard Worker at some later point in time. 59*da0073e9SAndroid Build Coastguard Worker """ 60*da0073e9SAndroid Build Coastguard Worker self.almostEqualAndThen(x, y, msg, self.addErr) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def almostEqualAndThen(self, x, y, msg, k): 63*da0073e9SAndroid Build Coastguard Worker """ 64*da0073e9SAndroid Build Coastguard Worker Helper for implementing "requireAlmostEqual" and "checkAlmostEqual". 65*da0073e9SAndroid Build Coastguard Worker Upon failure, invokes continuation "k" with the error message. 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker At the moment, only tests on "numpy.ndarray" are supported. 68*da0073e9SAndroid Build Coastguard Worker """ 69*da0073e9SAndroid Build Coastguard Worker if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 70*da0073e9SAndroid Build Coastguard Worker np.testing.assert_allclose( 71*da0073e9SAndroid Build Coastguard Worker x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker else: 74*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Unsupported almost equal test") 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def requireEqual(self, x, y, msg=None): 77*da0073e9SAndroid Build Coastguard Worker """ 78*da0073e9SAndroid Build Coastguard Worker Test that x and y are equal; aborts execution if they are not. 79*da0073e9SAndroid Build Coastguard Worker """ 80*da0073e9SAndroid Build Coastguard Worker self.equalAndThen(x, y, msg, self.failWith) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def checkEqual(self, x, y, msg=None): 83*da0073e9SAndroid Build Coastguard Worker """ 84*da0073e9SAndroid Build Coastguard Worker Test that x and y are equal, but continue execution even if they are not equal. 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker To prevent error cascades, you should remember to call "failIfErrs" 87*da0073e9SAndroid Build Coastguard Worker at some later point in time. 88*da0073e9SAndroid Build Coastguard Worker """ 89*da0073e9SAndroid Build Coastguard Worker self.equalAndThen(x, y, msg, self.addErr) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker # Bit-for-bit accuracy test 92*da0073e9SAndroid Build Coastguard Worker def equalAndThen(self, x, y, msg, k): 93*da0073e9SAndroid Build Coastguard Worker """ 94*da0073e9SAndroid Build Coastguard Worker Helper for implementing "requireEqual" and "checkEqual". Upon failure, 95*da0073e9SAndroid Build Coastguard Worker invokes continuation "k" with the error message. 96*da0073e9SAndroid Build Coastguard Worker """ 97*da0073e9SAndroid Build Coastguard Worker if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto): 98*da0073e9SAndroid Build Coastguard Worker self.equalAndThen(x.name, y.name, msg, k) 99*da0073e9SAndroid Build Coastguard Worker # Use numpy for the comparison 100*da0073e9SAndroid Build Coastguard Worker t1 = onnx.numpy_helper.to_array(x) 101*da0073e9SAndroid Build Coastguard Worker t2 = onnx.numpy_helper.to_array(y) 102*da0073e9SAndroid Build Coastguard Worker new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'" 103*da0073e9SAndroid Build Coastguard Worker self.equalAndThen(t1, t2, new_msg, k) 104*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray): 105*da0073e9SAndroid Build Coastguard Worker np.testing.assert_equal(x, y) 106*da0073e9SAndroid Build Coastguard Worker else: 107*da0073e9SAndroid Build Coastguard Worker if x != y: 108*da0073e9SAndroid Build Coastguard Worker # TODO: Better algorithm for lists 109*da0073e9SAndroid Build Coastguard Worker sx = str(x) 110*da0073e9SAndroid Build Coastguard Worker sy = str(y) 111*da0073e9SAndroid Build Coastguard Worker if len(sx) > 40 or len(sy) > 40 or "\n" in sx or "\n" in sy: 112*da0073e9SAndroid Build Coastguard Worker # long form 113*da0073e9SAndroid Build Coastguard Worker l = "=" * 50 114*da0073e9SAndroid Build Coastguard Worker k( 115*da0073e9SAndroid Build Coastguard Worker "\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}".format( 116*da0073e9SAndroid Build Coastguard Worker colonize(msg, ":\n"), l, sx, l, l, sy, l 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker ) 119*da0073e9SAndroid Build Coastguard Worker else: 120*da0073e9SAndroid Build Coastguard Worker k(f"{colonize(msg)}{sx} != {sy}") 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def requireMultiLineEqual(self, x, y, msg=None): 123*da0073e9SAndroid Build Coastguard Worker """ 124*da0073e9SAndroid Build Coastguard Worker Test that long, multi-line strings x and y are equal; 125*da0073e9SAndroid Build Coastguard Worker aborts execution if they are not. 126*da0073e9SAndroid Build Coastguard Worker """ 127*da0073e9SAndroid Build Coastguard Worker self.multiLineEqualAndThen(x, y, msg, self.failWith) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker def multiLineEqualAndThen(self, x, y, msg, k): 130*da0073e9SAndroid Build Coastguard Worker """ 131*da0073e9SAndroid Build Coastguard Worker Helper for implementing "requireMultiLineEqual". Upon failure, 132*da0073e9SAndroid Build Coastguard Worker invokes continuation "k" with the error message. 133*da0073e9SAndroid Build Coastguard Worker """ 134*da0073e9SAndroid Build Coastguard Worker if msg is None: 135*da0073e9SAndroid Build Coastguard Worker msg = "Strings are not equal" 136*da0073e9SAndroid Build Coastguard Worker if x != y: 137*da0073e9SAndroid Build Coastguard Worker diff = difflib.ndiff(x.splitlines(True), y.splitlines(True)) 138*da0073e9SAndroid Build Coastguard Worker k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff))) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker def addErr(self, msg): 141*da0073e9SAndroid Build Coastguard Worker """ 142*da0073e9SAndroid Build Coastguard Worker Add an error to the error context, but continue executing. 143*da0073e9SAndroid Build Coastguard Worker """ 144*da0073e9SAndroid Build Coastguard Worker # TODO: instead of immediately concatenating the context in the msg, 145*da0073e9SAndroid Build Coastguard Worker # attach it as metadata and make a decision how to format it later. 146*da0073e9SAndroid Build Coastguard Worker msg_w_ctx = msg 147*da0073e9SAndroid Build Coastguard Worker for c in reversed(self.context): 148*da0073e9SAndroid Build Coastguard Worker msg += "\n\n * " + "\n ".join(c.splitlines()) 149*da0073e9SAndroid Build Coastguard Worker self.errors.append(msg) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def fail(self): 152*da0073e9SAndroid Build Coastguard Worker """ 153*da0073e9SAndroid Build Coastguard Worker Immediately fail and short-circuit to the next recovery context. 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker NB: It is an error to "fail" without having added any errors to 156*da0073e9SAndroid Build Coastguard Worker the error context. 157*da0073e9SAndroid Build Coastguard Worker """ 158*da0073e9SAndroid Build Coastguard Worker raise self.exc_class 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker def failWith(self, msg): 161*da0073e9SAndroid Build Coastguard Worker """ 162*da0073e9SAndroid Build Coastguard Worker Add an error to the error context, and then short-circuit. 163*da0073e9SAndroid Build Coastguard Worker """ 164*da0073e9SAndroid Build Coastguard Worker self.addErr(msg) 165*da0073e9SAndroid Build Coastguard Worker self.fail() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def failIfErrs(self): 168*da0073e9SAndroid Build Coastguard Worker """ 169*da0073e9SAndroid Build Coastguard Worker If there are any errors in the error context, short-circuit. 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker This is used to prevent error cascades. 172*da0073e9SAndroid Build Coastguard Worker """ 173*da0073e9SAndroid Build Coastguard Worker if self.errors: 174*da0073e9SAndroid Build Coastguard Worker self.fail() 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker def recover(self): 177*da0073e9SAndroid Build Coastguard Worker """ 178*da0073e9SAndroid Build Coastguard Worker Returns a context manager which can be used to recover in case of 179*da0073e9SAndroid Build Coastguard Worker an error. Example usage: 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker >>> with errs.recover(): 182*da0073e9SAndroid Build Coastguard Worker >>> ... 183*da0073e9SAndroid Build Coastguard Worker """ 184*da0073e9SAndroid Build Coastguard Worker parent_self = self 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker class Recover: 187*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 188*da0073e9SAndroid Build Coastguard Worker pass 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_value, traceback): 191*da0073e9SAndroid Build Coastguard Worker if exc_type == parent_self.exc_class: 192*da0073e9SAndroid Build Coastguard Worker return True 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker return Recover() 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker def addErrCtxt(self, msg): 197*da0073e9SAndroid Build Coastguard Worker """ 198*da0073e9SAndroid Build Coastguard Worker Returns a context manager which encloses a fragment of code with 199*da0073e9SAndroid Build Coastguard Worker an extra contextual message, e.g., where an error occurred, or a hint 200*da0073e9SAndroid Build Coastguard Worker applicable to all errors in the area. Example usage: 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker >>> with errs.addErrCtx("Some text"): 203*da0073e9SAndroid Build Coastguard Worker >>> ... 204*da0073e9SAndroid Build Coastguard Worker """ 205*da0073e9SAndroid Build Coastguard Worker parent_self = self 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker class AddContext: 208*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 209*da0073e9SAndroid Build Coastguard Worker parent_self.context.append(msg) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_value, traceback): 212*da0073e9SAndroid Build Coastguard Worker parent_self.context.pop() 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker return AddContext() 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 217*da0073e9SAndroid Build Coastguard Worker return self 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_value, traceback): 220*da0073e9SAndroid Build Coastguard Worker if self.errors: 221*da0073e9SAndroid Build Coastguard Worker errors_msg = "\n\n".join("ERROR: " + x for x in self.errors) 222*da0073e9SAndroid Build Coastguard Worker final_msg = "{}\n{}\n{}".format(self.msg, "-" * 70, errors_msg) 223*da0073e9SAndroid Build Coastguard Worker raise AssertionError(final_msg) 224*da0073e9SAndroid Build Coastguard Worker if exc_type == self.exc_class: 225*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("ShortCircuit was raised, but no errors were recorded") 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Workerdef verify( 229*da0073e9SAndroid Build Coastguard Worker model, 230*da0073e9SAndroid Build Coastguard Worker args, 231*da0073e9SAndroid Build Coastguard Worker backend, 232*da0073e9SAndroid Build Coastguard Worker verbose=False, 233*da0073e9SAndroid Build Coastguard Worker training=torch.onnx.TrainingMode.EVAL, 234*da0073e9SAndroid Build Coastguard Worker rtol=1e-3, 235*da0073e9SAndroid Build Coastguard Worker atol=1e-7, 236*da0073e9SAndroid Build Coastguard Worker test_args=2, 237*da0073e9SAndroid Build Coastguard Worker do_constant_folding=True, 238*da0073e9SAndroid Build Coastguard Worker opset_version=None, 239*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=True, 240*da0073e9SAndroid Build Coastguard Worker add_node_names=False, 241*da0073e9SAndroid Build Coastguard Worker operator_export_type=torch.onnx.OperatorExportTypes.ONNX, 242*da0073e9SAndroid Build Coastguard Worker input_names=None, 243*da0073e9SAndroid Build Coastguard Worker dynamic_axes=None, 244*da0073e9SAndroid Build Coastguard Worker remained_onnx_input_idx=None, 245*da0073e9SAndroid Build Coastguard Worker): 246*da0073e9SAndroid Build Coastguard Worker """ 247*da0073e9SAndroid Build Coastguard Worker Export a model into ONNX, import it into a specified ONNX backend, and then 248*da0073e9SAndroid Build Coastguard Worker on a few random inputs verify that PyTorch and the backend produced the same 249*da0073e9SAndroid Build Coastguard Worker results. Requires onnx to be installed. 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker This function may spuriously fail: some operators are implemented with 252*da0073e9SAndroid Build Coastguard Worker different numerical precision in an ONNX backend, in which case an unstable 253*da0073e9SAndroid Build Coastguard Worker network (e.g., Inception) may blow up these numerical instabilities. This 254*da0073e9SAndroid Build Coastguard Worker situation is less likely to happen if your model has been trained. However, 255*da0073e9SAndroid Build Coastguard Worker if this is not the case, you may have found a bug! Please report it to the 256*da0073e9SAndroid Build Coastguard Worker PyTorch developers. You can also debug the issue yourself by removing 257*da0073e9SAndroid Build Coastguard Worker suffixes of operators from your model until verification passes. 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker For reproducibility, we recommend explicitly setting PyTorch's seed before 260*da0073e9SAndroid Build Coastguard Worker invoking this function. 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker Args: 263*da0073e9SAndroid Build Coastguard Worker model (torch.nn.Module): the model to be exported and verified 264*da0073e9SAndroid Build Coastguard Worker args (tuple of arguments): the inputs to 265*da0073e9SAndroid Build Coastguard Worker the model, e.g., such that ``model(*args)`` is a valid 266*da0073e9SAndroid Build Coastguard Worker invocation of the model. Any non-Variable arguments will 267*da0073e9SAndroid Build Coastguard Worker be hard-coded into the exported model; any Variable arguments 268*da0073e9SAndroid Build Coastguard Worker will become inputs of the exported model, in the order they 269*da0073e9SAndroid Build Coastguard Worker occur in args. If args is a Variable, this is equivalent 270*da0073e9SAndroid Build Coastguard Worker to having called it with a 1-ary tuple of that Variable. 271*da0073e9SAndroid Build Coastguard Worker (Note: passing keyword arguments to the model is not currently 272*da0073e9SAndroid Build Coastguard Worker supported. Give us a shout if you need it.) 273*da0073e9SAndroid Build Coastguard Worker backend (onnx.backend module): ONNX backend to verify with 274*da0073e9SAndroid Build Coastguard Worker verbose (bool, default False): if specified, we will print out a debug 275*da0073e9SAndroid Build Coastguard Worker description of the trace being exported. 276*da0073e9SAndroid Build Coastguard Worker training (bool, default False): export the model in training mode. At 277*da0073e9SAndroid Build Coastguard Worker the moment, ONNX is oriented towards exporting models for inference 278*da0073e9SAndroid Build Coastguard Worker only, so you will generally not need to set this to True. 279*da0073e9SAndroid Build Coastguard Worker rtol (float, default 1e-3): relative precision required 280*da0073e9SAndroid Build Coastguard Worker test_args (int or iterable of args, default 2): 281*da0073e9SAndroid Build Coastguard Worker either an integer specifying the number 282*da0073e9SAndroid Build Coastguard Worker of random arguments to generate, or an iterable producing arguments 283*da0073e9SAndroid Build Coastguard Worker to test under. 284*da0073e9SAndroid Build Coastguard Worker opset_version (int, default None): the opset version of the model to 285*da0073e9SAndroid Build Coastguard Worker export. If not specified, the default value in symboli_helper will 286*da0073e9SAndroid Build Coastguard Worker be used in utils._export(). 287*da0073e9SAndroid Build Coastguard Worker operator_export_type (enum, default OperatorExportTypes.ONNX): the operator 288*da0073e9SAndroid Build Coastguard Worker export type to use when exporting the model. The default value converts 289*da0073e9SAndroid Build Coastguard Worker all operators to ONNX ops. 290*da0073e9SAndroid Build Coastguard Worker input_names (list of string): list of input names. 291*da0073e9SAndroid Build Coastguard Worker dynamic_axes (dict of (string, list)): dynamic_axes. 292*da0073e9SAndroid Build Coastguard Worker remained_onnx_input_idx (list of int, default None): The remained ONNX input index. 293*da0073e9SAndroid Build Coastguard Worker """ 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker def _nested_map(condition, fn, condition_msg=None): 296*da0073e9SAndroid Build Coastguard Worker def _map(obj): 297*da0073e9SAndroid Build Coastguard Worker if condition(obj): 298*da0073e9SAndroid Build Coastguard Worker return fn(obj) 299*da0073e9SAndroid Build Coastguard Worker elif obj is None: 300*da0073e9SAndroid Build Coastguard Worker return None 301*da0073e9SAndroid Build Coastguard Worker elif isinstance(obj, (list, tuple)): 302*da0073e9SAndroid Build Coastguard Worker return type(obj)(_map(x) for x in obj) 303*da0073e9SAndroid Build Coastguard Worker else: 304*da0073e9SAndroid Build Coastguard Worker raise ValueError( 305*da0073e9SAndroid Build Coastguard Worker "Auto nesting doesn't know how to process " 306*da0073e9SAndroid Build Coastguard Worker "an input object of type " 307*da0073e9SAndroid Build Coastguard Worker + torch.typename(obj) 308*da0073e9SAndroid Build Coastguard Worker + ( 309*da0073e9SAndroid Build Coastguard Worker ". Accepted types: " 310*da0073e9SAndroid Build Coastguard Worker + condition_msg 311*da0073e9SAndroid Build Coastguard Worker + ", or lists/tuples of them" 312*da0073e9SAndroid Build Coastguard Worker if condition_msg 313*da0073e9SAndroid Build Coastguard Worker else "" 314*da0073e9SAndroid Build Coastguard Worker ) 315*da0073e9SAndroid Build Coastguard Worker ) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker return _map 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def _iter_filter(condition, allow_unknown=False, condition_msg=None): 320*da0073e9SAndroid Build Coastguard Worker def _iter(obj): 321*da0073e9SAndroid Build Coastguard Worker if condition(obj): 322*da0073e9SAndroid Build Coastguard Worker yield obj 323*da0073e9SAndroid Build Coastguard Worker elif obj is None: 324*da0073e9SAndroid Build Coastguard Worker return 325*da0073e9SAndroid Build Coastguard Worker elif isinstance(obj, (list, tuple)): 326*da0073e9SAndroid Build Coastguard Worker for o in obj: 327*da0073e9SAndroid Build Coastguard Worker yield from _iter(o) 328*da0073e9SAndroid Build Coastguard Worker elif allow_unknown: 329*da0073e9SAndroid Build Coastguard Worker yield obj 330*da0073e9SAndroid Build Coastguard Worker else: 331*da0073e9SAndroid Build Coastguard Worker raise ValueError( 332*da0073e9SAndroid Build Coastguard Worker "Auto nesting doesn't know how to process " 333*da0073e9SAndroid Build Coastguard Worker "an input object of type " 334*da0073e9SAndroid Build Coastguard Worker + torch.typename(obj) 335*da0073e9SAndroid Build Coastguard Worker + ( 336*da0073e9SAndroid Build Coastguard Worker ". Accepted types: " 337*da0073e9SAndroid Build Coastguard Worker + condition_msg 338*da0073e9SAndroid Build Coastguard Worker + ", or lists/tuples of them" 339*da0073e9SAndroid Build Coastguard Worker if condition_msg 340*da0073e9SAndroid Build Coastguard Worker else "" 341*da0073e9SAndroid Build Coastguard Worker ) 342*da0073e9SAndroid Build Coastguard Worker ) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker return _iter 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker def is_tensor(o): 347*da0073e9SAndroid Build Coastguard Worker return isinstance(o, torch.Tensor) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker _iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors") 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker def randomize_arg(arg): 352*da0073e9SAndroid Build Coastguard Worker new_data = arg.data.clone() 353*da0073e9SAndroid Build Coastguard Worker # For now, don't try randomizing non-float tensors; these 354*da0073e9SAndroid Build Coastguard Worker # are likely to be things like indices, where just randomly 355*da0073e9SAndroid Build Coastguard Worker # spattering some longs is unlikely to work. One way we could 356*da0073e9SAndroid Build Coastguard Worker # make this work is to apply a random permutation or something. 357*da0073e9SAndroid Build Coastguard Worker if arg.is_floating_point(): 358*da0073e9SAndroid Build Coastguard Worker new_data.uniform_() 359*da0073e9SAndroid Build Coastguard Worker return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker randomize_args = _nested_map(is_tensor, randomize_arg) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker def backend_args(args): 364*da0073e9SAndroid Build Coastguard Worker # TODO: onnx should accept iterables 365*da0073e9SAndroid Build Coastguard Worker return tuple(v.data.cpu().numpy() for v in _iter_tensors(args)) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def load_bytes(b): 368*da0073e9SAndroid Build Coastguard Worker b.seek(0) 369*da0073e9SAndroid Build Coastguard Worker x = onnx.load(b) 370*da0073e9SAndroid Build Coastguard Worker # doc_string has stack traces - let's remove them to make comparison 371*da0073e9SAndroid Build Coastguard Worker # sane 372*da0073e9SAndroid Build Coastguard Worker onnx.helper.strip_doc_string(x) 373*da0073e9SAndroid Build Coastguard Worker return x 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker # Special case for common case of passing a single Tensor 376*da0073e9SAndroid Build Coastguard Worker if isinstance(args, torch.Tensor): 377*da0073e9SAndroid Build Coastguard Worker args = (args,) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker with torch.onnx.select_model_mode_for_export(model, training): 380*da0073e9SAndroid Build Coastguard Worker proto_bytes = io.BytesIO() 381*da0073e9SAndroid Build Coastguard Worker torch_out = torch.onnx.utils._export( 382*da0073e9SAndroid Build Coastguard Worker model, 383*da0073e9SAndroid Build Coastguard Worker args, 384*da0073e9SAndroid Build Coastguard Worker proto_bytes, 385*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 386*da0073e9SAndroid Build Coastguard Worker do_constant_folding=do_constant_folding, 387*da0073e9SAndroid Build Coastguard Worker opset_version=opset_version, 388*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=keep_initializers_as_inputs, 389*da0073e9SAndroid Build Coastguard Worker add_node_names=add_node_names, 390*da0073e9SAndroid Build Coastguard Worker operator_export_type=operator_export_type, 391*da0073e9SAndroid Build Coastguard Worker input_names=input_names, 392*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 393*da0073e9SAndroid Build Coastguard Worker ) 394*da0073e9SAndroid Build Coastguard Worker if isinstance(model, torch.jit.ScriptModule): 395*da0073e9SAndroid Build Coastguard Worker torch_out = model(*args) 396*da0073e9SAndroid Build Coastguard Worker proto = load_bytes(proto_bytes) 397*da0073e9SAndroid Build Coastguard Worker prepared = backend.prepare(proto) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker def run(args, remained_onnx_input_idx): 400*da0073e9SAndroid Build Coastguard Worker alt_proto_bytes = io.BytesIO() 401*da0073e9SAndroid Build Coastguard Worker torch_out = torch.onnx.utils._export( 402*da0073e9SAndroid Build Coastguard Worker model, 403*da0073e9SAndroid Build Coastguard Worker args, 404*da0073e9SAndroid Build Coastguard Worker alt_proto_bytes, 405*da0073e9SAndroid Build Coastguard Worker verbose=verbose, 406*da0073e9SAndroid Build Coastguard Worker do_constant_folding=do_constant_folding, 407*da0073e9SAndroid Build Coastguard Worker opset_version=opset_version, 408*da0073e9SAndroid Build Coastguard Worker keep_initializers_as_inputs=keep_initializers_as_inputs, 409*da0073e9SAndroid Build Coastguard Worker add_node_names=add_node_names, 410*da0073e9SAndroid Build Coastguard Worker operator_export_type=operator_export_type, 411*da0073e9SAndroid Build Coastguard Worker input_names=input_names, 412*da0073e9SAndroid Build Coastguard Worker dynamic_axes=dynamic_axes, 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker if isinstance(model, torch.jit.ScriptModule): 415*da0073e9SAndroid Build Coastguard Worker torch_out = model(*args) 416*da0073e9SAndroid Build Coastguard Worker alt_proto = load_bytes(alt_proto_bytes) 417*da0073e9SAndroid Build Coastguard Worker if proto.SerializeToString() != alt_proto.SerializeToString(): 418*da0073e9SAndroid Build Coastguard Worker # OK, let's try to figure out what happened. 419*da0073e9SAndroid Build Coastguard Worker msg = "When I exported your model with different inputs, the result was different." 420*da0073e9SAndroid Build Coastguard Worker if not verbose: 421*da0073e9SAndroid Build Coastguard Worker msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))" 422*da0073e9SAndroid Build Coastguard Worker with Errors(msg, rtol=rtol, atol=atol) as errs: 423*da0073e9SAndroid Build Coastguard Worker # First, check if we have the same number of parameters, and 424*da0073e9SAndroid Build Coastguard Worker # that they"re the same order. If they don"t, something has *really* gone wrong. 425*da0073e9SAndroid Build Coastguard Worker initializer_order_hint = ( 426*da0073e9SAndroid Build Coastguard Worker "This is really strange! The second time I exported your model,\n" 427*da0073e9SAndroid Build Coastguard Worker "it had a different set of parameters. Are you assigning Parameters\n" 428*da0073e9SAndroid Build Coastguard Worker "in the forward() of your model definition?" 429*da0073e9SAndroid Build Coastguard Worker ) 430*da0073e9SAndroid Build Coastguard Worker with errs.addErrCtxt(initializer_order_hint): 431*da0073e9SAndroid Build Coastguard Worker errs.requireEqual( 432*da0073e9SAndroid Build Coastguard Worker [x.name for x in proto.graph.initializer], 433*da0073e9SAndroid Build Coastguard Worker [x.name for x in alt_proto.graph.initializer], 434*da0073e9SAndroid Build Coastguard Worker msg="Parameters list differs", 435*da0073e9SAndroid Build Coastguard Worker ) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker # Now check if the embedded parameters are actually the same 438*da0073e9SAndroid Build Coastguard Worker initializer_hint = ( 439*da0073e9SAndroid Build Coastguard Worker "A difference in embedded parameters usually means that\n" 440*da0073e9SAndroid Build Coastguard Worker "your model is updating parameters/buffers even in inference\n" 441*da0073e9SAndroid Build Coastguard Worker "mode. Look for a buggy nn.Module which isn't respecting train().\n" 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker with errs.recover(), errs.addErrCtxt(initializer_hint): 444*da0073e9SAndroid Build Coastguard Worker for x, y in zip( 445*da0073e9SAndroid Build Coastguard Worker proto.graph.initializer, alt_proto.graph.initializer 446*da0073e9SAndroid Build Coastguard Worker ): 447*da0073e9SAndroid Build Coastguard Worker errs.checkEqual(x, y) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker # Next, check if the model structure lines up. 450*da0073e9SAndroid Build Coastguard Worker structure_hint = ( 451*da0073e9SAndroid Build Coastguard Worker "A difference in model structure usually means that\n" 452*da0073e9SAndroid Build Coastguard Worker "your model has dynamic control flow. These models are not\n" 453*da0073e9SAndroid Build Coastguard Worker "currently supported by the exporter." 454*da0073e9SAndroid Build Coastguard Worker ) 455*da0073e9SAndroid Build Coastguard Worker with errs.recover(), errs.addErrCtxt(structure_hint): 456*da0073e9SAndroid Build Coastguard Worker # Delete initializers since we already tested them 457*da0073e9SAndroid Build Coastguard Worker stripped_proto = onnx.ModelProto() 458*da0073e9SAndroid Build Coastguard Worker stripped_proto.CopyFrom(proto) 459*da0073e9SAndroid Build Coastguard Worker del stripped_proto.graph.initializer[:] 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker stripped_alt_proto = onnx.ModelProto() 462*da0073e9SAndroid Build Coastguard Worker stripped_alt_proto.CopyFrom(alt_proto) 463*da0073e9SAndroid Build Coastguard Worker del stripped_alt_proto.graph.initializer[:] 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker # Compare the printable graph representations first 466*da0073e9SAndroid Build Coastguard Worker errs.requireMultiLineEqual( 467*da0073e9SAndroid Build Coastguard Worker onnx.helper.printable_graph(stripped_proto.graph), 468*da0073e9SAndroid Build Coastguard Worker onnx.helper.printable_graph(stripped_alt_proto.graph), 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker # Compare the actual protobuf text formats now (not 472*da0073e9SAndroid Build Coastguard Worker # very user-friendly!) 473*da0073e9SAndroid Build Coastguard Worker errs.requireMultiLineEqual( 474*da0073e9SAndroid Build Coastguard Worker str(stripped_proto), str(stripped_alt_proto) 475*da0073e9SAndroid Build Coastguard Worker ) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker # One last ditch effort, using built-in equality on 478*da0073e9SAndroid Build Coastguard Worker # protobufs 479*da0073e9SAndroid Build Coastguard Worker errs.requireEqual(stripped_proto, stripped_alt_proto) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker errs.failIfErrs() 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker # At this point, we should have figured out why the binary 484*da0073e9SAndroid Build Coastguard Worker # protobufs differed, and short-circuited out of this code 485*da0073e9SAndroid Build Coastguard Worker # with a helpful error message. But what if we didn't? 486*da0073e9SAndroid Build Coastguard Worker # We better still try to give a good error message in this 487*da0073e9SAndroid Build Coastguard Worker # case. We EXPECT these requires to fail. If they don't, 488*da0073e9SAndroid Build Coastguard Worker # that is a bug in verify 489*da0073e9SAndroid Build Coastguard Worker errs.requireEqual(proto, alt_proto) 490*da0073e9SAndroid Build Coastguard Worker errs.requireEqual( 491*da0073e9SAndroid Build Coastguard Worker proto_bytes.getvalue(), alt_proto_bytes.getvalue() 492*da0073e9SAndroid Build Coastguard Worker ) 493*da0073e9SAndroid Build Coastguard Worker raise AssertionError 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker # TODO: test that the traced model also returns the same thing... 496*da0073e9SAndroid Build Coastguard Worker run_helper(torch_out, args, remained_onnx_input_idx) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker # Factored out so we can avoid one run of the model 499*da0073e9SAndroid Build Coastguard Worker def run_helper(torch_out, args, remained_onnx_input_idx): 500*da0073e9SAndroid Build Coastguard Worker onnx_input = backend_args(args) 501*da0073e9SAndroid Build Coastguard Worker if remained_onnx_input_idx is not None: 502*da0073e9SAndroid Build Coastguard Worker input_onnx = [] 503*da0073e9SAndroid Build Coastguard Worker for idx in remained_onnx_input_idx: 504*da0073e9SAndroid Build Coastguard Worker input_onnx.append(onnx_input[idx]) 505*da0073e9SAndroid Build Coastguard Worker onnx_input = tuple(input_onnx) 506*da0073e9SAndroid Build Coastguard Worker backend_out = prepared.run(onnx_input) 507*da0073e9SAndroid Build Coastguard Worker if isinstance(torch_out, torch.Tensor): 508*da0073e9SAndroid Build Coastguard Worker torch_out = (torch_out,) 509*da0073e9SAndroid Build Coastguard Worker torch_out, _ = torch.jit._flatten(torch_out) 510*da0073e9SAndroid Build Coastguard Worker # NB: onnx backend NEVER returns bare numpy array 511*da0073e9SAndroid Build Coastguard Worker msg = "ONNX backend returned different results from PyTorch" 512*da0073e9SAndroid Build Coastguard Worker result_hint = ( 513*da0073e9SAndroid Build Coastguard Worker "If you are not using trained parameters, a difference in results\n" 514*da0073e9SAndroid Build Coastguard Worker "could mean that your network is numerically unstable. Otherwise\n" 515*da0073e9SAndroid Build Coastguard Worker "it indicates a bug in PyTorch/ONNX; please file a bug report." 516*da0073e9SAndroid Build Coastguard Worker ) 517*da0073e9SAndroid Build Coastguard Worker with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt( 518*da0073e9SAndroid Build Coastguard Worker result_hint 519*da0073e9SAndroid Build Coastguard Worker ): 520*da0073e9SAndroid Build Coastguard Worker for i, (x, y) in enumerate(zip(torch_out, backend_out)): 521*da0073e9SAndroid Build Coastguard Worker errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}") 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker run_helper(torch_out, args, remained_onnx_input_idx) 524*da0073e9SAndroid Build Coastguard Worker 525*da0073e9SAndroid Build Coastguard Worker if isinstance(test_args, int): 526*da0073e9SAndroid Build Coastguard Worker for i in range(test_args): 527*da0073e9SAndroid Build Coastguard Worker run(randomize_args(args), remained_onnx_input_idx) 528*da0073e9SAndroid Build Coastguard Worker else: 529*da0073e9SAndroid Build Coastguard Worker for test_arg in test_args: 530*da0073e9SAndroid Build Coastguard Worker run(test_arg, remained_onnx_input_idx) 531