xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/import_read.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <caffe2/serialize/inline_container.h>
2 #include <torch/csrc/jit/serialization/import_read.h>
3 
4 #include <utility>
5 
6 namespace torch::jit {
7 
readArchiveAndTensors(const std::string & archive_name,const std::string & pickle_prefix,const std::string & tensor_prefix,std::optional<TypeResolver> type_resolver,std::optional<ObjLoader> obj_loader,std::optional<at::Device> device,caffe2::serialize::PyTorchStreamReader & stream_reader,c10::TypePtr (* type_parser)(const std::string &),std::shared_ptr<DeserializationStorageContext> storage_context)8 IValue readArchiveAndTensors(
9     const std::string& archive_name,
10     const std::string& pickle_prefix,
11     const std::string& tensor_prefix,
12     std::optional<TypeResolver> type_resolver,
13     std::optional<ObjLoader> obj_loader,
14     std::optional<at::Device> device,
15     caffe2::serialize::PyTorchStreamReader& stream_reader,
16     c10::TypePtr (*type_parser)(const std::string&),
17     std::shared_ptr<DeserializationStorageContext> storage_context) {
18   std::string picklename = pickle_prefix + archive_name + ".pkl";
19   at::DataPtr pickle_ptr;
20   size_t pickle_size = 0;
21   std::tie(pickle_ptr, pickle_size) = stream_reader.getRecord(picklename);
22 
23   size_t bytes_read = 0;
24   auto data = reinterpret_cast<const char*>(pickle_ptr.get());
25   auto reader = [&](char* buffer, size_t len) -> size_t {
26     if (bytes_read >= pickle_size) {
27       return 0;
28     }
29     len = std::min(pickle_size - bytes_read, len);
30     // Copy len bytes into buffer
31     const char* start = data + bytes_read;
32     std::memcpy(buffer, start, len);
33     bytes_read += len;
34     return len;
35   };
36 
37   std::string tensor_dir_path =
38       (!tensor_prefix.empty()) ? tensor_prefix : archive_name + "/";
39 
40   auto read_record = [&](const std::string& name) {
41     std::string ss = tensor_dir_path + name;
42     return std::get<0>(stream_reader.getRecord(ss));
43   };
44 
45   Unpickler unpickler(
46       reader,
47       type_resolver ? std::move(*type_resolver) : nullptr,
48       obj_loader ? std::move(*obj_loader) : nullptr,
49       std::move(read_record),
50       device,
51       false,
52       type_parser,
53       std::move(storage_context));
54   unpickler.set_version(stream_reader.version());
55   return unpickler.parse_ivalue();
56 }
57 
check_zip_file(const std::shared_ptr<caffe2::serialize::ReadAdapterInterface> & rai)58 bool check_zip_file(
59     const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai) {
60   std::array<uint8_t, 2> first_short{};
61   static constexpr uint8_t first_slot = 0x80;
62   static constexpr uint8_t second_slot = 0x02;
63   rai->read(
64       /*pos=*/0,
65       /*buf=*/&first_short,
66       /*n=*/2,
67       /*what=*/"checking archive");
68 
69   // NB: zip files by spec can start with any data, so technically they might
70   // start with 0x80 0x02, but in practice zip files start with a file entry
71   // which begins with 0x04034b50. Furthermore, PyTorch will never produce zip
72   // files that do not start with the file entry, so it is relatively safe to
73   // perform this check.
74   return !(first_short[0] == first_slot && first_short[1] == second_slot);
75 }
76 
77 } // namespace torch::jit
78