xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/irparser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/irparser.h>
2 
3 #include <ATen/EmptyTensor.h>
4 #include <torch/csrc/jit/frontend/lexer.h>
5 #include <torch/csrc/jit/frontend/parse_string_literal.h>
6 #include <torch/csrc/jit/frontend/schema_type_parser.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #else
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/empty_strided.h>
14 #endif
15 
16 #include <string>
17 #include <vector>
18 
19 namespace torch::jit {
20 
21 struct VarWithType;
22 struct ParsedLiteral;
23 
24 class IRParser {
25   friend void parseIR(
26       const std::string& str,
27       torch::jit::Graph* graph,
28       std::unordered_map<std::string, Value*>& vmap,
29       bool parse_tensor_constants);
IRParser(const std::string & str,torch::jit::Graph * graph,std::unordered_map<std::string,Value * > & vmap,bool parse_tensor_constants)30   IRParser(
31       const std::string& str,
32       torch::jit::Graph* graph,
33       std::unordered_map<std::string, Value*>& vmap,
34       bool parse_tensor_constants)
35       : L(std::make_shared<Source>(str)),
36         g(graph),
37         vmap(vmap),
38         type_parser(
39             L,
40             /*parse_complete_tensor_types*/ true,
41             /*allow_type_vars*/ true),
42         parse_tensor_constants_(parse_tensor_constants) {}
43 
44   std::string parseVar();
45   VarWithType parseVarWithType(bool allow_optional = false);
46   ParsedLiteral parseScalarLiteral(Node* n);
47 
48   void parse();
49   void parseGraphInputs();
50   void parseReturnOperator();
51 
52   void parseBlocks(Node* parentNode);
53   void parseBlock(Node* parentNode);
54   void parseBlockInputs(Block* b);
55   void parseBlockOutputs(Block* b);
56 
57   void parseOperatorsList(Block* b);
58   void parseOperator(Block* b);
59   void parseOperatorOutputs(std::vector<VarWithType>* outs);
60   std::string parseOperatorName();
61   void parseOperatorInputs(Node* n);
62   void parseAttrs(Node* n);
63   void parseAttr(Node* n);
64 
65   void parseList(
66       int begin,
67       int sep,
68       int end,
69       const std::function<void()>& callback);
70 
71   void bypassTypeAnnotationList();
72 
73   Value* findValueInVMap(const std::string& name);
74 
75   torch::jit::Lexer L;
76   torch::jit::Graph* g = nullptr;
77   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
78   std::unordered_map<std::string, Value*>& vmap;
79   SchemaTypeParser type_parser;
80   bool parse_tensor_constants_;
81   std::vector<Node*> deferred_tensor_value_initializations_;
82   std::vector<Node*> deferred_empty_container_initializations_;
83 };
84 
85 struct ParsedLiteral {
86   ParsedLiteral() = default;
87 
88   AttributeKind k = AttributeKind::t;
89 
90   int64_t i = 0;
91   std::string s = "";
92   double f = 0.0;
93   c10::complex<double> c = c10::complex<double>(0, 0);
94   TypePtr ty;
95   std::vector<int64_t> is;
96   std::vector<std::string> ss;
97   std::vector<double> fs;
98   std::vector<c10::complex<double>> cs;
99   std::vector<TypePtr> tys;
100 };
101 
102 struct VarWithType {
103   VarWithType() = default;
104   std::string name;
105   TypePtr type;
106 };
107 
parseIR(const std::string & str,torch::jit::Graph * graph,std::unordered_map<std::string,Value * > & vmap,bool parse_tensor_constants)108 void parseIR(
109     const std::string& str,
110     torch::jit::Graph* graph,
111     std::unordered_map<std::string, Value*>& vmap,
112     bool parse_tensor_constants) {
113   torch::jit::IRParser p(str, graph, vmap, parse_tensor_constants);
114   p.parse();
115 }
116 
parseIR(const std::string & str,torch::jit::Graph * graph,bool parse_tensor_constants)117 void parseIR(
118     const std::string& str,
119     torch::jit::Graph* graph,
120     bool parse_tensor_constants) {
121   std::unordered_map<std::string, Value*> vmap;
122   parseIR(str, graph, vmap, parse_tensor_constants);
123 }
124 
parseVarWithType(bool allow_optional)125 VarWithType IRParser::parseVarWithType(bool allow_optional) {
126   VarWithType r;
127   r.name = parseVar();
128   if (allow_optional) {
129     r.type = nullptr;
130   } else {
131     r.type = TensorType::get();
132   }
133   if (L.nextIf(':')) {
134     auto type_alias = type_parser.parseType();
135     AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
136     r.type = type_alias.first;
137   }
138   return r;
139 }
140 
parseVar()141 std::string IRParser::parseVar() {
142   L.expect('%');
143   std::string name;
144   bool continue_parsing = false;
145   do {
146     if (L.cur().kind == TK_IDENT) {
147       name += L.expect(TK_IDENT).text();
148     } else {
149       name += L.expect(TK_NUMBER).text();
150     }
151     continue_parsing = false;
152     if (L.nextIf('.')) {
153       continue_parsing = true;
154       name += '.';
155     } else if (L.cur().kind == TK_NUMBER && L.cur().text()[0] == '.') {
156       continue_parsing = true;
157     }
158   } while (continue_parsing);
159   return name;
160 }
161 
parseOperatorOutputs(std::vector<VarWithType> * outs)162 void IRParser::parseOperatorOutputs(std::vector<VarWithType>* outs) {
163   if (L.cur().kind != '%') {
164     return;
165   }
166   parseList(TK_NOTHING, ',', TK_NOTHING, [&] {
167     outs->push_back(parseVarWithType(true));
168   });
169   L.expect('=');
170 }
171 
172 // Parse string or numeric literal and return it along with its type.
parseScalarLiteral(Node * n)173 ParsedLiteral IRParser::parseScalarLiteral(Node* n) {
174   auto token = L.cur();
175   std::string str;
176   std::pair<TypePtr, std::optional<c10::AliasInfo>> type_alias;
177   ParsedLiteral r;
178   switch (token.kind) {
179     case TK_STRINGLITERAL:
180       r.k = AttributeKind::s;
181       r.s = parseStringLiteral(token.range, token.text());
182       L.next();
183       return r;
184     case '-':
185       str = "-";
186       L.next();
187       if (L.cur().kind != TK_NUMBER) {
188         throw(
189             ErrorReport(token.range)
190             << "Expected a number after '-' but got:" << token.text());
191       }
192       [[fallthrough]];
193     case TK_NUMBER:
194       str += L.cur().text();
195       if (str.find('j') != std::string::npos) {
196         r.k = AttributeKind::c;
197         double imag = 0.0f;
198         try {
199           imag = std::stod(str.substr(0, str.size() - 1));
200         } catch (const std::invalid_argument& e) {
201           throw(
202               ErrorReport(token.range)
203               << "Number cannot be converted to double");
204         } catch (const std::out_of_range& e) {
205           throw(
206               ErrorReport(token.range)
207               << "Number is too long to be represented in type double");
208         }
209         r.c = c10::complex<double>(0, imag);
210       } else if (
211           str.find('.') != std::string::npos ||
212           str.find('e') != std::string::npos) {
213         r.k = AttributeKind::f;
214         try {
215           r.f = std::stod(str);
216         } catch (const std::invalid_argument& e) {
217           throw(
218               ErrorReport(token.range)
219               << "Number cannot be converted to double");
220         } catch (const std::out_of_range& e) {
221           throw(
222               ErrorReport(token.range)
223               << "Number is too long to be represented in type double");
224         }
225       } else {
226         r.k = AttributeKind::i;
227         try {
228           r.i = std::stoll(str);
229         } catch (const std::invalid_argument& e) {
230           throw(
231               ErrorReport(token.range)
232               << "Number cannot be converted to integer");
233         } catch (const std::out_of_range& e) {
234           throw(ErrorReport(token.range) << "Number is too big");
235         }
236       }
237       L.next();
238       return r;
239     case TK_IDENT:
240       // Type literal
241       r.k = AttributeKind::ty;
242       type_alias = type_parser.parseType();
243       AT_ASSERTM(!type_alias.second, "Parsing IR with Alias Info not handled");
244       r.ty = type_alias.first;
245       return r;
246     case '<': {
247       L.next();
248       auto text = L.expect(TK_IDENT);
249       if (text.text() != "Tensor") {
250         throw(
251             ErrorReport(token.range)
252             << "Could not parse literal" << token.text());
253       }
254       if (!parse_tensor_constants_) {
255         throw(
256             ErrorReport(token.range)
257             << "Tensor constant encountered but `parse_tensor_constants` set to false"
258             << token.text());
259       }
260       L.expect('>');
261       // these values will be set with randomly initialized data in
262       // a post processing pass;
263       deferred_tensor_value_initializations_.push_back(n);
264       r.k = AttributeKind::t;
265       return r;
266     }
267     case '{': {
268       L.next();
269       if (L.cur().kind == '-') {
270         L.next();
271       }
272       auto text = L.expect(TK_NUMBER);
273       if (!parse_tensor_constants_) {
274         throw(
275             ErrorReport(token.range)
276             << "Single-element tensor constant encountered but "
277             << "`parse_tensor_constants` is set to false " << token.text());
278       }
279       L.expect('}');
280       deferred_tensor_value_initializations_.push_back(n);
281       r.k = AttributeKind::t;
282       return r;
283     }
284     default:
285       throw(
286           ErrorReport(token.range)
287           << "Could not parse literal" << token.text());
288   }
289 }
290 
bypassTypeAnnotationList()291 void IRParser::bypassTypeAnnotationList() {
292   int depth = 0;
293   bool bypassed_list = false;
294   while (depth != 0 || !bypassed_list) {
295     if (L.cur().kind == '[') {
296       bypassed_list = true;
297       depth++;
298     } else if (L.cur().kind == ']') {
299       depth--;
300     }
301     L.next();
302   }
303 }
304 
305 /** \brief Parse attribute and add it to the node N.
306  *
307  * The function determines the attribute type (string, int, float, complex, list
308  * of strings, list of ints, list of floats, list of complex, and a list of
309  * tensors (currently only for empty lists)). An attribute looks like the
310  * following: AttrName=AttrValue Where AttrValue can be a list or a scalar
311  * literal, e.g.: size = 27 name = "Bob" coefs = [1.2, 3.4, 0.6]
312  */
parseAttr(Node * n)313 void IRParser::parseAttr(Node* n) {
314   std::string attrname = L.expect(TK_IDENT).text();
315   L.expect('=');
316   if (L.cur().kind == '[') {
317     // list
318     AttributeKind k = AttributeKind::ts;
319     c10::List<int64_t> is;
320     c10::List<std::string> ss;
321     c10::List<double> fs;
322     c10::List<c10::complex<double>> cs;
323     std::vector<TypePtr> tys;
324     int elem_num = 0;
325     parseList('[', ',', ']', [&] {
326       ParsedLiteral r = parseScalarLiteral(n);
327       switch (r.k) {
328         case AttributeKind::s:
329           ss.push_back(r.s);
330           AT_ASSERT(!elem_num++ || k == AttributeKind::ss);
331           k = AttributeKind::ss;
332           break;
333         case AttributeKind::i:
334           is.push_back(r.i);
335           AT_ASSERT(!elem_num++ || k == AttributeKind::is);
336           k = AttributeKind::is;
337           break;
338         case AttributeKind::f:
339           fs.push_back(r.f);
340           AT_ASSERT(!elem_num++ || k == AttributeKind::fs);
341           k = AttributeKind::fs;
342           break;
343         case AttributeKind::c:
344           cs.push_back(r.c);
345           AT_ASSERT(!elem_num++ || k == AttributeKind::cs);
346           k = AttributeKind::cs;
347           break;
348         case AttributeKind::ty:
349           tys.push_back(r.ty);
350           AT_ASSERT(!elem_num++ || k == AttributeKind::tys);
351           k = AttributeKind::tys;
352           break;
353         default:
354           throw(ErrorReport(L.cur().range) << "Unexpected attr type");
355       }
356     });
357     switch (k) {
358       case AttributeKind::ts:
359         n->ival_(Symbol::attr(attrname), IValue());
360         break;
361       case AttributeKind::ss:
362         n->ival_(Symbol::attr(attrname), IValue(ss));
363         break;
364       case AttributeKind::fs:
365         n->ival_(Symbol::attr(attrname), IValue(fs));
366         break;
367       case AttributeKind::cs:
368         n->ival_(Symbol::attr(attrname), IValue(cs));
369         break;
370       case AttributeKind::is:
371         n->ival_(Symbol::attr(attrname), IValue(is));
372         break;
373       case AttributeKind::tys:
374         n->tys_(Symbol::attr(attrname), tys);
375         break;
376       default:
377         throw(ErrorReport(L.cur().range) << "Unexpected attr type");
378     }
379   } else if (L.cur().text() == "annotate") {
380     L.next();
381     L.expect('(');
382     auto type = L.cur().text();
383     if (type != "List" && type != "Dict") {
384       throw(
385           ErrorReport(L.cur().range)
386           << "Unexpected annotation (only List and Dict can be parsed)");
387     }
388     L.next();
389     // ignore the annotations on the IValue constants, and instead recover
390     // type from the Node output
391     // Note: we could also use script_type_parser
392     bypassTypeAnnotationList();
393     L.expect(',');
394     // expect an empty definition (note - this isn't always true)
395     if (type == "Dict") {
396       L.expect('{');
397       L.expect('}');
398     } else if (type == "List") {
399       L.expect('[');
400       L.expect(']');
401     }
402     L.expect(')');
403     deferred_empty_container_initializations_.push_back(n);
404   } else {
405     // scalar
406     ParsedLiteral r = parseScalarLiteral(n);
407     switch (r.k) {
408       case AttributeKind::s:
409         n->s_(Symbol::attr(attrname), r.s);
410         break;
411       case AttributeKind::i:
412         n->i_(Symbol::attr(attrname), r.i);
413         break;
414       case AttributeKind::f:
415         n->f_(Symbol::attr(attrname), r.f);
416         break;
417       case AttributeKind::c:
418         n->c_(Symbol::attr(attrname), r.c);
419         break;
420       case AttributeKind::ty:
421         n->ty_(Symbol::attr(attrname), r.ty);
422         break;
423       case AttributeKind::t:
424         // initialized with random data later
425         break;
426       default:
427         throw(ErrorReport(L.cur().range) << "Unexpected attr type");
428     }
429     return;
430   }
431 }
432 
parseAttrs(Node * n)433 void IRParser::parseAttrs(Node* n) {
434   parseList('[', ',', ']', [&] { parseAttr(n); });
435 }
436 
parseOperatorInputs(Node * n)437 void IRParser::parseOperatorInputs(Node* n) {
438   if (L.cur().kind == '[') {
439     parseAttrs(n);
440   }
441   parseList('(', ',', ')', [&] {
442     std::string var_name = parseVar();
443     n->addInput(findValueInVMap(var_name));
444   });
445 }
446 
parseBlocks(Node * parentNode)447 void IRParser::parseBlocks(Node* parentNode) {
448   L.expect(TK_INDENT);
449   while (L.cur().kind != TK_DEDENT) {
450     parseBlock(parentNode);
451   }
452   L.expect(TK_DEDENT);
453 }
454 
parseBlockInputs(Block * b)455 void IRParser::parseBlockInputs(Block* b) {
456   parseList('(', ',', ')', [&] {
457     VarWithType v = parseVarWithType();
458     // If the name isn't valid, don't use it
459     std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
460     vmap[v.name] = b->addInput(uniq_name);
461     vmap[v.name]->setType(v.type);
462   });
463 }
464 
parseBlockOutputs(Block * b)465 void IRParser::parseBlockOutputs(Block* b) {
466   L.expect(TK_ARROW);
467   parseList('(', ',', ')', [&] {
468     std::string var_name = parseVar();
469     b->registerOutput(findValueInVMap(var_name));
470   });
471   L.expect(TK_NEWLINE);
472   L.expect(TK_DEDENT);
473 }
474 
475 /** \brief Parse a block.
476  *
477  * It should look like the following:
478  * blockName(input1, input2, input3, ...):
479  *   op1
480  *   op2
481  *   ...
482  *   opN
483  *   -> (output1, output2, output3, ...)
484  */
parseBlock(Node * parentNode)485 void IRParser::parseBlock(Node* parentNode) {
486   Block* b = parentNode->addBlock();
487   L.expect(TK_IDENT).text(); // Block name is not used anywhere.
488   parseBlockInputs(b);
489   L.expect(':');
490   parseOperatorsList(b);
491   parseBlockOutputs(b);
492 }
493 
494 /** \brief Parse a list of statements.
495  *
496  * It is expected to be delimited by TK_NEWLINE and end with TK_RETURN or
497  * TK_ARROW.
498  */
parseOperatorsList(Block * b)499 void IRParser::parseOperatorsList(Block* b) {
500   L.expect(TK_INDENT);
501   while (L.cur().kind != TK_ARROW && L.cur().kind != TK_RETURN) {
502     parseOperator(b);
503   }
504 }
505 
parseOperatorName()506 std::string IRParser::parseOperatorName() {
507   std::string name = L.expect(TK_IDENT).text();
508   L.expect(':');
509   L.expect(':');
510   name += "::" + L.expect(TK_IDENT).text();
511   return name;
512 }
513 
514 /** \brief Parse a statement.
515  *
516  * It should look like the following:
517  *   <outputs> = NodeName[<attributes>](<inputs>)
518  *     <blocks>
519  * Outputs, blocks and attributes are optional.
520  */
parseOperator(Block * b)521 void IRParser::parseOperator(Block* b) {
522   // Parse lefthand side.
523   std::vector<VarWithType> outs;
524   parseOperatorOutputs(&outs);
525 
526   // Parse the name and create the corresponding node in the graph.
527   auto source_range = L.cur().range;
528   std::string name = parseOperatorName();
529   Node* n = g->create(Symbol::fromQualString(name), {}, outs.size())
530                 ->setSourceRange(source_range);
531 
532   // Parse attributes and inputs.
533   parseOperatorInputs(n);
534 
535   const FunctionSchema* schema = n->maybeSchema();
536 
537   // Register outputs.
538   unsigned idx = 0;
539   for (const VarWithType& v : outs) {
540     vmap[v.name] = n->outputs()[idx];
541     if (schema && !schema->is_varret()) {
542       TORCH_CHECK(
543           schema->returns().size() > idx,
544           "Operator parsing error: out of bounds access at ",
545           idx,
546           " to schema->returns() which size is ",
547           schema->returns().size(),
548           " in size");
549       auto schema_return_type = schema->returns().at(idx).type();
550       if (!v.type) {
551         vmap[v.name]->setType(schema_return_type);
552       } else {
553         // Don't currently support checking against type variables
554         // TODO: support?
555         if (!schema_return_type->hasFreeVariables() &&
556             !v.type->isSubtypeOf(*schema_return_type)) {
557           throw(
558               ErrorReport(source_range)
559               << "Annotated type " << v.type->repr_str()
560               << " does not match schema type "
561               << schema_return_type->repr_str() << " for operator " << *schema);
562         }
563         vmap[v.name]->setType(v.type);
564       }
565     } else {
566       vmap[v.name]->setType(v.type ? v.type : TensorType::get());
567     }
568     idx++;
569   }
570 
571   // Insert the new node into block B.
572   b->appendNode(n);
573 
574   // If the statement has nested blocks, parse them:
575   if (L.cur().kind == TK_INDENT) {
576     parseBlocks(n);
577   }
578   L.nextIf(TK_NEWLINE);
579 }
580 
parseGraphInputs()581 void IRParser::parseGraphInputs() {
582   parseList('(', ',', ')', [&] {
583     VarWithType v = parseVarWithType();
584     // If the name isn't valid, don't use it
585     std::string uniq_name = Value::isValidName(v.name) ? v.name : "";
586     vmap[v.name] = g->addInput(uniq_name);
587     vmap[v.name]->setType(v.type);
588   });
589 }
590 
591 /** \brief Parse return statement.
592  *
593  * It should look like the following:
594  *   return (x : TypeX, y : TypeY, z, ...)
595  */
parseReturnOperator()596 void IRParser::parseReturnOperator() {
597   L.expect(TK_RETURN);
598 
599   // Parse output names and types
600   parseList('(', ',', ')', [&] {
601     std::string var_name = parseVar();
602     g->registerOutput(findValueInVMap(var_name));
603   });
604 
605   // Consume ending tokens
606   if (L.cur().kind != TK_EOF) {
607     L.expect(TK_NEWLINE);
608     L.expect(TK_DEDENT);
609   }
610 }
611 
612 /** \brief Parse entire graph.
613  *
614  * It should look like the following:
615  *   graphName (input1, input2, ... inputN):
616  *     op1
617  *     op2
618  *     ...
619  *     opN
620  *     return (output1, output2, ... outputN)
621  */
parse()622 void IRParser::parse() {
623   // Parse graph definition, it should look like the following:
624   // graphName (input1, input2, ... inputN):
625   L.expect(TK_IDENT);
626   parseGraphInputs();
627   L.expect(':');
628 
629   // After the definition we should have a list of statements, parse it:
630   parseOperatorsList(g->block());
631 
632   // The last statement should be return, which specifies graph outputs
633   parseReturnOperator();
634 
635   for (Node* n : deferred_tensor_value_initializations_) {
636     auto type = n->output()->type()->expect<TensorType>();
637     auto tt = n->output()->type()->cast<TensorType>();
638     TORCH_INTERNAL_ASSERT(tt, "expected tensor output ", *n);
639     auto sizes = tt->sizes().concrete_sizes();
640     TORCH_INTERNAL_ASSERT(sizes);
641     auto strides = tt->strides().concrete_sizes();
642     TORCH_INTERNAL_ASSERT(strides);
643     auto device = tt->device();
644     TORCH_INTERNAL_ASSERT(device);
645     auto dtype = tt->scalarType();
646     TORCH_INTERNAL_ASSERT(dtype);
647     auto options = at::TensorOptions(*device).dtype(dtype);
648     auto t = n->t_(attr::value, at::empty_strided(*sizes, *strides, options));
649     (void)t;
650   }
651 
652   for (Node* n : deferred_empty_container_initializations_) {
653     auto type = n->output()->type();
654     IValue val;
655     if (type->kind() == TypeKind::ListType) {
656       val = c10::impl::GenericList(type->containedType(0));
657     } else if (type->kind() == TypeKind::DictType) {
658       val = c10::impl::GenericDict(
659           type->containedType(0), type->containedType(1));
660     }
661     n->ival_(attr::value, val);
662   }
663 }
664 
parseList(int begin,int sep,int end,const std::function<void ()> & callback)665 void IRParser::parseList(
666     int begin,
667     int sep,
668     int end,
669     const std::function<void()>& callback) {
670   if (begin != TK_NOTHING) {
671     L.expect(begin);
672   }
673   if (L.cur().kind != end) {
674     do {
675       callback();
676     } while (L.nextIf(sep));
677   }
678   if (end != TK_NOTHING) {
679     L.expect(end);
680   }
681 }
682 
findValueInVMap(const std::string & name)683 Value* IRParser::findValueInVMap(const std::string& name) {
684   if (!vmap.count(name)) {
685     throw(
686         ErrorReport(L.cur().range)
687         << "Cannot find a variable with name '" << name << "'");
688   }
689   return vmap.at(name);
690 }
691 
692 } // namespace torch::jit
693