xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/unpickler.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/core/Dict.h>
3 #ifdef USE_RPC
4 #include <torch/csrc/distributed/rpc/rref_context.h>
5 #endif
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/mobile/type_parser.h>
8 #include <torch/csrc/jit/serialization/pickler.h>
9 #include <torch/csrc/jit/serialization/storage_context.h>
10 #include <torch/csrc/jit/serialization/unpickler.h>
11 #include <torch/csrc/utils/byte_order.h>
12 #include <string>
13 #include <utility>
14 
15 namespace torch::jit {
16 
17 using ::c10::IValue;
18 
restoreAccurateTypeTagsIfPossible(const IValue & root)19 static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
20   if (root.isObject()) {
21     restoreAccurateTypeTags(root, root.type());
22   }
23 }
24 
25 // Pickled objects are stored in a form compatible with Python pickling.
26 // In torchscript List[T]/Dict[K, V] are statically typed and contain
27 // dynamic type tags that allow T, K, and V to be recovered. But this
28 // info is not stored in the Python pickling information. However, we
29 // can recover this information from the static type of the top-level
30 // object being unpickled, because we have a record of the type of the
31 // objects it contains as attributes.
32 // `IfPossible` - we can only do this recovery when we have an object as
33 // the top-level unpickled thing (which is guaranteed for Modules, but
34 // not for torch.load/torch.save). Otherwise we do not know the types
35 // of the contained objects and cannot restore the tags.
restoreAccurateTypeTags(const IValue & root,const TypePtr & type_tag)36 void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
37   struct Work {
38     TypePtr type;
39     IValue value;
40   };
41   std::vector<Work> to_process = {{type_tag, root}};
42   std::unordered_set<const void*> scanned;
43   while (!to_process.empty()) {
44     Work w = std::move(to_process.back());
45     to_process.pop_back();
46     // ensure we only scan each pointer value once, otherwise this
47     // can become exponential (and if we allow recursive data in the future,
48     // it would not terminiate).
49     if (w.value.isPtrType()) {
50       const void* key = w.value.internalToPointer();
51       auto it = scanned.find(key);
52       if (it != scanned.end()) {
53         continue;
54       }
55       scanned.emplace_hint(it, key);
56     }
57     auto kind = w.type->kind();
58     if (auto dyn = w.type->castRaw<c10::DynamicType>()) {
59       kind = dyn->dynamicKind();
60     }
61     switch (kind) {
62       case TensorType::Kind:
63       case StorageType::Kind:
64       case NumberType::Kind:
65       case FloatType::Kind:
66       case ComplexType::Kind:
67       case IntType::Kind:
68       case NoneType::Kind:
69       case GeneratorType::Kind:
70       case QuantizerType::Kind:
71       case BoolType::Kind:
72       case VarType::Kind:
73       case CapsuleType::Kind:
74       case PyObjectType::Kind:
75       case StringType::Kind:
76       case FunctionType::Kind:
77       case DeviceObjType::Kind:
78       case StreamObjType::Kind:
79       case QSchemeType::Kind:
80       case LayoutType::Kind:
81       case MemoryFormatType::Kind:
82       case ScalarTypeType::Kind:
83       case RRefType::Kind:
84       case AnyType::Kind:
85       case AnyListType::Kind:
86       case AnyTupleType::Kind:
87       case AnyClassType::Kind:
88       case AnyEnumType::Kind:
89         // no op, there is nothing to tag
90         break;
91       case c10::SymIntType::Kind:
92         // TODO: Can this really show up though? :think:
93         TORCH_CHECK(!w.value.toSymInt().is_heap_allocated());
94         // no op, there is nothing to tag
95         break;
96       case c10::SymFloatType::Kind:
97         TORCH_CHECK(!w.value.toSymFloat().is_symbolic());
98         // no op, there is nothing to tag
99         break;
100       case c10::SymBoolType::Kind:
101         TORCH_CHECK(!w.value.toSymBool().is_heap_allocated());
102         // no op, there is nothing to tag
103         break;
104       case DynamicType::Kind:
105       case UnionType::Kind:
106       case EnumType::Kind:
107         // TODO(gmagogsfm): Implement serialization/deserialization of Enum.
108         TORCH_INTERNAL_ASSERT(false);
109       case TupleType::Kind: {
110         auto t = w.value.toTuple();
111         for (size_t i = 0; i < w.type->containedTypeSize(); ++i) {
112           Work elem = {w.type->containedType(i), t->elements().at(i)};
113           to_process.emplace_back(std::move(elem));
114         }
115       } break;
116       case FutureType::Kind: {
117         auto f = w.value.toFuture();
118         if (f->completed()) {
119           Work elem = {w.type->containedType(0), f->value()};
120           to_process.emplace_back(std::move(elem));
121         }
122       } break;
123       case AwaitType::Kind: {
124         auto aw = w.value.toAwait();
125         if (aw->completed()) {
126           Work elem = {w.type->containedType(0), aw->wait()};
127           to_process.emplace_back(std::move(elem));
128         }
129       } break;
130       case OptionalType::Kind: {
131         if (!w.value.isNone()) {
132           Work elem = {w.type->containedType(0), w.value};
133           to_process.emplace_back(std::move(elem));
134         }
135       } break;
136       case ListType::Kind: {
137         // specialized lists do not need their type refined, so we can exit
138         // early here
139         if (!w.value.isList()) {
140           break;
141         }
142         auto elem_type = w.type->containedType(0);
143         auto lst = w.value.toList();
144         lst.unsafeSetElementType(elem_type);
145         for (const IValue& item : lst) {
146           Work elem = {elem_type, item};
147           to_process.emplace_back(std::move(elem));
148         }
149       } break;
150       case DictType::Kind: {
151         auto d = w.value.toGenericDict();
152         auto keyType = w.type->containedType(0);
153         auto valType = w.type->containedType(1);
154         d.unsafeSetKeyType(keyType);
155         d.unsafeSetValueType(valType);
156         for (const auto& item : d) {
157           Work kelem = {keyType, item.key()};
158           Work velem = {valType, item.value()};
159           to_process.emplace_back(std::move(kelem));
160           to_process.emplace_back(std::move(velem));
161         }
162       } break;
163       // in both cases the dynamic type is a class, and we are going to tag with
164       // the dynamic type
165       case InterfaceType::Kind:
166       case ClassType::Kind: {
167         auto obj = w.value.toObject();
168         auto typ = obj->type(); // note: intentionally using the dynamic type,
169                                 // the static type is potentially less accurate
170         for (size_t i = 0; i < typ->numAttributes(); ++i) {
171           Work elem = {typ->getAttribute(i), obj->getSlot(i)};
172           to_process.emplace_back(std::move(elem));
173         }
174       };
175     }
176   }
177 }
178 
179 namespace {
180 template <typename T>
is(const Type & type)181 bool is(const Type& type) {
182   if (type.kind() == T::Kind) {
183     return true;
184   }
185   if (auto dyn = type.castRaw<c10::DynamicType>()) {
186     return dyn->tag() == c10::DynamicTypeTrait<T>::tagValue();
187   }
188   return false;
189 }
190 } // namespace
191 
restoreContainerTypeTags(const IValue & ivalue,const TypePtr & type)192 static void restoreContainerTypeTags(
193     const IValue& ivalue,
194     const TypePtr& type) {
195   if (is<DictType>(*type)) {
196     auto dict = ivalue.toGenericDict();
197     dict.unsafeSetKeyType(type->containedType(0));
198     dict.unsafeSetValueType(type->containedType(1));
199   } else if (is<ListType>(*type)) {
200     ivalue.toList().unsafeSetElementType(type->containedType(0));
201   } else {
202     AT_ERROR("Unknown type for tag restoration: " + type->annotation_str());
203   }
204 }
205 
parse_ivalue()206 IValue Unpickler::parse_ivalue() {
207   run();
208   TORCH_CHECK(
209       stack_.size() == 1,
210       "Unpickler expected 1 element on the stack, but found ",
211       stack_.size());
212   if (version_ <= 2) {
213     // See [type tag serialization]
214     restoreAccurateTypeTagsIfPossible(stack_[0]);
215   }
216   return stack_[0];
217 }
218 
readFloat()219 double Unpickler::readFloat() {
220   AT_ASSERT(sizeof(double) == 8);
221   double big_endian = read<double>();
222 #if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
223   double little_endian = 0;
224 
225   // Pickle floats are big endian, so reverse the bytes
226   auto big_endian_ptr = reinterpret_cast<const char*>(&big_endian);
227   std::reverse_copy(
228       big_endian_ptr,
229       big_endian_ptr + sizeof(big_endian),
230       reinterpret_cast<char*>(&little_endian));
231 
232   return little_endian;
233 #elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
234   return big_endian;
235 #else
236 #error Unexpected or undefined __BYTE_ORDER__
237 #endif
238 }
239 
run()240 void Unpickler::run() {
241   // Expect a PROTO opcode and protocol number at the start of blob
242   auto opcode = readOpCode();
243   TORCH_CHECK(
244       opcode == PickleOpCode::PROTO,
245       "Expected PROTO opcode at the start"
246       " of pickle archive, found ",
247       int(static_cast<uint8_t>(opcode)));
248   uint8_t protocol = read<uint8_t>();
249   TORCH_CHECK(
250       protocol == 2,
251       "Only Pickle protocol 2 is supported, found protocol = ",
252       protocol);
253 
254   while (true) {
255     PickleOpCode opcode = readInstruction();
256     if (opcode == PickleOpCode::STOP) {
257       return;
258     }
259   }
260 }
setInput(size_t memo_id)261 void Unpickler::setInput(size_t memo_id) {
262   AT_ASSERT(!stack_.empty());
263   if (memo_id >= memo_table_.size()) {
264     memo_table_.insert(
265         memo_table_.end(), memo_id - memo_table_.size(), IValue());
266     memo_table_.push_back(stack_.back());
267   } else {
268     memo_table_[memo_id] = stack_.back();
269   }
270 }
271 
272 // emplace_back on bool vectors does not exist on some systems
273 // avoid it by calling push_back for bool
274 template <typename T>
append(std::vector<T> & a,T && e)275 inline void append(std::vector<T>& a, T&& e) {
276   a.emplace_back(std::forward<T>(e));
277 }
278 template <>
279 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
append(std::vector<bool> & a,bool && e)280 inline void append<bool>(std::vector<bool>& a, bool&& e) {
281   a.push_back(e);
282 }
283 
tupleToIntList(const IValue & v)284 static std::vector<int64_t> tupleToIntList(const IValue& v) {
285   return fmap(v.toTupleRef().elements(), [](const IValue& v) -> int64_t {
286     return v.toInt();
287   });
288 }
289 
290 // note we cannot use toIntList, toDoubleList because during unpickling the
291 // lists are not yet tagged
292 template <typename T>
convertList(const IValue & v)293 static std::vector<T> convertList(const IValue& v) {
294   return fmap(v.toListRef(), [](const IValue& elem) { return elem.to<T>(); });
295 }
296 
readInstruction()297 PickleOpCode Unpickler::readInstruction() {
298   auto opcode = readOpCode();
299   switch (opcode) {
300     case PickleOpCode::EMPTY_LIST: {
301       stack_.emplace_back(c10::impl::GenericList(AnyType::get()));
302     } break;
303     case PickleOpCode::EMPTY_TUPLE: {
304       if (empty_tuple_.isNone()) {
305         // we only need one object, since tuples are not mutable.
306         empty_tuple_ = c10::ivalue::Tuple::create(std::vector<IValue>());
307       }
308       stack_.emplace_back(empty_tuple_);
309     } break;
310     case PickleOpCode::BINPUT: {
311       size_t memo_id = read<uint8_t>();
312       setInput(memo_id);
313     } break;
314     case PickleOpCode::LONG_BINPUT: {
315       TORCH_CHECK(
316           std::numeric_limits<size_t>::max() >=
317               std::numeric_limits<uint32_t>::max(),
318           "Found a LONG_BINPUT opcode, but size_t on this system is "
319           "not big enough to decode it");
320       size_t memo_id = read<uint32_t>();
321       setInput(memo_id);
322     } break;
323     case PickleOpCode::MARK: {
324       // Mark location of the container ivalue in the stack
325       marks_.push_back(stack_.size());
326     } break;
327     case PickleOpCode::NEWTRUE: {
328       stack_.emplace_back(true);
329     } break;
330     case PickleOpCode::NEWFALSE: {
331       stack_.emplace_back(false);
332     } break;
333     case PickleOpCode::NONE: {
334       stack_.emplace_back();
335     } break;
336     case PickleOpCode::BININT1: {
337       uint8_t value = read<uint8_t>();
338       stack_.emplace_back(int64_t(value));
339     } break;
340     case PickleOpCode::BININT2: {
341       uint16_t value = from_le16(read<uint16_t>());
342       stack_.emplace_back(int64_t(value));
343     } break;
344     case PickleOpCode::BININT: {
345       int32_t value = from_le32(read<int32_t>());
346       stack_.emplace_back(int64_t(value));
347     } break;
348     case PickleOpCode::LONG1: {
349       // Only read LONG1s with 8 as the length
350       uint8_t length = read<uint8_t>();
351       TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
352       stack_.emplace_back(int64_t(from_le64(read<int64_t>())));
353     } break;
354     case PickleOpCode::BINUNICODE: {
355       uint32_t length = from_le32(read<uint32_t>());
356       stack_.emplace_back(readBytes(length));
357     } break;
358     case PickleOpCode::BINUNICODE8: {
359       int64_t length = from_le64(read<int64_t>());
360       stack_.emplace_back(readBytes(length));
361     } break;
362     case PickleOpCode::BINFLOAT:
363       stack_.emplace_back(readFloat());
364       break;
365     case PickleOpCode::TUPLE: {
366       TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
367       size_t start = marks_.back();
368       marks_.pop_back();
369       std::vector<IValue> elements;
370       TORCH_CHECK(
371           stack_.size() >= start,
372           "Parsing error: wrong start index ",
373           start,
374           " for stack_ of size ",
375           stack_.size());
376       const auto tupleSize = stack_.size() - start;
377       switch (tupleSize) {
378         case 3: {
379           auto e3 = pop(stack_);
380           auto e2 = pop(stack_);
381           auto e1 = pop(stack_);
382           stack_.emplace_back(c10::ivalue::Tuple::create(
383               std::move(e1), std::move(e2), std::move(e3)));
384           break;
385         }
386         case 2: {
387           auto e2 = pop(stack_);
388           auto e1 = pop(stack_);
389           stack_.emplace_back(
390               c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
391           break;
392         }
393         case 1:
394           stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
395           break;
396         default: {
397           elements.reserve(stack_.size() - start);
398           auto start_it = stack_.begin() + static_cast<std::ptrdiff_t>(start);
399           for (auto it = start_it; it != stack_.end(); ++it) {
400             elements.emplace_back(std::move(*it));
401           }
402           stack_.erase(start_it, stack_.end());
403           stack_.emplace_back(c10::ivalue::Tuple::create(std::move(elements)));
404           break;
405         }
406       }
407     } break;
408     case PickleOpCode::TUPLE1: {
409       TORCH_CHECK(
410           !stack_.empty(),
411           "Parsing error: stack_ contains ",
412           stack_.size(),
413           " elements, at least 1 expected");
414       stack_.emplace_back(c10::ivalue::Tuple::create(pop(stack_)));
415     } break;
416     case PickleOpCode::TUPLE2: {
417       TORCH_CHECK(
418           stack_.size() > 1,
419           "Parsing error: stack_ contains ",
420           stack_.size(),
421           " elements, at least 2 expected");
422       auto e2 = pop(stack_);
423       auto e1 = pop(stack_);
424       stack_.emplace_back(
425           c10::ivalue::Tuple::create(std::move(e1), std::move(e2)));
426     } break;
427     case PickleOpCode::TUPLE3: {
428       TORCH_CHECK(
429           stack_.size() > 2,
430           "Parsing error: stack_ contains ",
431           stack_.size(),
432           " elements, at least 3 expected");
433       auto e3 = pop(stack_);
434       auto e2 = pop(stack_);
435       auto e1 = pop(stack_);
436       stack_.emplace_back(c10::ivalue::Tuple::create(
437           std::move(e1), std::move(e2), std::move(e3)));
438     } break;
439     case PickleOpCode::EMPTY_DICT:
440       stack_.emplace_back(
441           c10::impl::GenericDict(AnyType::get(), AnyType::get()));
442       break;
443     case PickleOpCode::APPENDS: {
444       TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
445       size_t start = marks_.back();
446       TORCH_CHECK(
447           start > 0 && start <= stack_.size(),
448           "Parsing error: wrong start index ",
449           start,
450           " for stack_ of size ",
451           stack_.size());
452       auto list_ivalue = stack_.at(start - 1);
453       readList(list_ivalue);
454     } break;
455     case PickleOpCode::APPEND: {
456       TORCH_CHECK(
457           stack_.size() >= 2, "Parsing error: missing elements in stack_.");
458       auto list_ivalue = stack_.at(stack_.size() - 2);
459       readListElements(list_ivalue, stack_.size() - 1);
460     } break;
461     case PickleOpCode::LIST: {
462       IValue list_ivalue = c10::impl::GenericList(AnyType::get());
463       readList(list_ivalue);
464       stack_.push_back(std::move(list_ivalue));
465     } break;
466     case PickleOpCode::DICT: {
467       TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
468       size_t start = marks_.back();
469       marks_.pop_back();
470       TORCH_CHECK(
471           stack_.size() > start,
472           "Parsing error: wrong start index ",
473           start,
474           " for stack_ which of size ",
475           stack_.size());
476       auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get());
477       TORCH_CHECK(
478           (stack_.size() - start) % 2 == 0,
479           "Parsing error: stack_ is of size ",
480           stack_.size(),
481           " and start index is ",
482           start,
483           ", but stack_ is iterated by two elements at a time");
484       for (size_t i = start; i < stack_.size(); i += 2) {
485         dict.insert_or_assign(stack_[i], stack_[i + 1]);
486       }
487       stack_.erase(
488           stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
489       stack_.emplace_back(std::move(dict));
490     } break;
491     case PickleOpCode::SETITEMS: {
492       TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
493       size_t start = marks_.back();
494       marks_.pop_back();
495       TORCH_CHECK(
496           start > 0 && start <= stack_.size(),
497           "Parsing error: wrong start index for stack_");
498       auto dict = stack_.at(start - 1).toGenericDict();
499       TORCH_CHECK(
500           (stack_.size() - start) % 2 == 0,
501           "Parsing error: stack_ is of size ",
502           stack_.size(),
503           " and start index is ",
504           start,
505           ", but stack_ is iterated by two elemenst at a time");
506       for (size_t i = start; i < stack_.size(); i += 2) {
507         dict.insert_or_assign(stack_[i], stack_[i + 1]);
508       }
509       stack_.erase(
510           stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
511     } break;
512     case PickleOpCode::BINGET: {
513       auto pos = read<uint8_t>();
514       TORCH_CHECK(
515           memo_table_.size() > pos,
516           "Parsing error: out of bounds access at ",
517           (size_t)pos,
518           " to memo_table_ which is of size ",
519           memo_table_.size());
520       stack_.push_back(memo_table_.at(pos));
521     } break;
522     case PickleOpCode::LONG_BINGET: {
523       auto pos = read<uint32_t>();
524       TORCH_CHECK(
525           memo_table_.size() > pos,
526           "Parsing error: out of bounds access at ",
527           (size_t)pos,
528           " to memo_table_ which is of size ",
529           memo_table_.size());
530       stack_.push_back(memo_table_.at(pos));
531     } break;
532     case PickleOpCode::STOP:
533       break;
534     case PickleOpCode::GLOBAL: {
535       // Module name, it's not needed for anything
536       auto module_name = readString();
537       auto class_name = readString();
538       readGlobal(module_name, class_name);
539     } break;
540     case PickleOpCode::NEWOBJ: {
541       TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
542       // pop empty tuple, the actual action is stored in the globals_stack_
543       stack_.pop_back();
544     } break;
545     // because we have NEWOBJ do nothing, BUILD and REDUCE end up doing
546     // the same thing
547     case PickleOpCode::BUILD:
548     case PickleOpCode::REDUCE: {
549       // stack is: <functor_idx> <functor_arg>
550       // extract <functor_idx> and remove from the stack:
551       TORCH_CHECK(
552           stack_.size() > 1,
553           "Parsing error: stack_ contains ",
554           stack_.size(),
555           " elements, at least 2 expected");
556       std::swap(*(stack_.end() - 2), *(stack_.end() - 1));
557       size_t idx = stack_.back().toInt();
558       stack_.pop_back();
559       // stack is: <functor_arg>
560       TORCH_CHECK(
561           idx < globals_.size(),
562           "Parsing error: out of bounds access to globals_");
563       globals_.at(idx)();
564     } break;
565     case PickleOpCode::BINPERSID: {
566       TORCH_CHECK(!stack_.empty(), "Parsing error: stack_ is empty");
567       auto tuple = pop(stack_).toTuple();
568       const auto& args = tuple->elements();
569       AT_ASSERT(
570           args.at(0).toStringRef() == "storage",
571           "unknown PERSID key ",
572           args.at(0).toStringRef());
573       at::ScalarType type = args.at(1).toScalarType();
574       const std::string& key = args.at(2).toStringRef();
575 
576       at::Device device(args.at(3).toStringRef());
577       // remap device location if it's not meta
578       if (device_ && !device.is_meta()) {
579         device = *device_;
580       }
581 
582       at::Storage storage;
583       if (storage_context_ != nullptr && storage_context_->hasStorage(key)) {
584         // for torch.package logic where storage may be loaded already
585         storage = storage_context_->getStorage(key);
586       } else {
587         int64_t numel = args.at(4).toInt();
588         caffe2::TypeMeta dtype = at::CPU(type).typeMeta();
589 
590         at::DataPtr storage_ptr;
591         if (numel > 0) {
592           // If there are no elements in the tensor, there's no point in
593           // reading a zero (0) byte file from the input stream and paying
594           // that cost.
595           storage_ptr = read_record_(key);
596         }
597 
598         storage = at::Storage(
599             c10::Storage::use_byte_size_t(),
600             numel * dtype.itemsize(),
601             std::move(storage_ptr),
602             /*allocator=*/nullptr,
603             /*resizable=*/false); // NB: we didn't set any allocator for the
604                                   // tensor
605         if (storage_context_ != nullptr) {
606           storage_context_->addStorage(key, storage);
607         }
608       }
609 
610       auto options = at::CPU(type).options();
611       if (use_storage_device_) {
612         options = options.device(storage.device());
613         device = storage.device();
614       }
615 
616       at::Tensor tensor;
617       if (options.backend() == c10::Backend::QuantizedCPU) {
618         tensor = at::_empty_affine_quantized({}, options, 0, 0)
619                      .set_(storage, 0, {}, {});
620       } else {
621         tensor = at::empty({0}, options).set_(storage);
622       }
623 
624       if (device.is_cuda() || device.is_xpu() || device.is_meta() ||
625           device.is_hpu() || device.is_mps() || device.is_privateuseone()) {
626         tensor = tensor.to(device, tensor.scalar_type());
627       } else if (device.type() != DeviceType::CPU) {
628         AT_ERROR(
629             "supported devices include CPU, CUDA, HPU and ",
630             c10::get_privateuse1_backend(),
631             " however got ",
632             DeviceTypeName(device.type(), false));
633       }
634       stack_.emplace_back(std::move(tensor));
635     } break;
636     case PickleOpCode::SETITEM: {
637       // At this OpCode, stack looks like
638       // | Stack Bottom |
639       // | ......       |
640       // | Dict         | -> (stack_size - 3)
641       // | Key          | -> (stack_size - 2)
642       // | Value        | -> (stack_size - 1)
643       TORCH_CHECK(
644           stack_.size() >= 3,
645           "Parsing error: stack doesn't have enough elements");
646 
647       auto stack_size = stack_.size();
648       auto dict_pos = stack_size - 3;
649       auto key_pos = stack_size - 2;
650       auto val_pos = stack_size - 1;
651 
652       TORCH_CHECK(
653           (dict_pos < stack_size) && (key_pos < stack_size) &&
654               (val_pos < stack_size),
655           "Parsing error: attempted out-of-bounds access while processing SETITEM opcode");
656 
657       auto dict = stack_.at(dict_pos).toGenericDict();
658       dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos));
659       stack_.erase(
660           stack_.begin() + static_cast<std::ptrdiff_t>(key_pos), stack_.end());
661     } break;
662     default: {
663       AT_ERROR(
664           "Unknown opcode for unpickling at ",
665           // NOLINTNEXTLINE(performance-no-int-to-ptr)
666           reinterpret_cast<void*>(opcode),
667           ": ",
668           int(static_cast<uint8_t>(opcode)));
669     } break;
670   }
671   return opcode;
672 }
673 
readGlobal(const std::string & module_name,const std::string & class_name)674 void Unpickler::readGlobal(
675     const std::string& module_name,
676     const std::string& class_name) {
677   if (this->skip_next_read_global) {
678     // See [NOTE] skip_next_read_global
679     this->skip_next_read_global--;
680     if (this->skip_next_read_global == 1) {
681       // Pass through to the correct handler
682     } else if (this->skip_next_read_global == 0) {
683       // Corresponds to the type of `Tensor` being unpickled
684       if (module_name != "torch" || class_name != "Tensor") {
685         TORCH_WARN(
686             "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++");
687       }
688       stack_.emplace_back(int64_t(globals_.size() - 1));
689       return;
690     } else {
691       TORCH_CHECK(false, "INVALID VALUES")
692     }
693   }
694   // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this
695   // is only here for bc-compatibility reasons
696   if (module_name == "__main__") {
697     if (class_name == "TensorID") {
698       globals_.emplace_back([this] {
699         auto setitem_data = stack_.back();
700         stack_.pop_back();
701         TORCH_INTERNAL_ASSERT(
702             !tensor_table_.empty(),
703             "Pickler tried to write a tensor but had no tensor table to write to");
704         stack_.emplace_back(tensor_table_.at(setitem_data.toInt()));
705       });
706     } else if (class_name == "IntList") {
707       globals_.emplace_back([this] {
708         stack_.back().toList().unsafeSetElementType(IntType::get());
709       });
710     } else {
711       AT_ERROR("Unknown pickler class id", class_name);
712     }
713   } else if (module_name == "torch.jit._pickle") {
714     if (class_name == "build_tensor_from_id") {
715       globals_.emplace_back([this] {
716         // Pop reduce arg off the stack
717         auto data = stack_.back().toTupleRef().elements().at(0);
718         stack_.pop_back();
719         TORCH_CHECK(
720             !tensor_table_.empty(),
721             "Found a tensor table reference but Unpickler"
722             " has no tensor table\n");
723         stack_.emplace_back(tensor_table_.at(data.toInt()));
724       });
725     } else if (class_name == "restore_type_tag") {
726       globals_.emplace_back([this] {
727         auto tuple = stack_.back().toTuple();
728         const auto& data = tuple->elements();
729         auto type_str = data.at(1).toStringRef();
730         stack_.pop_back();
731         TypePtr type = nullptr;
732         auto entry = type_cache_.find(type_str);
733         if (entry != type_cache_.end()) {
734           type = entry->second;
735         } else {
736           if (type_resolver_ == nullptr) {
737             // If we haven't injected a custom way of retrieving types from
738             // names, use a barebones type parser.
739             type = type_parser_(type_str);
740           } else {
741             type = type_resolver_(type_str).type_;
742           }
743           type_cache_[type_str] = type;
744         }
745         // TODO: Use lookahead to avoid creating the tuple and immediately
746         // destroying it here
747         restoreContainerTypeTags(data.at(0), type);
748         stack_.emplace_back(data.at(0));
749       });
750     } else {
751       TypePtr elem_type = nullptr;
752       if (class_name == "build_intlist") {
753         elem_type = IntType::get();
754       } else if (class_name == "build_tensorlist") {
755         elem_type = TensorType::get();
756       } else if (class_name == "build_doublelist") {
757         elem_type = FloatType::get();
758       } else if (class_name == "build_boollist") {
759         elem_type = BoolType::get();
760       } else {
761         AT_ERROR("Unknown pickler class id ", class_name);
762       }
763       // Unpickle a list specialization (e.g. List[Tensor], List[int], ...)
764       globals_.emplace_back([this, elem_type] {
765         // Pop reduce arg off the stack
766         auto data = stack_.back().toTupleRef().elements().at(0).toList();
767         stack_.pop_back();
768         data.unsafeSetElementType(elem_type);
769         stack_.emplace_back(std::move(data));
770       });
771     }
772   } else if (
773       module_name == "torch._utils" &&
774       (class_name == "_rebuild_tensor_v2" ||
775        class_name == "_rebuild_qtensor")) {
776     // Unpickle a tensor
777     bool quantized = class_name == "_rebuild_qtensor";
778     rebuildTensor(quantized);
779   } else if (
780       module_name == "torch._tensor" &&
781       (class_name == "_rebuild_from_type_v2")) {
782     // Unpickle a Tensor with Python attributes or
783     // a Subclassed Tensor.
784     rebuildTensorFromTypeV2();
785   } else if (
786       module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") {
787     rebuildSparseTensor();
788   } else if (module_name == "builtins" && class_name == "complex") {
789     globals_.emplace_back([this] {
790       auto tuple = pop(stack_).toTuple();
791       const auto& elems = tuple->elements();
792       AT_ASSERT(elems.size() == 2);
793       auto complex =
794           c10::complex<double>(elems.at(0).toDouble(), elems.at(1).toDouble());
795       stack_.emplace_back(complex);
796     });
797 
798   } else if (module_name == "collections" && class_name == "OrderedDict") {
799     // collections.OrderedDict is used in tensor serialization for a tensor's
800     // backward hooks (but they are not actually saved with this Pickler)
801     globals_.emplace_back([this] {
802       // drop the Tuple that was argument to OrderedDict, and replace it
803       // with None OrderedDicts only appear in tensor deserialization and
804       // their value is never used
805       stack_.back() = IValue();
806     });
807   } else if (module_name == "torch" && class_name == "device") {
808     globals_.emplace_back([this] {
809       auto device_string = stack_.back().toTupleRef().elements().at(0);
810       stack_.pop_back();
811       stack_.emplace_back(c10::Device(device_string.toStringRef()));
812     });
813     stack_.emplace_back(int64_t(globals_.size() - 1));
814     return;
815   } else if (module_name == "torch.distributed.rpc" && class_name == "rref") {
816 #ifdef USE_RPC
817     return rebuildRRef();
818 #else
819     TORCH_INTERNAL_ASSERT(
820         false,
821         "RRef unpickling is only supported with the distributed package");
822 #endif
823   } else if (module_name == "torch") {
824     // Try to manually resolve several global enums
825     // NOTE: this does not put a global into the global table,
826     // like the other branches here because no REDUCE or BUILD will
827     // be called on this value. Instead, we just put it on the stack
828     // and return early
829     std::optional<c10::ScalarType> scalar_type;
830 #define CHECK_SCALAR(_, name)          \
831   if (class_name == #name "Storage") { \
832     scalar_type = c10::k##name;        \
833   }
834     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(CHECK_SCALAR)
835 #undef CHECK_SCALAR
836     if (scalar_type.has_value()) {
837       stack_.emplace_back(int64_t(*scalar_type));
838       return;
839     }
840 
841     std::optional<at::QScheme> qscheme;
842     for (int i = 0; i < at::COMPILE_TIME_NUM_QSCHEMES; ++i) {
843       if (class_name == toString(static_cast<at::QScheme>(i))) {
844         qscheme = static_cast<at::QScheme>(i);
845       }
846     }
847     if (qscheme.has_value()) {
848       stack_.emplace_back(int64_t(*qscheme));
849       return;
850     }
851     TORCH_CHECK(
852         false,
853         "Unpickler found unknown torch global, 'torch.",
854         class_name,
855         "'");
856   } else {
857     TORCH_CHECK(
858         type_resolver_,
859         "Unpickler found unknown type ",
860         module_name,
861         ".",
862         class_name);
863     at::StrongTypePtr type =
864         type_resolver_(c10::QualifiedName(module_name, class_name));
865     if (auto enum_type = type.type_->cast<c10::EnumType>()) {
866       globals_.emplace_back([this, enum_type] {
867         auto val = stack_.back();
868         stack_.pop_back();
869         for (const auto& p : enum_type->enumNamesValues()) {
870           if (p.second == val) {
871             auto enum_holder = c10::make_intrusive<at::ivalue::EnumHolder>(
872                 enum_type, p.first, p.second);
873             stack_.emplace_back(std::move(enum_holder));
874             return;
875           }
876         }
877       });
878     } else {
879       // Otherwise, global is a class/object type.
880       globals_.emplace_back([this, type] {
881         auto val = stack_.back();
882         stack_.pop_back();
883         auto obj = obj_loader_(type, val);
884         stack_.emplace_back(std::move(obj));
885       });
886     }
887   }
888   stack_.emplace_back(int64_t(globals_.size() - 1));
889 }
890 
rebuildSparseTensor()891 void Unpickler::rebuildSparseTensor() {
892   globals_.emplace_back([this] {
893     auto tup = pop(stack_).toTuple();
894     const auto& elements = tup->elements();
895     size_t idx = 0;
896     auto layout = elements.at(idx++).toInt();
897     at::Tensor result;
898     switch (layout) {
899       case static_cast<int>(c10::Layout::Sparse): {
900         std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
901         bool requires_grad = elements.at(idx++).toBool();
902         auto& indices_tensor = elements.at(idx++).toTensor();
903         auto& values_tensor = elements.at(idx++).toTensor();
904         auto options = values_tensor.options()
905                            .layout(c10::Layout::Sparse)
906                            .requires_grad(requires_grad);
907         result = at::_sparse_coo_tensor_unsafe(
908             indices_tensor, values_tensor, size, options);
909         result = autograd::make_variable(result, options.requires_grad());
910         break;
911       }
912       case static_cast<int>(c10::Layout::SparseCsr): {
913         std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
914         bool requires_grad = elements.at(idx++).toBool();
915         auto& crow_indices = elements.at(idx++).toTensor();
916         auto& col_indices = elements.at(idx++).toTensor();
917         auto& values_tensor = elements.at(idx++).toTensor();
918         auto options = values_tensor.options()
919                            .layout(c10::Layout::SparseCsr)
920                            .requires_grad(requires_grad);
921         result = at::_sparse_csr_tensor_unsafe(
922             crow_indices, col_indices, values_tensor, size, options);
923         result =
924             autograd::make_variable(std::move(result), options.requires_grad());
925         break;
926       }
927       default:
928         TORCH_CHECK(
929             false,
930             "Unsupported sparse tensor layout type in serialization ",
931             static_cast<c10::Layout>(layout));
932         break;
933     }
934     stack_.emplace_back(std::move(result));
935   });
936 }
937 
rebuildTensor(bool quantized)938 void Unpickler::rebuildTensor(bool quantized) {
939   globals_.emplace_back([this, quantized] {
940     auto tup = pop(stack_).toTuple();
941     const auto& elements = tup->elements();
942     size_t idx = 0;
943     auto& storage_tensor = elements.at(idx++).toTensor();
944     int64_t storage_offset = elements.at(idx++).toInt();
945     std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
946     std::vector<int64_t> stride = tupleToIntList(elements.at(idx++));
947     at::Tensor result;
948     if (quantized) {
949       auto qparams_tuple = elements.at(idx++).toTuple();
950       const auto& qparams = qparams_tuple->elements();
951       auto qscheme = static_cast<at::QScheme>(qparams.at(0).toInt());
952       switch (qscheme) {
953         case at::kPerTensorAffine: {
954           double q_scale = qparams.at(1).toDouble();
955           int64_t q_zero_point = qparams.at(2).toInt();
956           result = at::_empty_affine_quantized(
957               {0}, storage_tensor.options(), q_scale, q_zero_point);
958         } break;
959         case at::kPerChannelAffineFloatQParams:
960         case at::kPerChannelAffine: {
961           const auto& scales = qparams.at(1).toTensor();
962           const auto& zero_points = qparams.at(2).toTensor();
963           int64_t axis = qparams.at(3).toInt();
964           result = at::_empty_per_channel_affine_quantized(
965               {0}, scales, zero_points, axis, storage_tensor.options());
966         } break;
967         default:
968           TORCH_CHECK(
969               false,
970               "Unsupported tensor quantization type in serialization ",
971               toString(qscheme));
972           break;
973       }
974     } else {
975       result = at::empty({0}, storage_tensor.options());
976     }
977     bool requires_grad = elements.at(idx++).toBool();
978     idx++; // backwards hooks is empty
979     at::TensorImpl* impl = result.unsafeGetTensorImpl();
980     impl->set_storage_keep_dtype(storage_tensor.storage());
981     impl->set_storage_offset(storage_offset);
982     impl->set_sizes_and_strides(size, stride);
983     result = autograd::make_variable(result, requires_grad);
984 
985     // Handle if math_bits were pickled.
986     // See `args` of _reduce_ex_internal
987     // for a regular tensor (final else case).
988     // Tensors pickled before this patch didn't
989     // have this argument for storing MathBits,
990     // in that case, we do nothing.
991     // NOTE: `math_bits` is the 7th arg.
992     // NOTE: This is only meant for regular tensor and not quantized
993     //       which also has 7 args serialized.
994     if (!quantized && elements.size() == 7) {
995       auto math_bits = elements.at(idx++).toGenericDict();
996       torch::jit::setTensorMetadata(result, math_bits);
997     }
998 
999     stack_.emplace_back(std::move(result));
1000   });
1001 }
1002 
rebuildTensorFromTypeV2()1003 void Unpickler::rebuildTensorFromTypeV2() {
1004   // [NOTE] skip_next_read_global
1005   // When rebuilding Tensor with Python Attr or Subclassed Tensor,
1006   // we receive `(func, type(self), args, state)` on stack for
1007   // `rebuildTensorFromTypeV2`.
1008   // Thus next call to readGlobal corresponds to `func` which is
1009   // the function to rebuild the base tensor.
1010   // The call after `func` to readGlobal corresponds to `type` of the
1011   // Tensor where we raise warning if the type is not `torch.Tensor`.
1012   this->skip_next_read_global = 2;
1013   auto curr_globals_idx = globals_.size();
1014   globals_.emplace_back([this, curr_globals_idx] {
1015     // args is a tuple with following data
1016     //  (function to rebuild base tensor, type of tensor,
1017     //   arguments to construct base tensor, Python State (as dict))
1018     auto args = pop(stack_).toTuple();
1019     size_t tup_idx = 0;
1020     const auto args_elems = args->elements();
1021     auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple();
1022     auto py_state = args_elems.at(tup_idx + 3).toGenericDict();
1023     if (!py_state.empty()) {
1024       TORCH_WARN(
1025           "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded");
1026     }
1027     // This calls the function to rebuild the
1028     // base tensor.
1029     // Eg. `rebuildTensor`, `rebuildSpareTensor`.
1030     stack_.emplace_back(base_tensor_args);
1031     globals_[curr_globals_idx + 1]();
1032     stack_.emplace_back(pop(stack_));
1033   });
1034 }
1035 
1036 #ifdef USE_RPC
rebuildRRef()1037 void Unpickler::rebuildRRef() {
1038   globals_.emplace_back([this] {
1039     // It is the same as how rref is unpickled in python,
1040     // see PyRRef::unpickle
1041     auto tuple = std::move(stack_.back()).toTuple();
1042     const auto& args = tuple->elements();
1043     stack_.pop_back();
1044     TORCH_INTERNAL_ASSERT(
1045         args.size() == distributed::rpc::RFD_TUPLE_SIZE,
1046         "Pickled RRefForkData must contain 7 numbers.");
1047     auto ownerId =
1048         static_cast<int16_t>(args.at(distributed::rpc::OWNER_IDX).toInt());
1049     // const reference will extend the lifetime of the temporary variable
1050     const auto& rrefId = distributed::rpc::RRefId(
1051         static_cast<int16_t>(args.at(distributed::rpc::RREFID_ON_IDX).toInt()),
1052         static_cast<int64_t>(args.at(distributed::rpc::RREFID_ID_IDX).toInt()));
1053     const auto& forkId = distributed::rpc::RRefId(
1054         static_cast<int16_t>(args.at(distributed::rpc::FORKID_ON_IDX).toInt()),
1055         static_cast<int64_t>(args.at(distributed::rpc::FORKID_ID_IDX).toInt()));
1056     auto parent =
1057         static_cast<int16_t>(args.at(distributed::rpc::PARENT_IDX).toInt());
1058     const auto& typeStr = static_cast<std::string>(
1059         args.at(distributed::rpc::TYPE_IDX).toStringRef());
1060     auto rrefForkData = distributed::rpc::RRefForkData(
1061         ownerId, rrefId, forkId, parent, typeStr);
1062     auto& ctx = distributed::rpc::RRefContext::getInstance();
1063     c10::intrusive_ptr<distributed::rpc::RRef> rref;
1064     TORCH_INTERNAL_ASSERT(
1065         type_resolver_ != nullptr, "type_resolver_ is nullptr.");
1066     at::StrongTypePtr type = type_resolver_(c10::QualifiedName(typeStr));
1067     rref = ctx.getOrCreateRRef(rrefForkData, type.type_);
1068     ctx.notifyOwnerAndParentOfFork(
1069         rrefForkData.forkId_, rrefForkData.parent_, rref);
1070     stack_.emplace_back(
1071         c10::static_intrusive_pointer_cast<c10::RRefInterface>(rref));
1072   });
1073   stack_.emplace_back(int64_t(globals_.size() - 1));
1074   return;
1075 }
1076 #endif
1077 
readSlowWithBuffer(char * dest,size_t sz)1078 void Unpickler::readSlowWithBuffer(char* dest, size_t sz) {
1079   // First, read any partial from buffer (may be 0).
1080   // We explicitly assume that sz > buffer_remaining_,
1081   // and that sz is never bigger than buffer_.size().
1082   AT_ASSERT(sz > buffer_remaining_);
1083   const size_t from_old_buf = buffer_remaining_;
1084   if (from_old_buf != 0) {
1085     memcpy(dest, buffer_.data() + buffer_pos_, from_old_buf);
1086   }
1087   const size_t needed = sz - from_old_buf;
1088   // Full read into the buffer. The calls here all explicitly
1089   // assume that one buffer will be enough for any sz.
1090   AT_ASSERT(sz <= buffer_.size());
1091   buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
1092   if (buffer_remaining_ < needed) {
1093     AT_ERROR("Unexpected end of pickler archive.");
1094   }
1095   memcpy(dest + from_old_buf, buffer_.data(), needed);
1096   buffer_pos_ = needed; // assignment (0'ed from read)
1097   buffer_remaining_ -= needed;
1098 }
1099 
1100 // Read a number of bytes from the input stream
readBytes(size_t length)1101 std::string Unpickler::readBytes(size_t length) {
1102   std::string data;
1103   static const size_t kSmallString = 64;
1104   TORCH_CHECK(
1105       length <= data.max_size(),
1106       "Parsing error: can't read ",
1107       length,
1108       " bytes to a string");
1109   if (length <= buffer_remaining_) {
1110     // Fast-path: entirely in buffer.
1111     data.assign(buffer_.data() + buffer_pos_, length);
1112     buffer_pos_ += length;
1113     buffer_remaining_ -= length;
1114   } else if (length <= kSmallString) {
1115     // If the string is smallish, do a full buffer read,
1116     // and read out of that buffer.
1117     data.resize(length);
1118     readSlowWithBuffer(&data[0], length);
1119   } else {
1120     // Otherwise, for larger strings, read what we can from
1121     // the buffer, and then read directly to the destination.
1122     const size_t from_old_buf = buffer_remaining_;
1123     if (from_old_buf != 0) {
1124       data.reserve(length);
1125       data.append(buffer_.data() + buffer_pos_, from_old_buf);
1126     }
1127     data.resize(length);
1128     const size_t needed = length - from_old_buf;
1129     size_t nread = reader_(&data[from_old_buf], needed);
1130     if (nread != needed) {
1131       AT_ERROR("Unexpected end of pickler archive.");
1132     }
1133     buffer_remaining_ = 0;
1134     // buffer_pos_ has no meaning with buffer_remaining_ == 0.
1135   }
1136   return data;
1137 }
1138 
readListElements(IValue list_ivalue,size_t start)1139 void Unpickler::readListElements(IValue list_ivalue, size_t start) {
1140   auto num_elements = stack_.size() - start;
1141   auto elements = c10::ArrayRef<IValue>(stack_).slice(start);
1142   if (list_ivalue.isIntList()) {
1143     auto list = std::move(list_ivalue).toIntList();
1144     list.reserve(num_elements);
1145     for (const auto& elem : elements) {
1146       list.emplace_back(elem.toInt());
1147     }
1148   } else if (list_ivalue.isTensorList()) {
1149     auto list = std::move(list_ivalue).toTensorList();
1150     list.reserve(num_elements);
1151     for (const auto& elem : elements) {
1152       list.emplace_back(elem.toTensor());
1153     }
1154   } else if (list_ivalue.isDoubleList()) {
1155     auto list = std::move(list_ivalue).toDoubleList();
1156     list.reserve(num_elements);
1157     for (const auto& elem : elements) {
1158       list.emplace_back(elem.toDouble());
1159     }
1160   } else if (list_ivalue.isBoolList()) {
1161     auto list = std::move(list_ivalue).toBoolList();
1162     list.reserve(num_elements);
1163     for (const auto& elem : elements) {
1164       list.push_back(elem.toBool());
1165     }
1166   } else if (list_ivalue.isList()) {
1167     auto list = std::move(list_ivalue).toList();
1168     list.reserve(num_elements);
1169     for (const auto& elem : elements) {
1170       list.emplace_back(elem);
1171     }
1172   } else {
1173     AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind());
1174   }
1175   stack_.erase(
1176       stack_.begin() + static_cast<std::ptrdiff_t>(start), stack_.end());
1177 }
1178 
1179 // Pop all the list items off of the stack and append them to the list at
1180 // the corresponding MARK
readList(IValue list_ivalue)1181 void Unpickler::readList(IValue list_ivalue) {
1182   TORCH_CHECK(!marks_.empty(), "Parsing error: marks_ is empty");
1183   size_t start = marks_.back();
1184   marks_.pop_back();
1185   readListElements(std::move(list_ivalue), start);
1186 }
1187 
is_valid_python_id_char(char c)1188 inline bool is_valid_python_id_char(char c) {
1189   return c == '_' || c == '.' || (c >= '0' && c <= '9') ||
1190       (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
1191 }
1192 
1193 // Read a newline terminated string
readString()1194 std::string Unpickler::readString() {
1195   std::string ss;
1196   while (true) {
1197     auto* const bufferStart = buffer_.data() + buffer_pos_;
1198     const auto bufferLeft = buffer_.size() - buffer_pos_;
1199     char* const newlinePtr =
1200         static_cast<char*>(memchr(bufferStart, '\n', bufferLeft));
1201     if (newlinePtr) {
1202       // read up to newline and we are done.
1203       auto const charsRead = newlinePtr - bufferStart;
1204       ss.append(bufferStart, charsRead);
1205       buffer_remaining_ -= charsRead + 1;
1206       buffer_pos_ += charsRead + 1;
1207       break;
1208     } else {
1209       // read whole buffer, refill
1210       for (const char* p = bufferStart; p < bufferStart + bufferLeft; ++p) {
1211         // Simple check just in case there is no terminating '\n'
1212         TORCH_CHECK(
1213             is_valid_python_id_char(*p),
1214             "Found character '",
1215             int(uint8_t(*p)),
1216             "' in string, ",
1217             "strings must be qualified Python identifiers");
1218       }
1219       ss.append(bufferStart, bufferLeft);
1220       buffer_remaining_ = reader_(buffer_.data(), buffer_.size());
1221       buffer_pos_ = 0;
1222     }
1223   }
1224   return ss;
1225 }
1226 
1227 } // namespace torch::jit
1228