1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/Dispatch_v2.h>
4 #include <ATen/native/Copy.h>
5 #include <ATen/native/UnaryOps.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/cpu/CopyKernel.h>
8 #include <ATen/native/cpu/Loops.h>
9 #include <c10/util/TypeCast.h>
10 #include <ATen/native/cpu/zmath.h>
11 #include <ATen/TensorIteratorInternal.h>
12 #include <ATen/Parallel.h>
13 #include <ATen/cpu/vec/functional.h>
14 namespace at::native {
15 inline namespace CPU_CAPABILITY {
16
17 namespace {
reduced_input(ScalarType input_t,ScalarType output_t)18 static bool reduced_input(ScalarType input_t, ScalarType output_t) {
19 return !at::isFloat8Type(input_t) && at::isReducedFloatingType(input_t) &&
20 output_t == kFloat;
21 }
22
reduced_output(ScalarType input_t,ScalarType output_t)23 static bool reduced_output(ScalarType input_t, ScalarType output_t) {
24 return !at::isFloat8Type(output_t) && at::isReducedFloatingType(output_t) &&
25 input_t == kFloat;
26 }
27 } // namespace
28
reduced_float_type_copy(bool requires_conj,TensorIteratorBase & iter)29 static bool reduced_float_type_copy(
30 bool requires_conj,
31 TensorIteratorBase& iter) {
32 auto strides_out = iter.strides(0);
33 auto strides_in = iter.strides(1);
34
35 // Check whether input is in BFloat16/Half data type and output is in float
36 // data type, or input is in float data type and output is in BFloat16/Half
37 // data type. In addition, input and output need contiguous parts to utilize
38 // vectorization.
39 return (
40 !requires_conj &&
41 ((reduced_input(iter.dtype(1), iter.dtype(0)) &&
42 sizeof(float) == strides_out[0] &&
43 (static_cast<int64_t>(elementSize(iter.dtype(1))) == strides_in[0] ||
44 strides_in[0] == 0)) ||
45 (reduced_output(iter.dtype(1), iter.dtype(0)) &&
46 static_cast<int64_t>(elementSize(iter.dtype(0))) == strides_out[0] &&
47 (sizeof(float) == strides_in[0] || strides_in[0] == 0))));
48 }
49
reduced_float_copy_kernel(TensorIteratorBase & iter,bool requires_neg)50 static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_neg) {
51 auto strides_out = iter.strides(0);
52 auto strides_in = iter.strides(1);
53 auto shape = iter.shape();
54 c10::SmallBuffer<int64_t, 8> strides(2 * std::max(iter.ndim(), 2));
55 auto get_strides = [](int64_t* strides, IntArrayRef strides_out, IntArrayRef strides_in, int64_t ndim) {
56 for (const auto dim : c10::irange(ndim)) {
57 for (const auto arg : c10::irange(2)) {
58 *strides++ = arg == 0? strides_out[dim] : strides_in[dim];
59 }
60 }
61 // Always at least 2d strides to support 2d for_each loops
62 if (ndim < 2) {
63 std::fill_n(strides, (2 - ndim) * 2, 0);
64 }
65 };
66 get_strides(strides.data(), strides_out, strides_in, iter.ndim());
67 if (reduced_input(iter.dtype(1), iter.dtype(0))) {
68 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(1), "copy_kernel", [&]() {
69 using dest_t = float;
70 using Vecd = Vectorized<dest_t>;
71 using Vecs = Vectorized<scalar_t>;
72 c10::SmallBuffer<char*, 2> ptrs(2);
73 dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
74 scalar_t* input_data = const_cast<scalar_t*>(iter.tensor_base(1).const_data_ptr<scalar_t>());
75 ptrs[0] = reinterpret_cast<char*>(output_data);
76 ptrs[1] = reinterpret_cast<char*>(input_data);
77
78 int64_t grain_size = at::internal::GRAIN_SIZE;
79
80 auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
81 std::array<char*, 2> data;
82 std::copy_n(base, 2, data.data());
83 const int64_t *outer_strides = &strides[2];
84
85 for (const auto it C10_UNUSED : c10::irange(size1)) {
86 Vecd dst_s;
87 if (strides_in[0] == 0) {
88 dst_s = Vecd(dest_t(*((scalar_t*)data[1])));
89 if (requires_neg) {
90 dst_s = dst_s.neg();
91 }
92 }
93 int64_t i = 0;
94 for (; i <= size0 - Vecs::size(); i += Vecs::size()) {
95 if (strides_in[0] != 0) {
96 Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t));
97 auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
98 if (requires_neg) {
99 data_vec0 = data_vec0.neg();
100 data_vec1 = data_vec1.neg();
101 }
102 data_vec0.store(data[0] + i * sizeof(dest_t));
103 data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t));
104 } else {
105 dst_s.store(data[0] + i * sizeof(dest_t));
106 dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t));
107 }
108 }
109 if (i < size0) {
110 if (strides_in[0] != 0) {
111 Vecs data_vec = Vecs::loadu(data[1] + i * sizeof(scalar_t), size0 - i);
112 auto [data_vec0, data_vec1] = convert_to_float<scalar_t>(data_vec);
113 if (requires_neg) {
114 data_vec0 = data_vec0.neg();
115 data_vec1 = data_vec1.neg();
116 }
117 data_vec0.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())? Vecd::size() : (size0 - i));
118 data_vec1.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0);
119 } else {
120 dst_s.store(data[0] + i * sizeof(dest_t), ((size0 - i) > Vecd::size())? Vecd::size() : (size0 - i));
121 dst_s.store(data[0] + (i + Vecd::size()) * sizeof(dest_t), ((size0 - i) > Vecd::size())? (size0 - i - Vecd::size()) : 0);
122 }
123 }
124 data[0] += outer_strides[0];
125 data[1] += outer_strides[1];
126 }
127
128 };
129
130 parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) {
131 at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end});
132 });
133 });
134 } else if (reduced_output(iter.dtype(1), iter.dtype(0))) {
135 AT_DISPATCH_REDUCED_FLOATING_TYPES(iter.dtype(0), "copy_kernel", [&]() {
136 using dest_t = scalar_t;
137 using source_t = float;
138 using Vecd = Vectorized<dest_t>;
139 using Vecs = Vectorized<source_t>;
140 c10::SmallBuffer<char*, 2> ptrs(2);
141 dest_t* output_data = iter.tensor_base(0).data_ptr<dest_t>();
142 source_t* input_data = const_cast<source_t*>(iter.tensor_base(1).const_data_ptr<source_t>());
143
144 ptrs[0] = reinterpret_cast<char*>(output_data);
145 ptrs[1] = reinterpret_cast<char*>(input_data);
146
147 int64_t grain_size = at::internal::GRAIN_SIZE;
148
149 auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
150 std::array<char*, 2> data;
151 std::copy_n(base, 2, data.data());
152 const int64_t *outer_strides = &strides[2];
153
154 for (const auto it C10_UNUSED : c10::irange(size1)) {
155 Vecd dst_s;
156 if (strides_in[0] == 0) {
157 dst_s = Vecd(dest_t(*((source_t*)data[1])));
158 if (requires_neg) {
159 dst_s = dst_s.neg();
160 }
161 }
162 int64_t i = 0;
163 for (; i <= size0 - 2 * Vecs::size(); i += 2 * Vecs::size()) {
164 if (strides_in[0] != 0) {
165 Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(source_t));
166 Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(source_t));
167 auto data_vec = convert_from_float<dest_t>(data_vec0, data_vec1);
168 if (requires_neg) {
169 data_vec = data_vec.neg();
170 }
171 data_vec.store(data[0] + i * sizeof(dest_t));
172 } else {
173 dst_s.store(data[0] + i * sizeof(dest_t));
174 }
175
176 }
177 if (i < size0) {
178 if (strides_in[0] != 0) {
179 Vecs data_vec0 = Vecs::loadu(data[1] + i * sizeof(source_t), ((size0 - i) > Vecs::size())? Vecs::size() : (size0 - i));
180 Vecs data_vec1 = Vecs::loadu(data[1] + (i + Vecs::size()) * sizeof(source_t), ((size0 - i) > Vecs::size())? (size0 - i - Vecs::size()) : 0);
181 auto data_vec = convert_from_float<dest_t>(data_vec0, data_vec1);
182 if (requires_neg) {
183 data_vec = data_vec.neg();
184 }
185 data_vec.store(data[0] + i * sizeof(dest_t), size0 - i);
186 } else {
187 dst_s.store(data[0] + i * sizeof(dest_t), size0 - i);
188 }
189 }
190 data[0] += outer_strides[0];
191 data[1] += outer_strides[1];
192 }
193
194 };
195 parallel_for(0, iter.numel(), grain_size, [&] (int64_t begin, int64_t end) {
196 at::internal::serial_for_each(shape, strides, ptrs.data(), 2, loop, {begin, end});
197 });
198 });
199
200 }
201 }
202
203 #if !defined(C10_MOBILE)
204 #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
205 AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
206 kComplexHalf, kHalf, kBool, \
207 kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
208 kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
209 #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
210 AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
211 kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
212 kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
213 #else
214 #define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
215 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
216 ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \
217 TYPE, NAME, __VA_ARGS__)
218 #define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
219 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
220 kBool, kHalf, kBFloat16, \
221 TYPE, NAME, __VA_ARGS__)
222 #endif
223
direct_copy_kernel(TensorIteratorBase & iter)224 void direct_copy_kernel(TensorIteratorBase &iter) {
225 // TODO: we don't actually need separate instantiations per dtype;
226 // we only need a separate instantiation per dtype size. This would
227 // probably save us a little bit of code size here
228 // TODO: not sure if optimizer is able to compile two levels of
229 // conditionals into a single jump table. We should have a
230 // single jump table here; might be worth just writing out the
231 // dispatch statement by hand instead of using AT_DISPATCH
232 ScalarType dtype = iter.dtype(0);
233 if (isQIntType(dtype)) {
234 AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
235 cpu_kernel_vec(
236 iter,
237 [=](scalar_t a) -> scalar_t { return a; },
238 [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
239 });
240 } else if (dtype == ScalarType::ComplexHalf) {
241 cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
242 } else if (isBitsType(dtype)) {
243 AT_DISPATCH_BIT_TYPES(dtype, "copy_kernel", [&] {
244 cpu_kernel(
245 iter,
246 [=](scalar_t a) -> scalar_t { return a; });
247 });
248 } else {
249 _AT_DISPATCH_ALL_TYPES_NO_CF(dtype, "copy_kernel", [&] {
250 cpu_kernel_vec(
251 iter,
252 [=](scalar_t a) -> scalar_t { return a; },
253 [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a; });
254 });
255 }
256 }
257
neg_conj_kernel(TensorIteratorBase & iter)258 static void neg_conj_kernel(TensorIteratorBase &iter) {
259 // fused a = b.neg().conj_physical()
260 AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cpu", [&] {
261 cpu_kernel_vec(
262 iter,
263 [=](scalar_t a) -> scalar_t { return -conj_impl(a); },
264 [=](Vectorized<scalar_t> a) -> Vectorized<scalar_t> { return a.neg().conj(); });
265 });
266 }
267
copy_same_dtype(TensorIteratorBase & iter,bool requires_conj,bool requires_neg)268 static void copy_same_dtype(TensorIteratorBase &iter, bool requires_conj, bool requires_neg) {
269 if (requires_neg) {
270 // This case should never actually happen since currently there's no way to get a complex tensor
271 // with negative bit.
272 if (requires_conj) {
273 neg_conj_kernel(iter);
274 } else {
275 neg_kernel(iter);
276 }
277 } else {
278 if (requires_conj) {
279 conj_kernel(iter);
280 } else {
281 direct_copy_kernel(iter);
282 }
283 }
284 }
285
copy_kernel(TensorIterator & iter,bool)286 void copy_kernel(TensorIterator& iter, bool /*non_blocking*/) {
287 ScalarType dtype = iter.dtype(0);
288 const bool requires_conj = (
289 isComplexType(dtype) && (iter.tensor_base(0).is_conj() != iter.tensor_base(1).is_conj()));
290 const bool requires_neg = (iter.tensor_base(0).is_neg() != iter.tensor_base(1).is_neg());
291
292 if (dtype == iter.dtype(1)) {
293 copy_same_dtype(iter, requires_conj, requires_neg);
294 } else if (reduced_float_type_copy(requires_conj, iter)) {
295 reduced_float_copy_kernel(iter, requires_neg);
296 } else {
297 _AT_DISPATCH_ALL_TYPES(dtype, "copy_", [&] {
298 using dest_t = scalar_t;
299 _AT_DISPATCH_ALL_TYPES(iter.dtype(1), "copy_", [&] {
300 if (iter.has_contiguous_first_dim()) {
301 TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
302 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
303
304 iter.for_each([](char **data, const int64_t *strides, int64_t size) {
305 auto src = reinterpret_cast<const scalar_t*>(data[1]);
306 auto dst = reinterpret_cast<dest_t*>(data[0]);
307 at::vec::convert(src, dst, size);
308 });
309 } else {
310 cpu_kernel(iter, [](scalar_t x) -> dest_t {
311 return c10::convert<dest_t>(x);
312 });
313 }
314 });
315 });
316
317 if (requires_conj || requires_neg) {
318 // This inplace "copy" will perform any missing neg or conj operations
319 auto self = iter.tensor_base(0);
320 auto iter = TensorIterator::unary_op(self, self);
321 copy_same_dtype(iter, requires_conj, requires_neg);
322 }
323 }
324 }
325
326 } // namespace CPU_CAPABILITY
327
328 REGISTER_DISPATCH(copy_stub, ©_kernel);
329
330 } // namespace at::native
331