xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/type_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/type_parser.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/type_factory.h>
5 #include <c10/util/string_view.h>
6 #include <torch/csrc/jit/frontend/parser_constants.h>
7 #include <torch/custom_class.h>
8 
9 using torch::jit::valid_single_char_tokens;
10 
11 namespace c10 {
12 
13 namespace {
14 
15 // Torchbind custom class always starts with the follow prefix, so use it as
16 // an identifier for torchbind custom class type
17 static constexpr const char* kTypeTorchbindCustomClass =
18     "__torch__.torch.classes";
19 static constexpr const char* kTypeNamedTuple = "NamedTuple";
20 
isSpecialChar(char a)21 bool isSpecialChar(char a) {
22   for (const char* c = valid_single_char_tokens; *c; c++) {
23     if (a == *c)
24       return true;
25   }
26   return false;
27 }
28 } // namespace
29 
TypeParser(std::string pythonStr)30 TypeParser::TypeParser(std::string pythonStr)
31     : pythonStr_(std::move(pythonStr)), start_(0) {
32   lex();
33 }
34 
TypeParser(std::vector<std::string> & pythonStrs)35 TypeParser::TypeParser(std::vector<std::string>& pythonStrs)
36     : start_(0), pythonStrs_(pythonStrs) {}
37 
38 // For the Python string list parsing, the order of the Python string matters.
39 // In bytecode, the order of the type list correspondings to the order of
40 // instruction. In nested type, the lowest level type will be at the beginning
41 // of the type list. It is possible to parse it without worrying about
42 // ordering, but it also introduces 1) extra cost to process nested type to
43 // the correct order 2) lost the benifit that the instruction order is likely
44 // problematic if type list parsing fails.
parseList()45 std::vector<TypePtr> TypeParser::parseList() {
46   std::vector<TypePtr> typePtrs;
47   typePtrs.resize(pythonStrs_.size());
48   static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
49   for (size_t i = 0; i < pythonStrs_.size(); i++) {
50     c10::QualifiedName qn(pythonStrs_[i]);
51     c10::TypePtr type_ptr;
52     if (classPrefix.isPrefixOf(qn)) {
53       type_ptr = torch::getCustomClass(qn.qualifiedName());
54       TORCH_CHECK(
55           type_ptr,
56           "The implementation of class ",
57           qn.qualifiedName(),
58           " cannot be found.");
59     } else {
60       pythonStr_ = pythonStrs_[i];
61       start_ = 0;
62       lex();
63       type_ptr = parse();
64     }
65     typePtrs[i] = type_ptr;
66     str_type_ptr_map_[type_ptr->repr_str()] = type_ptr;
67   }
68   return typePtrs;
69 }
70 
71 // The list of non-simple types supported by current parser.
getNonSimpleType()72 const std::unordered_set<std::string>& TypeParser::getNonSimpleType() {
73   static std::unordered_set<std::string> nonSimpleTypes{
74       "List", "Optional", "Dict", "Tuple"};
75   return nonSimpleTypes;
76 }
77 
78 // The list of custom types supported by current parser.
getCustomType()79 const std::unordered_set<std::string>& TypeParser::getCustomType() {
80   static std::unordered_set<std::string> customeTypes{
81       kTypeTorchbindCustomClass, kTypeNamedTuple};
82   return customeTypes;
83 }
84 
85 // Given a PyThon str, get all contained types. It's usually used for
86 // compatibility check between model and runtime. For example:
87 // PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
88 // contained type is: [Dict, int, Tuple, Tensor]
getContainedTypes()89 std::unordered_set<std::string> TypeParser::getContainedTypes() {
90   return contained_types_;
91 }
92 
93 template <typename T>
parseSingleElementType()94 TypePtr TypeParser::parseSingleElementType() {
95   expectChar('[');
96   auto result = DynamicTypeFactory::create<T>(parse());
97   expectChar(']');
98   return result;
99 }
100 
parseNonSimple(const std::string & token)101 TypePtr TypeParser::parseNonSimple(const std::string& token) {
102   if (token == "List") {
103     return parseSingleElementType<ListType>();
104   } else if (token == "Optional") {
105     return parseSingleElementType<OptionalType>();
106   } else if (token == "Dict") {
107     expectChar('[');
108     auto key = parse();
109     expectChar(',');
110     auto val = parse();
111     expectChar(']');
112     return DynamicTypeFactory::create<DictType>(std::move(key), std::move(val));
113   } else if (token == "Tuple") {
114     std::vector<TypePtr> types;
115     expectChar('[');
116     while (cur() != "]") {
117       types.emplace_back(parse());
118       if (cur() != "]") {
119         expectChar(',');
120       }
121     }
122     expect("]");
123     return DynamicTypeFactory::create<TupleType>(types);
124   }
125   return nullptr;
126 }
127 
parse()128 TypePtr TypeParser::parse() {
129   std::string token = next();
130   const auto& baseTypes = DynamicTypeFactory::basePythonTypes();
131   auto simpleTypeIt = baseTypes.find(token);
132   if (simpleTypeIt != baseTypes.end()) {
133     if (cur() != "]" && cur() != "," && !cur().empty()) {
134       TORCH_CHECK(
135           false, "Simple type ", token, " is followed by ", "invalid chars.");
136     }
137     contained_types_.insert(token);
138     return simpleTypeIt->second;
139   } else if (getNonSimpleType().find(token) != getNonSimpleType().end()) {
140     contained_types_.insert(token);
141     return parseNonSimple(token);
142   } else if (token == "__torch__") {
143     expectChar('.');
144     if (cur() == "torch") {
145       // torch bind class starts with __torch__.torch.classes
146       return parseTorchbindClassType();
147     } else {
148       // other class starts with __torch__ following by custom names
149       return parseCustomType();
150     }
151   } else if (token == "Union") {
152     // TODO Union types are not supported on embedded runtime, and we need to
153     // generate compiler errors for users scripting UnionTypes. Right now
154     // for preserving backward compatibility we have to return a nullptr since
155     // it does not get involved in type reflection.
156     return nullptr;
157   } else {
158     TORCH_CHECK(
159         false,
160         "Type ",
161         token,
162         " is not supported in the parser, ",
163         "or the token is in wrong format.");
164   }
165   return nullptr;
166 }
167 
168 // NamedTuple custom type will be following structure:
169 // "qualified_named[
170 //   NamedTuple, [
171 //       [filed_name_1, field_type_1],
172 //       [filed_name_2, field_type_2]
173 //   ]
174 // ]"
175 //  Example NamedTuple type:
176 //  "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
177 //     NamedTuple, [
178 //         [float_features, Tensor],
179 //         [id_list_features, List[Tensor]],
180 //         [label,  Tensor],
181 //         [weight, Tensor],
182 //         ]
183 //     ]"
parseNamedTuple(const std::string & qualified_name)184 TypePtr TypeParser::parseNamedTuple(const std::string& qualified_name) {
185   std::vector<c10::string_view> field_names;
186   std::vector<TypePtr> field_types;
187   expect(",");
188   expect("[");
189   while (cur() != "]") {
190     expect("[");
191     auto field_name = nextView();
192     expect(",");
193     TypePtr field_type = parse();
194     field_names.emplace_back(field_name);
195     field_types.emplace_back(field_type);
196     expect("]");
197     if (cur() == ",") {
198       next();
199     }
200   }
201   return DynamicTypeFactory::createNamedTuple(
202       qualified_name, field_names, field_types);
203 }
204 
205 // Custom type will be following structure:
206 // "qualified_named[
207 //   custom_type, [
208 //       [filed_name_1, field_type_1],
209 //       [filed_name_2, field_type_2]
210 //   ]
211 // ]"
parseCustomType()212 TypePtr TypeParser::parseCustomType() {
213   c10::string_view token = cur();
214   std::string qualified_name = "__torch__.";
215   qualified_name.reserve(qualified_name.size() + token.size());
216   qualified_name.append(token.begin(), token.end());
217   next();
218   while (cur() == ".") {
219     qualified_name.append(next());
220     qualified_name.append(next());
221   }
222   // After cur() moves to the next token after qualified name, if it's "[", it
223   // means this custom type follow by it's class definition. Otherwise, it's a
224   // barebone qualified name and needs to look up str_type_ptr_map_ to find
225   // the typeptr.
226   if (cur() == "[") {
227     next();
228     std::string type_name = next();
229     // Currently only supports NamedTuple custom type, if more types need to
230     // be supported, extend them here.
231     if (type_name == kTypeNamedTuple) {
232       contained_types_.insert(kTypeNamedTuple);
233       return parseNamedTuple(qualified_name);
234     } else {
235       TORCH_CHECK(
236           false, "Custom Type ", type_name, " is not supported in the parser.");
237     }
238   } else {
239     auto find_type = str_type_ptr_map_.find(qualified_name);
240     if (find_type != str_type_ptr_map_.end()) {
241       return find_type->second;
242     } else {
243       // When the type definition can't be found, likely two reasons
244       // 1. The type list in bytecode.pkl is not in the correct order
245       // 2. This custom type definition doesn't exist in bytecode.pkl type
246       // table
247       TORCH_CHECK(
248           false, "Can't find definition for the type: ", qualified_name);
249     }
250     return nullptr;
251   }
252 }
253 
parseTorchbindClassType()254 TypePtr TypeParser::parseTorchbindClassType() {
255   static constexpr std::array<const char*, 4> expected_atoms = {
256       "torch", ".", "classes", "."};
257   for (const auto& atom : expected_atoms) {
258     expect(atom);
259   }
260   std::string ns = next();
261   expectChar('.');
262   std::string classname = next();
263   std::string customClassName = "__torch__.torch.classes.";
264   customClassName.reserve(
265       customClassName.size() + ns.size() + 1 + classname.size());
266   customClassName.append(ns);
267   customClassName.push_back('.');
268   customClassName.append(classname);
269   return torch::getCustomClass(customClassName);
270 }
271 
expect(const char * s)272 void TypeParser::expect(const char* s) {
273   c10::string_view token = cur();
274   TORCH_CHECK(
275       token == s,
276       "Error when parsing type ",
277       pythonStr_,
278       ": Expect ",
279       s,
280       ", but get ",
281       token);
282   advance();
283 }
284 
285 // c10::string_view::operator== calls memcmp to compare against the target
286 // string; we can do better if we specialize for a single character.
expectChar(char c)287 void TypeParser::expectChar(char c) {
288   c10::string_view token = cur();
289   TORCH_CHECK(
290       token.size() == 1 && token[0] == c,
291       "Error when parsing type ",
292       pythonStr_,
293       ": Expect ",
294       c,
295       ", but get ",
296       token);
297   advance();
298 }
299 
lex()300 void TypeParser::lex() {
301   // skip white spaces
302   while (start_ < pythonStr_.size() && pythonStr_[start_] == ' ')
303     ++start_;
304   if (start_ < pythonStr_.size()) {
305     if (isSpecialChar(pythonStr_[start_])) {
306       next_token_ = c10::string_view(pythonStr_.data() + start_++, 1);
307     } else { // A word
308       size_t end = start_;
309       for (; end < pythonStr_.size() && !isSpecialChar(pythonStr_[end]) &&
310            pythonStr_[end] != ' ';
311            ++end)
312         ;
313       next_token_ = c10::string_view(pythonStr_.data() + start_, end - start_);
314       start_ = end;
315     }
316   }
317 }
318 
nextView()319 c10::string_view TypeParser::nextView() {
320   TORCH_CHECK(
321       !next_token_.empty(),
322       "Empty token queue in mobile type parser.",
323       "Check the format of the type string and make sure it's correct.");
324   c10::string_view token = cur();
325   advance();
326   return token;
327 }
328 
next()329 std::string TypeParser::next() {
330   auto token = nextView();
331   return std::string(token.begin(), token.end());
332 }
333 
advance()334 void TypeParser::advance() {
335   next_token_ = "";
336   lex();
337 }
338 
cur() const339 C10_NODISCARD c10::string_view TypeParser::cur() const {
340   return next_token_;
341 }
342 
parseType(const std::string & pythonStr)343 TORCH_API at::TypePtr parseType(const std::string& pythonStr) {
344   at::TypeParser parser(pythonStr);
345   return parser.parse();
346 }
347 
parseType(std::vector<std::string> & pythonStrs)348 TORCH_API std::vector<at::TypePtr> parseType(
349     std::vector<std::string>& pythonStrs) {
350   at::TypeParser parser(pythonStrs);
351   return parser.parseList();
352 }
353 
354 } // namespace c10
355