1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/api/module.h> 5 6 #include <iosfwd> 7 #include <memory> 8 #include <string> 9 #include <utility> 10 11 namespace at { 12 class Tensor; 13 } // namespace at 14 15 namespace torch { 16 using at::Tensor; 17 namespace jit { 18 struct Module; 19 } // namespace jit 20 } // namespace torch 21 22 namespace torch { 23 namespace serialize { 24 class TORCH_API OutputArchive final { 25 public: 26 explicit OutputArchive(std::shared_ptr<jit::CompilationUnit> cu); OutputArchive()27 explicit OutputArchive() 28 : cu_(std::make_shared<jit::CompilationUnit>()), 29 module_("__torch__.Module", cu_) {} 30 31 // Move is allowed. 32 OutputArchive(OutputArchive&&) = default; 33 OutputArchive& operator=(OutputArchive&&) = default; 34 35 // Copy is disallowed. 36 OutputArchive(OutputArchive&) = delete; 37 OutputArchive& operator=(OutputArchive&) = delete; 38 compilation_unit()39 std::shared_ptr<jit::CompilationUnit> compilation_unit() const { 40 return cu_; 41 } 42 43 /// Writes an `IValue` to the `OutputArchive`. 44 void write(const std::string& key, const c10::IValue& ivalue); 45 46 /// Writes a `(key, tensor)` pair to the `OutputArchive`, and marks it as 47 /// being or not being a buffer (non-differentiable tensor). 48 void write( 49 const std::string& key, 50 const Tensor& tensor, 51 bool is_buffer = false); 52 53 /// Writes a nested `OutputArchive` under the given `key` to this 54 /// `OutputArchive`. 55 void write(const std::string& key, OutputArchive& nested_archive); 56 57 /// Saves the `OutputArchive` into a serialized representation in a file at 58 /// `filename`. 59 void save_to(const std::string& filename); 60 61 /// Saves the `OutputArchive` into a serialized representation into the given 62 /// `stream`. 63 void save_to(std::ostream& stream); 64 65 /// Saves the `OutputArchive` into a serialized representation using the 66 /// given writer function. 67 void save_to(const std::function<size_t(const void*, size_t)>& func); 68 69 /// Forwards all arguments to `write()`. 70 /// Useful for generic code that can be re-used for both `OutputArchive` and 71 /// `InputArchive` (where `operator()` forwards to `read()`). 72 template <typename... Ts> operator()73 void operator()(Ts&&... ts) { 74 write(std::forward<Ts>(ts)...); 75 } 76 77 private: 78 std::shared_ptr<jit::CompilationUnit> cu_; 79 jit::Module module_; 80 }; 81 } // namespace serialize 82 } // namespace torch 83