xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/pickler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/Dict.h>
3 #ifdef USE_RPC
4 #include <torch/csrc/distributed/rpc/rref_context.h>
5 #endif
6 #include <ATen/quantized/Quantizer.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/api/function_impl.h>
9 #include <torch/csrc/jit/serialization/pickler.h>
10 #include <torch/csrc/utils/byte_order.h>
11 #include <string>
12 #include <type_traits>
13 
14 namespace torch::jit {
15 
16 using ::c10::IValue;
17 
18 // Protocol 2 is the highest that can be decoded by Python 2
19 // See https://docs.python.org/3/library/pickle.html#data-stream-format
20 constexpr static uint8_t PROTOCOL_VERSION = 2;
21 
22 // NOLINTNEXTLINE(bugprone-exception-escape)
~Pickler()23 Pickler::~Pickler() {
24   flush();
25 }
26 
protocol()27 void Pickler::protocol() {
28   push<PickleOpCode>(PickleOpCode::PROTO);
29   push<uint8_t>(PROTOCOL_VERSION);
30 }
31 
startTuple()32 void Pickler::startTuple() {
33   // All attributes get pushed into a tuple and their indices saved in the
34   // module def
35   push<PickleOpCode>(PickleOpCode::MARK);
36 }
37 
endTuple()38 void Pickler::endTuple() {
39   push<PickleOpCode>(PickleOpCode::TUPLE);
40 }
41 
stop()42 void Pickler::stop() {
43   push<PickleOpCode>(PickleOpCode::STOP);
44   flush();
45 }
46 
47 // unmemoized version called by pushIValue
pushIValueImpl(const IValue & ivalue)48 void Pickler::pushIValueImpl(const IValue& ivalue) {
49   if (ivalue.isTensor()) {
50     pushTensor(ivalue);
51   } else if (ivalue.isTuple()) {
52     pushTuple(ivalue);
53   } else if (ivalue.isDouble()) {
54     pushDouble(ivalue.toDouble());
55   } else if (ivalue.isComplexDouble()) {
56     pushComplexDouble(ivalue);
57   } else if (ivalue.isInt()) {
58     pushInt(ivalue.toInt());
59   } else if (ivalue.isBool()) {
60     pushBool(ivalue.toBool());
61   } else if (ivalue.isString()) {
62     pushString(ivalue.toStringRef());
63   } else if (ivalue.isGenericDict()) {
64     pushDict(ivalue);
65   } else if (ivalue.isNone()) {
66     push<PickleOpCode>(PickleOpCode::NONE);
67   } else if (ivalue.isIntList()) {
68     pushSpecializedList(ivalue, "build_intlist", [this](const IValue& ivalue) {
69       for (const int64_t item : ivalue.toIntVector()) {
70         pushInt(item);
71       }
72     });
73   } else if (ivalue.isTensorList()) {
74     pushSpecializedList(
75         ivalue, "build_tensorlist", [this](const IValue& ivalue) {
76           for (const at::Tensor& item : ivalue.toTensorVector()) {
77             pushIValue(item);
78           }
79         });
80   } else if (ivalue.isDoubleList()) {
81     pushSpecializedList(
82         ivalue, "build_doublelist", [this](const IValue& ivalue) {
83           for (double item : ivalue.toDoubleVector()) {
84             pushDouble(item);
85           }
86         });
87   } else if (ivalue.isBoolList()) {
88     pushSpecializedList(ivalue, "build_boollist", [this](const IValue& ivalue) {
89       for (bool item : ivalue.toBoolList()) {
90         pushBool(item);
91       }
92     });
93     // note: isList must be after isIntList and friends because
94     // isList is true for all lists.
95   } else if (ivalue.isList()) {
96     pushGenericList(ivalue);
97   } else if (ivalue.isObject()) {
98     auto obj = ivalue.toObject();
99     auto type = obj->type();
100     if (memoized_class_types_ != nullptr) {
101       // memoize every class type the Pickler encountered
102       // This is used to make sure we capture all the run-time types
103       // and serialize them properly for class/interface polymorphism
104       memoized_class_types_->emplace_back(type);
105     }
106     auto type_name = type->name().value();
107     if (type_renamer_) {
108       type_name = type_renamer_(type);
109     }
110     pushGlobal(type_name.prefix(), type_name.name());
111     push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
112     push<PickleOpCode>(PickleOpCode::NEWOBJ);
113     if (checkHasValidSetGetState(type)) {
114       Function& getstate = type->getMethod("__getstate__");
115       pushIValue(getstate({obj}));
116     } else {
117       push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
118       push<PickleOpCode>(PickleOpCode::MARK);
119       for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
120         pushString(type->getAttributeName(i));
121         pushIValue(obj->getSlot(i));
122       }
123       push<PickleOpCode>(PickleOpCode::SETITEMS);
124     }
125     push<PickleOpCode>(PickleOpCode::BUILD);
126   } else if (ivalue.isDevice()) {
127     pushDevice(ivalue);
128   } else if (ivalue.isCapsule()) {
129     std::stringstream err;
130     err << "Cannot serialize custom bound C++ class";
131     if (memoized_class_types_ && !memoized_class_types_->empty()) {
132       if (auto qualname = memoized_class_types_->back()->name()) {
133         err << " " << qualname->qualifiedName();
134       }
135     }
136     err << ". Please define serialization methods via def_pickle() for "
137            "this class.";
138     AT_ERROR(err.str());
139   } else if (ivalue.isRRef()) {
140 #ifdef USE_RPC
141     TORCH_CHECK(
142         torch::distributed::rpc::getAllowJitRRefPickle() == true,
143         "RRef jit pickling is only allowed inside RPC calls.");
144     pushRRef(ivalue);
145 #else
146     TORCH_CHECK(
147         false, "RRef pickling is only supported with the distributed package");
148 #endif
149   } else if (ivalue.isEnum()) {
150     auto enum_holder = ivalue.toEnumHolder();
151     const auto& qualified_class_name =
152         enum_holder->type()->qualifiedClassName();
153     pushGlobal(qualified_class_name.prefix(), qualified_class_name.name());
154     pushIValue(enum_holder->value());
155     push<PickleOpCode>(PickleOpCode::REDUCE);
156   } else {
157     AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
158   }
159 }
160 
pushDevice(const IValue & ivalue)161 void Pickler::pushDevice(const IValue& ivalue) {
162   auto device = ivalue.toDevice();
163   auto deviceStr = device.str();
164   auto it = memoized_devices_map_.find(deviceStr);
165   if (it == memoized_devices_map_.end()) {
166     pushGlobal("torch", "device");
167     pushString(deviceStr);
168     push<PickleOpCode>(PickleOpCode::TUPLE1);
169     push<PickleOpCode>(PickleOpCode::REDUCE);
170     memoized_devices_map_[deviceStr] = pushNextBinPut();
171   } else {
172     pushBinGet(it->second);
173   }
174 }
175 
176 #ifdef USE_RPC
pushRRef(const IValue & ivalue)177 void Pickler::pushRRef(const IValue& ivalue) {
178   // It is the same as how rref is pickled in python, see PyRRef::pickle
179   auto rrefInterface = ivalue.toRRef();
180   auto rref =
181       c10::static_intrusive_pointer_cast<distributed::rpc::RRef>(rrefInterface);
182   pushGlobal("torch.distributed.rpc", "rref");
183   auto& ctx = distributed::rpc::RRefContext::getInstance();
184   auto rrefForkData = ctx.prepareChildFork(rref);
185   push<PickleOpCode>(PickleOpCode::MARK);
186   pushInt(rrefForkData.ownerId_);
187   pushInt(rrefForkData.rrefId_.createdOn_);
188   pushInt(rrefForkData.rrefId_.localId_);
189   pushInt(rrefForkData.forkId_.createdOn_);
190   pushInt(rrefForkData.forkId_.localId_);
191   pushInt(rrefForkData.parent_);
192   pushString(rrefForkData.typeStr_);
193   push<PickleOpCode>(PickleOpCode::TUPLE);
194   push<PickleOpCode>(PickleOpCode::REDUCE);
195 }
196 #endif
197 
pushIValue(const IValue & ivalue)198 void Pickler::pushIValue(const IValue& ivalue) {
199   bool shouldMemoizeByPointer =
200       ivalue.isPtrType() && !ivalue.isString() && ivalue.use_count() > 1;
201 
202   // Mutable ivalues are memoized by pointer equality, which we handle at this
203   // outer granularity.  Immutable ivalues are memoized by value equality which
204   // is handled in the type-specific handlers inside pushIValueImpl.
205   if (shouldMemoizeByPointer) {
206     const void* ptr = ivalue.internalToPointer();
207     TORCH_CHECK(
208         ptr != nullptr,
209         "Pickler cannot memoize ",
210         ivalue.tagKind(),
211         " IValue ",
212         ivalue);
213     auto memo_entry = memoized_ivalue_map_.find(ptr);
214     if (memo_entry != memoized_ivalue_map_.end()) {
215       // This value has already been pushed, just do a BINGET
216       pushBinGet(memo_entry->second);
217       return;
218     }
219 
220     pushIValueImpl(ivalue);
221 
222     memoized_ivalues_.push_back(ivalue);
223     memoized_ivalue_map_[ptr] = pushNextBinPut();
224   } else {
225     pushIValueImpl(ivalue);
226   }
227 }
228 
pushInt(int64_t n)229 void Pickler::pushInt(int64_t n) {
230   if (n >= std::numeric_limits<uint8_t>::min() &&
231       n <= std::numeric_limits<uint8_t>::max()) {
232     push<PickleOpCode>(PickleOpCode::BININT1);
233     push<uint8_t>(n);
234   } else if (
235       n >= std::numeric_limits<uint16_t>::min() &&
236       n <= std::numeric_limits<uint16_t>::max()) {
237     push<PickleOpCode>(PickleOpCode::BININT2);
238     push<uint16_t>(to_le16(n));
239   } else if (
240       n >= std::numeric_limits<int32_t>::min() &&
241       n <= std::numeric_limits<int32_t>::max()) {
242     push<PickleOpCode>(PickleOpCode::BININT);
243     push<int32_t>(to_le32(n));
244   } else {
245     // Push 8 byte integer
246     push<PickleOpCode>(PickleOpCode::LONG1);
247     push<uint8_t>(8);
248     push<int64_t>(to_le64(n));
249   }
250 }
251 
pushBool(bool value)252 void Pickler::pushBool(bool value) {
253   push<PickleOpCode>(value ? PickleOpCode::NEWTRUE : PickleOpCode::NEWFALSE);
254 }
255 
pushBinGet(uint32_t memo_id)256 void Pickler::pushBinGet(uint32_t memo_id) {
257   if (memo_id <= std::numeric_limits<uint8_t>::max()) {
258     push<PickleOpCode>(PickleOpCode::BINGET);
259     push<uint8_t>(memo_id);
260   } else {
261     // Memoized too many items, issue a LONG_BINGET instead
262     push<PickleOpCode>(PickleOpCode::LONG_BINGET);
263     push<uint32_t>(memo_id);
264   }
265 }
266 
267 // unmemoized encoding of a string
pushStringImpl(const std::string & string)268 void Pickler::pushStringImpl(const std::string& string) {
269   if (string.size() <= UINT_MAX) {
270     push<PickleOpCode>(PickleOpCode::BINUNICODE);
271     push<uint32_t>(to_le32(string.size()));
272     pushBytes(string);
273   } else {
274     push<PickleOpCode>(PickleOpCode::BINUNICODE8);
275     push<int64_t>(to_le64(string.size()));
276     pushBytes(string);
277   }
278 }
279 
pushString(const std::string & string)280 void Pickler::pushString(const std::string& string) {
281   auto it = memoized_strings_map_.find(string);
282   if (it == memoized_strings_map_.end()) {
283     pushStringImpl(string);
284     memoized_strings_map_[string] = pushNextBinPut();
285   } else {
286     pushBinGet(it->second);
287   }
288 }
289 
pushStorageOfTensor(const at::Tensor & tensor)290 void Pickler::pushStorageOfTensor(const at::Tensor& tensor) {
291   const at::Storage& storage = tensor.storage();
292   void* addr = storage.unsafeGetStorageImpl();
293   auto it = memoized_storage_map_.find(addr);
294   if (it != memoized_storage_map_.end()) {
295     pushBinGet(it->second);
296     return;
297   }
298 
299   // Tuple for persistent_load
300   push<PickleOpCode>(PickleOpCode::MARK);
301   // typename
302   pushString("storage");
303   // data_type
304   std::string data_type =
305       std::string(toString(tensor.scalar_type())).append("Storage");
306   pushGlobal("torch", data_type);
307   // root_key
308   std::string root_key = get_tensor_id_ != nullptr
309       ? get_tensor_id_(tensor)
310       : std::to_string(tensor_data_.size());
311   pushString(root_key);
312   // location
313   pushString(tensor.device().str());
314   // size
315   pushInt(
316       static_cast<int64_t>(tensor.storage().nbytes() / tensor.element_size()));
317 
318   push<PickleOpCode>(PickleOpCode::TUPLE);
319   push<PickleOpCode>(PickleOpCode::BINPERSID);
320 
321   // TODO: Skip this if not writing tensors
322   memoized_storage_map_[addr] = pushNextBinPut();
323   tensor_data_.push_back(tensor);
324 }
325 
pushBytes(const std::string & string)326 void Pickler::pushBytes(const std::string& string) {
327   static const size_t kSmallStr = 32;
328   if (string.size() <= kSmallStr &&
329       bufferPos_ + string.size() <= buffer_.size()) {
330     // Small string that fits: buffer the data.
331     memcpy(buffer_.data() + bufferPos_, string.data(), string.size());
332     bufferPos_ += string.size();
333   } else {
334     // Otherwise, first flush, then write directly.
335     flush();
336     writer_(string.data(), string.size());
337   }
338 }
339 
pushGlobal(c10::string_view module_name,c10::string_view class_name)340 void Pickler::pushGlobal(
341     c10::string_view module_name,
342     c10::string_view class_name) {
343   std::string key;
344   key.reserve(module_name.size() + class_name.size() + 2);
345   key.append(module_name.data(), module_name.size());
346   key.push_back('\n');
347   key.append(class_name.data(), class_name.size());
348   key.push_back('\n');
349 
350   const auto memo_entry = memoized_globals_map_.find(key);
351   if (memo_entry == memoized_globals_map_.end()) {
352     push<PickleOpCode>(PickleOpCode::GLOBAL);
353     pushBytes(key);
354     // Push BINPUT without adding anything to the memoized_ivalues_
355     size_t memo_id = pushNextBinPut();
356     memoized_globals_map_.insert({key, memo_id});
357   } else {
358     pushBinGet(memo_entry->second);
359   }
360 }
361 
pushTensor(const IValue & ivalue)362 void Pickler::pushTensor(const IValue& ivalue) {
363   if (tensor_table_ == nullptr) {
364     pushLiteralTensor(ivalue);
365   } else {
366     pushTensorReference(ivalue);
367   }
368 }
369 
pushLiteralSparseTensor(const at::Tensor & tensor)370 void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
371   pushGlobal("torch._utils", "_rebuild_sparse_tensor");
372   push<PickleOpCode>(PickleOpCode::MARK);
373   // layout
374   auto layout = tensor.layout();
375   pushInt(static_cast<int>(layout));
376   switch (layout) {
377     case c10::Layout::Sparse:
378       // size
379       push<PickleOpCode>(PickleOpCode::MARK);
380       for (auto size : tensor.sizes()) {
381         pushInt(size);
382       }
383       push<PickleOpCode>(PickleOpCode::TUPLE);
384       // requires grad
385       pushIValue(tensor.requires_grad());
386       // indices
387       pushTensor(tensor._indices());
388       // values
389       pushTensor(tensor._values());
390       break;
391     case c10::Layout::SparseCsr:
392       push<PickleOpCode>(PickleOpCode::MARK);
393       for (auto size : tensor.sizes()) {
394         pushInt(size);
395       }
396       push<PickleOpCode>(PickleOpCode::TUPLE);
397 
398       pushIValue(tensor.requires_grad());
399       pushTensor(tensor.crow_indices());
400       pushTensor(tensor.col_indices());
401       pushTensor(tensor.values());
402       break;
403     default:
404       TORCH_CHECK(
405           false,
406           "Unsupported sparse tensor layout type in serialization ",
407           layout);
408       break;
409   }
410   // backward_hooks
411   pushGlobal("collections", "OrderedDict");
412   push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
413   // Construct the collections.OrderedDict for the backward_hooks
414   push<PickleOpCode>(PickleOpCode::REDUCE);
415   push<PickleOpCode>(PickleOpCode::TUPLE);
416   // Call torch._utils._rebuild_sparse_coo_tensor
417   push<PickleOpCode>(PickleOpCode::REDUCE);
418 }
419 
pushLiteralTensor(const IValue & ivalue)420 void Pickler::pushLiteralTensor(const IValue& ivalue) {
421   // In contrast to tensor references, literal tensors are included in the
422   // pickle program binary blob. They are written to the file after the STOP
423   // opcode. They can't be included in the pickle program itself without a bunch
424   // of extra machinery since byte strings are limited to 4 GB.
425   //
426   // The format here is the same one used by `torch.save()`. The code for the
427   // format can be found in `torch/serialization.py`.
428   auto& tensor = ivalue.toTensor();
429 
430   if (tensor.is_sparse() || tensor.is_sparse_csr()) {
431     pushLiteralSparseTensor(tensor);
432     return;
433   }
434 
435   bool quantized = tensor.is_quantized();
436   // The arguments to this function are:
437   //    storage, storage_offset, size, stride, requires_grad, backward_hooks
438   pushGlobal(
439       "torch._utils", quantized ? "_rebuild_qtensor" : "_rebuild_tensor_v2");
440 
441   push<PickleOpCode>(PickleOpCode::MARK);
442   pushStorageOfTensor(tensor);
443 
444   // storage offset
445   pushInt(tensor.storage_offset());
446 
447   // size
448   push<PickleOpCode>(PickleOpCode::MARK);
449   for (auto size : tensor.sizes()) {
450     pushInt(size);
451   }
452   push<PickleOpCode>(PickleOpCode::TUPLE);
453 
454   // stride
455   push<PickleOpCode>(PickleOpCode::MARK);
456   for (auto stride : tensor.strides()) {
457     pushInt(stride);
458   }
459   push<PickleOpCode>(PickleOpCode::TUPLE);
460 
461   if (quantized) {
462     push<PickleOpCode>(PickleOpCode::MARK);
463     pushGlobal("torch", toString(tensor.qscheme()));
464     // tuple of (qscheme, scale, zp) or (qscheme, scales, zps, axis)
465     switch (tensor.qscheme()) {
466       case at::kPerTensorAffine:
467         pushDouble(tensor.q_scale());
468         pushInt(tensor.q_zero_point());
469         break;
470       case at::kPerChannelAffineFloatQParams:
471       case at::kPerChannelAffine: {
472         pushTensor(tensor.q_per_channel_scales());
473         pushTensor(tensor.q_per_channel_zero_points());
474         pushInt(tensor.q_per_channel_axis());
475       } break;
476       default:
477         TORCH_CHECK(
478             false,
479             "Unsupported tensor quantization type in serialization ",
480             toString(tensor.qscheme()));
481         break;
482     }
483     push<PickleOpCode>(PickleOpCode::TUPLE);
484   }
485 
486   // requires_grad
487   pushIValue(tensor.requires_grad());
488 
489   // backward_hooks
490   pushGlobal("collections", "OrderedDict");
491   push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
492   // Construct the collections.OrderedDict for the backward_hooks
493   push<PickleOpCode>(PickleOpCode::REDUCE);
494 
495   if (!quantized) {
496     // Only push it for regular tensor if the dictionary is not empty.
497     auto metadata = torch::jit::getTensorMetadata(tensor);
498     if (!metadata.empty()) {
499       // IValues based on std::unordered_map<K, V> are slow and deprecated.
500       // Thus, pass a c10::Dict to pushDict.
501       c10::Dict<std::string, bool> math_bits_;
502       for (const auto& pair : metadata) {
503         math_bits_.insert(pair.first, pair.second);
504       }
505       pushDict(math_bits_);
506     }
507   }
508 
509   push<PickleOpCode>(PickleOpCode::TUPLE);
510 
511   // Call torch._utils._rebuild_tensor_v2
512   push<PickleOpCode>(PickleOpCode::REDUCE);
513 }
514 
pushSpecializedList(const IValue & ivalue,const char * list_name,const std::function<void (const IValue &)> & item_pusher)515 void Pickler::pushSpecializedList(
516     const IValue& ivalue,
517     const char* list_name,
518     const std::function<void(const IValue&)>& item_pusher) {
519   pushGlobal("torch.jit._pickle", list_name);
520 
521   // Reduce arguments are spread (e.g. `*args`) before calling the global,
522   // so wrap in a tuple
523   push<PickleOpCode>(PickleOpCode::MARK);
524 
525   push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
526   // Mark list
527   push<PickleOpCode>(PickleOpCode::MARK);
528 
529   // Add all items
530   item_pusher(ivalue);
531 
532   // Finish list
533   push<PickleOpCode>(PickleOpCode::APPENDS);
534 
535   // Finish tuple
536   push<PickleOpCode>(PickleOpCode::TUPLE);
537 
538   // Call reduce
539   push<PickleOpCode>(PickleOpCode::REDUCE);
540 }
541 
swapDouble(double value)542 static inline double swapDouble(double value) {
543   const char* bytes = reinterpret_cast<const char*>(&value);
544   double flipped = 0;
545   char* out_bytes = reinterpret_cast<char*>(&flipped);
546   for (const auto i : c10::irange(sizeof(double))) {
547     out_bytes[i] = bytes[sizeof(double) - i - 1];
548   }
549   return *reinterpret_cast<double*>(out_bytes);
550 }
551 
pushDouble(double value)552 void Pickler::pushDouble(double value) {
553   push<PickleOpCode>(PickleOpCode::BINFLOAT);
554 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
555   // Python pickle format is big endian, swap.
556   push<double>(swapDouble(value));
557 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
558   push<double>(value);
559 #else
560 #error Unexpected or undefined __BYTE_ORDER__
561 #endif
562 }
pushComplexDouble(const IValue & value)563 void Pickler::pushComplexDouble(const IValue& value) {
564   c10::complex<double> d = value.toComplexDouble();
565   pushGlobal("builtins", "complex");
566   pushIValue(d.real());
567   pushIValue(d.imag());
568   push<PickleOpCode>(PickleOpCode::TUPLE2);
569   push<PickleOpCode>(PickleOpCode::REDUCE);
570 }
571 
pushLong(const std::string & data)572 void Pickler::pushLong(const std::string& data) {
573   uint64_t size = data.size();
574 
575   TORCH_INTERNAL_ASSERT(
576       size <= std::numeric_limits<uint8_t>::max(),
577       "Cannot pickle a long larger than 255 bytes");
578   push<PickleOpCode>(PickleOpCode::LONG1);
579   push<uint8_t>(size);
580   pushBytes(data);
581 }
582 
pushTensorReference(const IValue & ivalue)583 void Pickler::pushTensorReference(const IValue& ivalue) {
584   pushGlobal("torch.jit._pickle", "build_tensor_from_id");
585   tensor_table_->push_back(ivalue.toTensor());
586   auto tensor_id = tensor_table_->size() - 1;
587   // Reduce arguments are spread (e.g. `*args`) before calling the global,
588   // so wrap in a tuple
589   push<PickleOpCode>(PickleOpCode::MARK);
590   pushIValue(static_cast<int64_t>(tensor_id));
591   push<PickleOpCode>(PickleOpCode::TUPLE);
592 
593   push<PickleOpCode>(PickleOpCode::REDUCE);
594 }
595 
596 // startTypeTag() and endTypeTag() must be called in a pair, with 1 argument
597 // pushed on the stack in between them. They will add the type of a container
598 // ivalue to the stack as a string so we can preserve type tags across
599 // serialization
startTypeTag()600 void Pickler::startTypeTag() {
601   if (tag_aggregates_) {
602     pushGlobal("torch.jit._pickle", "restore_type_tag");
603   }
604 }
605 namespace {
type_printer(const c10::Type & type)606 std::optional<std::string> type_printer(const c10::Type& type) {
607   if (auto dyn = type.castRaw<c10::DynamicType>()) {
608     return dyn->fallback()->annotation_str(type_printer);
609   }
610   return std::nullopt;
611 }
612 } // namespace
613 
614 // See startTypeTag
endTypeTag(const IValue & ivalue)615 void Pickler::endTypeTag(const IValue& ivalue) {
616   if (!tag_aggregates_) {
617     return;
618   }
619   TORCH_INTERNAL_ASSERT(ivalue.isGenericDict() || ivalue.isList());
620 
621   // Push the dict type
622   auto type = ivalue.type();
623   TORCH_INTERNAL_ASSERT(type);
624 
625   auto annot_str = type->annotation_str(type_printer);
626   pushString(annot_str);
627 
628   // Pop the dict and type into a tuple
629   push<PickleOpCode>(PickleOpCode::TUPLE2);
630 
631   // Call function via reduce
632   push<PickleOpCode>(PickleOpCode::REDUCE);
633 }
634 
pushDict(const IValue & ivalue)635 void Pickler::pushDict(const IValue& ivalue) {
636   auto dict = ivalue.toGenericDict();
637 
638   startTypeTag();
639 
640   push<PickleOpCode>(PickleOpCode::EMPTY_DICT);
641 
642   static_assert(
643       std::is_unsigned_v<decltype(dict.size())>,
644       "Expected size to be non-negative.");
645   push<PickleOpCode>(PickleOpCode::MARK);
646 
647   // Sort the dict for deterministic keys
648   for (const auto& entry : dict) {
649     pushIValue(entry.key());
650     pushIValue(entry.value());
651   }
652 
653   push<PickleOpCode>(PickleOpCode::SETITEMS);
654 
655   endTypeTag(ivalue);
656 }
657 
pushNextBinPut()658 size_t Pickler::pushNextBinPut() {
659   if (memo_id_ <= std::numeric_limits<uint8_t>::max()) {
660     push<PickleOpCode>(PickleOpCode::BINPUT);
661     push<uint8_t>(memo_id_);
662   } else {
663     // Memoized too many items, issue a LONG_BINPUT instead
664     push<PickleOpCode>(PickleOpCode::LONG_BINPUT);
665     push<uint32_t>(memo_id_);
666   }
667   AT_ASSERT(memo_id_ <= std::numeric_limits<uint32_t>::max());
668   ++memo_id_;
669   return memo_id_ - 1;
670 }
671 
pushGenericList(const IValue & ivalue)672 void Pickler::pushGenericList(const IValue& ivalue) {
673   auto list = ivalue.toListRef();
674   startTypeTag();
675 
676   // Push the list items
677   push<PickleOpCode>(PickleOpCode::EMPTY_LIST);
678   push<PickleOpCode>(PickleOpCode::MARK);
679   for (const IValue& item : list) {
680     pushIValue(item);
681   }
682   push<PickleOpCode>(PickleOpCode::APPENDS);
683 
684   endTypeTag(ivalue);
685 }
686 
pushTuple(const IValue & ivalue)687 void Pickler::pushTuple(const IValue& ivalue) {
688   auto tuple = ivalue.toTuple();
689   auto tuple_size = tuple->elements().size();
690 
691   switch (tuple_size) {
692     case 0: {
693       push<PickleOpCode>(PickleOpCode::EMPTY_TUPLE);
694     } break;
695     case 1: {
696       pushIValue(tuple->elements()[0]);
697       push<PickleOpCode>(PickleOpCode::TUPLE1);
698     } break;
699     case 2: {
700       pushIValue(tuple->elements()[0]);
701       pushIValue(tuple->elements()[1]);
702       push<PickleOpCode>(PickleOpCode::TUPLE2);
703     } break;
704     case 3: {
705       pushIValue(tuple->elements()[0]);
706       pushIValue(tuple->elements()[1]);
707       pushIValue(tuple->elements()[2]);
708       push<PickleOpCode>(PickleOpCode::TUPLE3);
709     } break;
710     default: {
711       push<PickleOpCode>(PickleOpCode::MARK);
712       for (const IValue& item : tuple->elements()) {
713         pushIValue(item);
714       }
715       push<PickleOpCode>(PickleOpCode::TUPLE);
716     } break;
717   }
718 }
719 
getWriteableTensorData(const at::Tensor & tensor,bool to_cpu)720 WriteableTensorData getWriteableTensorData(
721     const at::Tensor& tensor,
722     bool to_cpu) {
723   WriteableTensorData result;
724   result.tensor_ = tensor;
725   result.size_ = tensor.storage().nbytes();
726   // TODO HIP support
727   if (tensor.storage().device_type() != DeviceType::CPU && to_cpu) {
728     // NB: This new tensor is created to support cuda tensors.
729     // Storages can be mutated when converting tensors from cuda to cpu,
730     // and we need a cpu tensor to copy data from.
731     result.tensor_ =
732         at::empty({0}, tensor.options())
733             .set_(
734                 tensor.storage(),
735                 /* storage_offset = */ 0,
736                 /* size = */
737                 {static_cast<int64_t>(
738                     tensor.storage().nbytes() / tensor.element_size())},
739                 /* stride = */ {1})
740             .cpu();
741     TORCH_CHECK(
742         result.tensor_.storage().nbytes() == result.size_,
743         "Storage tensor size did not match record size");
744   }
745   return result;
746 }
747 
checkHasValidSetGetState(const std::shared_ptr<c10::ClassType> & cls)748 bool checkHasValidSetGetState(const std::shared_ptr<c10::ClassType>& cls) {
749   // Check that the schemas for __getstate__ and __setstate__ are correct
750   auto getstate = cls->findMethod("__getstate__");
751   if (getstate == nullptr) {
752     return false;
753   }
754   auto get_schema = getstate->getSchema();
755 
756   // Check __getstate__
757   //   __getstate__ is expected to be (self) -> T
758   TORCH_CHECK(
759       get_schema.arguments().size() == 1,
760       "'__getstate__' must have 'self' as its only argument, but found ",
761       get_schema.arguments().size(),
762       " arguments");
763   TORCH_CHECK(
764       get_schema.returns().size() == 1,
765       "'__getstate__' must return 1 value, but found ",
766       get_schema.returns().size());
767 
768   // Check __setstate__ if the method exists
769   //   __setstate__ is expected to be (self, T) -> None
770   auto setstate = cls->findMethod("__setstate__");
771   if (!setstate) {
772     return false;
773   }
774   auto set_schema = setstate->getSchema();
775 
776   TORCH_CHECK(
777       set_schema.arguments().size() == 2,
778       "'__setstate__' must have 'self' and the state as its "
779       "only arguments, but found ",
780       set_schema.arguments().size(),
781       " arguments");
782   TORCH_CHECK(
783       set_schema.returns().size() == 1,
784       "'__setstate__' must return None, but found ",
785       set_schema.returns().size(),
786       " return values");
787   TORCH_CHECK(
788       set_schema.returns().at(0).type()->isSubtypeOf(*NoneType::get()),
789       "'__setstate__' must return None, but found value of type",
790       set_schema.returns().at(0).type()->annotation_str());
791 
792   // Check that the return type of __getstate__ matches the input to
793   // __setstate__
794   auto get_type = get_schema.returns().at(0).type();
795   auto set_type = set_schema.arguments().at(1).type();
796 
797   TORCH_CHECK(
798       get_type->isSubtypeOf(*set_type),
799       "'__getstate__'s return type (",
800       get_type->annotation_str(),
801       ") does not match '__setstate__'s argument type (",
802       set_type->annotation_str(),
803       ")");
804 
805   return true;
806 }
807 
808 } // namespace torch::jit
809