1 #include <ATen/ScalarOps.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
4
5 namespace torch::jit {
6
tupleIndex(Stack & stack)7 void tupleIndex(Stack& stack) {
8 int64_t index = pop(stack).toInt();
9 auto tuple = pop(stack).toTuple();
10 auto norm_index =
11 normalizeIndex(index, static_cast<int64_t>(tuple->elements().size()));
12 if (norm_index < 0 ||
13 norm_index >= static_cast<int64_t>(tuple->elements().size())) {
14 throw std::out_of_range("Tuple list index out of range");
15 }
16 stack.emplace_back(tuple->elements()[norm_index]);
17 }
18
raiseException(Stack & stack)19 void raiseException(Stack& stack) {
20 // this kernel supports RaiseException with only one argument: the error
21 // DEPRECATED from bytecode_version 8;
22 // Please do not make any changes to this to support BC
23 throw JITException(pop(stack).toStringRef());
24 }
25
raiseExceptionWithMessage(Stack & stack)26 void raiseExceptionWithMessage(Stack& stack) {
27 // this kernel supports RaiseException with only two arguments: the error and
28 // the message Please make changes only to this kernel
29 std::optional<std::string> qualified_class_name =
30 pop(stack).toOptional<std::string>();
31 std::string message;
32 pop(stack, message);
33
34 throw JITException(message, qualified_class_name);
35 }
36
is(Stack & stack)37 void is(Stack& stack) {
38 IValue self, obj;
39 pop(stack, self, obj);
40 push(stack, self.is(obj));
41 }
42
unInitialized(Stack & stack)43 void unInitialized(Stack& stack) {
44 push(stack, IValue::uninitialized());
45 }
46
isNot(Stack & stack)47 void isNot(Stack& stack) {
48 IValue self, obj;
49 pop(stack, self, obj);
50 push(stack, !self.is(obj));
51 }
52
aten_format(Stack & stack)53 void aten_format(Stack& stack) {
54 size_t num_inputs = pop(stack).toInt();
55 format(stack, num_inputs);
56 }
57
size(Stack & stack)58 void size(Stack& stack) {
59 auto t = std::move(pop(stack)).toTensor();
60 pack(stack, t.sizes().vec());
61 }
62
sym_size(Stack & stack)63 void sym_size(Stack& stack) {
64 auto t = std::move(pop(stack)).toTensor();
65 pack(stack, t.sym_sizes().vec());
66 }
sym_size_int(Stack & stack)67 void sym_size_int(Stack& stack) {
68 auto dim = pop(stack).toInt();
69 auto t = pop(stack).toTensor();
70 push(stack, t.sym_sizes()[dim]);
71 }
sym_stride_int(Stack & stack)72 void sym_stride_int(Stack& stack) {
73 auto dim = pop(stack).toInt();
74 auto t = pop(stack).toTensor();
75 push(stack, t.sym_strides()[dim]);
76 }
77
sym_numel(Stack & stack)78 void sym_numel(Stack& stack) {
79 auto t = std::move(pop(stack)).toTensor();
80 push(stack, t.sym_numel());
81 }
82
sym_storage_offset(Stack & stack)83 void sym_storage_offset(Stack& stack) {
84 auto t = std::move(pop(stack)).toTensor();
85 push(stack, t.sym_storage_offset());
86 }
87
sym_stride(Stack & stack)88 void sym_stride(Stack& stack) {
89 auto t = std::move(pop(stack)).toTensor();
90 pack(stack, t.sym_strides().vec());
91 }
92
device(Stack & stack)93 void device(Stack& stack) {
94 push(stack, pop(stack).toTensor().device());
95 }
96
device_with_index(Stack & stack)97 void device_with_index(Stack& stack) {
98 std::string type = pop(stack).toStringRef();
99 auto index = pop(stack).toInt();
100 std::string device_str = fmt::format("{}:{}", type, index);
101 auto device = c10::Device(device_str);
102 push(stack, device);
103 }
104
dtype(Stack & stack)105 void dtype(Stack& stack) {
106 at::Tensor a;
107 pop(stack, a);
108 push(stack, static_cast<int64_t>(a.scalar_type()));
109 }
110
layout(Stack & stack)111 void layout(Stack& stack) {
112 push(stack, pop(stack).toTensor().layout());
113 }
114
toPrimDType(Stack & stack)115 void toPrimDType(Stack& stack) {
116 bool non_blocking = false;
117 bool copy = false;
118 pop(stack, non_blocking, copy);
119 std::optional<at::ScalarType> scalarType =
120 pop(stack).toOptional<at::ScalarType>();
121 std::optional<c10::Device> device = std::nullopt;
122 at::Tensor self = pop(stack).toTensor();
123 push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
124 }
125
dim(Stack & stack)126 void dim(Stack& stack) {
127 at::Tensor arg = pop(stack).toTensor();
128 push(stack, arg.dim());
129 }
130
_not(Stack & stack)131 void _not(Stack& stack) {
132 push(stack, !pop(stack).toBool());
133 }
134
boolTensor(Stack & stack)135 void boolTensor(Stack& stack) {
136 at::Tensor a;
137 pop(stack, a);
138 push(stack, at::native::is_nonzero(a));
139 }
140
toList(Stack & stack)141 void toList(Stack& stack) {
142 int elem_ty_val = 0;
143 int dim_val = 0;
144 at::Tensor t;
145
146 pop(stack, elem_ty_val);
147 pop(stack, dim_val);
148 pop(stack, t);
149
150 // If the Tensor is not on the CPU, transfer it.
151 if (!t.device().is_cpu()) {
152 t = t.cpu();
153 }
154
155 // Rebuild the output type using elem_ty_val and dim_val. Start
156 // with the element type corresponding to elem_ty_val.
157 at::TypePtr out_ty;
158 if (elem_ty_val == 0) {
159 out_ty = at::IntType::get();
160 } else if (elem_ty_val == 1) {
161 out_ty = at::FloatType::get();
162 } else if (elem_ty_val == 2) {
163 out_ty = at::BoolType::get();
164 } else if (elem_ty_val == 3) {
165 out_ty = at::ComplexType::get();
166 } else {
167 TORCH_CHECK(
168 false,
169 "Unsupported element type for tolist; only int, float, complex and bool are supported");
170 }
171
172 // Check that type of the Tensor matches that of the annotation.
173 // Make an exception for the case in which the annotated type is
174 // float/complex and the Tensor data type is also float/complex;
175 // the elements will be casted to double/c10::complex<double>
176 // later.
177 TORCH_CHECK(
178 (out_ty == at::FloatType::get() && t.is_floating_point()) ||
179 (out_ty == at::ComplexType::get() && t.is_complex()) ||
180 tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
181 "Output annotation element type and runtime tensor element type must match for tolist(): ",
182 *tryScalarTypeFromJitType(*out_ty),
183 " vs ",
184 t.scalar_type());
185
186 // Check that the dimension of the Tensor matches that of the
187 // annotation.
188 TORCH_CHECK(
189 dim_val == t.dim(),
190 "Output annotation list dimension and runtime tensor dimension must match for tolist()");
191
192 // Wrap out_ty in a ListType dim times.
193 for (const auto i : c10::irange(dim_val)) {
194 (void)i; // Suppress unused variable warning
195 out_ty = at::ListType::create(out_ty);
196 }
197
198 int64_t dim = t.dim();
199 auto sizes = t.sizes();
200 auto strides = t.strides();
201 size_t element_size = t.element_size();
202 char* data = static_cast<char*>(t.data_ptr());
203 auto result = tensorToListRecursive(
204 data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
205 push(stack, std::move(result));
206 }
207
numToTensorScalar(Stack & stack)208 void numToTensorScalar(Stack& stack) {
209 at::Scalar s;
210 pop(stack, s);
211 push(stack, c10::scalar_to_tensor(s));
212 }
213
isCuda(Stack & stack)214 void isCuda(Stack& stack) {
215 at::Tensor a;
216 pop(stack, a);
217 push(stack, a.is_cuda());
218 }
219
numToTensorBool(Stack & stack)220 void numToTensorBool(Stack& stack) {
221 bool b = false;
222 pop(stack, b);
223 push(stack, c10::scalar_to_tensor(b));
224 }
225
dictIndex(Stack & stack)226 void dictIndex(Stack& stack) {
227 auto key = pop(stack);
228 auto dict = pop(stack).toGenericDict();
229 auto value = dict.find(key);
230 if (value == dict.end()) {
231 AT_ERROR("KeyError: ", key);
232 }
233 push(stack, value->value());
234 }
235
236 static const C10_UNUSED std::array<mobile::prim_op_fn_register, 16> op_reg = {
237 mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
238 mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
239 mobile::prim_op_fn_register("aten::format", aten_format),
240 mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
241 mobile::prim_op_fn_register(
242 "prim::RaiseException",
243 raiseExceptionWithMessage),
244 mobile::prim_op_fn_register("prim::device", device),
245 mobile::prim_op_fn_register("prim::dtype", dtype),
246 mobile::prim_op_fn_register("prim::layout", layout),
247 mobile::prim_op_fn_register("aten::__not__", _not),
248 mobile::prim_op_fn_register("aten::__is__", is),
249 mobile::prim_op_fn_register("aten::__isnot__", isNot),
250 mobile::prim_op_fn_register("aten::dim", dim),
251 mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
252 mobile::prim_op_fn_register("prim::is_cuda", isCuda),
253 mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
254 mobile::prim_op_fn_register("prim::unchecked_cast", noop),
255 // TODO: (@pavithran) size is overloaded with int[] and Tensor
256 // so this throws error expecting int not Tensor
257 // mobile::prim_op_fn_register("aten::size", size)
258 };
259
260 } // namespace torch::jit
261