xref: /aosp_15_r20/external/pytorch/test/onnx/verify.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport difflib
2*da0073e9SAndroid Build Coastguard Workerimport io
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport onnx
7*da0073e9SAndroid Build Coastguard Workerimport onnx.helper
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch.jit
11*da0073e9SAndroid Build Coastguard Workerimport torch.onnx
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerdef colonize(msg, sep=": "):
15*da0073e9SAndroid Build Coastguard Worker    if not msg:
16*da0073e9SAndroid Build Coastguard Worker        return ""
17*da0073e9SAndroid Build Coastguard Worker    else:
18*da0073e9SAndroid Build Coastguard Worker        return msg + sep
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerclass Errors:
22*da0073e9SAndroid Build Coastguard Worker    """
23*da0073e9SAndroid Build Coastguard Worker    An error-collecting object which supports error recovery.
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    It is intended to be used like a context manager:
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    >>> with Errors("Top-level error message") as errs:
28*da0073e9SAndroid Build Coastguard Worker    >>>     ...
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def __init__(self, msg, rtol=1e-3, atol=1e-5):
32*da0073e9SAndroid Build Coastguard Worker        self.msg = msg
33*da0073e9SAndroid Build Coastguard Worker        self.errors = []
34*da0073e9SAndroid Build Coastguard Worker        self.context = []
35*da0073e9SAndroid Build Coastguard Worker        self.rtol = rtol
36*da0073e9SAndroid Build Coastguard Worker        self.atol = atol
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        # Allocated upon instance creation so that multiple Errors
39*da0073e9SAndroid Build Coastguard Worker        # can be used
40*da0073e9SAndroid Build Coastguard Worker        class ShortCircuit(Exception):
41*da0073e9SAndroid Build Coastguard Worker            pass
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        self.exc_class = ShortCircuit
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def requireAlmostEqual(self, x, y, msg=None):
46*da0073e9SAndroid Build Coastguard Worker        """
47*da0073e9SAndroid Build Coastguard Worker        Test that x and y are nearly equal (equal within self.rtol
48*da0073e9SAndroid Build Coastguard Worker        precision); aborts execution if they are not.
49*da0073e9SAndroid Build Coastguard Worker        """
50*da0073e9SAndroid Build Coastguard Worker        self.almostEqualAndThen(x, y, msg, self.failWith)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def checkAlmostEqual(self, x, y, msg=None):
53*da0073e9SAndroid Build Coastguard Worker        """
54*da0073e9SAndroid Build Coastguard Worker        Test that x and y are nearly equal (equal within self.rtol
55*da0073e9SAndroid Build Coastguard Worker        precision), but continue execution even if they are not equal.
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        To prevent error cascades, you should remember to call "failIfErrs"
58*da0073e9SAndroid Build Coastguard Worker        at some later point in time.
59*da0073e9SAndroid Build Coastguard Worker        """
60*da0073e9SAndroid Build Coastguard Worker        self.almostEqualAndThen(x, y, msg, self.addErr)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def almostEqualAndThen(self, x, y, msg, k):
63*da0073e9SAndroid Build Coastguard Worker        """
64*da0073e9SAndroid Build Coastguard Worker        Helper for implementing "requireAlmostEqual" and "checkAlmostEqual".
65*da0073e9SAndroid Build Coastguard Worker        Upon failure, invokes continuation "k" with the error message.
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        At the moment, only tests on "numpy.ndarray" are supported.
68*da0073e9SAndroid Build Coastguard Worker        """
69*da0073e9SAndroid Build Coastguard Worker        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
70*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_allclose(
71*da0073e9SAndroid Build Coastguard Worker                x, y, rtol=self.rtol, atol=self.atol, equal_nan=True, verbose=True
72*da0073e9SAndroid Build Coastguard Worker            )
73*da0073e9SAndroid Build Coastguard Worker        else:
74*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("Unsupported almost equal test")
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def requireEqual(self, x, y, msg=None):
77*da0073e9SAndroid Build Coastguard Worker        """
78*da0073e9SAndroid Build Coastguard Worker        Test that x and y are equal; aborts execution if they are not.
79*da0073e9SAndroid Build Coastguard Worker        """
80*da0073e9SAndroid Build Coastguard Worker        self.equalAndThen(x, y, msg, self.failWith)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def checkEqual(self, x, y, msg=None):
83*da0073e9SAndroid Build Coastguard Worker        """
84*da0073e9SAndroid Build Coastguard Worker        Test that x and y are equal, but continue execution even if they are not equal.
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker        To prevent error cascades, you should remember to call "failIfErrs"
87*da0073e9SAndroid Build Coastguard Worker        at some later point in time.
88*da0073e9SAndroid Build Coastguard Worker        """
89*da0073e9SAndroid Build Coastguard Worker        self.equalAndThen(x, y, msg, self.addErr)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    # Bit-for-bit accuracy test
92*da0073e9SAndroid Build Coastguard Worker    def equalAndThen(self, x, y, msg, k):
93*da0073e9SAndroid Build Coastguard Worker        """
94*da0073e9SAndroid Build Coastguard Worker        Helper for implementing "requireEqual" and "checkEqual".  Upon failure,
95*da0073e9SAndroid Build Coastguard Worker        invokes continuation "k" with the error message.
96*da0073e9SAndroid Build Coastguard Worker        """
97*da0073e9SAndroid Build Coastguard Worker        if isinstance(x, onnx.TensorProto) and isinstance(y, onnx.TensorProto):
98*da0073e9SAndroid Build Coastguard Worker            self.equalAndThen(x.name, y.name, msg, k)
99*da0073e9SAndroid Build Coastguard Worker            # Use numpy for the comparison
100*da0073e9SAndroid Build Coastguard Worker            t1 = onnx.numpy_helper.to_array(x)
101*da0073e9SAndroid Build Coastguard Worker            t2 = onnx.numpy_helper.to_array(y)
102*da0073e9SAndroid Build Coastguard Worker            new_msg = f"{colonize(msg)}In embedded parameter '{x.name}'"
103*da0073e9SAndroid Build Coastguard Worker            self.equalAndThen(t1, t2, new_msg, k)
104*da0073e9SAndroid Build Coastguard Worker        elif isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
105*da0073e9SAndroid Build Coastguard Worker            np.testing.assert_equal(x, y)
106*da0073e9SAndroid Build Coastguard Worker        else:
107*da0073e9SAndroid Build Coastguard Worker            if x != y:
108*da0073e9SAndroid Build Coastguard Worker                # TODO: Better algorithm for lists
109*da0073e9SAndroid Build Coastguard Worker                sx = str(x)
110*da0073e9SAndroid Build Coastguard Worker                sy = str(y)
111*da0073e9SAndroid Build Coastguard Worker                if len(sx) > 40 or len(sy) > 40 or "\n" in sx or "\n" in sy:
112*da0073e9SAndroid Build Coastguard Worker                    # long form
113*da0073e9SAndroid Build Coastguard Worker                    l = "=" * 50
114*da0073e9SAndroid Build Coastguard Worker                    k(
115*da0073e9SAndroid Build Coastguard Worker                        "\n{}The value\n{}\n{}\n{}\n\ndoes not equal\n\n{}\n{}\n{}".format(
116*da0073e9SAndroid Build Coastguard Worker                            colonize(msg, ":\n"), l, sx, l, l, sy, l
117*da0073e9SAndroid Build Coastguard Worker                        )
118*da0073e9SAndroid Build Coastguard Worker                    )
119*da0073e9SAndroid Build Coastguard Worker                else:
120*da0073e9SAndroid Build Coastguard Worker                    k(f"{colonize(msg)}{sx} != {sy}")
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def requireMultiLineEqual(self, x, y, msg=None):
123*da0073e9SAndroid Build Coastguard Worker        """
124*da0073e9SAndroid Build Coastguard Worker        Test that long, multi-line strings x and y are equal;
125*da0073e9SAndroid Build Coastguard Worker        aborts execution if they are not.
126*da0073e9SAndroid Build Coastguard Worker        """
127*da0073e9SAndroid Build Coastguard Worker        self.multiLineEqualAndThen(x, y, msg, self.failWith)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    def multiLineEqualAndThen(self, x, y, msg, k):
130*da0073e9SAndroid Build Coastguard Worker        """
131*da0073e9SAndroid Build Coastguard Worker        Helper for implementing "requireMultiLineEqual".  Upon failure,
132*da0073e9SAndroid Build Coastguard Worker        invokes continuation "k" with the error message.
133*da0073e9SAndroid Build Coastguard Worker        """
134*da0073e9SAndroid Build Coastguard Worker        if msg is None:
135*da0073e9SAndroid Build Coastguard Worker            msg = "Strings are not equal"
136*da0073e9SAndroid Build Coastguard Worker        if x != y:
137*da0073e9SAndroid Build Coastguard Worker            diff = difflib.ndiff(x.splitlines(True), y.splitlines(True))
138*da0073e9SAndroid Build Coastguard Worker            k("{}{}".format(colonize(msg, ":\n\n"), "".join(diff)))
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    def addErr(self, msg):
141*da0073e9SAndroid Build Coastguard Worker        """
142*da0073e9SAndroid Build Coastguard Worker        Add an error to the error context, but continue executing.
143*da0073e9SAndroid Build Coastguard Worker        """
144*da0073e9SAndroid Build Coastguard Worker        # TODO: instead of immediately concatenating the context in the msg,
145*da0073e9SAndroid Build Coastguard Worker        # attach it as metadata and make a decision how to format it later.
146*da0073e9SAndroid Build Coastguard Worker        msg_w_ctx = msg
147*da0073e9SAndroid Build Coastguard Worker        for c in reversed(self.context):
148*da0073e9SAndroid Build Coastguard Worker            msg += "\n\n  * " + "\n    ".join(c.splitlines())
149*da0073e9SAndroid Build Coastguard Worker        self.errors.append(msg)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def fail(self):
152*da0073e9SAndroid Build Coastguard Worker        """
153*da0073e9SAndroid Build Coastguard Worker        Immediately fail and short-circuit to the next recovery context.
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker        NB: It is an error to "fail" without having added any errors to
156*da0073e9SAndroid Build Coastguard Worker        the error context.
157*da0073e9SAndroid Build Coastguard Worker        """
158*da0073e9SAndroid Build Coastguard Worker        raise self.exc_class
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    def failWith(self, msg):
161*da0073e9SAndroid Build Coastguard Worker        """
162*da0073e9SAndroid Build Coastguard Worker        Add an error to the error context, and then short-circuit.
163*da0073e9SAndroid Build Coastguard Worker        """
164*da0073e9SAndroid Build Coastguard Worker        self.addErr(msg)
165*da0073e9SAndroid Build Coastguard Worker        self.fail()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def failIfErrs(self):
168*da0073e9SAndroid Build Coastguard Worker        """
169*da0073e9SAndroid Build Coastguard Worker        If there are any errors in the error context, short-circuit.
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        This is used to prevent error cascades.
172*da0073e9SAndroid Build Coastguard Worker        """
173*da0073e9SAndroid Build Coastguard Worker        if self.errors:
174*da0073e9SAndroid Build Coastguard Worker            self.fail()
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    def recover(self):
177*da0073e9SAndroid Build Coastguard Worker        """
178*da0073e9SAndroid Build Coastguard Worker        Returns a context manager which can be used to recover in case of
179*da0073e9SAndroid Build Coastguard Worker        an error.  Example usage:
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        >>> with errs.recover():
182*da0073e9SAndroid Build Coastguard Worker        >>>     ...
183*da0073e9SAndroid Build Coastguard Worker        """
184*da0073e9SAndroid Build Coastguard Worker        parent_self = self
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker        class Recover:
187*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
188*da0073e9SAndroid Build Coastguard Worker                pass
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_value, traceback):
191*da0073e9SAndroid Build Coastguard Worker                if exc_type == parent_self.exc_class:
192*da0073e9SAndroid Build Coastguard Worker                    return True
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        return Recover()
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker    def addErrCtxt(self, msg):
197*da0073e9SAndroid Build Coastguard Worker        """
198*da0073e9SAndroid Build Coastguard Worker        Returns a context manager which encloses a fragment of code with
199*da0073e9SAndroid Build Coastguard Worker        an extra contextual message, e.g., where an error occurred, or a hint
200*da0073e9SAndroid Build Coastguard Worker        applicable to all errors in the area.  Example usage:
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        >>> with errs.addErrCtx("Some text"):
203*da0073e9SAndroid Build Coastguard Worker        >>>     ...
204*da0073e9SAndroid Build Coastguard Worker        """
205*da0073e9SAndroid Build Coastguard Worker        parent_self = self
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        class AddContext:
208*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
209*da0073e9SAndroid Build Coastguard Worker                parent_self.context.append(msg)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, exc_type, exc_value, traceback):
212*da0073e9SAndroid Build Coastguard Worker                parent_self.context.pop()
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        return AddContext()
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
217*da0073e9SAndroid Build Coastguard Worker        return self
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, exc_type, exc_value, traceback):
220*da0073e9SAndroid Build Coastguard Worker        if self.errors:
221*da0073e9SAndroid Build Coastguard Worker            errors_msg = "\n\n".join("ERROR: " + x for x in self.errors)
222*da0073e9SAndroid Build Coastguard Worker            final_msg = "{}\n{}\n{}".format(self.msg, "-" * 70, errors_msg)
223*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(final_msg)
224*da0073e9SAndroid Build Coastguard Worker        if exc_type == self.exc_class:
225*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("ShortCircuit was raised, but no errors were recorded")
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Workerdef verify(
229*da0073e9SAndroid Build Coastguard Worker    model,
230*da0073e9SAndroid Build Coastguard Worker    args,
231*da0073e9SAndroid Build Coastguard Worker    backend,
232*da0073e9SAndroid Build Coastguard Worker    verbose=False,
233*da0073e9SAndroid Build Coastguard Worker    training=torch.onnx.TrainingMode.EVAL,
234*da0073e9SAndroid Build Coastguard Worker    rtol=1e-3,
235*da0073e9SAndroid Build Coastguard Worker    atol=1e-7,
236*da0073e9SAndroid Build Coastguard Worker    test_args=2,
237*da0073e9SAndroid Build Coastguard Worker    do_constant_folding=True,
238*da0073e9SAndroid Build Coastguard Worker    opset_version=None,
239*da0073e9SAndroid Build Coastguard Worker    keep_initializers_as_inputs=True,
240*da0073e9SAndroid Build Coastguard Worker    add_node_names=False,
241*da0073e9SAndroid Build Coastguard Worker    operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
242*da0073e9SAndroid Build Coastguard Worker    input_names=None,
243*da0073e9SAndroid Build Coastguard Worker    dynamic_axes=None,
244*da0073e9SAndroid Build Coastguard Worker    remained_onnx_input_idx=None,
245*da0073e9SAndroid Build Coastguard Worker):
246*da0073e9SAndroid Build Coastguard Worker    """
247*da0073e9SAndroid Build Coastguard Worker    Export a model into ONNX, import it into a specified ONNX backend, and then
248*da0073e9SAndroid Build Coastguard Worker    on a few random inputs verify that PyTorch and the backend produced the same
249*da0073e9SAndroid Build Coastguard Worker    results.  Requires onnx to be installed.
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    This function may spuriously fail: some operators are implemented with
252*da0073e9SAndroid Build Coastguard Worker    different numerical precision in an ONNX backend, in which case an unstable
253*da0073e9SAndroid Build Coastguard Worker    network (e.g., Inception) may blow up these numerical instabilities.  This
254*da0073e9SAndroid Build Coastguard Worker    situation is less likely to happen if your model has been trained.  However,
255*da0073e9SAndroid Build Coastguard Worker    if this is not the case, you may have found a bug!  Please report it to the
256*da0073e9SAndroid Build Coastguard Worker    PyTorch developers.  You can also debug the issue yourself by removing
257*da0073e9SAndroid Build Coastguard Worker    suffixes of operators from your model until verification passes.
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    For reproducibility, we recommend explicitly setting PyTorch's seed before
260*da0073e9SAndroid Build Coastguard Worker    invoking this function.
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker    Args:
263*da0073e9SAndroid Build Coastguard Worker        model (torch.nn.Module): the model to be exported and verified
264*da0073e9SAndroid Build Coastguard Worker        args (tuple of arguments): the inputs to
265*da0073e9SAndroid Build Coastguard Worker            the model, e.g., such that ``model(*args)`` is a valid
266*da0073e9SAndroid Build Coastguard Worker            invocation of the model.  Any non-Variable arguments will
267*da0073e9SAndroid Build Coastguard Worker            be hard-coded into the exported model; any Variable arguments
268*da0073e9SAndroid Build Coastguard Worker            will become inputs of the exported model, in the order they
269*da0073e9SAndroid Build Coastguard Worker            occur in args.  If args is a Variable, this is equivalent
270*da0073e9SAndroid Build Coastguard Worker            to having called it with a 1-ary tuple of that Variable.
271*da0073e9SAndroid Build Coastguard Worker            (Note: passing keyword arguments to the model is not currently
272*da0073e9SAndroid Build Coastguard Worker            supported.  Give us a shout if you need it.)
273*da0073e9SAndroid Build Coastguard Worker        backend (onnx.backend module): ONNX backend to verify with
274*da0073e9SAndroid Build Coastguard Worker        verbose (bool, default False): if specified, we will print out a debug
275*da0073e9SAndroid Build Coastguard Worker            description of the trace being exported.
276*da0073e9SAndroid Build Coastguard Worker        training (bool, default False): export the model in training mode.  At
277*da0073e9SAndroid Build Coastguard Worker            the moment, ONNX is oriented towards exporting models for inference
278*da0073e9SAndroid Build Coastguard Worker            only, so you will generally not need to set this to True.
279*da0073e9SAndroid Build Coastguard Worker        rtol (float, default 1e-3): relative precision required
280*da0073e9SAndroid Build Coastguard Worker        test_args (int or iterable of args, default 2):
281*da0073e9SAndroid Build Coastguard Worker            either an integer specifying the number
282*da0073e9SAndroid Build Coastguard Worker            of random arguments to generate, or an iterable producing arguments
283*da0073e9SAndroid Build Coastguard Worker            to test under.
284*da0073e9SAndroid Build Coastguard Worker        opset_version (int, default None): the opset version of the model to
285*da0073e9SAndroid Build Coastguard Worker            export. If not specified, the default value in symboli_helper will
286*da0073e9SAndroid Build Coastguard Worker            be used in utils._export().
287*da0073e9SAndroid Build Coastguard Worker        operator_export_type (enum, default OperatorExportTypes.ONNX): the operator
288*da0073e9SAndroid Build Coastguard Worker            export type to use when exporting the model. The default value converts
289*da0073e9SAndroid Build Coastguard Worker            all operators to ONNX ops.
290*da0073e9SAndroid Build Coastguard Worker        input_names (list of string): list of input names.
291*da0073e9SAndroid Build Coastguard Worker        dynamic_axes (dict of (string, list)): dynamic_axes.
292*da0073e9SAndroid Build Coastguard Worker        remained_onnx_input_idx (list of int, default None): The remained ONNX input index.
293*da0073e9SAndroid Build Coastguard Worker    """
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    def _nested_map(condition, fn, condition_msg=None):
296*da0073e9SAndroid Build Coastguard Worker        def _map(obj):
297*da0073e9SAndroid Build Coastguard Worker            if condition(obj):
298*da0073e9SAndroid Build Coastguard Worker                return fn(obj)
299*da0073e9SAndroid Build Coastguard Worker            elif obj is None:
300*da0073e9SAndroid Build Coastguard Worker                return None
301*da0073e9SAndroid Build Coastguard Worker            elif isinstance(obj, (list, tuple)):
302*da0073e9SAndroid Build Coastguard Worker                return type(obj)(_map(x) for x in obj)
303*da0073e9SAndroid Build Coastguard Worker            else:
304*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
305*da0073e9SAndroid Build Coastguard Worker                    "Auto nesting doesn't know how to process "
306*da0073e9SAndroid Build Coastguard Worker                    "an input object of type "
307*da0073e9SAndroid Build Coastguard Worker                    + torch.typename(obj)
308*da0073e9SAndroid Build Coastguard Worker                    + (
309*da0073e9SAndroid Build Coastguard Worker                        ". Accepted types: "
310*da0073e9SAndroid Build Coastguard Worker                        + condition_msg
311*da0073e9SAndroid Build Coastguard Worker                        + ", or lists/tuples of them"
312*da0073e9SAndroid Build Coastguard Worker                        if condition_msg
313*da0073e9SAndroid Build Coastguard Worker                        else ""
314*da0073e9SAndroid Build Coastguard Worker                    )
315*da0073e9SAndroid Build Coastguard Worker                )
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        return _map
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    def _iter_filter(condition, allow_unknown=False, condition_msg=None):
320*da0073e9SAndroid Build Coastguard Worker        def _iter(obj):
321*da0073e9SAndroid Build Coastguard Worker            if condition(obj):
322*da0073e9SAndroid Build Coastguard Worker                yield obj
323*da0073e9SAndroid Build Coastguard Worker            elif obj is None:
324*da0073e9SAndroid Build Coastguard Worker                return
325*da0073e9SAndroid Build Coastguard Worker            elif isinstance(obj, (list, tuple)):
326*da0073e9SAndroid Build Coastguard Worker                for o in obj:
327*da0073e9SAndroid Build Coastguard Worker                    yield from _iter(o)
328*da0073e9SAndroid Build Coastguard Worker            elif allow_unknown:
329*da0073e9SAndroid Build Coastguard Worker                yield obj
330*da0073e9SAndroid Build Coastguard Worker            else:
331*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
332*da0073e9SAndroid Build Coastguard Worker                    "Auto nesting doesn't know how to process "
333*da0073e9SAndroid Build Coastguard Worker                    "an input object of type "
334*da0073e9SAndroid Build Coastguard Worker                    + torch.typename(obj)
335*da0073e9SAndroid Build Coastguard Worker                    + (
336*da0073e9SAndroid Build Coastguard Worker                        ". Accepted types: "
337*da0073e9SAndroid Build Coastguard Worker                        + condition_msg
338*da0073e9SAndroid Build Coastguard Worker                        + ", or lists/tuples of them"
339*da0073e9SAndroid Build Coastguard Worker                        if condition_msg
340*da0073e9SAndroid Build Coastguard Worker                        else ""
341*da0073e9SAndroid Build Coastguard Worker                    )
342*da0073e9SAndroid Build Coastguard Worker                )
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        return _iter
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    def is_tensor(o):
347*da0073e9SAndroid Build Coastguard Worker        return isinstance(o, torch.Tensor)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    _iter_tensors = _iter_filter(is_tensor, condition_msg="Tensors")
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    def randomize_arg(arg):
352*da0073e9SAndroid Build Coastguard Worker        new_data = arg.data.clone()
353*da0073e9SAndroid Build Coastguard Worker        # For now, don't try randomizing non-float tensors; these
354*da0073e9SAndroid Build Coastguard Worker        # are likely to be things like indices, where just randomly
355*da0073e9SAndroid Build Coastguard Worker        # spattering some longs is unlikely to work.  One way we could
356*da0073e9SAndroid Build Coastguard Worker        # make this work is to apply a random permutation or something.
357*da0073e9SAndroid Build Coastguard Worker        if arg.is_floating_point():
358*da0073e9SAndroid Build Coastguard Worker            new_data.uniform_()
359*da0073e9SAndroid Build Coastguard Worker        return torch.autograd.Variable(new_data, requires_grad=arg.requires_grad)
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    randomize_args = _nested_map(is_tensor, randomize_arg)
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    def backend_args(args):
364*da0073e9SAndroid Build Coastguard Worker        # TODO: onnx should accept iterables
365*da0073e9SAndroid Build Coastguard Worker        return tuple(v.data.cpu().numpy() for v in _iter_tensors(args))
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    def load_bytes(b):
368*da0073e9SAndroid Build Coastguard Worker        b.seek(0)
369*da0073e9SAndroid Build Coastguard Worker        x = onnx.load(b)
370*da0073e9SAndroid Build Coastguard Worker        # doc_string has stack traces - let's remove them to make comparison
371*da0073e9SAndroid Build Coastguard Worker        # sane
372*da0073e9SAndroid Build Coastguard Worker        onnx.helper.strip_doc_string(x)
373*da0073e9SAndroid Build Coastguard Worker        return x
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    # Special case for common case of passing a single Tensor
376*da0073e9SAndroid Build Coastguard Worker    if isinstance(args, torch.Tensor):
377*da0073e9SAndroid Build Coastguard Worker        args = (args,)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker    with torch.onnx.select_model_mode_for_export(model, training):
380*da0073e9SAndroid Build Coastguard Worker        proto_bytes = io.BytesIO()
381*da0073e9SAndroid Build Coastguard Worker        torch_out = torch.onnx.utils._export(
382*da0073e9SAndroid Build Coastguard Worker            model,
383*da0073e9SAndroid Build Coastguard Worker            args,
384*da0073e9SAndroid Build Coastguard Worker            proto_bytes,
385*da0073e9SAndroid Build Coastguard Worker            verbose=verbose,
386*da0073e9SAndroid Build Coastguard Worker            do_constant_folding=do_constant_folding,
387*da0073e9SAndroid Build Coastguard Worker            opset_version=opset_version,
388*da0073e9SAndroid Build Coastguard Worker            keep_initializers_as_inputs=keep_initializers_as_inputs,
389*da0073e9SAndroid Build Coastguard Worker            add_node_names=add_node_names,
390*da0073e9SAndroid Build Coastguard Worker            operator_export_type=operator_export_type,
391*da0073e9SAndroid Build Coastguard Worker            input_names=input_names,
392*da0073e9SAndroid Build Coastguard Worker            dynamic_axes=dynamic_axes,
393*da0073e9SAndroid Build Coastguard Worker        )
394*da0073e9SAndroid Build Coastguard Worker        if isinstance(model, torch.jit.ScriptModule):
395*da0073e9SAndroid Build Coastguard Worker            torch_out = model(*args)
396*da0073e9SAndroid Build Coastguard Worker        proto = load_bytes(proto_bytes)
397*da0073e9SAndroid Build Coastguard Worker        prepared = backend.prepare(proto)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        def run(args, remained_onnx_input_idx):
400*da0073e9SAndroid Build Coastguard Worker            alt_proto_bytes = io.BytesIO()
401*da0073e9SAndroid Build Coastguard Worker            torch_out = torch.onnx.utils._export(
402*da0073e9SAndroid Build Coastguard Worker                model,
403*da0073e9SAndroid Build Coastguard Worker                args,
404*da0073e9SAndroid Build Coastguard Worker                alt_proto_bytes,
405*da0073e9SAndroid Build Coastguard Worker                verbose=verbose,
406*da0073e9SAndroid Build Coastguard Worker                do_constant_folding=do_constant_folding,
407*da0073e9SAndroid Build Coastguard Worker                opset_version=opset_version,
408*da0073e9SAndroid Build Coastguard Worker                keep_initializers_as_inputs=keep_initializers_as_inputs,
409*da0073e9SAndroid Build Coastguard Worker                add_node_names=add_node_names,
410*da0073e9SAndroid Build Coastguard Worker                operator_export_type=operator_export_type,
411*da0073e9SAndroid Build Coastguard Worker                input_names=input_names,
412*da0073e9SAndroid Build Coastguard Worker                dynamic_axes=dynamic_axes,
413*da0073e9SAndroid Build Coastguard Worker            )
414*da0073e9SAndroid Build Coastguard Worker            if isinstance(model, torch.jit.ScriptModule):
415*da0073e9SAndroid Build Coastguard Worker                torch_out = model(*args)
416*da0073e9SAndroid Build Coastguard Worker            alt_proto = load_bytes(alt_proto_bytes)
417*da0073e9SAndroid Build Coastguard Worker            if proto.SerializeToString() != alt_proto.SerializeToString():
418*da0073e9SAndroid Build Coastguard Worker                # OK, let's try to figure out what happened.
419*da0073e9SAndroid Build Coastguard Worker                msg = "When I exported your model with different inputs, the result was different."
420*da0073e9SAndroid Build Coastguard Worker                if not verbose:
421*da0073e9SAndroid Build Coastguard Worker                    msg += "\n(To get more information, run torch.onnx.verify(..., verbose=True))"
422*da0073e9SAndroid Build Coastguard Worker                with Errors(msg, rtol=rtol, atol=atol) as errs:
423*da0073e9SAndroid Build Coastguard Worker                    # First, check if we have the same number of parameters, and
424*da0073e9SAndroid Build Coastguard Worker                    # that they"re the same order.  If they don"t, something has *really* gone wrong.
425*da0073e9SAndroid Build Coastguard Worker                    initializer_order_hint = (
426*da0073e9SAndroid Build Coastguard Worker                        "This is really strange! The second time I exported your model,\n"
427*da0073e9SAndroid Build Coastguard Worker                        "it had a different set of parameters.  Are you assigning Parameters\n"
428*da0073e9SAndroid Build Coastguard Worker                        "in the forward() of your model definition?"
429*da0073e9SAndroid Build Coastguard Worker                    )
430*da0073e9SAndroid Build Coastguard Worker                    with errs.addErrCtxt(initializer_order_hint):
431*da0073e9SAndroid Build Coastguard Worker                        errs.requireEqual(
432*da0073e9SAndroid Build Coastguard Worker                            [x.name for x in proto.graph.initializer],
433*da0073e9SAndroid Build Coastguard Worker                            [x.name for x in alt_proto.graph.initializer],
434*da0073e9SAndroid Build Coastguard Worker                            msg="Parameters list differs",
435*da0073e9SAndroid Build Coastguard Worker                        )
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker                    # Now check if the embedded parameters are actually the same
438*da0073e9SAndroid Build Coastguard Worker                    initializer_hint = (
439*da0073e9SAndroid Build Coastguard Worker                        "A difference in embedded parameters usually means that\n"
440*da0073e9SAndroid Build Coastguard Worker                        "your model is updating parameters/buffers even in inference\n"
441*da0073e9SAndroid Build Coastguard Worker                        "mode.  Look for a buggy nn.Module which isn't respecting train().\n"
442*da0073e9SAndroid Build Coastguard Worker                    )
443*da0073e9SAndroid Build Coastguard Worker                    with errs.recover(), errs.addErrCtxt(initializer_hint):
444*da0073e9SAndroid Build Coastguard Worker                        for x, y in zip(
445*da0073e9SAndroid Build Coastguard Worker                            proto.graph.initializer, alt_proto.graph.initializer
446*da0073e9SAndroid Build Coastguard Worker                        ):
447*da0073e9SAndroid Build Coastguard Worker                            errs.checkEqual(x, y)
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker                    # Next, check if the model structure lines up.
450*da0073e9SAndroid Build Coastguard Worker                    structure_hint = (
451*da0073e9SAndroid Build Coastguard Worker                        "A difference in model structure usually means that\n"
452*da0073e9SAndroid Build Coastguard Worker                        "your model has dynamic control flow.  These models are not\n"
453*da0073e9SAndroid Build Coastguard Worker                        "currently supported by the exporter."
454*da0073e9SAndroid Build Coastguard Worker                    )
455*da0073e9SAndroid Build Coastguard Worker                    with errs.recover(), errs.addErrCtxt(structure_hint):
456*da0073e9SAndroid Build Coastguard Worker                        # Delete initializers since we already tested them
457*da0073e9SAndroid Build Coastguard Worker                        stripped_proto = onnx.ModelProto()
458*da0073e9SAndroid Build Coastguard Worker                        stripped_proto.CopyFrom(proto)
459*da0073e9SAndroid Build Coastguard Worker                        del stripped_proto.graph.initializer[:]
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker                        stripped_alt_proto = onnx.ModelProto()
462*da0073e9SAndroid Build Coastguard Worker                        stripped_alt_proto.CopyFrom(alt_proto)
463*da0073e9SAndroid Build Coastguard Worker                        del stripped_alt_proto.graph.initializer[:]
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker                        # Compare the printable graph representations first
466*da0073e9SAndroid Build Coastguard Worker                        errs.requireMultiLineEqual(
467*da0073e9SAndroid Build Coastguard Worker                            onnx.helper.printable_graph(stripped_proto.graph),
468*da0073e9SAndroid Build Coastguard Worker                            onnx.helper.printable_graph(stripped_alt_proto.graph),
469*da0073e9SAndroid Build Coastguard Worker                        )
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker                        # Compare the actual protobuf text formats now (not
472*da0073e9SAndroid Build Coastguard Worker                        # very user-friendly!)
473*da0073e9SAndroid Build Coastguard Worker                        errs.requireMultiLineEqual(
474*da0073e9SAndroid Build Coastguard Worker                            str(stripped_proto), str(stripped_alt_proto)
475*da0073e9SAndroid Build Coastguard Worker                        )
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker                        # One last ditch effort, using built-in equality on
478*da0073e9SAndroid Build Coastguard Worker                        # protobufs
479*da0073e9SAndroid Build Coastguard Worker                        errs.requireEqual(stripped_proto, stripped_alt_proto)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker                    errs.failIfErrs()
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker                    # At this point, we should have figured out why the binary
484*da0073e9SAndroid Build Coastguard Worker                    # protobufs differed, and short-circuited out of this code
485*da0073e9SAndroid Build Coastguard Worker                    # with a helpful error message.  But what if we didn't?
486*da0073e9SAndroid Build Coastguard Worker                    # We better still try to give a good error message in this
487*da0073e9SAndroid Build Coastguard Worker                    # case.  We EXPECT these requires to fail.  If they don't,
488*da0073e9SAndroid Build Coastguard Worker                    # that is a bug in verify
489*da0073e9SAndroid Build Coastguard Worker                    errs.requireEqual(proto, alt_proto)
490*da0073e9SAndroid Build Coastguard Worker                    errs.requireEqual(
491*da0073e9SAndroid Build Coastguard Worker                        proto_bytes.getvalue(), alt_proto_bytes.getvalue()
492*da0073e9SAndroid Build Coastguard Worker                    )
493*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker            # TODO: test that the traced model also returns the same thing...
496*da0073e9SAndroid Build Coastguard Worker            run_helper(torch_out, args, remained_onnx_input_idx)
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker        # Factored out so we can avoid one run of the model
499*da0073e9SAndroid Build Coastguard Worker        def run_helper(torch_out, args, remained_onnx_input_idx):
500*da0073e9SAndroid Build Coastguard Worker            onnx_input = backend_args(args)
501*da0073e9SAndroid Build Coastguard Worker            if remained_onnx_input_idx is not None:
502*da0073e9SAndroid Build Coastguard Worker                input_onnx = []
503*da0073e9SAndroid Build Coastguard Worker                for idx in remained_onnx_input_idx:
504*da0073e9SAndroid Build Coastguard Worker                    input_onnx.append(onnx_input[idx])
505*da0073e9SAndroid Build Coastguard Worker                onnx_input = tuple(input_onnx)
506*da0073e9SAndroid Build Coastguard Worker            backend_out = prepared.run(onnx_input)
507*da0073e9SAndroid Build Coastguard Worker            if isinstance(torch_out, torch.Tensor):
508*da0073e9SAndroid Build Coastguard Worker                torch_out = (torch_out,)
509*da0073e9SAndroid Build Coastguard Worker            torch_out, _ = torch.jit._flatten(torch_out)
510*da0073e9SAndroid Build Coastguard Worker            # NB: onnx backend NEVER returns bare numpy array
511*da0073e9SAndroid Build Coastguard Worker            msg = "ONNX backend returned different results from PyTorch"
512*da0073e9SAndroid Build Coastguard Worker            result_hint = (
513*da0073e9SAndroid Build Coastguard Worker                "If you are not using trained parameters, a difference in results\n"
514*da0073e9SAndroid Build Coastguard Worker                "could mean that your network is numerically unstable.  Otherwise\n"
515*da0073e9SAndroid Build Coastguard Worker                "it indicates a bug in PyTorch/ONNX; please file a bug report."
516*da0073e9SAndroid Build Coastguard Worker            )
517*da0073e9SAndroid Build Coastguard Worker            with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(
518*da0073e9SAndroid Build Coastguard Worker                result_hint
519*da0073e9SAndroid Build Coastguard Worker            ):
520*da0073e9SAndroid Build Coastguard Worker                for i, (x, y) in enumerate(zip(torch_out, backend_out)):
521*da0073e9SAndroid Build Coastguard Worker                    errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}")
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker        run_helper(torch_out, args, remained_onnx_input_idx)
524*da0073e9SAndroid Build Coastguard Worker
525*da0073e9SAndroid Build Coastguard Worker        if isinstance(test_args, int):
526*da0073e9SAndroid Build Coastguard Worker            for i in range(test_args):
527*da0073e9SAndroid Build Coastguard Worker                run(randomize_args(args), remained_onnx_input_idx)
528*da0073e9SAndroid Build Coastguard Worker        else:
529*da0073e9SAndroid Build Coastguard Worker            for test_arg in test_args:
530*da0073e9SAndroid Build Coastguard Worker                run(test_arg, remained_onnx_input_idx)
531