1# mypy: allow-untyped-defs 2"""Tracing. 3 4This module contains functionality to support the JIT's tracing frontend, notably: 5 * torch.jit.trace 6 * torch.jit.trace_module 7 8This is not intended to be imported directly; please use the exposed 9functionalities in `torch.jit`. 10""" 11 12import contextlib 13import copy 14import functools 15import inspect 16import os 17import re 18import warnings 19from enum import Enum 20from typing import Any, Callable, Dict, List, Optional, Set, TypeVar 21from typing_extensions import ParamSpec 22 23import torch 24from torch._jit_internal import ( 25 _get_model_id, 26 _qualified_name, 27 get_callable_argument_names, 28 is_scripting, 29) 30from torch.autograd import function 31from torch.jit._script import _CachedForward, script, ScriptModule 32from torch.jit._state import _enabled, _python_cu 33from torch.nn import Module 34from torch.testing._comparison import default_tolerances 35 36 37_flatten = torch._C._jit_flatten 38_unflatten = torch._C._jit_unflatten 39 40R = TypeVar("R", covariant=True) # return type (always covariant) 41P = ParamSpec("P") 42 43 44def _create_interpreter_name_lookup_fn(frames_up=1): 45 def _get_interpreter_name_for_var(var): 46 frame = inspect.currentframe() 47 if not frame: 48 raise RuntimeError("failed to inspect frame") 49 50 i = 0 51 while i < frames_up + 1: 52 frame = frame.f_back 53 if not frame: 54 raise RuntimeError("failed to get frame") 55 i += 1 56 57 f_locals = frame.f_locals 58 f_globals = frame.f_globals 59 60 for k, v in f_locals.items(): 61 if isinstance(v, torch.Tensor) and var is v: 62 return k if k != "self" else "" 63 return "" 64 65 return _get_interpreter_name_for_var 66 67 68def _unique_state_dict(module, keep_vars=False): 69 # since Parameter.detach() always creates a new torch.Tensor instance, 70 # id(v) doesn't work with it. So we always get the Parameter or Buffer 71 # as values, and deduplicate the params using Parameters and Buffers 72 state_dict = module.state_dict(keep_vars=True) 73 filtered_dict = type(state_dict)() 74 seen_ids: Set[int] = set() 75 for k, v in state_dict.items(): 76 if id(v) in seen_ids: 77 continue 78 seen_ids.add(id(v)) 79 if keep_vars: 80 filtered_dict[k] = v 81 else: 82 filtered_dict[k] = v.detach() 83 return filtered_dict 84 85 86class ONNXTracedModule(torch.nn.Module): 87 def __init__( 88 self, 89 inner, 90 strict=True, 91 force_outplace=False, 92 return_inputs=False, 93 return_inputs_states=False, 94 ): 95 super().__init__() 96 # inner may be a Module, or it may be an arbitrary callable 97 # If it's a Module, we get its parameters automatically, which lets 98 # us avoid a special casing functions versus modules. 99 self.inner = inner 100 self.strict = strict 101 self._force_outplace = force_outplace 102 self._return_inputs = return_inputs 103 self._return_inputs_states = return_inputs_states 104 105 def forward(self, *args: torch.Tensor): 106 in_vars, in_desc = _flatten(args) 107 # NOTE: use full state, because we need it for BatchNorm export 108 # This differs from the compiler path, which doesn't support it at the moment. 109 module_state = list(_unique_state_dict(self, keep_vars=True).values()) 110 111 ret_inputs = [] 112 inputs_states = [] 113 outs = [] 114 115 def wrapper(*args): 116 in_args: List[torch.Tensor] = [] 117 for i in range(len(in_vars)): 118 if not isinstance(args[i], torch.Tensor): 119 raise RuntimeError("Expected Tensor argument") 120 in_args.append(args[i]) 121 122 trace_inputs = _unflatten(in_args, in_desc) 123 124 if self._return_inputs: 125 ret_inputs.append( 126 tuple(x.clone(memory_format=torch.preserve_format) for x in args) 127 ) 128 if self._return_inputs_states: 129 inputs_states.append(_unflatten(in_args, in_desc)) 130 outs.append(self.inner(*trace_inputs)) 131 if self._return_inputs_states: 132 inputs_states[0] = (inputs_states[0], trace_inputs) 133 out_vars, _ = _flatten(outs) 134 if len(out_vars) == 1: 135 return out_vars[0] 136 else: 137 return tuple(out_vars) 138 139 graph, out = torch._C._create_graph_by_tracing( 140 wrapper, 141 in_vars + module_state, 142 _create_interpreter_name_lookup_fn(), 143 self.strict, 144 self._force_outplace, 145 ) 146 147 if self._return_inputs: 148 return graph, outs[0], ret_inputs[0] 149 if self._return_inputs_states: 150 return graph, outs[0], inputs_states[0] 151 else: 152 return graph, outs[0] 153 154 155def _clone_inputs(args): 156 def clone_input(a): 157 if a is None: 158 return None 159 elif isinstance(a, torch.Tensor): 160 # TODO: figure out one liner to .clone() and set requires_grad 161 v = ( 162 a.detach() 163 .clone(memory_format=None if a.is_mkldnn else torch.preserve_format) 164 .requires_grad_(a.requires_grad) 165 ) 166 if a.grad is not None: 167 v.grad = clone_input(v.grad) 168 return v 169 else: 170 return a.clone(memory_format=torch.preserve_format) 171 172 return function._nested_map( 173 lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors" 174 )(args) 175 176 177# This is purely for developer debugging. We are not going to advertise it. 178_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False) # CUDA-only timing 179_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False) 180_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False) 181 182 183@contextlib.contextmanager 184def _time(trace_name, name, time=True): 185 if (not _JIT_TIME and not time) or not torch.cuda.is_available(): 186 yield 187 return 188 stream = torch.cuda.current_stream() 189 start = torch.cuda.Event(enable_timing=True) 190 end = torch.cuda.Event(enable_timing=True) 191 stream.record_event(start) 192 try: 193 yield 194 finally: 195 stream.record_event(end) 196 end.synchronize() 197 print(f"{trace_name} {name} time: {start.elapsed_time(end)} ms") 198 199 200def verify(model, args, loss_fn=torch.sum, devices=None): 201 """ 202 Verify that a JIT compiled model has the same behavior as its uncompiled version along with its backwards pass. 203 204 If your model returns multiple outputs, 205 you must also specify a `loss_fn` to produce a loss for which 206 the backwards will be computed. 207 208 This function has side-effects (e.g., it executes your model / saves and loads 209 parameters), so don't expect the model to come out exactly the same as what 210 you passed in. 211 212 Args: 213 model (compiled torch.nn.Module or function): the module/function to be 214 verified. The module/function definition MUST have been decorated with 215 `@torch.jit.compile`. 216 args (tuple or Tensor): the positional arguments to pass to the 217 compiled function/module to be verified. A non-tuple is assumed to 218 be a single positional argument to be passed to the model. 219 loss_fn (function, optional): the loss function to be applied to 220 the output of the model, before backwards is invoked. By default, 221 we assume that a model returns a single result, and we :func:`torch.sum` 222 before calling backwards; if this is inappropriate, you can pass your 223 own loss function. Note that if a model returns a tuple of results, 224 these are passed as separate positional arguments to `loss_fn`. 225 devices (iterable of device IDs, optional): the GPU devices which the 226 compiled module will be run on. This determines the RNG state we 227 must save when running both compiled and uncompiled versions of the model. 228 """ 229 # TODO: In principle, we track device information in our trace, so it 230 # should be possible to check if our execution actually obeyed the 'devices' 231 # the user provided. 232 233 # TODO: Consider adding a utility function to torch.jit to test 234 # for this case 235 if not isinstance(model, torch._C.CompiledFunction): # type: ignore[attr-defined] 236 raise TypeError( 237 "Cannot verify an uncompiled module. Add @torch.jit.compile to compile it" 238 ) 239 is_module = isinstance(model, Module) 240 241 if not isinstance(args, tuple): 242 args = (args,) 243 244 saved_args = _clone_inputs(args) 245 if is_module: 246 saved_state = copy.deepcopy(model.state_dict()) 247 248 def run_fwd_bwd(args, force_trace=False, assert_compiled=False): 249 params = list(model.parameters()) if is_module else [] 250 in_vars, _ = _flatten((args, params)) 251 # We use a special API to reset the trace and compile it from scratch. 252 compiled_fn = model 253 if force_trace: 254 compiled_fn.clear_cache() 255 if assert_compiled: 256 hits = compiled_fn.hits 257 out = model(*args) 258 if assert_compiled and compiled_fn.hits == hits: # type: ignore[possibly-undefined] 259 raise RuntimeError("failed to use the compiled function") 260 if not isinstance(out, tuple): 261 out = (out,) 262 if loss_fn == torch.sum and len(out) != 1: 263 raise ValueError( 264 f"Model returns {len(out)} outputs, but default loss function " 265 "(torch.sum) can only handle a single output" 266 ) 267 out_vars, _ = _flatten(out) 268 saved_outs = [ 269 v.detach().clone(memory_format=torch.preserve_format) for v in out_vars 270 ] 271 loss = loss_fn(*out) 272 grads = torch.autograd.grad([loss], in_vars) 273 # TODO: I'm not sure if the clone here is necessary but it is safer 274 saved_grads = [ 275 v.detach().clone(memory_format=torch.preserve_format) for v in grads 276 ] 277 return (saved_outs, saved_grads) 278 279 with torch.random.fork_rng(devices, _caller="torch.jit.verify"): 280 uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True) 281 assert model.has_trace_for(*args) 282 283 if is_module: 284 model.load_state_dict(saved_state) # type: ignore[possibly-undefined] 285 compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True) 286 287 _verify_equal(uncompiled_outs, compiled_outs) 288 _verify_equal(uncompiled_grads, compiled_grads) 289 290 291def _verify_equal(xs, ys): 292 for x, y in zip(xs, ys): 293 if x.sub(y).abs().max() > 1e-6: 294 raise RuntimeError("JIT and real computation mismatch") 295 296 297def indent(s): 298 return "\n".join(["\t" + line for line in s.splitlines()]) 299 300 301class TracingCheckError(Exception): 302 def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None): 303 self.message = "Tracing failed sanity checks!\n" 304 if extra_msg is not None: 305 self.message += extra_msg + "\n" 306 if graph_diff_error is not None: 307 self.message += "ERROR: Graphs differed across invocations!\n" 308 self.message += indent(graph_diff_error) + "\n" 309 if tensor_compare_error is not None: 310 self.message += ( 311 "ERROR: Tensor-valued Constant nodes differed in value " 312 "across invocations. This often indicates that the tracer has" 313 " encountered untraceable code.\n" 314 ) 315 self.message += indent(tensor_compare_error) + "\n" 316 super().__init__(self.message) 317 318 319# Check the traced module against a set of user-provided validation inputs 320@torch.no_grad() 321def _check_trace( 322 check_inputs, 323 func, 324 traced_func, 325 check_tolerance, 326 strict, 327 force_outplace, 328 is_trace_module, 329 _module_class, 330 example_inputs_is_kwarg=False, 331): 332 # Note: tracing is independent of optimizations, which consume the trace 333 for inputs in check_inputs: 334 if isinstance(inputs, torch.Tensor): 335 inputs = (inputs,) 336 337 if is_trace_module: 338 copied_dict = {} 339 for name, data in inputs.items(): 340 copied_dict[name] = _clone_inputs(data) 341 check_mod = torch.jit.trace_module( 342 getattr(func, "__self__", func), 343 copied_dict, 344 check_trace=False, 345 strict=strict, 346 _force_outplace=force_outplace, 347 _module_class=_module_class, 348 _compilation_unit=torch._C.CompilationUnit(), 349 example_inputs_is_kwarg=example_inputs_is_kwarg, 350 _store_inputs=False, 351 ) 352 check_mod_func = check_mod._c._get_method(traced_func.name) 353 inputs = inputs[traced_func.name] 354 if ( 355 isinstance(inputs, (torch.Tensor)) 356 or isinstance(inputs, dict) 357 and not example_inputs_is_kwarg 358 ): 359 inputs = (inputs,) 360 else: 361 if example_inputs_is_kwarg: 362 check_mod = torch.jit.trace( 363 func, 364 check_trace=False, 365 strict=strict, 366 _force_outplace=force_outplace, 367 _module_class=_module_class, 368 example_kwarg_inputs=_clone_inputs(inputs), 369 _store_inputs=False, 370 ) 371 else: 372 check_mod = torch.jit.trace( 373 func, 374 _clone_inputs(inputs), 375 check_trace=False, 376 strict=strict, 377 _force_outplace=force_outplace, 378 _module_class=_module_class, 379 _store_inputs=False, 380 ) 381 check_mod_func = check_mod 382 383 def graph_diagnostic_info(): 384 mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph) 385 torch._C._jit_pass_inline(mod_canonicalized) 386 torch._C._jit_pass_erase_shape_information(mod_canonicalized) 387 mod_str = str(mod_canonicalized) 388 mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str) 389 check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph) 390 torch._C._jit_pass_inline(check_canonicalized) 391 torch._C._jit_pass_erase_shape_information(check_canonicalized) 392 check_str = str(check_canonicalized) 393 check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str) 394 395 graph_diff_errors = None 396 if mod_str != check_str: 397 import difflib 398 399 graph_diff = difflib.ndiff( 400 mod_str.splitlines(True), check_str.splitlines(True) 401 ) 402 graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n" 403 404 for n_mod, n_check in zip( 405 mod_canonicalized.nodes(), check_canonicalized.nodes() 406 ): 407 if str(n_mod) != str(n_check): 408 graph_diff_errors += "First diverging operator:\n" 409 node_diff = difflib.ndiff( 410 str(n_mod).splitlines(True), str(n_check).splitlines(True) 411 ) 412 source_printout = ( 413 "Node diff:\n" + indent("".join(node_diff)) + "\n" 414 ) 415 mod_stack = n_mod.sourceRange() 416 if mod_stack: 417 source_printout += ( 418 "Trace source location:\n" + indent(mod_stack) + "\n" 419 ) 420 check_stack = n_check.sourceRange() 421 if check_stack: 422 source_printout += ( 423 "Check source location:\n" + indent(check_stack) + "\n" 424 ) 425 graph_diff_errors += source_printout 426 427 break # For now, only print out the first pair of nodes that diverges 428 429 tensor_compare_errors = None 430 # Check Tensor-valued constant nodes 431 for n_mod, n_check in zip( 432 mod_canonicalized.nodes(), check_canonicalized.nodes() 433 ): 434 if n_mod.kind() != n_check.kind(): 435 break # Graphs have already diverged 436 437 if n_mod.kind() == "prim::Constant" and not ( 438 n_mod.mustBeNone() or n_check.mustBeNone() 439 ): 440 if not n_mod.hasAttribute("value"): 441 continue 442 if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t": 443 continue 444 445 mod_tensor_val = n_mod.t("value") 446 check_tensor_val = n_check.t("value") 447 448 try: 449 torch.testing.assert_close( 450 mod_tensor_val, check_tensor_val, equal_nan=True 451 ) 452 except (RuntimeError, AssertionError) as e: 453 if tensor_compare_errors is None: 454 tensor_compare_errors = "" 455 tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n" 456 compare_stack = n_mod.sourceRange() 457 if compare_stack: 458 tensor_compare_errors += ( 459 "Source Location:\n" + indent(compare_stack) + "\n" 460 ) 461 tensor_compare_errors += "Comparison exception: " + indent( 462 str(e) 463 ) 464 465 break # For now, only print the first diverging pair 466 467 return graph_diff_errors, tensor_compare_errors 468 469 def wrap_retval(x): 470 return x if isinstance(x, tuple) else (x,) 471 472 def run_mod_and_filter_tensor_outputs(mod, inputs, running_what): 473 try: 474 if isinstance(inputs, dict) and example_inputs_is_kwarg: 475 outs = wrap_retval(mod(**inputs)) 476 else: 477 outs = wrap_retval(mod(*_clone_inputs(inputs))) 478 outs = [out for out in outs if isinstance(out, torch.Tensor)] 479 return outs 480 except Exception as e: 481 graph_diff_errors, tensor_compare_errors = graph_diagnostic_info() 482 msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}" 483 raise TracingCheckError( 484 graph_diff_errors, 485 tensor_compare_errors, 486 extra_msg=msg, 487 ) from e 488 489 has_warned = [False] 490 491 def maybe_warn_nondeterministic(): 492 if has_warned[0]: 493 return 494 has_warned[0] = True 495 nondeterm_ops = [ 496 op for op in traced_func.graph.nodes() if op.isNondeterministic() 497 ] 498 if len(nondeterm_ops) > 0: 499 nondeterministic_ops_warning = "Trace had nondeterministic nodes. " 500 nondeterministic_ops_warning += ( 501 "Did you forget call .eval() on your model? Nodes:\n" 502 ) 503 nondeterministic_ops_warning += "\n".join( 504 [indent(str(op)) for op in nondeterm_ops][:20] 505 ) 506 nondeterministic_ops_warning += ( 507 "\nThis may cause errors in trace checking. To disable trace checking," 508 " pass check_trace=False to torch.jit.trace()" 509 ) 510 warnings.warn( 511 nondeterministic_ops_warning, category=TracerWarning, stacklevel=5 512 ) 513 514 def compare_outputs(original, reference, match_what): 515 all_ok = True 516 for i, (orig, ref) in enumerate(zip(original, reference)): 517 try: 518 if orig.is_quantized: 519 orig = orig.dequantize() 520 if ref.is_quantized: 521 ref = ref.dequantize() 522 if orig.is_mkldnn: 523 orig = orig.to_dense() 524 if ref.is_mkldnn: 525 ref = ref.to_dense() 526 if ref.is_complex() or orig.is_complex(): 527 torch.testing.assert_close( 528 orig.to(torch.cdouble), 529 ref.to(torch.cdouble), 530 rtol=check_tolerance, 531 atol=default_tolerances(orig, ref)[1], 532 equal_nan=True, 533 ) 534 else: 535 if orig.is_mps or ref.is_mps: 536 torch.testing.assert_close( 537 orig.float(), 538 ref.float(), 539 rtol=check_tolerance, 540 atol=default_tolerances(orig, ref)[1], 541 equal_nan=True, 542 ) 543 elif getattr(orig, "is_nested", None) or getattr( 544 ref, "is_nested", None 545 ): 546 assert getattr(orig, "is_nested", None) == getattr( 547 ref, "is_nested", None 548 ) 549 for t_orig, t_ref in zip(orig.unbind(), ref.unbind()): 550 torch.testing.assert_close( 551 t_orig.double(), 552 t_ref.double(), 553 rtol=check_tolerance, 554 atol=default_tolerances(t_orig, t_ref)[1], 555 equal_nan=True, 556 ) 557 else: 558 torch.testing.assert_close( 559 orig.double(), 560 ref.double(), 561 rtol=check_tolerance, 562 atol=default_tolerances(orig, ref)[1], 563 equal_nan=True, 564 ) 565 except AssertionError as e: 566 maybe_warn_nondeterministic() 567 warnings.warn( 568 "Output nr " 569 + str(i + 1) 570 + ". of the traced function does not match " 571 "the corresponding output of the " 572 + match_what 573 + ". Detailed error:\n" 574 + str(e), 575 category=TracerWarning, 576 stacklevel=4, 577 ) 578 all_ok = False 579 580 return all_ok 581 582 traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace") 583 fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function") 584 if compare_outputs(traced_outs, fn_outs, "Python function"): 585 check_outs = run_mod_and_filter_tensor_outputs( 586 check_mod_func, inputs, "repeated trace" 587 ) 588 compare_outputs(traced_outs, check_outs, "repeated trace") 589 590 diag_info = graph_diagnostic_info() 591 if any(info is not None for info in diag_info): 592 raise TracingCheckError(*diag_info) 593 594 595class TracerWarning(Warning): 596 @staticmethod 597 def ignore_lib_warnings(): 598 # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace 599 warnings.filterwarnings( 600 "ignore", category=TracerWarning, module="torch.(?!jit)" 601 ) 602 warnings.filterwarnings("ignore", "torch::jit::fuser::cuda") 603 604 605# We ignore the tracer warnings coming form inside the library, because all our shape 606# checks in nn will trigger them. 607TracerWarning.ignore_lib_warnings() 608torch._C._tracer_warn_use_python() 609 610 611def make_tuple(example_inputs): 612 if isinstance(example_inputs, (torch.Tensor, dict)): 613 return (example_inputs,) 614 # done primarily so that weird iterables fail here and not pybind11 code 615 if not isinstance(example_inputs, tuple): 616 return tuple(example_inputs) 617 return example_inputs 618 619 620def make_module(mod, _module_class, _compilation_unit): 621 if isinstance(mod, ScriptModule): 622 return mod 623 elif torch._jit_internal.module_has_exports(mod): 624 infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods 625 return torch.jit._recursive.create_script_module( 626 mod, infer_methods_stubs_fn, share_types=False, is_tracing=True 627 ) 628 else: 629 if _module_class is None: 630 _module_class = TopLevelTracedModule 631 return _module_class(mod, _compilation_unit=_compilation_unit) 632 633 634def wrap_check_inputs(check_inputs): 635 if check_inputs is None: 636 return None 637 638 return [{"forward": c} for c in check_inputs] 639 640 641def analyze_ts_result_with_export_result(export, trace): 642 import torch.utils._pytree as pytree 643 644 flat_export = pytree.tree_leaves(export) 645 flat_trace = pytree.tree_leaves(trace) 646 647 for orig, loaded in zip(flat_export, flat_trace): 648 if orig.layout != loaded.layout: 649 return False 650 # mkldnn is not supported for torch.allclose 651 if orig.layout == torch._mkldnn: # type: ignore[attr-defined] 652 return True 653 if type(orig) != type(loaded): 654 return False 655 656 if isinstance(orig, torch._subclasses.FakeTensor): 657 # Skip for FakeTensor. 658 return True 659 elif isinstance(orig, torch.Tensor): 660 if orig.dtype != loaded.dtype: 661 return False 662 if not torch.allclose(orig, loaded): 663 return False 664 else: 665 if orig != loaded: 666 return False 667 return True 668 669 670def _trace_impl( 671 func, 672 example_inputs=None, 673 optimize=None, 674 check_trace=True, 675 check_inputs=None, 676 check_tolerance=1e-5, 677 strict=True, 678 _force_outplace=False, 679 _module_class=None, 680 _compilation_unit=_python_cu, 681 example_kwarg_inputs=None, 682 _store_inputs=True, 683): 684 if isinstance(func, torch.jit.ScriptModule): 685 # it is hard to trace it because the forward method on ScriptModule is already defined, so it 686 # would result in an error. 687 warnings.warn( 688 "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is." 689 ) 690 return func 691 692 if isinstance(func, torch.nn.Module): 693 if example_inputs is None: 694 if isinstance(example_kwarg_inputs, dict): 695 example_inputs = example_kwarg_inputs 696 else: 697 raise RuntimeError("example_kwarg_inputs should be a dict") 698 return trace_module( 699 func, 700 {"forward": example_inputs}, 701 None, 702 check_trace, 703 wrap_check_inputs(check_inputs), 704 check_tolerance, 705 strict, 706 _force_outplace, 707 _module_class, 708 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), 709 _store_inputs=_store_inputs, 710 ) 711 if ( 712 hasattr(func, "__self__") 713 and isinstance(func.__self__, torch.nn.Module) 714 and func.__name__ == "forward" 715 ): 716 if example_inputs is None: 717 if isinstance(example_kwarg_inputs, dict): 718 example_inputs = example_kwarg_inputs 719 else: 720 raise RuntimeError("example_kwarg_inputs should be a dict") 721 return trace_module( 722 func.__self__, 723 {"forward": example_inputs}, 724 None, 725 check_trace, 726 wrap_check_inputs(check_inputs), 727 check_tolerance, 728 strict, 729 _force_outplace, 730 _module_class, 731 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), 732 _store_inputs=_store_inputs, 733 ) 734 735 # Special case for common case of passing a single Tensor 736 if ( 737 isinstance(example_inputs, (torch.Tensor, dict)) 738 and example_kwarg_inputs is None 739 ): 740 example_inputs = (example_inputs,) 741 # done primarily so that weird iterables fail here and not pybind11 code 742 elif example_kwarg_inputs is None and not isinstance(example_inputs, tuple): 743 example_inputs = tuple(example_inputs) 744 745 var_lookup_fn = _create_interpreter_name_lookup_fn(0) 746 747 if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module): 748 raise AttributeError( 749 "trace doesn't support compiling individual module's functions.\n" 750 "Please use trace_module" 751 ) 752 753 name = _qualified_name(func) 754 if isinstance(example_kwarg_inputs, dict): 755 example_inputs = example_kwarg_inputs 756 traced = torch._C._create_function_from_trace_with_dict( 757 name, 758 func, 759 example_kwarg_inputs, 760 var_lookup_fn, 761 strict, 762 _force_outplace, 763 get_callable_argument_names(func), 764 ) 765 else: 766 traced = torch._C._create_function_from_trace( 767 name, 768 func, 769 example_inputs, 770 var_lookup_fn, 771 strict, 772 _force_outplace, 773 get_callable_argument_names(func), 774 ) 775 776 # Check the trace against new traces created from user-specified inputs 777 if check_trace: 778 if check_inputs is not None: 779 _check_trace( 780 check_inputs, 781 func, 782 traced, 783 check_tolerance, 784 strict, 785 _force_outplace, 786 False, 787 _module_class, 788 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), 789 ) 790 else: 791 _check_trace( 792 [example_inputs], 793 func, 794 traced, 795 check_tolerance, 796 strict, 797 _force_outplace, 798 False, 799 _module_class, 800 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), 801 ) 802 803 # Allow torch.compile() to inline 804 traced._torchdynamo_inline = func # type: ignore[attr-defined] 805 return traced 806 807 808class _ExportType(str, Enum): 809 DIRECT_EXPORT = "DIRECT_EXPORT" 810 TRACE_AND_EXPORT = "TRACE_AND_EXPORT" 811 SOURCE_TO_SOURCE = "SOURCE_TO_SOURCE" 812 813 def __str__(self) -> str: 814 return self.value 815 816 817class _ExportOutcome(str, Enum): 818 SUCCESS = "SUCCESS" 819 FAILED_TO_EXPORT = "FAILED_TO_EXPORT" 820 FAILED_TO_RUN = "FAILED_TO_RUN" 821 ACCURACY_ERROR = "ACCURACY_ERROR" 822 823 def __str__(self) -> str: 824 return self.value 825 826 827def trace( 828 func, 829 example_inputs=None, 830 optimize=None, 831 check_trace=True, 832 check_inputs=None, 833 check_tolerance=1e-5, 834 strict=True, 835 _force_outplace=False, 836 _module_class=None, 837 _compilation_unit=_python_cu, 838 example_kwarg_inputs=None, 839 _store_inputs=True, 840): 841 r""" 842 Trace a function and return an executable or :class:`ScriptFunction` that will be optimized using just-in-time compilation. 843 844 Tracing is ideal for code that operates only on 845 ``Tensor``\\s and lists, dictionaries, and 846 tuples of ``Tensor``\\s. 847 848 Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an 849 existing module or Python function into a TorchScript 850 :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example 851 inputs, and we run the function, recording the operations performed on all 852 the tensors. 853 854 * The resulting recording of a standalone function produces `ScriptFunction`. 855 * The resulting recording of `nn.Module.forward` or `nn.Module` produces 856 `ScriptModule`. 857 858 This module also contains any parameters that the original 859 module had as well. 860 861 Warning: 862 Tracing only correctly records functions and modules which are not data 863 dependent (e.g., do not have conditionals on data in tensors) and do not have 864 any untracked external dependencies (e.g., perform input/output or 865 access global variables). Tracing only records operations done when the given 866 function is run on the given tensors. Therefore, the returned 867 `ScriptModule` will always run the same traced graph on any input. This 868 has some important implications when your module is expected to run 869 different sets of operations, depending on the input and/or the module 870 state. For example, 871 872 * Tracing will not record any control-flow like if-statements or loops. 873 When this control-flow is constant across your module, this is fine 874 and it often inlines the control-flow decisions. But sometimes the 875 control-flow is actually part of the model itself. For instance, a 876 recurrent network is a loop over the (possibly dynamic) length of an 877 input sequence. 878 * In the returned :class:`ScriptModule`, operations that have different 879 behaviors in ``training`` and ``eval`` modes will always behave as if 880 it is in the mode it was in during tracing, no matter which mode the 881 `ScriptModule` is in. 882 883 In cases like these, tracing would not be appropriate and 884 :func:`scripting <torch.jit.script>` is a better choice. If you trace 885 such models, you may silently get incorrect results on subsequent 886 invocations of the model. The tracer will try to emit warnings when 887 doing something that may cause an incorrect trace to be produced. 888 889 Args: 890 func (callable or torch.nn.Module): A Python function or `torch.nn.Module` 891 that will be run with `example_inputs`. `func` arguments and return 892 values must be tensors or (possibly nested) tuples that contain 893 tensors. When a module is passed `torch.jit.trace`, only the 894 ``forward`` method is run and traced (see :func:`torch.jit.trace 895 <torch.jit.trace_module>` for details). 896 897 Keyword arguments: 898 example_inputs (tuple or torch.Tensor or None, optional): A tuple of example 899 inputs that will be passed to the function while tracing. 900 Default: ``None``. Either this argument or ``example_kwarg_inputs`` 901 should be specified. The resulting trace can be run with inputs of 902 different types and shapes assuming the traced operations support those 903 types and shapes. `example_inputs` may also be a single Tensor in which 904 case it is automatically wrapped in a tuple. When the value is None, 905 ``example_kwarg_inputs`` should be specified. 906 907 check_trace (``bool``, optional): Check if the same inputs run through 908 traced code produce the same outputs. Default: ``True``. You might want 909 to disable this if, for example, your network contains non- 910 deterministic ops or if you are sure that the network is correct despite 911 a checker failure. 912 913 check_inputs (list of tuples, optional): A list of tuples of input 914 arguments that should be used to check the trace against what is 915 expected. Each tuple is equivalent to a set of input arguments that 916 would be specified in ``example_inputs``. For best results, pass in 917 a set of checking inputs representative of the space of shapes and 918 types of inputs you expect the network to see. If not specified, 919 the original ``example_inputs`` are used for checking 920 check_tolerance (float, optional): Floating-point comparison tolerance 921 to use in the checker procedure. This can be used to relax the 922 checker strictness in the event that results diverge numerically 923 for a known reason, such as operator fusion. 924 strict (``bool``, optional): run the tracer in a strict mode or not 925 (default: ``True``). Only turn this off when you want the tracer to 926 record your mutable container types (currently ``list``/``dict``) 927 and you are sure that the container you are using in your 928 problem is a ``constant`` structure and does not get used as 929 control flow (if, for) conditions. 930 example_kwarg_inputs (dict, optional): This parameter is a pack of keyword 931 arguments of example inputs that will be passed to the function while 932 tracing. Default: ``None``. Either this argument or ``example_inputs`` 933 should be specified. The dict will be unpacking by the arguments name 934 of the traced function. If the keys of the dict don't not match with 935 the traced function's arguments name, a runtime exception will be raised. 936 937 Returns: 938 If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns 939 a :class:`ScriptModule` object with a single ``forward`` method 940 containing the traced code. The returned `ScriptModule` will 941 have the same set of sub-modules and parameters as the original 942 ``nn.Module``. If ``func`` is a standalone function, ``trace`` 943 returns `ScriptFunction`. 944 945 Example (tracing a function): 946 947 .. testcode:: 948 949 import torch 950 951 def foo(x, y): 952 return 2 * x + y 953 954 # Run `foo` with the provided inputs and record the tensor operations 955 traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) 956 957 # `traced_foo` can now be run with the TorchScript interpreter or saved 958 # and loaded in a Python-free environment 959 960 Example (tracing an existing module):: 961 962 import torch 963 import torch.nn as nn 964 965 class Net(nn.Module): 966 def __init__(self) -> None: 967 super().__init__() 968 self.conv = nn.Conv2d(1, 1, 3) 969 970 def forward(self, x): 971 return self.conv(x) 972 973 n = Net() 974 example_weight = torch.rand(1, 1, 3, 3) 975 example_forward_input = torch.rand(1, 1, 3, 3) 976 977 # Trace a specific method and construct `ScriptModule` with 978 # a single `forward` method 979 module = torch.jit.trace(n.forward, example_forward_input) 980 981 # Trace a module (implicitly traces `forward`) and construct a 982 # `ScriptModule` with a single `forward` method 983 module = torch.jit.trace(n, example_forward_input) 984 985 """ 986 if not _enabled: 987 return func 988 if optimize is not None: 989 warnings.warn( 990 "`optimize` is deprecated and has no effect. " 991 "Use `with torch.jit.optimized_execution()` instead", 992 FutureWarning, 993 stacklevel=2, 994 ) 995 996 from torch._utils_internal import ( 997 check_if_torch_exportable, 998 log_torch_jit_trace_exportability, 999 log_torchscript_usage, 1000 ) 1001 1002 traced_func = _trace_impl( 1003 func, 1004 example_inputs, 1005 optimize, 1006 check_trace, 1007 check_inputs, 1008 check_tolerance, 1009 strict, 1010 _force_outplace, 1011 _module_class, 1012 _compilation_unit, 1013 example_kwarg_inputs, 1014 _store_inputs, 1015 ) 1016 log_torchscript_usage("trace", model_id=_get_model_id(traced_func)) 1017 1018 if check_if_torch_exportable(): 1019 from torch._export.converter import TS2EPConverter 1020 from torch.export._trace import ( 1021 _convert_ts_to_export_experimental, 1022 _process_jit_trace_inputs_for_export, 1023 ) 1024 1025 traced_func_for_export = _trace_impl( 1026 func, 1027 example_inputs=example_inputs, 1028 optimize=optimize, 1029 check_trace=False, 1030 check_inputs=check_inputs, 1031 check_tolerance=check_tolerance, 1032 strict=strict, 1033 _force_outplace=_force_outplace, 1034 _module_class=_module_class, 1035 _compilation_unit=_compilation_unit, 1036 example_kwarg_inputs=example_kwarg_inputs, 1037 _store_inputs=_store_inputs, 1038 ) 1039 1040 export_args, _ = _process_jit_trace_inputs_for_export( 1041 example_inputs, example_kwarg_inputs 1042 ) 1043 1044 def _log_exportability(func_to_export, export_func, export_args, export_type): 1045 try: 1046 traced_result = func_to_export(*export_args) 1047 except Exception as e: 1048 _ = e 1049 log_torch_jit_trace_exportability( 1050 "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" 1051 ) 1052 return 1053 1054 try: 1055 ep_module = export_func(func_to_export, export_args) 1056 except Exception as e: 1057 log_torch_jit_trace_exportability( 1058 "trace", 1059 str(export_type), 1060 str(_ExportOutcome.FAILED_TO_EXPORT), 1061 str(e), 1062 ) 1063 return 1064 1065 try: 1066 export = ep_module(*export_args) 1067 except Exception as e: 1068 log_torch_jit_trace_exportability( 1069 "trace", str(export_type), str(_ExportOutcome.FAILED_TO_RUN), str(e) 1070 ) 1071 return 1072 1073 if not analyze_ts_result_with_export_result(export, traced_result): 1074 log_torch_jit_trace_exportability( 1075 "trace", 1076 str(export_type), 1077 str(_ExportOutcome.ACCURACY_ERROR), 1078 "accuracy error", 1079 ) 1080 return 1081 1082 log_torch_jit_trace_exportability( 1083 "trace", str(export_type), str(_ExportOutcome.SUCCESS), "succeeded" 1084 ) 1085 1086 def _direct_export_and_lower(func, export_args): 1087 return torch.export.export(func, export_args, strict=False).module() 1088 1089 def _convert_ts_to_export_source_to_source(func, export_args): 1090 return TS2EPConverter(func, export_args).convert().module() 1091 1092 # torch.jit.trace is noop when the original module is torch.jit.ScriptModule 1093 if not isinstance(traced_func_for_export, torch.jit.ScriptModule): 1094 _log_exportability( 1095 traced_func_for_export, 1096 _direct_export_and_lower, 1097 export_args, 1098 _ExportType.DIRECT_EXPORT, 1099 ) 1100 1101 _log_exportability( 1102 traced_func_for_export, 1103 _convert_ts_to_export_experimental, 1104 export_args, 1105 _ExportType.TRACE_AND_EXPORT, 1106 ) 1107 _log_exportability( 1108 traced_func_for_export, 1109 _convert_ts_to_export_source_to_source, 1110 export_args, 1111 _ExportType.SOURCE_TO_SOURCE, 1112 ) 1113 1114 return traced_func 1115 1116 1117_trace_module_map: Optional[Dict[Any, Any]] = None 1118 1119 1120def trace_module( 1121 mod, 1122 inputs, 1123 optimize=None, 1124 check_trace=True, 1125 check_inputs=None, 1126 check_tolerance=1e-5, 1127 strict=True, 1128 _force_outplace=False, 1129 _module_class=None, 1130 _compilation_unit=_python_cu, 1131 example_inputs_is_kwarg=False, 1132 _store_inputs=True, 1133): 1134 """ 1135 Trace a module and return an executable :class:`ScriptModule` that will be optimized using just-in-time compilation. 1136 1137 When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only 1138 the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of 1139 method names to example inputs to trace (see the ``inputs``) argument below. 1140 1141 See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing. 1142 1143 Args: 1144 mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are 1145 specified in ``inputs``. The given methods will be compiled 1146 as a part of a single `ScriptModule`. 1147 inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. 1148 The inputs will be passed to methods whose names correspond to inputs' 1149 keys while tracing. 1150 ``{ 'forward' : example_forward_input, 'method2': example_method2_input}`` 1151 Keyword arguments: 1152 check_trace (``bool``, optional): Check if the same inputs run through 1153 traced code produce the same outputs. Default: ``True``. You might want 1154 to disable this if, for example, your network contains non- 1155 deterministic ops or if you are sure that the network is correct despite 1156 a checker failure. 1157 1158 check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used 1159 to check the trace against what is expected. Each tuple 1160 is equivalent to a set of input arguments that would 1161 be specified in ``inputs``. For best results, pass in a 1162 set of checking inputs representative of the space of 1163 shapes and types of inputs you expect the network to see. 1164 If not specified, the original ``inputs`` are used for checking 1165 check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure. 1166 This can be used to relax the checker strictness in the event that 1167 results diverge numerically for a known reason, such as operator fusion. 1168 example_inputs_is_kwarg (``bool``, optional): This parameter indicate whether the example inputs is a pack 1169 pack of keyword arguments. Default: ``False``. 1170 1171 Returns: 1172 A :class:`ScriptModule` object with a single ``forward`` method containing the traced code. 1173 When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of 1174 sub-modules and parameters as ``func``. 1175 1176 Example (tracing a module with multiple methods):: 1177 1178 import torch 1179 import torch.nn as nn 1180 1181 class Net(nn.Module): 1182 def __init__(self) -> None: 1183 super().__init__() 1184 self.conv = nn.Conv2d(1, 1, 3) 1185 1186 def forward(self, x): 1187 return self.conv(x) 1188 1189 def weighted_kernel_sum(self, weight): 1190 return weight * self.conv.weight 1191 1192 1193 n = Net() 1194 example_weight = torch.rand(1, 1, 3, 3) 1195 example_forward_input = torch.rand(1, 1, 3, 3) 1196 1197 # Trace a specific method and construct `ScriptModule` with 1198 # a single `forward` method 1199 module = torch.jit.trace(n.forward, example_forward_input) 1200 1201 # Trace a module (implicitly traces `forward`) and construct a 1202 # `ScriptModule` with a single `forward` method 1203 module = torch.jit.trace(n, example_forward_input) 1204 1205 # Trace specific methods on a module (specified in `inputs`), constructs 1206 # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods 1207 inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} 1208 module = torch.jit.trace_module(n, inputs) 1209 1210 """ 1211 if not _enabled: 1212 return mod 1213 if optimize is not None: 1214 warnings.warn( 1215 "`optimize` is deprecated and has no effect. " 1216 "Use `with torch.jit.optimized_execution()` instead", 1217 FutureWarning, 1218 stacklevel=2, 1219 ) 1220 1221 var_lookup_fn = _create_interpreter_name_lookup_fn(0) 1222 1223 if not isinstance(mod, torch.nn.Module): 1224 raise AttributeError("expected torch.nn.Module as the first argument") 1225 1226 if not isinstance(inputs, dict): 1227 raise AttributeError("expected a dictionary of (method_name, input) pairs") 1228 1229 old_module_map = torch.jit._trace._trace_module_map 1230 try: 1231 trace_module_map: Dict[Any, Any] = {} 1232 1233 def register_submods(mod, prefix): 1234 for name, child in mod.named_children(): 1235 submod_qualname = prefix + "." + name 1236 trace_module_map[child] = submod_qualname 1237 register_submods(child, submod_qualname) 1238 1239 trace_module_map["__module"] = mod 1240 torch.jit._trace._trace_module_map = trace_module_map 1241 register_submods(mod, "__module") 1242 1243 module = make_module(mod, _module_class, _compilation_unit) 1244 1245 for method_name, example_inputs in inputs.items(): 1246 if method_name == "forward": 1247 # "forward" is a special case because we need to trace 1248 # `Module.__call__`, which sets up some extra tracing, but uses 1249 # argument names of the real `Module.forward` method. 1250 func = mod 1251 forward_method = getattr(mod, method_name) 1252 argument_names = get_callable_argument_names(forward_method) 1253 else: 1254 func = getattr(mod, method_name) 1255 argument_names = get_callable_argument_names(func) 1256 1257 if isinstance(example_inputs, dict) and example_inputs_is_kwarg: 1258 # Raise exception when the user provided key names are not aligned with forward() method's arguments' name/ 1259 for key in example_inputs: 1260 if key not in argument_names: 1261 valid_arguments = "[" + ",".join(argument_names) + "]" 1262 raise NameError( 1263 f"""'{key}' is not in forward() method's arguments, 1264 valid arguments name are {valid_arguments}""" 1265 ) 1266 module._c._create_method_from_trace_with_dict( 1267 method_name, 1268 func, 1269 example_inputs, 1270 var_lookup_fn, 1271 strict, 1272 _force_outplace, 1273 argument_names, 1274 _store_inputs, 1275 ) 1276 else: 1277 example_inputs = make_tuple(example_inputs) 1278 module._c._create_method_from_trace( 1279 method_name, 1280 func, 1281 example_inputs, 1282 var_lookup_fn, 1283 strict, 1284 _force_outplace, 1285 argument_names, 1286 _store_inputs, 1287 ) 1288 1289 check_trace_method = module._c._get_method(method_name) 1290 1291 # Check the trace against new traces created from user-specified inputs 1292 if check_trace: 1293 if check_inputs is not None: 1294 _check_trace( 1295 check_inputs, 1296 func, 1297 check_trace_method, 1298 check_tolerance, 1299 strict, 1300 _force_outplace, 1301 True, 1302 _module_class, 1303 example_inputs_is_kwarg=example_inputs_is_kwarg, 1304 ) 1305 else: 1306 _check_trace( 1307 [inputs], 1308 func, 1309 check_trace_method, 1310 check_tolerance, 1311 strict, 1312 _force_outplace, 1313 True, 1314 _module_class, 1315 example_inputs_is_kwarg=example_inputs_is_kwarg, 1316 ) 1317 finally: 1318 torch.jit._trace._trace_module_map = old_module_map 1319 1320 return module 1321 1322 1323def is_tracing(): 1324 """Return a boolean value. 1325 1326 Returns ``True`` in tracing (if a function is called during the 1327 tracing of code with ``torch.jit.trace``) and ``False`` otherwise. 1328 """ 1329 if is_scripting(): 1330 return False 1331 return torch._C._is_tracing() 1332 1333 1334class TracedModule(ScriptModule): 1335 _disable_script_meta = True 1336 1337 def __init__(self, orig, id_set=None, _compilation_unit=None): 1338 # XXX: orig can be a nn.Module or a function! 1339 super().__init__() 1340 assert isinstance(orig, torch.nn.Module) 1341 1342 # Copy a subset of `orig` to a temporary nn.Module. 1343 # This is a way to customize what will actually get compiled by create_script_module 1344 id_set = set() 1345 1346 # This allows us to preserve the original module's qualified name by defining a new 1347 # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name 1348 # we have a special case that will look up this attribute to override whatever qualname 1349 # we would get from the python type system 1350 class QualnameWrapper(torch.nn.Module): 1351 pass 1352 1353 QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined] 1354 type(orig) 1355 ) 1356 1357 tmp_module = QualnameWrapper() 1358 1359 def check_unique(param): 1360 if param in id_set: 1361 raise ValueError( 1362 "TracedModules don't support parameter sharing between modules" 1363 ) 1364 id_set.add(param) 1365 1366 tmp_module.training = orig.training 1367 1368 for name, param in orig._parameters.items(): 1369 if param is not None: 1370 tmp_module._parameters[name] = param 1371 check_unique(param) 1372 for name, buf in orig._buffers.items(): 1373 if buf is not None: 1374 tmp_module._buffers[name] = buf 1375 check_unique(buf) 1376 for name, val in orig.__dict__.items(): 1377 if ( 1378 torch._C._jit_is_script_object(val) 1379 and name not in orig._parameters 1380 and name not in orig._buffers 1381 ): 1382 setattr(tmp_module, name, val) 1383 1384 if orig._backward_hooks: 1385 raise ValueError( 1386 "Modules that have backward hooks assigned can't be compiled: " 1387 + str(orig) 1388 ) 1389 1390 for name, submodule in orig._modules.items(): 1391 if submodule is None: 1392 continue 1393 tmp_module._modules[name] = make_module( 1394 submodule, TracedModule, _compilation_unit=None 1395 ) 1396 1397 script_module = torch.jit._recursive.create_script_module( 1398 tmp_module, lambda module: (), share_types=False, is_tracing=True 1399 ) 1400 1401 self.__dict__["_name"] = type(orig).__name__ 1402 self.__dict__["_actual_script_module"] = script_module 1403 for name in ("_parameters", "_buffers", "_modules", "training"): 1404 delattr(self, name) 1405 1406 def forward(self, *args, **kwargs): 1407 raise RuntimeError("Trace submodules cannot be called.") 1408 1409 def __getattr__(self, attr): 1410 if "_actual_script_module" not in self.__dict__: 1411 return super().__getattr__(attr) 1412 return getattr(self._actual_script_module, attr) 1413 1414 def __setattr__(self, attr, value): 1415 if "_actual_script_module" not in self.__dict__: 1416 return super().__setattr__(attr, value) 1417 setattr(self._actual_script_module, attr, value) 1418 1419 def _get_name(self): 1420 return self._name 1421 1422 def extra_repr(self): 1423 return f"original_name={self._name}" 1424 1425 1426class TopLevelTracedModule(TracedModule): 1427 forward: Callable[..., Any] = _CachedForward() # type: ignore[assignment] 1428 1429 def _reconstruct(self, cpp_module): 1430 """ 1431 Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. 1432 1433 Args: 1434 cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. 1435 """ 1436 self.__dict__["_actual_script_module"]._reconstruct(cpp_module) 1437 1438 1439def _script_if_tracing(fn: Callable[P, R]) -> Callable[P, R]: 1440 @functools.wraps(fn) 1441 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 1442 if not is_tracing(): 1443 # Not tracing, don't do anything 1444 return fn(*args, **kwargs) 1445 1446 compiled_fn: Callable[P, R] = script(wrapper.__original_fn) # type: ignore[attr-defined] 1447 return compiled_fn(*args, **kwargs) 1448 1449 wrapper.__original_fn = fn # type: ignore[attr-defined] 1450 wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] 1451 1452 return wrapper 1453 1454 1455def _get_trace_graph( 1456 f, 1457 args=(), 1458 kwargs=None, 1459 strict=True, 1460 _force_outplace=False, 1461 return_inputs=False, 1462 _return_inputs_states=False, 1463): 1464 """Return a tuple on tracing a function or model. 1465 1466 .. warning:: 1467 This function is internal-only and should only be used by the ONNX 1468 exporter. If you are trying to get a graph through tracing, please go 1469 through the public API instead:: 1470 1471 trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) 1472 trace_graph = trace.graph 1473 1474 Trace a function or model, returning a tuple consisting of the both the 1475 *trace* of an execution, as well as the original return value. If return_inputs, 1476 also returns the trace inputs as part of the tuple 1477 1478 Tracing is guaranteed not to change the semantics of the function/module 1479 that is traced. 1480 1481 Args: 1482 f (torch.nn.Module or function): the function or module 1483 to be traced. 1484 args (tuple or Tensor): the positional arguments to pass to the 1485 function/module to be traced. A non-tuple is assumed to 1486 be a single positional argument to be passed to the model. 1487 kwargs (dict): the keyword arguments to pass to the function/module 1488 to be traced. 1489 1490 Example (trace a cell): 1491 1492 .. testcode:: 1493 1494 trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) 1495 """ 1496 if kwargs is None: 1497 kwargs = {} 1498 if not isinstance(args, tuple): 1499 args = (args,) 1500 outs = ONNXTracedModule( 1501 f, strict, _force_outplace, return_inputs, _return_inputs_states 1502 )(*args, **kwargs) 1503 return outs 1504