1 #pragma once 2 #include <ATen/TensorGeometry.h> 3 #include <ATen/core/ivalue.h> 4 #include <c10/core/impl/TorchDispatchModeTLS.h> 5 #include <c10/util/flat_hash_map.h> 6 #include <torch/csrc/autograd/function.h> 7 #include <torch/csrc/autograd/input_metadata.h> 8 #include <torch/csrc/autograd/saved_variable.h> 9 #include <torch/csrc/autograd/variable_info.h> 10 #include <torch/csrc/utils/python_stub.h> 11 #include <torch/csrc/utils/torch_dispatch_mode.h> 12 #include <typeindex> 13 #include <vector> 14 15 // see [Note: Compiled Autograd] 16 17 namespace torch::dynamo::autograd { 18 using namespace torch::autograd; 19 20 struct SizeInput { 21 // Note: int value is still needed when dynamic to pass as an arg 22 enum DynType : uint8_t { STATIC = 0, DYNAMIC = 1 }; SizeInputSizeInput23 SizeInput(DynType dt, int64_t v) : dyn_type(dt), value(v) {} 24 DynType dyn_type; 25 int64_t value; 26 }; 27 28 struct CacheKeyBuffer { CacheKeyBufferCacheKeyBuffer29 CacheKeyBuffer(const uint8_t* key, uint16_t len) : data(new uint8_t[len]) { 30 std::memcpy(data.get(), key, len); 31 } getCacheKeyBuffer32 const uint8_t* get() const { 33 return data.get(); 34 } 35 36 private: 37 // NOLINTNEXTLINE(*c-array*) 38 std::unique_ptr<uint8_t[]> data; 39 }; 40 41 struct CacheKey { 42 // Key to find the next node in the shadow graph. We use C++ RTTI for the 43 // type of the node (ntype), then a key generated with a visitor pattern. CacheKeyCacheKey44 CacheKey(const std::type_index& ntype, const uint8_t* key, uint16_t len) 45 : node_type(ntype), key_size(len), key(key) {} 46 47 bool operator<(const CacheKey& other) const { 48 if (node_type != other.node_type) { 49 return node_type < other.node_type; 50 } 51 if (key_size != other.key_size) { 52 return key_size < other.key_size; 53 } 54 return std::memcmp(key, other.key, key_size) < 0; 55 } 56 57 bool operator==(const CacheKey& other) const { 58 return node_type == other.node_type && key_size == other.key_size && 59 std::memcmp(key, other.key, key_size) == 0; 60 } 61 hashCacheKey62 size_t hash() const { 63 // don't bother hashing the key data, common case 1 cache entry per node 64 return std::hash<std::type_index>()(node_type) ^ key_size; 65 } 66 67 std::type_index node_type; 68 uint16_t key_size; 69 const uint8_t* key; 70 }; 71 72 struct NodeCall { NodeCallNodeCall73 NodeCall(uint32_t id_, std::shared_ptr<Node> node_) 74 : id(id_), node(std::move(node_)) {} 75 mark_outputNodeCall76 void mark_output(int input_nr, int output_idx) { 77 graph_output.emplace_back(input_nr, output_idx); 78 } 79 80 uint32_t id; 81 std::shared_ptr<Node> node; 82 std::vector<std::pair<int, int>> tensor_pre_hooks; 83 std::vector<int> pre_hooks; 84 std::vector<int> post_hooks; 85 std::vector<int> post_acc_grad_hooks; 86 std::vector<std::pair<int, int>> graph_output; 87 bool needed = true; 88 }; 89 90 struct NodeCalls : public std::unordered_map<Node*, NodeCall> { lookupNodeCalls91 NodeCall& lookup(const std::shared_ptr<Node>& function) { 92 auto it = find(function.get()); 93 if (it == end()) { 94 it = emplace(function.get(), NodeCall(_next_id++, function)).first; 95 } 96 return it->second; 97 } 98 99 private: 100 uint32_t _next_id = 0; 101 }; 102 103 struct TensorArg { 104 // Represents a de-duplicated tensor that will be passed into the graph idTensorArg105 TensorArg(uint32_t i = 0) : id(i) {} indexTensorArg106 uint32_t index() const { 107 TORCH_INTERNAL_ASSERT(defined()); 108 return id - 1; 109 } definedTensorArg110 bool defined() const { 111 return id != 0; 112 } 113 uint32_t id; 114 at::Tensor proxy_tensor; 115 }; 116 117 struct TensorArgs { 118 // Manages a collection of TensorArgs and mappings from Tensors/SavedVariables 119 // to them. This also allows us to unpack SavedVariable exactly once and 120 // store the unpacked Tensor. 121 122 TensorArg& lookup(const at::Tensor& tensor, bool create = false) { 123 if (!tensor.defined()) { 124 return _undefined; 125 } 126 auto impl = tensor.unsafeGetTensorImpl(); 127 auto it = _args.find(impl); 128 if (it == _args.end()) { 129 TORCH_INTERNAL_ASSERT(create && inputs.size() == _next_id - 1); 130 it = _args.emplace(impl, TensorArg(_next_id++)).first; 131 inputs.emplace_back(tensor); 132 } 133 return it->second; 134 } 135 lookupTensorArgs136 TensorArg& lookup(const SavedVariable& sv) { 137 auto it = _saved_variables.find(&sv); 138 TORCH_INTERNAL_ASSERT(it != _saved_variables.end()); 139 return *it->second; 140 } 141 addTensorArgs142 TensorArg& add(const at::Tensor& tensor) { 143 return lookup(tensor, true); 144 } 145 addTensorArgs146 TensorArg& add(const SavedVariable& sv, const std::shared_ptr<Node>& node) { 147 // TODO(jansel): Here we unpack the SavedVariable exactly once. This might 148 // fire SavedTensor hooks. In the future we should try to put saved tensor 149 // hooks into the graph. 150 at::Tensor tensor = sv.unpack(node); 151 TensorArg& arg = add(tensor); 152 _saved_variables.emplace(&sv, &arg); 153 return arg; 154 } 155 156 // the concrete tensors that will get passed into the graph as inputs 157 std::vector<at::Tensor> inputs; 158 159 private: 160 std::unordered_map<const c10::TensorImpl*, TensorArg> _args; 161 // Every TensorArg from this is actually owned by _args (or _undefined) and 162 // that's why we have an un-owned pointer here. 163 std::unordered_map<const SavedVariable*, TensorArg*> _saved_variables; 164 TensorArg _undefined; 165 uint32_t _next_id = 1; // id=0 used by _undefined 166 }; 167 168 struct LiftedIValueArg { 169 LiftedIValueArg() = delete; LiftedIValueArgLiftedIValueArg170 LiftedIValueArg(const at::IValue* ptr) 171 : actual_ptr(ptr), proxy(at::IValue::uninitialized()) {} 172 173 const at::IValue* actual_ptr; // lifetime handled by autograd node 174 at::IValue proxy; 175 }; 176 177 struct LiftedIValueArgs { next_proxyLiftedIValueArgs178 at::IValue& next_proxy(const at::IValue* actual_ptr) { 179 TORCH_INTERNAL_ASSERT(next < args.size()); 180 auto& iv_arg = args.at(next++); 181 TORCH_INTERNAL_ASSERT(iv_arg.actual_ptr == actual_ptr); 182 return iv_arg.proxy; 183 } 184 185 std::vector<LiftedIValueArg> args; 186 size_t next = 0; 187 }; 188 189 struct AutogradCompilerCall { add_size_inputAutogradCompilerCall190 void add_size_input(const c10::SymInt& s) { 191 all_size_inputs.emplace_back( 192 default_dyn_type, s.guard_int(__FILE__, __LINE__)); 193 } 194 emplace_hookAutogradCompilerCall195 size_t emplace_hook(c10::SafePyObject&& fn) { 196 hooks.emplace_back(std::move(fn)); 197 return hooks.size() - 1; 198 } 199 200 TensorArgs tensor_args; 201 std::vector<SizeInput> all_size_inputs; 202 LiftedIValueArgs lifted_ivalue_args; 203 std::vector<int64_t> dyn_size_inputs; 204 std::vector<c10::SafePyObject> hooks; 205 NodeCalls node_calls; 206 SizeInput::DynType default_dyn_type = SizeInput::STATIC; 207 }; 208 209 class CompiledNodeArgs { 210 // CompiledNodeArgs builds a representation of the constant values found 211 // across all the nodes in the compiled graph, via 'collect' overloads. The 212 // collected constants are specialized on by concatenation into a cache key. 213 // Tensor, symint arguments (which are lifted to become graph inputs rather 214 // than specialized on) are forwarded to the compiler and not included in the 215 // key. 216 public: collect(const TensorArg & t)217 void collect(const TensorArg& t) { 218 collect_size(t.id); 219 if (t.defined()) { 220 const at::Tensor& tensor = _compiler.tensor_args.inputs[t.index()]; 221 // including these in the cache key means dynamo-level tensor guards can 222 // be skipped 223 collect(tensor.device()); 224 collect(tensor.dtype()); 225 collect(tensor.requires_grad()); 226 } 227 } 228 collect(const at::Tensor & t)229 void collect(const at::Tensor& t) { 230 collect(_compiler.tensor_args.add(t)); 231 } collect(const SavedVariable & sv,bool is_output)232 void collect(const SavedVariable& sv, bool is_output) { 233 collect( 234 _compiler.tensor_args.add(sv, is_output ? _node_call.node : nullptr)); 235 } collect(const c10::SymInt & t)236 void collect(const c10::SymInt& t) { 237 _compiler.add_size_input(t); 238 } collect(const std::vector<SavedVariable> & t,bool is_output)239 void collect(const std::vector<SavedVariable>& t, bool is_output) { 240 collect_size(t.size()); 241 for (const SavedVariable& i : t) { 242 collect(i, is_output); 243 } 244 } 245 template <typename T> collect(const std::vector<T> & t)246 void collect(const std::vector<T>& t) { 247 collect_size(t.size()); 248 for (const T& i : t) { 249 collect(i); 250 } 251 } collect(const c10::ArrayRef<SavedVariable> & t,bool is_output)252 void collect(const c10::ArrayRef<SavedVariable>& t, bool is_output) { 253 collect_size(t.size()); 254 for (const SavedVariable& i : t) { 255 collect(i, is_output); 256 } 257 } 258 template <typename T> collect(const c10::ArrayRef<T> & t)259 void collect(const c10::ArrayRef<T>& t) { 260 collect_size(t.size()); 261 for (const T& i : t) { 262 collect(i); 263 } 264 } 265 template <typename T> collect(const c10::OptionalArray<T> & t)266 void collect(const c10::OptionalArray<T>& t) { 267 collect(t.list); 268 } 269 template <typename T> collect(const std::optional<T> & t)270 void collect(const std::optional<T>& t) { 271 if (cond(t.has_value())) { 272 collect(*t); 273 } 274 } 275 template <typename A, typename B> collect(const std::pair<A,B> & t)276 void collect(const std::pair<A, B>& t) { 277 collect(t.first); 278 collect(t.second); 279 } 280 template <typename V> collect(const ska::flat_hash_map<std::string,V> & m)281 void collect(const ska::flat_hash_map<std::string, V>& m) { 282 collect_size(m.size()); 283 284 std::vector<std::string> keys; 285 keys.reserve(m.size()); 286 std::transform( 287 m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) { 288 return entry.first; 289 }); 290 std::sort(keys.begin(), keys.end()); 291 for (const auto& k : keys) { 292 collect(k); 293 collect(m.at(k)); 294 } 295 } 296 void collect(const at::IValue& iv, bool nested = false) { 297 // used by AutogradContext::saved_data from CppNode 298 if (iv.isList()) { 299 c10::List<at::IValue> list = iv.toList(); 300 collect_size(list.size()); 301 for (auto&& value : list) { 302 collect(value, true); 303 } 304 } else if (iv.isGenericDict()) { 305 c10::Dict<at::IValue, at::IValue> ordered_dict = iv.toGenericDict(); 306 collect_size(ordered_dict.size()); 307 // NOLINTNEXTLINE(modernize-loop-convert) 308 for (auto it = ordered_dict.begin(); it != ordered_dict.end(); it++) { 309 collect(it->key()); 310 collect(it->value(), true); 311 } 312 } else if (iv.isTensor()) { 313 collect(iv.toTensor()); 314 } else if ( 315 !nested && 316 (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat())) { 317 // can't lift ivalues nested in collections 318 _compiler.lifted_ivalue_args.args.emplace_back(&iv); 319 } else { 320 try { 321 collect(static_cast<uint64_t>(at::IValue::hash(iv))); catch(const std::runtime_error & e)322 } catch (const std::runtime_error& e) { 323 std::string msg = 324 "Compiled autograd can not trace unhashable IValues, error: " + 325 std::string(e.what()); 326 TORCH_CHECK_NOT_IMPLEMENTED(false, msg); 327 } 328 } 329 } collect(const c10::Scalar & t)330 void collect(const c10::Scalar& t) { 331 auto type = t.type(); 332 specialize_on_bytes(type); 333 if (type == c10::ScalarType::Double) { 334 collect(t.toDouble()); 335 } else if (type == c10::ScalarType::Long) { 336 collect(t.toLong()); 337 } else if (type == c10::ScalarType::Bool) { 338 collect(t.toBool()); 339 } else if (type == c10::ScalarType::ComplexDouble) { 340 auto c = t.toComplexDouble(); 341 collect(c.real()); 342 collect(c.imag()); 343 } else { 344 TORCH_INTERNAL_ASSERT(false); 345 } 346 } collect(const c10::TensorOptions & t)347 void collect(const c10::TensorOptions& t) { 348 collect(t.device()); 349 collect(t.dtype()); 350 collect(t.layout()); 351 collect(t.requires_grad()); 352 collect(t.pinned_memory()); 353 collect(t.memory_format_opt()); 354 } collect(const at::TensorGeometry & t)355 void collect(const at::TensorGeometry& t) { 356 collect(t.sym_sizes()); 357 collect(t.sym_strides()); 358 collect(t.sym_storage_offset()); 359 } collect(const torch::autograd::TypeAndSize & t)360 void collect(const torch::autograd::TypeAndSize& t) { 361 collect(t.sym_sizes); 362 collect(t.options); 363 } collect(const c10::Device & t)364 void collect(const c10::Device& t) { 365 collect(t.type()); 366 collect(t.index()); 367 } collect(const std::string & t)368 void collect(const std::string& t) { 369 collect_size(t.size()); 370 for (char c : t) { 371 collect(c); 372 } 373 } collect(const caffe2::TypeMeta & t)374 void collect(const caffe2::TypeMeta& t) { 375 specialize_on_bytes(t.id()); 376 } collect(const std::shared_ptr<Node> & t)377 void collect(const std::shared_ptr<Node>& t) { 378 // Note: this is only capturing the ID of the node not everything 379 // contained inside it. This is used for tracking connections between 380 // nodes and the actual details of the node itself must be handled by 381 // a seperate call to `node->compiled_args()`. 382 if (cond((bool)t)) { 383 collect(_compiler.node_calls.lookup(t)); 384 } 385 } collect(const NodeCall & t)386 void collect(const NodeCall& t) { 387 collect_size(t.id); 388 collect(t.graph_output); 389 collect_hooks_from(t.node.get()); 390 } collect(const Edge & t)391 void collect(const Edge& t) { 392 if (cond(t.is_valid())) { 393 collect_size(_compiler.node_calls.lookup(t.function).id); 394 collect_size(t.input_nr); 395 collect(t.function->input_metadata(t.input_nr)); // for validate_outputs 396 } 397 } collect(const InputMetadata & t)398 void collect(const InputMetadata& t) { 399 TORCH_CHECK(!t.is_nested_tensor(), "NestedTensor not implemented"); 400 collect(t.options()); 401 collect(t.is_tensor_subclass()); 402 collect(t.shape_as_dim_vector()); 403 } collect(const VariableInfo & t)404 void collect(const VariableInfo& t) { 405 collect(t.layout); 406 collect(t.device); 407 collect(t.scalar_type); 408 collect(t.size); 409 collect(t.requires_grad); 410 collect(t.is_empty); 411 } cond(bool cond)412 bool cond(bool cond) { 413 collect(cond); 414 return cond; 415 } 416 417 #define COLLECT_AS_BYTES(T) \ 418 void collect(T t) { \ 419 specialize_on_bytes(t); \ 420 } 421 COLLECT_AS_BYTES(c10::ScalarType); 422 COLLECT_AS_BYTES(c10::DeviceType); 423 COLLECT_AS_BYTES(c10::Layout); 424 COLLECT_AS_BYTES(c10::MemoryFormat); 425 COLLECT_AS_BYTES(int8_t); 426 COLLECT_AS_BYTES(int16_t); 427 COLLECT_AS_BYTES(int32_t); 428 COLLECT_AS_BYTES(int64_t); 429 COLLECT_AS_BYTES(uint8_t); 430 COLLECT_AS_BYTES(uint16_t); 431 COLLECT_AS_BYTES(uint32_t); 432 COLLECT_AS_BYTES(uint64_t); 433 COLLECT_AS_BYTES(bool); 434 COLLECT_AS_BYTES(float); 435 COLLECT_AS_BYTES(double); 436 #undef COLLECT_AS_BYTES 437 collect_hooks_from(Node * fn)438 void collect_hooks_from(Node* fn) { 439 TORCH_CHECK( 440 fn->retains_grad_hooks().empty(), 441 "retains_grad_hooks not implemented for compiled autograd"); 442 for (auto& i : fn->tensor_pre_hooks()) { 443 i->compiled_args(*this); 444 } 445 for (auto& i : fn->pre_hooks()) { 446 i->compiled_args(*this); 447 } 448 for (auto& i : fn->post_hooks()) { 449 i->compiled_args(*this); 450 } 451 collect_size(_node_call.tensor_pre_hooks.size()); 452 collect_size(_node_call.pre_hooks.size()); 453 collect_size(_node_call.post_hooks.size()); 454 for (const auto& h : _node_call.tensor_pre_hooks) { 455 collect_size(static_cast<size_t>(h.second)); 456 } 457 } 458 key()459 CacheKey key() const { 460 Node* node = _node_call.node.get(); 461 return CacheKey( 462 typeid(*node), _specialization_key, _specialization_key_size); 463 } 464 add_backward(c10::SafePyObject && obj)465 size_t add_backward(c10::SafePyObject&& obj) { 466 return _compiler.emplace_hook(std::move(obj)); 467 } 468 add_backward_state(c10::SafePyObject && obj)469 size_t add_backward_state(c10::SafePyObject&& obj) { 470 return _compiler.emplace_hook(std::move(obj)); 471 } 472 add_tensor_pre_hook(c10::SafePyObject && obj,int index)473 void add_tensor_pre_hook(c10::SafePyObject&& obj, int index) { 474 auto fn_id = _compiler.emplace_hook(std::move(obj)); 475 collect_size(fn_id); 476 _node_call.tensor_pre_hooks.emplace_back(fn_id, index); 477 } 478 add_pre_hook(c10::SafePyObject && obj)479 void add_pre_hook(c10::SafePyObject&& obj) { 480 auto fn_id = _compiler.emplace_hook(std::move(obj)); 481 collect_size(fn_id); 482 _node_call.pre_hooks.emplace_back(fn_id); 483 } 484 add_post_hook(c10::SafePyObject && obj)485 void add_post_hook(c10::SafePyObject&& obj) { 486 auto fn_id = _compiler.emplace_hook(std::move(obj)); 487 collect_size(fn_id); 488 _node_call.post_hooks.emplace_back(fn_id); 489 } 490 add_post_acc_grad_hook(c10::SafePyObject && obj)491 void add_post_acc_grad_hook(c10::SafePyObject&& obj) { 492 auto fn_id = _compiler.emplace_hook(std::move(obj)); 493 collect_size(fn_id); 494 _node_call.post_acc_grad_hooks.emplace_back(fn_id); 495 } 496 497 // Need to template the size_t to silence internal 32-bit build errors due to 498 // a mix of -Werror, -Wtautological-type-limit-compare and 499 // -Wunknown-pragmas 500 template <typename T> collect_size(T s)501 std::enable_if_t<std::is_unsigned_v<T>, void> collect_size(T s) { 502 // we expect sizes to be small, so try to cram them into a single byte 503 constexpr uint8_t encode_as_u64 = std::numeric_limits<uint8_t>::max(); 504 constexpr uint8_t encode_as_u32 = encode_as_u64 - 1; 505 constexpr uint8_t encode_as_u16 = encode_as_u64 - 2; 506 if (C10_UNLIKELY(s >= encode_as_u16)) { 507 // first write a byte indicating the path we followed, then the data 508 if (s <= std::numeric_limits<uint16_t>::max()) { 509 // 3 bytes 510 specialize_on_bytes(encode_as_u16); 511 specialize_on_bytes(static_cast<uint16_t>(s)); 512 } else if (s <= std::numeric_limits<uint32_t>::max()) { 513 // 5 bytes 514 specialize_on_bytes(encode_as_u32); 515 specialize_on_bytes(static_cast<uint32_t>(s)); 516 } else { 517 // 9 bytes 518 specialize_on_bytes(encode_as_u64); 519 specialize_on_bytes(s); 520 } 521 } else { 522 // happy case, 1 byte 523 specialize_on_bytes(static_cast<uint8_t>(s)); 524 } 525 } 526 set_default_dyn_type(SizeInput::DynType default_dyn_type)527 SizeInput::DynType set_default_dyn_type(SizeInput::DynType default_dyn_type) { 528 return std::exchange(_compiler.default_dyn_type, default_dyn_type); 529 } 530 CompiledNodeArgs(AutogradCompilerCall & compiler,NodeCall & node_call)531 CompiledNodeArgs(AutogradCompilerCall& compiler, NodeCall& node_call) 532 : _compiler(compiler), 533 _node_call(node_call), 534 _specialization_key( 535 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 536 (uint8_t*)std::malloc(_specialization_key_storage)) {} ~CompiledNodeArgs()537 ~CompiledNodeArgs() { 538 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 539 std::free(_specialization_key); 540 } 541 CompiledNodeArgs(const CompiledNodeArgs&) = delete; 542 543 private: 544 template <typename T> specialize_on_bytes(const T & t)545 void specialize_on_bytes(const T& t) { 546 while (C10_UNLIKELY( 547 _specialization_key_size + sizeof(T) > _specialization_key_storage)) { 548 _specialization_key_storage *= 2; 549 // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) 550 _specialization_key = (uint8_t*)std::realloc( 551 _specialization_key, _specialization_key_storage); 552 } 553 std::memcpy(_specialization_key + _specialization_key_size, &t, sizeof(T)); 554 _specialization_key_size += sizeof(T); 555 } 556 557 AutogradCompilerCall& _compiler; 558 NodeCall& _node_call; 559 size_t _specialization_key_size{0}; 560 size_t _specialization_key_storage{1024}; 561 uint8_t* _specialization_key; 562 }; 563 564 struct TraceState { TraceStateTraceState565 TraceState(std::vector<std::optional<c10::SymInt>>&& ss, size_t num_outputs) 566 : sym_sizes(ss), outputs(num_outputs) {} 567 debug_assertsTraceState568 void debug_asserts() { 569 TORCH_INTERNAL_ASSERT(sym_sizes_index == sym_sizes.size()); 570 } next_sym_sizeTraceState571 std::optional<c10::SymInt> next_sym_size() { 572 TORCH_INTERNAL_ASSERT(sym_sizes_index < sym_sizes.size()); 573 return sym_sizes[sym_sizes_index++]; 574 } 575 576 size_t sym_sizes_index{0}; 577 std::vector<std::optional<c10::SymInt>> sym_sizes; 578 variable_list outputs; 579 }; 580 581 class SwapSavedVariables { 582 // SwapSavedVariables is used during the tracing/compilation phase after a 583 // cache-miss. It swaps any 'lifted' inputs (tensors, symints) to proxy nodes, 584 // allows tracing to happen, then swaps them back afterwards. 585 public: before(at::Tensor & t)586 void before(at::Tensor& t) { 587 TensorArg& arg = compiler.tensor_args.lookup(t); 588 stashed_tensors.save(&t, std::move(t)); 589 if (arg.defined()) { 590 TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); 591 t = arg.proxy_tensor; 592 } 593 } after(at::Tensor & t)594 void after(at::Tensor& t) { 595 stashed_tensors.restore(&t); 596 } 597 before(SavedVariable & t)598 void before(SavedVariable& t) { 599 TensorArg& arg = compiler.tensor_args.lookup(t); 600 stashed_variables.save(&t, std::move(t)); 601 if (arg.defined()) { 602 bool prior = at::SavedTensorDefaultHooks::set_tracing(true); 603 TORCH_INTERNAL_ASSERT(arg.proxy_tensor.defined()); 604 t = SavedVariable(arg.proxy_tensor, false); 605 at::SavedTensorDefaultHooks::set_tracing(prior); 606 } 607 } after(SavedVariable & t)608 void after(SavedVariable& t) { 609 stashed_variables.restore(&t); 610 } 611 before(c10::SymInt & t)612 void before(c10::SymInt& t) { 613 stashed_symints.save(&t, c10::SymInt(t)); 614 auto opt_value = state.next_sym_size(); 615 if (opt_value.has_value()) { 616 t = *opt_value; // dynamic shape 617 } 618 } after(c10::SymInt & t)619 void after(c10::SymInt& t) { 620 stashed_symints.restore(&t); 621 } 622 before(at::IValue & iv)623 void before(at::IValue& iv) { 624 if (iv.isTensor()) { 625 before(iv.toTensor()); 626 } else { 627 stashed_ivalues.save(&iv, at::IValue(iv)); 628 if (iv.isInt() || iv.isSymInt() || iv.isDouble() || iv.isSymFloat()) { 629 iv = compiler.lifted_ivalue_args.next_proxy(&iv); 630 } 631 } 632 } 633 after(at::IValue & t)634 void after(at::IValue& t) { 635 if (t.isTensor()) { 636 after(t.toTensor()); 637 } else { 638 stashed_ivalues.restore(&t); 639 } 640 } 641 before(Edge & t)642 void before(Edge& t) { 643 if (t.is_valid()) { 644 // need for symints used by validate_outputs 645 before(t.function->mutable_input_metadata(t.input_nr)); 646 } 647 } after(Edge & t)648 void after(Edge& t) { 649 if (t.is_valid()) { 650 after(t.function->mutable_input_metadata(t.input_nr)); 651 } 652 } before(InputMetadata & t)653 void before(InputMetadata& t) { 654 before(t.mutable_shape_as_dim_vector()); 655 } after(InputMetadata & t)656 void after(InputMetadata& t) { 657 after(t.mutable_shape_as_dim_vector()); 658 } before(at::TensorGeometry & t)659 void before(at::TensorGeometry& t) { 660 before(t.mutable_sizes()); 661 before(t.mutable_strides()); 662 before(t.mutable_storage_offset()); 663 t.recompute(); 664 } after(at::TensorGeometry & t)665 void after(at::TensorGeometry& t) { 666 after(t.mutable_sizes()); 667 after(t.mutable_strides()); 668 after(t.mutable_storage_offset()); 669 t.recompute(); 670 } before(torch::autograd::TypeAndSize & t)671 void before(torch::autograd::TypeAndSize& t) { 672 before(t.sym_sizes); 673 before(t.options); 674 } after(torch::autograd::TypeAndSize & t)675 void after(torch::autograd::TypeAndSize& t) { 676 after(t.sym_sizes); 677 after(t.options); 678 } before(VariableInfo & t)679 void before(VariableInfo& t) { 680 before(t.size); 681 } after(VariableInfo & t)682 void after(VariableInfo& t) { 683 after(t.size); 684 } 685 686 template <typename T> before(std::vector<T> & t)687 void before(std::vector<T>& t) { 688 for (T& i : t) { 689 before(i); 690 } 691 } 692 template <typename T> after(std::vector<T> & t)693 void after(std::vector<T>& t) { 694 for (T& i : t) { 695 after(i); 696 } 697 } 698 template <typename T, unsigned N> before(c10::SmallVector<T,N> & t)699 void before(c10::SmallVector<T, N>& t) { 700 for (T& i : t) { 701 before(i); 702 } 703 } 704 template <typename T, unsigned N> after(c10::SmallVector<T,N> & t)705 void after(c10::SmallVector<T, N>& t) { 706 for (T& i : t) { 707 after(i); 708 } 709 } 710 711 template <typename T> before(c10::OptionalArray<T> & t)712 void before(c10::OptionalArray<T>& t) { 713 before(t.list); 714 } 715 template <typename T> after(c10::OptionalArray<T> & t)716 void after(c10::OptionalArray<T>& t) { 717 after(t.list); 718 } 719 720 template <typename T> before(std::optional<T> & t)721 void before(std::optional<T>& t) { 722 if (t.has_value()) { 723 before(*t); 724 } 725 } 726 template <typename T> after(std::optional<T> & t)727 void after(std::optional<T>& t) { 728 if (t.has_value()) { 729 after(*t); 730 } 731 } 732 733 template <typename V> before(ska::flat_hash_map<std::string,V> & m)734 void before(ska::flat_hash_map<std::string, V>& m) { 735 std::vector<std::string> keys; 736 keys.reserve(m.size()); 737 std::transform( 738 m.begin(), m.end(), std::back_inserter(keys), [](const auto& entry) { 739 return entry.first; 740 }); 741 std::sort(keys.begin(), keys.end()); 742 for (auto& k : keys) { 743 before(m.at(k)); 744 } 745 } 746 747 template <typename V> after(ska::flat_hash_map<std::string,V> & m)748 void after(ska::flat_hash_map<std::string, V>& m) { 749 for (auto& [_, v] : m) { 750 after(v); 751 } 752 } 753 754 #define NO_OP_VISIT(T) \ 755 void before(const T&) {} \ 756 void after(const T&) {} 757 NO_OP_VISIT(caffe2::TypeMeta); 758 NO_OP_VISIT(c10::Device); 759 NO_OP_VISIT(c10::DeviceType); 760 NO_OP_VISIT(c10::Layout); 761 NO_OP_VISIT(c10::MemoryFormat); 762 NO_OP_VISIT(c10::ScalarType); 763 NO_OP_VISIT(c10::Scalar); 764 NO_OP_VISIT(c10::TensorOptions); 765 NO_OP_VISIT(std::string); 766 NO_OP_VISIT(int64_t); 767 NO_OP_VISIT(bool); 768 NO_OP_VISIT(double); 769 #undef NO_OP_VISIT 770 SwapSavedVariables(AutogradCompilerCall & c,TraceState & s,PyObject * p,const NodeCall & n)771 SwapSavedVariables( 772 AutogradCompilerCall& c, 773 TraceState& s, 774 PyObject* p, 775 const NodeCall& n) 776 : compiler(c), state(s), py_compiler(p), curr_node_call(n) {} 777 get_py_compiler()778 PyObject* get_py_compiler() { 779 return py_compiler; 780 } 781 get_curr_node_call()782 const NodeCall& get_curr_node_call() { 783 return curr_node_call; 784 } 785 debug_asserts()786 void debug_asserts() { 787 stashed_variables.debug_assert(); 788 stashed_tensors.debug_assert(); 789 stashed_symints.debug_assert(); 790 } 791 792 private: 793 template <typename T> 794 struct Stashed { StashedStashed795 Stashed(T&& v) : prior_value(std::move(v)) {} 796 T prior_value; 797 // Note: we need count here to support duplicate calls to before() 798 // which happen when we have multiple autograd::Edge objects pointing 799 // to the same autograd::Node 800 int count = 1; 801 }; 802 803 template <typename T> 804 struct StashedVars : public std::unordered_map<const T*, Stashed<T>> { saveStashedVars805 void save(const T* key, T&& value) { 806 auto [it, inserted] = this->try_emplace(key, std::move(value)); 807 if (!inserted) { 808 // keep the value from the prior save() 809 it->second.count++; 810 } 811 } restoreStashedVars812 void restore(T* var) { 813 auto it = this->find(var); 814 TORCH_INTERNAL_ASSERT(it != this->end(), "missing before())"); 815 if (--it->second.count == 0) { 816 // restore the value on the last restore() 817 *var = std::move(it->second.prior_value); 818 this->erase(it); 819 } 820 } debug_assertStashedVars821 void debug_assert() { 822 TORCH_INTERNAL_ASSERT(this->empty(), "missing call to after()"); 823 } 824 }; 825 826 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 827 AutogradCompilerCall& compiler; 828 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 829 TraceState& state; 830 // This is a borrowed reference, we do not increment ownership, or lower it, 831 // it's lifecycle is entirely longer than this objects. 832 PyObject* py_compiler; 833 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 834 const NodeCall& curr_node_call; 835 836 // These mappings are used to save the prior values when we overwrite things 837 // in before(). In after(), we use these to cleanup after ourselves. 838 StashedVars<SavedVariable> stashed_variables; 839 StashedVars<at::Tensor> stashed_tensors; 840 StashedVars<c10::SymInt> stashed_symints; 841 StashedVars<at::IValue> stashed_ivalues; 842 }; 843 844 } // namespace torch::dynamo::autograd 845 846 template <> 847 struct std::hash<torch::dynamo::autograd::CacheKey> { 848 size_t operator()(const torch::dynamo::autograd::CacheKey& k) const { 849 return k.hash(); 850 } 851 }; 852