xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TracerMode.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/impl/LocalDispatchKeySet.h>
4 #include <c10/macros/Export.h>
5 #include <c10/macros/Macros.h>
6 
7 // NOTE [Tracing Mode Switches]
8 //
9 // Historically, tracing function was controlled by two switches:
10 //
11 // - `AutoDispatchBelowADInplaceOrView` guard
12 //
13 //    Tracing function used to be script-generated inside `VariableType_*.cpp`
14 //    kernels, sharing the same `Autograd` dispatch key with autograd function.
15 //    Therefore, before tracing function was moved out of VariableType,
16 //    `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
17 //    side effect of disabling `Autograd` dispatching.
18 //
19 // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
20 //
21 //    It stores tracing data in a `TracingState` object in TLS. If the
22 //    `TracingState` object in TLS is `null`, then tracing is paused.
23 //
24 //    The `TracingState` object is created in `tracer::trace()` - the main
25 //    entrance of tracing function. It's temporarily set to `null` inside
26 //    generated VariableType (now TraceType) to bypass tracing for intermediate
27 //    ops (ops being called by other ops). After the intermediate op call
28 //    finishes it's set back to the original `TracingState` object.
29 //
30 //    The `TracingState` obect in TLS can also be read/written via its Python
31 //    binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
32 //    which are also exposed as `TORCH_API`.
33 //
34 // Two new switches were introduced since tracing function was moved out of
35 // VariableType:
36 //
37 // - `tracer::impl::set_dispatch_enabled()` API
38 //
39 //    Unlike the special `Autograd` dispatch key which is included in dispatch
40 //    key set by default, `Tracer` dispatch key is off by default. The
41 //    dispatching switch can be toggled via this new API.
42 //
43 // - `tracer::impl::NoTracerDispatchMode` guard
44 //
45 //    It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
46 //    after tracing was moved out of VariableType.
47 //
48 // Before tracing function was moved out of VariableType, tracing was enabled
49 // when the following conditions are satisfied:
50 //
51 //    1) `TracingState` object in TLS != null;
52 //       - Either inside the execution scope of `tracer::trace()`, or
53 //       - Eagerly called `setTracingState()` with non-null object.
54 //    2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
55 //
56 // After:
57 //
58 //    1) `TracingState` object in TLS != null;
59 //    2) Has called `tracer::impl::set_dispatch_enabled(true)`;
60 //    3) Not inside `tracer::impl::NonDispatchGuard` scope;
61 //
62 // [TODOs]
63 //
64 // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
65 //
66 //   Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
67 //   to keep the semantics exactly the same as before - it's confusing to keep
68 //   both switches, though. We should consider simplifying/limiting the exposed
69 //   `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
70 //   these two can be unified.
71 //
72 // - `AutoDispatchBelowADInplaceOrView` v.s.
73 // `tracer::impl::NoTracerDispatchMode`
74 //
75 //   We don't need to always set both guards together to keep semantics
76 //   unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
77 //   we don't need set the new tracer guard:
78 //
79 //   * Script-generated VariableType kernels. The guard is not necessary as
80 //     tracing is already disabled explicitly by `setTracingState(null)` in
81 //     generated TraceType kernels - we could keep it as is or use the new guard
82 //     instead.
83 //
84 //   * Custom ops. Will be handled by fallback kernel for `Tracer`.
85 //
86 //   * Functions that are not likely to be called in tracing context (no python
87 //     binding / not an operator), e.g.: all mobile forward() wrappers, test
88 //     binaries, and etc.
89 //
90 //   * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
91 //     It's not necessary as tracing is off by default.
92 //
93 //   For the rest of cases we might need have both:
94 //
95 //   * Functions that might be reachable from eager mode python (especially
96 //     factory methods), e.g.:
97 //     `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
98 //     Without the new guard it will add `aten::empty` to the traced graph.
99 //
100 //   * Some manually maintained functions, e.g.:
101 //     `torch/csrc/autograd/VariableTypeManual.cpp`.
102 //     Set the new guard if it's not obvious whether `setTracingState(null)`
103 //     has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
104 //     guard.
105 //
106 //   We might need tweak the usage of the new guard to optimize/fix things.
107 //   It should only affect the correctness of tracing function, because the
108 //   guard is essentially no-op when the master `setTracingState()` switch is
109 //   off.
110 
111 // TODO: move this from `at::` to `jit::torch::` after
112 // `aten/src/ATen/cpp_custom_type_hack.h` is removed.
113 
114 namespace at::tracer::impl {
115 
is_dispatch_enabled()116 inline bool is_dispatch_enabled() {
117   return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
118       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
119 }
120 
set_dispatch_enabled(bool enabled)121 inline void set_dispatch_enabled(bool enabled) {
122   TORCH_INTERNAL_ASSERT(
123       !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
124       "Cannot enable tracing within the scope of NoTracerDispatchMode!");
125   c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
126 }
127 
128 struct NoTracerDispatchMode {
129   c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
130 };
131 
132 } // namespace at::tracer::impl
133