1 #pragma once 2 3 #include <istream> 4 #include <memory> 5 #include <string> 6 #include <unordered_map> 7 #include <vector> 8 9 #include <ATen/core/ivalue.h> 10 #include <c10/core/Device.h> 11 #include <c10/macros/Macros.h> 12 #include <torch/csrc/jit/mobile/module.h> 13 #include <optional> 14 15 /** 16 * Defines the public API for loading flatbuffer-serialized mobile modules. 17 * Note that this header must not include or depend on flatbuffer-defined 18 * types, to avoid leaking those details to PyTorch clients. 19 */ 20 21 namespace torch::jit { 22 23 /// All non-copied data pointers provided to `parse_and_initialize_*` functions 24 /// must be aligned to this boundary. Since the Module will point directly into 25 /// the data, this alignment is necessary to ensure that certain types/structs 26 /// are properly aligned. 27 constexpr size_t kFlatbufferDataAlignmentBytes = 16; 28 29 /// Maps file names to file contents. 30 using ExtraFilesMap = std::unordered_map<std::string, std::string>; 31 32 // On high level, to produce a Module from a file on disk, we need to go 33 // through the follow steps: 34 // 1. Read: Read the file from disk -> memory 35 // 2. Deserialize: Parse the bytes to produce some in memory manipulable 36 // structure 37 // 3. Module initialization: Produce mobile::Module out of the structure 38 // produced in 2. 39 // Under this context, the structure described in 2. is the flatbuffer-defined 40 // type mobile::serialization::Module. However, this step/type is not visible in 41 // the public API. 42 43 // Parse a mobile::Module from raw bytes. 44 // 45 // This function does steps 2+3 described above. 46 // 47 // Does not take ownership of `data`; if you want it to take ownership, see the 48 // shared_ptr overload of this function. 49 // 50 // If should_copy_tensor_memory is true, then the returned module will NOT have 51 // refences to `data`, so `data` can be freed immediately. 52 // 53 // If should_copy_tensor_memory is false, then returned module will have tensors 54 // that points inside of `data`; the caller will need to make sure that `data` 55 // outlives the returned Module. Also, `data` must be aligned to 56 // kFlatbufferDataAlignmentBytes. 57 TORCH_API mobile::Module parse_and_initialize_mobile_module( 58 void* data, 59 size_t size, // of `data`, in bytes. 60 std::optional<at::Device> device = std::nullopt, 61 ExtraFilesMap* extra_files = nullptr, 62 bool should_copy_tensor_memory = false); 63 64 // Parse a mobile::Module from raw bytes. 65 // 66 // This function does steps 2+3 described above. 67 // 68 // The returned Module holds a reference to `data`, which must be aligned to 69 // kFlatbufferDataAlignmentBytes. 70 // 71 // If you do not want the Module to hold a reference to `data`, see the raw 72 // pointer overload of this function. 73 TORCH_API mobile::Module parse_and_initialize_mobile_module( 74 std::shared_ptr<char> data, 75 size_t size, // of `data`, in bytes. 76 std::optional<at::Device> device = std::nullopt, 77 ExtraFilesMap* extra_files = nullptr); 78 79 // Parse a mobile::Module from raw bytes, also returning JIT-related metadata. 80 // 81 // This is the same as parse_and_initialize_mobile_module() except that it also 82 // extracts JIT source files and constants. Can be used to construct a 83 // jit::Module. 84 TORCH_API mobile::Module parse_and_initialize_mobile_module_for_jit( 85 void* data, 86 size_t size, // of `data`, in bytes. 87 ExtraFilesMap& jit_sources, 88 std::vector<IValue>& jit_constants, 89 std::optional<at::Device> device = std::nullopt, 90 ExtraFilesMap* extra_files = nullptr); 91 92 // Load a mobile::Module from a filepath. 93 // 94 // This function does steps 1+2+3 described above. 95 // 96 // We need to have this as a convienience because Python API will need to wrap 97 // this. C++ clients should use one of the versions of 98 // parse_and_initialize_mobile_module() so they can manage the raw data more 99 // directly. 100 TORCH_API mobile::Module load_mobile_module_from_file( 101 const std::string& filename, 102 std::optional<at::Device> device = std::nullopt, 103 ExtraFilesMap* extra_files = nullptr); 104 105 TORCH_API uint64_t get_bytecode_version(std::istream& in); 106 TORCH_API uint64_t get_bytecode_version(const std::string& filename); 107 TORCH_API uint64_t get_bytecode_version_from_bytes(char* flatbuffer_content); 108 109 TORCH_API mobile::ModuleInfo get_module_info_from_flatbuffer( 110 char* flatbuffer_content); 111 112 // The methods below are less efficient because it need to read the stream in 113 // its entirity to a buffer 114 TORCH_API mobile::Module load_mobile_module_from_stream_with_copy( 115 std::istream& in, 116 std::optional<at::Device> device = std::nullopt, 117 ExtraFilesMap* extra_files = nullptr); 118 119 TORCH_API mobile::Module parse_flatbuffer_no_object( 120 std::shared_ptr<char> data, 121 size_t size, 122 std::optional<at::Device> device); 123 124 TORCH_API mobile::Module parse_and_initialize_mobile_module( 125 void* data, 126 size_t, 127 std::optional<at::Device>, 128 ExtraFilesMap* extra_files, 129 bool should_copy_tensor_memory); 130 131 // no op, TODO(qihan) delete 132 TORCH_API bool register_flatbuffer_loader(); 133 134 } // namespace torch::jit 135