xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/model_compatibility.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/macros/Export.h>
5 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
6 
7 #include <istream>
8 #include <memory>
9 #include <unordered_map>
10 #include <vector>
11 
12 namespace caffe2::serialize {
13 class PyTorchStreamReader;
14 class ReadAdapterInterface;
15 } // namespace caffe2::serialize
16 
17 namespace torch::jit {
18 
19 // The family of methods below to get bytecode version from a model
20 // Throws if not passed in a well formed model
21 TORCH_API uint64_t _get_model_bytecode_version(std::istream& in);
22 
23 TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename);
24 
25 TORCH_API uint64_t _get_model_bytecode_version(
26     const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai);
27 
28 uint64_t _get_model_bytecode_version(
29     const std::vector<c10::IValue>& bytecode_ivalues);
30 
31 // The family of methods below to get the operator version from a model
32 // Throws if not passed in a well formed model
33 TORCH_API uint64_t _get_model_operator_version(std::istream& in);
34 
35 TORCH_API uint64_t _get_model_operator_version(const std::string& filename);
36 
37 TORCH_API uint64_t _get_model_operator_version(
38     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
39 
40 // Utility Functions
41 std::vector<c10::IValue> get_bytecode_ivalues(
42     caffe2::serialize::PyTorchStreamReader& reader);
43 
44 c10::IValue readArchive(
45     const std::string& archive_name,
46     caffe2::serialize::PyTorchStreamReader& stream_reader);
47 
48 bool check_zip_file(
49     const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai);
50 
51 // The family of methods below to get the root ops and information from a model
52 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
53     std::istream& in);
54 
55 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
56     const std::string& filename);
57 
58 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
59     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
60 
61 // The family of methods below to get contained types from a model
62 // Throws if not passed in a well formed model
63 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
64     std::istream& in);
65 
66 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
67     const std::string& filename);
68 
69 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
70     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
71 
72 std::unordered_set<std::string> _get_mobile_model_contained_types(
73     const std::vector<c10::IValue>& bytecode_ivalues);
74 
75 // The family of methods below return the compatibility information of a model
76 struct ModelCompatibilityInfo {
77   uint64_t bytecode_version;
78   std::unordered_map<std::string, OperatorInfo> operator_info;
79   std::unordered_set<std::string> type_table;
80   uint64_t operator_version;
81 
82   // Factory Methods
83   static TORCH_API ModelCompatibilityInfo get(std::istream& in);
84   static TORCH_API ModelCompatibilityInfo get(const std::string& filename);
85   static TORCH_API ModelCompatibilityInfo
86   get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
87 };
88 
89 enum ModelCompatibilityStatus {
90   OK = 1,
91   ERROR = 2,
92 };
93 
94 struct ModelCompatCheckResult {
95   ModelCompatibilityStatus status;
96   std::vector<std::string> errors{};
97 };
98 // Takes in information about a runtime and a model and returns if the two are
99 // compatible with one another.
100 TORCH_API ModelCompatCheckResult is_compatible(
101     RuntimeCompatibilityInfo runtime_info,
102     const ModelCompatibilityInfo& model_info);
103 
104 } // namespace torch::jit
105