xref: /aosp_15_r20/external/pytorch/torch/jit/_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Tracing.
3
4This module contains functionality to support the JIT's tracing frontend, notably:
5    * torch.jit.trace
6    * torch.jit.trace_module
7
8This is not intended to be imported directly; please use the exposed
9functionalities in `torch.jit`.
10"""
11
12import contextlib
13import copy
14import functools
15import inspect
16import os
17import re
18import warnings
19from enum import Enum
20from typing import Any, Callable, Dict, List, Optional, Set, TypeVar
21from typing_extensions import ParamSpec
22
23import torch
24from torch._jit_internal import (
25    _get_model_id,
26    _qualified_name,
27    get_callable_argument_names,
28    is_scripting,
29)
30from torch.autograd import function
31from torch.jit._script import _CachedForward, script, ScriptModule
32from torch.jit._state import _enabled, _python_cu
33from torch.nn import Module
34from torch.testing._comparison import default_tolerances
35
36
37_flatten = torch._C._jit_flatten
38_unflatten = torch._C._jit_unflatten
39
40R = TypeVar("R", covariant=True)  # return type (always covariant)
41P = ParamSpec("P")
42
43
44def _create_interpreter_name_lookup_fn(frames_up=1):
45    def _get_interpreter_name_for_var(var):
46        frame = inspect.currentframe()
47        if not frame:
48            raise RuntimeError("failed to inspect frame")
49
50        i = 0
51        while i < frames_up + 1:
52            frame = frame.f_back
53            if not frame:
54                raise RuntimeError("failed to get frame")
55            i += 1
56
57        f_locals = frame.f_locals
58        f_globals = frame.f_globals
59
60        for k, v in f_locals.items():
61            if isinstance(v, torch.Tensor) and var is v:
62                return k if k != "self" else ""
63        return ""
64
65    return _get_interpreter_name_for_var
66
67
68def _unique_state_dict(module, keep_vars=False):
69    # since Parameter.detach() always creates a new torch.Tensor instance,
70    # id(v) doesn't work with it. So we always get the Parameter or Buffer
71    # as values, and deduplicate the params using Parameters and Buffers
72    state_dict = module.state_dict(keep_vars=True)
73    filtered_dict = type(state_dict)()
74    seen_ids: Set[int] = set()
75    for k, v in state_dict.items():
76        if id(v) in seen_ids:
77            continue
78        seen_ids.add(id(v))
79        if keep_vars:
80            filtered_dict[k] = v
81        else:
82            filtered_dict[k] = v.detach()
83    return filtered_dict
84
85
86class ONNXTracedModule(torch.nn.Module):
87    def __init__(
88        self,
89        inner,
90        strict=True,
91        force_outplace=False,
92        return_inputs=False,
93        return_inputs_states=False,
94    ):
95        super().__init__()
96        # inner may be a Module, or it may be an arbitrary callable
97        # If it's a Module, we get its parameters automatically, which lets
98        # us avoid a special casing functions versus modules.
99        self.inner = inner
100        self.strict = strict
101        self._force_outplace = force_outplace
102        self._return_inputs = return_inputs
103        self._return_inputs_states = return_inputs_states
104
105    def forward(self, *args: torch.Tensor):
106        in_vars, in_desc = _flatten(args)
107        # NOTE: use full state, because we need it for BatchNorm export
108        # This differs from the compiler path, which doesn't support it at the moment.
109        module_state = list(_unique_state_dict(self, keep_vars=True).values())
110
111        ret_inputs = []
112        inputs_states = []
113        outs = []
114
115        def wrapper(*args):
116            in_args: List[torch.Tensor] = []
117            for i in range(len(in_vars)):
118                if not isinstance(args[i], torch.Tensor):
119                    raise RuntimeError("Expected Tensor argument")
120                in_args.append(args[i])
121
122            trace_inputs = _unflatten(in_args, in_desc)
123
124            if self._return_inputs:
125                ret_inputs.append(
126                    tuple(x.clone(memory_format=torch.preserve_format) for x in args)
127                )
128            if self._return_inputs_states:
129                inputs_states.append(_unflatten(in_args, in_desc))
130            outs.append(self.inner(*trace_inputs))
131            if self._return_inputs_states:
132                inputs_states[0] = (inputs_states[0], trace_inputs)
133            out_vars, _ = _flatten(outs)
134            if len(out_vars) == 1:
135                return out_vars[0]
136            else:
137                return tuple(out_vars)
138
139        graph, out = torch._C._create_graph_by_tracing(
140            wrapper,
141            in_vars + module_state,
142            _create_interpreter_name_lookup_fn(),
143            self.strict,
144            self._force_outplace,
145        )
146
147        if self._return_inputs:
148            return graph, outs[0], ret_inputs[0]
149        if self._return_inputs_states:
150            return graph, outs[0], inputs_states[0]
151        else:
152            return graph, outs[0]
153
154
155def _clone_inputs(args):
156    def clone_input(a):
157        if a is None:
158            return None
159        elif isinstance(a, torch.Tensor):
160            # TODO: figure out one liner to .clone() and set requires_grad
161            v = (
162                a.detach()
163                .clone(memory_format=None if a.is_mkldnn else torch.preserve_format)
164                .requires_grad_(a.requires_grad)
165            )
166            if a.grad is not None:
167                v.grad = clone_input(v.grad)
168            return v
169        else:
170            return a.clone(memory_format=torch.preserve_format)
171
172    return function._nested_map(
173        lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
174    )(args)
175
176
177# This is purely for developer debugging.  We are not going to advertise it.
178_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False)  # CUDA-only timing
179_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
180_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)
181
182
183@contextlib.contextmanager
184def _time(trace_name, name, time=True):
185    if (not _JIT_TIME and not time) or not torch.cuda.is_available():
186        yield
187        return
188    stream = torch.cuda.current_stream()
189    start = torch.cuda.Event(enable_timing=True)
190    end = torch.cuda.Event(enable_timing=True)
191    stream.record_event(start)
192    try:
193        yield
194    finally:
195        stream.record_event(end)
196        end.synchronize()
197        print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms")
198
199
200def verify(model, args, loss_fn=torch.sum, devices=None):
201    """
202    Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass.
203
204    If your model returns multiple outputs,
205    you must also specify a `loss_fn` to produce a loss for which
206    the backwards will be computed.
207
208    This function has side-effects (e.g., it executes your model / saves and loads
209    parameters), so don't expect the model to come out exactly the same as what
210    you passed in.
211
212    Args:
213        model (compiled torch.nn.Module or function): the module/function to be
214            verified.  The module/function definition MUST have been decorated with
215            `@torch.jit.compile`.
216        args (tuple or Tensor): the positional arguments to pass to the
217            compiled function/module to be verified.  A non-tuple is assumed to
218            be a single positional argument to be passed to the model.
219        loss_fn (function, optional): the loss function to be applied to
220            the output of the model, before backwards is invoked.  By default,
221            we assume that a model returns a single result, and we :func:`torch.sum`
222            before calling backwards; if this is inappropriate, you can pass your
223            own loss function.  Note that if a model returns a tuple of results,
224            these are passed as separate positional arguments to `loss_fn`.
225        devices (iterable of device IDs, optional): the GPU devices which the
226            compiled module will be run on.  This determines the RNG state we
227            must save when running both compiled and uncompiled versions of the model.
228    """
229    # TODO: In principle, we track device information in our trace, so it
230    # should be possible to check if our execution actually obeyed the 'devices'
231    # the user provided.
232
233    # TODO: Consider adding a utility function to torch.jit to test
234    # for this case
235    if not isinstance(model, torch._C.CompiledFunction):  # type: ignore[attr-defined]
236        raise TypeError(
237            "Cannot verify an uncompiled module.  Add @torch.jit.compile to compile it"
238        )
239    is_module = isinstance(model, Module)
240
241    if not isinstance(args, tuple):
242        args = (args,)
243
244    saved_args = _clone_inputs(args)
245    if is_module:
246        saved_state = copy.deepcopy(model.state_dict())
247
248    def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
249        params = list(model.parameters()) if is_module else []
250        in_vars, _ = _flatten((args, params))
251        # We use a special API to reset the trace and compile it from scratch.
252        compiled_fn = model
253        if force_trace:
254            compiled_fn.clear_cache()
255        if assert_compiled:
256            hits = compiled_fn.hits
257        out = model(*args)
258        if assert_compiled and compiled_fn.hits == hits:  # type: ignore[possibly-undefined]
259            raise RuntimeError("failed to use the compiled function")
260        if not isinstance(out, tuple):
261            out = (out,)
262        if loss_fn == torch.sum and len(out) != 1:
263            raise ValueError(
264                f"Model returns {len(out)} outputs, but default loss function "
265                "(torch.sum) can only handle a single output"
266            )
267        out_vars, _ = _flatten(out)
268        saved_outs = [
269            v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
270        ]
271        loss = loss_fn(*out)
272        grads = torch.autograd.grad([loss], in_vars)
273        # TODO: I'm not sure if the clone here is necessary but it is safer
274        saved_grads = [
275            v.detach().clone(memory_format=torch.preserve_format) for v in grads
276        ]
277        return (saved_outs, saved_grads)
278
279    with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
280        uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
281        assert model.has_trace_for(*args)
282
283    if is_module:
284        model.load_state_dict(saved_state)  # type: ignore[possibly-undefined]
285    compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)
286
287    _verify_equal(uncompiled_outs, compiled_outs)
288    _verify_equal(uncompiled_grads, compiled_grads)
289
290
291def _verify_equal(xs, ys):
292    for x, y in zip(xs, ys):
293        if x.sub(y).abs().max() > 1e-6:
294            raise RuntimeError("JIT and real computation mismatch")
295
296
297def indent(s):
298    return "\n".join(["\t" + line for line in s.splitlines()])
299
300
301class TracingCheckError(Exception):
302    def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
303        self.message = "Tracing failed sanity checks!\n"
304        if extra_msg is not None:
305            self.message += extra_msg + "\n"
306        if graph_diff_error is not None:
307            self.message += "ERROR: Graphs differed across invocations!\n"
308            self.message += indent(graph_diff_error) + "\n"
309        if tensor_compare_error is not None:
310            self.message += (
311                "ERROR: Tensor-valued Constant nodes differed in value "
312                "across invocations. This often indicates that the tracer has"
313                " encountered untraceable code.\n"
314            )
315            self.message += indent(tensor_compare_error) + "\n"
316        super().__init__(self.message)
317
318
319# Check the traced module against a set of user-provided validation inputs
320@torch.no_grad()
321def _check_trace(
322    check_inputs,
323    func,
324    traced_func,
325    check_tolerance,
326    strict,
327    force_outplace,
328    is_trace_module,
329    _module_class,
330    example_inputs_is_kwarg=False,
331):
332    # Note: tracing is independent of optimizations, which consume the trace
333    for inputs in check_inputs:
334        if isinstance(inputs, torch.Tensor):
335            inputs = (inputs,)
336
337        if is_trace_module:
338            copied_dict = {}
339            for name, data in inputs.items():
340                copied_dict[name] = _clone_inputs(data)
341            check_mod = torch.jit.trace_module(
342                getattr(func, "__self__", func),
343                copied_dict,
344                check_trace=False,
345                strict=strict,
346                _force_outplace=force_outplace,
347                _module_class=_module_class,
348                _compilation_unit=torch._C.CompilationUnit(),
349                example_inputs_is_kwarg=example_inputs_is_kwarg,
350                _store_inputs=False,
351            )
352            check_mod_func = check_mod._c._get_method(traced_func.name)
353            inputs = inputs[traced_func.name]
354            if (
355                isinstance(inputs, (torch.Tensor))
356                or isinstance(inputs, dict)
357                and not example_inputs_is_kwarg
358            ):
359                inputs = (inputs,)
360        else:
361            if example_inputs_is_kwarg:
362                check_mod = torch.jit.trace(
363                    func,
364                    check_trace=False,
365                    strict=strict,
366                    _force_outplace=force_outplace,
367                    _module_class=_module_class,
368                    example_kwarg_inputs=_clone_inputs(inputs),
369                    _store_inputs=False,
370                )
371            else:
372                check_mod = torch.jit.trace(
373                    func,
374                    _clone_inputs(inputs),
375                    check_trace=False,
376                    strict=strict,
377                    _force_outplace=force_outplace,
378                    _module_class=_module_class,
379                    _store_inputs=False,
380                )
381            check_mod_func = check_mod
382
383        def graph_diagnostic_info():
384            mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
385            torch._C._jit_pass_inline(mod_canonicalized)
386            torch._C._jit_pass_erase_shape_information(mod_canonicalized)
387            mod_str = str(mod_canonicalized)
388            mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str)
389            check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
390            torch._C._jit_pass_inline(check_canonicalized)
391            torch._C._jit_pass_erase_shape_information(check_canonicalized)
392            check_str = str(check_canonicalized)
393            check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str)
394
395            graph_diff_errors = None
396            if mod_str != check_str:
397                import difflib
398
399                graph_diff = difflib.ndiff(
400                    mod_str.splitlines(True), check_str.splitlines(True)
401                )
402                graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n"
403
404                for n_mod, n_check in zip(
405                    mod_canonicalized.nodes(), check_canonicalized.nodes()
406                ):
407                    if str(n_mod) != str(n_check):
408                        graph_diff_errors += "First diverging operator:\n"
409                        node_diff = difflib.ndiff(
410                            str(n_mod).splitlines(True), str(n_check).splitlines(True)
411                        )
412                        source_printout = (
413                            "Node diff:\n" + indent("".join(node_diff)) + "\n"
414                        )
415                        mod_stack = n_mod.sourceRange()
416                        if mod_stack:
417                            source_printout += (
418                                "Trace source location:\n" + indent(mod_stack) + "\n"
419                            )
420                        check_stack = n_check.sourceRange()
421                        if check_stack:
422                            source_printout += (
423                                "Check source location:\n" + indent(check_stack) + "\n"
424                            )
425                        graph_diff_errors += source_printout
426
427                        break  # For now, only print out the first pair of nodes that diverges
428
429            tensor_compare_errors = None
430            # Check Tensor-valued constant nodes
431            for n_mod, n_check in zip(
432                mod_canonicalized.nodes(), check_canonicalized.nodes()
433            ):
434                if n_mod.kind() != n_check.kind():
435                    break  # Graphs have already diverged
436
437                if n_mod.kind() == "prim::Constant" and not (
438                    n_mod.mustBeNone() or n_check.mustBeNone()
439                ):
440                    if not n_mod.hasAttribute("value"):
441                        continue
442                    if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t":
443                        continue
444
445                    mod_tensor_val = n_mod.t("value")
446                    check_tensor_val = n_check.t("value")
447
448                    try:
449                        torch.testing.assert_close(
450                            mod_tensor_val, check_tensor_val, equal_nan=True
451                        )
452                    except (RuntimeError, AssertionError) as e:
453                        if tensor_compare_errors is None:
454                            tensor_compare_errors = ""
455                        tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n"
456                        compare_stack = n_mod.sourceRange()
457                        if compare_stack:
458                            tensor_compare_errors += (
459                                "Source Location:\n" + indent(compare_stack) + "\n"
460                            )
461                        tensor_compare_errors += "Comparison exception: " + indent(
462                            str(e)
463                        )
464
465                        break  # For now, only print the first diverging pair
466
467            return graph_diff_errors, tensor_compare_errors
468
469        def wrap_retval(x):
470            return x if isinstance(x, tuple) else (x,)
471
472        def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
473            try:
474                if isinstance(inputs, dict) and example_inputs_is_kwarg:
475                    outs = wrap_retval(mod(**inputs))
476                else:
477                    outs = wrap_retval(mod(*_clone_inputs(inputs)))
478                outs = [out for out in outs if isinstance(out, torch.Tensor)]
479                return outs
480            except Exception as e:
481                graph_diff_errors, tensor_compare_errors = graph_diagnostic_info()
482                msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}"
483                raise TracingCheckError(
484                    graph_diff_errors,
485                    tensor_compare_errors,
486                    extra_msg=msg,
487                ) from e
488
489        has_warned = [False]
490
491        def maybe_warn_nondeterministic():
492            if has_warned[0]:
493                return
494            has_warned[0] = True
495            nondeterm_ops = [
496                op for op in traced_func.graph.nodes() if op.isNondeterministic()
497            ]
498            if len(nondeterm_ops) > 0:
499                nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
500                nondeterministic_ops_warning += (
501                    "Did you forget call .eval() on your model? Nodes:\n"
502                )
503                nondeterministic_ops_warning += "\n".join(
504                    [indent(str(op)) for op in nondeterm_ops][:20]
505                )
506                nondeterministic_ops_warning += (
507                    "\nThis may cause errors in trace checking. To disable trace checking,"
508                    " pass check_trace=False to torch.jit.trace()"
509                )
510                warnings.warn(
511                    nondeterministic_ops_warning, category=TracerWarning, stacklevel=5
512                )
513
514        def compare_outputs(original, reference, match_what):
515            all_ok = True
516            for i, (orig, ref) in enumerate(zip(original, reference)):
517                try:
518                    if orig.is_quantized:
519                        orig = orig.dequantize()
520                    if ref.is_quantized:
521                        ref = ref.dequantize()
522                    if orig.is_mkldnn:
523                        orig = orig.to_dense()
524                    if ref.is_mkldnn:
525                        ref = ref.to_dense()
526                    if ref.is_complex() or orig.is_complex():
527                        torch.testing.assert_close(
528                            orig.to(torch.cdouble),
529                            ref.to(torch.cdouble),
530                            rtol=check_tolerance,
531                            atol=default_tolerances(orig, ref)[1],
532                            equal_nan=True,
533                        )
534                    else:
535                        if orig.is_mps or ref.is_mps:
536                            torch.testing.assert_close(
537                                orig.float(),
538                                ref.float(),
539                                rtol=check_tolerance,
540                                atol=default_tolerances(orig, ref)[1],
541                                equal_nan=True,
542                            )
543                        elif getattr(orig, "is_nested", None) or getattr(
544                            ref, "is_nested", None
545                        ):
546                            assert getattr(orig, "is_nested", None) == getattr(
547                                ref, "is_nested", None
548                            )
549                            for t_orig, t_ref in zip(orig.unbind(), ref.unbind()):
550                                torch.testing.assert_close(
551                                    t_orig.double(),
552                                    t_ref.double(),
553                                    rtol=check_tolerance,
554                                    atol=default_tolerances(t_orig, t_ref)[1],
555                                    equal_nan=True,
556                                )
557                        else:
558                            torch.testing.assert_close(
559                                orig.double(),
560                                ref.double(),
561                                rtol=check_tolerance,
562                                atol=default_tolerances(orig, ref)[1],
563                                equal_nan=True,
564                            )
565                except AssertionError as e:
566                    maybe_warn_nondeterministic()
567                    warnings.warn(
568                        "Output nr "
569                        + str(i + 1)
570                        + ". of the traced function does not match "
571                        "the corresponding output of the "
572                        + match_what
573                        + ". Detailed error:\n"
574                        + str(e),
575                        category=TracerWarning,
576                        stacklevel=4,
577                    )
578                    all_ok = False
579
580            return all_ok
581
582        traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
583        fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
584        if compare_outputs(traced_outs, fn_outs, "Python function"):
585            check_outs = run_mod_and_filter_tensor_outputs(
586                check_mod_func, inputs, "repeated trace"
587            )
588            compare_outputs(traced_outs, check_outs, "repeated trace")
589
590        diag_info = graph_diagnostic_info()
591        if any(info is not None for info in diag_info):
592            raise TracingCheckError(*diag_info)
593
594
595class TracerWarning(Warning):
596    @staticmethod
597    def ignore_lib_warnings():
598        # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
599        warnings.filterwarnings(
600            "ignore", category=TracerWarning, module="torch.(?!jit)"
601        )
602        warnings.filterwarnings("ignore", "torch::jit::fuser::cuda")
603
604
605# We ignore the tracer warnings coming form inside the library, because all our shape
606# checks in nn will trigger them.
607TracerWarning.ignore_lib_warnings()
608torch._C._tracer_warn_use_python()
609
610
611def make_tuple(example_inputs):
612    if isinstance(example_inputs, (torch.Tensor, dict)):
613        return (example_inputs,)
614    # done primarily so that weird iterables fail here and not pybind11 code
615    if not isinstance(example_inputs, tuple):
616        return tuple(example_inputs)
617    return example_inputs
618
619
620def make_module(mod, _module_class, _compilation_unit):
621    if isinstance(mod, ScriptModule):
622        return mod
623    elif torch._jit_internal.module_has_exports(mod):
624        infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
625        return torch.jit._recursive.create_script_module(
626            mod, infer_methods_stubs_fn, share_types=False, is_tracing=True
627        )
628    else:
629        if _module_class is None:
630            _module_class = TopLevelTracedModule
631        return _module_class(mod, _compilation_unit=_compilation_unit)
632
633
634def wrap_check_inputs(check_inputs):
635    if check_inputs is None:
636        return None
637
638    return [{"forward": c} for c in check_inputs]
639
640
641def analyze_ts_result_with_export_result(export, trace):
642    import torch.utils._pytree as pytree
643
644    flat_export = pytree.tree_leaves(export)
645    flat_trace = pytree.tree_leaves(trace)
646
647    for orig, loaded in zip(flat_export, flat_trace):
648        if orig.layout != loaded.layout:
649            return False
650        # mkldnn is not supported for torch.allclose
651        if orig.layout == torch._mkldnn:  # type: ignore[attr-defined]
652            return True
653        if type(orig) != type(loaded):
654            return False
655
656        if isinstance(orig, torch._subclasses.FakeTensor):
657            # Skip for FakeTensor.
658            return True
659        elif isinstance(orig, torch.Tensor):
660            if orig.dtype != loaded.dtype:
661                return False
662            if not torch.allclose(orig, loaded):
663                return False
664        else:
665            if orig != loaded:
666                return False
667    return True
668
669
670def _trace_impl(
671    func,
672    example_inputs=None,
673    optimize=None,
674    check_trace=True,
675    check_inputs=None,
676    check_tolerance=1e-5,
677    strict=True,
678    _force_outplace=False,
679    _module_class=None,
680    _compilation_unit=_python_cu,
681    example_kwarg_inputs=None,
682    _store_inputs=True,
683):
684    if isinstance(func, torch.jit.ScriptModule):
685        # it is hard to trace it because the forward method on ScriptModule is already defined, so it
686        # would result in an error.
687        warnings.warn(
688            "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is."
689        )
690        return func
691
692    if isinstance(func, torch.nn.Module):
693        if example_inputs is None:
694            if isinstance(example_kwarg_inputs, dict):
695                example_inputs = example_kwarg_inputs
696            else:
697                raise RuntimeError("example_kwarg_inputs should be a dict")
698        return trace_module(
699            func,
700            {"forward": example_inputs},
701            None,
702            check_trace,
703            wrap_check_inputs(check_inputs),
704            check_tolerance,
705            strict,
706            _force_outplace,
707            _module_class,
708            example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
709            _store_inputs=_store_inputs,
710        )
711    if (
712        hasattr(func, "__self__")
713        and isinstance(func.__self__, torch.nn.Module)
714        and func.__name__ == "forward"
715    ):
716        if example_inputs is None:
717            if isinstance(example_kwarg_inputs, dict):
718                example_inputs = example_kwarg_inputs
719            else:
720                raise RuntimeError("example_kwarg_inputs should be a dict")
721        return trace_module(
722            func.__self__,
723            {"forward": example_inputs},
724            None,
725            check_trace,
726            wrap_check_inputs(check_inputs),
727            check_tolerance,
728            strict,
729            _force_outplace,
730            _module_class,
731            example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
732            _store_inputs=_store_inputs,
733        )
734
735    # Special case for common case of passing a single Tensor
736    if (
737        isinstance(example_inputs, (torch.Tensor, dict))
738        and example_kwarg_inputs is None
739    ):
740        example_inputs = (example_inputs,)
741    # done primarily so that weird iterables fail here and not pybind11 code
742    elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple):
743        example_inputs = tuple(example_inputs)
744
745    var_lookup_fn = _create_interpreter_name_lookup_fn(0)
746
747    if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module):
748        raise AttributeError(
749            "trace doesn't support compiling individual module's functions.\n"
750            "Please use trace_module"
751        )
752
753    name = _qualified_name(func)
754    if isinstance(example_kwarg_inputs, dict):
755        example_inputs = example_kwarg_inputs
756        traced = torch._C._create_function_from_trace_with_dict(
757            name,
758            func,
759            example_kwarg_inputs,
760            var_lookup_fn,
761            strict,
762            _force_outplace,
763            get_callable_argument_names(func),
764        )
765    else:
766        traced = torch._C._create_function_from_trace(
767            name,
768            func,
769            example_inputs,
770            var_lookup_fn,
771            strict,
772            _force_outplace,
773            get_callable_argument_names(func),
774        )
775
776    # Check the trace against new traces created from user-specified inputs
777    if check_trace:
778        if check_inputs is not None:
779            _check_trace(
780                check_inputs,
781                func,
782                traced,
783                check_tolerance,
784                strict,
785                _force_outplace,
786                False,
787                _module_class,
788                example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
789            )
790        else:
791            _check_trace(
792                [example_inputs],
793                func,
794                traced,
795                check_tolerance,
796                strict,
797                _force_outplace,
798                False,
799                _module_class,
800                example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
801            )
802
803    # Allow torch.compile() to inline
804    traced._torchdynamo_inline = func  # type: ignore[attr-defined]
805    return traced
806
807
808class _ExportType(str, Enum):
809    DIRECT_EXPORT = "DIRECT_EXPORT"
810    TRACE_AND_EXPORT = "TRACE_AND_EXPORT"
811    SOURCE_TO_SOURCE = "SOURCE_TO_SOURCE"
812
813    def __str__(self) -> str:
814        return self.value
815
816
817class _ExportOutcome(str, Enum):
818    SUCCESS = "SUCCESS"
819    FAILED_TO_EXPORT = "FAILED_TO_EXPORT"
820    FAILED_TO_RUN = "FAILED_TO_RUN"
821    ACCURACY_ERROR = "ACCURACY_ERROR"
822
823    def __str__(self) -> str:
824        return self.value
825
826
827def trace(
828    func,
829    example_inputs=None,
830    optimize=None,
831    check_trace=True,
832    check_inputs=None,
833    check_tolerance=1e-5,
834    strict=True,
835    _force_outplace=False,
836    _module_class=None,
837    _compilation_unit=_python_cu,
838    example_kwarg_inputs=None,
839    _store_inputs=True,
840):
841    r"""
842    Trace a function and return an executable  or :class:`ScriptFunction` that will be optimized using just-in-time compilation.
843
844    Tracing is ideal for code that operates only on
845    ``Tensor``\\s and lists, dictionaries, and
846    tuples of ``Tensor``\\s.
847
848    Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an
849    existing module or Python function into a TorchScript
850    :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example
851    inputs, and we run the function, recording the operations performed on all
852    the tensors.
853
854    * The resulting recording of a standalone function produces `ScriptFunction`.
855    * The resulting recording of `nn.Module.forward` or `nn.Module` produces
856      `ScriptModule`.
857
858    This module also contains any parameters that the original
859    module had as well.
860
861    Warning:
862        Tracing only correctly records functions and modules which are not data
863        dependent (e.g., do not have conditionals on data in tensors) and do not have
864        any untracked external dependencies (e.g., perform input/output or
865        access global variables). Tracing only records operations done when the given
866        function is run on the given tensors. Therefore, the returned
867        `ScriptModule` will always run the same traced graph on any input. This
868        has some important implications when your module is expected to run
869        different sets of operations, depending on the input and/or the module
870        state. For example,
871
872        * Tracing will not record any control-flow like if-statements or loops.
873          When this control-flow is constant across your module, this is fine
874          and it often inlines the control-flow decisions. But sometimes the
875          control-flow is actually part of the model itself. For instance, a
876          recurrent network is a loop over the (possibly dynamic) length of an
877          input sequence.
878        * In the returned :class:`ScriptModule`, operations that have different
879          behaviors in ``training`` and ``eval`` modes will always behave as if
880          it is in the mode it was in during tracing, no matter which mode the
881          `ScriptModule` is in.
882
883        In cases like these, tracing would not be appropriate and
884        :func:`scripting <torch.jit.script>` is a better choice. If you trace
885        such models, you may silently get incorrect results on subsequent
886        invocations of the model. The tracer will try to emit warnings when
887        doing something that may cause an incorrect trace to be produced.
888
889    Args:
890        func (callable or torch.nn.Module):  A Python function or `torch.nn.Module`
891            that will be run with `example_inputs`. `func` arguments and return
892            values  must be tensors or (possibly nested) tuples that contain
893            tensors. When a module is passed `torch.jit.trace`, only the
894            ``forward`` method is run and traced (see :func:`torch.jit.trace
895            <torch.jit.trace_module>` for details).
896
897    Keyword arguments:
898        example_inputs (tuple or torch.Tensor or None, optional): A tuple of example
899            inputs that will be passed to the function while tracing.
900            Default: ``None``. Either this argument or ``example_kwarg_inputs``
901            should be specified. The resulting trace can be run with inputs of
902            different types and shapes assuming the traced operations support those
903            types and shapes. `example_inputs` may also be a single Tensor in which
904            case it is automatically wrapped in a tuple. When the value is None,
905            ``example_kwarg_inputs`` should be specified.
906
907        check_trace (``bool``, optional): Check if the same inputs run through
908            traced code produce the same outputs. Default: ``True``. You might want
909            to disable this if, for example, your network contains non-
910            deterministic ops or if you are sure that the network is correct despite
911            a checker failure.
912
913        check_inputs (list of tuples, optional): A list of tuples of input
914            arguments that should be used to check the trace against what is
915            expected. Each tuple is equivalent to a set of input arguments that
916            would be specified in ``example_inputs``. For best results, pass in
917            a set of checking inputs representative of the space of shapes and
918            types of inputs you expect the network to see.  If not specified,
919            the original ``example_inputs`` are used for checking
920        check_tolerance (float, optional): Floating-point comparison tolerance
921            to use in the checker procedure.  This can be used to relax the
922            checker strictness in the event that results diverge numerically
923            for a known reason, such as operator fusion.
924        strict (``bool``, optional): run the tracer in a strict mode or not
925            (default: ``True``). Only turn this off when you want the tracer to
926            record your mutable container types (currently ``list``/``dict``)
927            and you are sure that the container you are using in your
928            problem is a ``constant`` structure and does not get used as
929            control flow (if, for) conditions.
930        example_kwarg_inputs (dict, optional): This parameter is a pack of keyword
931            arguments of example inputs that will be passed to the function while
932            tracing. Default: ``None``. Either this argument or ``example_inputs``
933            should be specified. The dict will be unpacking by the arguments name
934            of the traced function. If the keys of the dict don't not match with
935            the traced function's arguments name, a runtime exception will be raised.
936
937    Returns:
938        If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns
939        a :class:`ScriptModule` object with a single ``forward`` method
940        containing the traced code.  The returned `ScriptModule` will
941        have the same set of sub-modules and parameters as the original
942        ``nn.Module``.  If ``func`` is a standalone function, ``trace``
943        returns `ScriptFunction`.
944
945    Example (tracing a function):
946
947    .. testcode::
948
949        import torch
950
951        def foo(x, y):
952            return 2 * x + y
953
954        # Run `foo` with the provided inputs and record the tensor operations
955        traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
956
957        # `traced_foo` can now be run with the TorchScript interpreter or saved
958        # and loaded in a Python-free environment
959
960    Example (tracing an existing module)::
961
962        import torch
963        import torch.nn as nn
964
965        class Net(nn.Module):
966            def __init__(self) -> None:
967                super().__init__()
968                self.conv = nn.Conv2d(1, 1, 3)
969
970            def forward(self, x):
971                return self.conv(x)
972
973        n = Net()
974        example_weight = torch.rand(1, 1, 3, 3)
975        example_forward_input = torch.rand(1, 1, 3, 3)
976
977        # Trace a specific method and construct `ScriptModule` with
978        # a single `forward` method
979        module = torch.jit.trace(n.forward, example_forward_input)
980
981        # Trace a module (implicitly traces `forward`) and construct a
982        # `ScriptModule` with a single `forward` method
983        module = torch.jit.trace(n, example_forward_input)
984
985    """
986    if not _enabled:
987        return func
988    if optimize is not None:
989        warnings.warn(
990            "`optimize` is deprecated and has no effect. "
991            "Use `with torch.jit.optimized_execution()` instead",
992            FutureWarning,
993            stacklevel=2,
994        )
995
996    from torch._utils_internal import (
997        check_if_torch_exportable,
998        log_torch_jit_trace_exportability,
999        log_torchscript_usage,
1000    )
1001
1002    traced_func = _trace_impl(
1003        func,
1004        example_inputs,
1005        optimize,
1006        check_trace,
1007        check_inputs,
1008        check_tolerance,
1009        strict,
1010        _force_outplace,
1011        _module_class,
1012        _compilation_unit,
1013        example_kwarg_inputs,
1014        _store_inputs,
1015    )
1016    log_torchscript_usage("trace", model_id=_get_model_id(traced_func))
1017
1018    if check_if_torch_exportable():
1019        from torch._export.converter import TS2EPConverter
1020        from torch.export._trace import (
1021            _convert_ts_to_export_experimental,
1022            _process_jit_trace_inputs_for_export,
1023        )
1024
1025        traced_func_for_export = _trace_impl(
1026            func,
1027            example_inputs=example_inputs,
1028            optimize=optimize,
1029            check_trace=False,
1030            check_inputs=check_inputs,
1031            check_tolerance=check_tolerance,
1032            strict=strict,
1033            _force_outplace=_force_outplace,
1034            _module_class=_module_class,
1035            _compilation_unit=_compilation_unit,
1036            example_kwarg_inputs=example_kwarg_inputs,
1037            _store_inputs=_store_inputs,
1038        )
1039
1040        export_args, _ = _process_jit_trace_inputs_for_export(
1041            example_inputs, example_kwarg_inputs
1042        )
1043
1044        def _log_exportability(func_to_export, export_func, export_args, export_type):
1045            try:
1046                traced_result = func_to_export(*export_args)
1047            except Exception as e:
1048                _ = e
1049                log_torch_jit_trace_exportability(
1050                    "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
1051                )
1052                return
1053
1054            try:
1055                ep_module = export_func(func_to_export, export_args)
1056            except Exception as e:
1057                log_torch_jit_trace_exportability(
1058                    "trace",
1059                    str(export_type),
1060                    str(_ExportOutcome.FAILED_TO_EXPORT),
1061                    str(e),
1062                )
1063                return
1064
1065            try:
1066                export = ep_module(*export_args)
1067            except Exception as e:
1068                log_torch_jit_trace_exportability(
1069                    "trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e)
1070                )
1071                return
1072
1073            if not analyze_ts_result_with_export_result(export, traced_result):
1074                log_torch_jit_trace_exportability(
1075                    "trace",
1076                    str(export_type),
1077                    str(_ExportOutcome.ACCURACY_ERROR),
1078                    "accuracy error",
1079                )
1080                return
1081
1082            log_torch_jit_trace_exportability(
1083                "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded"
1084            )
1085
1086        def _direct_export_and_lower(func, export_args):
1087            return torch.export.export(func, export_args, strict=False).module()
1088
1089        def _convert_ts_to_export_source_to_source(func, export_args):
1090            return TS2EPConverter(func, export_args).convert().module()
1091
1092        # torch.jit.trace is noop when the original module is torch.jit.ScriptModule
1093        if not isinstance(traced_func_for_export, torch.jit.ScriptModule):
1094            _log_exportability(
1095                traced_func_for_export,
1096                _direct_export_and_lower,
1097                export_args,
1098                _ExportType.DIRECT_EXPORT,
1099            )
1100
1101        _log_exportability(
1102            traced_func_for_export,
1103            _convert_ts_to_export_experimental,
1104            export_args,
1105            _ExportType.TRACE_AND_EXPORT,
1106        )
1107        _log_exportability(
1108            traced_func_for_export,
1109            _convert_ts_to_export_source_to_source,
1110            export_args,
1111            _ExportType.SOURCE_TO_SOURCE,
1112        )
1113
1114    return traced_func
1115
1116
1117_trace_module_map: Optional[Dict[Any, Any]] = None
1118
1119
1120def trace_module(
1121    mod,
1122    inputs,
1123    optimize=None,
1124    check_trace=True,
1125    check_inputs=None,
1126    check_tolerance=1e-5,
1127    strict=True,
1128    _force_outplace=False,
1129    _module_class=None,
1130    _compilation_unit=_python_cu,
1131    example_inputs_is_kwarg=False,
1132    _store_inputs=True,
1133):
1134    """
1135    Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation.
1136
1137    When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only
1138    the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of
1139    method names to example inputs to trace (see the ``inputs``) argument below.
1140
1141    See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing.
1142
1143    Args:
1144        mod (torch.nn.Module):  A ``torch.nn.Module`` containing methods whose names are
1145                                specified in ``inputs``. The given methods will be compiled
1146                                as a part of a single `ScriptModule`.
1147        inputs (dict):  A dict containing sample inputs indexed by method names in ``mod``.
1148                                The inputs will be passed to methods whose names correspond to inputs'
1149                                keys while tracing.
1150                                ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
1151    Keyword arguments:
1152        check_trace (``bool``, optional): Check if the same inputs run through
1153                                      traced code produce the same outputs. Default: ``True``. You might want
1154                                      to disable this if, for example, your network contains non-
1155                                      deterministic ops or if you are sure that the network is correct despite
1156                                      a checker failure.
1157
1158        check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
1159                                                 to check the trace against what is expected. Each tuple
1160                                                 is equivalent to a set of input arguments that would
1161                                                 be specified in ``inputs``. For best results, pass in a
1162                                                 set of checking inputs representative of the space of
1163                                                 shapes and types of inputs you expect the network to see.
1164                                                 If not specified, the original ``inputs`` are used for checking
1165        check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
1166                                           This can be used to relax the checker strictness in the event that
1167                                           results diverge numerically for a known reason, such as operator fusion.
1168        example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack
1169                                           pack of keyword arguments. Default: ``False``.
1170
1171    Returns:
1172        A :class:`ScriptModule` object with a single ``forward`` method containing the traced code.
1173        When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of
1174        sub-modules and parameters as ``func``.
1175
1176    Example (tracing a module with multiple methods)::
1177
1178        import torch
1179        import torch.nn as nn
1180
1181        class Net(nn.Module):
1182            def __init__(self) -> None:
1183                super().__init__()
1184                self.conv = nn.Conv2d(1, 1, 3)
1185
1186            def forward(self, x):
1187                return self.conv(x)
1188
1189            def weighted_kernel_sum(self, weight):
1190                return weight * self.conv.weight
1191
1192
1193        n = Net()
1194        example_weight = torch.rand(1, 1, 3, 3)
1195        example_forward_input = torch.rand(1, 1, 3, 3)
1196
1197        # Trace a specific method and construct `ScriptModule` with
1198        # a single `forward` method
1199        module = torch.jit.trace(n.forward, example_forward_input)
1200
1201        # Trace a module (implicitly traces `forward`) and construct a
1202        # `ScriptModule` with a single `forward` method
1203        module = torch.jit.trace(n, example_forward_input)
1204
1205        # Trace specific methods on a module (specified in `inputs`), constructs
1206        # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
1207        inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
1208        module = torch.jit.trace_module(n, inputs)
1209
1210    """
1211    if not _enabled:
1212        return mod
1213    if optimize is not None:
1214        warnings.warn(
1215            "`optimize` is deprecated and has no effect. "
1216            "Use `with torch.jit.optimized_execution()` instead",
1217            FutureWarning,
1218            stacklevel=2,
1219        )
1220
1221    var_lookup_fn = _create_interpreter_name_lookup_fn(0)
1222
1223    if not isinstance(mod, torch.nn.Module):
1224        raise AttributeError("expected torch.nn.Module as the first argument")
1225
1226    if not isinstance(inputs, dict):
1227        raise AttributeError("expected a dictionary of (method_name, input) pairs")
1228
1229    old_module_map = torch.jit._trace._trace_module_map
1230    try:
1231        trace_module_map: Dict[Any, Any] = {}
1232
1233        def register_submods(mod, prefix):
1234            for name, child in mod.named_children():
1235                submod_qualname = prefix + "." + name
1236                trace_module_map[child] = submod_qualname
1237                register_submods(child, submod_qualname)
1238
1239        trace_module_map["__module"] = mod
1240        torch.jit._trace._trace_module_map = trace_module_map
1241        register_submods(mod, "__module")
1242
1243        module = make_module(mod, _module_class, _compilation_unit)
1244
1245        for method_name, example_inputs in inputs.items():
1246            if method_name == "forward":
1247                # "forward" is a special case because we need to trace
1248                # `Module.__call__`, which sets up some extra tracing, but uses
1249                # argument names of the real `Module.forward` method.
1250                func = mod
1251                forward_method = getattr(mod, method_name)
1252                argument_names = get_callable_argument_names(forward_method)
1253            else:
1254                func = getattr(mod, method_name)
1255                argument_names = get_callable_argument_names(func)
1256
1257            if isinstance(example_inputs, dict) and example_inputs_is_kwarg:
1258                # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/
1259                for key in example_inputs:
1260                    if key not in argument_names:
1261                        valid_arguments = "[" + ",".join(argument_names) + "]"
1262                        raise NameError(
1263                            f"""'{key}' is not in forward() method's arguments,
1264                         valid arguments name are {valid_arguments}"""
1265                        )
1266                module._c._create_method_from_trace_with_dict(
1267                    method_name,
1268                    func,
1269                    example_inputs,
1270                    var_lookup_fn,
1271                    strict,
1272                    _force_outplace,
1273                    argument_names,
1274                    _store_inputs,
1275                )
1276            else:
1277                example_inputs = make_tuple(example_inputs)
1278                module._c._create_method_from_trace(
1279                    method_name,
1280                    func,
1281                    example_inputs,
1282                    var_lookup_fn,
1283                    strict,
1284                    _force_outplace,
1285                    argument_names,
1286                    _store_inputs,
1287                )
1288
1289            check_trace_method = module._c._get_method(method_name)
1290
1291            # Check the trace against new traces created from user-specified inputs
1292            if check_trace:
1293                if check_inputs is not None:
1294                    _check_trace(
1295                        check_inputs,
1296                        func,
1297                        check_trace_method,
1298                        check_tolerance,
1299                        strict,
1300                        _force_outplace,
1301                        True,
1302                        _module_class,
1303                        example_inputs_is_kwarg=example_inputs_is_kwarg,
1304                    )
1305                else:
1306                    _check_trace(
1307                        [inputs],
1308                        func,
1309                        check_trace_method,
1310                        check_tolerance,
1311                        strict,
1312                        _force_outplace,
1313                        True,
1314                        _module_class,
1315                        example_inputs_is_kwarg=example_inputs_is_kwarg,
1316                    )
1317    finally:
1318        torch.jit._trace._trace_module_map = old_module_map
1319
1320    return module
1321
1322
1323def is_tracing():
1324    """Return a boolean value.
1325
1326    Returns ``True`` in tracing (if a function is called during the
1327    tracing of code with ``torch.jit.trace``) and ``False`` otherwise.
1328    """
1329    if is_scripting():
1330        return False
1331    return torch._C._is_tracing()
1332
1333
1334class TracedModule(ScriptModule):
1335    _disable_script_meta = True
1336
1337    def __init__(self, orig, id_set=None, _compilation_unit=None):
1338        # XXX: orig can be a nn.Module or a function!
1339        super().__init__()
1340        assert isinstance(orig, torch.nn.Module)
1341
1342        # Copy a subset of `orig` to a temporary nn.Module.
1343        # This is a way to customize what will actually get compiled by create_script_module
1344        id_set = set()
1345
1346        # This allows us to preserve the original module's qualified name by defining a new
1347        # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name
1348        # we have a special case that will look up this attribute to override whatever qualname
1349        # we would get from the python type system
1350        class QualnameWrapper(torch.nn.Module):
1351            pass
1352
1353        QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name(  # type: ignore[attr-defined]
1354            type(orig)
1355        )
1356
1357        tmp_module = QualnameWrapper()
1358
1359        def check_unique(param):
1360            if param in id_set:
1361                raise ValueError(
1362                    "TracedModules don't support parameter sharing between modules"
1363                )
1364            id_set.add(param)
1365
1366        tmp_module.training = orig.training
1367
1368        for name, param in orig._parameters.items():
1369            if param is not None:
1370                tmp_module._parameters[name] = param
1371                check_unique(param)
1372        for name, buf in orig._buffers.items():
1373            if buf is not None:
1374                tmp_module._buffers[name] = buf
1375                check_unique(buf)
1376        for name, val in orig.__dict__.items():
1377            if (
1378                torch._C._jit_is_script_object(val)
1379                and name not in orig._parameters
1380                and name not in orig._buffers
1381            ):
1382                setattr(tmp_module, name, val)
1383
1384        if orig._backward_hooks:
1385            raise ValueError(
1386                "Modules that have backward hooks assigned can't be compiled: "
1387                + str(orig)
1388            )
1389
1390        for name, submodule in orig._modules.items():
1391            if submodule is None:
1392                continue
1393            tmp_module._modules[name] = make_module(
1394                submodule, TracedModule, _compilation_unit=None
1395            )
1396
1397        script_module = torch.jit._recursive.create_script_module(
1398            tmp_module, lambda module: (), share_types=False, is_tracing=True
1399        )
1400
1401        self.__dict__["_name"] = type(orig).__name__
1402        self.__dict__["_actual_script_module"] = script_module
1403        for name in ("_parameters", "_buffers", "_modules", "training"):
1404            delattr(self, name)
1405
1406    def forward(self, *args, **kwargs):
1407        raise RuntimeError("Trace submodules cannot be called.")
1408
1409    def __getattr__(self, attr):
1410        if "_actual_script_module" not in self.__dict__:
1411            return super().__getattr__(attr)
1412        return getattr(self._actual_script_module, attr)
1413
1414    def __setattr__(self, attr, value):
1415        if "_actual_script_module" not in self.__dict__:
1416            return super().__setattr__(attr, value)
1417        setattr(self._actual_script_module, attr, value)
1418
1419    def _get_name(self):
1420        return self._name
1421
1422    def extra_repr(self):
1423        return f"original_name={self._name}"
1424
1425
1426class TopLevelTracedModule(TracedModule):
1427    forward: Callable[..., Any] = _CachedForward()  # type: ignore[assignment]
1428
1429    def _reconstruct(self, cpp_module):
1430        """
1431        Re-construct an instance of TopLevelTracedModule using an instance of a C++ module.
1432
1433        Args:
1434            cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
1435        """
1436        self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
1437
1438
1439def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]:
1440    @functools.wraps(fn)
1441    def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
1442        if not is_tracing():
1443            # Not tracing, don't do anything
1444            return fn(*args, **kwargs)
1445
1446        compiled_fn: Callable[P, R] = script(wrapper.__original_fn)  # type: ignore[attr-defined]
1447        return compiled_fn(*args, **kwargs)
1448
1449    wrapper.__original_fn = fn  # type: ignore[attr-defined]
1450    wrapper.__script_if_tracing_wrapper = True  # type: ignore[attr-defined]
1451
1452    return wrapper
1453
1454
1455def _get_trace_graph(
1456    f,
1457    args=(),
1458    kwargs=None,
1459    strict=True,
1460    _force_outplace=False,
1461    return_inputs=False,
1462    _return_inputs_states=False,
1463):
1464    """Return a tuple on tracing a function or model.
1465
1466    .. warning::
1467        This function is internal-only and should only be used by the ONNX
1468        exporter. If you are trying to get a graph through tracing, please go
1469        through the public API instead::
1470
1471            trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
1472            trace_graph = trace.graph
1473
1474    Trace a function or model, returning a tuple consisting of the both the
1475    *trace* of an execution, as well as the original return value. If return_inputs,
1476    also returns the trace inputs as part of the tuple
1477
1478    Tracing is guaranteed not to change the semantics of the function/module
1479    that is traced.
1480
1481    Args:
1482        f (torch.nn.Module or function): the function or module
1483            to be traced.
1484        args (tuple or Tensor): the positional arguments to pass to the
1485            function/module to be traced.  A non-tuple is assumed to
1486            be a single positional argument to be passed to the model.
1487        kwargs (dict): the keyword arguments to pass to the function/module
1488            to be traced.
1489
1490    Example (trace a cell):
1491
1492    .. testcode::
1493
1494        trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
1495    """
1496    if kwargs is None:
1497        kwargs = {}
1498    if not isinstance(args, tuple):
1499        args = (args,)
1500    outs = ONNXTracedModule(
1501        f, strict, _force_outplace, return_inputs, _return_inputs_states
1502    )(*args, **kwargs)
1503    return outs
1504