1 #include <torch/csrc/jit/frontend/tracer.h>
2
3 #include <ATen/Backtrace.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/TracerMode.h>
6 #include <ATen/core/Dict.h>
7 #include <ATen/core/functional.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/engine.h>
11 #include <torch/csrc/autograd/function.h>
12 #include <torch/csrc/autograd/variable.h>
13 #include <torch/csrc/jit/api/module.h>
14 #include <torch/csrc/jit/ir/constants.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
18 #include <torch/csrc/jit/passes/inliner.h>
19 #include <torch/csrc/jit/passes/lower_tuples.h>
20 #include <torch/csrc/jit/passes/normalize_ops.h>
21 #include <torch/csrc/jit/passes/remove_expands.h>
22 #include <torch/csrc/utils/variadic.h>
23 #include <torch/custom_class.h>
24
25 #include <memory>
26 #include <sstream>
27 #include <string>
28
29 namespace torch::jit::tracer {
30
31 ////////////////////////////////////////////////////////////////////////////////
32 // Recording the traces
33 ////////////////////////////////////////////////////////////////////////////////
34 namespace detail {
35
36 template <typename T>
genericAddInput(Node * n,T value)37 void genericAddInput(Node* n, T value) {
38 Value* v = n->owningGraph()->insertConstant(value);
39 recordSourceLocation(v->node());
40 n->addInput(v);
41 }
42
43 template <typename T>
genericAddOptionalInput(Node * n,const char * name,const std::optional<T> & value)44 void genericAddOptionalInput(
45 Node* n,
46 const char* name,
47 const std::optional<T>& value) {
48 if (value) {
49 jit::tracer::addInputs(n, name, *value);
50 } else {
51 Graph* g = n->owningGraph();
52 Value* none = g->insertNode(g->createNone())->output();
53 n->addInput(none);
54 }
55 }
56
57 template <typename T>
badArgType(const T & v)58 void badArgType(const T& v) {
59 AT_ERROR(
60 "Found an unsupported argument type in the JIT tracer: ",
61 c10::demangle_type<T>(),
62 ". File a bug report.");
63 }
64
65 thread_local std::shared_ptr<TracingState> tracing_state;
66 } // namespace detail
67
68 static std::atomic<bool> tracer_state_warn_mode{true};
69
getTracerStateWarnMode()70 std::atomic<bool>& getTracerStateWarnMode() {
71 return tracer_state_warn_mode;
72 }
73
pauseTracing()74 std::function<void()> pauseTracing() {
75 // NOLINTNEXTLINE
76 std::shared_ptr<tracer::TracingState> state = getTracingState();
77 tracer::setTracingState(nullptr);
78
79 return [state]() { tracer::setTracingState(state); };
80 }
81
delValueTrace(const IValue & var)82 void delValueTrace(const IValue& var) {
83 getTracingState()->delValue(var);
84 }
delValue(const IValue & var)85 void TracingState::delValue(const IValue& var) {
86 for (const auto i : c10::irange(env_stack.size())) {
87 auto& value_map = env_stack.at(env_stack.size() - 1 - i);
88 auto it = value_map.find(var);
89 if (it == value_map.end()) {
90 continue;
91 }
92 value_map.erase(it);
93 }
94 }
95
96 // Given a IValue 'var', return the 'node' which represents the instruction
97 // which computes the value of this variable in the IR.
98 // Here, we interpret untraced variables as constants that are just embedded
99 // in the graph. This is useful to handle code which does things like this
100 // (from torch.autograd.variable, now moved to C++):
101 //
102 // def mm(self, matrix):
103 // output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
104 // return Addmm.apply(output, self, matrix, 0, 1, True)
105 //
106 // Here, mm fakes up a dummy variable with uninitialized data to do an inplace
107 // update on, but subsequently ignores it because the alpha scaling factor is
108 // zero. This is one of the cases where a Variable can be created inside of a
109 // trace, and if we treat it as a constant, everything will work out.
getValueTrace(const IValue & var)110 Value* getValueTrace(const IValue& var) {
111 return getTracingState()->getValue(var);
112 }
getOptTensorValueTrace(const std::optional<at::Tensor> & var)113 static Value* getOptTensorValueTrace(const std::optional<at::Tensor>& var) {
114 return getValueTrace(IValue(var));
115 }
getValue(const IValue & var)116 Value* TracingState::getValue(const IValue& var) {
117 // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
118 // arguments
119 if (var.isTensorList()) {
120 return graph
121 ->insertNode(graph->createList(
122 TensorType::get(),
123 fmap(
124 var.toTensorVector(),
125 [&](const IValue& val) { return getValue(val); })))
126 ->output();
127 } else if (var.isTuple()) {
128 return graph
129 ->insertNode(graph->createTuple(fmap(
130 var.toTupleRef().elements(),
131 [&](const IValue& val) { return getValue(val); })))
132 ->output();
133 } else if (var.isGenericDict()) {
134 auto dict = var.toGenericDict();
135 TypePtr key_type = dict.keyType();
136 TypePtr value_type = dict.valueType();
137 std::vector<Value*> keys;
138 std::vector<Value*> values;
139 for (const auto& entry : dict) {
140 keys.emplace_back(getValue(entry.key()));
141 values.emplace_back(getValue(entry.value()));
142 }
143 auto dict_node = graph->createDict(key_type, value_type, keys, values);
144 return graph->insertNode(dict_node)->output();
145 }
146 if (var.isTensor()) {
147 auto& ten = var.toTensor();
148 if (!ten.defined()) {
149 Node* n = graph->createNone();
150 return graph->insertNode(n)->output();
151 }
152 for (const auto i : c10::irange(env_stack.size())) {
153 auto& value_map = env_stack.at(env_stack.size() - 1 - i);
154 auto it = value_map.find(var);
155 if (it == value_map.end()) {
156 continue;
157 }
158 if (!it->second->hasDebugName()) {
159 auto unique_name = getTracingState()->lookup_var_name_fn(ten);
160 if (!unique_name.empty()) {
161 it->second->setDebugName(unique_name);
162 }
163 }
164 return it->second;
165 }
166
167 // Didn't find it. Bake in a constant
168 if (ten.requires_grad()) {
169 pauseTracing();
170 std::ostringstream oss;
171 oss << "Cannot insert a Tensor that requires grad as a constant. "
172 << "Consider making it a parameter or input, or detaching the gradient\n"
173 << "Tensor:\n"
174 << ten;
175 throw std::runtime_error(oss.str());
176 }
177
178 Value* constant = graph->insertConstant(ten);
179 recordSourceLocation(constant->node());
180 constant->inferTypeFrom(ten);
181 auto it = env_stack.back().emplace(var, constant);
182 return it.first->second;
183 } else if (var.isFuture() || var.isObject()) {
184 for (const auto i : c10::irange(env_stack.size())) {
185 auto& future_map = env_stack.at(env_stack.size() - 1 - i);
186 auto it = future_map.find(var);
187 if (it == future_map.end()) {
188 continue;
189 }
190 return it->second;
191 }
192
193 // Find torchbind classes
194 if (isCustomClass(var)) {
195 auto obj = Object(var.toObject());
196 auto qualname = obj.type()->name();
197 auto custom_class_type = getCustomClass(qualname->qualifiedName());
198 if (custom_class_type) {
199 auto capsule = var.toObject()->getAttr("capsule");
200 for (const auto i : c10::irange(env_stack.size())) {
201 auto& value_map = env_stack.at(env_stack.size() - 1 - i);
202 auto it = value_map.find(capsule);
203 if (it == value_map.end()) {
204 continue;
205 }
206 return it->second;
207 }
208 }
209 }
210
211 std::ostringstream oss;
212 if (var.isFuture()) {
213 oss << "Tried to trace Future or Object that the tracer was not aware of.";
214 } else {
215 oss << "Tried to trace " << var
216 << " but it is not part of the active trace. Modules that are called during a trace"
217 << " must be registered as submodules of the thing being traced.";
218 }
219 throw std::runtime_error(oss.str());
220 } else {
221 // If the values are non-tensors, we try to create constants
222 // and bake those constants into the traced graph
223 auto constant = tryInsertConstant(*graph, var);
224 if (constant) {
225 recordSourceLocation(constant.value()->node());
226 return *constant;
227 }
228 std::ostringstream os;
229 os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
230 << "The below value could not be materialized as a constant:\n"
231 << var;
232 throw std::runtime_error(os.str());
233 }
234 }
hasValue(const IValue & var) const235 bool TracingState::hasValue(const IValue& var) const {
236 for (const auto& frame : env_stack) {
237 if (frame.count(var)) {
238 return true;
239 }
240 }
241 return false;
242 }
243
getOutput(const IValue & iv,size_t i)244 Value* TracingState::getOutput(const IValue& iv, size_t i) {
245 bool tracing_mode_strict = getTracingState()->strict;
246 if (iv.isTensor()) {
247 const at::Tensor& var = iv.toTensor();
248 if (!var.defined()) {
249 Node* n = graph->createNone();
250 return graph->insertNode(n)->output();
251 }
252
253 auto& value_map = getTracingState()->env_stack.back();
254 auto it = value_map.find(iv);
255 if (it == value_map.end()) {
256 std::ostringstream os;
257 os << "output " << i << " (" << var
258 << ") of traced region did not have observable "
259 << "data dependence with trace inputs; this probably indicates your "
260 "program "
261 << "cannot be understood by the tracer.";
262 throw std::runtime_error(os.str());
263 }
264 return it->second;
265 } else if (iv.isTensorList()) {
266 if (tracing_mode_strict) {
267 tracer::warn(
268 "Encountering a list at the output of the tracer", STRICT_TRACER_MSG);
269 }
270 return graph
271 ->insertNode(graph->createList(
272 TensorType::get(),
273 fmap(
274 iv.toTensorVector(),
275 [&](const IValue& ival) { return getOutput(ival, i); })))
276 ->output();
277 } else if (iv.isTuple()) {
278 const auto& tuple = iv.toTupleRef().elements();
279 auto tuple_node = graph->createTuple(
280 fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); }));
281 graph->insertNode(tuple_node);
282 return tuple_node->output();
283 } else if (iv.isGenericDict()) {
284 if (tracing_mode_strict) {
285 throw std::runtime_error(
286 "Encountering a dict at the output of the tracer" +
287 std::string(STRICT_TRACER_MSG));
288 }
289 auto dict = iv.toGenericDict();
290 TypePtr key_type = dict.keyType();
291 TypePtr value_type = dict.valueType();
292
293 bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) ||
294 key_type->isSubtypeOf(*TensorType::get());
295 bool value_type_valid = value_type->isSubtypeOf(*TensorType::get());
296
297 // Support tuple values that contain only tensors
298 if (value_type->isSubtypeOf(*AnyTupleType::get())) {
299 value_type_valid = true;
300 for (const auto& type : value_type->containedTypes()) {
301 if (!type->isSubtypeOf(*TensorType::get())) {
302 value_type_valid = false;
303 break;
304 }
305 }
306 }
307
308 if (!key_type_valid || !value_type_valid) {
309 std::ostringstream os;
310 os << "output " << i << " (" << dict << ") of traced region "
311 << "cannot be understood by the tracer, only outputs matching"
312 << "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
313 << "can be a dictionary output of a traced function";
314 throw std::runtime_error(os.str());
315 }
316 std::vector<Value*> keys;
317 std::vector<Value*> values;
318 for (const auto& entry : dict) {
319 keys.emplace_back(getValue(entry.key()));
320 values.emplace_back(getOutput(entry.value(), i));
321 }
322 auto dict_node = graph->createDict(key_type, value_type, keys, values);
323 graph->insertNode(dict_node);
324 return dict_node->output();
325 } else {
326 AT_ERROR(
327 "Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions");
328 }
329 }
330
createNode(c10::Symbol op_name,size_t num_outputs)331 Node* TracingState::createNode(c10::Symbol op_name, size_t num_outputs) {
332 return graph->create(op_name, num_outputs);
333 }
334
insertNode(Node * node)335 void TracingState::insertNode(Node* node) {
336 graph->insertNode(node);
337 }
338
339 // XXX: this function mutates input
addInput(const std::shared_ptr<TracingState> & state,const IValue & input,const TypePtr & type,Value * value)340 static IValue addInput(
341 const std::shared_ptr<TracingState>& state,
342 const IValue& input,
343 const TypePtr& type,
344 Value* value) {
345 value->setType(type);
346 if (type->isSubtypeOf(*TensorType::get())) {
347 auto input_tensor = input.toTensor();
348 auto name = Variable(input_tensor).name();
349 if (state->hasValue(input)) {
350 input_tensor = input_tensor.view(input_tensor.sizes());
351 }
352 if (!value->hasDebugName()) {
353 value->setDebugName(name);
354 }
355 state->setValue(input_tensor, value);
356 return input_tensor;
357 } else if (auto tuple_type = type->cast<TupleType>()) {
358 auto unpack_node =
359 state->graph->insertNode(state->graph->createTupleUnpack(value));
360 auto elem_values = unpack_node->outputs();
361 auto elem_types = tuple_type->elements();
362 auto tuple = input.toTuple();
363 const auto& elems = tuple->elements();
364 size_t num_elems = elems.size();
365 AT_ASSERT(
366 elem_values.size() == num_elems && elem_types.size() == num_elems);
367 for (const auto i : c10::irange(num_elems)) {
368 tuple->unsafeSetElement(
369 i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
370 }
371 return tuple;
372 } else if (auto dict_type = type->cast<DictType>()) {
373 auto dict = input.toGenericDict();
374
375 // Unpack the list values statically
376 for (const auto& entry : dict) {
377 const IValue& key = entry.key();
378 auto static_key = state->graph->insertConstant(key);
379 auto static_value =
380 state->graph->insert(aten::__getitem__, {value, static_key});
381 recordSourceLocation(static_value->node());
382 dict.insert_or_assign(
383 entry.key(),
384 addInput(
385 state, entry.value(), dict_type->getValueType(), static_value));
386 }
387
388 return dict;
389 } else if (auto list_type = type->cast<ListType>()) {
390 size_t num_elems = input.isList() ? input.toListRef().size()
391 : input.toTensorVector().size();
392 auto list_unpack = state->graph->insertNode(
393 state->graph->createListUnpack(value, num_elems));
394 auto unpack_outputs = list_unpack->outputs();
395
396 if (input.isTensorList()) {
397 auto elems = input.toTensorList();
398 for (const auto i : c10::irange(num_elems)) {
399 elems[i] = addInput(
400 state,
401 elems.get(i),
402 list_type->getElementType(),
403 unpack_outputs[i])
404 .toTensor();
405 }
406 return elems;
407 } else {
408 auto elems = input.toList();
409 for (const auto i : c10::irange(num_elems)) {
410 elems[i] = addInput(
411 state,
412 elems.get(i),
413 list_type->getElementType(),
414 unpack_outputs[i]);
415 }
416 return elems;
417 }
418 } else {
419 AT_ERROR(
420 "Only tensors or (possibly nested) dict or tuples of tensors can be "
421 "inputs to traced functions. Got ",
422 type->repr_str());
423 }
424 }
425
gatherParametersAndBuffers(const std::shared_ptr<TracingState> & state,Value * self_value,const Module & self,const std::string & prefix)426 static void gatherParametersAndBuffers(
427 const std::shared_ptr<TracingState>& state,
428 Value* self_value,
429 const Module& self,
430 const std::string& prefix) {
431 Graph& g = *self_value->owningGraph();
432
433 state->setValue(self._ivalue(), self_value);
434
435 auto self_ty = self.type();
436 for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
437 auto qualname = prefix + "." + s.name;
438 Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
439 ->s_(attr::scope, qualname)
440 ->output()
441 ->setType(s.value.type());
442 if (s.value.type()->isSubtypeOf(*TensorType::get())) {
443 addInput(state, s.value, s.value.type(), trace_get_attr);
444 }
445 if (isCustomClass(s.value)) {
446 tracer::setValueTrace(s.value, trace_get_attr);
447 }
448
449 auto attr_type = self_ty->getAttribute(s.name);
450 // Skipping Parameters and Buffers that are behind an `InterfaceType`
451 // because it is illegal for InterfaceType to expose any attribute.
452 // And these attributes should never be used/exposed outside of
453 // InterfaceType'd module anyway.
454 if (attr_type->is_module() &&
455 attr_type->kind() != TypeKind::InterfaceType) {
456 gatherParametersAndBuffers(
457 state, trace_get_attr, Module(s.value.toObject()), qualname);
458 }
459 }
460 }
461
trace(Stack inputs,const std::function<Stack (Stack)> & traced_fn,std::function<std::string (const Variable &)> var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)462 std::pair<std::shared_ptr<TracingState>, Stack> trace(
463 Stack inputs,
464 const std::function<Stack(Stack)>& traced_fn,
465 std::function<std::string(const Variable&)> var_name_lookup_fn,
466 bool strict,
467 bool force_outplace,
468 Module* self,
469 const std::vector<std::string>& argument_names) {
470 try {
471 // Start tracing, treating 'inputs' as inputs to the trace, which can be
472 // varied on subsequent invocations of the trace. Any other variables
473 // will be treated as constants.
474 if (isTracing()) {
475 AT_ERROR("Tracing can't be nested");
476 }
477 auto state = std::make_shared<TracingState>();
478 setTracingState(state);
479
480 // if we are a module, then make sure the modules parameters are in the map
481 // and mapped to accesses to the self object
482 if (self) {
483 Value* self_value = state->graph->insertInput(0, "self")->setType(
484 self->_ivalue()->type());
485 gatherParametersAndBuffers(state, self_value, *self, {"__module"});
486 }
487
488 // When enough argument name hints are provided, use them as debug names
489 // for traced function/modules.
490 // Here argument_names is allowed to have more names than needed because
491 // some arguments may have valid default values, therefore they don't need
492 // example inputs.
493 if (argument_names.size() >= inputs.size()) {
494 for (size_t i = 0, e = inputs.size(); i < e; ++i) {
495 IValue& input = inputs[i];
496 input = addInput(
497 state,
498 input,
499 input.type(),
500 state->graph->addInput(argument_names[i]));
501 }
502 } else {
503 for (IValue& input : inputs) {
504 input = addInput(state, input, input.type(), state->graph->addInput());
505 }
506 }
507
508 auto graph = state->graph;
509
510 getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
511 getTracingState()->strict = strict;
512 getTracingState()->force_outplace = force_outplace;
513
514 // Invoke the traced function
515 auto out_stack = traced_fn(inputs);
516
517 // Exit a trace, treating 'out_stack' as the outputs of the trace. These
518 // are the variables whose values will be computed upon subsequent
519 // invocations of the trace.
520 size_t i = 0;
521 for (auto& output : out_stack) {
522 // NB: The stack is in "reverse" order, so when we pass the diagnostic
523 // number we need to flip it based on size.
524 state->graph->registerOutput(
525 state->getOutput(output, out_stack.size() - i));
526 i++;
527 }
528 setTracingState(nullptr);
529
530 if (getInlineEverythingMode()) {
531 Inline(*graph);
532 }
533 FixupTraceScopeBlocks(graph, self);
534 NormalizeOps(graph);
535 return {state, out_stack};
536 } catch (...) {
537 tracer::abandon();
538 throw;
539 }
540 }
541
542 // Abort tracing. Used to reset the state in case of errors.
abandon()543 void abandon() {
544 setTracingState(nullptr);
545 }
546
setValueTrace(const IValue & v,Value * value)547 void setValueTrace(const IValue& v, Value* value) {
548 return getTracingState()->setValue(v, value);
549 }
setValue(const IValue & v,Value * value)550 void TracingState::setValue(const IValue& v, Value* value) {
551 if (v.isTensor()) {
552 auto& var = v.toTensor();
553 AT_ASSERT(var.defined());
554 env_stack.back()[v] = value;
555
556 // If the value comes from a CallFunction or CallMethod, it may not have
557 // shape information attached. For debuggability, we enhance the type
558 // information by assigning the concrete value's tupe to the jit::Value.
559 if (auto tensor_type = value->type()->cast<TensorType>()) {
560 if (!tensor_type->isComplete()) {
561 value->inferTypeFrom(var);
562 }
563 }
564 } else if (v.isTensorList()) {
565 auto outputs = v.toTensorList();
566 Node* unpack_node =
567 graph->insertNode(graph->createListUnpack(value, outputs.size()));
568 for (const auto i : c10::irange(outputs.size())) {
569 setValue(outputs.get(i), unpack_node->outputs()[i]);
570 }
571 } else if (v.isTuple()) {
572 const auto& outputs = v.toTupleRef().elements();
573 Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
574 for (const auto i : c10::irange(outputs.size())) {
575 setValue(outputs[i], unpack_node->outputs()[i]);
576 }
577 } else if (v.isList()) {
578 auto elements = v.toListRef();
579 Node* unpack_node =
580 graph->insertNode(graph->createListUnpack(value, elements.size()));
581 for (const auto i : c10::irange(elements.size())) {
582 setValue(elements[i], unpack_node->outputs()[i]);
583 }
584 } else if (isCustomClass(v)) {
585 auto capsule = v.toObject()->getAttr("capsule");
586 env_stack.back()[capsule] = value;
587 } else if (v.isFuture() || v.isObject()) {
588 env_stack.back()[v] = value;
589 } else if (v.isGenericDict()) {
590 auto dict = v.toGenericDict();
591 TypePtr key_type = dict.keyType();
592 TypePtr value_type = dict.valueType();
593 for (const auto& entry : dict) {
594 auto static_key = graph->insertConstant(entry.key());
595 auto static_value = graph->insert(aten::__getitem__, {value, static_key});
596 setValue(entry.value(), static_value);
597 }
598 } else {
599 std::ostringstream os;
600 os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
601 << "Supported types are tensor, tensor list, and tuple of tensors.";
602 throw std::runtime_error(os.str());
603 }
604 }
605
addInputs(Node * n,const char * name,int64_t value)606 void addInputs(Node* n, const char* name, int64_t value) {
607 using ArgumentStash = jit::tracer::ArgumentStash;
608 if (ArgumentStash::hasValue(name)) {
609 Value* v = ArgumentStash::popValue(name);
610 n->addInput(v);
611 } else {
612 detail::genericAddInput(n, value);
613 }
614 }
615
addInputs(Node * n,const char * name,const c10::SymInt & value)616 void addInputs(Node* n, const char* name, const c10::SymInt& value) {
617 addInputs(n, name, value.guard_int(__FILE__, __LINE__));
618 }
619
addInputs(Node * n,const char * name,std::optional<int64_t> value)620 void addInputs(Node* n, const char* name, std::optional<int64_t> value) {
621 using ArgumentStash = jit::tracer::ArgumentStash;
622 if (ArgumentStash::hasValue(name)) {
623 Value* v = ArgumentStash::popValue(name);
624 n->addInput(v);
625 } else if (value) {
626 detail::genericAddInput(n, *value);
627 } else {
628 Graph* g = n->owningGraph();
629 Value* none = g->insertNode(g->createNone())->output();
630 n->addInput(none);
631 }
632 }
addInputs(Node * n,const char * name,bool value)633 void addInputs(Node* n, const char* name, bool value) {
634 detail::genericAddInput(n, value);
635 }
addInputs(Node * n,const char * name,const std::optional<bool> & value)636 void addInputs(Node* n, const char* name, const std::optional<bool>& value) {
637 detail::genericAddOptionalInput(n, name, value);
638 }
addInputs(Node * n,const char * name,double value)639 void addInputs(Node* n, const char* name, double value) {
640 detail::genericAddInput(n, value);
641 }
addInputs(Node * n,const char * name,const std::optional<double> & value)642 void addInputs(Node* n, const char* name, const std::optional<double>& value) {
643 detail::genericAddOptionalInput(n, name, value);
644 }
addInputs(Node * n,const char * name,const at::Scalar & value)645 void addInputs(Node* n, const char* name, const at::Scalar& value) {
646 using ArgumentStash = jit::tracer::ArgumentStash;
647 if (ArgumentStash::hasValue(name)) {
648 Value* v = ArgumentStash::popValue(name);
649 n->addInput(v);
650 } else {
651 detail::genericAddInput(n, value);
652 }
653 }
addInputs(Node * n,const char * name,const std::optional<at::Scalar> & value)654 void addInputs(
655 Node* n,
656 const char* name,
657 const std::optional<at::Scalar>& value) {
658 detail::genericAddOptionalInput(n, name, value);
659 }
addInputs(Node * n,const char * name,const c10::string_view value)660 void addInputs(Node* n, const char* name, const c10::string_view value) {
661 detail::genericAddInput(n, std::string(value));
662 }
addInputs(Node * n,const char * name,const std::optional<c10::string_view> & value)663 void addInputs(
664 Node* n,
665 const char* name,
666 const std::optional<c10::string_view>& value) {
667 detail::genericAddOptionalInput(n, name, value);
668 }
addInputs(Node * n,const char * name,const at::Tensor & value)669 void addInputs(Node* n, const char* name, const at::Tensor& value) {
670 n->addInput(getValueTrace(value));
671 }
addInputs(Node * n,const char * name,const std::optional<at::Tensor> & value)672 void addInputs(
673 Node* n,
674 const char* name,
675 const std::optional<at::Tensor>& value) {
676 detail::genericAddOptionalInput(n, name, value);
677 }
addInputs(Node * n,const char * name,const std::optional<at::Generator> & value)678 void addInputs(
679 Node* n,
680 const char* name,
681 const std::optional<at::Generator>& value) {
682 Graph* g = n->owningGraph();
683
684 if (value.has_value() && value->defined()) {
685 detail::genericAddInput(n, *value);
686 } else {
687 Value* undef_gen = g->insertNode(g->createNone())->output();
688 n->addInput(undef_gen);
689 }
690 }
addInputs(Node * n,const char * name,at::Device value)691 void addInputs(Node* n, const char* name, at::Device value) {
692 detail::genericAddInput(n, value);
693 }
addInputs(Node * n,const char * name,c10::Stream stream)694 void addInputs(Node* n, const char* name, c10::Stream stream) {
695 detail::genericAddInput(n, c10::IValue(stream));
696 }
addInputs(Node * n,const char * name,at::Layout value)697 void addInputs(Node* n, const char* name, at::Layout value) {
698 detail::genericAddInput(n, static_cast<int64_t>(value));
699 }
addInputs(Node * n,const char * name,at::ScalarType value)700 void addInputs(Node* n, const char* name, at::ScalarType value) {
701 detail::genericAddInput(n, static_cast<int64_t>(value));
702 }
addInputs(Node * n,const char * name,at::MemoryFormat value)703 void addInputs(Node* n, const char* name, at::MemoryFormat value) {
704 detail::genericAddInput(n, static_cast<int64_t>(value));
705 }
addInputs(Node * n,const char * name,const std::optional<at::MemoryFormat> & value)706 void addInputs(
707 Node* n,
708 const char* name,
709 const std::optional<at::MemoryFormat>& value) {
710 detail::genericAddOptionalInput(n, name, value);
711 }
addInputs(Node * n,const char * name,const std::optional<at::Layout> & value)712 void addInputs(
713 Node* n,
714 const char* name,
715 const std::optional<at::Layout>& value) {
716 detail::genericAddOptionalInput(n, name, value);
717 }
addInputs(Node * n,const char * name,const std::optional<at::Device> & value)718 void addInputs(
719 Node* n,
720 const char* name,
721 const std::optional<at::Device>& value) {
722 detail::genericAddOptionalInput(n, name, value);
723 }
addInputs(Node * n,const char * name,std::optional<at::DimnameList> value)724 void addInputs(
725 Node* n,
726 const char* name,
727 std::optional<at::DimnameList> value) {
728 TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer");
729 }
addInputs(Node * n,const char * name,const std::optional<at::ScalarType> & value)730 void addInputs(
731 Node* n,
732 const char* name,
733 const std::optional<at::ScalarType>& value) {
734 detail::genericAddOptionalInput(n, name, value);
735 }
addInputs(Node * n,const char * name,at::ArrayRef<at::Tensor> value,bool allow_undefined)736 void addInputs(
737 Node* n,
738 const char* name,
739 at::ArrayRef<at::Tensor> value,
740 bool allow_undefined) {
741 addInputs(n, name, at::ITensorListRef(value), allow_undefined);
742 }
addInputs(Node * n,const char * name,const std::vector<at::Tensor> & value,bool allow_undefined)743 void addInputs(
744 Node* n,
745 const char* name,
746 const std::vector<at::Tensor>& value,
747 bool allow_undefined) {
748 addInputs(n, name, at::ITensorListRef(value), allow_undefined);
749 }
addInputs(Node * n,const char * name,at::ITensorListRef value,bool allow_undefined)750 void addInputs(
751 Node* n,
752 const char* name,
753 at::ITensorListRef value,
754 bool allow_undefined) {
755 Graph* g = n->owningGraph();
756 Node* list_node = nullptr;
757 if (allow_undefined) {
758 // if allow undefined, we create a list of optional tensors
759 list_node = g->insertNode(
760 g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace)));
761 } else {
762 list_node = g->insertNode(
763 g->createList(TensorType::get(), fmap(value, getValueTrace)));
764 }
765 n->addInput(list_node->output());
766 }
addInputs(Node * n,const char * name,const List<std::optional<at::Tensor>> & value)767 TORCH_API void addInputs(
768 Node* n,
769 const char* name,
770 const List<std::optional<at::Tensor>>& value) {
771 Graph* g = n->owningGraph();
772 Node* list_node = nullptr;
773 list_node = g->insertNode(g->createList(
774 OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
775 n->addInput(list_node->output());
776 }
addInputs(Node * n,const char * name,ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,const ClassTypePtr & class_type)777 void addInputs(
778 Node* n,
779 const char* name,
780 ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
781 const ClassTypePtr& class_type) {
782 Graph* g = n->owningGraph();
783 Node* list_node =
784 g->insertNode(g->createList(class_type, fmap(value, getValueTrace)));
785 n->addInput(list_node->output());
786 }
787
addInputs(Node * n,const char * name,at::IntArrayRef value)788 void addInputs(Node* n, const char* name, at::IntArrayRef value) {
789 using ArgumentStash = jit::tracer::ArgumentStash;
790 std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)
791 ? ArgumentStash::popIntArrayRef(name)
792 : ArgumentStash::IntArrayRefTrace(value.size());
793
794 auto& g = getTracingState()->graph;
795 for (const auto i : c10::irange(info.size())) {
796 if (info[i] != nullptr)
797 continue;
798 info[i] = g->insertConstant(value[i]);
799 recordSourceLocation(info[i]->node());
800 }
801 for (jit::Value* v : info) {
802 if (*v->type() != *jit::IntType::get()) {
803 throw std::runtime_error(
804 "Type mismatch in setposattr for IntArrayRef. Check that your program "
805 "is valid without tracing, and please file a bug report if it is.");
806 }
807 }
808 n->addInput(
809 g->insertNode(g->createList(jit::IntType::get(), info))->output());
810 }
811
addInputs(Node * n,const char * name,c10::SymIntArrayRef value)812 void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
813 addInputs(n, name, C10_AS_INTARRAYREF_SLOW(value));
814 }
815
addInputs(Node * n,const char * name,std::optional<c10::SymInt> value)816 void addInputs(Node* n, const char* name, std::optional<c10::SymInt> value) {
817 addInputs(
818 n,
819 name,
820 value.has_value()
821 ? std::make_optional(value->guard_int(__FILE__, __LINE__))
822 : std::nullopt);
823 }
824
addInputs(Node * n,const char * name,const std::optional<at::IntArrayRef> & opt_value)825 void addInputs(
826 Node* n,
827 const char* name,
828 const std::optional<at::IntArrayRef>& opt_value) {
829 detail::genericAddOptionalInput(n, name, opt_value);
830 }
831
addInputs(Node * n,const char * name,const at::OptionalIntArrayRef & opt_value)832 void addInputs(
833 Node* n,
834 const char* name,
835 const at::OptionalIntArrayRef& opt_value) {
836 if (opt_value.has_value()) {
837 jit::tracer::addInputs(n, name, *opt_value);
838 } else {
839 Graph* g = n->owningGraph();
840 Value* none = g->insertNode(g->createNone())->output();
841 n->addInput(none);
842 }
843 }
844
addInputs(Node * n,const char * name,const at::OptionalSymIntArrayRef & opt_value)845 void addInputs(
846 Node* n,
847 const char* name,
848 const at::OptionalSymIntArrayRef& opt_value) {
849 if (opt_value.has_value()) {
850 jit::tracer::addInputs(n, name, *opt_value);
851 } else {
852 Graph* g = n->owningGraph();
853 Value* none = g->insertNode(g->createNone())->output();
854 n->addInput(none);
855 }
856 }
857
addInputs(Node * n,const char * name,ArrayRef<double> value)858 void addInputs(Node* n, const char* name, ArrayRef<double> value) {
859 std::vector<Value*> info;
860 auto& g = getTracingState()->graph;
861 for (double elt : value) {
862 info.push_back(g->insertConstant(elt));
863 recordSourceLocation(info.back()->node());
864 }
865 n->addInput(
866 g->insertNode(g->createList(jit::FloatType::get(), info))->output());
867 }
868
addInputs(Node * n,const char * name,const std::optional<c10::ArrayRef<double>> & opt_value)869 void addInputs(
870 Node* n,
871 const char* name,
872 const std::optional<c10::ArrayRef<double>>& opt_value) {
873 detail::genericAddOptionalInput(n, name, opt_value);
874 }
875
addInputs(Node * n,const char * name,const c10::intrusive_ptr<c10::ivalue::Object> & obj)876 void addInputs(
877 Node* n,
878 const char* name,
879 const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
880 Value* v = getValueTrace(obj);
881 n->addInput(v);
882 }
883
addOutput(Node * node,const at::Tensor & output)884 void addOutput(Node* node, const at::Tensor& output) {
885 setOutput(node->addOutput(), output);
886 }
887
setOutput(Value * value,const at::Tensor & output)888 void setOutput(Value* value, const at::Tensor& output) {
889 if (output.defined()) {
890 value->inferTypeFrom(output);
891 setValueTrace(output, value);
892 }
893 }
894
addOutput(Node * node,const std::vector<at::Tensor> & outputs)895 void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
896 Value* value = node->addOutput()->setType(ListType::ofTensors());
897 Graph* graph = node->owningGraph();
898 Node* unpack_node = graph->insertNode(
899 graph->create(prim::ListUnpack, {value}, outputs.size()));
900 for (const auto i : c10::irange(outputs.size())) {
901 Value* output_val = unpack_node->outputs()[i];
902 output_val->inferTypeFrom(outputs[i]);
903 setValueTrace(outputs[i], output_val);
904 }
905 }
906
addOutput(Node * node,const c10::List<at::Tensor> & outputs)907 void addOutput(Node* node, const c10::List<at::Tensor>& outputs) {
908 return addOutput(node, outputs.vec());
909 }
910
addOutput(Node * node,const c10::intrusive_ptr<c10::ivalue::Object> & output)911 void addOutput(
912 Node* node,
913 const c10::intrusive_ptr<c10::ivalue::Object>& output) {
914 Value* output_val = node->addOutput();
915 output_val->inferTypeFrom(output);
916 setValueTrace(output, output_val);
917 }
918
getTracingState()919 const std::shared_ptr<TracingState>& getTracingState() {
920 return detail::tracing_state;
921 }
922
setTracingState(std::shared_ptr<TracingState> state)923 void setTracingState(std::shared_ptr<TracingState> state) {
924 at::tracer::impl::set_dispatch_enabled(state != nullptr);
925 detail::tracing_state = std::move(state);
926 }
927
TracingState()928 TracingState::TracingState() : graph(new Graph()), env_stack{Frame()} {}
929
930 TracingState::~TracingState() = default;
931
getSizeOf(const autograd::Variable & var,int64_t dim)932 autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
933 auto& tracing_state = getTracingState();
934 auto& graph = tracing_state->graph;
935
936 Variable size_var;
937 {
938 // Make sure this scalar to tensor isn't traced!
939 at::AutoDispatchBelowADInplaceOrView guard;
940 size_var = scalar_to_tensor(at::Scalar(var.size(dim)));
941 }
942 auto* value = getValueTrace(var);
943 auto dim_val = graph->insertConstant(dim);
944 recordSourceLocation(dim_val->node());
945 auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
946 recordSourceLocation(node);
947 node->output()->setType(jit::IntType::get());
948
949 auto ten =
950 graph->insertNode(graph->createNumToTensor(node->output()))->output();
951 setValueTrace(size_var, ten);
952 return size_var;
953 }
954
getNumelOf(const autograd::Variable & var)955 autograd::Variable getNumelOf(const autograd::Variable& var) {
956 auto& tracing_state = getTracingState();
957 auto& graph = tracing_state->graph;
958
959 Variable numel_var;
960 {
961 // Make sure this scalar to tensor isn't traced!
962 at::AutoDispatchBelowADInplaceOrView guard;
963 numel_var = scalar_to_tensor(at::Scalar(var.numel()));
964 }
965 auto* value = getValueTrace(var);
966 auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
967 recordSourceLocation(node);
968 node->output()->setType(jit::IntType::get());
969
970 auto ten =
971 graph->insertNode(graph->createNumToTensor(node->output()))->output();
972 setValueTrace(numel_var, ten);
973 return numel_var;
974 }
975
ensureUniqueIfOutOfPlaced(const char * name,const at::Tensor & tensor)976 void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
977 auto& state = getTracingState();
978 if (state && state->force_outplace == false) {
979 // If we're not converting in-place ops to out-of-place, this check is
980 // unnecessary
981 return;
982 }
983 auto aliases = tensor.storage().use_count();
984 if (isTracing() && aliases > 1) {
985 std::stringstream ss;
986 ss << "There are " << aliases
987 << " live references to the data region being modified when tracing in-place operator "
988 << name
989 << ". This might cause the trace to be incorrect, because all other views "
990 << "that also reference this data will not reflect this change in the trace! "
991 << "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
992 << "are outputs of torch.split), this might still be safe.";
993 warn(ss.str().c_str());
994 }
995 }
ensureUniqueIfOutOfPlaced(const char * name,const std::optional<at::Tensor> & tensor)996 void ensureUniqueIfOutOfPlaced(
997 const char* name,
998 const std::optional<at::Tensor>& tensor) {
999 ensureUniqueIfOutOfPlaced(name, tensor.has_value() ? *tensor : at::Tensor());
1000 }
1001
1002 ////////////////////////////////////////////////////////////////////////////////
1003 // Argument stash
1004 ////////////////////////////////////////////////////////////////////////////////
1005 thread_local ArgumentStash ArgumentStash::stash;
1006
stashIntArrayRefElem(const std::string & arg_name,size_t size,size_t idx,const Variable & var)1007 void ArgumentStash::stashIntArrayRefElem(
1008 const std::string& arg_name,
1009 size_t size,
1010 size_t idx,
1011 const Variable& var) {
1012 // TODO: check type?
1013 if (!isTracing())
1014 return;
1015 IntArrayRefTrace& list_trace =
1016 stash.intlists.emplace(arg_name, size).first->second;
1017 AT_ASSERT(size == list_trace.size());
1018 AT_ASSERT(idx < list_trace.size());
1019 AT_ASSERT(list_trace[idx] == nullptr);
1020
1021 Value* ten = getValueTrace(var);
1022 auto& g = *ten->owningGraph();
1023 WithInsertPoint guard(ten->node()->next());
1024 auto prim = g.insert(aten::Int, {ten});
1025 list_trace[idx] = prim;
1026 }
1027
stashValue(const std::string & arg_name,size_t idx,const Variable & var,const TypePtr & type)1028 void ArgumentStash::stashValue(
1029 const std::string& arg_name,
1030 size_t idx,
1031 const Variable& var,
1032 const TypePtr& type) {
1033 if (!isTracing())
1034 return;
1035
1036 Value* ten = getValueTrace(var);
1037 WithInsertPoint guard(ten->node()->next());
1038 auto& g = *ten->owningGraph();
1039
1040 if (type == IntType::get()) {
1041 ten = g.insert(aten::Int, {ten});
1042 } else if (type == FloatType::get()) {
1043 ten = g.insert(aten::Float, {ten});
1044 } else if (type == NumberType::get()) {
1045 ten = g.insert(aten::ScalarImplicit, {ten});
1046 }
1047
1048 stash.values.emplace(arg_name, ten);
1049 }
1050
1051 ////////////////////////////////////////////////////////////////////////////////
1052 // Stack trace recording
1053 ////////////////////////////////////////////////////////////////////////////////
1054 // no python present so we just do not record source information
defaultRecordSourceLocation(Node * n)1055 static void defaultRecordSourceLocation(Node* n) {}
1056 std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
1057 defaultRecordSourceLocation);
recordSourceLocation(Node * n)1058 void recordSourceLocation(Node* n) {
1059 return record_source_location.load()(n);
1060 }
setRecordSourceLocation(void (* v)(Node *))1061 void setRecordSourceLocation(void (*v)(Node*)) {
1062 record_source_location.store(v);
1063 }
1064
defaultPythonCallstack()1065 static std::vector<StackEntry> defaultPythonCallstack() {
1066 return std::vector<StackEntry>();
1067 }
1068 std::atomic<decltype(&defaultPythonCallstack)> python_callstack_fn(
1069 defaultPythonCallstack);
pythonCallstack()1070 std::vector<StackEntry> pythonCallstack() {
1071 return python_callstack_fn.load()();
1072 }
setPythonCallstack(std::vector<StackEntry> (* v)())1073 void setPythonCallstack(std::vector<StackEntry> (*v)()) {
1074 python_callstack_fn.store(v);
1075 }
1076
defaultWarn(const std::string & str)1077 static void defaultWarn(const std::string& str) {
1078 TORCH_WARN(str);
1079 }
1080 std::atomic<warn_fn_type> warn_callback{defaultWarn};
1081
1082 const char* WARN_PYTHON_DATAFLOW =
1083 " might cause the trace to be incorrect. We can't record the data flow of "
1084 "Python values, so this value will be treated as a constant in the future. "
1085 "This means that the trace might not generalize to other inputs!";
1086 const char* WARN_CONSTRUCTOR =
1087 " results are registered as constants in the trace. You can safely ignore this "
1088 "warning if you use this function to create tensors out of constant variables "
1089 "that would be the same every time you call this function. In any other case, "
1090 "this might cause the trace to be incorrect.";
1091 const char* WARN_RESIZE =
1092 " can't be represented in the JIT at the moment, so we won't connect any uses of "
1093 "this value with its current trace. If you happen to use it again, it will show "
1094 "up as a constant in the graph. Consider using `view` or `reshape` to make "
1095 "it traceable.";
1096 const char* STRICT_TRACER_MSG =
1097 " might cause the trace to be incorrect, this is only valid if the container "
1098 "structure does not change based on the module's inputs. Consider using a constant "
1099 "container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a "
1100 "`NamedTuple` instead). If you absolutely need this and know the side effects, pass "
1101 "strict=False to trace() to allow this behavior.";
1102 // XXX: _kind can be a nullptr
_do_warn(const char * _reason,const char * _kind)1103 void _do_warn(const char* _reason, const char* _kind) {
1104 std::string reason{_reason};
1105 std::string kind{_kind ? _kind : ""};
1106 std::ostringstream s;
1107 s << reason << kind;
1108 warn_callback.load()(s.str());
1109 }
1110
setWarn(warn_fn_type fn)1111 void setWarn(warn_fn_type fn) {
1112 warn_callback.store(fn);
1113 }
1114 } // namespace torch::jit::tracer
1115