1 #pragma once
2
3 #include <ATen/core/Dimname.h>
4 #include <ATen/core/class_type.h>
5 #include <ATen/core/jit_type.h>
6 #include <ATen/core/stack.h>
7 #include <ATen/core/symbol.h>
8 #include <c10/util/Exception.h>
9 #include <torch/csrc/Export.h>
10
11 #include <torch/csrc/jit/frontend/source_range.h>
12 #include <torch/csrc/utils/variadic.h>
13
14 #include <cstdint>
15 #include <memory>
16 #include <unordered_map>
17 #include <vector>
18
19 namespace torch::jit {
20 struct Node;
21 struct Value;
22 struct Graph;
23 struct Module;
24
25 namespace tracer {
26
27 using ::c10::ivalue::Shared;
28
29 using ::c10::IValue;
30 using ::c10::ivalue::Future;
31
32 using ::c10::ArrayRef;
33 using ::c10::TupleType;
34 using ::c10::TupleTypePtr;
35 using ::c10::ivalue::ConstantString;
36
37 using torch::autograd::Variable;
38 using variable_list = std::vector<Variable>;
39
40 TORCH_API std::atomic<bool>& getTracerStateWarnMode();
41
42 struct TORCH_API TracingState
43 : public std::enable_shared_from_this<TracingState> {
44 TracingState();
45 ~TracingState();
46
47 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
48 std::shared_ptr<Graph> graph;
49 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
50 bool warn = getTracerStateWarnMode();
51 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
52 bool strict = true;
53 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
54 bool force_outplace = false;
55 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
56 std::function<std::string(const Variable& var)> lookup_var_name_fn =
57 [](const Variable& var) { return ""; };
58
enterFrameTracingState59 void enterFrame() {
60 env_stack.emplace_back();
61 }
62
leaveFrameTracingState63 void leaveFrame() {
64 env_stack.pop_back();
65 }
66
67 void setValue(const IValue& v, Value* value);
68 void delValue(const IValue& var);
69 Value* getValue(const IValue& var);
70 Value* getOutput(const IValue& var, size_t i);
71 bool hasValue(const IValue& var) const;
72
73 Node* createNode(c10::Symbol op_name, size_t num_outputs);
74 void insertNode(Node* node);
75
76 private:
77 using WeakIValue = at::WeakIValue;
78
79 struct WeakIValueHasher {
operatorTracingState::WeakIValueHasher80 size_t operator()(const WeakIValue& t) const {
81 return t.hash();
82 }
83 };
84
85 struct WeakIValueEq {
operatorTracingState::WeakIValueEq86 bool operator()(const WeakIValue& t1, const WeakIValue& t2) const {
87 return t1.isSameIdentity(t2);
88 }
89 };
90
91 using Frame =
92 std::unordered_map<WeakIValue, Value*, WeakIValueHasher, WeakIValueEq>;
93 std::vector<Frame> env_stack;
94 };
95
96 // This is meant to be used as a thread local place, where we can store extra
97 // info that gets lost when we call into ATen from Python bindings. One example
98 // for when this happens is when we get an IntArrayRef argument with e.g. sizes
99 // for view. When tracing, those might be tensors, which let us encode extra
100 // data dependencies, but once they get to the ATen call where we actually have
101 // the tracing logic, they get converted into a raw IntArrayRef, and we loose
102 // all information. To prevent this, we temporarily stash it in here.
103 struct ArgumentStash {
104 struct IntArrayRefTrace : std::vector<Value*> {
IntArrayRefTraceArgumentStash::IntArrayRefTrace105 IntArrayRefTrace(size_t size) : std::vector<Value*>(size, nullptr) {}
106 };
107
emptyArgumentStash108 static bool empty() {
109 return stash.intlists.empty();
110 }
111
112 TORCH_API static void stashIntArrayRefElem(
113 const std::string& arg_name,
114 size_t size,
115 size_t idx,
116 const Variable& var);
117
hasIntArrayRefArgumentStash118 static bool hasIntArrayRef(const std::string& arg_name) {
119 return stash.intlists.count(arg_name) > 0;
120 }
121
popIntArrayRefArgumentStash122 static IntArrayRefTrace popIntArrayRef(const std::string& arg_name) {
123 auto info = std::move(stash.intlists.at(arg_name));
124 stash.intlists.erase(arg_name);
125 return info;
126 }
127
128 // Value stashing: Use these methods to stash arguments which correspond
129 // to regular Value*'s in the graph. i.e. they don't require special
130 // handling like in the case of IntArrayRefs
131 TORCH_API static void stashValue(
132 const std::string& arg_name,
133 size_t idx,
134 const Variable& var,
135 const c10::TypePtr& type = nullptr);
136
hasValueArgumentStash137 static bool hasValue(const std::string& arg_name) {
138 return stash.values.count(arg_name) > 0;
139 }
140
popValueArgumentStash141 static Value* popValue(const std::string& arg_name) {
142 auto info = stash.values.at(arg_name);
143 stash.values.erase(arg_name);
144 return info;
145 }
146
147 private:
148 static thread_local ArgumentStash stash;
149 std::unordered_map<std::string, IntArrayRefTrace> intlists;
150 std::unordered_map<std::string, Value*> values;
151 };
152
153 // Retrieve or set the current tracing state. Returns a nullptr if tracing is
154 // disabled.
155 TORCH_API const std::shared_ptr<TracingState>& getTracingState();
156 TORCH_API void setTracingState(std::shared_ptr<TracingState> state);
157
isTracing()158 inline bool isTracing() {
159 return static_cast<bool>(getTracingState());
160 }
161
162 using warn_fn_type = void (*)(const std::string& msg);
163 TORCH_API extern const char* WARN_PYTHON_DATAFLOW;
164 TORCH_API extern const char* WARN_CONSTRUCTOR;
165 TORCH_API extern const char* WARN_RESIZE;
166 TORCH_API extern const char* STRICT_TRACER_MSG;
167 TORCH_API void _do_warn(const char* _reason, const char* _kind);
168 inline void warn(const char* _reason, const char* _kind = nullptr) {
169 if (const auto& state = getTracingState()) {
170 if (!state->warn)
171 return;
172 _do_warn(_reason, _kind);
173 }
174 }
175 TORCH_API void setWarn(warn_fn_type fn);
176
177 struct TORCH_API NoWarn {
NoWarnNoWarn178 NoWarn() : state(getTracingState()) {
179 if (state) {
180 prev = state->warn;
181 state->warn = false;
182 }
183 }
~NoWarnNoWarn184 ~NoWarn() {
185 if (state) {
186 state->warn = prev;
187 }
188 }
189 std::shared_ptr<TracingState> state;
190 bool prev{false};
191 };
192
193 struct WithNestedTracingFrame {
WithNestedTracingFrameWithNestedTracingFrame194 WithNestedTracingFrame() {
195 getTracingState()->enterFrame();
196 }
197
~WithNestedTracingFrameWithNestedTracingFrame198 ~WithNestedTracingFrame() {
199 getTracingState()->leaveFrame();
200 }
201 };
202 TORCH_API void recordSourceLocation(Node* n);
203 TORCH_API void setRecordSourceLocation(void (*v)(Node*));
204
205 TORCH_API std::vector<StackEntry> pythonCallstack();
206 TORCH_API void setPythonCallstack(std::vector<StackEntry> (*v)());
207
208 // Having finished adding a new 'node' to the graph IR 'setValueTrace'
209 // associates this node with an output variable, so that further operations
210 // involving this variable know which node in the IR to reference.
211 TORCH_API void setValueTrace(const IValue& v, Value* value);
212
213 TORCH_API void delValueTrace(const IValue& var);
214
215 TORCH_API std::function<void()> pauseTracing();
216
217 TORCH_API Value* getValueTrace(const IValue& var);
218
219 TORCH_API std::pair<std::shared_ptr<TracingState>, Stack> trace(
220 Stack inputs,
221 const std::function<Stack(Stack)>& traced_fn,
222 std::function<std::string(const Variable&)> var_name_lookup_fn,
223 bool strict = true,
224 bool force_outplace = false,
225 Module* self = nullptr,
226 const std::vector<std::string>& argument_names = {});
227
228 TORCH_API void abandon();
229
230 // NB: those serve both as an intermediate steps in addInputs below,
231 // as well as the overloads that terminate template recursion
232 TORCH_API void addInputs(Node* n, const char* name, int64_t value);
233 TORCH_API void addInputs(Node* n, const char* name, const c10::SymInt& value);
234 TORCH_API void addInputs(
235 Node* n,
236 const char* name,
237 std::optional<int64_t> value);
238 TORCH_API void addInputs(Node* n, const char* name, bool value);
239 TORCH_API void addInputs(
240 Node* n,
241 const char* name,
242 const std::optional<bool>& value);
243 TORCH_API void addInputs(Node* n, const char* name, double value);
244 TORCH_API void addInputs(
245 Node* n,
246 const char* name,
247 const std::optional<double>& value);
248 TORCH_API void addInputs(Node* n, const char* name, const at::Scalar& value);
249 TORCH_API void addInputs(
250 Node* n,
251 const char* name,
252 const std::optional<at::Scalar>& value);
253 TORCH_API void addInputs(Node* n, const char* name, const at::Tensor& value);
254 TORCH_API void addInputs(
255 Node* n,
256 const char* name,
257 const std::optional<at::Tensor>& value);
258 TORCH_API void addInputs(Node* n, const char* name, ArrayRef<int64_t> value);
259 TORCH_API void addInputs(Node* n, const char* name, c10::SymIntArrayRef value);
260 TORCH_API void addInputs(
261 Node* n,
262 const char* name,
263 std::optional<c10::SymInt> value);
264 TORCH_API void addInputs(
265 Node* n,
266 const char* name,
267 const std::optional<ArrayRef<int64_t>>& value);
268 TORCH_API void addInputs(
269 Node* n,
270 const char* name,
271 const at::OptionalIntArrayRef& opt_value);
272 TORCH_API void addInputs(
273 Node* n,
274 const char* name,
275 const at::OptionalSymIntArrayRef& opt_value);
276 TORCH_API void addInputs(
277 Node* n,
278 const char* name,
279 ArrayRef<at::Tensor> value,
280 bool allow_undefined = false);
281 TORCH_API void addInputs(
282 Node* n,
283 const char* name,
284 const std::vector<at::Tensor>& value,
285 bool allow_undefined = false);
286 TORCH_API void addInputs(
287 Node* n,
288 const char* name,
289 at::ITensorListRef value,
290 bool allow_undefined = false);
291 TORCH_API void addInputs(
292 Node* n,
293 const char* name,
294 const List<std::optional<at::Tensor>>& value);
295 TORCH_API void addInputs(
296 Node* n,
297 const char* name,
298 ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
299 const c10::ClassTypePtr& class_type);
300 TORCH_API void addInputs(Node* n, const char* name, ArrayRef<double> value);
301 TORCH_API void addInputs(
302 Node* n,
303 const char* name,
304 const std::optional<ArrayRef<double>>& value);
305 TORCH_API void addInputs(
306 Node* n,
307 const char* name,
308 const c10::string_view value);
309 TORCH_API void addInputs(
310 Node* n,
311 const char* name,
312 const std::optional<c10::string_view>& value);
313 TORCH_API void addInputs(Node* n, const char* name, at::Device value);
314 TORCH_API void addInputs(Node* n, const char* name, c10::Stream stream);
315 TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
316 TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);
317 TORCH_API void addInputs(
318 Node* n,
319 const char* name,
320 const std::optional<at::ScalarType>& value);
321 TORCH_API void addInputs(
322 Node* n,
323 const char* name,
324 const std::optional<at::Device>& value);
325 TORCH_API void addInputs(
326 Node* n,
327 const char* name,
328 const std::optional<at::Layout>& value);
329 TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value);
330 TORCH_API void addInputs(
331 Node* n,
332 const char* name,
333 std::optional<at::DimnameList> value);
334 TORCH_API void addInputs(
335 Node* n,
336 const char* name,
337 const std::optional<at::MemoryFormat>& value);
338 TORCH_API void addInputs(
339 Node* n,
340 const char* name,
341 const std::optional<at::Generator>& value);
342
addInputs(Node * n,const char * name,const std::vector<bool> & value)343 inline void addInputs(
344 Node* n,
345 const char* name,
346 const std::vector<bool>& value) {
347 AT_ERROR("Tracing a list of bool type is currently not supported!");
348 }
349
350 template <typename T>
addInputs(Node * n,const char * name,ArrayRef<T> value)351 void addInputs(Node* n, const char* name, ArrayRef<T> value) {
352 AT_ERROR("Tracing a list of arbitrary type is currently not supported!");
353 }
354 template <typename K, typename V>
addInputs(Node * n,const char * name,const std::unordered_map<K,V> & value)355 void addInputs(
356 Node* n,
357 const char* name,
358 const std::unordered_map<K, V>& value) {
359 AT_ERROR("Tracing a dict of arbitrary types is currently not supported!");
360 }
361
362 template <size_t N>
addInputs(Node * n,const char * name,std::array<bool,N> value)363 void addInputs(Node* n, const char* name, std::array<bool, N> value) {
364 throw std::runtime_error(
365 "Found an unsupported argument type in the JIT tracer. File a bug report.");
366 }
367
368 TORCH_API void addInputs(
369 Node* n,
370 const char* name,
371 const c10::intrusive_ptr<c10::ivalue::Object>& obj);
372
373 TORCH_API void ensureUniqueIfOutOfPlaced(
374 const char* name,
375 const at::Tensor& tensor);
376 TORCH_API void ensureUniqueIfOutOfPlaced(
377 const char* name,
378 const std::optional<at::Tensor>& tensor);
379
380 template <
381 typename T,
382 typename = std::enable_if_t<
383 (!std::is_convertible_v<std::decay_t<T>, at::TensorList> &&
384 !std::is_convertible_v<std::decay_t<T>, c10::List<at::Tensor>> &&
385 !std::is_convertible_v<std::decay_t<T>, at::Tensor> &&
386 !std::is_convertible_v<
387 std::decay_t<T>,
388 c10::intrusive_ptr<c10::ivalue::Object>>)>>
addOutput(Node * node,T &&)389 void addOutput(Node* node, T&&) {
390 AT_ERROR(
391 "Found an unsupported argument type ",
392 c10::demangle_type<T>(),
393 " in the JIT tracer. File a bug report.");
394 }
395 TORCH_API void addOutput(Node* node, const at::Tensor& tensor);
396 TORCH_API void setOutput(Value* value, const at::Tensor& output);
397 TORCH_API void addOutput(Node* node, const std::vector<at::Tensor>& list);
398 TORCH_API void addOutput(Node* node, const c10::List<at::Tensor>& list);
399 TORCH_API void addOutput(
400 Node* node,
401 const c10::intrusive_ptr<c10::ivalue::Object>& output);
402
403 TORCH_API autograd::Variable getSizeOf(
404 const autograd::Variable& var,
405 int64_t dim);
406
407 TORCH_API autograd::Variable getNumelOf(const autograd::Variable& var);
408
409 } // namespace tracer
410 } // namespace torch::jit
411