xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Copy.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/cuda/CachingHostAllocator.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/CUDAEvent.h>
9 #include <ATen/cuda/PeerToPeerAccess.h>
10 #include <ATen/native/Copy.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cuda/Loops.cuh>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty_like.h>
18 #endif
19 
20 #include <c10/cuda/CUDACachingAllocator.h>
21 #include <c10/cuda/CUDAStream.h>
22 
23 // TODO(NS): Investigate why FP8 conversion intrinsics end up being slower
24 #ifdef AT_USE_NV_CVT_INTRINSICS
25 #include <cuda_fp8.h>
26 #endif
27 
28 namespace at::native {
29 
30 void neg_kernel_cuda(TensorIteratorBase &iter);
31 void conj_kernel_cuda(TensorIteratorBase &iter);
32 
float8_copy_kernel_cuda(TensorIteratorBase & iter)33 void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
34   ScalarType dtype = iter.dtype(0);
35   ScalarType other_dtype = iter.dtype(1);
36   if (dtype == kFloat8_e4m3fn) {
37     switch (other_dtype) {
38       case kFloat:
39          gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
40              return Float8_e4m3fn(value);
41          });
42          break;
43       case kHalf:
44          gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
45              return Float8_e4m3fn(value);
46          });
47          break;
48       case kBFloat16:
49          gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
50              return Float8_e4m3fn(value);
51          });
52          break;
53       default:
54         gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; });
55         break;
56     }
57   } else if (dtype == kFloat8_e5m2) {
58     switch (other_dtype) {
59       case kFloat:
60          gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
61 #ifdef AT_USE_NV_CVT_INTRINSICS
62              const auto x =  __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2);
63              return Float8_e5m2(x, Float8_e5m2::from_bits());
64 #else
65              return Float8_e5m2(value);
66 #endif
67          });
68          break;
69       case kHalf:
70          gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
71 #ifdef AT_USE_NV_CVT_INTRINSICS
72              const auto x =  __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2);
73              return Float8_e5m2(x, Float8_e5m2::from_bits());
74 #else
75              return Float8_e5m2(value);
76 #endif
77          });
78          break;
79       case kBFloat16:
80          gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
81 #ifdef AT_USE_NV_CVT_INTRINSICS
82              const auto x =  __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2);
83              return Float8_e5m2(x, Float8_e5m2::from_bits());
84 #else
85              return Float8_e5m2(value);
86 #endif
87          });
88          break;
89       default:
90          gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; });
91          break;
92     }
93   } else if (dtype == kFloat8_e4m3fnuz) {
94     switch (other_dtype) {
95       case kFloat:
96          gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
97              return Float8_e4m3fnuz(value);
98          });
99          break;
100       case kHalf:
101          gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
102              return Float8_e4m3fnuz(value);
103          });
104          break;
105       case kBFloat16:
106          gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
107              return Float8_e4m3fnuz(value);
108          });
109          break;
110       default:
111         gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; });
112         break;
113     }
114   } else if (dtype == kFloat8_e5m2fnuz) {
115     switch (other_dtype) {
116       case kFloat:
117          gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
118              return Float8_e5m2fnuz(value);
119          });
120          break;
121       case kHalf:
122          gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
123              return Float8_e5m2fnuz(value);
124          });
125          break;
126       case kBFloat16:
127          gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
128              return Float8_e5m2fnuz(value);
129          });
130          break;
131       default:
132          gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
133          break;
134     }
135   } else {
136     TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
137   }
138 }
139 
140 // TODO: We probably can use the opaque type trick to avoid creating duplicate
141 // kernels for equivalent bit lengths
direct_copy_kernel_cuda(TensorIteratorBase & iter)142 void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
143   ScalarType dtype = iter.dtype(0);
144   if (isQIntType(dtype)) {
145     AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
146       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
147     });
148   } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
149      float8_copy_kernel_cuda(iter);
150   } else if (isBitsType(dtype)) {
151     TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
152       "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
153     AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
154       gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
155     });
156   } else {
157     AT_DISPATCH_V2(
158         dtype, "copy_", AT_WRAP([&] {
159           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
160     }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
161   }
162 }
163 
neg_conj_kernel_cuda(TensorIteratorBase & iter)164 void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
165   AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cuda", [&] {
166     gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); });
167   });
168 }
169 
170 using namespace at::cuda;
171 
172 // device-to-device copy, does type conversion
copy_device_to_device(TensorIterator & iter,bool non_blocking,bool p2p_enabled)173 void copy_device_to_device(TensorIterator& iter,
174                            bool non_blocking,
175                            bool p2p_enabled) {
176   int64_t numel = iter.numel();
177 
178   // We can memcpy the memory if both tensors have the same type AND both
179   // tensors are contiguous after dimension coalescing and reordering.
180   bool same_type = iter.dtype(0) == iter.dtype(1);
181   bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
182   bool same_neg = iter.tensor(0).is_neg() == iter.tensor(1).is_neg();
183   bool memcpy_eligible = same_type && same_conj && same_neg && iter.is_contiguous();
184 
185   Device dst_device = iter.device(0);
186   Device src_device = iter.device(1);
187 
188   CUDAGuard device_guard(src_device);
189 
190   // We always perform the copy on the source device, using the current stream
191   // on the source device, and we fully synchronize on both src and dst's
192   // current streams for completion of the copy. We have to explicitly do this
193   // for non-contig copies. This mimics the behavior of cross-device
194   // cudaMemcpyAsync on the default stream.
195   CUDAStream copy_stream = getCurrentCUDAStream(src_device.index());
196   if (src_device != dst_device) {
197     // This is a cross-device copy on the src current stream and dst current
198     // stream. We perform a two-way barrier between both devices' streams
199     // before the copy. This ensures that any write-after-write and
200     // write-after-read dependencies on the destination side are handled, so
201     // that no one is operating on the dst memory when we perform the copy.
202     // src waits on dst barrier (src already waits on src)
203     CUDAEvent dst_ready;
204     device_guard.set_device(dst_device);
205     dst_ready.record(getCurrentCUDAStream(dst_device.index()));
206 
207     device_guard.set_device(src_device);
208     dst_ready.block(copy_stream);
209   }
210 
211   if (memcpy_eligible) {
212     void *dst = iter.data_ptr(0);
213     void *src = iter.data_ptr(1);
214     size_t size = numel * iter.element_size(0);
215     if (src != dst || src_device != dst_device) {
216       // Due to bizarre cuda driver intricacies, copies of
217       // cudaMallocAsynced memory between devices that aren't
218       // peer-to-peer-capable need "cudaMemcpyPeerAsync".
219       // So we let the allocator implement the correct call
220       // (either cudaMemcpyAsync or cudaMemcpyPeerAsync)
221       AT_CUDA_CHECK(CUDACachingAllocator::memcpyAsync(
222         dst, dst_device.index(),
223         src, src_device.index(),
224         size, copy_stream, p2p_enabled));
225     }
226   } else {
227     if (same_neg) {
228       if (!same_conj) {
229         conj_kernel_cuda(iter);
230       } else {
231         direct_copy_kernel_cuda(iter);
232       }
233     } else {
234       if (!same_conj) {
235         neg_conj_kernel_cuda(iter);
236       } else {
237         neg_kernel_cuda(iter);
238       }
239     }
240   }
241 
242   if (src_device != dst_device) {
243     // dst waits on src barrier (dst already waits on dst). We cannot
244     // operate on dst's copy until the copy is complete.
245 
246     // Still on src_device, record stream event
247     CUDAEvent src_ready;
248     src_ready.record(copy_stream);
249 
250     device_guard.set_device(dst_device);
251     src_ready.block(getCurrentCUDAStream(dst_device.index()));
252   }
253 
254   AT_CUDA_CHECK(cudaGetLastError());
255 }
256 
copy_requires_temporaries(TensorIterator & iter,bool p2p_enabled)257 static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
258   Device dst_device = iter.device(0);
259   Device src_device = iter.device(1);
260 
261   if (dst_device == src_device) {
262     // We never require temporaries for copies on the same GPU.
263     TORCH_INTERNAL_ASSERT(dst_device.is_cuda() && src_device.is_cuda());
264     return false;
265   }
266 
267   bool same_dtype = iter.dtype(0) == iter.dtype(1);
268   if (same_dtype && iter.is_contiguous()) {
269     // Contiguous same-dtype copies can always use cudaMemcpyAsync
270     return false;
271   } else if (dst_device.is_cuda() && src_device.is_cuda()) {
272     // Copies between GPUs can use the copy kernel if P2P is supported
273     return !p2p_enabled;
274   } else {
275     // The remaining cases require temporaries. For example, this includes
276     // non-contiguous copies between CPU and GPU.
277     return true;
278   }
279 }
280 
maybe_enable_p2p_access(Device dst_device,Device src_device)281 static bool maybe_enable_p2p_access(Device dst_device, Device src_device) {
282   if (dst_device.is_cpu() || src_device.is_cpu()) {
283     return false;
284   }
285   return at::cuda::get_p2p_access(src_device.index(), dst_device.index());
286 }
287 
copy_kernel_cuda(TensorIterator & iter,bool non_blocking)288 static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
289   TORCH_CHECK(iter.ntensors() == 2);
290 
291   Device dst_device = iter.device(0);
292   Device src_device = iter.device(1);
293 
294   // Enable p2p access between devices. (No-op if it involves the CPU)
295   bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);
296 
297   if (copy_requires_temporaries(iter, p2p_enabled)) {
298     // NB: this involves recursive calls to copy. Be careful that those copies
299     // don't require temporaries or you will cause an infinite recursion!
300     auto& dst = iter.tensor(0);
301     Tensor dst_contig;
302     Tensor src_contig;
303 
304     // If non_blocking is true - type conversions are performed on the GPU
305     // For blocking transfers conversions are performed on CPU to avoid allocating
306     // extra GPU memory
307     // for GPU-GPU transfers conversions are performed on the source device
308     auto conversion_device = non_blocking ? kCUDA : kCPU;
309     if (iter.device_type(1) == conversion_device) {
310       dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
311       src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous();
312     } else {
313       bool same_type = iter.dtype(0) == iter.dtype(1);
314       dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
315       src_contig = iter.tensor(1).expand_as(dst).contiguous();
316     }
317 
318     // propagate the correct conjugate bit
319     dst_contig._set_conj(dst.is_conj());
320     src_contig._set_conj(iter.tensor(1).is_conj());
321 
322     dst_contig._set_neg(dst.is_neg());
323     src_contig._set_neg(iter.tensor(1).is_neg());
324 
325     // perform a same-dtype copy on contiguous tensors
326     TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
327     TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
328     dst_contig.copy_(src_contig, non_blocking);
329 
330     // if necessary, copy back into dst
331     if (!dst_contig.is_same(dst)) {
332       TORCH_INTERNAL_ASSERT(dst_contig.device() == dst.device());
333       dst.copy_(dst_contig, non_blocking);
334     }
335     return;
336   }
337 
338   // Copy on GPU (or between GPUs)
339   if (dst_device.is_cuda() && src_device.is_cuda()) {
340     copy_device_to_device(iter, non_blocking, p2p_enabled);
341     return;
342   }
343 
344   // Copy between CPU and GPU
345   cuda::OptionalCUDAGuard device_guard;
346   cudaMemcpyKind kind;
347   if (dst_device.is_cuda() && src_device.is_cpu()) {
348     device_guard.set_device(dst_device);
349     kind = cudaMemcpyHostToDevice;
350   } else if (dst_device.is_cpu() && src_device.is_cuda()) {
351     device_guard.set_device(src_device);
352     kind = cudaMemcpyDeviceToHost;
353   } else {
354     TORCH_INTERNAL_ASSERT(false, "unsupported devices in GPU copy_()");
355   }
356 
357   void* dst = iter.data_ptr(0);
358   void* src = iter.data_ptr(1);
359   int64_t nbytes = iter.numel() * iter.element_size(0);
360   CUDAStream stream = getCurrentCUDAStream();
361 
362   if (non_blocking) {
363     AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
364     // we use both the storage context and the tensor data pointer as the key
365     // for the caching host allocator. This allows us to better attribute the
366     // events to the original tensor allocation correctly. The cases we seek to
367     // handle are:
368 
369     // 1: a user can pass a pinned memory tensor with an alternative
370     // context, for example if allocating memory directly from the pinned memory
371     // allocator and constructing a tensor with torch::from_blob.
372 
373     // 2: a user can pass a tensor with a different base pointer to the original
374     // allocation (via slicing).
375     const auto& dst_tensor = iter.tensor(0);
376     const auto& src_tensor = iter.tensor(1);
377     const auto& host_tensor = (dst_device == kCPU ? dst_tensor : src_tensor);
378     auto* ptr = (dst_device == kCPU ? dst : src);
379     auto* ctx = host_tensor.storage().data_ptr().get_context();
380     // TODO: warn on the return value.
381     CachingHostAllocator_recordEvent(ptr, ctx, stream);
382 
383   } else {
384     at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
385   }
386 
387   if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
388      iter.tensor(0).conj_physical_();
389   }
390   if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
391      iter.tensor(0).neg_();
392   }
393 }
394 
395 REGISTER_DISPATCH(copy_stub, &copy_kernel_cuda);
396 
397 } // namespace at::native
398