1 #include <torch/csrc/jit/runtime/vararg_functions.h>
2
3 #include <ATen/Functions.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/core/class_type.h>
6 #include <c10/util/irange.h>
7
8 namespace torch::jit {
9
10 namespace {
11 static constexpr int defaultPrecision = 6;
12
13 // IValue tags are intentionally private, so we need additional logic to cast
14 // the IValue type to the specified format.
addFormattedArg(char key,const IValue & ival,std::stringstream & ss,int precision=defaultPrecision)15 void addFormattedArg(
16 char key,
17 const IValue& ival,
18 std::stringstream& ss,
19 int precision = defaultPrecision) {
20 // TODO: Implement precision-based formatting
21 std::stringstream tmp;
22 switch (key) {
23 case 'd':
24 case 'i':
25 TORCH_CHECK(
26 ival.isScalar(),
27 "%",
28 key,
29 " requires a number for formatting, but got ",
30 ival.tagKind());
31 if (ival.isInt()) {
32 ss << ival.toInt();
33 } else {
34 ss << static_cast<int>(ival.toDouble());
35 }
36 break;
37 case 'e':
38 case 'E':
39 TORCH_CHECK(
40 ival.isScalar(),
41 "%",
42 key,
43 " requires a number for formatting, but got ",
44 ival.tagKind());
45 tmp << std::setprecision(precision) << std::scientific;
46 if (key == 'E') {
47 tmp << std::uppercase;
48 }
49 if (ival.isInt()) {
50 tmp << static_cast<float>(ival.toInt());
51 } else {
52 tmp << static_cast<float>(ival.toDouble());
53 }
54 ss << tmp.str();
55 break;
56 case 'f':
57 case 'F':
58 TORCH_CHECK(
59 ival.isScalar(),
60 "%",
61 key,
62 " requires a number for formatting, but got ",
63 ival.tagKind());
64 tmp << std::setprecision(precision) << std::fixed;
65 if (ival.isInt()) {
66 tmp << static_cast<float>(ival.toInt());
67 } else {
68 tmp << static_cast<float>(ival.toDouble());
69 }
70 ss << tmp.str();
71 break;
72 case 'c':
73 TORCH_CHECK(
74 ival.isInt() || (ival.isString() && ival.toStringRef().length() == 1),
75 "%",
76 key,
77 " requires an int or char for formatting, but got ",
78 ival.tagKind());
79 if (ival.isInt()) {
80 ss << static_cast<char>(ival.toInt());
81 } else {
82 ss << ival.toStringRef();
83 }
84 break;
85 case 's':
86 if (ival.isString()) {
87 ss << ival.toStringRef();
88 } else {
89 ss << ival;
90 }
91 break;
92 default:
93 TORCH_CHECK(
94 false,
95 "The specifier %",
96 key,
97 " is not supported in TorchScript format strings");
98 }
99 }
100
101 } // namespace
102
tupleUnpack(Stack & stack)103 void tupleUnpack(Stack& stack) {
104 auto tuple = pop(stack).toTuple();
105 stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
106 }
107
format(Stack & stack,size_t num_inputs)108 void format(Stack& stack, size_t num_inputs) {
109 TORCH_CHECK(
110 num_inputs > 0 && num_inputs <= stack.size(),
111 "Invalid number of inputs for format string: ",
112 num_inputs);
113
114 // static const std::regex unsupported_options("\\{(.*?)\\}");
115 auto format = peek(stack, 0, num_inputs).toStringRef();
116 // // Temporally comment out the warning message because of
117 // // "StdRegexIsAwful" internal Lint error, to prevent sev
118 // // of std::regex from PT mobile.
119 // if (std::regex_search(format, unsupported_options)) {
120 // TORCH_WARN("Format options are not supported.");
121 // }
122
123 auto args = last(stack, num_inputs - 1);
124 std::stringstream ss;
125 for (size_t begin = 0, used_args = 0; true; ++used_args) {
126 size_t loc = format.find("{}", begin);
127 if (loc == std::string::npos) {
128 ss << format.substr(begin);
129 break;
130 }
131 ss << format.substr(begin, loc - begin);
132 if (used_args >= args.size()) {
133 AT_ERROR("Too few arguments for format string: ", format);
134 }
135 ss << args[used_args];
136 begin = loc + 2;
137 }
138
139 drop(stack, num_inputs);
140 push(stack, ss.str());
141 }
142
einsum(Stack & stack,size_t num_inputs)143 void einsum(Stack& stack, size_t num_inputs) {
144 TORCH_CHECK(
145 num_inputs >= 2,
146 "einsum(): must specify the equation string and at least one operand, ",
147 "or at least one operand and its subscripts list");
148
149 const auto args = last(stack, num_inputs);
150
151 // Convert the subscript list format which is an interleaving of operand and
152 // its subscripts list with an optional output subscripts list at the end
153 // (see documentation for more details on this) to the equation string
154 // format by creating the equation string from the subscripts list and
155 // grouping the input operands into a tensorlist (List[Tensor]).
156 std::stringstream ss;
157
158 auto parse_sublist = [&ss](const c10::List<int64_t>& l, size_t arg_num) {
159 for (const auto i : c10::irange(l.size())) {
160 TORCH_CHECK(
161 l[i] >= 0 && l[i] < 52,
162 "einsum(): expected subscript ",
163 i,
164 " in argument ",
165 arg_num,
166 " to be within the range [0, 52), but got ",
167 l[i]);
168 if (l[i] < 26) {
169 ss << static_cast<char>(l[i] + 'A');
170 } else {
171 ss << static_cast<char>(l[i] - 26 + 'a');
172 }
173 }
174 };
175
176 // Parse subscripts for input operands
177 for (auto i = decltype(num_inputs){1}; i < num_inputs; i += 2) {
178 TORCH_CHECK(
179 args[i].isIntList(),
180 "einsum(): expected List[int] in argument ",
181 i,
182 ", but got ",
183 args[i].type()->repr_str());
184 parse_sublist(args[i].toIntList(), i);
185 if (i + 2 < num_inputs) {
186 ss << ',';
187 }
188 }
189
190 // Parse optional output subscripts (provided if #args is odd)
191 if (num_inputs % 2 == 1) {
192 TORCH_CHECK(
193 args.back().isIntList(),
194 "einsum(): expected List[int] in argument ",
195 num_inputs - 1,
196 ", but got ",
197 args.back().type()->repr_str());
198 ss << "->";
199 parse_sublist(args.back().toIntList(), num_inputs - 1);
200 }
201
202 const auto equation = ss.str();
203 std::vector<at::Tensor> operands;
204
205 // Parse input operands
206 const auto end = num_inputs % 2 == 1 ? num_inputs - 1 : num_inputs;
207 for (auto i = decltype(num_inputs){0}; i < end; i += 2) {
208 TORCH_CHECK(
209 args[i].isTensor(),
210 "einsum(): expected Tensor in argument ",
211 i,
212 ", but got ",
213 args[i].type()->repr_str());
214 operands.emplace_back(args[i].toTensor());
215 }
216
217 drop(stack, num_inputs);
218 push(stack, at::einsum(equation, operands));
219 }
220
percentFormat(Stack & stack,size_t num_inputs)221 void percentFormat(Stack& stack, size_t num_inputs) {
222 auto format_str = peek(stack, 0, num_inputs).toStringRef();
223 auto args = last(stack, num_inputs - 1)[0];
224 size_t args_size = 1; // assumed size
225 if (args.isTuple()) {
226 args_size = args.toTupleRef().elements().size();
227 }
228 std::stringstream ss;
229 size_t used_args = 0;
230 size_t begin = 0;
231 while (true) {
232 size_t percent_idx = format_str.find('%', begin);
233 if (percent_idx == std::string::npos) {
234 ss << format_str.substr(begin);
235 break;
236 }
237 size_t format_idx = percent_idx + 1;
238 TORCH_CHECK(
239 percent_idx < format_str.length() - 1, "Incomplete format specifier");
240 ss << format_str.substr(begin, percent_idx - begin);
241 if (format_str.at(format_idx) == '%') {
242 ss << '%';
243 begin = percent_idx + 2; // skip the `%` and the format specifier
244 continue;
245 }
246 TORCH_CHECK(used_args < args_size, "Too few arguments for format string");
247 char key = format_str.at(format_idx);
248 IValue arg;
249 if (args.isTuple()) {
250 arg = args.toTupleRef().elements()[used_args];
251 } else {
252 arg = args;
253 }
254 addFormattedArg(key, arg, ss);
255 begin = percent_idx + 2;
256 ++used_args;
257 }
258 TORCH_CHECK(used_args == args_size, "Too many arguments for format string");
259 drop(stack, num_inputs);
260 push(stack, ss.str());
261 }
262
listUnpack(Stack & stack,size_t num_outputs)263 void listUnpack(Stack& stack, size_t num_outputs) {
264 auto list = pop(stack).toList();
265 TORCH_CHECK(
266 list.size() == num_outputs,
267 "Expected ",
268 num_outputs,
269 " elements in a list but found ",
270 list.size());
271 stack.insert(stack.end(), list.begin(), list.end());
272 }
273
tupleConstruct(Stack & stack,size_t num_inputs)274 void tupleConstruct(Stack& stack, size_t num_inputs) {
275 if (num_inputs > stack.size()) {
276 TORCH_CHECK(false, "Invalid number of inputs: ", num_inputs);
277 }
278 switch (num_inputs) {
279 case 0:
280 stack.emplace_back(c10::ivalue::Tuple::create());
281 break;
282 case 1:
283 stack.back() = c10::ivalue::Tuple::create(std::move(stack.back()));
284 break;
285 case 2: {
286 auto tuple = c10::ivalue::Tuple::create(
287 std::move(stack[stack.size() - 2]),
288 std::move(stack[stack.size() - 1]));
289 stack.pop_back();
290 stack.back() = std::move(tuple);
291 break;
292 }
293 case 3: {
294 auto tuple = c10::ivalue::Tuple::create(
295 std::move(stack[stack.size() - 3]),
296 std::move(stack[stack.size() - 2]),
297 std::move(stack[stack.size() - 1]));
298 stack.pop_back();
299 stack.pop_back();
300 stack.back() = std::move(tuple);
301 break;
302 }
303 default: {
304 std::vector<IValue> elems{
305 std::make_move_iterator(stack.end() - num_inputs),
306 std::make_move_iterator(stack.end())};
307 drop(stack, num_inputs - 1);
308 stack.back() = c10::ivalue::Tuple::create(std::move(elems));
309 break;
310 }
311 }
312 }
313
namedTupleConstruct(Stack & stack,c10::TypePtr tuple_type,size_t num_inputs)314 void namedTupleConstruct(
315 Stack& stack,
316 c10::TypePtr tuple_type,
317 size_t num_inputs) {
318 std::vector<IValue> elems{
319 std::make_move_iterator(stack.end() - num_inputs),
320 std::make_move_iterator(stack.end())};
321 drop(stack, num_inputs);
322 push(
323 stack,
324 c10::ivalue::Tuple::createNamed(std::move(elems), std::move(tuple_type)));
325 }
326
listConstruct(Stack & stack,const c10::Type & list_type,size_t num_inputs)327 void listConstruct(
328 Stack& stack,
329 const c10::Type& list_type,
330 size_t num_inputs) {
331 // Structuring the implementation this way allows NRVO to avoid
332 // move-constructing vals on its way onto the stack. Moving a List
333 // isn't free.
334 auto makeList =
335 [](Stack& stack, const c10::Type& list_type, size_t num_inputs) {
336 c10::List<IValue> vals(list_type.containedType(0));
337 vals.reserve(num_inputs);
338 for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) {
339 vals.push_back(std::move(stack[i]));
340 }
341 drop(stack, num_inputs);
342 return vals;
343 };
344 stack.emplace_back(makeList(stack, list_type, num_inputs));
345 }
346
dictConstruct(Stack & stack,const c10::Type & dict_type,size_t num_inputs)347 void dictConstruct(
348 Stack& stack,
349 const c10::Type& dict_type,
350 size_t num_inputs) {
351 auto vals = c10::impl::GenericDict(
352 dict_type.containedType(0), dict_type.containedType(1));
353 vals.reserve(num_inputs / 2);
354 // loop from the bottom of the stack to ensure the dictConstruct preserve
355 // the inputs order.
356 auto inputs = last(stack, num_inputs);
357 for (size_t i = 0; i < num_inputs; i += 2) {
358 auto key = inputs[i];
359 auto val = inputs[i + 1];
360 vals.insert_or_assign(std::move(key), std::move(val));
361 }
362 drop(stack, num_inputs);
363 push(stack, std::move(vals));
364 }
365
createObject(Stack & stack,const at::ClassTypePtr & type,bool as_weak_ref)366 void createObject(
367 Stack& stack,
368 const at::ClassTypePtr& type,
369 bool as_weak_ref) {
370 if (as_weak_ref) {
371 c10::WeakTypePtr weak(type->compilation_unit(), type);
372 auto userObj = c10::ivalue::Object::create(
373 c10::WeakOrStrongTypePtr(weak), type->numAttributes());
374 push(stack, std::move(userObj));
375 } else {
376 auto userObj = c10::ivalue::Object::create(
377 c10::StrongTypePtr(type->compilation_unit(), type),
378 type->numAttributes());
379 push(stack, std::move(userObj));
380 }
381 }
382
isinstance(Stack & stack,at::ArrayRef<at::TypePtr> types)383 void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
384 at::TypePtr ty = pop(stack).type();
385 for (const at::TypePtr& candidate : types) {
386 if (ty->isSubtypeOf(*candidate)) {
387 push(stack, true);
388 return;
389 }
390 }
391
392 push(stack, false);
393 }
394
tupleSlice(Stack & stack,size_t begin,size_t end)395 void tupleSlice(Stack& stack, size_t begin, size_t end) {
396 auto tuple = pop(stack).toTuple();
397 push(
398 stack,
399 c10::ivalue::Tuple::create(
400 tuple->elements().asArrayRef().slice(begin, end - begin)));
401 }
402
dequantize(Stack & stack)403 void dequantize(Stack& stack) {
404 auto iv = pop(stack);
405 if (iv.isTuple()) {
406 auto tuple = iv.toTuple();
407 const auto& elems = tuple->elements();
408 std::vector<IValue> output_elems;
409 output_elems.reserve(elems.size());
410 for (const auto& elem : elems) {
411 if (elem.isTensor()) {
412 output_elems.emplace_back(at::dequantize(elem.toTensor()));
413 } else {
414 output_elems.emplace_back(elem);
415 }
416 }
417 push(stack, c10::ivalue::Tuple::create(std::move(output_elems)));
418 } else if (iv.isTensorList()) {
419 auto elems = iv.toTensorList();
420 auto output_list = c10::impl::GenericList(elems.elementType());
421 for (auto&& elem : elems) {
422 output_list.emplace_back(at::dequantize(elem));
423 }
424 push(stack, std::move(output_list));
425 } else {
426 TORCH_CHECK(
427 false,
428 "Unsupported type in dequantize, only List[Tensor] and \
429 Tuple[Tensor or other types] are supported, got type:",
430 toString(iv.type()));
431 }
432 }
433
434 } // namespace torch::jit
435