xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/flatbuffer_loader.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <istream>
4 #include <memory>
5 #include <string>
6 #include <unordered_map>
7 #include <vector>
8 
9 #include <ATen/core/ivalue.h>
10 #include <c10/core/Device.h>
11 #include <c10/macros/Macros.h>
12 #include <torch/csrc/jit/mobile/module.h>
13 #include <optional>
14 
15 /**
16  * Defines the public API for loading flatbuffer-serialized mobile modules.
17  * Note that this header must not include or depend on flatbuffer-defined
18  * types, to avoid leaking those details to PyTorch clients.
19  */
20 
21 namespace torch::jit {
22 
23 /// All non-copied data pointers provided to `parse_and_initialize_*` functions
24 /// must be aligned to this boundary. Since the Module will point directly into
25 /// the data, this alignment is necessary to ensure that certain types/structs
26 /// are properly aligned.
27 constexpr size_t kFlatbufferDataAlignmentBytes = 16;
28 
29 /// Maps file names to file contents.
30 using ExtraFilesMap = std::unordered_map<std::string, std::string>;
31 
32 // On high level, to produce a Module from a file on disk, we need to go
33 // through the follow steps:
34 // 1. Read: Read the file from disk -> memory
35 // 2. Deserialize: Parse the bytes to produce some in memory manipulable
36 //    structure
37 // 3. Module initialization: Produce mobile::Module out of the structure
38 //    produced in 2.
39 // Under this context, the structure described in 2. is the flatbuffer-defined
40 // type mobile::serialization::Module. However, this step/type is not visible in
41 // the public API.
42 
43 // Parse a mobile::Module from raw bytes.
44 //
45 // This function does steps 2+3 described above.
46 //
47 // Does not take ownership of `data`; if you want it to take ownership, see the
48 // shared_ptr overload of this function.
49 //
50 // If should_copy_tensor_memory is true, then the returned module will NOT have
51 // refences to `data`, so `data` can be freed immediately.
52 //
53 // If should_copy_tensor_memory is false, then returned module will have tensors
54 // that points inside of `data`; the caller will need to make sure that `data`
55 // outlives the returned Module. Also, `data` must be aligned to
56 // kFlatbufferDataAlignmentBytes.
57 TORCH_API mobile::Module parse_and_initialize_mobile_module(
58     void* data,
59     size_t size, // of `data`, in bytes.
60     std::optional<at::Device> device = std::nullopt,
61     ExtraFilesMap* extra_files = nullptr,
62     bool should_copy_tensor_memory = false);
63 
64 // Parse a mobile::Module from raw bytes.
65 //
66 // This function does steps 2+3 described above.
67 //
68 // The returned Module holds a reference to `data`, which must be aligned to
69 // kFlatbufferDataAlignmentBytes.
70 //
71 // If you do not want the Module to hold a reference to `data`, see the raw
72 // pointer overload of this function.
73 TORCH_API mobile::Module parse_and_initialize_mobile_module(
74     std::shared_ptr<char> data,
75     size_t size, // of `data`, in bytes.
76     std::optional<at::Device> device = std::nullopt,
77     ExtraFilesMap* extra_files = nullptr);
78 
79 // Parse a mobile::Module from raw bytes, also returning JIT-related metadata.
80 //
81 // This is the same as parse_and_initialize_mobile_module() except that it also
82 // extracts JIT source files and constants. Can be used to construct a
83 // jit::Module.
84 TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit(
85     void* data,
86     size_t size, // of `data`, in bytes.
87     ExtraFilesMap& jit_sources,
88     std::vector<IValue>& jit_constants,
89     std::optional<at::Device> device = std::nullopt,
90     ExtraFilesMap* extra_files = nullptr);
91 
92 // Load a mobile::Module from a filepath.
93 //
94 // This function does steps 1+2+3 described above.
95 //
96 // We need to have this as a convienience because Python API will need to wrap
97 // this. C++ clients should use one of the versions of
98 // parse_and_initialize_mobile_module() so they can manage the raw data more
99 // directly.
100 TORCH_API mobile::Module load_mobile_module_from_file(
101     const std::string& filename,
102     std::optional<at::Device> device = std::nullopt,
103     ExtraFilesMap* extra_files = nullptr);
104 
105 TORCH_API uint64_t get_bytecode_version(std::istream& in);
106 TORCH_API uint64_t get_bytecode_version(const std::string& filename);
107 TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content);
108 
109 TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer(
110     char* flatbuffer_content);
111 
112 // The methods below are less efficient because it need to read the stream in
113 // its entirity to a buffer
114 TORCH_API mobile::Module load_mobile_module_from_stream_with_copy(
115     std::istream& in,
116     std::optional<at::Device> device = std::nullopt,
117     ExtraFilesMap* extra_files = nullptr);
118 
119 TORCH_API mobile::Module parse_flatbuffer_no_object(
120     std::shared_ptr<char> data,
121     size_t size,
122     std::optional<at::Device> device);
123 
124 TORCH_API mobile::Module parse_and_initialize_mobile_module(
125     void* data,
126     size_t,
127     std::optional<at::Device>,
128     ExtraFilesMap* extra_files,
129     bool should_copy_tensor_memory);
130 
131 // no op, TODO(qihan) delete
132 TORCH_API bool register_flatbuffer_loader();
133 
134 } // namespace torch::jit
135