xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/serialize/output-archive.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/api/module.h>
5 
6 #include <iosfwd>
7 #include <memory>
8 #include <string>
9 #include <utility>
10 
11 namespace at {
12 class Tensor;
13 } // namespace at
14 
15 namespace torch {
16 using at::Tensor;
17 namespace jit {
18 struct Module;
19 } // namespace jit
20 } // namespace torch
21 
22 namespace torch {
23 namespace serialize {
24 class TORCH_API OutputArchive final {
25  public:
26   explicit OutputArchive(std::shared_ptr<jit::CompilationUnit> cu);
OutputArchive()27   explicit OutputArchive()
28       : cu_(std::make_shared<jit::CompilationUnit>()),
29         module_("__torch__.Module", cu_) {}
30 
31   // Move is allowed.
32   OutputArchive(OutputArchive&&) = default;
33   OutputArchive& operator=(OutputArchive&&) = default;
34 
35   // Copy is disallowed.
36   OutputArchive(OutputArchive&) = delete;
37   OutputArchive& operator=(OutputArchive&) = delete;
38 
compilation_unit()39   std::shared_ptr<jit::CompilationUnit> compilation_unit() const {
40     return cu_;
41   }
42 
43   /// Writes an `IValue` to the `OutputArchive`.
44   void write(const std::string& key, const c10::IValue& ivalue);
45 
46   /// Writes a `(key, tensor)` pair to the `OutputArchive`, and marks it as
47   /// being or not being a buffer (non-differentiable tensor).
48   void write(
49       const std::string& key,
50       const Tensor& tensor,
51       bool is_buffer = false);
52 
53   /// Writes a nested `OutputArchive` under the given `key` to this
54   /// `OutputArchive`.
55   void write(const std::string& key, OutputArchive& nested_archive);
56 
57   /// Saves the `OutputArchive` into a serialized representation in a file at
58   /// `filename`.
59   void save_to(const std::string& filename);
60 
61   /// Saves the `OutputArchive` into a serialized representation into the given
62   /// `stream`.
63   void save_to(std::ostream& stream);
64 
65   /// Saves the `OutputArchive` into a serialized representation using the
66   /// given writer function.
67   void save_to(const std::function<size_t(const void*, size_t)>& func);
68 
69   /// Forwards all arguments to `write()`.
70   /// Useful for generic code that can be re-used for both `OutputArchive` and
71   /// `InputArchive` (where `operator()` forwards to `read()`).
72   template <typename... Ts>
operator()73   void operator()(Ts&&... ts) {
74     write(std::forward<Ts>(ts)...);
75   }
76 
77  private:
78   std::shared_ptr<jit::CompilationUnit> cu_;
79   jit::Module module_;
80 };
81 } // namespace serialize
82 } // namespace torch
83