1 #include <ATen/ATen.h>
2 #include <ATen/core/Dict.h>
3 #ifdef USE_RPC
4 #include <torch/csrc/distributed/rpc/rref_context.h>
5 #endif
6 #include <ATen/quantized/Quantizer.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/serialization/pickler.h>
10 #include <torch/csrc/utils/byte_order.h>
11 #include <string>
12 #include <type_traits>
13
14 namespace torch::jit {
15
16 using ::c10::IValue;
17
18 // Protocol 2 is the highest that can be decoded by Python 2
19 // See https://docs.python.org/3/library/pickle.html#data-stream-format
20 constexpr static uint8_t PROTOCOL_VERSION = 2;
21
22 // NOLINTNEXTLINE(bugprone-exception-escape)
~Pickler()23 Pickler::~Pickler() {
24 flush();
25 }
26
protocol()27 void Pickler::protocol() {
28 push<PickleOpCode>(PickleOpCode::PROTO);
29 push<uint8_t>(PROTOCOL_VERSION);
30 }
31
startTuple()32 void Pickler::startTuple() {
33 // All attributes get pushed into a tuple and their indices saved in the
34 // module def
35 push<PickleOpCode>(PickleOpCode::MARK);
36 }
37
endTuple()38 void Pickler::endTuple() {
39 push<PickleOpCode>(PickleOpCode::TUPLE);
40 }
41
stop()42 void Pickler::stop() {
43 push<PickleOpCode>(PickleOpCode::STOP);
44 flush();
45 }
46
47 // unmemoized version called by pushIValue
pushIValueImpl(const IValue & ivalue)48 void Pickler::pushIValueImpl(const IValue& ivalue) {
49 if (ivalue.isTensor()) {
50 pushTensor(ivalue);
51 } else if (ivalue.isTuple()) {
52 pushTuple(ivalue);
53 } else if (ivalue.isDouble()) {
54 pushDouble(ivalue.toDouble());
55 } else if (ivalue.isComplexDouble()) {
56 pushComplexDouble(ivalue);
57 } else if (ivalue.isInt()) {
58 pushInt(ivalue.toInt());
59 } else if (ivalue.isBool()) {
60 pushBool(ivalue.toBool());
61 } else if (ivalue.isString()) {
62 pushString(ivalue.toStringRef());
63 } else if (ivalue.isGenericDict()) {
64 pushDict(ivalue);
65 } else if (ivalue.isNone()) {
66 push<PickleOpCode>(PickleOpCode::NONE);
67 } else if (ivalue.isIntList()) {
68 pushSpecializedList(ivalue, "build_intlist", [this](const IValue& ivalue) {
69 for (const int64_t item : ivalue.toIntVector()) {
70 pushInt(item);
71 }
72 });
73 } else if (ivalue.isTensorList()) {
74 pushSpecializedList(
75 ivalue, "build_tensorlist", [this](const IValue& ivalue) {
76 for (const at::Tensor& item : ivalue.toTensorVector()) {
77 pushIValue(item);
78 }
79 });
80 } else if (ivalue.isDoubleList()) {
81 pushSpecializedList(
82 ivalue, "build_doublelist", [this](const IValue& ivalue) {
83 for (double item : ivalue.toDoubleVector()) {
84 pushDouble(item);
85 }
86 });
87 } else if (ivalue.isBoolList()) {
88 pushSpecializedList(ivalue, "build_boollist", [this](const IValue& ivalue) {
89 for (bool item : ivalue.toBoolList()) {
90 pushBool(item);
91 }
92 });
93 // note: isList must be after isIntList and friends because
94 // isList is true for all lists.
95 } else if (ivalue.isList()) {
96 pushGenericList(ivalue);
97 } else if (ivalue.isObject()) {
98 auto obj = ivalue.toObject();
99 auto type = obj->type();
100 if (memoized_class_types_ != nullptr) {
101 // memoize every class type the Pickler encountered
102 // This is used to make sure we capture all the run-time types
103 // and serialize them properly for class/interface polymorphism
104 memoized_class_types_->emplace_back(type);
105 }
106 auto type_name = type->name().value();
107 if (type_renamer_) {
108 type_name = type_renamer_(type);
109 }
110 pushGlobal(type_name.prefix(), type_name.name());
111 push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
112 push<PickleOpCode>(PickleOpCode::NEWOBJ);
113 if (checkHasValidSetGetState(type)) {
114 Function& getstate = type->getMethod("__getstate__");
115 pushIValue(getstate({obj}));
116 } else {
117 push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
118 push<PickleOpCode>(PickleOpCode::MARK);
119 for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
120 pushString(type->getAttributeName(i));
121 pushIValue(obj->getSlot(i));
122 }
123 push<PickleOpCode>(PickleOpCode::SETITEMS);
124 }
125 push<PickleOpCode>(PickleOpCode::BUILD);
126 } else if (ivalue.isDevice()) {
127 pushDevice(ivalue);
128 } else if (ivalue.isCapsule()) {
129 std::stringstream err;
130 err << "Cannot serialize custom bound C++ class";
131 if (memoized_class_types_ && !memoized_class_types_->empty()) {
132 if (auto qualname = memoized_class_types_->back()->name()) {
133 err << " " << qualname->qualifiedName();
134 }
135 }
136 err << ". Please define serialization methods via def_pickle() for "
137 "this class.";
138 AT_ERROR(err.str());
139 } else if (ivalue.isRRef()) {
140 #ifdef USE_RPC
141 TORCH_CHECK(
142 torch::distributed::rpc::getAllowJitRRefPickle() == true,
143 "RRef jit pickling is only allowed inside RPC calls.");
144 pushRRef(ivalue);
145 #else
146 TORCH_CHECK(
147 false, "RRef pickling is only supported with the distributed package");
148 #endif
149 } else if (ivalue.isEnum()) {
150 auto enum_holder = ivalue.toEnumHolder();
151 const auto& qualified_class_name =
152 enum_holder->type()->qualifiedClassName();
153 pushGlobal(qualified_class_name.prefix(), qualified_class_name.name());
154 pushIValue(enum_holder->value());
155 push<PickleOpCode>(PickleOpCode::REDUCE);
156 } else {
157 AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
158 }
159 }
160
pushDevice(const IValue & ivalue)161 void Pickler::pushDevice(const IValue& ivalue) {
162 auto device = ivalue.toDevice();
163 auto deviceStr = device.str();
164 auto it = memoized_devices_map_.find(deviceStr);
165 if (it == memoized_devices_map_.end()) {
166 pushGlobal("torch", "device");
167 pushString(deviceStr);
168 push<PickleOpCode>(PickleOpCode::TUPLE1);
169 push<PickleOpCode>(PickleOpCode::REDUCE);
170 memoized_devices_map_[deviceStr] = pushNextBinPut();
171 } else {
172 pushBinGet(it->second);
173 }
174 }
175
176 #ifdef USE_RPC
pushRRef(const IValue & ivalue)177 void Pickler::pushRRef(const IValue& ivalue) {
178 // It is the same as how rref is pickled in python, see PyRRef::pickle
179 auto rrefInterface = ivalue.toRRef();
180 auto rref =
181 c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(rrefInterface);
182 pushGlobal("torch.distributed.rpc", "rref");
183 auto& ctx = distributed::rpc::RRefContext::getInstance();
184 auto rrefForkData = ctx.prepareChildFork(rref);
185 push<PickleOpCode>(PickleOpCode::MARK);
186 pushInt(rrefForkData.ownerId_);
187 pushInt(rrefForkData.rrefId_.createdOn_);
188 pushInt(rrefForkData.rrefId_.localId_);
189 pushInt(rrefForkData.forkId_.createdOn_);
190 pushInt(rrefForkData.forkId_.localId_);
191 pushInt(rrefForkData.parent_);
192 pushString(rrefForkData.typeStr_);
193 push<PickleOpCode>(PickleOpCode::TUPLE);
194 push<PickleOpCode>(PickleOpCode::REDUCE);
195 }
196 #endif
197
pushIValue(const IValue & ivalue)198 void Pickler::pushIValue(const IValue& ivalue) {
199 bool shouldMemoizeByPointer =
200 ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1;
201
202 // Mutable ivalues are memoized by pointer equality, which we handle at this
203 // outer granularity. Immutable ivalues are memoized by value equality which
204 // is handled in the type-specific handlers inside pushIValueImpl.
205 if (shouldMemoizeByPointer) {
206 const void* ptr = ivalue.internalToPointer();
207 TORCH_CHECK(
208 ptr != nullptr,
209 "Pickler cannot memoize ",
210 ivalue.tagKind(),
211 " IValue ",
212 ivalue);
213 auto memo_entry = memoized_ivalue_map_.find(ptr);
214 if (memo_entry != memoized_ivalue_map_.end()) {
215 // This value has already been pushed, just do a BINGET
216 pushBinGet(memo_entry->second);
217 return;
218 }
219
220 pushIValueImpl(ivalue);
221
222 memoized_ivalues_.push_back(ivalue);
223 memoized_ivalue_map_[ptr] = pushNextBinPut();
224 } else {
225 pushIValueImpl(ivalue);
226 }
227 }
228
pushInt(int64_t n)229 void Pickler::pushInt(int64_t n) {
230 if (n >= std::numeric_limits<uint8_t>::min() &&
231 n <= std::numeric_limits<uint8_t>::max()) {
232 push<PickleOpCode>(PickleOpCode::BININT1);
233 push<uint8_t>(n);
234 } else if (
235 n >= std::numeric_limits<uint16_t>::min() &&
236 n <= std::numeric_limits<uint16_t>::max()) {
237 push<PickleOpCode>(PickleOpCode::BININT2);
238 push<uint16_t>(to_le16(n));
239 } else if (
240 n >= std::numeric_limits<int32_t>::min() &&
241 n <= std::numeric_limits<int32_t>::max()) {
242 push<PickleOpCode>(PickleOpCode::BININT);
243 push<int32_t>(to_le32(n));
244 } else {
245 // Push 8 byte integer
246 push<PickleOpCode>(PickleOpCode::LONG1);
247 push<uint8_t>(8);
248 push<int64_t>(to_le64(n));
249 }
250 }
251
pushBool(bool value)252 void Pickler::pushBool(bool value) {
253 push<PickleOpCode>(value ? PickleOpCode::NEWTRUE : PickleOpCode::NEWFALSE);
254 }
255
pushBinGet(uint32_t memo_id)256 void Pickler::pushBinGet(uint32_t memo_id) {
257 if (memo_id <= std::numeric_limits<uint8_t>::max()) {
258 push<PickleOpCode>(PickleOpCode::BINGET);
259 push<uint8_t>(memo_id);
260 } else {
261 // Memoized too many items, issue a LONG_BINGET instead
262 push<PickleOpCode>(PickleOpCode::LONG_BINGET);
263 push<uint32_t>(memo_id);
264 }
265 }
266
267 // unmemoized encoding of a string
pushStringImpl(const std::string & string)268 void Pickler::pushStringImpl(const std::string& string) {
269 if (string.size() <= UINT_MAX) {
270 push<PickleOpCode>(PickleOpCode::BINUNICODE);
271 push<uint32_t>(to_le32(string.size()));
272 pushBytes(string);
273 } else {
274 push<PickleOpCode>(PickleOpCode::BINUNICODE8);
275 push<int64_t>(to_le64(string.size()));
276 pushBytes(string);
277 }
278 }
279
pushString(const std::string & string)280 void Pickler::pushString(const std::string& string) {
281 auto it = memoized_strings_map_.find(string);
282 if (it == memoized_strings_map_.end()) {
283 pushStringImpl(string);
284 memoized_strings_map_[string] = pushNextBinPut();
285 } else {
286 pushBinGet(it->second);
287 }
288 }
289
pushStorageOfTensor(const at::Tensor & tensor)290 void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
291 const at::Storage& storage = tensor.storage();
292 void* addr = storage.unsafeGetStorageImpl();
293 auto it = memoized_storage_map_.find(addr);
294 if (it != memoized_storage_map_.end()) {
295 pushBinGet(it->second);
296 return;
297 }
298
299 // Tuple for persistent_load
300 push<PickleOpCode>(PickleOpCode::MARK);
301 // typename
302 pushString("storage");
303 // data_type
304 std::string data_type =
305 std::string(toString(tensor.scalar_type())).append("Storage");
306 pushGlobal("torch", data_type);
307 // root_key
308 std::string root_key = get_tensor_id_ != nullptr
309 ? get_tensor_id_(tensor)
310 : std::to_string(tensor_data_.size());
311 pushString(root_key);
312 // location
313 pushString(tensor.device().str());
314 // size
315 pushInt(
316 static_cast<int64_t>(tensor.storage().nbytes() / tensor.element_size()));
317
318 push<PickleOpCode>(PickleOpCode::TUPLE);
319 push<PickleOpCode>(PickleOpCode::BINPERSID);
320
321 // TODO: Skip this if not writing tensors
322 memoized_storage_map_[addr] = pushNextBinPut();
323 tensor_data_.push_back(tensor);
324 }
325
pushBytes(const std::string & string)326 void Pickler::pushBytes(const std::string& string) {
327 static const size_t kSmallStr = 32;
328 if (string.size() <= kSmallStr &&
329 bufferPos_ + string.size() <= buffer_.size()) {
330 // Small string that fits: buffer the data.
331 memcpy(buffer_.data() + bufferPos_, string.data(), string.size());
332 bufferPos_ += string.size();
333 } else {
334 // Otherwise, first flush, then write directly.
335 flush();
336 writer_(string.data(), string.size());
337 }
338 }
339
pushGlobal(c10::string_view module_name,c10::string_view class_name)340 void Pickler::pushGlobal(
341 c10::string_view module_name,
342 c10::string_view class_name) {
343 std::string key;
344 key.reserve(module_name.size() + class_name.size() + 2);
345 key.append(module_name.data(), module_name.size());
346 key.push_back('\n');
347 key.append(class_name.data(), class_name.size());
348 key.push_back('\n');
349
350 const auto memo_entry = memoized_globals_map_.find(key);
351 if (memo_entry == memoized_globals_map_.end()) {
352 push<PickleOpCode>(PickleOpCode::GLOBAL);
353 pushBytes(key);
354 // Push BINPUT without adding anything to the memoized_ivalues_
355 size_t memo_id = pushNextBinPut();
356 memoized_globals_map_.insert({key, memo_id});
357 } else {
358 pushBinGet(memo_entry->second);
359 }
360 }
361
pushTensor(const IValue & ivalue)362 void Pickler::pushTensor(const IValue& ivalue) {
363 if (tensor_table_ == nullptr) {
364 pushLiteralTensor(ivalue);
365 } else {
366 pushTensorReference(ivalue);
367 }
368 }
369
pushLiteralSparseTensor(const at::Tensor & tensor)370 void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
371 pushGlobal("torch._utils", "_rebuild_sparse_tensor");
372 push<PickleOpCode>(PickleOpCode::MARK);
373 // layout
374 auto layout = tensor.layout();
375 pushInt(static_cast<int>(layout));
376 switch (layout) {
377 case c10::Layout::Sparse:
378 // size
379 push<PickleOpCode>(PickleOpCode::MARK);
380 for (auto size : tensor.sizes()) {
381 pushInt(size);
382 }
383 push<PickleOpCode>(PickleOpCode::TUPLE);
384 // requires grad
385 pushIValue(tensor.requires_grad());
386 // indices
387 pushTensor(tensor._indices());
388 // values
389 pushTensor(tensor._values());
390 break;
391 case c10::Layout::SparseCsr:
392 push<PickleOpCode>(PickleOpCode::MARK);
393 for (auto size : tensor.sizes()) {
394 pushInt(size);
395 }
396 push<PickleOpCode>(PickleOpCode::TUPLE);
397
398 pushIValue(tensor.requires_grad());
399 pushTensor(tensor.crow_indices());
400 pushTensor(tensor.col_indices());
401 pushTensor(tensor.values());
402 break;
403 default:
404 TORCH_CHECK(
405 false,
406 "Unsupported sparse tensor layout type in serialization ",
407 layout);
408 break;
409 }
410 // backward_hooks
411 pushGlobal("collections", "OrderedDict");
412 push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
413 // Construct the collections.OrderedDict for the backward_hooks
414 push<PickleOpCode>(PickleOpCode::REDUCE);
415 push<PickleOpCode>(PickleOpCode::TUPLE);
416 // Call torch._utils._rebuild_sparse_coo_tensor
417 push<PickleOpCode>(PickleOpCode::REDUCE);
418 }
419
pushLiteralTensor(const IValue & ivalue)420 void Pickler::pushLiteralTensor(const IValue& ivalue) {
421 // In contrast to tensor references, literal tensors are included in the
422 // pickle program binary blob. They are written to the file after the STOP
423 // opcode. They can't be included in the pickle program itself without a bunch
424 // of extra machinery since byte strings are limited to 4 GB.
425 //
426 // The format here is the same one used by `torch.save()`. The code for the
427 // format can be found in `torch/serialization.py`.
428 auto& tensor = ivalue.toTensor();
429
430 if (tensor.is_sparse() || tensor.is_sparse_csr()) {
431 pushLiteralSparseTensor(tensor);
432 return;
433 }
434
435 bool quantized = tensor.is_quantized();
436 // The arguments to this function are:
437 // storage, storage_offset, size, stride, requires_grad, backward_hooks
438 pushGlobal(
439 "torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2");
440
441 push<PickleOpCode>(PickleOpCode::MARK);
442 pushStorageOfTensor(tensor);
443
444 // storage offset
445 pushInt(tensor.storage_offset());
446
447 // size
448 push<PickleOpCode>(PickleOpCode::MARK);
449 for (auto size : tensor.sizes()) {
450 pushInt(size);
451 }
452 push<PickleOpCode>(PickleOpCode::TUPLE);
453
454 // stride
455 push<PickleOpCode>(PickleOpCode::MARK);
456 for (auto stride : tensor.strides()) {
457 pushInt(stride);
458 }
459 push<PickleOpCode>(PickleOpCode::TUPLE);
460
461 if (quantized) {
462 push<PickleOpCode>(PickleOpCode::MARK);
463 pushGlobal("torch", toString(tensor.qscheme()));
464 // tuple of (qscheme, scale, zp) or (qscheme, scales, zps, axis)
465 switch (tensor.qscheme()) {
466 case at::kPerTensorAffine:
467 pushDouble(tensor.q_scale());
468 pushInt(tensor.q_zero_point());
469 break;
470 case at::kPerChannelAffineFloatQParams:
471 case at::kPerChannelAffine: {
472 pushTensor(tensor.q_per_channel_scales());
473 pushTensor(tensor.q_per_channel_zero_points());
474 pushInt(tensor.q_per_channel_axis());
475 } break;
476 default:
477 TORCH_CHECK(
478 false,
479 "Unsupported tensor quantization type in serialization ",
480 toString(tensor.qscheme()));
481 break;
482 }
483 push<PickleOpCode>(PickleOpCode::TUPLE);
484 }
485
486 // requires_grad
487 pushIValue(tensor.requires_grad());
488
489 // backward_hooks
490 pushGlobal("collections", "OrderedDict");
491 push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
492 // Construct the collections.OrderedDict for the backward_hooks
493 push<PickleOpCode>(PickleOpCode::REDUCE);
494
495 if (!quantized) {
496 // Only push it for regular tensor if the dictionary is not empty.
497 auto metadata = torch::jit::getTensorMetadata(tensor);
498 if (!metadata.empty()) {
499 // IValues based on std::unordered_map<K, V> are slow and deprecated.
500 // Thus, pass a c10::Dict to pushDict.
501 c10::Dict<std::string, bool> math_bits_;
502 for (const auto& pair : metadata) {
503 math_bits_.insert(pair.first, pair.second);
504 }
505 pushDict(math_bits_);
506 }
507 }
508
509 push<PickleOpCode>(PickleOpCode::TUPLE);
510
511 // Call torch._utils._rebuild_tensor_v2
512 push<PickleOpCode>(PickleOpCode::REDUCE);
513 }
514
pushSpecializedList(const IValue & ivalue,const char * list_name,const std::function<void (const IValue &)> & item_pusher)515 void Pickler::pushSpecializedList(
516 const IValue& ivalue,
517 const char* list_name,
518 const std::function<void(const IValue&)>& item_pusher) {
519 pushGlobal("torch.jit._pickle", list_name);
520
521 // Reduce arguments are spread (e.g. `*args`) before calling the global,
522 // so wrap in a tuple
523 push<PickleOpCode>(PickleOpCode::MARK);
524
525 push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
526 // Mark list
527 push<PickleOpCode>(PickleOpCode::MARK);
528
529 // Add all items
530 item_pusher(ivalue);
531
532 // Finish list
533 push<PickleOpCode>(PickleOpCode::APPENDS);
534
535 // Finish tuple
536 push<PickleOpCode>(PickleOpCode::TUPLE);
537
538 // Call reduce
539 push<PickleOpCode>(PickleOpCode::REDUCE);
540 }
541
swapDouble(double value)542 static inline double swapDouble(double value) {
543 const char* bytes = reinterpret_cast<const char*>(&value);
544 double flipped = 0;
545 char* out_bytes = reinterpret_cast<char*>(&flipped);
546 for (const auto i : c10::irange(sizeof(double))) {
547 out_bytes[i] = bytes[sizeof(double) - i - 1];
548 }
549 return *reinterpret_cast<double*>(out_bytes);
550 }
551
pushDouble(double value)552 void Pickler::pushDouble(double value) {
553 push<PickleOpCode>(PickleOpCode::BINFLOAT);
554 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
555 // Python pickle format is big endian, swap.
556 push<double>(swapDouble(value));
557 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
558 push<double>(value);
559 #else
560 #error Unexpected or undefined __BYTE_ORDER__
561 #endif
562 }
pushComplexDouble(const IValue & value)563 void Pickler::pushComplexDouble(const IValue& value) {
564 c10::complex<double> d = value.toComplexDouble();
565 pushGlobal("builtins", "complex");
566 pushIValue(d.real());
567 pushIValue(d.imag());
568 push<PickleOpCode>(PickleOpCode::TUPLE2);
569 push<PickleOpCode>(PickleOpCode::REDUCE);
570 }
571
pushLong(const std::string & data)572 void Pickler::pushLong(const std::string& data) {
573 uint64_t size = data.size();
574
575 TORCH_INTERNAL_ASSERT(
576 size <= std::numeric_limits<uint8_t>::max(),
577 "Cannot pickle a long larger than 255 bytes");
578 push<PickleOpCode>(PickleOpCode::LONG1);
579 push<uint8_t>(size);
580 pushBytes(data);
581 }
582
pushTensorReference(const IValue & ivalue)583 void Pickler::pushTensorReference(const IValue& ivalue) {
584 pushGlobal("torch.jit._pickle", "build_tensor_from_id");
585 tensor_table_->push_back(ivalue.toTensor());
586 auto tensor_id = tensor_table_->size() - 1;
587 // Reduce arguments are spread (e.g. `*args`) before calling the global,
588 // so wrap in a tuple
589 push<PickleOpCode>(PickleOpCode::MARK);
590 pushIValue(static_cast<int64_t>(tensor_id));
591 push<PickleOpCode>(PickleOpCode::TUPLE);
592
593 push<PickleOpCode>(PickleOpCode::REDUCE);
594 }
595
596 // startTypeTag() and endTypeTag() must be called in a pair, with 1 argument
597 // pushed on the stack in between them. They will add the type of a container
598 // ivalue to the stack as a string so we can preserve type tags across
599 // serialization
startTypeTag()600 void Pickler::startTypeTag() {
601 if (tag_aggregates_) {
602 pushGlobal("torch.jit._pickle", "restore_type_tag");
603 }
604 }
605 namespace {
type_printer(const c10::Type & type)606 std::optional<std::string> type_printer(const c10::Type& type) {
607 if (auto dyn = type.castRaw<c10::DynamicType>()) {
608 return dyn->fallback()->annotation_str(type_printer);
609 }
610 return std::nullopt;
611 }
612 } // namespace
613
614 // See startTypeTag
endTypeTag(const IValue & ivalue)615 void Pickler::endTypeTag(const IValue& ivalue) {
616 if (!tag_aggregates_) {
617 return;
618 }
619 TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList());
620
621 // Push the dict type
622 auto type = ivalue.type();
623 TORCH_INTERNAL_ASSERT(type);
624
625 auto annot_str = type->annotation_str(type_printer);
626 pushString(annot_str);
627
628 // Pop the dict and type into a tuple
629 push<PickleOpCode>(PickleOpCode::TUPLE2);
630
631 // Call function via reduce
632 push<PickleOpCode>(PickleOpCode::REDUCE);
633 }
634
pushDict(const IValue & ivalue)635 void Pickler::pushDict(const IValue& ivalue) {
636 auto dict = ivalue.toGenericDict();
637
638 startTypeTag();
639
640 push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
641
642 static_assert(
643 std::is_unsigned_v<decltype(dict.size())>,
644 "Expected size to be non-negative.");
645 push<PickleOpCode>(PickleOpCode::MARK);
646
647 // Sort the dict for deterministic keys
648 for (const auto& entry : dict) {
649 pushIValue(entry.key());
650 pushIValue(entry.value());
651 }
652
653 push<PickleOpCode>(PickleOpCode::SETITEMS);
654
655 endTypeTag(ivalue);
656 }
657
pushNextBinPut()658 size_t Pickler::pushNextBinPut() {
659 if (memo_id_ <= std::numeric_limits<uint8_t>::max()) {
660 push<PickleOpCode>(PickleOpCode::BINPUT);
661 push<uint8_t>(memo_id_);
662 } else {
663 // Memoized too many items, issue a LONG_BINPUT instead
664 push<PickleOpCode>(PickleOpCode::LONG_BINPUT);
665 push<uint32_t>(memo_id_);
666 }
667 AT_ASSERT(memo_id_ <= std::numeric_limits<uint32_t>::max());
668 ++memo_id_;
669 return memo_id_ - 1;
670 }
671
pushGenericList(const IValue & ivalue)672 void Pickler::pushGenericList(const IValue& ivalue) {
673 auto list = ivalue.toListRef();
674 startTypeTag();
675
676 // Push the list items
677 push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
678 push<PickleOpCode>(PickleOpCode::MARK);
679 for (const IValue& item : list) {
680 pushIValue(item);
681 }
682 push<PickleOpCode>(PickleOpCode::APPENDS);
683
684 endTypeTag(ivalue);
685 }
686
pushTuple(const IValue & ivalue)687 void Pickler::pushTuple(const IValue& ivalue) {
688 auto tuple = ivalue.toTuple();
689 auto tuple_size = tuple->elements().size();
690
691 switch (tuple_size) {
692 case 0: {
693 push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
694 } break;
695 case 1: {
696 pushIValue(tuple->elements()[0]);
697 push<PickleOpCode>(PickleOpCode::TUPLE1);
698 } break;
699 case 2: {
700 pushIValue(tuple->elements()[0]);
701 pushIValue(tuple->elements()[1]);
702 push<PickleOpCode>(PickleOpCode::TUPLE2);
703 } break;
704 case 3: {
705 pushIValue(tuple->elements()[0]);
706 pushIValue(tuple->elements()[1]);
707 pushIValue(tuple->elements()[2]);
708 push<PickleOpCode>(PickleOpCode::TUPLE3);
709 } break;
710 default: {
711 push<PickleOpCode>(PickleOpCode::MARK);
712 for (const IValue& item : tuple->elements()) {
713 pushIValue(item);
714 }
715 push<PickleOpCode>(PickleOpCode::TUPLE);
716 } break;
717 }
718 }
719
getWriteableTensorData(const at::Tensor & tensor,bool to_cpu)720 WriteableTensorData getWriteableTensorData(
721 const at::Tensor& tensor,
722 bool to_cpu) {
723 WriteableTensorData result;
724 result.tensor_ = tensor;
725 result.size_ = tensor.storage().nbytes();
726 // TODO HIP support
727 if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
728 // NB: This new tensor is created to support cuda tensors.
729 // Storages can be mutated when converting tensors from cuda to cpu,
730 // and we need a cpu tensor to copy data from.
731 result.tensor_ =
732 at::empty({0}, tensor.options())
733 .set_(
734 tensor.storage(),
735 /* storage_offset = */ 0,
736 /* size = */
737 {static_cast<int64_t>(
738 tensor.storage().nbytes() / tensor.element_size())},
739 /* stride = */ {1})
740 .cpu();
741 TORCH_CHECK(
742 result.tensor_.storage().nbytes() == result.size_,
743 "Storage tensor size did not match record size");
744 }
745 return result;
746 }
747
checkHasValidSetGetState(const std::shared_ptr<c10::ClassType> & cls)748 bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
749 // Check that the schemas for __getstate__ and __setstate__ are correct
750 auto getstate = cls->findMethod("__getstate__");
751 if (getstate == nullptr) {
752 return false;
753 }
754 auto get_schema = getstate->getSchema();
755
756 // Check __getstate__
757 // __getstate__ is expected to be (self) -> T
758 TORCH_CHECK(
759 get_schema.arguments().size() == 1,
760 "'__getstate__' must have 'self' as its only argument, but found ",
761 get_schema.arguments().size(),
762 " arguments");
763 TORCH_CHECK(
764 get_schema.returns().size() == 1,
765 "'__getstate__' must return 1 value, but found ",
766 get_schema.returns().size());
767
768 // Check __setstate__ if the method exists
769 // __setstate__ is expected to be (self, T) -> None
770 auto setstate = cls->findMethod("__setstate__");
771 if (!setstate) {
772 return false;
773 }
774 auto set_schema = setstate->getSchema();
775
776 TORCH_CHECK(
777 set_schema.arguments().size() == 2,
778 "'__setstate__' must have 'self' and the state as its "
779 "only arguments, but found ",
780 set_schema.arguments().size(),
781 " arguments");
782 TORCH_CHECK(
783 set_schema.returns().size() == 1,
784 "'__setstate__' must return None, but found ",
785 set_schema.returns().size(),
786 " return values");
787 TORCH_CHECK(
788 set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
789 "'__setstate__' must return None, but found value of type",
790 set_schema.returns().at(0).type()->annotation_str());
791
792 // Check that the return type of __getstate__ matches the input to
793 // __setstate__
794 auto get_type = get_schema.returns().at(0).type();
795 auto set_type = set_schema.arguments().at(1).type();
796
797 TORCH_CHECK(
798 get_type->isSubtypeOf(*set_type),
799 "'__getstate__'s return type (",
800 get_type->annotation_str(),
801 ") does not match '__setstate__'s argument type (",
802 set_type->annotation_str(),
803 ")");
804
805 return true;
806 }
807
808 } // namespace torch::jit
809