1 #pragma once 2 3 #include <ATen/core/ivalue.h> 4 #include <c10/macros/Export.h> 5 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h> 6 7 #include <istream> 8 #include <memory> 9 #include <unordered_map> 10 #include <vector> 11 12 namespace caffe2::serialize { 13 class PyTorchStreamReader; 14 class ReadAdapterInterface; 15 } // namespace caffe2::serialize 16 17 namespace torch::jit { 18 19 // The family of methods below to get bytecode version from a model 20 // Throws if not passed in a well formed model 21 TORCH_API uint64_t _get_model_bytecode_version(std::istream& in); 22 23 TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename); 24 25 TORCH_API uint64_t _get_model_bytecode_version( 26 const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai); 27 28 uint64_t _get_model_bytecode_version( 29 const std::vector<c10::IValue>& bytecode_ivalues); 30 31 // The family of methods below to get the operator version from a model 32 // Throws if not passed in a well formed model 33 TORCH_API uint64_t _get_model_operator_version(std::istream& in); 34 35 TORCH_API uint64_t _get_model_operator_version(const std::string& filename); 36 37 TORCH_API uint64_t _get_model_operator_version( 38 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); 39 40 // Utility Functions 41 std::vector<c10::IValue> get_bytecode_ivalues( 42 caffe2::serialize::PyTorchStreamReader& reader); 43 44 c10::IValue readArchive( 45 const std::string& archive_name, 46 caffe2::serialize::PyTorchStreamReader& stream_reader); 47 48 bool check_zip_file( 49 const std::shared_ptr<caffe2::serialize::ReadAdapterInterface>& rai); 50 51 // The family of methods below to get the root ops and information from a model 52 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( 53 std::istream& in); 54 55 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( 56 const std::string& filename); 57 58 TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( 59 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); 60 61 // The family of methods below to get contained types from a model 62 // Throws if not passed in a well formed model 63 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( 64 std::istream& in); 65 66 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( 67 const std::string& filename); 68 69 TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( 70 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); 71 72 std::unordered_set<std::string> _get_mobile_model_contained_types( 73 const std::vector<c10::IValue>& bytecode_ivalues); 74 75 // The family of methods below return the compatibility information of a model 76 struct ModelCompatibilityInfo { 77 uint64_t bytecode_version; 78 std::unordered_map<std::string, OperatorInfo> operator_info; 79 std::unordered_set<std::string> type_table; 80 uint64_t operator_version; 81 82 // Factory Methods 83 static TORCH_API ModelCompatibilityInfo get(std::istream& in); 84 static TORCH_API ModelCompatibilityInfo get(const std::string& filename); 85 static TORCH_API ModelCompatibilityInfo 86 get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); 87 }; 88 89 enum ModelCompatibilityStatus { 90 OK = 1, 91 ERROR = 2, 92 }; 93 94 struct ModelCompatCheckResult { 95 ModelCompatibilityStatus status; 96 std::vector<std::string> errors{}; 97 }; 98 // Takes in information about a runtime and a model and returns if the two are 99 // compatible with one another. 100 TORCH_API ModelCompatCheckResult is_compatible( 101 RuntimeCompatibilityInfo runtime_info, 102 const ModelCompatibilityInfo& model_info); 103 104 } // namespace torch::jit 105