xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/tracer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Dimname.h>
4 #include <ATen/core/class_type.h>
5 #include <ATen/core/jit_type.h>
6 #include <ATen/core/stack.h>
7 #include <ATen/core/symbol.h>
8 #include <c10/util/Exception.h>
9 #include <torch/csrc/Export.h>
10 
11 #include <torch/csrc/jit/frontend/source_range.h>
12 #include <torch/csrc/utils/variadic.h>
13 
14 #include <cstdint>
15 #include <memory>
16 #include <unordered_map>
17 #include <vector>
18 
19 namespace torch::jit {
20 struct Node;
21 struct Value;
22 struct Graph;
23 struct Module;
24 
25 namespace tracer {
26 
27 using ::c10::ivalue::Shared;
28 
29 using ::c10::IValue;
30 using ::c10::ivalue::Future;
31 
32 using ::c10::ArrayRef;
33 using ::c10::TupleType;
34 using ::c10::TupleTypePtr;
35 using ::c10::ivalue::ConstantString;
36 
37 using torch::autograd::Variable;
38 using variable_list = std::vector<Variable>;
39 
40 TORCH_API std::atomic<bool>& getTracerStateWarnMode();
41 
42 struct TORCH_API TracingState
43     : public std::enable_shared_from_this<TracingState> {
44   TracingState();
45   ~TracingState();
46 
47   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
48   std::shared_ptr<Graph> graph;
49   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
50   bool warn = getTracerStateWarnMode();
51   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
52   bool strict = true;
53   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
54   bool force_outplace = false;
55   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
56   std::function<std::string(const Variable& var)> lookup_var_name_fn =
57       [](const Variable& var) { return ""; };
58 
enterFrameTracingState59   void enterFrame() {
60     env_stack.emplace_back();
61   }
62 
leaveFrameTracingState63   void leaveFrame() {
64     env_stack.pop_back();
65   }
66 
67   void setValue(const IValue& v, Value* value);
68   void delValue(const IValue& var);
69   Value* getValue(const IValue& var);
70   Value* getOutput(const IValue& var, size_t i);
71   bool hasValue(const IValue& var) const;
72 
73   Node* createNode(c10::Symbol op_name, size_t num_outputs);
74   void insertNode(Node* node);
75 
76  private:
77   using WeakIValue = at::WeakIValue;
78 
79   struct WeakIValueHasher {
operatorTracingState::WeakIValueHasher80     size_t operator()(const WeakIValue& t) const {
81       return t.hash();
82     }
83   };
84 
85   struct WeakIValueEq {
operatorTracingState::WeakIValueEq86     bool operator()(const WeakIValue& t1, const WeakIValue& t2) const {
87       return t1.isSameIdentity(t2);
88     }
89   };
90 
91   using Frame =
92       std::unordered_map<WeakIValue, Value*, WeakIValueHasher, WeakIValueEq>;
93   std::vector<Frame> env_stack;
94 };
95 
96 // This is meant to be used as a thread local place, where we can store extra
97 // info that gets lost when we call into ATen from Python bindings. One example
98 // for when this happens is when we get an IntArrayRef argument with e.g. sizes
99 // for view. When tracing, those might be tensors, which let us encode extra
100 // data dependencies, but once they get to the ATen call where we actually have
101 // the tracing logic, they get converted into a raw IntArrayRef, and we loose
102 // all information. To prevent this, we temporarily stash it in here.
103 struct ArgumentStash {
104   struct IntArrayRefTrace : std::vector<Value*> {
IntArrayRefTraceArgumentStash::IntArrayRefTrace105     IntArrayRefTrace(size_t size) : std::vector<Value*>(size, nullptr) {}
106   };
107 
emptyArgumentStash108   static bool empty() {
109     return stash.intlists.empty();
110   }
111 
112   TORCH_API static void stashIntArrayRefElem(
113       const std::string& arg_name,
114       size_t size,
115       size_t idx,
116       const Variable& var);
117 
hasIntArrayRefArgumentStash118   static bool hasIntArrayRef(const std::string& arg_name) {
119     return stash.intlists.count(arg_name) > 0;
120   }
121 
popIntArrayRefArgumentStash122   static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) {
123     auto info = std::move(stash.intlists.at(arg_name));
124     stash.intlists.erase(arg_name);
125     return info;
126   }
127 
128   // Value stashing: Use these methods to stash arguments which correspond
129   // to regular Value*'s in the graph. i.e. they don't require special
130   // handling like in the case of IntArrayRefs
131   TORCH_API static void stashValue(
132       const std::string& arg_name,
133       size_t idx,
134       const Variable& var,
135       const c10::TypePtr& type = nullptr);
136 
hasValueArgumentStash137   static bool hasValue(const std::string& arg_name) {
138     return stash.values.count(arg_name) > 0;
139   }
140 
popValueArgumentStash141   static Value* popValue(const std::string& arg_name) {
142     auto info = stash.values.at(arg_name);
143     stash.values.erase(arg_name);
144     return info;
145   }
146 
147  private:
148   static thread_local ArgumentStash stash;
149   std::unordered_map<std::string, IntArrayRefTrace> intlists;
150   std::unordered_map<std::string, Value*> values;
151 };
152 
153 // Retrieve or set the current tracing state. Returns a nullptr if tracing is
154 // disabled.
155 TORCH_API const std::shared_ptr<TracingState>& getTracingState();
156 TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
157 
isTracing()158 inline bool isTracing() {
159   return static_cast<bool>(getTracingState());
160 }
161 
162 using warn_fn_type = void (*)(const std::string& msg);
163 TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
164 TORCH_API extern const char* WARN_CONSTRUCTOR;
165 TORCH_API extern const char* WARN_RESIZE;
166 TORCH_API extern const char* STRICT_TRACER_MSG;
167 TORCH_API void _do_warn(const char* _reason, const char* _kind);
168 inline void warn(const char* _reason, const char* _kind = nullptr) {
169   if (const auto& state = getTracingState()) {
170     if (!state->warn)
171       return;
172     _do_warn(_reason, _kind);
173   }
174 }
175 TORCH_API void setWarn(warn_fn_type fn);
176 
177 struct TORCH_API NoWarn {
NoWarnNoWarn178   NoWarn() : state(getTracingState()) {
179     if (state) {
180       prev = state->warn;
181       state->warn = false;
182     }
183   }
~NoWarnNoWarn184   ~NoWarn() {
185     if (state) {
186       state->warn = prev;
187     }
188   }
189   std::shared_ptr<TracingState> state;
190   bool prev{false};
191 };
192 
193 struct WithNestedTracingFrame {
WithNestedTracingFrameWithNestedTracingFrame194   WithNestedTracingFrame() {
195     getTracingState()->enterFrame();
196   }
197 
~WithNestedTracingFrameWithNestedTracingFrame198   ~WithNestedTracingFrame() {
199     getTracingState()->leaveFrame();
200   }
201 };
202 TORCH_API void recordSourceLocation(Node* n);
203 TORCH_API void setRecordSourceLocation(void (*v)(Node*));
204 
205 TORCH_API std::vector<StackEntry> pythonCallstack();
206 TORCH_API void setPythonCallstack(std::vector<StackEntry> (*v)());
207 
208 // Having finished adding a new 'node' to the graph IR 'setValueTrace'
209 // associates this node with an output variable, so that further operations
210 // involving this variable know which node in the IR to reference.
211 TORCH_API void setValueTrace(const IValue& v, Value* value);
212 
213 TORCH_API void delValueTrace(const IValue& var);
214 
215 TORCH_API std::function<void()> pauseTracing();
216 
217 TORCH_API Value* getValueTrace(const IValue& var);
218 
219 TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
220     Stack inputs,
221     const std::function<Stack(Stack)>& traced_fn,
222     std::function<std::string(const Variable&)> var_name_lookup_fn,
223     bool strict = true,
224     bool force_outplace = false,
225     Module* self = nullptr,
226     const std::vector<std::string>& argument_names = {});
227 
228 TORCH_API void abandon();
229 
230 // NB: those serve both as an intermediate steps in addInputs below,
231 // as well as the overloads that terminate template recursion
232 TORCH_API void addInputs(Node* n, const char* name, int64_t value);
233 TORCH_API void addInputs(Node* n, const char* name, const c10::SymInt& value);
234 TORCH_API void addInputs(
235     Node* n,
236     const char* name,
237     std::optional<int64_t> value);
238 TORCH_API void addInputs(Node* n, const char* name, bool value);
239 TORCH_API void addInputs(
240     Node* n,
241     const char* name,
242     const std::optional<bool>& value);
243 TORCH_API void addInputs(Node* n, const char* name, double value);
244 TORCH_API void addInputs(
245     Node* n,
246     const char* name,
247     const std::optional<double>& value);
248 TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
249 TORCH_API void addInputs(
250     Node* n,
251     const char* name,
252     const std::optional<at::Scalar>& value);
253 TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
254 TORCH_API void addInputs(
255     Node* n,
256     const char* name,
257     const std::optional<at::Tensor>& value);
258 TORCH_API void addInputs(Node* n, const char* name, ArrayRef<int64_t> value);
259 TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value);
260 TORCH_API void addInputs(
261     Node* n,
262     const char* name,
263     std::optional<c10::SymInt> value);
264 TORCH_API void addInputs(
265     Node* n,
266     const char* name,
267     const std::optional<ArrayRef<int64_t>>& value);
268 TORCH_API void addInputs(
269     Node* n,
270     const char* name,
271     const at::OptionalIntArrayRef& opt_value);
272 TORCH_API void addInputs(
273     Node* n,
274     const char* name,
275     const at::OptionalSymIntArrayRef& opt_value);
276 TORCH_API void addInputs(
277     Node* n,
278     const char* name,
279     ArrayRef<at::Tensor> value,
280     bool allow_undefined = false);
281 TORCH_API void addInputs(
282     Node* n,
283     const char* name,
284     const std::vector<at::Tensor>& value,
285     bool allow_undefined = false);
286 TORCH_API void addInputs(
287     Node* n,
288     const char* name,
289     at::ITensorListRef value,
290     bool allow_undefined = false);
291 TORCH_API void addInputs(
292     Node* n,
293     const char* name,
294     const List<std::optional<at::Tensor>>& value);
295 TORCH_API void addInputs(
296     Node* n,
297     const char* name,
298     ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
299     const c10::ClassTypePtr& class_type);
300 TORCH_API void addInputs(Node* n, const char* name, ArrayRef<double> value);
301 TORCH_API void addInputs(
302     Node* n,
303     const char* name,
304     const std::optional<ArrayRef<double>>& value);
305 TORCH_API void addInputs(
306     Node* n,
307     const char* name,
308     const c10::string_view value);
309 TORCH_API void addInputs(
310     Node* n,
311     const char* name,
312     const std::optional<c10::string_view>& value);
313 TORCH_API void addInputs(Node* n, const char* name, at::Device value);
314 TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream);
315 TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
316 TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
317 TORCH_API void addInputs(
318     Node* n,
319     const char* name,
320     const std::optional<at::ScalarType>& value);
321 TORCH_API void addInputs(
322     Node* n,
323     const char* name,
324     const std::optional<at::Device>& value);
325 TORCH_API void addInputs(
326     Node* n,
327     const char* name,
328     const std::optional<at::Layout>& value);
329 TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
330 TORCH_API void addInputs(
331     Node* n,
332     const char* name,
333     std::optional<at::DimnameList> value);
334 TORCH_API void addInputs(
335     Node* n,
336     const char* name,
337     const std::optional<at::MemoryFormat>& value);
338 TORCH_API void addInputs(
339     Node* n,
340     const char* name,
341     const std::optional<at::Generator>& value);
342 
addInputs(Node * n,const char * name,const std::vector<bool> & value)343 inline void addInputs(
344     Node* n,
345     const char* name,
346     const std::vector<bool>& value) {
347   AT_ERROR("Tracing a list of bool type is currently not supported!");
348 }
349 
350 template <typename T>
addInputs(Node * n,const char * name,ArrayRef<T> value)351 void addInputs(Node* n, const char* name, ArrayRef<T> value) {
352   AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
353 }
354 template <typename K, typename V>
addInputs(Node * n,const char * name,const std::unordered_map<K,V> & value)355 void addInputs(
356     Node* n,
357     const char* name,
358     const std::unordered_map<K, V>& value) {
359   AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
360 }
361 
362 template <size_t N>
addInputs(Node * n,const char * name,std::array<bool,N> value)363 void addInputs(Node* n, const char* name, std::array<bool, N> value) {
364   throw std::runtime_error(
365       "Found an unsupported argument type in the JIT tracer. File a bug report.");
366 }
367 
368 TORCH_API void addInputs(
369     Node* n,
370     const char* name,
371     const c10::intrusive_ptr<c10::ivalue::Object>& obj);
372 
373 TORCH_API void ensureUniqueIfOutOfPlaced(
374     const char* name,
375     const at::Tensor& tensor);
376 TORCH_API void ensureUniqueIfOutOfPlaced(
377     const char* name,
378     const std::optional<at::Tensor>& tensor);
379 
380 template <
381     typename T,
382     typename = std::enable_if_t<
383         (!std::is_convertible_v<std::decay_t<T>, at::TensorList> &&
384          !std::is_convertible_v<std::decay_t<T>, c10::List<at::Tensor>> &&
385          !std::is_convertible_v<std::decay_t<T>, at::Tensor> &&
386          !std::is_convertible_v<
387              std::decay_t<T>,
388              c10::intrusive_ptr<c10::ivalue::Object>>)>>
addOutput(Node * node,T &&)389 void addOutput(Node* node, T&&) {
390   AT_ERROR(
391       "Found an unsupported argument type ",
392       c10::demangle_type<T>(),
393       " in the JIT tracer. File a bug report.");
394 }
395 TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
396 TORCH_API void setOutput(Value* value, const at::Tensor& output);
397 TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
398 TORCH_API void addOutput(Node* node, const c10::List<at::Tensor>& list);
399 TORCH_API void addOutput(
400     Node* node,
401     const c10::intrusive_ptr<c10::ivalue::Object>& output);
402 
403 TORCH_API autograd::Variable getSizeOf(
404     const autograd::Variable& var,
405     int64_t dim);
406 
407 TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
408 
409 } // namespace tracer
410 } // namespace torch::jit
411