xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/pickle.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/util/ArrayRef.h>
5 #include <caffe2/serialize/inline_container.h>
6 #include <torch/csrc/Export.h>
7 #include <torch/csrc/jit/serialization/pickler.h>
8 #include <torch/csrc/jit/serialization/unpickler.h>
9 
10 namespace torch::jit {
11 
12 /// Pickle an IValue by calling a function to handle writing the data.
13 ///
14 /// `writer` is a function that takes in a pointer to a chunk of memory and its
15 /// size and consumes it.
16 ///
17 /// See `jit::pickle` for more details.
18 TORCH_API void pickle(
19     std::function<void(const char* data_start, size_t data_len)> writer,
20     const IValue& ivalue,
21     std::vector<at::Tensor>* tensor_table = nullptr);
22 
23 /// Save a `torch::IValue` in a format compatible with Python's `pickle` module
24 ///
25 /// If present, `tensor_table` is a pointer to a table in which tensors that
26 /// are contained within `ivalue` are stored, and the bytes returned by the
27 /// pickler will only include references to these tensors in the table. This can
28 /// be used to keep the binary blob size small.
29 /// If not provided, tensors are stored in the same byte stream as the pickle
30 /// data, similar to `torch.save()` in eager Python.
31 ///
32 /// Pickled values can be loaded in Python and C++:
33 /// \rst
34 /// .. code-block:: cpp
35 ///
36 ///  torch::IValue float_value(2.3);
37 ///
38 ///  // TODO: when tensors are stored in the pickle, delete this
39 ///  std::vector<at::Tensor> tensor_table;
40 ///  auto data = torch::jit::pickle(float_value, &tensor_table);
41 ///
42 ///  std::vector<torch::IValue> ivalues =
43 ///      torch::jit::unpickle(data.data(), data.size());
44 ///
45 /// .. code-block:: python
46 ///
47 ///   values = torch.load('data.pkl')
48 ///   print(values)
49 ///
50 /// \endrst
51 TORCH_API std::vector<char> pickle(
52     const IValue& ivalue,
53     std::vector<at::Tensor>* tensor_table = nullptr);
54 
55 /// Save a `torch::IValue` in a format that can be loaded by both
56 /// `torch::pickle_load` in C++ and `torch.load` in Python.
57 TORCH_API std::vector<char> pickle_save(const IValue& ivalue);
58 
59 /// Deserialize a `torch::IValue` from bytes produced by either
60 /// `torch::pickle_save` in C++ or `torch.save` in Python
61 TORCH_API IValue pickle_load(const std::vector<char>& data);
62 
63 /// Deserialize a `torch::IValue` from bytes produced by either
64 /// `torch::pickle_save` in C++ or `torch.save` in Python with custom object.
65 TORCH_API IValue pickle_load_obj(std::string_view data);
66 
67 /// `reader` is a function that takes in a size to read from some pickled
68 /// binary. `reader` should remember where it last read, and return
69 /// the number of bytes read.
70 /// See `torch::pickle` for details.
71 /// type_resolver is used to resolve any JIT type based on type str
72 TORCH_API IValue unpickle(
73     std::function<size_t(char*, size_t)> reader,
74     TypeResolver type_resolver,
75     c10::ArrayRef<at::Tensor> tensor_table,
76     c10::TypePtr (*type_parser)(const std::string&) =
77         Unpickler::defaultTypeParser,
78     ObjLoader obj_loader = nullptr);
79 
80 /// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
81 ///
82 /// If any `torch::IValue`s in the pickled data are `Object`s, then a
83 /// `class_resolver` function must be provided.
84 ///
85 /// See `torch::pickle` for details.
86 TORCH_API IValue unpickle(
87     const char* data,
88     size_t size,
89     TypeResolver type_resolver = nullptr,
90     c10::ArrayRef<at::Tensor> tensor_table = {},
91     c10::TypePtr (*type_parser)(const std::string&) =
92         Unpickler::defaultTypeParser);
93 
94 /// Decode a chunk of memory containing pickled data into its `torch::IValue`s.
95 ///
96 /// If any `torch::IValue`s in the pickled data are `Object`s, then a
97 /// `class_resolver` function must be provided.
98 ///
99 /// See `torch::pickle` for details.
100 TORCH_API IValue unpickle(
101     const char* data,
102     size_t size,
103     ObjLoader obj_loader,
104     TypeResolver type_resolver = nullptr,
105     c10::ArrayRef<at::Tensor> tensor_table = {},
106     c10::TypePtr (*type_parser)(const std::string&) =
107         Unpickler::defaultTypeParser);
108 
109 #ifndef C10_MOBILE
110 class VectorReader : public caffe2::serialize::ReadAdapterInterface {
111  public:
VectorReader(std::vector<char> data)112   VectorReader(std::vector<char> data) : data_(std::move(data)) {}
113 
size()114   size_t size() const override {
115     return data_.size();
116   }
117 
118   size_t read(uint64_t pos, void* buf, size_t n, const char* what)
119       const override;
120 
121  private:
122   std::vector<char> data_;
123 };
124 
125 class StringViewReader : public caffe2::serialize::ReadAdapterInterface {
126  public:
StringViewReader(std::string_view data)127   StringViewReader(std::string_view data) : data_(data) {}
128 
size()129   size_t size() const override {
130     return data_.size();
131   }
132 
133   size_t read(uint64_t pos, void* buf, size_t n, const char* what)
134       const override;
135 
136  private:
137   std::string_view data_;
138 };
139 #endif
140 } // namespace torch::jit
141