xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/compiled_autograd.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/TensorGeometry.h>
3 #include <ATen/core/ivalue.h>
4 #include <c10/core/impl/TorchDispatchModeTLS.h>
5 #include <c10/util/flat_hash_map.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/input_metadata.h>
8 #include <torch/csrc/autograd/saved_variable.h>
9 #include <torch/csrc/autograd/variable_info.h>
10 #include <torch/csrc/utils/python_stub.h>
11 #include <torch/csrc/utils/torch_dispatch_mode.h>
12 #include <typeindex>
13 #include <vector>
14 
15 // see [Note: Compiled Autograd]
16 
17 namespace torch::dynamo::autograd {
18 using namespace torch::autograd;
19 
20 struct SizeInput {
21   // Note: int value is still needed when dynamic to pass as an arg
22   enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 };
SizeInputSizeInput23   SizeInput(DynType dt, int64_t v) : dyn_type(dt), value(v) {}
24   DynType dyn_type;
25   int64_t value;
26 };
27 
28 struct CacheKeyBuffer {
CacheKeyBufferCacheKeyBuffer29   CacheKeyBuffer(const uint8_t* key, uint16_t len) : data(new uint8_t[len]) {
30     std::memcpy(data.get(), key, len);
31   }
getCacheKeyBuffer32   const uint8_t* get() const {
33     return data.get();
34   }
35 
36  private:
37   // NOLINTNEXTLINE(*c-array*)
38   std::unique_ptr<uint8_t[]> data;
39 };
40 
41 struct CacheKey {
42   // Key to find the next node in the shadow graph.  We use C++ RTTI for the
43   // type of the node (ntype), then a key generated with a visitor pattern.
CacheKeyCacheKey44   CacheKey(const std::type_index& ntype, const uint8_t* key, uint16_t len)
45       : node_type(ntype), key_size(len), key(key) {}
46 
47   bool operator<(const CacheKey& other) const {
48     if (node_type != other.node_type) {
49       return node_type < other.node_type;
50     }
51     if (key_size != other.key_size) {
52       return key_size < other.key_size;
53     }
54     return std::memcmp(key, other.key, key_size) < 0;
55   }
56 
57   bool operator==(const CacheKey& other) const {
58     return node_type == other.node_type && key_size == other.key_size &&
59         std::memcmp(key, other.key, key_size) == 0;
60   }
61 
hashCacheKey62   size_t hash() const {
63     // don't bother hashing the key data, common case 1 cache entry per node
64     return std::hash<std::type_index>()(node_type) ^ key_size;
65   }
66 
67   std::type_index node_type;
68   uint16_t key_size;
69   const uint8_t* key;
70 };
71 
72 struct NodeCall {
NodeCallNodeCall73   NodeCall(uint32_t id_, std::shared_ptr<Node> node_)
74       : id(id_), node(std::move(node_)) {}
75 
mark_outputNodeCall76   void mark_output(int input_nr, int output_idx) {
77     graph_output.emplace_back(input_nr, output_idx);
78   }
79 
80   uint32_t id;
81   std::shared_ptr<Node> node;
82   std::vector<std::pair<int, int>> tensor_pre_hooks;
83   std::vector<int> pre_hooks;
84   std::vector<int> post_hooks;
85   std::vector<int> post_acc_grad_hooks;
86   std::vector<std::pair<int, int>> graph_output;
87   bool needed = true;
88 };
89 
90 struct NodeCalls : public std::unordered_map<Node*, NodeCall> {
lookupNodeCalls91   NodeCall& lookup(const std::shared_ptr<Node>& function) {
92     auto it = find(function.get());
93     if (it == end()) {
94       it = emplace(function.get(), NodeCall(_next_id++, function)).first;
95     }
96     return it->second;
97   }
98 
99  private:
100   uint32_t _next_id = 0;
101 };
102 
103 struct TensorArg {
104   // Represents a de-duplicated tensor that will be passed into the graph
idTensorArg105   TensorArg(uint32_t i = 0) : id(i) {}
indexTensorArg106   uint32_t index() const {
107     TORCH_INTERNAL_ASSERT(defined());
108     return id - 1;
109   }
definedTensorArg110   bool defined() const {
111     return id != 0;
112   }
113   uint32_t id;
114   at::Tensor proxy_tensor;
115 };
116 
117 struct TensorArgs {
118   // Manages a collection of TensorArgs and mappings from Tensors/SavedVariables
119   // to them.  This also allows us to unpack SavedVariable exactly once and
120   // store the unpacked Tensor.
121 
122   TensorArg& lookup(const at::Tensor& tensor, bool create = false) {
123     if (!tensor.defined()) {
124       return _undefined;
125     }
126     auto impl = tensor.unsafeGetTensorImpl();
127     auto it = _args.find(impl);
128     if (it == _args.end()) {
129       TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1);
130       it = _args.emplace(impl, TensorArg(_next_id++)).first;
131       inputs.emplace_back(tensor);
132     }
133     return it->second;
134   }
135 
lookupTensorArgs136   TensorArg& lookup(const SavedVariable& sv) {
137     auto it = _saved_variables.find(&sv);
138     TORCH_INTERNAL_ASSERT(it != _saved_variables.end());
139     return *it->second;
140   }
141 
addTensorArgs142   TensorArg& add(const at::Tensor& tensor) {
143     return lookup(tensor, true);
144   }
145 
addTensorArgs146   TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) {
147     // TODO(jansel): Here we unpack the SavedVariable exactly once.  This might
148     // fire SavedTensor hooks.  In the future we should try to put saved tensor
149     // hooks into the graph.
150     at::Tensor tensor = sv.unpack(node);
151     TensorArg& arg = add(tensor);
152     _saved_variables.emplace(&sv, &arg);
153     return arg;
154   }
155 
156   // the concrete tensors that will get passed into the graph as inputs
157   std::vector<at::Tensor> inputs;
158 
159  private:
160   std::unordered_map<const c10::TensorImpl*, TensorArg> _args;
161   // Every TensorArg from this is actually owned by _args (or _undefined) and
162   // that's why we have an un-owned pointer here.
163   std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables;
164   TensorArg _undefined;
165   uint32_t _next_id = 1; // id=0 used by _undefined
166 };
167 
168 struct LiftedIValueArg {
169   LiftedIValueArg() = delete;
LiftedIValueArgLiftedIValueArg170   LiftedIValueArg(const at::IValue* ptr)
171       : actual_ptr(ptr), proxy(at::IValue::uninitialized()) {}
172 
173   const at::IValue* actual_ptr; // lifetime handled by autograd node
174   at::IValue proxy;
175 };
176 
177 struct LiftedIValueArgs {
next_proxyLiftedIValueArgs178   at::IValue& next_proxy(const at::IValue* actual_ptr) {
179     TORCH_INTERNAL_ASSERT(next < args.size());
180     auto& iv_arg = args.at(next++);
181     TORCH_INTERNAL_ASSERT(iv_arg.actual_ptr == actual_ptr);
182     return iv_arg.proxy;
183   }
184 
185   std::vector<LiftedIValueArg> args;
186   size_t next = 0;
187 };
188 
189 struct AutogradCompilerCall {
add_size_inputAutogradCompilerCall190   void add_size_input(const c10::SymInt& s) {
191     all_size_inputs.emplace_back(
192         default_dyn_type, s.guard_int(__FILE__, __LINE__));
193   }
194 
emplace_hookAutogradCompilerCall195   size_t emplace_hook(c10::SafePyObject&& fn) {
196     hooks.emplace_back(std::move(fn));
197     return hooks.size() - 1;
198   }
199 
200   TensorArgs tensor_args;
201   std::vector<SizeInput> all_size_inputs;
202   LiftedIValueArgs lifted_ivalue_args;
203   std::vector<int64_t> dyn_size_inputs;
204   std::vector<c10::SafePyObject> hooks;
205   NodeCalls node_calls;
206   SizeInput::DynType default_dyn_type = SizeInput::STATIC;
207 };
208 
209 class CompiledNodeArgs {
210   // CompiledNodeArgs builds a representation of the constant values found
211   // across all the nodes in the compiled graph, via 'collect' overloads. The
212   // collected constants are specialized on by concatenation into a cache key.
213   // Tensor, symint arguments (which are lifted to become graph inputs rather
214   // than specialized on) are forwarded to the compiler and not included in the
215   // key.
216  public:
collect(const TensorArg & t)217   void collect(const TensorArg& t) {
218     collect_size(t.id);
219     if (t.defined()) {
220       const at::Tensor& tensor = _compiler.tensor_args.inputs[t.index()];
221       // including these in the cache key means dynamo-level tensor guards can
222       // be skipped
223       collect(tensor.device());
224       collect(tensor.dtype());
225       collect(tensor.requires_grad());
226     }
227   }
228 
collect(const at::Tensor & t)229   void collect(const at::Tensor& t) {
230     collect(_compiler.tensor_args.add(t));
231   }
collect(const SavedVariable & sv,bool is_output)232   void collect(const SavedVariable& sv, bool is_output) {
233     collect(
234         _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr));
235   }
collect(const c10::SymInt & t)236   void collect(const c10::SymInt& t) {
237     _compiler.add_size_input(t);
238   }
collect(const std::vector<SavedVariable> & t,bool is_output)239   void collect(const std::vector<SavedVariable>& t, bool is_output) {
240     collect_size(t.size());
241     for (const SavedVariable& i : t) {
242       collect(i, is_output);
243     }
244   }
245   template <typename T>
collect(const std::vector<T> & t)246   void collect(const std::vector<T>& t) {
247     collect_size(t.size());
248     for (const T& i : t) {
249       collect(i);
250     }
251   }
collect(const c10::ArrayRef<SavedVariable> & t,bool is_output)252   void collect(const c10::ArrayRef<SavedVariable>& t, bool is_output) {
253     collect_size(t.size());
254     for (const SavedVariable& i : t) {
255       collect(i, is_output);
256     }
257   }
258   template <typename T>
collect(const c10::ArrayRef<T> & t)259   void collect(const c10::ArrayRef<T>& t) {
260     collect_size(t.size());
261     for (const T& i : t) {
262       collect(i);
263     }
264   }
265   template <typename T>
collect(const c10::OptionalArray<T> & t)266   void collect(const c10::OptionalArray<T>& t) {
267     collect(t.list);
268   }
269   template <typename T>
collect(const std::optional<T> & t)270   void collect(const std::optional<T>& t) {
271     if (cond(t.has_value())) {
272       collect(*t);
273     }
274   }
275   template <typename A, typename B>
collect(const std::pair<A,B> & t)276   void collect(const std::pair<A, B>& t) {
277     collect(t.first);
278     collect(t.second);
279   }
280   template <typename V>
collect(const ska::flat_hash_map<std::string,V> & m)281   void collect(const ska::flat_hash_map<std::string, V>& m) {
282     collect_size(m.size());
283 
284     std::vector<std::string> keys;
285     keys.reserve(m.size());
286     std::transform(
287         m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) {
288           return entry.first;
289         });
290     std::sort(keys.begin(), keys.end());
291     for (const auto& k : keys) {
292       collect(k);
293       collect(m.at(k));
294     }
295   }
296   void collect(const at::IValue& iv, bool nested = false) {
297     // used by AutogradContext::saved_data from CppNode
298     if (iv.isList()) {
299       c10::List<at::IValue> list = iv.toList();
300       collect_size(list.size());
301       for (auto&& value : list) {
302         collect(value, true);
303       }
304     } else if (iv.isGenericDict()) {
305       c10::Dict<at::IValue, at::IValue> ordered_dict = iv.toGenericDict();
306       collect_size(ordered_dict.size());
307       // NOLINTNEXTLINE(modernize-loop-convert)
308       for (auto it = ordered_dict.begin(); it != ordered_dict.end(); it++) {
309         collect(it->key());
310         collect(it->value(), true);
311       }
312     } else if (iv.isTensor()) {
313       collect(iv.toTensor());
314     } else if (
315         !nested &&
316         (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) {
317       // can't lift ivalues nested in collections
318       _compiler.lifted_ivalue_args.args.emplace_back(&iv);
319     } else {
320       try {
321         collect(static_cast<uint64_t>(at::IValue::hash(iv)));
catch(const std::runtime_error & e)322       } catch (const std::runtime_error& e) {
323         std::string msg =
324             "Compiled autograd can not trace unhashable IValues, error: " +
325             std::string(e.what());
326         TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
327       }
328     }
329   }
collect(const c10::Scalar & t)330   void collect(const c10::Scalar& t) {
331     auto type = t.type();
332     specialize_on_bytes(type);
333     if (type == c10::ScalarType::Double) {
334       collect(t.toDouble());
335     } else if (type == c10::ScalarType::Long) {
336       collect(t.toLong());
337     } else if (type == c10::ScalarType::Bool) {
338       collect(t.toBool());
339     } else if (type == c10::ScalarType::ComplexDouble) {
340       auto c = t.toComplexDouble();
341       collect(c.real());
342       collect(c.imag());
343     } else {
344       TORCH_INTERNAL_ASSERT(false);
345     }
346   }
collect(const c10::TensorOptions & t)347   void collect(const c10::TensorOptions& t) {
348     collect(t.device());
349     collect(t.dtype());
350     collect(t.layout());
351     collect(t.requires_grad());
352     collect(t.pinned_memory());
353     collect(t.memory_format_opt());
354   }
collect(const at::TensorGeometry & t)355   void collect(const at::TensorGeometry& t) {
356     collect(t.sym_sizes());
357     collect(t.sym_strides());
358     collect(t.sym_storage_offset());
359   }
collect(const torch::autograd::TypeAndSize & t)360   void collect(const torch::autograd::TypeAndSize& t) {
361     collect(t.sym_sizes);
362     collect(t.options);
363   }
collect(const c10::Device & t)364   void collect(const c10::Device& t) {
365     collect(t.type());
366     collect(t.index());
367   }
collect(const std::string & t)368   void collect(const std::string& t) {
369     collect_size(t.size());
370     for (char c : t) {
371       collect(c);
372     }
373   }
collect(const caffe2::TypeMeta & t)374   void collect(const caffe2::TypeMeta& t) {
375     specialize_on_bytes(t.id());
376   }
collect(const std::shared_ptr<Node> & t)377   void collect(const std::shared_ptr<Node>& t) {
378     // Note: this is only capturing the ID of the node not everything
379     // contained inside it.  This is used for tracking connections between
380     // nodes and the actual details of the node itself must be handled by
381     // a seperate call to `node->compiled_args()`.
382     if (cond((bool)t)) {
383       collect(_compiler.node_calls.lookup(t));
384     }
385   }
collect(const NodeCall & t)386   void collect(const NodeCall& t) {
387     collect_size(t.id);
388     collect(t.graph_output);
389     collect_hooks_from(t.node.get());
390   }
collect(const Edge & t)391   void collect(const Edge& t) {
392     if (cond(t.is_valid())) {
393       collect_size(_compiler.node_calls.lookup(t.function).id);
394       collect_size(t.input_nr);
395       collect(t.function->input_metadata(t.input_nr)); // for validate_outputs
396     }
397   }
collect(const InputMetadata & t)398   void collect(const InputMetadata& t) {
399     TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented");
400     collect(t.options());
401     collect(t.is_tensor_subclass());
402     collect(t.shape_as_dim_vector());
403   }
collect(const VariableInfo & t)404   void collect(const VariableInfo& t) {
405     collect(t.layout);
406     collect(t.device);
407     collect(t.scalar_type);
408     collect(t.size);
409     collect(t.requires_grad);
410     collect(t.is_empty);
411   }
cond(bool cond)412   bool cond(bool cond) {
413     collect(cond);
414     return cond;
415   }
416 
417 #define COLLECT_AS_BYTES(T) \
418   void collect(T t) {       \
419     specialize_on_bytes(t); \
420   }
421   COLLECT_AS_BYTES(c10::ScalarType);
422   COLLECT_AS_BYTES(c10::DeviceType);
423   COLLECT_AS_BYTES(c10::Layout);
424   COLLECT_AS_BYTES(c10::MemoryFormat);
425   COLLECT_AS_BYTES(int8_t);
426   COLLECT_AS_BYTES(int16_t);
427   COLLECT_AS_BYTES(int32_t);
428   COLLECT_AS_BYTES(int64_t);
429   COLLECT_AS_BYTES(uint8_t);
430   COLLECT_AS_BYTES(uint16_t);
431   COLLECT_AS_BYTES(uint32_t);
432   COLLECT_AS_BYTES(uint64_t);
433   COLLECT_AS_BYTES(bool);
434   COLLECT_AS_BYTES(float);
435   COLLECT_AS_BYTES(double);
436 #undef COLLECT_AS_BYTES
437 
collect_hooks_from(Node * fn)438   void collect_hooks_from(Node* fn) {
439     TORCH_CHECK(
440         fn->retains_grad_hooks().empty(),
441         "retains_grad_hooks not implemented for compiled autograd");
442     for (auto& i : fn->tensor_pre_hooks()) {
443       i->compiled_args(*this);
444     }
445     for (auto& i : fn->pre_hooks()) {
446       i->compiled_args(*this);
447     }
448     for (auto& i : fn->post_hooks()) {
449       i->compiled_args(*this);
450     }
451     collect_size(_node_call.tensor_pre_hooks.size());
452     collect_size(_node_call.pre_hooks.size());
453     collect_size(_node_call.post_hooks.size());
454     for (const auto& h : _node_call.tensor_pre_hooks) {
455       collect_size(static_cast<size_t>(h.second));
456     }
457   }
458 
key()459   CacheKey key() const {
460     Node* node = _node_call.node.get();
461     return CacheKey(
462         typeid(*node), _specialization_key, _specialization_key_size);
463   }
464 
add_backward(c10::SafePyObject && obj)465   size_t add_backward(c10::SafePyObject&& obj) {
466     return _compiler.emplace_hook(std::move(obj));
467   }
468 
add_backward_state(c10::SafePyObject && obj)469   size_t add_backward_state(c10::SafePyObject&& obj) {
470     return _compiler.emplace_hook(std::move(obj));
471   }
472 
add_tensor_pre_hook(c10::SafePyObject && obj,int index)473   void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) {
474     auto fn_id = _compiler.emplace_hook(std::move(obj));
475     collect_size(fn_id);
476     _node_call.tensor_pre_hooks.emplace_back(fn_id, index);
477   }
478 
add_pre_hook(c10::SafePyObject && obj)479   void add_pre_hook(c10::SafePyObject&& obj) {
480     auto fn_id = _compiler.emplace_hook(std::move(obj));
481     collect_size(fn_id);
482     _node_call.pre_hooks.emplace_back(fn_id);
483   }
484 
add_post_hook(c10::SafePyObject && obj)485   void add_post_hook(c10::SafePyObject&& obj) {
486     auto fn_id = _compiler.emplace_hook(std::move(obj));
487     collect_size(fn_id);
488     _node_call.post_hooks.emplace_back(fn_id);
489   }
490 
add_post_acc_grad_hook(c10::SafePyObject && obj)491   void add_post_acc_grad_hook(c10::SafePyObject&& obj) {
492     auto fn_id = _compiler.emplace_hook(std::move(obj));
493     collect_size(fn_id);
494     _node_call.post_acc_grad_hooks.emplace_back(fn_id);
495   }
496 
497   // Need to template the size_t to silence internal 32-bit build errors due to
498   // a mix of -Werror, -Wtautological-type-limit-compare and
499   // -Wunknown-pragmas
500   template <typename T>
collect_size(T s)501   std::enable_if_t<std::is_unsigned_v<T>, void> collect_size(T s) {
502     // we expect sizes to be small, so try to cram them into a single byte
503     constexpr uint8_t encode_as_u64 = std::numeric_limits<uint8_t>::max();
504     constexpr uint8_t encode_as_u32 = encode_as_u64 - 1;
505     constexpr uint8_t encode_as_u16 = encode_as_u64 - 2;
506     if (C10_UNLIKELY(s >= encode_as_u16)) {
507       // first write a byte indicating the path we followed, then the data
508       if (s <= std::numeric_limits<uint16_t>::max()) {
509         // 3 bytes
510         specialize_on_bytes(encode_as_u16);
511         specialize_on_bytes(static_cast<uint16_t>(s));
512       } else if (s <= std::numeric_limits<uint32_t>::max()) {
513         // 5 bytes
514         specialize_on_bytes(encode_as_u32);
515         specialize_on_bytes(static_cast<uint32_t>(s));
516       } else {
517         // 9 bytes
518         specialize_on_bytes(encode_as_u64);
519         specialize_on_bytes(s);
520       }
521     } else {
522       // happy case, 1 byte
523       specialize_on_bytes(static_cast<uint8_t>(s));
524     }
525   }
526 
set_default_dyn_type(SizeInput::DynType default_dyn_type)527   SizeInput::DynType set_default_dyn_type(SizeInput::DynType default_dyn_type) {
528     return std::exchange(_compiler.default_dyn_type, default_dyn_type);
529   }
530 
CompiledNodeArgs(AutogradCompilerCall & compiler,NodeCall & node_call)531   CompiledNodeArgs(AutogradCompilerCall& compiler, NodeCall& node_call)
532       : _compiler(compiler),
533         _node_call(node_call),
534         _specialization_key(
535             // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
536             (uint8_t*)std::malloc(_specialization_key_storage)) {}
~CompiledNodeArgs()537   ~CompiledNodeArgs() {
538     // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
539     std::free(_specialization_key);
540   }
541   CompiledNodeArgs(const CompiledNodeArgs&) = delete;
542 
543  private:
544   template <typename T>
specialize_on_bytes(const T & t)545   void specialize_on_bytes(const T& t) {
546     while (C10_UNLIKELY(
547         _specialization_key_size + sizeof(T) > _specialization_key_storage)) {
548       _specialization_key_storage *= 2;
549       // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
550       _specialization_key = (uint8_t*)std::realloc(
551           _specialization_key, _specialization_key_storage);
552     }
553     std::memcpy(_specialization_key + _specialization_key_size, &t, sizeof(T));
554     _specialization_key_size += sizeof(T);
555   }
556 
557   AutogradCompilerCall& _compiler;
558   NodeCall& _node_call;
559   size_t _specialization_key_size{0};
560   size_t _specialization_key_storage{1024};
561   uint8_t* _specialization_key;
562 };
563 
564 struct TraceState {
TraceStateTraceState565   TraceState(std::vector<std::optional<c10::SymInt>>&& ss, size_t num_outputs)
566       : sym_sizes(ss), outputs(num_outputs) {}
567 
debug_assertsTraceState568   void debug_asserts() {
569     TORCH_INTERNAL_ASSERT(sym_sizes_index == sym_sizes.size());
570   }
next_sym_sizeTraceState571   std::optional<c10::SymInt> next_sym_size() {
572     TORCH_INTERNAL_ASSERT(sym_sizes_index < sym_sizes.size());
573     return sym_sizes[sym_sizes_index++];
574   }
575 
576   size_t sym_sizes_index{0};
577   std::vector<std::optional<c10::SymInt>> sym_sizes;
578   variable_list outputs;
579 };
580 
581 class SwapSavedVariables {
582   // SwapSavedVariables is used during the tracing/compilation phase after a
583   // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes,
584   // allows tracing to happen, then swaps them back afterwards.
585  public:
before(at::Tensor & t)586   void before(at::Tensor& t) {
587     TensorArg& arg = compiler.tensor_args.lookup(t);
588     stashed_tensors.save(&t, std::move(t));
589     if (arg.defined()) {
590       TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
591       t = arg.proxy_tensor;
592     }
593   }
after(at::Tensor & t)594   void after(at::Tensor& t) {
595     stashed_tensors.restore(&t);
596   }
597 
before(SavedVariable & t)598   void before(SavedVariable& t) {
599     TensorArg& arg = compiler.tensor_args.lookup(t);
600     stashed_variables.save(&t, std::move(t));
601     if (arg.defined()) {
602       bool prior = at::SavedTensorDefaultHooks::set_tracing(true);
603       TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined());
604       t = SavedVariable(arg.proxy_tensor, false);
605       at::SavedTensorDefaultHooks::set_tracing(prior);
606     }
607   }
after(SavedVariable & t)608   void after(SavedVariable& t) {
609     stashed_variables.restore(&t);
610   }
611 
before(c10::SymInt & t)612   void before(c10::SymInt& t) {
613     stashed_symints.save(&t, c10::SymInt(t));
614     auto opt_value = state.next_sym_size();
615     if (opt_value.has_value()) {
616       t = *opt_value; // dynamic shape
617     }
618   }
after(c10::SymInt & t)619   void after(c10::SymInt& t) {
620     stashed_symints.restore(&t);
621   }
622 
before(at::IValue & iv)623   void before(at::IValue& iv) {
624     if (iv.isTensor()) {
625       before(iv.toTensor());
626     } else {
627       stashed_ivalues.save(&iv, at::IValue(iv));
628       if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) {
629         iv = compiler.lifted_ivalue_args.next_proxy(&iv);
630       }
631     }
632   }
633 
after(at::IValue & t)634   void after(at::IValue& t) {
635     if (t.isTensor()) {
636       after(t.toTensor());
637     } else {
638       stashed_ivalues.restore(&t);
639     }
640   }
641 
before(Edge & t)642   void before(Edge& t) {
643     if (t.is_valid()) {
644       // need for symints used by validate_outputs
645       before(t.function->mutable_input_metadata(t.input_nr));
646     }
647   }
after(Edge & t)648   void after(Edge& t) {
649     if (t.is_valid()) {
650       after(t.function->mutable_input_metadata(t.input_nr));
651     }
652   }
before(InputMetadata & t)653   void before(InputMetadata& t) {
654     before(t.mutable_shape_as_dim_vector());
655   }
after(InputMetadata & t)656   void after(InputMetadata& t) {
657     after(t.mutable_shape_as_dim_vector());
658   }
before(at::TensorGeometry & t)659   void before(at::TensorGeometry& t) {
660     before(t.mutable_sizes());
661     before(t.mutable_strides());
662     before(t.mutable_storage_offset());
663     t.recompute();
664   }
after(at::TensorGeometry & t)665   void after(at::TensorGeometry& t) {
666     after(t.mutable_sizes());
667     after(t.mutable_strides());
668     after(t.mutable_storage_offset());
669     t.recompute();
670   }
before(torch::autograd::TypeAndSize & t)671   void before(torch::autograd::TypeAndSize& t) {
672     before(t.sym_sizes);
673     before(t.options);
674   }
after(torch::autograd::TypeAndSize & t)675   void after(torch::autograd::TypeAndSize& t) {
676     after(t.sym_sizes);
677     after(t.options);
678   }
before(VariableInfo & t)679   void before(VariableInfo& t) {
680     before(t.size);
681   }
after(VariableInfo & t)682   void after(VariableInfo& t) {
683     after(t.size);
684   }
685 
686   template <typename T>
before(std::vector<T> & t)687   void before(std::vector<T>& t) {
688     for (T& i : t) {
689       before(i);
690     }
691   }
692   template <typename T>
after(std::vector<T> & t)693   void after(std::vector<T>& t) {
694     for (T& i : t) {
695       after(i);
696     }
697   }
698   template <typename T, unsigned N>
before(c10::SmallVector<T,N> & t)699   void before(c10::SmallVector<T, N>& t) {
700     for (T& i : t) {
701       before(i);
702     }
703   }
704   template <typename T, unsigned N>
after(c10::SmallVector<T,N> & t)705   void after(c10::SmallVector<T, N>& t) {
706     for (T& i : t) {
707       after(i);
708     }
709   }
710 
711   template <typename T>
before(c10::OptionalArray<T> & t)712   void before(c10::OptionalArray<T>& t) {
713     before(t.list);
714   }
715   template <typename T>
after(c10::OptionalArray<T> & t)716   void after(c10::OptionalArray<T>& t) {
717     after(t.list);
718   }
719 
720   template <typename T>
before(std::optional<T> & t)721   void before(std::optional<T>& t) {
722     if (t.has_value()) {
723       before(*t);
724     }
725   }
726   template <typename T>
after(std::optional<T> & t)727   void after(std::optional<T>& t) {
728     if (t.has_value()) {
729       after(*t);
730     }
731   }
732 
733   template <typename V>
before(ska::flat_hash_map<std::string,V> & m)734   void before(ska::flat_hash_map<std::string, V>& m) {
735     std::vector<std::string> keys;
736     keys.reserve(m.size());
737     std::transform(
738         m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) {
739           return entry.first;
740         });
741     std::sort(keys.begin(), keys.end());
742     for (auto& k : keys) {
743       before(m.at(k));
744     }
745   }
746 
747   template <typename V>
after(ska::flat_hash_map<std::string,V> & m)748   void after(ska::flat_hash_map<std::string, V>& m) {
749     for (auto& [_, v] : m) {
750       after(v);
751     }
752   }
753 
754 #define NO_OP_VISIT(T)     \
755   void before(const T&) {} \
756   void after(const T&) {}
757   NO_OP_VISIT(caffe2::TypeMeta);
758   NO_OP_VISIT(c10::Device);
759   NO_OP_VISIT(c10::DeviceType);
760   NO_OP_VISIT(c10::Layout);
761   NO_OP_VISIT(c10::MemoryFormat);
762   NO_OP_VISIT(c10::ScalarType);
763   NO_OP_VISIT(c10::Scalar);
764   NO_OP_VISIT(c10::TensorOptions);
765   NO_OP_VISIT(std::string);
766   NO_OP_VISIT(int64_t);
767   NO_OP_VISIT(bool);
768   NO_OP_VISIT(double);
769 #undef NO_OP_VISIT
770 
SwapSavedVariables(AutogradCompilerCall & c,TraceState & s,PyObject * p,const NodeCall & n)771   SwapSavedVariables(
772       AutogradCompilerCall& c,
773       TraceState& s,
774       PyObject* p,
775       const NodeCall& n)
776       : compiler(c), state(s), py_compiler(p), curr_node_call(n) {}
777 
get_py_compiler()778   PyObject* get_py_compiler() {
779     return py_compiler;
780   }
781 
get_curr_node_call()782   const NodeCall& get_curr_node_call() {
783     return curr_node_call;
784   }
785 
debug_asserts()786   void debug_asserts() {
787     stashed_variables.debug_assert();
788     stashed_tensors.debug_assert();
789     stashed_symints.debug_assert();
790   }
791 
792  private:
793   template <typename T>
794   struct Stashed {
StashedStashed795     Stashed(T&& v) : prior_value(std::move(v)) {}
796     T prior_value;
797     // Note: we need count here to support duplicate calls to before()
798     // which happen when we have multiple autograd::Edge objects pointing
799     // to the same autograd::Node
800     int count = 1;
801   };
802 
803   template <typename T>
804   struct StashedVars : public std::unordered_map<const T*, Stashed<T>> {
saveStashedVars805     void save(const T* key, T&& value) {
806       auto [it, inserted] = this->try_emplace(key, std::move(value));
807       if (!inserted) {
808         // keep the value from the prior save()
809         it->second.count++;
810       }
811     }
restoreStashedVars812     void restore(T* var) {
813       auto it = this->find(var);
814       TORCH_INTERNAL_ASSERT(it != this->end(), "missing before())");
815       if (--it->second.count == 0) {
816         // restore the value on the last restore()
817         *var = std::move(it->second.prior_value);
818         this->erase(it);
819       }
820     }
debug_assertStashedVars821     void debug_assert() {
822       TORCH_INTERNAL_ASSERT(this->empty(), "missing call to after()");
823     }
824   };
825 
826   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
827   AutogradCompilerCall& compiler;
828   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
829   TraceState& state;
830   // This is a borrowed reference, we do not increment ownership, or lower it,
831   // it's lifecycle is entirely longer than this objects.
832   PyObject* py_compiler;
833   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
834   const NodeCall& curr_node_call;
835 
836   // These mappings are used to save the prior values when we overwrite things
837   // in before(). In after(), we use these to cleanup after ourselves.
838   StashedVars<SavedVariable> stashed_variables;
839   StashedVars<at::Tensor> stashed_tensors;
840   StashedVars<c10::SymInt> stashed_symints;
841   StashedVars<at::IValue> stashed_ivalues;
842 };
843 
844 } // namespace torch::dynamo::autograd
845 
846 template <>
847 struct std::hash<torch::dynamo::autograd::CacheKey> {
848   size_t operator()(const torch::dynamo::autograd::CacheKey& k) const {
849     return k.hash();
850   }
851 };
852