xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/tracer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/tracer.h>
2 
3 #include <ATen/Backtrace.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/TracerMode.h>
6 #include <ATen/core/Dict.h>
7 #include <ATen/core/functional.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 #include <torch/csrc/autograd/engine.h>
11 #include <torch/csrc/autograd/function.h>
12 #include <torch/csrc/autograd/variable.h>
13 #include <torch/csrc/jit/api/module.h>
14 #include <torch/csrc/jit/ir/constants.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/passes/dead_code_elimination.h>
17 #include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
18 #include <torch/csrc/jit/passes/inliner.h>
19 #include <torch/csrc/jit/passes/lower_tuples.h>
20 #include <torch/csrc/jit/passes/normalize_ops.h>
21 #include <torch/csrc/jit/passes/remove_expands.h>
22 #include <torch/csrc/utils/variadic.h>
23 #include <torch/custom_class.h>
24 
25 #include <memory>
26 #include <sstream>
27 #include <string>
28 
29 namespace torch::jit::tracer {
30 
31 ////////////////////////////////////////////////////////////////////////////////
32 // Recording the traces
33 ////////////////////////////////////////////////////////////////////////////////
34 namespace detail {
35 
36 template <typename T>
genericAddInput(Node * n,T value)37 void genericAddInput(Node* n, T value) {
38   Value* v = n->owningGraph()->insertConstant(value);
39   recordSourceLocation(v->node());
40   n->addInput(v);
41 }
42 
43 template <typename T>
genericAddOptionalInput(Node * n,const char * name,const std::optional<T> & value)44 void genericAddOptionalInput(
45     Node* n,
46     const char* name,
47     const std::optional<T>& value) {
48   if (value) {
49     jit::tracer::addInputs(n, name, *value);
50   } else {
51     Graph* g = n->owningGraph();
52     Value* none = g->insertNode(g->createNone())->output();
53     n->addInput(none);
54   }
55 }
56 
57 template <typename T>
badArgType(const T & v)58 void badArgType(const T& v) {
59   AT_ERROR(
60       "Found an unsupported argument type in the JIT tracer: ",
61       c10::demangle_type<T>(),
62       ". File a bug report.");
63 }
64 
65 thread_local std::shared_ptr<TracingState> tracing_state;
66 } // namespace detail
67 
68 static std::atomic<bool> tracer_state_warn_mode{true};
69 
getTracerStateWarnMode()70 std::atomic<bool>& getTracerStateWarnMode() {
71   return tracer_state_warn_mode;
72 }
73 
pauseTracing()74 std::function<void()> pauseTracing() {
75   // NOLINTNEXTLINE
76   std::shared_ptr<tracer::TracingState> state = getTracingState();
77   tracer::setTracingState(nullptr);
78 
79   return [state]() { tracer::setTracingState(state); };
80 }
81 
delValueTrace(const IValue & var)82 void delValueTrace(const IValue& var) {
83   getTracingState()->delValue(var);
84 }
delValue(const IValue & var)85 void TracingState::delValue(const IValue& var) {
86   for (const auto i : c10::irange(env_stack.size())) {
87     auto& value_map = env_stack.at(env_stack.size() - 1 - i);
88     auto it = value_map.find(var);
89     if (it == value_map.end()) {
90       continue;
91     }
92     value_map.erase(it);
93   }
94 }
95 
96 // Given a IValue 'var', return the 'node' which represents the instruction
97 // which computes the value of this variable in the IR.
98 // Here, we interpret untraced variables as constants that are just embedded
99 // in the graph.  This is useful to handle code which does things like this
100 // (from torch.autograd.variable, now moved to C++):
101 //
102 //    def mm(self, matrix):
103 //      output = Variable(self.data.new(self.data.size(0), matrix.data.size(1)))
104 //      return Addmm.apply(output, self, matrix, 0, 1, True)
105 //
106 // Here, mm fakes up a dummy variable with uninitialized data to do an inplace
107 // update on, but subsequently ignores it because the alpha scaling factor is
108 // zero. This is one of the cases where a Variable can be created inside of a
109 // trace, and if we treat it as a constant, everything will work out.
getValueTrace(const IValue & var)110 Value* getValueTrace(const IValue& var) {
111   return getTracingState()->getValue(var);
112 }
getOptTensorValueTrace(const std::optional<at::Tensor> & var)113 static Value* getOptTensorValueTrace(const std::optional<at::Tensor>& var) {
114   return getValueTrace(IValue(var));
115 }
getValue(const IValue & var)116 Value* TracingState::getValue(const IValue& var) {
117   // allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
118   // arguments
119   if (var.isTensorList()) {
120     return graph
121         ->insertNode(graph->createList(
122             TensorType::get(),
123             fmap(
124                 var.toTensorVector(),
125                 [&](const IValue& val) { return getValue(val); })))
126         ->output();
127   } else if (var.isTuple()) {
128     return graph
129         ->insertNode(graph->createTuple(fmap(
130             var.toTupleRef().elements(),
131             [&](const IValue& val) { return getValue(val); })))
132         ->output();
133   } else if (var.isGenericDict()) {
134     auto dict = var.toGenericDict();
135     TypePtr key_type = dict.keyType();
136     TypePtr value_type = dict.valueType();
137     std::vector<Value*> keys;
138     std::vector<Value*> values;
139     for (const auto& entry : dict) {
140       keys.emplace_back(getValue(entry.key()));
141       values.emplace_back(getValue(entry.value()));
142     }
143     auto dict_node = graph->createDict(key_type, value_type, keys, values);
144     return graph->insertNode(dict_node)->output();
145   }
146   if (var.isTensor()) {
147     auto& ten = var.toTensor();
148     if (!ten.defined()) {
149       Node* n = graph->createNone();
150       return graph->insertNode(n)->output();
151     }
152     for (const auto i : c10::irange(env_stack.size())) {
153       auto& value_map = env_stack.at(env_stack.size() - 1 - i);
154       auto it = value_map.find(var);
155       if (it == value_map.end()) {
156         continue;
157       }
158       if (!it->second->hasDebugName()) {
159         auto unique_name = getTracingState()->lookup_var_name_fn(ten);
160         if (!unique_name.empty()) {
161           it->second->setDebugName(unique_name);
162         }
163       }
164       return it->second;
165     }
166 
167     // Didn't find it. Bake in a constant
168     if (ten.requires_grad()) {
169       pauseTracing();
170       std::ostringstream oss;
171       oss << "Cannot insert a Tensor that requires grad as a constant. "
172           << "Consider making it a parameter or input, or detaching the gradient\n"
173           << "Tensor:\n"
174           << ten;
175       throw std::runtime_error(oss.str());
176     }
177 
178     Value* constant = graph->insertConstant(ten);
179     recordSourceLocation(constant->node());
180     constant->inferTypeFrom(ten);
181     auto it = env_stack.back().emplace(var, constant);
182     return it.first->second;
183   } else if (var.isFuture() || var.isObject()) {
184     for (const auto i : c10::irange(env_stack.size())) {
185       auto& future_map = env_stack.at(env_stack.size() - 1 - i);
186       auto it = future_map.find(var);
187       if (it == future_map.end()) {
188         continue;
189       }
190       return it->second;
191     }
192 
193     // Find torchbind classes
194     if (isCustomClass(var)) {
195       auto obj = Object(var.toObject());
196       auto qualname = obj.type()->name();
197       auto custom_class_type = getCustomClass(qualname->qualifiedName());
198       if (custom_class_type) {
199         auto capsule = var.toObject()->getAttr("capsule");
200         for (const auto i : c10::irange(env_stack.size())) {
201           auto& value_map = env_stack.at(env_stack.size() - 1 - i);
202           auto it = value_map.find(capsule);
203           if (it == value_map.end()) {
204             continue;
205           }
206           return it->second;
207         }
208       }
209     }
210 
211     std::ostringstream oss;
212     if (var.isFuture()) {
213       oss << "Tried to trace Future or Object that the tracer was not aware of.";
214     } else {
215       oss << "Tried to trace " << var
216           << " but it is not part of the active trace. Modules that are called during a trace"
217           << " must be registered as submodules of the thing being traced.";
218     }
219     throw std::runtime_error(oss.str());
220   } else {
221     // If the values are non-tensors, we try to create constants
222     // and bake those constants into the traced graph
223     auto constant = tryInsertConstant(*graph, var);
224     if (constant) {
225       recordSourceLocation(constant.value()->node());
226       return *constant;
227     }
228     std::ostringstream os;
229     os << "Tracer cannot get value trace for type " << var.tagKind() << ". "
230        << "The below value could not be materialized as a constant:\n"
231        << var;
232     throw std::runtime_error(os.str());
233   }
234 }
hasValue(const IValue & var) const235 bool TracingState::hasValue(const IValue& var) const {
236   for (const auto& frame : env_stack) {
237     if (frame.count(var)) {
238       return true;
239     }
240   }
241   return false;
242 }
243 
getOutput(const IValue & iv,size_t i)244 Value* TracingState::getOutput(const IValue& iv, size_t i) {
245   bool tracing_mode_strict = getTracingState()->strict;
246   if (iv.isTensor()) {
247     const at::Tensor& var = iv.toTensor();
248     if (!var.defined()) {
249       Node* n = graph->createNone();
250       return graph->insertNode(n)->output();
251     }
252 
253     auto& value_map = getTracingState()->env_stack.back();
254     auto it = value_map.find(iv);
255     if (it == value_map.end()) {
256       std::ostringstream os;
257       os << "output " << i << " (" << var
258          << ") of traced region did not have observable "
259          << "data dependence with trace inputs; this probably indicates your "
260             "program "
261          << "cannot be understood by the tracer.";
262       throw std::runtime_error(os.str());
263     }
264     return it->second;
265   } else if (iv.isTensorList()) {
266     if (tracing_mode_strict) {
267       tracer::warn(
268           "Encountering a list at the output of the tracer", STRICT_TRACER_MSG);
269     }
270     return graph
271         ->insertNode(graph->createList(
272             TensorType::get(),
273             fmap(
274                 iv.toTensorVector(),
275                 [&](const IValue& ival) { return getOutput(ival, i); })))
276         ->output();
277   } else if (iv.isTuple()) {
278     const auto& tuple = iv.toTupleRef().elements();
279     auto tuple_node = graph->createTuple(
280         fmap(tuple, [&](const IValue& ival) { return getOutput(ival, i); }));
281     graph->insertNode(tuple_node);
282     return tuple_node->output();
283   } else if (iv.isGenericDict()) {
284     if (tracing_mode_strict) {
285       throw std::runtime_error(
286           "Encountering a dict at the output of the tracer" +
287           std::string(STRICT_TRACER_MSG));
288     }
289     auto dict = iv.toGenericDict();
290     TypePtr key_type = dict.keyType();
291     TypePtr value_type = dict.valueType();
292 
293     bool key_type_valid = key_type->isSubtypeOf(*StringType::get()) ||
294         key_type->isSubtypeOf(*TensorType::get());
295     bool value_type_valid = value_type->isSubtypeOf(*TensorType::get());
296 
297     // Support tuple values that contain only tensors
298     if (value_type->isSubtypeOf(*AnyTupleType::get())) {
299       value_type_valid = true;
300       for (const auto& type : value_type->containedTypes()) {
301         if (!type->isSubtypeOf(*TensorType::get())) {
302           value_type_valid = false;
303           break;
304         }
305       }
306     }
307 
308     if (!key_type_valid || !value_type_valid) {
309       std::ostringstream os;
310       os << "output " << i << " (" << dict << ") of traced region "
311          << "cannot be understood by the tracer, only outputs matching"
312          << "dict[Union[str, Tensor], Union[Tensor, Tuple[Tensor, ...]]] "
313          << "can be a dictionary output of a traced function";
314       throw std::runtime_error(os.str());
315     }
316     std::vector<Value*> keys;
317     std::vector<Value*> values;
318     for (const auto& entry : dict) {
319       keys.emplace_back(getValue(entry.key()));
320       values.emplace_back(getOutput(entry.value(), i));
321     }
322     auto dict_node = graph->createDict(key_type, value_type, keys, values);
323     graph->insertNode(dict_node);
324     return dict_node->output();
325   } else {
326     AT_ERROR(
327         "Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions");
328   }
329 }
330 
createNode(c10::Symbol op_name,size_t num_outputs)331 Node* TracingState::createNode(c10::Symbol op_name, size_t num_outputs) {
332   return graph->create(op_name, num_outputs);
333 }
334 
insertNode(Node * node)335 void TracingState::insertNode(Node* node) {
336   graph->insertNode(node);
337 }
338 
339 // XXX: this function mutates input
addInput(const std::shared_ptr<TracingState> & state,const IValue & input,const TypePtr & type,Value * value)340 static IValue addInput(
341     const std::shared_ptr<TracingState>& state,
342     const IValue& input,
343     const TypePtr& type,
344     Value* value) {
345   value->setType(type);
346   if (type->isSubtypeOf(*TensorType::get())) {
347     auto input_tensor = input.toTensor();
348     auto name = Variable(input_tensor).name();
349     if (state->hasValue(input)) {
350       input_tensor = input_tensor.view(input_tensor.sizes());
351     }
352     if (!value->hasDebugName()) {
353       value->setDebugName(name);
354     }
355     state->setValue(input_tensor, value);
356     return input_tensor;
357   } else if (auto tuple_type = type->cast<TupleType>()) {
358     auto unpack_node =
359         state->graph->insertNode(state->graph->createTupleUnpack(value));
360     auto elem_values = unpack_node->outputs();
361     auto elem_types = tuple_type->elements();
362     auto tuple = input.toTuple();
363     const auto& elems = tuple->elements();
364     size_t num_elems = elems.size();
365     AT_ASSERT(
366         elem_values.size() == num_elems && elem_types.size() == num_elems);
367     for (const auto i : c10::irange(num_elems)) {
368       tuple->unsafeSetElement(
369           i, addInput(state, elems.at(i), elem_types[i], elem_values[i]));
370     }
371     return tuple;
372   } else if (auto dict_type = type->cast<DictType>()) {
373     auto dict = input.toGenericDict();
374 
375     // Unpack the list values statically
376     for (const auto& entry : dict) {
377       const IValue& key = entry.key();
378       auto static_key = state->graph->insertConstant(key);
379       auto static_value =
380           state->graph->insert(aten::__getitem__, {value, static_key});
381       recordSourceLocation(static_value->node());
382       dict.insert_or_assign(
383           entry.key(),
384           addInput(
385               state, entry.value(), dict_type->getValueType(), static_value));
386     }
387 
388     return dict;
389   } else if (auto list_type = type->cast<ListType>()) {
390     size_t num_elems = input.isList() ? input.toListRef().size()
391                                       : input.toTensorVector().size();
392     auto list_unpack = state->graph->insertNode(
393         state->graph->createListUnpack(value, num_elems));
394     auto unpack_outputs = list_unpack->outputs();
395 
396     if (input.isTensorList()) {
397       auto elems = input.toTensorList();
398       for (const auto i : c10::irange(num_elems)) {
399         elems[i] = addInput(
400                        state,
401                        elems.get(i),
402                        list_type->getElementType(),
403                        unpack_outputs[i])
404                        .toTensor();
405       }
406       return elems;
407     } else {
408       auto elems = input.toList();
409       for (const auto i : c10::irange(num_elems)) {
410         elems[i] = addInput(
411             state,
412             elems.get(i),
413             list_type->getElementType(),
414             unpack_outputs[i]);
415       }
416       return elems;
417     }
418   } else {
419     AT_ERROR(
420         "Only tensors or (possibly nested) dict or tuples of tensors can be "
421         "inputs to traced functions. Got ",
422         type->repr_str());
423   }
424 }
425 
gatherParametersAndBuffers(const std::shared_ptr<TracingState> & state,Value * self_value,const Module & self,const std::string & prefix)426 static void gatherParametersAndBuffers(
427     const std::shared_ptr<TracingState>& state,
428     Value* self_value,
429     const Module& self,
430     const std::string& prefix) {
431   Graph& g = *self_value->owningGraph();
432 
433   state->setValue(self._ivalue(), self_value);
434 
435   auto self_ty = self.type();
436   for (const NameValue& s : self.named_attributes(/*recurse=*/false)) {
437     auto qualname = prefix + "." + s.name;
438     Value* trace_get_attr = g.insertNode(g.create(prim::TracedAttr))
439                                 ->s_(attr::scope, qualname)
440                                 ->output()
441                                 ->setType(s.value.type());
442     if (s.value.type()->isSubtypeOf(*TensorType::get())) {
443       addInput(state, s.value, s.value.type(), trace_get_attr);
444     }
445     if (isCustomClass(s.value)) {
446       tracer::setValueTrace(s.value, trace_get_attr);
447     }
448 
449     auto attr_type = self_ty->getAttribute(s.name);
450     // Skipping Parameters and Buffers that are behind an `InterfaceType`
451     // because it is illegal for InterfaceType to expose any attribute.
452     // And these attributes should never be used/exposed outside of
453     // InterfaceType'd module anyway.
454     if (attr_type->is_module() &&
455         attr_type->kind() != TypeKind::InterfaceType) {
456       gatherParametersAndBuffers(
457           state, trace_get_attr, Module(s.value.toObject()), qualname);
458     }
459   }
460 }
461 
trace(Stack inputs,const std::function<Stack (Stack)> & traced_fn,std::function<std::string (const Variable &)> var_name_lookup_fn,bool strict,bool force_outplace,Module * self,const std::vector<std::string> & argument_names)462 std::pair<std::shared_ptr<TracingState>, Stack> trace(
463     Stack inputs,
464     const std::function<Stack(Stack)>& traced_fn,
465     std::function<std::string(const Variable&)> var_name_lookup_fn,
466     bool strict,
467     bool force_outplace,
468     Module* self,
469     const std::vector<std::string>& argument_names) {
470   try {
471     // Start tracing, treating 'inputs' as inputs to the trace, which can be
472     // varied on subsequent invocations of the trace.  Any other variables
473     // will be treated as constants.
474     if (isTracing()) {
475       AT_ERROR("Tracing can't be nested");
476     }
477     auto state = std::make_shared<TracingState>();
478     setTracingState(state);
479 
480     // if we are a module, then make sure the modules parameters are in the map
481     // and mapped to accesses to the self object
482     if (self) {
483       Value* self_value = state->graph->insertInput(0, "self")->setType(
484           self->_ivalue()->type());
485       gatherParametersAndBuffers(state, self_value, *self, {"__module"});
486     }
487 
488     // When enough argument name hints are provided, use them as debug names
489     // for traced function/modules.
490     // Here argument_names is allowed to have more names than needed because
491     // some arguments may have valid default values, therefore they don't need
492     // example inputs.
493     if (argument_names.size() >= inputs.size()) {
494       for (size_t i = 0, e = inputs.size(); i < e; ++i) {
495         IValue& input = inputs[i];
496         input = addInput(
497             state,
498             input,
499             input.type(),
500             state->graph->addInput(argument_names[i]));
501       }
502     } else {
503       for (IValue& input : inputs) {
504         input = addInput(state, input, input.type(), state->graph->addInput());
505       }
506     }
507 
508     auto graph = state->graph;
509 
510     getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
511     getTracingState()->strict = strict;
512     getTracingState()->force_outplace = force_outplace;
513 
514     // Invoke the traced function
515     auto out_stack = traced_fn(inputs);
516 
517     // Exit a trace, treating 'out_stack' as the outputs of the trace.  These
518     // are the variables whose values will be computed upon subsequent
519     // invocations of the trace.
520     size_t i = 0;
521     for (auto& output : out_stack) {
522       // NB: The stack is in "reverse" order, so when we pass the diagnostic
523       // number we need to flip it based on size.
524       state->graph->registerOutput(
525           state->getOutput(output, out_stack.size() - i));
526       i++;
527     }
528     setTracingState(nullptr);
529 
530     if (getInlineEverythingMode()) {
531       Inline(*graph);
532     }
533     FixupTraceScopeBlocks(graph, self);
534     NormalizeOps(graph);
535     return {state, out_stack};
536   } catch (...) {
537     tracer::abandon();
538     throw;
539   }
540 }
541 
542 // Abort tracing. Used to reset the state in case of errors.
abandon()543 void abandon() {
544   setTracingState(nullptr);
545 }
546 
setValueTrace(const IValue & v,Value * value)547 void setValueTrace(const IValue& v, Value* value) {
548   return getTracingState()->setValue(v, value);
549 }
setValue(const IValue & v,Value * value)550 void TracingState::setValue(const IValue& v, Value* value) {
551   if (v.isTensor()) {
552     auto& var = v.toTensor();
553     AT_ASSERT(var.defined());
554     env_stack.back()[v] = value;
555 
556     // If the value comes from a CallFunction or CallMethod, it may not have
557     // shape information attached. For debuggability, we enhance the type
558     // information by assigning the concrete value's tupe to the jit::Value.
559     if (auto tensor_type = value->type()->cast<TensorType>()) {
560       if (!tensor_type->isComplete()) {
561         value->inferTypeFrom(var);
562       }
563     }
564   } else if (v.isTensorList()) {
565     auto outputs = v.toTensorList();
566     Node* unpack_node =
567         graph->insertNode(graph->createListUnpack(value, outputs.size()));
568     for (const auto i : c10::irange(outputs.size())) {
569       setValue(outputs.get(i), unpack_node->outputs()[i]);
570     }
571   } else if (v.isTuple()) {
572     const auto& outputs = v.toTupleRef().elements();
573     Node* unpack_node = graph->insertNode(graph->createTupleUnpack(value));
574     for (const auto i : c10::irange(outputs.size())) {
575       setValue(outputs[i], unpack_node->outputs()[i]);
576     }
577   } else if (v.isList()) {
578     auto elements = v.toListRef();
579     Node* unpack_node =
580         graph->insertNode(graph->createListUnpack(value, elements.size()));
581     for (const auto i : c10::irange(elements.size())) {
582       setValue(elements[i], unpack_node->outputs()[i]);
583     }
584   } else if (isCustomClass(v)) {
585     auto capsule = v.toObject()->getAttr("capsule");
586     env_stack.back()[capsule] = value;
587   } else if (v.isFuture() || v.isObject()) {
588     env_stack.back()[v] = value;
589   } else if (v.isGenericDict()) {
590     auto dict = v.toGenericDict();
591     TypePtr key_type = dict.keyType();
592     TypePtr value_type = dict.valueType();
593     for (const auto& entry : dict) {
594       auto static_key = graph->insertConstant(entry.key());
595       auto static_value = graph->insert(aten::__getitem__, {value, static_key});
596       setValue(entry.value(), static_value);
597     }
598   } else {
599     std::ostringstream os;
600     os << "Tracer cannot set value trace for type " << v.tagKind() << ". "
601        << "Supported types are tensor, tensor list, and tuple of tensors.";
602     throw std::runtime_error(os.str());
603   }
604 }
605 
addInputs(Node * n,const char * name,int64_t value)606 void addInputs(Node* n, const char* name, int64_t value) {
607   using ArgumentStash = jit::tracer::ArgumentStash;
608   if (ArgumentStash::hasValue(name)) {
609     Value* v = ArgumentStash::popValue(name);
610     n->addInput(v);
611   } else {
612     detail::genericAddInput(n, value);
613   }
614 }
615 
addInputs(Node * n,const char * name,const c10::SymInt & value)616 void addInputs(Node* n, const char* name, const c10::SymInt& value) {
617   addInputs(n, name, value.guard_int(__FILE__, __LINE__));
618 }
619 
addInputs(Node * n,const char * name,std::optional<int64_t> value)620 void addInputs(Node* n, const char* name, std::optional<int64_t> value) {
621   using ArgumentStash = jit::tracer::ArgumentStash;
622   if (ArgumentStash::hasValue(name)) {
623     Value* v = ArgumentStash::popValue(name);
624     n->addInput(v);
625   } else if (value) {
626     detail::genericAddInput(n, *value);
627   } else {
628     Graph* g = n->owningGraph();
629     Value* none = g->insertNode(g->createNone())->output();
630     n->addInput(none);
631   }
632 }
addInputs(Node * n,const char * name,bool value)633 void addInputs(Node* n, const char* name, bool value) {
634   detail::genericAddInput(n, value);
635 }
addInputs(Node * n,const char * name,const std::optional<bool> & value)636 void addInputs(Node* n, const char* name, const std::optional<bool>& value) {
637   detail::genericAddOptionalInput(n, name, value);
638 }
addInputs(Node * n,const char * name,double value)639 void addInputs(Node* n, const char* name, double value) {
640   detail::genericAddInput(n, value);
641 }
addInputs(Node * n,const char * name,const std::optional<double> & value)642 void addInputs(Node* n, const char* name, const std::optional<double>& value) {
643   detail::genericAddOptionalInput(n, name, value);
644 }
addInputs(Node * n,const char * name,const at::Scalar & value)645 void addInputs(Node* n, const char* name, const at::Scalar& value) {
646   using ArgumentStash = jit::tracer::ArgumentStash;
647   if (ArgumentStash::hasValue(name)) {
648     Value* v = ArgumentStash::popValue(name);
649     n->addInput(v);
650   } else {
651     detail::genericAddInput(n, value);
652   }
653 }
addInputs(Node * n,const char * name,const std::optional<at::Scalar> & value)654 void addInputs(
655     Node* n,
656     const char* name,
657     const std::optional<at::Scalar>& value) {
658   detail::genericAddOptionalInput(n, name, value);
659 }
addInputs(Node * n,const char * name,const c10::string_view value)660 void addInputs(Node* n, const char* name, const c10::string_view value) {
661   detail::genericAddInput(n, std::string(value));
662 }
addInputs(Node * n,const char * name,const std::optional<c10::string_view> & value)663 void addInputs(
664     Node* n,
665     const char* name,
666     const std::optional<c10::string_view>& value) {
667   detail::genericAddOptionalInput(n, name, value);
668 }
addInputs(Node * n,const char * name,const at::Tensor & value)669 void addInputs(Node* n, const char* name, const at::Tensor& value) {
670   n->addInput(getValueTrace(value));
671 }
addInputs(Node * n,const char * name,const std::optional<at::Tensor> & value)672 void addInputs(
673     Node* n,
674     const char* name,
675     const std::optional<at::Tensor>& value) {
676   detail::genericAddOptionalInput(n, name, value);
677 }
addInputs(Node * n,const char * name,const std::optional<at::Generator> & value)678 void addInputs(
679     Node* n,
680     const char* name,
681     const std::optional<at::Generator>& value) {
682   Graph* g = n->owningGraph();
683 
684   if (value.has_value() && value->defined()) {
685     detail::genericAddInput(n, *value);
686   } else {
687     Value* undef_gen = g->insertNode(g->createNone())->output();
688     n->addInput(undef_gen);
689   }
690 }
addInputs(Node * n,const char * name,at::Device value)691 void addInputs(Node* n, const char* name, at::Device value) {
692   detail::genericAddInput(n, value);
693 }
addInputs(Node * n,const char * name,c10::Stream stream)694 void addInputs(Node* n, const char* name, c10::Stream stream) {
695   detail::genericAddInput(n, c10::IValue(stream));
696 }
addInputs(Node * n,const char * name,at::Layout value)697 void addInputs(Node* n, const char* name, at::Layout value) {
698   detail::genericAddInput(n, static_cast<int64_t>(value));
699 }
addInputs(Node * n,const char * name,at::ScalarType value)700 void addInputs(Node* n, const char* name, at::ScalarType value) {
701   detail::genericAddInput(n, static_cast<int64_t>(value));
702 }
addInputs(Node * n,const char * name,at::MemoryFormat value)703 void addInputs(Node* n, const char* name, at::MemoryFormat value) {
704   detail::genericAddInput(n, static_cast<int64_t>(value));
705 }
addInputs(Node * n,const char * name,const std::optional<at::MemoryFormat> & value)706 void addInputs(
707     Node* n,
708     const char* name,
709     const std::optional<at::MemoryFormat>& value) {
710   detail::genericAddOptionalInput(n, name, value);
711 }
addInputs(Node * n,const char * name,const std::optional<at::Layout> & value)712 void addInputs(
713     Node* n,
714     const char* name,
715     const std::optional<at::Layout>& value) {
716   detail::genericAddOptionalInput(n, name, value);
717 }
addInputs(Node * n,const char * name,const std::optional<at::Device> & value)718 void addInputs(
719     Node* n,
720     const char* name,
721     const std::optional<at::Device>& value) {
722   detail::genericAddOptionalInput(n, name, value);
723 }
addInputs(Node * n,const char * name,std::optional<at::DimnameList> value)724 void addInputs(
725     Node* n,
726     const char* name,
727     std::optional<at::DimnameList> value) {
728   TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer");
729 }
addInputs(Node * n,const char * name,const std::optional<at::ScalarType> & value)730 void addInputs(
731     Node* n,
732     const char* name,
733     const std::optional<at::ScalarType>& value) {
734   detail::genericAddOptionalInput(n, name, value);
735 }
addInputs(Node * n,const char * name,at::ArrayRef<at::Tensor> value,bool allow_undefined)736 void addInputs(
737     Node* n,
738     const char* name,
739     at::ArrayRef<at::Tensor> value,
740     bool allow_undefined) {
741   addInputs(n, name, at::ITensorListRef(value), allow_undefined);
742 }
addInputs(Node * n,const char * name,const std::vector<at::Tensor> & value,bool allow_undefined)743 void addInputs(
744     Node* n,
745     const char* name,
746     const std::vector<at::Tensor>& value,
747     bool allow_undefined) {
748   addInputs(n, name, at::ITensorListRef(value), allow_undefined);
749 }
addInputs(Node * n,const char * name,at::ITensorListRef value,bool allow_undefined)750 void addInputs(
751     Node* n,
752     const char* name,
753     at::ITensorListRef value,
754     bool allow_undefined) {
755   Graph* g = n->owningGraph();
756   Node* list_node = nullptr;
757   if (allow_undefined) {
758     // if allow undefined, we create a list of optional tensors
759     list_node = g->insertNode(
760         g->createList(OptionalType::ofTensor(), fmap(value, getValueTrace)));
761   } else {
762     list_node = g->insertNode(
763         g->createList(TensorType::get(), fmap(value, getValueTrace)));
764   }
765   n->addInput(list_node->output());
766 }
addInputs(Node * n,const char * name,const List<std::optional<at::Tensor>> & value)767 TORCH_API void addInputs(
768     Node* n,
769     const char* name,
770     const List<std::optional<at::Tensor>>& value) {
771   Graph* g = n->owningGraph();
772   Node* list_node = nullptr;
773   list_node = g->insertNode(g->createList(
774       OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
775   n->addInput(list_node->output());
776 }
addInputs(Node * n,const char * name,ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,const ClassTypePtr & class_type)777 void addInputs(
778     Node* n,
779     const char* name,
780     ArrayRef<c10::intrusive_ptr<c10::ivalue::Object>> value,
781     const ClassTypePtr& class_type) {
782   Graph* g = n->owningGraph();
783   Node* list_node =
784       g->insertNode(g->createList(class_type, fmap(value, getValueTrace)));
785   n->addInput(list_node->output());
786 }
787 
addInputs(Node * n,const char * name,at::IntArrayRef value)788 void addInputs(Node* n, const char* name, at::IntArrayRef value) {
789   using ArgumentStash = jit::tracer::ArgumentStash;
790   std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)
791       ? ArgumentStash::popIntArrayRef(name)
792       : ArgumentStash::IntArrayRefTrace(value.size());
793 
794   auto& g = getTracingState()->graph;
795   for (const auto i : c10::irange(info.size())) {
796     if (info[i] != nullptr)
797       continue;
798     info[i] = g->insertConstant(value[i]);
799     recordSourceLocation(info[i]->node());
800   }
801   for (jit::Value* v : info) {
802     if (*v->type() != *jit::IntType::get()) {
803       throw std::runtime_error(
804           "Type mismatch in setposattr for IntArrayRef. Check that your program "
805           "is valid without tracing, and please file a bug report if it is.");
806     }
807   }
808   n->addInput(
809       g->insertNode(g->createList(jit::IntType::get(), info))->output());
810 }
811 
addInputs(Node * n,const char * name,c10::SymIntArrayRef value)812 void addInputs(Node* n, const char* name, c10::SymIntArrayRef value) {
813   addInputs(n, name, C10_AS_INTARRAYREF_SLOW(value));
814 }
815 
addInputs(Node * n,const char * name,std::optional<c10::SymInt> value)816 void addInputs(Node* n, const char* name, std::optional<c10::SymInt> value) {
817   addInputs(
818       n,
819       name,
820       value.has_value()
821           ? std::make_optional(value->guard_int(__FILE__, __LINE__))
822           : std::nullopt);
823 }
824 
addInputs(Node * n,const char * name,const std::optional<at::IntArrayRef> & opt_value)825 void addInputs(
826     Node* n,
827     const char* name,
828     const std::optional<at::IntArrayRef>& opt_value) {
829   detail::genericAddOptionalInput(n, name, opt_value);
830 }
831 
addInputs(Node * n,const char * name,const at::OptionalIntArrayRef & opt_value)832 void addInputs(
833     Node* n,
834     const char* name,
835     const at::OptionalIntArrayRef& opt_value) {
836   if (opt_value.has_value()) {
837     jit::tracer::addInputs(n, name, *opt_value);
838   } else {
839     Graph* g = n->owningGraph();
840     Value* none = g->insertNode(g->createNone())->output();
841     n->addInput(none);
842   }
843 }
844 
addInputs(Node * n,const char * name,const at::OptionalSymIntArrayRef & opt_value)845 void addInputs(
846     Node* n,
847     const char* name,
848     const at::OptionalSymIntArrayRef& opt_value) {
849   if (opt_value.has_value()) {
850     jit::tracer::addInputs(n, name, *opt_value);
851   } else {
852     Graph* g = n->owningGraph();
853     Value* none = g->insertNode(g->createNone())->output();
854     n->addInput(none);
855   }
856 }
857 
addInputs(Node * n,const char * name,ArrayRef<double> value)858 void addInputs(Node* n, const char* name, ArrayRef<double> value) {
859   std::vector<Value*> info;
860   auto& g = getTracingState()->graph;
861   for (double elt : value) {
862     info.push_back(g->insertConstant(elt));
863     recordSourceLocation(info.back()->node());
864   }
865   n->addInput(
866       g->insertNode(g->createList(jit::FloatType::get(), info))->output());
867 }
868 
addInputs(Node * n,const char * name,const std::optional<c10::ArrayRef<double>> & opt_value)869 void addInputs(
870     Node* n,
871     const char* name,
872     const std::optional<c10::ArrayRef<double>>& opt_value) {
873   detail::genericAddOptionalInput(n, name, opt_value);
874 }
875 
addInputs(Node * n,const char * name,const c10::intrusive_ptr<c10::ivalue::Object> & obj)876 void addInputs(
877     Node* n,
878     const char* name,
879     const c10::intrusive_ptr<c10::ivalue::Object>& obj) {
880   Value* v = getValueTrace(obj);
881   n->addInput(v);
882 }
883 
addOutput(Node * node,const at::Tensor & output)884 void addOutput(Node* node, const at::Tensor& output) {
885   setOutput(node->addOutput(), output);
886 }
887 
setOutput(Value * value,const at::Tensor & output)888 void setOutput(Value* value, const at::Tensor& output) {
889   if (output.defined()) {
890     value->inferTypeFrom(output);
891     setValueTrace(output, value);
892   }
893 }
894 
addOutput(Node * node,const std::vector<at::Tensor> & outputs)895 void addOutput(Node* node, const std::vector<at::Tensor>& outputs) {
896   Value* value = node->addOutput()->setType(ListType::ofTensors());
897   Graph* graph = node->owningGraph();
898   Node* unpack_node = graph->insertNode(
899       graph->create(prim::ListUnpack, {value}, outputs.size()));
900   for (const auto i : c10::irange(outputs.size())) {
901     Value* output_val = unpack_node->outputs()[i];
902     output_val->inferTypeFrom(outputs[i]);
903     setValueTrace(outputs[i], output_val);
904   }
905 }
906 
addOutput(Node * node,const c10::List<at::Tensor> & outputs)907 void addOutput(Node* node, const c10::List<at::Tensor>& outputs) {
908   return addOutput(node, outputs.vec());
909 }
910 
addOutput(Node * node,const c10::intrusive_ptr<c10::ivalue::Object> & output)911 void addOutput(
912     Node* node,
913     const c10::intrusive_ptr<c10::ivalue::Object>& output) {
914   Value* output_val = node->addOutput();
915   output_val->inferTypeFrom(output);
916   setValueTrace(output, output_val);
917 }
918 
getTracingState()919 const std::shared_ptr<TracingState>& getTracingState() {
920   return detail::tracing_state;
921 }
922 
setTracingState(std::shared_ptr<TracingState> state)923 void setTracingState(std::shared_ptr<TracingState> state) {
924   at::tracer::impl::set_dispatch_enabled(state != nullptr);
925   detail::tracing_state = std::move(state);
926 }
927 
TracingState()928 TracingState::TracingState() : graph(new Graph()), env_stack{Frame()} {}
929 
930 TracingState::~TracingState() = default;
931 
getSizeOf(const autograd::Variable & var,int64_t dim)932 autograd::Variable getSizeOf(const autograd::Variable& var, int64_t dim) {
933   auto& tracing_state = getTracingState();
934   auto& graph = tracing_state->graph;
935 
936   Variable size_var;
937   {
938     // Make sure this scalar to tensor isn't traced!
939     at::AutoDispatchBelowADInplaceOrView guard;
940     size_var = scalar_to_tensor(at::Scalar(var.size(dim)));
941   }
942   auto* value = getValueTrace(var);
943   auto dim_val = graph->insertConstant(dim);
944   recordSourceLocation(dim_val->node());
945   auto* node = graph->insertNode(graph->create(aten::size, {value, dim_val}));
946   recordSourceLocation(node);
947   node->output()->setType(jit::IntType::get());
948 
949   auto ten =
950       graph->insertNode(graph->createNumToTensor(node->output()))->output();
951   setValueTrace(size_var, ten);
952   return size_var;
953 }
954 
getNumelOf(const autograd::Variable & var)955 autograd::Variable getNumelOf(const autograd::Variable& var) {
956   auto& tracing_state = getTracingState();
957   auto& graph = tracing_state->graph;
958 
959   Variable numel_var;
960   {
961     // Make sure this scalar to tensor isn't traced!
962     at::AutoDispatchBelowADInplaceOrView guard;
963     numel_var = scalar_to_tensor(at::Scalar(var.numel()));
964   }
965   auto* value = getValueTrace(var);
966   auto* node = graph->insertNode(graph->create(Symbol::aten("numel"), {value}));
967   recordSourceLocation(node);
968   node->output()->setType(jit::IntType::get());
969 
970   auto ten =
971       graph->insertNode(graph->createNumToTensor(node->output()))->output();
972   setValueTrace(numel_var, ten);
973   return numel_var;
974 }
975 
ensureUniqueIfOutOfPlaced(const char * name,const at::Tensor & tensor)976 void ensureUniqueIfOutOfPlaced(const char* name, const at::Tensor& tensor) {
977   auto& state = getTracingState();
978   if (state && state->force_outplace == false) {
979     // If we're not converting in-place ops to out-of-place, this check is
980     // unnecessary
981     return;
982   }
983   auto aliases = tensor.storage().use_count();
984   if (isTracing() && aliases > 1) {
985     std::stringstream ss;
986     ss << "There are " << aliases
987        << " live references to the data region being modified when tracing in-place operator "
988        << name
989        << ". This might cause the trace to be incorrect, because all other views "
990        << "that also reference this data will not reflect this change in the trace! "
991        << "On the other hand, if all other views use the same memory chunk, but are disjoint (e.g. "
992        << "are outputs of torch.split), this might still be safe.";
993     warn(ss.str().c_str());
994   }
995 }
ensureUniqueIfOutOfPlaced(const char * name,const std::optional<at::Tensor> & tensor)996 void ensureUniqueIfOutOfPlaced(
997     const char* name,
998     const std::optional<at::Tensor>& tensor) {
999   ensureUniqueIfOutOfPlaced(name, tensor.has_value() ? *tensor : at::Tensor());
1000 }
1001 
1002 ////////////////////////////////////////////////////////////////////////////////
1003 // Argument stash
1004 ////////////////////////////////////////////////////////////////////////////////
1005 thread_local ArgumentStash ArgumentStash::stash;
1006 
stashIntArrayRefElem(const std::string & arg_name,size_t size,size_t idx,const Variable & var)1007 void ArgumentStash::stashIntArrayRefElem(
1008     const std::string& arg_name,
1009     size_t size,
1010     size_t idx,
1011     const Variable& var) {
1012   // TODO: check type?
1013   if (!isTracing())
1014     return;
1015   IntArrayRefTrace& list_trace =
1016       stash.intlists.emplace(arg_name, size).first->second;
1017   AT_ASSERT(size == list_trace.size());
1018   AT_ASSERT(idx < list_trace.size());
1019   AT_ASSERT(list_trace[idx] == nullptr);
1020 
1021   Value* ten = getValueTrace(var);
1022   auto& g = *ten->owningGraph();
1023   WithInsertPoint guard(ten->node()->next());
1024   auto prim = g.insert(aten::Int, {ten});
1025   list_trace[idx] = prim;
1026 }
1027 
stashValue(const std::string & arg_name,size_t idx,const Variable & var,const TypePtr & type)1028 void ArgumentStash::stashValue(
1029     const std::string& arg_name,
1030     size_t idx,
1031     const Variable& var,
1032     const TypePtr& type) {
1033   if (!isTracing())
1034     return;
1035 
1036   Value* ten = getValueTrace(var);
1037   WithInsertPoint guard(ten->node()->next());
1038   auto& g = *ten->owningGraph();
1039 
1040   if (type == IntType::get()) {
1041     ten = g.insert(aten::Int, {ten});
1042   } else if (type == FloatType::get()) {
1043     ten = g.insert(aten::Float, {ten});
1044   } else if (type == NumberType::get()) {
1045     ten = g.insert(aten::ScalarImplicit, {ten});
1046   }
1047 
1048   stash.values.emplace(arg_name, ten);
1049 }
1050 
1051 ////////////////////////////////////////////////////////////////////////////////
1052 // Stack trace recording
1053 ////////////////////////////////////////////////////////////////////////////////
1054 // no python present so we just do not record source information
defaultRecordSourceLocation(Node * n)1055 static void defaultRecordSourceLocation(Node* n) {}
1056 std::atomic<decltype(&defaultRecordSourceLocation)> record_source_location(
1057     defaultRecordSourceLocation);
recordSourceLocation(Node * n)1058 void recordSourceLocation(Node* n) {
1059   return record_source_location.load()(n);
1060 }
setRecordSourceLocation(void (* v)(Node *))1061 void setRecordSourceLocation(void (*v)(Node*)) {
1062   record_source_location.store(v);
1063 }
1064 
defaultPythonCallstack()1065 static std::vector<StackEntry> defaultPythonCallstack() {
1066   return std::vector<StackEntry>();
1067 }
1068 std::atomic<decltype(&defaultPythonCallstack)> python_callstack_fn(
1069     defaultPythonCallstack);
pythonCallstack()1070 std::vector<StackEntry> pythonCallstack() {
1071   return python_callstack_fn.load()();
1072 }
setPythonCallstack(std::vector<StackEntry> (* v)())1073 void setPythonCallstack(std::vector<StackEntry> (*v)()) {
1074   python_callstack_fn.store(v);
1075 }
1076 
defaultWarn(const std::string & str)1077 static void defaultWarn(const std::string& str) {
1078   TORCH_WARN(str);
1079 }
1080 std::atomic<warn_fn_type> warn_callback{defaultWarn};
1081 
1082 const char* WARN_PYTHON_DATAFLOW =
1083     " might cause the trace to be incorrect. We can't record the data flow of "
1084     "Python values, so this value will be treated as a constant in the future. "
1085     "This means that the trace might not generalize to other inputs!";
1086 const char* WARN_CONSTRUCTOR =
1087     " results are registered as constants in the trace. You can safely ignore this "
1088     "warning if you use this function to create tensors out of constant variables "
1089     "that would be the same every time you call this function. In any other case, "
1090     "this might cause the trace to be incorrect.";
1091 const char* WARN_RESIZE =
1092     " can't be represented in the JIT at the moment, so we won't connect any uses of "
1093     "this value with its current trace. If you happen to use it again, it will show "
1094     "up as a constant in the graph. Consider using `view` or `reshape` to make "
1095     "it traceable.";
1096 const char* STRICT_TRACER_MSG =
1097     " might cause the trace to be incorrect, this is only valid if the container "
1098     "structure does not change based on the module's inputs. Consider using a constant "
1099     "container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a "
1100     "`NamedTuple` instead). If you absolutely need this and know the side effects, pass "
1101     "strict=False to trace() to allow this behavior.";
1102 // XXX: _kind can be a nullptr
_do_warn(const char * _reason,const char * _kind)1103 void _do_warn(const char* _reason, const char* _kind) {
1104   std::string reason{_reason};
1105   std::string kind{_kind ? _kind : ""};
1106   std::ostringstream s;
1107   s << reason << kind;
1108   warn_callback.load()(s.str());
1109 }
1110 
setWarn(warn_fn_type fn)1111 void setWarn(warn_fn_type fn) {
1112   warn_callback.store(fn);
1113 }
1114 } // namespace torch::jit::tracer
1115