xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/promoted_prim_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ScalarOps.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/jit/mobile/promoted_prim_ops.h>
4 
5 namespace torch::jit {
6 
tupleIndex(Stack & stack)7 void tupleIndex(Stack& stack) {
8   int64_t index = pop(stack).toInt();
9   auto tuple = pop(stack).toTuple();
10   auto norm_index =
11       normalizeIndex(index, static_cast<int64_t>(tuple->elements().size()));
12   if (norm_index < 0 ||
13       norm_index >= static_cast<int64_t>(tuple->elements().size())) {
14     throw std::out_of_range("Tuple list index out of range");
15   }
16   stack.emplace_back(tuple->elements()[norm_index]);
17 }
18 
raiseException(Stack & stack)19 void raiseException(Stack& stack) {
20   // this kernel supports RaiseException with only one argument: the error
21   // DEPRECATED from bytecode_version 8;
22   // Please do not make any changes to this to support BC
23   throw JITException(pop(stack).toStringRef());
24 }
25 
raiseExceptionWithMessage(Stack & stack)26 void raiseExceptionWithMessage(Stack& stack) {
27   // this kernel supports RaiseException with only two arguments: the error and
28   // the message Please make changes only to this kernel
29   std::optional<std::string> qualified_class_name =
30       pop(stack).toOptional<std::string>();
31   std::string message;
32   pop(stack, message);
33 
34   throw JITException(message, qualified_class_name);
35 }
36 
is(Stack & stack)37 void is(Stack& stack) {
38   IValue self, obj;
39   pop(stack, self, obj);
40   push(stack, self.is(obj));
41 }
42 
unInitialized(Stack & stack)43 void unInitialized(Stack& stack) {
44   push(stack, IValue::uninitialized());
45 }
46 
isNot(Stack & stack)47 void isNot(Stack& stack) {
48   IValue self, obj;
49   pop(stack, self, obj);
50   push(stack, !self.is(obj));
51 }
52 
aten_format(Stack & stack)53 void aten_format(Stack& stack) {
54   size_t num_inputs = pop(stack).toInt();
55   format(stack, num_inputs);
56 }
57 
size(Stack & stack)58 void size(Stack& stack) {
59   auto t = std::move(pop(stack)).toTensor();
60   pack(stack, t.sizes().vec());
61 }
62 
sym_size(Stack & stack)63 void sym_size(Stack& stack) {
64   auto t = std::move(pop(stack)).toTensor();
65   pack(stack, t.sym_sizes().vec());
66 }
sym_size_int(Stack & stack)67 void sym_size_int(Stack& stack) {
68   auto dim = pop(stack).toInt();
69   auto t = pop(stack).toTensor();
70   push(stack, t.sym_sizes()[dim]);
71 }
sym_stride_int(Stack & stack)72 void sym_stride_int(Stack& stack) {
73   auto dim = pop(stack).toInt();
74   auto t = pop(stack).toTensor();
75   push(stack, t.sym_strides()[dim]);
76 }
77 
sym_numel(Stack & stack)78 void sym_numel(Stack& stack) {
79   auto t = std::move(pop(stack)).toTensor();
80   push(stack, t.sym_numel());
81 }
82 
sym_storage_offset(Stack & stack)83 void sym_storage_offset(Stack& stack) {
84   auto t = std::move(pop(stack)).toTensor();
85   push(stack, t.sym_storage_offset());
86 }
87 
sym_stride(Stack & stack)88 void sym_stride(Stack& stack) {
89   auto t = std::move(pop(stack)).toTensor();
90   pack(stack, t.sym_strides().vec());
91 }
92 
device(Stack & stack)93 void device(Stack& stack) {
94   push(stack, pop(stack).toTensor().device());
95 }
96 
device_with_index(Stack & stack)97 void device_with_index(Stack& stack) {
98   std::string type = pop(stack).toStringRef();
99   auto index = pop(stack).toInt();
100   std::string device_str = fmt::format("{}:{}", type, index);
101   auto device = c10::Device(device_str);
102   push(stack, device);
103 }
104 
dtype(Stack & stack)105 void dtype(Stack& stack) {
106   at::Tensor a;
107   pop(stack, a);
108   push(stack, static_cast<int64_t>(a.scalar_type()));
109 }
110 
layout(Stack & stack)111 void layout(Stack& stack) {
112   push(stack, pop(stack).toTensor().layout());
113 }
114 
toPrimDType(Stack & stack)115 void toPrimDType(Stack& stack) {
116   bool non_blocking = false;
117   bool copy = false;
118   pop(stack, non_blocking, copy);
119   std::optional<at::ScalarType> scalarType =
120       pop(stack).toOptional<at::ScalarType>();
121   std::optional<c10::Device> device = std::nullopt;
122   at::Tensor self = pop(stack).toTensor();
123   push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
124 }
125 
dim(Stack & stack)126 void dim(Stack& stack) {
127   at::Tensor arg = pop(stack).toTensor();
128   push(stack, arg.dim());
129 }
130 
_not(Stack & stack)131 void _not(Stack& stack) {
132   push(stack, !pop(stack).toBool());
133 }
134 
boolTensor(Stack & stack)135 void boolTensor(Stack& stack) {
136   at::Tensor a;
137   pop(stack, a);
138   push(stack, at::native::is_nonzero(a));
139 }
140 
toList(Stack & stack)141 void toList(Stack& stack) {
142   int elem_ty_val = 0;
143   int dim_val = 0;
144   at::Tensor t;
145 
146   pop(stack, elem_ty_val);
147   pop(stack, dim_val);
148   pop(stack, t);
149 
150   // If the Tensor is not on the CPU, transfer it.
151   if (!t.device().is_cpu()) {
152     t = t.cpu();
153   }
154 
155   // Rebuild the output type using elem_ty_val and dim_val. Start
156   // with the element type corresponding to elem_ty_val.
157   at::TypePtr out_ty;
158   if (elem_ty_val == 0) {
159     out_ty = at::IntType::get();
160   } else if (elem_ty_val == 1) {
161     out_ty = at::FloatType::get();
162   } else if (elem_ty_val == 2) {
163     out_ty = at::BoolType::get();
164   } else if (elem_ty_val == 3) {
165     out_ty = at::ComplexType::get();
166   } else {
167     TORCH_CHECK(
168         false,
169         "Unsupported element type for tolist; only int, float, complex and bool are supported");
170   }
171 
172   // Check that type of the Tensor matches that of the annotation.
173   // Make an exception for the case in which the annotated type is
174   // float/complex and the Tensor data type is also float/complex;
175   // the elements will be casted to double/c10::complex<double>
176   // later.
177   TORCH_CHECK(
178       (out_ty == at::FloatType::get() && t.is_floating_point()) ||
179           (out_ty == at::ComplexType::get() && t.is_complex()) ||
180           tryScalarTypeFromJitType(*out_ty) == t.scalar_type(),
181       "Output annotation element type and runtime tensor element type must match for tolist(): ",
182       *tryScalarTypeFromJitType(*out_ty),
183       " vs ",
184       t.scalar_type());
185 
186   // Check that the dimension of the Tensor matches that of the
187   // annotation.
188   TORCH_CHECK(
189       dim_val == t.dim(),
190       "Output annotation list dimension and runtime tensor dimension must match for tolist()");
191 
192   // Wrap out_ty in a ListType dim times.
193   for (const auto i : c10::irange(dim_val)) {
194     (void)i; // Suppress unused variable warning
195     out_ty = at::ListType::create(out_ty);
196   }
197 
198   int64_t dim = t.dim();
199   auto sizes = t.sizes();
200   auto strides = t.strides();
201   size_t element_size = t.element_size();
202   char* data = static_cast<char*>(t.data_ptr());
203   auto result = tensorToListRecursive(
204       data, 0, dim, out_ty, t.scalar_type(), sizes, strides, element_size);
205   push(stack, std::move(result));
206 }
207 
numToTensorScalar(Stack & stack)208 void numToTensorScalar(Stack& stack) {
209   at::Scalar s;
210   pop(stack, s);
211   push(stack, c10::scalar_to_tensor(s));
212 }
213 
isCuda(Stack & stack)214 void isCuda(Stack& stack) {
215   at::Tensor a;
216   pop(stack, a);
217   push(stack, a.is_cuda());
218 }
219 
numToTensorBool(Stack & stack)220 void numToTensorBool(Stack& stack) {
221   bool b = false;
222   pop(stack, b);
223   push(stack, c10::scalar_to_tensor(b));
224 }
225 
dictIndex(Stack & stack)226 void dictIndex(Stack& stack) {
227   auto key = pop(stack);
228   auto dict = pop(stack).toGenericDict();
229   auto value = dict.find(key);
230   if (value == dict.end()) {
231     AT_ERROR("KeyError: ", key);
232   }
233   push(stack, value->value());
234 }
235 
236 static const C10_UNUSED std::array<mobile::prim_op_fn_register, 16> op_reg = {
237     mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex),
238     mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor),
239     mobile::prim_op_fn_register("aten::format", aten_format),
240     mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar),
241     mobile::prim_op_fn_register(
242         "prim::RaiseException",
243         raiseExceptionWithMessage),
244     mobile::prim_op_fn_register("prim::device", device),
245     mobile::prim_op_fn_register("prim::dtype", dtype),
246     mobile::prim_op_fn_register("prim::layout", layout),
247     mobile::prim_op_fn_register("aten::__not__", _not),
248     mobile::prim_op_fn_register("aten::__is__", is),
249     mobile::prim_op_fn_register("aten::__isnot__", isNot),
250     mobile::prim_op_fn_register("aten::dim", dim),
251     mobile::prim_op_fn_register("prim::Uninitialized", unInitialized),
252     mobile::prim_op_fn_register("prim::is_cuda", isCuda),
253     mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex),
254     mobile::prim_op_fn_register("prim::unchecked_cast", noop),
255     // TODO: (@pavithran) size is overloaded with int[] and Tensor
256     // so this throws error expecting int not Tensor
257     // mobile::prim_op_fn_register("aten::size", size)
258 };
259 
260 } // namespace torch::jit
261