1 #include <torch/csrc/jit/mobile/type_parser.h>
2
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/type_factory.h>
5 #include <c10/util/string_view.h>
6 #include <torch/csrc/jit/frontend/parser_constants.h>
7 #include <torch/custom_class.h>
8
9 using torch::jit::valid_single_char_tokens;
10
11 namespace c10 {
12
13 namespace {
14
15 // Torchbind custom class always starts with the follow prefix, so use it as
16 // an identifier for torchbind custom class type
17 static constexpr const char* kTypeTorchbindCustomClass =
18 "__torch__.torch.classes";
19 static constexpr const char* kTypeNamedTuple = "NamedTuple";
20
isSpecialChar(char a)21 bool isSpecialChar(char a) {
22 for (const char* c = valid_single_char_tokens; *c; c++) {
23 if (a == *c)
24 return true;
25 }
26 return false;
27 }
28 } // namespace
29
TypeParser(std::string pythonStr)30 TypeParser::TypeParser(std::string pythonStr)
31 : pythonStr_(std::move(pythonStr)), start_(0) {
32 lex();
33 }
34
TypeParser(std::vector<std::string> & pythonStrs)35 TypeParser::TypeParser(std::vector<std::string>& pythonStrs)
36 : start_(0), pythonStrs_(pythonStrs) {}
37
38 // For the Python string list parsing, the order of the Python string matters.
39 // In bytecode, the order of the type list correspondings to the order of
40 // instruction. In nested type, the lowest level type will be at the beginning
41 // of the type list. It is possible to parse it without worrying about
42 // ordering, but it also introduces 1) extra cost to process nested type to
43 // the correct order 2) lost the benifit that the instruction order is likely
44 // problematic if type list parsing fails.
parseList()45 std::vector<TypePtr> TypeParser::parseList() {
46 std::vector<TypePtr> typePtrs;
47 typePtrs.resize(pythonStrs_.size());
48 static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
49 for (size_t i = 0; i < pythonStrs_.size(); i++) {
50 c10::QualifiedName qn(pythonStrs_[i]);
51 c10::TypePtr type_ptr;
52 if (classPrefix.isPrefixOf(qn)) {
53 type_ptr = torch::getCustomClass(qn.qualifiedName());
54 TORCH_CHECK(
55 type_ptr,
56 "The implementation of class ",
57 qn.qualifiedName(),
58 " cannot be found.");
59 } else {
60 pythonStr_ = pythonStrs_[i];
61 start_ = 0;
62 lex();
63 type_ptr = parse();
64 }
65 typePtrs[i] = type_ptr;
66 str_type_ptr_map_[type_ptr->repr_str()] = type_ptr;
67 }
68 return typePtrs;
69 }
70
71 // The list of non-simple types supported by current parser.
getNonSimpleType()72 const std::unordered_set<std::string>& TypeParser::getNonSimpleType() {
73 static std::unordered_set<std::string> nonSimpleTypes{
74 "List", "Optional", "Dict", "Tuple"};
75 return nonSimpleTypes;
76 }
77
78 // The list of custom types supported by current parser.
getCustomType()79 const std::unordered_set<std::string>& TypeParser::getCustomType() {
80 static std::unordered_set<std::string> customeTypes{
81 kTypeTorchbindCustomClass, kTypeNamedTuple};
82 return customeTypes;
83 }
84
85 // Given a PyThon str, get all contained types. It's usually used for
86 // compatibility check between model and runtime. For example:
87 // PyThon string: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
88 // contained type is: [Dict, int, Tuple, Tensor]
getContainedTypes()89 std::unordered_set<std::string> TypeParser::getContainedTypes() {
90 return contained_types_;
91 }
92
93 template <typename T>
parseSingleElementType()94 TypePtr TypeParser::parseSingleElementType() {
95 expectChar('[');
96 auto result = DynamicTypeFactory::create<T>(parse());
97 expectChar(']');
98 return result;
99 }
100
parseNonSimple(const std::string & token)101 TypePtr TypeParser::parseNonSimple(const std::string& token) {
102 if (token == "List") {
103 return parseSingleElementType<ListType>();
104 } else if (token == "Optional") {
105 return parseSingleElementType<OptionalType>();
106 } else if (token == "Dict") {
107 expectChar('[');
108 auto key = parse();
109 expectChar(',');
110 auto val = parse();
111 expectChar(']');
112 return DynamicTypeFactory::create<DictType>(std::move(key), std::move(val));
113 } else if (token == "Tuple") {
114 std::vector<TypePtr> types;
115 expectChar('[');
116 while (cur() != "]") {
117 types.emplace_back(parse());
118 if (cur() != "]") {
119 expectChar(',');
120 }
121 }
122 expect("]");
123 return DynamicTypeFactory::create<TupleType>(types);
124 }
125 return nullptr;
126 }
127
parse()128 TypePtr TypeParser::parse() {
129 std::string token = next();
130 const auto& baseTypes = DynamicTypeFactory::basePythonTypes();
131 auto simpleTypeIt = baseTypes.find(token);
132 if (simpleTypeIt != baseTypes.end()) {
133 if (cur() != "]" && cur() != "," && !cur().empty()) {
134 TORCH_CHECK(
135 false, "Simple type ", token, " is followed by ", "invalid chars.");
136 }
137 contained_types_.insert(token);
138 return simpleTypeIt->second;
139 } else if (getNonSimpleType().find(token) != getNonSimpleType().end()) {
140 contained_types_.insert(token);
141 return parseNonSimple(token);
142 } else if (token == "__torch__") {
143 expectChar('.');
144 if (cur() == "torch") {
145 // torch bind class starts with __torch__.torch.classes
146 return parseTorchbindClassType();
147 } else {
148 // other class starts with __torch__ following by custom names
149 return parseCustomType();
150 }
151 } else if (token == "Union") {
152 // TODO Union types are not supported on embedded runtime, and we need to
153 // generate compiler errors for users scripting UnionTypes. Right now
154 // for preserving backward compatibility we have to return a nullptr since
155 // it does not get involved in type reflection.
156 return nullptr;
157 } else {
158 TORCH_CHECK(
159 false,
160 "Type ",
161 token,
162 " is not supported in the parser, ",
163 "or the token is in wrong format.");
164 }
165 return nullptr;
166 }
167
168 // NamedTuple custom type will be following structure:
169 // "qualified_named[
170 // NamedTuple, [
171 // [filed_name_1, field_type_1],
172 // [filed_name_2, field_type_2]
173 // ]
174 // ]"
175 // Example NamedTuple type:
176 // "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
177 // NamedTuple, [
178 // [float_features, Tensor],
179 // [id_list_features, List[Tensor]],
180 // [label, Tensor],
181 // [weight, Tensor],
182 // ]
183 // ]"
parseNamedTuple(const std::string & qualified_name)184 TypePtr TypeParser::parseNamedTuple(const std::string& qualified_name) {
185 std::vector<c10::string_view> field_names;
186 std::vector<TypePtr> field_types;
187 expect(",");
188 expect("[");
189 while (cur() != "]") {
190 expect("[");
191 auto field_name = nextView();
192 expect(",");
193 TypePtr field_type = parse();
194 field_names.emplace_back(field_name);
195 field_types.emplace_back(field_type);
196 expect("]");
197 if (cur() == ",") {
198 next();
199 }
200 }
201 return DynamicTypeFactory::createNamedTuple(
202 qualified_name, field_names, field_types);
203 }
204
205 // Custom type will be following structure:
206 // "qualified_named[
207 // custom_type, [
208 // [filed_name_1, field_type_1],
209 // [filed_name_2, field_type_2]
210 // ]
211 // ]"
parseCustomType()212 TypePtr TypeParser::parseCustomType() {
213 c10::string_view token = cur();
214 std::string qualified_name = "__torch__.";
215 qualified_name.reserve(qualified_name.size() + token.size());
216 qualified_name.append(token.begin(), token.end());
217 next();
218 while (cur() == ".") {
219 qualified_name.append(next());
220 qualified_name.append(next());
221 }
222 // After cur() moves to the next token after qualified name, if it's "[", it
223 // means this custom type follow by it's class definition. Otherwise, it's a
224 // barebone qualified name and needs to look up str_type_ptr_map_ to find
225 // the typeptr.
226 if (cur() == "[") {
227 next();
228 std::string type_name = next();
229 // Currently only supports NamedTuple custom type, if more types need to
230 // be supported, extend them here.
231 if (type_name == kTypeNamedTuple) {
232 contained_types_.insert(kTypeNamedTuple);
233 return parseNamedTuple(qualified_name);
234 } else {
235 TORCH_CHECK(
236 false, "Custom Type ", type_name, " is not supported in the parser.");
237 }
238 } else {
239 auto find_type = str_type_ptr_map_.find(qualified_name);
240 if (find_type != str_type_ptr_map_.end()) {
241 return find_type->second;
242 } else {
243 // When the type definition can't be found, likely two reasons
244 // 1. The type list in bytecode.pkl is not in the correct order
245 // 2. This custom type definition doesn't exist in bytecode.pkl type
246 // table
247 TORCH_CHECK(
248 false, "Can't find definition for the type: ", qualified_name);
249 }
250 return nullptr;
251 }
252 }
253
parseTorchbindClassType()254 TypePtr TypeParser::parseTorchbindClassType() {
255 static constexpr std::array<const char*, 4> expected_atoms = {
256 "torch", ".", "classes", "."};
257 for (const auto& atom : expected_atoms) {
258 expect(atom);
259 }
260 std::string ns = next();
261 expectChar('.');
262 std::string classname = next();
263 std::string customClassName = "__torch__.torch.classes.";
264 customClassName.reserve(
265 customClassName.size() + ns.size() + 1 + classname.size());
266 customClassName.append(ns);
267 customClassName.push_back('.');
268 customClassName.append(classname);
269 return torch::getCustomClass(customClassName);
270 }
271
expect(const char * s)272 void TypeParser::expect(const char* s) {
273 c10::string_view token = cur();
274 TORCH_CHECK(
275 token == s,
276 "Error when parsing type ",
277 pythonStr_,
278 ": Expect ",
279 s,
280 ", but get ",
281 token);
282 advance();
283 }
284
285 // c10::string_view::operator== calls memcmp to compare against the target
286 // string; we can do better if we specialize for a single character.
expectChar(char c)287 void TypeParser::expectChar(char c) {
288 c10::string_view token = cur();
289 TORCH_CHECK(
290 token.size() == 1 && token[0] == c,
291 "Error when parsing type ",
292 pythonStr_,
293 ": Expect ",
294 c,
295 ", but get ",
296 token);
297 advance();
298 }
299
lex()300 void TypeParser::lex() {
301 // skip white spaces
302 while (start_ < pythonStr_.size() && pythonStr_[start_] == ' ')
303 ++start_;
304 if (start_ < pythonStr_.size()) {
305 if (isSpecialChar(pythonStr_[start_])) {
306 next_token_ = c10::string_view(pythonStr_.data() + start_++, 1);
307 } else { // A word
308 size_t end = start_;
309 for (; end < pythonStr_.size() && !isSpecialChar(pythonStr_[end]) &&
310 pythonStr_[end] != ' ';
311 ++end)
312 ;
313 next_token_ = c10::string_view(pythonStr_.data() + start_, end - start_);
314 start_ = end;
315 }
316 }
317 }
318
nextView()319 c10::string_view TypeParser::nextView() {
320 TORCH_CHECK(
321 !next_token_.empty(),
322 "Empty token queue in mobile type parser.",
323 "Check the format of the type string and make sure it's correct.");
324 c10::string_view token = cur();
325 advance();
326 return token;
327 }
328
next()329 std::string TypeParser::next() {
330 auto token = nextView();
331 return std::string(token.begin(), token.end());
332 }
333
advance()334 void TypeParser::advance() {
335 next_token_ = "";
336 lex();
337 }
338
cur() const339 C10_NODISCARD c10::string_view TypeParser::cur() const {
340 return next_token_;
341 }
342
parseType(const std::string & pythonStr)343 TORCH_API at::TypePtr parseType(const std::string& pythonStr) {
344 at::TypeParser parser(pythonStr);
345 return parser.parse();
346 }
347
parseType(std::vector<std::string> & pythonStrs)348 TORCH_API std::vector<at::TypePtr> parseType(
349 std::vector<std::string>& pythonStrs) {
350 at::TypeParser parser(pythonStrs);
351 return parser.parseList();
352 }
353
354 } // namespace c10
355