xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/backport_manager.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <functional>
4 #include <memory>
5 #include <unordered_map>
6 
7 namespace c10 {
8 struct IValue;
9 }
10 
11 namespace caffe2::serialize {
12 class PyTorchStreamWriter;
13 } // namespace caffe2::serialize
14 
15 namespace torch::jit {
16 
17 /*
18 BackportManager manages a list of backport from n to n-1 function, and provides
19 function to check if a specific function exists.
20 */
21 class BackportManager final {
22  public:
23   bool hasBytecodeBackportFunction(const int64_t from_version) const;
24 
25   std::unordered_map<
26       int64_t,
27       std::function<std::stringstream(std::stringstream&)>>&
28   bytecodeBackportFunctions() const;
29 
30   bool backport(
31       std::istream& oss,
32       caffe2::serialize::PyTorchStreamWriter& final_writer,
33       int64_t from_version,
34       int64_t to_version) const;
35 
36   BackportManager(BackportManager const&) = delete;
37   BackportManager& operator=(BackportManager const&) = delete;
38   BackportManager();
39 
40  private:
41   // Registry of backport functions.
42   void registerBytecodeBackportFunction(
43       const int64_t from_version,
44       const std::function<std::stringstream(std::stringstream&)>&
45           backport_function);
46 };
47 
48 } // namespace torch::jit
49