xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/runtime_compatibility.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/type_factory.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
5 #include <torch/csrc/jit/mobile/type_parser.h>
6 #include <torch/csrc/jit/runtime/operator.h>
7 #include <torch/custom_class.h>
8 #include <unordered_map>
9 
10 namespace c10 {
11 TypePtr parseType(const std::string& pythonStr);
12 } // namespace c10
13 
14 namespace torch::jit {
15 
_get_runtime_bytecode_version()16 uint64_t _get_runtime_bytecode_version() {
17   return caffe2::serialize::kMaxSupportedBytecodeVersion;
18 }
19 
_get_runtime_bytecode_min_max_versions()20 std::pair<uint64_t, uint64_t> _get_runtime_bytecode_min_max_versions() {
21   return std::pair<uint64_t, uint64_t>(
22       caffe2::serialize::kMinSupportedBytecodeVersion,
23       caffe2::serialize::kMaxSupportedBytecodeVersion);
24 }
25 
_get_runtime_operators_min_max_versions()26 std::pair<uint64_t, uint64_t> _get_runtime_operators_min_max_versions() {
27   return std::pair<uint64_t, uint64_t>(
28       caffe2::serialize::kMinSupportedFileFormatVersion,
29       caffe2::serialize::kMaxSupportedFileFormatVersion);
30 }
31 
32 /*
33  * Returns all registered PyTorch ops and their versioning
34  */
_get_runtime_ops_and_info()35 std::unordered_map<std::string, OperatorInfo> _get_runtime_ops_and_info() {
36   std::unordered_map<std::string, OperatorInfo> result;
37 
38   // Grab the jit operators
39   auto nonDispatcherOperators = torch::jit::getAllOperators();
40   for (const auto& full_op : nonDispatcherOperators) {
41     auto op = full_op->schema();
42     auto num_schema_args = op.arguments().size();
43     auto op_name = op.name();
44     if (!op.overload_name().empty()) {
45       op_name += ("." + op.overload_name());
46     }
47     result.emplace(op_name, OperatorInfo{num_schema_args});
48   }
49 
50   // Grab the dispatcher operators
51   auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
52   for (auto& op : dispatcherOperators) {
53     // grab schema
54     const auto op_handle = c10::Dispatcher::singleton().findOp(op);
55     std::optional<int> num_schema_args;
56     if (op_handle->hasSchema()) {
57       num_schema_args = op_handle->schema().arguments().size();
58     }
59     auto op_name = op.name;
60     if (!op.overload_name.empty()) {
61       op_name += ("." + op.overload_name);
62     }
63     result.emplace(op_name, OperatorInfo{num_schema_args});
64   }
65 
66   return result;
67 }
68 
get()69 RuntimeCompatibilityInfo RuntimeCompatibilityInfo::get() {
70   return RuntimeCompatibilityInfo{
71       _get_runtime_bytecode_min_max_versions(),
72       _get_runtime_ops_and_info(),
73       _get_mobile_supported_types(),
74       _get_runtime_operators_min_max_versions()};
75 }
76 
_get_mobile_supported_types()77 std::unordered_set<std::string> _get_mobile_supported_types() {
78   std::unordered_set<std::string> supported_types;
79   for (const auto& it : c10::DynamicTypeFactory::basePythonTypes()) {
80     supported_types.insert(it.first);
81   }
82   supported_types.insert(
83       at::TypeParser::getNonSimpleType().begin(),
84       at::TypeParser::getNonSimpleType().end());
85   supported_types.insert(
86       at::TypeParser::getCustomType().begin(),
87       at::TypeParser::getCustomType().end());
88 
89   return supported_types;
90 }
91 
_get_loaded_custom_classes()92 TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes() {
93   return torch::getAllCustomClassesNames();
94 }
95 
96 } // namespace torch::jit
97