xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Copy.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Copy.h>
3 #include <ATen/native/Copy.h>
4 
5 #include <ATen/core/Tensor.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/Dispatch_v2.h>
8 #include <ATen/ExpandUtils.h>
9 #include <ATen/FunctionalTensorWrapper.h>
10 #include <ATen/TensorIterator.h>
11 #include <ATen/native/quantized/Copy.h>
12 #include <ATen/native/mps/Copy.h>
13 #include <ATen/native/vulkan/ops/Copy.h>
14 #include <ATen/native/TensorShape.h>
15 #include <ATen/quantized/Quantizer.h>
16 #include <ATen/vulkan/Context.h>
17 #include <ATen/metal/Context.h>
18 #include <ATen/NamedTensorUtils.h>
19 #include <ATen/Parallel.h>
20 #include <c10/util/irange.h>
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/_copy_from.h>
27 #include <ATen/ops/_propagate_xla_data.h>
28 #include <ATen/ops/_propagate_xla_data_native.h>
29 #include <ATen/ops/copy.h>
30 #include <ATen/ops/copy_native.h>
31 #include <ATen/ops/_foreach_copy.h>
32 #include <ATen/ops/_foreach_copy_native.h>
33 #include <ATen/ops/empty.h>
34 #include <ATen/ops/empty_strided.h>
35 #include <ATen/ops/expand_copy.h>
36 #endif
37 
38 #ifdef USE_FBGEMM
39 #include <fbgemm/Fbgemm.h>
40 #include <fbgemm/FbgemmConvert.h>
41 #endif
42 
43 namespace {
44 
45 using namespace at;
46 
copy_transpose_valid(const Tensor & self,const Tensor & src)47 bool copy_transpose_valid(const Tensor& self, const Tensor& src) {
48   const int MIN_SZ = 60 * 60;
49   return self.is_contiguous() && src.numel() != 0 && src.dim() == 2 &&
50       src.stride(0) == 1 && src.stride(1) == src.size(0) &&
51       self.scalar_type() == src.scalar_type() &&
52       !isBitsType(self.scalar_type()) &&
53       self.sizes().equals(src.sizes()) &&
54       self.is_neg() == src.is_neg() &&
55       self.is_conj() == src.is_conj() &&
56       self.numel() >= MIN_SZ;
57 }
58 
59 #if !defined(C10_MOBILE)
60 #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...)                              \
61         AT_DISPATCH_V2(                             \
62             TYPE, NAME, AT_WRAP(__VA_ARGS__), kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2,            \
63             kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
64 #else
65 #define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...)     \
66         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(    \
67             kComplexHalf, kHalf, kBool, kBFloat16, \
68             TYPE, NAME, __VA_ARGS__)
69 #endif
70 
71 // special case copy where tensor is contiguous and src is a transposed matrix
72 // This can be generalized to most copies, but it's trickier
copy_same_type_transpose_(Tensor & self,const Tensor & src)73 void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
74   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
75   int64_t BLOCK_SZ;
76   if (self.scalar_type() == kByte) {
77     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
78     BLOCK_SZ = 120;
79   } else {
80     // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
81     BLOCK_SZ = 60;
82   }
83   Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
84 
85   // The code below is implemented with the assumption that sizes are equal
86   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
87 
88   _AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
89     const scalar_t* sp = src.const_data_ptr<scalar_t>();
90     scalar_t* rp = self.data_ptr<scalar_t>();
91     scalar_t* bp = buf.data_ptr<scalar_t>();
92 
93     int64_t NR = src.size(0);
94     int64_t NC = src.size(1);
95     for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
96       for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
97         const scalar_t* spo = sp + R + C * NR;
98         scalar_t* rpo = rp + C + R * NC;
99 
100         int nr = std::min(NR - R, BLOCK_SZ);
101         int nc = std::min(NC - C, BLOCK_SZ);
102 
103         // 1. copy columns from src to buf
104         for (const auto c : c10::irange(nc)) {
105           memcpy(bp + c * BLOCK_SZ, spo + c * NR, nr * sizeof(scalar_t));
106         }
107 
108         // 2. transpose buf in place
109         int rc_max = std::max(nr, nc);
110         int rc_min = std::min(nr, nc);
111         for (const auto r : c10::irange(rc_max)) {
112           int end = std::min(r, rc_min);
113           for (const auto c : c10::irange(end)) {
114             scalar_t tmp = bp[r + BLOCK_SZ * c];
115             bp[r + BLOCK_SZ * c] = bp[r * BLOCK_SZ + c];
116             bp[r * BLOCK_SZ + c] = tmp;
117           }
118         }
119 
120         // 3. copy rows from buf to dst
121         for (const auto r : c10::irange(nr)) {
122           memcpy(rpo + r * NC, bp + r * BLOCK_SZ, nc * sizeof(scalar_t));
123         }
124       }
125     }
126   });
127 }
128 
129 // Devices directly supported by this copy implementation. Other device types
130 // (e.g. XLA) may be supported by overriding copy_ and _copy_from.
is_supported_device(Device device)131 bool is_supported_device(Device device) {
132   DeviceType device_type = device.type();
133   return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal || device_type == kMPS || device_type == kXPU;
134 }
135 
136 } // namespace
137 
138 namespace at::native {
139 
copy_impl(Tensor & self,const Tensor & src,bool non_blocking)140 static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) {
141   // TODO: this should be handled during dispatch, but that's missing...
142   TORCH_CHECK(self.defined(), "self is undefined");
143   TORCH_CHECK(src.defined(), "src is undefined");
144 
145   // FBGeMM kernel support exists only for the following case,
146   // 1. Memory Format for source and destination tensors is contiguous.
147   // 2. Device for both the source and destination tensor is CPU.
148   // 3. dtype conversion between FP32->FP16 and FP16->FP32.
149   // This checks that self.sizes() == src.sizes() because this code path doesn't
150   // support broadcasting. This also guards against out of bounds memory access
151   // when copying, see fbgemm::Float16ToFloat_ref.
152   // https://github.com/pytorch/pytorch/issues/88543
153   #ifdef USE_FBGEMM
154     if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) ||
155          (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) &&
156         (self.device().is_cpu() && src.device().is_cpu()) &&
157         ((self.is_contiguous() && src.is_contiguous()) ||
158          (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) &&
159         (self.sizes() == src.sizes())) {
160       if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
161         auto* output_ptr =
162             reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
163         if (self.numel() < at::internal::GRAIN_SIZE) {
164           fbgemm::FloatToFloat16_simd(src.const_data_ptr<float>(), output_ptr, self.numel());
165         } else {
166           at::parallel_for(
167               0,
168               self.numel(),
169               at::internal::GRAIN_SIZE,
170               [&](int64_t begin, int64_t end) {
171                 fbgemm::FloatToFloat16_simd(
172                     src.const_data_ptr<float>() + begin,
173                     output_ptr + begin,
174                   end - begin);
175               });
176         }
177       } else {
178         auto in_data = reinterpret_cast<const fbgemm::float16*>(
179             src.const_data_ptr<at::Half>());
180         auto* output_ptr = self.data_ptr<float>();
181         if (self.numel() < at::internal::GRAIN_SIZE) {
182           fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
183         } else {
184           at::parallel_for(
185               0,
186               self.numel(),
187               at::internal::GRAIN_SIZE,
188               [&](int64_t begin, int64_t end) {
189                 fbgemm::Float16ToFloat_simd(
190                     in_data + begin, output_ptr + begin, end - begin);
191               });
192         }
193       }
194       return self;
195     }
196   #endif
197 
198   if (self.is_same(src)) {
199     return self;
200   }
201 
202   // Copies into meta self are OK and just ignored (similar to inplace)
203   if (self.is_meta()) {
204     auto shape = infer_size_symdimvector(self.sym_sizes(), src.sym_sizes());
205     TORCH_CHECK(
206         self.sym_sizes().equals(shape),
207         "output with shape ",
208         self.sym_sizes(),
209         " doesn't match the broadcast shape ",
210         shape);
211     return self;
212   }
213 
214   if (src.is_meta()) {
215     TORCH_CHECK_NOT_IMPLEMENTED(false, "Cannot copy out of meta tensor; no data!")
216   }
217 
218   // Re-dispatch copies when either src or self device not implemented here (e.g. XLA).
219   // _copy_from has a proper device dispatch setup.
220   // This includes:
221   //   cpu_tensor.copy_(xla_tensor) => xla_tensor._copy_from(cpu_tensor)
222   //   xla_tensor.copy_(cpu_tensor) => cpu_tensor._copy_from(xla_tensor)
223   // Both the _copy_from calls above will be dispatched to XLA's _copy_from kernels.
224 
225   if (!is_supported_device(src.device()) || !is_supported_device(self.device())) {
226     at::_copy_from(src, self, non_blocking);
227     return self;
228   }
229 
230   if (self.is_quantized() && !src.is_quantized()) {
231     return quantized_copy_from_float_(self, src);
232   }
233 
234   if (self.is_quantized() && src.is_quantized()) {
235     TORCH_CHECK(self.qscheme() == src.qscheme(),
236                 "Quantized Copy only works with same qscheme");
237     TORCH_CHECK(self.scalar_type() == src.scalar_type());
238     set_quantizer_(self, src.quantizer());
239   }
240 
241   if (!self.is_quantized() && src.is_quantized()) {
242     TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
243   }
244 
245   if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
246   #ifdef USE_VULKAN_API
247     return vulkan::ops::copy_(self, src);
248   #else
249     return at::vulkan::vulkan_copy_(self, src);
250   #endif
251   }
252 
253   if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
254     return at::metal::metal_copy_(self, src);
255   }
256 
257   // Exit early if self and src are views of the same data
258   const bool is_same_data = (
259       self.is_alias_of(src) &&
260       self.storage_offset() == src.storage_offset() &&
261       self.strides().equals(src.strides()) &&
262       self.sizes().equals(src.sizes()) &&
263       self.scalar_type() == src.scalar_type() &&
264       self.is_conj() == src.is_conj() &&
265       self.is_neg() == src.is_neg()
266     );
267   if (is_same_data) {
268     return self;
269   }
270 
271 
272   auto iter = TensorIteratorConfig()
273     .add_output(self)
274     .add_const_input(src)
275     .resize_outputs(false)
276     .check_all_same_dtype(false)
277     .check_all_same_device(false)
278     .build();
279 
280   if (iter.numel() == 0) {
281     return self;
282   }
283 
284   DeviceType device_type = iter.device_type(0);
285   if (iter.device_type(1) == kCUDA) {
286     device_type = kCUDA;
287   } else if (iter.device_type(1) == kHIP) {
288     device_type = kHIP;
289   } else if (iter.device_type(1) == kMPS) {
290     device_type = kMPS;
291   } else if (iter.device_type(1) == kXPU){
292     device_type = kXPU;
293   }
294 
295   // TODO: if we need to, we can also enable this path for quantized tensor
296   if (device_type == kCPU && copy_transpose_valid(self, src) && !self.is_quantized()) {
297     copy_same_type_transpose_(self, src);
298     return self;
299   }
300 
301 #ifdef USE_MPS
302   if (self.device().type() == at::kMPS || src.device().type() == at::kMPS) {
303     return at::native::mps::mps_copy_(self, src, non_blocking);
304   }
305 #endif
306 
307   if(!(self.is_complex() || self.dtype() == at::kBool) && src.is_complex()) {
308     TORCH_WARN_ONCE("Casting complex values to real discards the imaginary part");
309   }
310   copy_stub(device_type, iter, non_blocking);
311   return self;
312 }
313 
copy_meta(const Tensor & self,const Tensor & src,bool non_blocking)314 Tensor copy_meta(const Tensor& self, const Tensor& src, bool non_blocking) {
315   // Must directly use self(), so we can dispatch properly is self is a subclass
316   auto r = clone_preserve_strides(self);
317   r.copy_(src, non_blocking);
318   return r;
319 }
320 
copy(const Tensor & self,const Tensor & src,bool non_blocking)321 Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
322   at::Tensor r;
323   // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but:
324   // (1) It isn't exposed to the frontend (no python bindings)
325   // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls.
326   auto self_storage = self.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl();
327   // If self has no real storage, we can't actually clone it.
328   // Instead, generate an empty tensor with the right sizes/strides, since we should be able to assume
329   // that copy_() will fully overwrite all data with that of src
330   if (self_storage->nbytes() == 0) {
331     r = at::empty_strided(self.sizes(), self.strides(), self.options());
332   } else {
333     r = clone_preserve_strides(self);
334   }
335   r.copy_(src, non_blocking);
336   return r;
337 }
338 
_foreach_copy(at::TensorList self,at::TensorList src,bool non_blocking)339 ::std::vector<at::Tensor> _foreach_copy(at::TensorList self, at::TensorList src, bool non_blocking) {
340   std::vector<at::Tensor> outs;
341   outs.reserve(self.size());
342   // This is a very slow implementation, but needs to directly call the copy() kernel above to handle
343   // when self has zero storage.
344   // This kernel should never really be run, except with debugging using compile(backend="aot_eager")
345   for (const auto i : c10::irange(src.size())) {
346     const auto& curr_src = src[i];
347     const auto& curr_self = self[i];
348     outs.push_back(at::copy(curr_self, curr_src, non_blocking));
349   }
350   return outs;
351 }
352 
copy_(Tensor & self,const Tensor & src,bool non_blocking)353 Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) {
354   auto maybe_outnames = namedinference::compute_broadcast_outnames(self, src);
355   {
356     NoNamesGuard guard;
357     if (self._is_zerotensor()) {
358      TORCH_CHECK(false, "ZeroTensors are immutable. Please materialize the tensor using `.clone()`, if you want a mutable zero tensor.");
359     }
360     if (src._is_zerotensor()) {
361       return self.zero_();
362     }
363     copy_impl(self, src, non_blocking);
364   }
365   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
366   return self;
367 }
368 
copy_ignoring_overlaps(const TensorBase & dst,const TensorBase & src)369 void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
370   // Called when we are copying into an overlapping index `dst`, but we don't
371   // care which writer wins. Hacky but it works. This is only used by
372   // CUDA_tensor_apply2 in case that there are write overlaps.
373   // FIXME: really, overlapping writes should be illegal/an error in Torch
374   auto iter = TensorIteratorConfig()
375       .add_output(dst)
376       .add_const_input(src)
377       .resize_outputs(false)
378       .set_check_mem_overlap(false)
379       .check_all_same_dtype(true)
380       .check_all_same_device(true)
381       .build();
382   copy_stub(iter.device_type(), iter, /*non_blocking=*/false);
383 }
384 
_propagate_xla_data(const Tensor & input,const Tensor & output)385 void _propagate_xla_data(const Tensor& input, const Tensor& output) {
386   TORCH_INTERNAL_ASSERT(input.device().type() == kXLA, "This op should only be called by XLA")
387 }
388 
389 DEFINE_DISPATCH(copy_stub);
390 
391 } // namespace at::native
392