xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/CopyKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/native/Copy.h>
5 #include <ATen/native/UnaryOps.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/cpu/CopyKernel.h>
8 #include <ATen/native/cpu/Loops.h>
9 #include <c10/util/TypeCast.h>
10 #include <ATen/native/cpu/zmath.h>
11 #include <ATen/TensorIteratorInternal.h>
12 #include <ATen/Parallel.h>
13 #include <ATen/cpu/vec/functional.h>
14 namespace at::native {
15 inline namespace CPU_CAPABILITY {
16 
17 namespace {
reduced_input(ScalarType input_t,ScalarType output_t)18 static bool reduced_input(ScalarType input_t, ScalarType output_t) {
19   return !at::isFloat8Type(input_t) && at::isReducedFloatingType(input_t) &&
20       output_t == kFloat;
21 }
22 
reduced_output(ScalarType input_t,ScalarType output_t)23 static bool reduced_output(ScalarType input_t, ScalarType output_t) {
24   return !at::isFloat8Type(output_t) && at::isReducedFloatingType(output_t) &&
25       input_t == kFloat;
26 }
27 } // namespace
28 
reduced_float_type_copy(bool requires_conj,TensorIteratorBase & iter)29 static bool reduced_float_type_copy(
30     bool requires_conj,
31     TensorIteratorBase& iter) {
32   auto strides_out = iter.strides(0);
33   auto strides_in = iter.strides(1);
34 
35   // Check whether input is in BFloat16/Half data type and output is in float
36   // data type, or input is in float data type and output is in BFloat16/Half
37   // data type. In addition, input and output need contiguous parts to utilize
38   // vectorization.
39   return (
40       !requires_conj &&
41       ((reduced_input(iter.dtype(1), iter.dtype(0)) &&
42         sizeof(float) == strides_out[0] &&
43         (static_cast<int64_t>(elementSize(iter.dtype(1))) == strides_in[0] ||
44          strides_in[0] == 0)) ||
45        (reduced_output(iter.dtype(1), iter.dtype(0)) &&
46         static_cast<int64_t>(elementSize(iter.dtype(0))) == strides_out[0] &&
47         (sizeof(float) == strides_in[0] || strides_in[0] == 0))));
48 }
49 
reduced_float_copy_kernel(TensorIteratorBase & iter,bool requires_neg)50 static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_neg) {
51   auto strides_out = iter.strides(0);
52   auto strides_in = iter.strides(1);
53   auto shape = iter.shape();
54   c10::SmallBuffer<int64_t, 8> strides(2 * std::max(iter.ndim(), 2));
55   auto get_strides = [](int64_t* strides, IntArrayRef strides_out, IntArrayRef strides_in, int64_t ndim) {
56       for (const auto dim : c10::irange(ndim)) {
57         for (const auto arg : c10::irange(2)) {
58           *strides++ = arg == 0? strides_out[dim] : strides_in[dim];
59         }
60       }
61       // Always at least 2d strides to support 2d for_each loops
62       if (ndim < 2) {
63         std::fill_n(strides, (2 - ndim) * 2, 0);
64       }
65     };
66   get_strides(strides.data(), strides_out, strides_in, iter.ndim());
67   if (reduced_input(iter.dtype(1), iter.dtype(0))) {
68     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(1), "copy_kernel", [&]() {
69       using dest_t = float;
70       using Vecd = Vectorized<dest_t>;
71       using Vecs = Vectorized<scalar_t>;
72       c10::SmallBuffer<char*, 2> ptrs(2);
73       dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
74       scalar_t* input_data = const_cast<scalar_t*>(iter.tensor_base(1).const_data_ptr<scalar_t>());
75       ptrs[0] = reinterpret_cast<char*>(output_data);
76       ptrs[1] = reinterpret_cast<char*>(input_data);
77 
78       int64_t grain_size = at::internal::GRAIN_SIZE;
79 
80       auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
81         std::array<char*, 2> data;
82         std::copy_n(base, 2, data.data());
83         const int64_t *outer_strides = &strides[2];
84 
85         for (const auto it C10_UNUSED : c10::irange(size1)) {
86           Vecd dst_s;
87           if (strides_in[0] == 0) {
88             dst_s = Vecd(dest_t(*((scalar_t*)data[1])));
89             if (requires_neg) {
90               dst_s = dst_s.neg();
91             }
92           }
93           int64_t i = 0;
94           for (; i <= size0 - Vecs::size(); i += Vecs::size()) {
95             if (strides_in[0] != 0) {
96               Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t));
97               auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
98               if (requires_neg) {
99                 data_vec0 = data_vec0.neg();
100                 data_vec1 = data_vec1.neg();
101               }
102               data_vec0.store(data[0] + i * sizeof(dest_t));
103               data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t));
104             } else {
105               dst_s.store(data[0] + i * sizeof(dest_t));
106               dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t));
107             }
108           }
109           if (i < size0) {
110             if (strides_in[0] != 0) {
111               Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t), size0 - i);
112               auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
113               if (requires_neg) {
114                 data_vec0 = data_vec0.neg();
115                 data_vec1 = data_vec1.neg();
116               }
117               data_vec0.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())?  Vecd::size() : (size0 - i));
118               data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0);
119             } else {
120               dst_s.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())?  Vecd::size() : (size0 - i));
121               dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0);
122             }
123           }
124           data[0] += outer_strides[0];
125           data[1] += outer_strides[1];
126         }
127 
128       };
129 
130       parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) {
131         at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end});
132       });
133     });
134   } else if (reduced_output(iter.dtype(1), iter.dtype(0))) {
135     AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(0), "copy_kernel", [&]() {
136       using dest_t = scalar_t;
137       using source_t = float;
138       using Vecd = Vectorized<dest_t>;
139       using Vecs = Vectorized<source_t>;
140       c10::SmallBuffer<char*, 2> ptrs(2);
141       dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
142       source_t* input_data = const_cast<source_t*>(iter.tensor_base(1).const_data_ptr<source_t>());
143 
144       ptrs[0] = reinterpret_cast<char*>(output_data);
145       ptrs[1] = reinterpret_cast<char*>(input_data);
146 
147       int64_t grain_size = at::internal::GRAIN_SIZE;
148 
149       auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
150         std::array<char*, 2> data;
151         std::copy_n(base, 2, data.data());
152         const int64_t *outer_strides = &strides[2];
153 
154         for (const auto it C10_UNUSED : c10::irange(size1)) {
155           Vecd dst_s;
156           if (strides_in[0] == 0) {
157             dst_s = Vecd(dest_t(*((source_t*)data[1])));
158             if (requires_neg) {
159               dst_s = dst_s.neg();
160             }
161           }
162           int64_t i = 0;
163           for (; i <= size0 - 2 * Vecs::size(); i += 2 * Vecs::size()) {
164             if (strides_in[0] != 0) {
165               Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(source_t));
166               Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(source_t));
167               auto data_vec = convert_from_float<dest_t>(data_vec0, data_vec1);
168               if (requires_neg) {
169                 data_vec = data_vec.neg();
170               }
171               data_vec.store(data[0] + i * sizeof(dest_t));
172             } else {
173               dst_s.store(data[0] + i * sizeof(dest_t));
174             }
175 
176           }
177           if (i < size0) {
178             if (strides_in[0] != 0) {
179               Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(source_t), ((size0 - i) > Vecs::size())?  Vecs::size() : (size0 - i));
180               Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(source_t), ((size0 - i) > Vecs::size())?  (size0 - i - Vecs::size()) : 0);
181               auto data_vec = convert_from_float<dest_t>(data_vec0, data_vec1);
182               if (requires_neg) {
183                 data_vec = data_vec.neg();
184               }
185               data_vec.store(data[0] + i * sizeof(dest_t), size0 - i);
186             } else {
187               dst_s.store(data[0] + i * sizeof(dest_t), size0 - i);
188             }
189           }
190           data[0] += outer_strides[0];
191           data[1] += outer_strides[1];
192         }
193 
194       };
195       parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) {
196         at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end});
197       });
198     });
199 
200   }
201 }
202 
203 #if !defined(C10_MOBILE)
204 #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                                       \
205         AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__),                                       \
206             kComplexHalf, kHalf, kBool,              \
207             kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208             kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209 #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...)              \
210         AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__),                    \
211             kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212             kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213 #else
214 #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                                               \
215         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                                               \
216             ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \
217             TYPE, NAME, __VA_ARGS__)
218 #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
219         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(       \
220             kBool, kHalf, kBFloat16,                  \
221             TYPE, NAME, __VA_ARGS__)
222 #endif
223 
direct_copy_kernel(TensorIteratorBase & iter)224 void direct_copy_kernel(TensorIteratorBase &iter) {
225   // TODO: we don't actually need separate instantiations per dtype;
226   // we only need a separate instantiation per dtype size. This would
227   // probably save us a little bit of code size here
228   // TODO: not sure if optimizer is able to compile two levels of
229   // conditionals into a single jump table.  We should have a
230   // single jump table here; might be worth just writing out the
231   // dispatch statement by hand instead of using AT_DISPATCH
232   ScalarType dtype = iter.dtype(0);
233   if (isQIntType(dtype)) {
234     AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
235       cpu_kernel_vec(
236           iter,
237           [=](scalar_t a) -> scalar_t { return a; },
238           [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
239     });
240   } else if (dtype == ScalarType::ComplexHalf) {
241     cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
242   } else if (isBitsType(dtype)) {
243     AT_DISPATCH_BIT_TYPES(dtype, "copy_kernel", [&] {
244       cpu_kernel(
245           iter,
246           [=](scalar_t a) -> scalar_t { return a; });
247     });
248   } else {
249     _AT_DISPATCH_ALL_TYPES_NO_CF(dtype, "copy_kernel", [&] {
250       cpu_kernel_vec(
251           iter,
252           [=](scalar_t a) -> scalar_t { return a; },
253           [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
254     });
255   }
256 }
257 
neg_conj_kernel(TensorIteratorBase & iter)258 static void neg_conj_kernel(TensorIteratorBase &iter) {
259   // fused a = b.neg().conj_physical()
260   AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cpu", [&] {
261     cpu_kernel_vec(
262         iter,
263         [=](scalar_t a) -> scalar_t { return -conj_impl(a); },
264         [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg().conj(); });
265   });
266 }
267 
copy_same_dtype(TensorIteratorBase & iter,bool requires_conj,bool requires_neg)268 static void copy_same_dtype(TensorIteratorBase &iter, bool requires_conj, bool requires_neg) {
269   if (requires_neg) {
270     // This case should never actually happen since currently there's no way to get a complex tensor
271     // with negative bit.
272     if (requires_conj) {
273       neg_conj_kernel(iter);
274     } else {
275       neg_kernel(iter);
276     }
277   } else {
278     if (requires_conj) {
279       conj_kernel(iter);
280     } else {
281       direct_copy_kernel(iter);
282     }
283   }
284 }
285 
copy_kernel(TensorIterator & iter,bool)286 void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
287   ScalarType dtype = iter.dtype(0);
288   const bool requires_conj = (
289       isComplexType(dtype) && (iter.tensor_base(0).is_conj() != iter.tensor_base(1).is_conj()));
290   const bool requires_neg = (iter.tensor_base(0).is_neg() != iter.tensor_base(1).is_neg());
291 
292   if (dtype == iter.dtype(1)) {
293     copy_same_dtype(iter, requires_conj, requires_neg);
294   } else if (reduced_float_type_copy(requires_conj, iter)) {
295     reduced_float_copy_kernel(iter, requires_neg);
296   } else {
297     _AT_DISPATCH_ALL_TYPES(dtype, "copy_", [&] {
298       using dest_t = scalar_t;
299       _AT_DISPATCH_ALL_TYPES(iter.dtype(1), "copy_", [&] {
300         if (iter.has_contiguous_first_dim()) {
301           TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
302           TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
303 
304           iter.for_each([](char **data, const int64_t *strides, int64_t size) {
305             auto src = reinterpret_cast<const scalar_t*>(data[1]);
306             auto dst = reinterpret_cast<dest_t*>(data[0]);
307             at::vec::convert(src, dst, size);
308           });
309         } else {
310           cpu_kernel(iter, [](scalar_t x) -> dest_t {
311             return c10::convert<dest_t>(x);
312           });
313         }
314       });
315     });
316 
317     if (requires_conj || requires_neg) {
318       // This inplace "copy" will perform any missing neg or conj operations
319       auto self = iter.tensor_base(0);
320       auto iter = TensorIterator::unary_op(self, self);
321       copy_same_dtype(iter, requires_conj, requires_neg);
322     }
323   }
324 }
325 
326 } // namespace CPU_CAPABILITY
327 
328 REGISTER_DISPATCH(copy_stub, &copy_kernel);
329 
330 } // namespace at::native
331