1 #include <caffe2/serialize/inline_container.h>
2 #include <torch/csrc/jit/serialization/import_read.h>
3
4 #include <utility>
5
6 namespace torch::jit {
7
readArchiveAndTensors(const std::string & archive_name,const std::string & pickle_prefix,const std::string & tensor_prefix,std::optional<TypeResolver> type_resolver,std::optional<ObjLoader> obj_loader,std::optional<at::Device> device,caffe2::serialize::PyTorchStreamReader & stream_reader,c10::TypePtr (* type_parser)(const std::string &),std::shared_ptr<DeserializationStorageContext> storage_context)8 IValue readArchiveAndTensors(
9 const std::string& archive_name,
10 const std::string& pickle_prefix,
11 const std::string& tensor_prefix,
12 std::optional<TypeResolver> type_resolver,
13 std::optional<ObjLoader> obj_loader,
14 std::optional<at::Device> device,
15 caffe2::serialize::PyTorchStreamReader& stream_reader,
16 c10::TypePtr (*type_parser)(const std::string&),
17 std::shared_ptr<DeserializationStorageContext> storage_context) {
18 std::string picklename = pickle_prefix + archive_name + ".pkl";
19 at::DataPtr pickle_ptr;
20 size_t pickle_size = 0;
21 std::tie(pickle_ptr, pickle_size) = stream_reader.getRecord(picklename);
22
23 size_t bytes_read = 0;
24 auto data = reinterpret_cast<const char*>(pickle_ptr.get());
25 auto reader = [&](char* buffer, size_t len) -> size_t {
26 if (bytes_read >= pickle_size) {
27 return 0;
28 }
29 len = std::min(pickle_size - bytes_read, len);
30 // Copy len bytes into buffer
31 const char* start = data + bytes_read;
32 std::memcpy(buffer, start, len);
33 bytes_read += len;
34 return len;
35 };
36
37 std::string tensor_dir_path =
38 (!tensor_prefix.empty()) ? tensor_prefix : archive_name + "/";
39
40 auto read_record = [&](const std::string& name) {
41 std::string ss = tensor_dir_path + name;
42 return std::get<0>(stream_reader.getRecord(ss));
43 };
44
45 Unpickler unpickler(
46 reader,
47 type_resolver ? std::move(*type_resolver) : nullptr,
48 obj_loader ? std::move(*obj_loader) : nullptr,
49 std::move(read_record),
50 device,
51 false,
52 type_parser,
53 std::move(storage_context));
54 unpickler.set_version(stream_reader.version());
55 return unpickler.parse_ivalue();
56 }
57
check_zip_file(const std::shared_ptr<caffe2::serialize::ReadAdapterInterface> & rai)58 bool check_zip_file(
59 const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai) {
60 std::array<uint8_t, 2> first_short{};
61 static constexpr uint8_t first_slot = 0x80;
62 static constexpr uint8_t second_slot = 0x02;
63 rai->read(
64 /*pos=*/0,
65 /*buf=*/&first_short,
66 /*n=*/2,
67 /*what=*/"checking archive");
68
69 // NB: zip files by spec can start with any data, so technically they might
70 // start with 0x80 0x02, but in practice zip files start with a file entry
71 // which begins with 0x04034b50. Furthermore, PyTorch will never produce zip
72 // files that do not start with the file entry, so it is relatively safe to
73 // perform this check.
74 return !(first_short[0] == first_slot && first_short[1] == second_slot);
75 }
76
77 } // namespace torch::jit
78