1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Copy.h>
3 #include <ATen/native/Copy.h>
4
5 #include <ATen/core/Tensor.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/FunctionalTensorWrapper.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/native/quantized/Copy.h>
12 #include <ATen/native/mps/Copy.h>
13 #include <ATen/native/vulkan/ops/Copy.h>
14 #include <ATen/native/TensorShape.h>
15 #include <ATen/quantized/Quantizer.h>
16 #include <ATen/vulkan/Context.h>
17 #include <ATen/metal/Context.h>
18 #include <ATen/NamedTensorUtils.h>
19 #include <ATen/Parallel.h>
20 #include <c10/util/irange.h>
21
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_copy_from.h>
27 #include <ATen/ops/_propagate_xla_data.h>
28 #include <ATen/ops/_propagate_xla_data_native.h>
29 #include <ATen/ops/copy.h>
30 #include <ATen/ops/copy_native.h>
31 #include <ATen/ops/_foreach_copy.h>
32 #include <ATen/ops/_foreach_copy_native.h>
33 #include <ATen/ops/empty.h>
34 #include <ATen/ops/empty_strided.h>
35 #include <ATen/ops/expand_copy.h>
36 #endif
37
38 #ifdef USE_FBGEMM
39 #include <fbgemm/Fbgemm.h>
40 #include <fbgemm/FbgemmConvert.h>
41 #endif
42
43 namespace {
44
45 using namespace at;
46
copy_transpose_valid(const Tensor & self,const Tensor & src)47 bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
48 const int MIN_SZ = 60 * 60;
49 return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 &&
50 src.stride(0) == 1 && src.stride(1) == src.size(0) &&
51 self.scalar_type() == src.scalar_type() &&
52 !isBitsType(self.scalar_type()) &&
53 self.sizes().equals(src.sizes()) &&
54 self.is_neg() == src.is_neg() &&
55 self.is_conj() == src.is_conj() &&
56 self.numel() >= MIN_SZ;
57 }
58
59 #if !defined(C10_MOBILE)
60 #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
61 AT_DISPATCH_V2( \
62 TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, \
63 kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
64 #else
65 #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...) \
66 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
67 kComplexHalf, kHalf, kBool, kBFloat16, \
68 TYPE, NAME, __VA_ARGS__)
69 #endif
70
71 // special case copy where tensor is contiguous and src is a transposed matrix
72 // This can be generalized to most copies, but it's trickier
copy_same_type_transpose_(Tensor & self,const Tensor & src)73 void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
74 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
75 int64_t BLOCK_SZ;
76 if (self.scalar_type() == kByte) {
77 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
78 BLOCK_SZ = 120;
79 } else {
80 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
81 BLOCK_SZ = 60;
82 }
83 Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
84
85 // The code below is implemented with the assumption that sizes are equal
86 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
87
88 _AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
89 const scalar_t* sp = src.const_data_ptr<scalar_t>();
90 scalar_t* rp = self.data_ptr<scalar_t>();
91 scalar_t* bp = buf.data_ptr<scalar_t>();
92
93 int64_t NR = src.size(0);
94 int64_t NC = src.size(1);
95 for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
96 for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
97 const scalar_t* spo = sp + R + C * NR;
98 scalar_t* rpo = rp + C + R * NC;
99
100 int nr = std::min(NR - R, BLOCK_SZ);
101 int nc = std::min(NC - C, BLOCK_SZ);
102
103 // 1. copy columns from src to buf
104 for (const auto c : c10::irange(nc)) {
105 memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
106 }
107
108 // 2. transpose buf in place
109 int rc_max = std::max(nr, nc);
110 int rc_min = std::min(nr, nc);
111 for (const auto r : c10::irange(rc_max)) {
112 int end = std::min(r, rc_min);
113 for (const auto c : c10::irange(end)) {
114 scalar_t tmp = bp[r + BLOCK_SZ * c];
115 bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
116 bp[r * BLOCK_SZ + c] = tmp;
117 }
118 }
119
120 // 3. copy rows from buf to dst
121 for (const auto r : c10::irange(nr)) {
122 memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
123 }
124 }
125 }
126 });
127 }
128
129 // Devices directly supported by this copy implementation. Other device types
130 // (e.g. XLA) may be supported by overriding copy_ and _copy_from.
is_supported_device(Device device)131 bool is_supported_device(Device device) {
132 DeviceType device_type = device.type();
133 return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS || device_type == kXPU;
134 }
135
136 } // namespace
137
138 namespace at::native {
139
copy_impl(Tensor & self,const Tensor & src,bool non_blocking)140 static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) {
141 // TODO: this should be handled during dispatch, but that's missing...
142 TORCH_CHECK(self.defined(), "self is undefined");
143 TORCH_CHECK(src.defined(), "src is undefined");
144
145 // FBGeMM kernel support exists only for the following case,
146 // 1. Memory Format for source and destination tensors is contiguous.
147 // 2. Device for both the source and destination tensor is CPU.
148 // 3. dtype conversion between FP32->FP16 and FP16->FP32.
149 // This checks that self.sizes() == src.sizes() because this code path doesn't
150 // support broadcasting. This also guards against out of bounds memory access
151 // when copying, see fbgemm::Float16ToFloat_ref.
152 // https://github.com/pytorch/pytorch/issues/88543
153 #ifdef USE_FBGEMM
154 if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) ||
155 (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) &&
156 (self.device().is_cpu() && src.device().is_cpu()) &&
157 ((self.is_contiguous() && src.is_contiguous()) ||
158 (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) &&
159 (self.sizes() == src.sizes())) {
160 if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
161 auto* output_ptr =
162 reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
163 if (self.numel() < at::internal::GRAIN_SIZE) {
164 fbgemm::FloatToFloat16_simd(src.const_data_ptr<float>(), output_ptr, self.numel());
165 } else {
166 at::parallel_for(
167 0,
168 self.numel(),
169 at::internal::GRAIN_SIZE,
170 [&](int64_t begin, int64_t end) {
171 fbgemm::FloatToFloat16_simd(
172 src.const_data_ptr<float>() + begin,
173 output_ptr + begin,
174 end - begin);
175 });
176 }
177 } else {
178 auto in_data = reinterpret_cast<const fbgemm::float16*>(
179 src.const_data_ptr<at::Half>());
180 auto* output_ptr = self.data_ptr<float>();
181 if (self.numel() < at::internal::GRAIN_SIZE) {
182 fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
183 } else {
184 at::parallel_for(
185 0,
186 self.numel(),
187 at::internal::GRAIN_SIZE,
188 [&](int64_t begin, int64_t end) {
189 fbgemm::Float16ToFloat_simd(
190 in_data + begin, output_ptr + begin, end - begin);
191 });
192 }
193 }
194 return self;
195 }
196 #endif
197
198 if (self.is_same(src)) {
199 return self;
200 }
201
202 // Copies into meta self are OK and just ignored (similar to inplace)
203 if (self.is_meta()) {
204 auto shape = infer_size_symdimvector(self.sym_sizes(), src.sym_sizes());
205 TORCH_CHECK(
206 self.sym_sizes().equals(shape),
207 "output with shape ",
208 self.sym_sizes(),
209 " doesn't match the broadcast shape ",
210 shape);
211 return self;
212 }
213
214 if (src.is_meta()) {
215 TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!")
216 }
217
218 // Re-dispatch copies when either src or self device not implemented here (e.g. XLA).
219 // _copy_from has a proper device dispatch setup.
220 // This includes:
221 // cpu_tensor.copy_(xla_tensor) => xla_tensor._copy_from(cpu_tensor)
222 // xla_tensor.copy_(cpu_tensor) => cpu_tensor._copy_from(xla_tensor)
223 // Both the _copy_from calls above will be dispatched to XLA's _copy_from kernels.
224
225 if (!is_supported_device(src.device()) || !is_supported_device(self.device())) {
226 at::_copy_from(src, self, non_blocking);
227 return self;
228 }
229
230 if (self.is_quantized() && !src.is_quantized()) {
231 return quantized_copy_from_float_(self, src);
232 }
233
234 if (self.is_quantized() && src.is_quantized()) {
235 TORCH_CHECK(self.qscheme() == src.qscheme(),
236 "Quantized Copy only works with same qscheme");
237 TORCH_CHECK(self.scalar_type() == src.scalar_type());
238 set_quantizer_(self, src.quantizer());
239 }
240
241 if (!self.is_quantized() && src.is_quantized()) {
242 TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
243 }
244
245 if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
246 #ifdef USE_VULKAN_API
247 return vulkan::ops::copy_(self, src);
248 #else
249 return at::vulkan::vulkan_copy_(self, src);
250 #endif
251 }
252
253 if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
254 return at::metal::metal_copy_(self, src);
255 }
256
257 // Exit early if self and src are views of the same data
258 const bool is_same_data = (
259 self.is_alias_of(src) &&
260 self.storage_offset() == src.storage_offset() &&
261 self.strides().equals(src.strides()) &&
262 self.sizes().equals(src.sizes()) &&
263 self.scalar_type() == src.scalar_type() &&
264 self.is_conj() == src.is_conj() &&
265 self.is_neg() == src.is_neg()
266 );
267 if (is_same_data) {
268 return self;
269 }
270
271
272 auto iter = TensorIteratorConfig()
273 .add_output(self)
274 .add_const_input(src)
275 .resize_outputs(false)
276 .check_all_same_dtype(false)
277 .check_all_same_device(false)
278 .build();
279
280 if (iter.numel() == 0) {
281 return self;
282 }
283
284 DeviceType device_type = iter.device_type(0);
285 if (iter.device_type(1) == kCUDA) {
286 device_type = kCUDA;
287 } else if (iter.device_type(1) == kHIP) {
288 device_type = kHIP;
289 } else if (iter.device_type(1) == kMPS) {
290 device_type = kMPS;
291 } else if (iter.device_type(1) == kXPU){
292 device_type = kXPU;
293 }
294
295 // TODO: if we need to, we can also enable this path for quantized tensor
296 if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) {
297 copy_same_type_transpose_(self, src);
298 return self;
299 }
300
301 #ifdef USE_MPS
302 if (self.device().type() == at::kMPS || src.device().type() == at::kMPS) {
303 return at::native::mps::mps_copy_(self, src, non_blocking);
304 }
305 #endif
306
307 if(!(self.is_complex() || self.dtype() == at::kBool) && src.is_complex()) {
308 TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part");
309 }
310 copy_stub(device_type, iter, non_blocking);
311 return self;
312 }
313
copy_meta(const Tensor & self,const Tensor & src,bool non_blocking)314 Tensor copy_meta(const Tensor& self, const Tensor& src, bool non_blocking) {
315 // Must directly use self(), so we can dispatch properly is self is a subclass
316 auto r = clone_preserve_strides(self);
317 r.copy_(src, non_blocking);
318 return r;
319 }
320
copy(const Tensor & self,const Tensor & src,bool non_blocking)321 Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
322 at::Tensor r;
323 // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
324 // (1) It isn't exposed to the frontend (no python bindings)
325 // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
326 auto self_storage = self.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
327 // If self has no real storage, we can't actually clone it.
328 // Instead, generate an empty tensor with the right sizes/strides, since we should be able to assume
329 // that copy_() will fully overwrite all data with that of src
330 if (self_storage->nbytes() == 0) {
331 r = at::empty_strided(self.sizes(), self.strides(), self.options());
332 } else {
333 r = clone_preserve_strides(self);
334 }
335 r.copy_(src, non_blocking);
336 return r;
337 }
338
_foreach_copy(at::TensorList self,at::TensorList src,bool non_blocking)339 ::std::vector<at::Tensor> _foreach_copy(at::TensorList self, at::TensorList src, bool non_blocking) {
340 std::vector<at::Tensor> outs;
341 outs.reserve(self.size());
342 // This is a very slow implementation, but needs to directly call the copy() kernel above to handle
343 // when self has zero storage.
344 // This kernel should never really be run, except with debugging using compile(backend="aot_eager")
345 for (const auto i : c10::irange(src.size())) {
346 const auto& curr_src = src[i];
347 const auto& curr_self = self[i];
348 outs.push_back(at::copy(curr_self, curr_src, non_blocking));
349 }
350 return outs;
351 }
352
copy_(Tensor & self,const Tensor & src,bool non_blocking)353 Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
354 auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
355 {
356 NoNamesGuard guard;
357 if (self._is_zerotensor()) {
358 TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
359 }
360 if (src._is_zerotensor()) {
361 return self.zero_();
362 }
363 copy_impl(self, src, non_blocking);
364 }
365 namedinference::propagate_names_if_nonempty(self, maybe_outnames);
366 return self;
367 }
368
copy_ignoring_overlaps(const TensorBase & dst,const TensorBase & src)369 void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
370 // Called when we are copying into an overlapping index `dst`, but we don't
371 // care which writer wins. Hacky but it works. This is only used by
372 // CUDA_tensor_apply2 in case that there are write overlaps.
373 // FIXME: really, overlapping writes should be illegal/an error in Torch
374 auto iter = TensorIteratorConfig()
375 .add_output(dst)
376 .add_const_input(src)
377 .resize_outputs(false)
378 .set_check_mem_overlap(false)
379 .check_all_same_dtype(true)
380 .check_all_same_device(true)
381 .build();
382 copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
383 }
384
_propagate_xla_data(const Tensor & input,const Tensor & output)385 void _propagate_xla_data(const Tensor& input, const Tensor& output) {
386 TORCH_INTERNAL_ASSERT(input.device().type() == kXLA, "This op should only be called by XLA")
387 }
388
389 DEFINE_DISPATCH(copy_stub);
390
391 } // namespace at::native
392