xref: /aosp_15_r20/external/pytorch/test/inductor/custom_ops.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/api/include/torch/types.h>
2 
3 #include <cstdint>
4 #include <iostream>
5 #include <string>
6 
7 namespace at {
8 
custom_add_impl(Tensor t1,Tensor t2)9 Tensor custom_add_impl(Tensor t1, Tensor t2) {
10   return t1 + t2;
11 }
12 
fn_with_all_inputs_impl(const Tensor & tensor,const c10::List<Tensor> & tensors,const c10::List<std::optional<Tensor>> & optional_tensors,const bool b8,const c10::List<bool> & b8s,const int64_t i64,const c10::List<int64_t> & i64s,const int64_t & symint,const IntArrayRef symints,const double f64,const c10::List<double> & f64s,const at::Scalar & scalar,at::ArrayRef<at::Scalar> scalars,const std::string & string,const std::vector<std::string> & strings,const Device & device,const std::optional<Tensor> & o_tensor,const std::optional<c10::List<Tensor>> & o_tensors,const std::optional<bool> & o_b8,const std::optional<c10::List<bool>> & o_b8s,const std::optional<int64_t> & o_i64,const std::optional<c10::List<int64_t>> & o_i64s,const std::optional<int64_t> & o_symint,const std::optional<IntArrayRef> & o_symints,const std::optional<double> & o_f64,const std::optional<c10::List<double>> & o_f64s,const std::optional<at::Scalar> & o_scalar,const std::optional<at::ArrayRef<at::Scalar>> & o_scalars,const std::optional<std::string> & o_string,const std::optional<std::vector<std::string>> & o_strings,const std::optional<Device> & o_device)13 Tensor fn_with_all_inputs_impl(
14     const Tensor& tensor,
15     const c10::List<Tensor>& tensors,
16     const c10::List<std::optional<Tensor>>& optional_tensors,
17     const bool b8,
18     const c10::List<bool>& b8s,
19     const int64_t i64,
20     const c10::List<int64_t>& i64s,
21     const int64_t& symint,
22     const IntArrayRef symints,
23     const double f64,
24     const c10::List<double>& f64s,
25     const at::Scalar& scalar,
26     at::ArrayRef<at::Scalar> scalars,
27     const std::string& string,
28     const std::vector<std::string>& strings,
29     // const c10::ScalarType& dtype,
30     // const MemoryFormat& memory_format,
31     // const Layout& layout,
32     const Device& device,
33     // optional
34     const std::optional<Tensor>& o_tensor,
35     const std::optional<c10::List<Tensor>>& o_tensors,
36     const std::optional<bool>& o_b8,
37     const std::optional<c10::List<bool>>& o_b8s,
38     const std::optional<int64_t>& o_i64,
39     const std::optional<c10::List<int64_t>>& o_i64s,
40     const std::optional<int64_t>& o_symint,
41     const std::optional<IntArrayRef>& o_symints,
42     const std::optional<double>& o_f64,
43     const std::optional<c10::List<double>>& o_f64s,
44     const std::optional<at::Scalar>& o_scalar,
45     const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
46     const std::optional<std::string>& o_string,
47     const std::optional<std::vector<std::string>>& o_strings,
48     // const std::optional<c10::ScalarType>& o_dtype,
49     // const std::optional<MemoryFormat>& o_memory_format,
50     // const std::optional<Layout>& o_layout,
51     const std::optional<Device>& o_device) {
52   std::cout << "tensor shape: " << tensor.sizes() << std::endl;
53 
54   std::cout << "tensors shape: ";
55   for (auto t : tensors) {
56     std::cout << t.get().toTensor().sizes() << ", ";
57   }
58   std::cout << std::endl;
59 
60   std::cout << "optional tensors shape: ";
61   for (auto t : optional_tensors) {
62     if (t.get().toOptional<Tensor>().has_value()) {
63       std::cout << t.get().toTensor().sizes() << ", ";
64     } else {
65       std::cout << "None, ";
66     }
67   }
68   std::cout << std::endl;
69 
70   std::cout << "b8 " << c10::IValue(b8) << std::endl;
71   std::cout << "b8s " << c10::IValue(b8s) << std::endl;
72   std::cout << "i64 " << c10::IValue(i64) << std::endl;
73   std::cout << "i64s " << c10::IValue(i64s) << std::endl;
74   std::cout << "symint " << c10::IValue(symint) << std::endl;
75   std::cout << "symints " << c10::IValue(symints) << std::endl;
76   std::cout << "f64 " << c10::IValue(f64) << std::endl;
77   std::cout << "f64s " << c10::IValue(f64s) << std::endl;
78   std::cout << "scalar " << c10::IValue(scalar) << std::endl;
79   std::cout << "scalars " << c10::IValue(scalars) << std::endl;
80   std::cout << "string " << c10::IValue(string) << std::endl;
81   std::cout << "strings " << c10::IValue(strings) << std::endl;
82   // std::cout << "dtype " << c10::IValue(dtype) << std::endl;
83   // std::cout << "memory_format " << c10::IValue(memory_format) << std::endl;
84   // std::cout << "layout " << c10::IValue(layout) << std::endl;
85   std::cout << "device " << c10::IValue(device) << std::endl;
86 
87   std::cout << "o_tensor "
88             << (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes())
89                                      : "None")
90             << std::endl;
91 
92   std::cout << "o_tensors shape: ";
93   if (o_tensors.has_value()) {
94     for (auto t : o_tensors.value()) {
95       std::cout << t.get().toTensor().sizes() << ", ";
96     }
97   } else {
98     std::cout << "None";
99   }
100   std::cout << std::endl;
101 
102   std::cout << "o_b8 "
103             << (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None")
104             << std::endl;
105   std::cout << "o_b8s "
106             << (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None")
107             << std::endl;
108   std::cout << "o_i64 "
109             << (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None")
110             << std::endl;
111   std::cout << "o_i64s "
112             << (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None")
113             << std::endl;
114   std::cout << "o_symint "
115             << (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None")
116             << std::endl;
117   std::cout << "o_symints "
118             << (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None")
119             << std::endl;
120   std::cout << "o_f64 "
121             << (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None")
122             << std::endl;
123   std::cout << "o_f64s "
124             << (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None")
125             << std::endl;
126   std::cout << "o_scalar "
127             << (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None")
128             << std::endl;
129   std::cout << "o_scalars "
130             << (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None")
131             << std::endl;
132   std::cout << "o_string "
133             << (o_string.has_value() ? c10::IValue(o_string.value()) : "None")
134             << std::endl;
135   std::cout << "o_strings "
136             << (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None")
137             << std::endl;
138   // std::cout << "o_dtype "
139   //           << (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None")
140   //           << std::endl;
141   // std::cout << "o_memory_format "
142   //           << (o_memory_format.has_value()
143   //                   ? c10::IValue(o_memory_format.value())
144   //                   : "None")
145   //           << std::endl;
146   // std::cout << "o_layout "
147   //           << (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None")
148   //           << std::endl;
149   std::cout << "o_device "
150             << (o_device.has_value() ? c10::IValue(o_device.value()) : "None")
151             << std::endl;
152 
153   int64_t int_hash = 0;
154   int_hash ^= i64;
155   for (auto i : i64s) {
156     int_hash ^= i;
157   }
158   if (o_i64.has_value()) {
159     int_hash ^= o_i64.value();
160   }
161   if (o_i64s.has_value()) {
162     for (auto i : o_i64s.value()) {
163       int_hash ^= i;
164     }
165   }
166 
167   int_hash ^= symint;
168   for (auto i : symints) {
169     int_hash ^= i;
170   }
171   if (o_symint.has_value()) {
172     int_hash ^= o_symint.value();
173   }
174   if (o_symints.has_value()) {
175     for (auto i : o_symints.value()) {
176       int_hash ^= i;
177     }
178   }
179 
180   return tensor + int_hash;
181 }
182 
fn_with_default_input_impl(const Tensor & tensor,const int64_t i64)183 Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) {
184   return tensor + i64;
185 }
186 
fn_with_tuple_output_impl(const Tensor & tensor,const int64_t i64)187 std::tuple<Tensor, Tensor> fn_with_tuple_output_impl(
188     const Tensor& tensor,
189     const int64_t i64) {
190   return {tensor + i64, tensor - i64};
191 }
192 
fn_with_list_output_impl(TensorList tensors,const int64_t i64)193 std::vector<Tensor> fn_with_list_output_impl(
194     TensorList tensors,
195     const int64_t i64) {
196   std::vector<Tensor> outputs;
197   for (auto& t : tensors) {
198     outputs.emplace_back(t + i64);
199   }
200   return outputs;
201 }
202 
fn_with_mix_outputs_impl(const Tensor & tensor,TensorList tensors)203 std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_impl(
204     const Tensor& tensor,
205     TensorList tensors) {
206   std::vector<Tensor> outputs;
207   for (auto& t : tensors) {
208     outputs.emplace_back(t + 2);
209   }
210   return {tensor + 1, outputs};
211 }
212 
fn_with_input_mutation_impl(Tensor & t0,const Tensor & t1,Tensor & t2)213 std::tuple<Tensor, Tensor> fn_with_input_mutation_impl(
214     Tensor& t0,
215     const Tensor& t1,
216     Tensor& t2) {
217   t0.add_(1);
218   t2.sub_(1);
219   return {t1 + 1, t1 + 2};
220 }
221 
222 // NOLINTBEGIN(clang-diagnostic-unused-parameter)
fn_with_all_inputs_meta(const Tensor & tensor,const c10::List<Tensor> & tensors,const c10::List<std::optional<Tensor>> & optional_tensors,const bool b8,const c10::List<bool> & b8s,const int64_t i64,const c10::List<int64_t> & i64s,const c10::SymInt & symint,c10::SymIntArrayRef symints,const double f64,const c10::List<double> & f64s,const at::Scalar & scalar,at::ArrayRef<at::Scalar> scalars,const std::string & string,const std::vector<std::string> & strings,const Device & device,const std::optional<Tensor> & o_tensor,const std::optional<c10::List<Tensor>> & o_tensors,const std::optional<bool> & o_b8,const std::optional<c10::List<bool>> & o_b8s,const std::optional<int64_t> & o_i64,const std::optional<c10::List<int64_t>> & o_i64s,const std::optional<c10::SymInt> & o_symint,at::OptionalSymIntArrayRef o_symints,const std::optional<double> & o_f64,const std::optional<c10::List<double>> & o_f64s,const std::optional<at::Scalar> & o_scalar,const std::optional<at::ArrayRef<at::Scalar>> & o_scalars,const std::optional<std::string> & o_string,const std::optional<std::vector<std::string>> & o_strings,const std::optional<Device> & o_device)223 Tensor fn_with_all_inputs_meta(
224     const Tensor& tensor,
225     const c10::List<Tensor>& tensors,
226     const c10::List<std::optional<Tensor>>& optional_tensors,
227     const bool b8,
228     const c10::List<bool>& b8s,
229     const int64_t i64,
230     const c10::List<int64_t>& i64s,
231     const c10::SymInt& symint,
232     c10::SymIntArrayRef symints,
233     const double f64,
234     const c10::List<double>& f64s,
235     const at::Scalar& scalar,
236     at::ArrayRef<at::Scalar> scalars,
237     const std::string& string,
238     const std::vector<std::string>& strings,
239     // const c10::ScalarType& dtype,
240     // const MemoryFormat& memory_format,
241     // const Layout& layout,
242     const Device& device,
243     // optional
244     const std::optional<Tensor>& o_tensor,
245     const std::optional<c10::List<Tensor>>& o_tensors,
246     const std::optional<bool>& o_b8,
247     const std::optional<c10::List<bool>>& o_b8s,
248     const std::optional<int64_t>& o_i64,
249     const std::optional<c10::List<int64_t>>& o_i64s,
250     const std::optional<c10::SymInt>& o_symint,
251     at::OptionalSymIntArrayRef o_symints,
252     const std::optional<double>& o_f64,
253     const std::optional<c10::List<double>>& o_f64s,
254     const std::optional<at::Scalar>& o_scalar,
255     const std::optional<at::ArrayRef<at::Scalar>>& o_scalars,
256     const std::optional<std::string>& o_string,
257     const std::optional<std::vector<std::string>>& o_strings,
258     // const std::optional<c10::ScalarType>& o_dtype,
259     // const std::optional<MemoryFormat>& o_memory_format,
260     // const std::optional<Layout>& o_layout,
261     const std::optional<Device>& o_device) {
262   return tensor;
263 }
264 
fn_with_default_input_meta(const Tensor & tensor,const int64_t i64)265 Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) {
266   return tensor.clone();
267 }
268 
fn_with_tuple_output_meta(const Tensor & tensor,const int64_t i64)269 std::tuple<Tensor, Tensor> fn_with_tuple_output_meta(
270     const Tensor& tensor,
271     const int64_t i64) {
272   return {tensor.clone(), tensor.clone()};
273 }
274 
fn_with_list_output_meta(TensorList tensors,const int64_t i64)275 std::vector<Tensor> fn_with_list_output_meta(
276     TensorList tensors,
277     const int64_t i64) {
278   std::vector<Tensor> outputs;
279   for (auto& t : tensors) {
280     outputs.push_back(t.clone());
281   }
282   return outputs;
283 }
284 
fn_with_mix_outputs_meta(const Tensor & tensor,TensorList tensors)285 std::tuple<Tensor, std::vector<Tensor>> fn_with_mix_outputs_meta(
286     const Tensor& tensor,
287     TensorList tensors) {
288   std::vector<Tensor> outputs;
289   for (auto& t : tensors) {
290     outputs.push_back(t.clone());
291   }
292   return {tensor.clone(), outputs};
293 }
294 
fn_with_input_mutation_meta(Tensor & t0,const Tensor & t1,Tensor & t2)295 std::tuple<Tensor, Tensor> fn_with_input_mutation_meta(
296     Tensor& t0,
297     const Tensor& t1,
298     Tensor& t2) {
299   return {t1.clone(), t1.clone()};
300 }
301 
302 } // namespace at
303 
TORCH_LIBRARY(aoti_custom_ops,m)304 TORCH_LIBRARY(aoti_custom_ops, m) {
305   m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
306   m.def(
307       "fn_with_all_inputs(Tensor tensor, "
308       "Tensor[] tensors, "
309       "Tensor?[] optional_tensors, "
310       "bool b8, bool[] b8s, "
311       "int i64, int[] i64s, "
312       "SymInt symint, SymInt[] symints, "
313       "float f64, float[] f64s, "
314       "Scalar scalar, Scalar[] scalars, "
315       "str string, str[] strings, "
316       // "ScalarType dtype, "
317       // "MemoryFormat memory_format, "
318       // "Layout layout, "
319       "Device device, "
320       "*, "
321       "Tensor? o_tensor, Tensor[]? o_tensors, "
322       "bool? o_b8, bool[]? o_b8s, "
323       "int? o_i64, int[]? o_i64s, "
324       "SymInt? o_symint, SymInt[]? o_symints, "
325       "float? o_f64, float[]? o_f64s, "
326       "Scalar? o_scalar, Scalar[]? o_scalars, "
327       "str? o_string, str[]? o_strings, "
328       // "ScalarType? o_dtype, "
329       // "MemoryFormat? o_memory_format, "
330       // "Layout? o_layout, "
331       "Device? o_device) -> Tensor");
332 
333   m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor");
334 
335   m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)");
336 
337   m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]");
338 
339   m.def(
340       "fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])");
341 
342   m.def(
343       "fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)");
344 
345 }
346 
TORCH_LIBRARY_IMPL(aoti_custom_ops,CompositeExplicitAutograd,m)347 TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
348   m.impl("custom_add", at::custom_add_impl);
349   m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
350   m.impl("fn_with_default_input", at::fn_with_default_input_impl);
351   m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
352   m.impl("fn_with_list_output", at::fn_with_list_output_impl);
353   m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl);
354   m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl);
355 }
356 
TORCH_LIBRARY_IMPL(aoti_custom_ops,Meta,m)357 TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
358   m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
359   m.impl("fn_with_default_input", at::fn_with_default_input_meta);
360   m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
361   m.impl("fn_with_list_output", at::fn_with_list_output_meta);
362   m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta);
363   m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta);
364 }
365