xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/vararg_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/vararg_functions.h>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/core/class_type.h>
6 #include <c10/util/irange.h>
7 
8 namespace torch::jit {
9 
10 namespace {
11 static constexpr int defaultPrecision = 6;
12 
13 // IValue tags are intentionally private, so we need additional logic to cast
14 // the IValue type to the specified format.
addFormattedArg(char key,const IValue & ival,std::stringstream & ss,int precision=defaultPrecision)15 void addFormattedArg(
16     char key,
17     const IValue& ival,
18     std::stringstream& ss,
19     int precision = defaultPrecision) {
20   // TODO: Implement precision-based formatting
21   std::stringstream tmp;
22   switch (key) {
23     case 'd':
24     case 'i':
25       TORCH_CHECK(
26           ival.isScalar(),
27           "%",
28           key,
29           " requires a number for formatting, but got ",
30           ival.tagKind());
31       if (ival.isInt()) {
32         ss << ival.toInt();
33       } else {
34         ss << static_cast<int>(ival.toDouble());
35       }
36       break;
37     case 'e':
38     case 'E':
39       TORCH_CHECK(
40           ival.isScalar(),
41           "%",
42           key,
43           " requires a number for formatting, but got ",
44           ival.tagKind());
45       tmp << std::setprecision(precision) << std::scientific;
46       if (key == 'E') {
47         tmp << std::uppercase;
48       }
49       if (ival.isInt()) {
50         tmp << static_cast<float>(ival.toInt());
51       } else {
52         tmp << static_cast<float>(ival.toDouble());
53       }
54       ss << tmp.str();
55       break;
56     case 'f':
57     case 'F':
58       TORCH_CHECK(
59           ival.isScalar(),
60           "%",
61           key,
62           " requires a number for formatting, but got ",
63           ival.tagKind());
64       tmp << std::setprecision(precision) << std::fixed;
65       if (ival.isInt()) {
66         tmp << static_cast<float>(ival.toInt());
67       } else {
68         tmp << static_cast<float>(ival.toDouble());
69       }
70       ss << tmp.str();
71       break;
72     case 'c':
73       TORCH_CHECK(
74           ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1),
75           "%",
76           key,
77           " requires an int or char for formatting, but got ",
78           ival.tagKind());
79       if (ival.isInt()) {
80         ss << static_cast<char>(ival.toInt());
81       } else {
82         ss << ival.toStringRef();
83       }
84       break;
85     case 's':
86       if (ival.isString()) {
87         ss << ival.toStringRef();
88       } else {
89         ss << ival;
90       }
91       break;
92     default:
93       TORCH_CHECK(
94           false,
95           "The specifier %",
96           key,
97           " is not supported in TorchScript format strings");
98   }
99 }
100 
101 } // namespace
102 
tupleUnpack(Stack & stack)103 void tupleUnpack(Stack& stack) {
104   auto tuple = pop(stack).toTuple();
105   stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
106 }
107 
format(Stack & stack,size_t num_inputs)108 void format(Stack& stack, size_t num_inputs) {
109   TORCH_CHECK(
110       num_inputs > 0 && num_inputs <= stack.size(),
111       "Invalid number of inputs for format string: ",
112       num_inputs);
113 
114   // static const std::regex unsupported_options("\\{(.*?)\\}");
115   auto format = peek(stack, 0, num_inputs).toStringRef();
116   // // Temporally comment out the warning message because of
117   // // "StdRegexIsAwful" internal Lint error, to prevent sev
118   // // of std::regex from PT mobile.
119   // if (std::regex_search(format, unsupported_options)) {
120   //   TORCH_WARN("Format options are not supported.");
121   // }
122 
123   auto args = last(stack, num_inputs - 1);
124   std::stringstream ss;
125   for (size_t begin = 0, used_args = 0; true; ++used_args) {
126     size_t loc = format.find("{}", begin);
127     if (loc == std::string::npos) {
128       ss << format.substr(begin);
129       break;
130     }
131     ss << format.substr(begin, loc - begin);
132     if (used_args >= args.size()) {
133       AT_ERROR("Too few arguments for format string: ", format);
134     }
135     ss << args[used_args];
136     begin = loc + 2;
137   }
138 
139   drop(stack, num_inputs);
140   push(stack, ss.str());
141 }
142 
einsum(Stack & stack,size_t num_inputs)143 void einsum(Stack& stack, size_t num_inputs) {
144   TORCH_CHECK(
145       num_inputs >= 2,
146       "einsum(): must specify the equation string and at least one operand, ",
147       "or at least one operand and its subscripts list");
148 
149   const auto args = last(stack, num_inputs);
150 
151   // Convert the subscript list format which is an interleaving of operand and
152   // its subscripts list with an optional output subscripts list at the end
153   // (see documentation for more details on this) to the equation string
154   // format by creating the equation string from the subscripts list and
155   // grouping the input operands into a tensorlist (List[Tensor]).
156   std::stringstream ss;
157 
158   auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
159     for (const auto i : c10::irange(l.size())) {
160       TORCH_CHECK(
161           l[i] >= 0 && l[i] < 52,
162           "einsum(): expected subscript ",
163           i,
164           " in argument ",
165           arg_num,
166           " to be within the range [0, 52), but got ",
167           l[i]);
168       if (l[i] < 26) {
169         ss << static_cast<char>(l[i] + 'A');
170       } else {
171         ss << static_cast<char>(l[i] - 26 + 'a');
172       }
173     }
174   };
175 
176   // Parse subscripts for input operands
177   for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
178     TORCH_CHECK(
179         args[i].isIntList(),
180         "einsum(): expected List[int] in argument ",
181         i,
182         ", but got ",
183         args[i].type()->repr_str());
184     parse_sublist(args[i].toIntList(), i);
185     if (i + 2 < num_inputs) {
186       ss << ',';
187     }
188   }
189 
190   // Parse optional output subscripts (provided if #args is odd)
191   if (num_inputs % 2 == 1) {
192     TORCH_CHECK(
193         args.back().isIntList(),
194         "einsum(): expected List[int] in argument ",
195         num_inputs - 1,
196         ", but got ",
197         args.back().type()->repr_str());
198     ss << "->";
199     parse_sublist(args.back().toIntList(), num_inputs - 1);
200   }
201 
202   const auto equation = ss.str();
203   std::vector<at::Tensor> operands;
204 
205   // Parse input operands
206   const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
207   for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
208     TORCH_CHECK(
209         args[i].isTensor(),
210         "einsum(): expected Tensor in argument ",
211         i,
212         ", but got ",
213         args[i].type()->repr_str());
214     operands.emplace_back(args[i].toTensor());
215   }
216 
217   drop(stack, num_inputs);
218   push(stack, at::einsum(equation, operands));
219 }
220 
percentFormat(Stack & stack,size_t num_inputs)221 void percentFormat(Stack& stack, size_t num_inputs) {
222   auto format_str = peek(stack, 0, num_inputs).toStringRef();
223   auto args = last(stack, num_inputs - 1)[0];
224   size_t args_size = 1; // assumed size
225   if (args.isTuple()) {
226     args_size = args.toTupleRef().elements().size();
227   }
228   std::stringstream ss;
229   size_t used_args = 0;
230   size_t begin = 0;
231   while (true) {
232     size_t percent_idx = format_str.find('%', begin);
233     if (percent_idx == std::string::npos) {
234       ss << format_str.substr(begin);
235       break;
236     }
237     size_t format_idx = percent_idx + 1;
238     TORCH_CHECK(
239         percent_idx < format_str.length() - 1, "Incomplete format specifier");
240     ss << format_str.substr(begin, percent_idx - begin);
241     if (format_str.at(format_idx) == '%') {
242       ss << '%';
243       begin = percent_idx + 2; // skip the `%` and the format specifier
244       continue;
245     }
246     TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
247     char key = format_str.at(format_idx);
248     IValue arg;
249     if (args.isTuple()) {
250       arg = args.toTupleRef().elements()[used_args];
251     } else {
252       arg = args;
253     }
254     addFormattedArg(key, arg, ss);
255     begin = percent_idx + 2;
256     ++used_args;
257   }
258   TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
259   drop(stack, num_inputs);
260   push(stack, ss.str());
261 }
262 
listUnpack(Stack & stack,size_t num_outputs)263 void listUnpack(Stack& stack, size_t num_outputs) {
264   auto list = pop(stack).toList();
265   TORCH_CHECK(
266       list.size() == num_outputs,
267       "Expected ",
268       num_outputs,
269       " elements in a list but found ",
270       list.size());
271   stack.insert(stack.end(), list.begin(), list.end());
272 }
273 
tupleConstruct(Stack & stack,size_t num_inputs)274 void tupleConstruct(Stack& stack, size_t num_inputs) {
275   if (num_inputs > stack.size()) {
276     TORCH_CHECK(false, "Invalid number of inputs: ", num_inputs);
277   }
278   switch (num_inputs) {
279     case 0:
280       stack.emplace_back(c10::ivalue::Tuple::create());
281       break;
282     case 1:
283       stack.back() = c10::ivalue::Tuple::create(std::move(stack.back()));
284       break;
285     case 2: {
286       auto tuple = c10::ivalue::Tuple::create(
287           std::move(stack[stack.size() - 2]),
288           std::move(stack[stack.size() - 1]));
289       stack.pop_back();
290       stack.back() = std::move(tuple);
291       break;
292     }
293     case 3: {
294       auto tuple = c10::ivalue::Tuple::create(
295           std::move(stack[stack.size() - 3]),
296           std::move(stack[stack.size() - 2]),
297           std::move(stack[stack.size() - 1]));
298       stack.pop_back();
299       stack.pop_back();
300       stack.back() = std::move(tuple);
301       break;
302     }
303     default: {
304       std::vector<IValue> elems{
305           std::make_move_iterator(stack.end() - num_inputs),
306           std::make_move_iterator(stack.end())};
307       drop(stack, num_inputs - 1);
308       stack.back() = c10::ivalue::Tuple::create(std::move(elems));
309       break;
310     }
311   }
312 }
313 
namedTupleConstruct(Stack & stack,c10::TypePtr tuple_type,size_t num_inputs)314 void namedTupleConstruct(
315     Stack& stack,
316     c10::TypePtr tuple_type,
317     size_t num_inputs) {
318   std::vector<IValue> elems{
319       std::make_move_iterator(stack.end() - num_inputs),
320       std::make_move_iterator(stack.end())};
321   drop(stack, num_inputs);
322   push(
323       stack,
324       c10::ivalue::Tuple::createNamed(std::move(elems), std::move(tuple_type)));
325 }
326 
listConstruct(Stack & stack,const c10::Type & list_type,size_t num_inputs)327 void listConstruct(
328     Stack& stack,
329     const c10::Type& list_type,
330     size_t num_inputs) {
331   // Structuring the implementation this way allows NRVO to avoid
332   // move-constructing vals on its way onto the stack. Moving a List
333   // isn't free.
334   auto makeList =
335       [](Stack& stack, const c10::Type& list_type, size_t num_inputs) {
336         c10::List<IValue> vals(list_type.containedType(0));
337         vals.reserve(num_inputs);
338         for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) {
339           vals.push_back(std::move(stack[i]));
340         }
341         drop(stack, num_inputs);
342         return vals;
343       };
344   stack.emplace_back(makeList(stack, list_type, num_inputs));
345 }
346 
dictConstruct(Stack & stack,const c10::Type & dict_type,size_t num_inputs)347 void dictConstruct(
348     Stack& stack,
349     const c10::Type& dict_type,
350     size_t num_inputs) {
351   auto vals = c10::impl::GenericDict(
352       dict_type.containedType(0), dict_type.containedType(1));
353   vals.reserve(num_inputs / 2);
354   // loop from the bottom of the stack to ensure the dictConstruct preserve
355   // the inputs order.
356   auto inputs = last(stack, num_inputs);
357   for (size_t i = 0; i < num_inputs; i += 2) {
358     auto key = inputs[i];
359     auto val = inputs[i + 1];
360     vals.insert_or_assign(std::move(key), std::move(val));
361   }
362   drop(stack, num_inputs);
363   push(stack, std::move(vals));
364 }
365 
createObject(Stack & stack,const at::ClassTypePtr & type,bool as_weak_ref)366 void createObject(
367     Stack& stack,
368     const at::ClassTypePtr& type,
369     bool as_weak_ref) {
370   if (as_weak_ref) {
371     c10::WeakTypePtr weak(type->compilation_unit(), type);
372     auto userObj = c10::ivalue::Object::create(
373         c10::WeakOrStrongTypePtr(weak), type->numAttributes());
374     push(stack, std::move(userObj));
375   } else {
376     auto userObj = c10::ivalue::Object::create(
377         c10::StrongTypePtr(type->compilation_unit(), type),
378         type->numAttributes());
379     push(stack, std::move(userObj));
380   }
381 }
382 
isinstance(Stack & stack,at::ArrayRef<at::TypePtr> types)383 void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
384   at::TypePtr ty = pop(stack).type();
385   for (const at::TypePtr& candidate : types) {
386     if (ty->isSubtypeOf(*candidate)) {
387       push(stack, true);
388       return;
389     }
390   }
391 
392   push(stack, false);
393 }
394 
tupleSlice(Stack & stack,size_t begin,size_t end)395 void tupleSlice(Stack& stack, size_t begin, size_t end) {
396   auto tuple = pop(stack).toTuple();
397   push(
398       stack,
399       c10::ivalue::Tuple::create(
400           tuple->elements().asArrayRef().slice(begin, end - begin)));
401 }
402 
dequantize(Stack & stack)403 void dequantize(Stack& stack) {
404   auto iv = pop(stack);
405   if (iv.isTuple()) {
406     auto tuple = iv.toTuple();
407     const auto& elems = tuple->elements();
408     std::vector<IValue> output_elems;
409     output_elems.reserve(elems.size());
410     for (const auto& elem : elems) {
411       if (elem.isTensor()) {
412         output_elems.emplace_back(at::dequantize(elem.toTensor()));
413       } else {
414         output_elems.emplace_back(elem);
415       }
416     }
417     push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
418   } else if (iv.isTensorList()) {
419     auto elems = iv.toTensorList();
420     auto output_list = c10::impl::GenericList(elems.elementType());
421     for (auto&& elem : elems) {
422       output_list.emplace_back(at::dequantize(elem));
423     }
424     push(stack, std::move(output_list));
425   } else {
426     TORCH_CHECK(
427         false,
428         "Unsupported type in dequantize, only List[Tensor] and \
429  Tuple[Tensor or other types] are supported, got type:",
430         toString(iv.type()));
431   }
432 }
433 
434 } // namespace torch::jit
435