1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <ATen/core/type_factory.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
5 #include <torch/csrc/jit/mobile/type_parser.h>
6 #include <torch/csrc/jit/runtime/operator.h>
7 #include <torch/custom_class.h>
8 #include <unordered_map>
9
10 namespace c10 {
11 TypePtr parseType(const std::string& pythonStr);
12 } // namespace c10
13
14 namespace torch::jit {
15
_get_runtime_bytecode_version()16 uint64_t _get_runtime_bytecode_version() {
17 return caffe2::serialize::kMaxSupportedBytecodeVersion;
18 }
19
_get_runtime_bytecode_min_max_versions()20 std::pair<uint64_t, uint64_t> _get_runtime_bytecode_min_max_versions() {
21 return std::pair<uint64_t, uint64_t>(
22 caffe2::serialize::kMinSupportedBytecodeVersion,
23 caffe2::serialize::kMaxSupportedBytecodeVersion);
24 }
25
_get_runtime_operators_min_max_versions()26 std::pair<uint64_t, uint64_t> _get_runtime_operators_min_max_versions() {
27 return std::pair<uint64_t, uint64_t>(
28 caffe2::serialize::kMinSupportedFileFormatVersion,
29 caffe2::serialize::kMaxSupportedFileFormatVersion);
30 }
31
32 /*
33 * Returns all registered PyTorch ops and their versioning
34 */
_get_runtime_ops_and_info()35 std::unordered_map<std::string, OperatorInfo> _get_runtime_ops_and_info() {
36 std::unordered_map<std::string, OperatorInfo> result;
37
38 // Grab the jit operators
39 auto nonDispatcherOperators = torch::jit::getAllOperators();
40 for (const auto& full_op : nonDispatcherOperators) {
41 auto op = full_op->schema();
42 auto num_schema_args = op.arguments().size();
43 auto op_name = op.name();
44 if (!op.overload_name().empty()) {
45 op_name += ("." + op.overload_name());
46 }
47 result.emplace(op_name, OperatorInfo{num_schema_args});
48 }
49
50 // Grab the dispatcher operators
51 auto dispatcherOperators = c10::Dispatcher::singleton().getAllOpNames();
52 for (auto& op : dispatcherOperators) {
53 // grab schema
54 const auto op_handle = c10::Dispatcher::singleton().findOp(op);
55 std::optional<int> num_schema_args;
56 if (op_handle->hasSchema()) {
57 num_schema_args = op_handle->schema().arguments().size();
58 }
59 auto op_name = op.name;
60 if (!op.overload_name.empty()) {
61 op_name += ("." + op.overload_name);
62 }
63 result.emplace(op_name, OperatorInfo{num_schema_args});
64 }
65
66 return result;
67 }
68
get()69 RuntimeCompatibilityInfo RuntimeCompatibilityInfo::get() {
70 return RuntimeCompatibilityInfo{
71 _get_runtime_bytecode_min_max_versions(),
72 _get_runtime_ops_and_info(),
73 _get_mobile_supported_types(),
74 _get_runtime_operators_min_max_versions()};
75 }
76
_get_mobile_supported_types()77 std::unordered_set<std::string> _get_mobile_supported_types() {
78 std::unordered_set<std::string> supported_types;
79 for (const auto& it : c10::DynamicTypeFactory::basePythonTypes()) {
80 supported_types.insert(it.first);
81 }
82 supported_types.insert(
83 at::TypeParser::getNonSimpleType().begin(),
84 at::TypeParser::getNonSimpleType().end());
85 supported_types.insert(
86 at::TypeParser::getCustomType().begin(),
87 at::TypeParser::getCustomType().end());
88
89 return supported_types;
90 }
91
_get_loaded_custom_classes()92 TORCH_API std::unordered_set<std::string> _get_loaded_custom_classes() {
93 return torch::getAllCustomClassesNames();
94 }
95
96 } // namespace torch::jit
97