1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Dispatch_v2.h>
6 #include <ATen/cuda/CachingHostAllocator.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/CUDAEvent.h>
9 #include <ATen/cuda/PeerToPeerAccess.h>
10 #include <ATen/native/Copy.h>
11 #include <ATen/native/TensorIterator.h>
12 #include <ATen/native/cuda/Loops.cuh>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty_like.h>
18 #endif
19
20 #include <c10/cuda/CUDACachingAllocator.h>
21 #include <c10/cuda/CUDAStream.h>
22
23 // TODO(NS): Investigate why FP8 conversion intrinsics end up being slower
24 #ifdef AT_USE_NV_CVT_INTRINSICS
25 #include <cuda_fp8.h>
26 #endif
27
28 namespace at::native {
29
30 void neg_kernel_cuda(TensorIteratorBase &iter);
31 void conj_kernel_cuda(TensorIteratorBase &iter);
32
float8_copy_kernel_cuda(TensorIteratorBase & iter)33 void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
34 ScalarType dtype = iter.dtype(0);
35 ScalarType other_dtype = iter.dtype(1);
36 if (dtype == kFloat8_e4m3fn) {
37 switch (other_dtype) {
38 case kFloat:
39 gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
40 return Float8_e4m3fn(value);
41 });
42 break;
43 case kHalf:
44 gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
45 return Float8_e4m3fn(value);
46 });
47 break;
48 case kBFloat16:
49 gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
50 return Float8_e4m3fn(value);
51 });
52 break;
53 default:
54 gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; });
55 break;
56 }
57 } else if (dtype == kFloat8_e5m2) {
58 switch (other_dtype) {
59 case kFloat:
60 gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
61 #ifdef AT_USE_NV_CVT_INTRINSICS
62 const auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2);
63 return Float8_e5m2(x, Float8_e5m2::from_bits());
64 #else
65 return Float8_e5m2(value);
66 #endif
67 });
68 break;
69 case kHalf:
70 gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
71 #ifdef AT_USE_NV_CVT_INTRINSICS
72 const auto x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2);
73 return Float8_e5m2(x, Float8_e5m2::from_bits());
74 #else
75 return Float8_e5m2(value);
76 #endif
77 });
78 break;
79 case kBFloat16:
80 gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
81 #ifdef AT_USE_NV_CVT_INTRINSICS
82 const auto x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2);
83 return Float8_e5m2(x, Float8_e5m2::from_bits());
84 #else
85 return Float8_e5m2(value);
86 #endif
87 });
88 break;
89 default:
90 gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; });
91 break;
92 }
93 } else if (dtype == kFloat8_e4m3fnuz) {
94 switch (other_dtype) {
95 case kFloat:
96 gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
97 return Float8_e4m3fnuz(value);
98 });
99 break;
100 case kHalf:
101 gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
102 return Float8_e4m3fnuz(value);
103 });
104 break;
105 case kBFloat16:
106 gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
107 return Float8_e4m3fnuz(value);
108 });
109 break;
110 default:
111 gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; });
112 break;
113 }
114 } else if (dtype == kFloat8_e5m2fnuz) {
115 switch (other_dtype) {
116 case kFloat:
117 gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
118 return Float8_e5m2fnuz(value);
119 });
120 break;
121 case kHalf:
122 gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
123 return Float8_e5m2fnuz(value);
124 });
125 break;
126 case kBFloat16:
127 gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
128 return Float8_e5m2fnuz(value);
129 });
130 break;
131 default:
132 gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
133 break;
134 }
135 } else {
136 TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
137 }
138 }
139
140 // TODO: We probably can use the opaque type trick to avoid creating duplicate
141 // kernels for equivalent bit lengths
direct_copy_kernel_cuda(TensorIteratorBase & iter)142 void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
143 ScalarType dtype = iter.dtype(0);
144 if (isQIntType(dtype)) {
145 AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
146 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
147 });
148 } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
149 float8_copy_kernel_cuda(iter);
150 } else if (isBitsType(dtype)) {
151 TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
152 "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
153 AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
154 gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
155 });
156 } else {
157 AT_DISPATCH_V2(
158 dtype, "copy_", AT_WRAP([&] {
159 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
160 }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kHalf, kBool, kBFloat16, kComplexHalf, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
161 }
162 }
163
neg_conj_kernel_cuda(TensorIteratorBase & iter)164 void neg_conj_kernel_cuda(TensorIteratorBase &iter) {
165 AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "neg_conj_cuda", [&] {
166 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return -std::conj(x); });
167 });
168 }
169
170 using namespace at::cuda;
171
172 // device-to-device copy, does type conversion
copy_device_to_device(TensorIterator & iter,bool non_blocking,bool p2p_enabled)173 void copy_device_to_device(TensorIterator& iter,
174 bool non_blocking,
175 bool p2p_enabled) {
176 int64_t numel = iter.numel();
177
178 // We can memcpy the memory if both tensors have the same type AND both
179 // tensors are contiguous after dimension coalescing and reordering.
180 bool same_type = iter.dtype(0) == iter.dtype(1);
181 bool same_conj = iter.tensor(0).is_conj() == iter.tensor(1).is_conj();
182 bool same_neg = iter.tensor(0).is_neg() == iter.tensor(1).is_neg();
183 bool memcpy_eligible = same_type && same_conj && same_neg && iter.is_contiguous();
184
185 Device dst_device = iter.device(0);
186 Device src_device = iter.device(1);
187
188 CUDAGuard device_guard(src_device);
189
190 // We always perform the copy on the source device, using the current stream
191 // on the source device, and we fully synchronize on both src and dst's
192 // current streams for completion of the copy. We have to explicitly do this
193 // for non-contig copies. This mimics the behavior of cross-device
194 // cudaMemcpyAsync on the default stream.
195 CUDAStream copy_stream = getCurrentCUDAStream(src_device.index());
196 if (src_device != dst_device) {
197 // This is a cross-device copy on the src current stream and dst current
198 // stream. We perform a two-way barrier between both devices' streams
199 // before the copy. This ensures that any write-after-write and
200 // write-after-read dependencies on the destination side are handled, so
201 // that no one is operating on the dst memory when we perform the copy.
202 // src waits on dst barrier (src already waits on src)
203 CUDAEvent dst_ready;
204 device_guard.set_device(dst_device);
205 dst_ready.record(getCurrentCUDAStream(dst_device.index()));
206
207 device_guard.set_device(src_device);
208 dst_ready.block(copy_stream);
209 }
210
211 if (memcpy_eligible) {
212 void *dst = iter.data_ptr(0);
213 void *src = iter.data_ptr(1);
214 size_t size = numel * iter.element_size(0);
215 if (src != dst || src_device != dst_device) {
216 // Due to bizarre cuda driver intricacies, copies of
217 // cudaMallocAsynced memory between devices that aren't
218 // peer-to-peer-capable need "cudaMemcpyPeerAsync".
219 // So we let the allocator implement the correct call
220 // (either cudaMemcpyAsync or cudaMemcpyPeerAsync)
221 AT_CUDA_CHECK(CUDACachingAllocator::memcpyAsync(
222 dst, dst_device.index(),
223 src, src_device.index(),
224 size, copy_stream, p2p_enabled));
225 }
226 } else {
227 if (same_neg) {
228 if (!same_conj) {
229 conj_kernel_cuda(iter);
230 } else {
231 direct_copy_kernel_cuda(iter);
232 }
233 } else {
234 if (!same_conj) {
235 neg_conj_kernel_cuda(iter);
236 } else {
237 neg_kernel_cuda(iter);
238 }
239 }
240 }
241
242 if (src_device != dst_device) {
243 // dst waits on src barrier (dst already waits on dst). We cannot
244 // operate on dst's copy until the copy is complete.
245
246 // Still on src_device, record stream event
247 CUDAEvent src_ready;
248 src_ready.record(copy_stream);
249
250 device_guard.set_device(dst_device);
251 src_ready.block(getCurrentCUDAStream(dst_device.index()));
252 }
253
254 AT_CUDA_CHECK(cudaGetLastError());
255 }
256
copy_requires_temporaries(TensorIterator & iter,bool p2p_enabled)257 static bool copy_requires_temporaries(TensorIterator& iter, bool p2p_enabled) {
258 Device dst_device = iter.device(0);
259 Device src_device = iter.device(1);
260
261 if (dst_device == src_device) {
262 // We never require temporaries for copies on the same GPU.
263 TORCH_INTERNAL_ASSERT(dst_device.is_cuda() && src_device.is_cuda());
264 return false;
265 }
266
267 bool same_dtype = iter.dtype(0) == iter.dtype(1);
268 if (same_dtype && iter.is_contiguous()) {
269 // Contiguous same-dtype copies can always use cudaMemcpyAsync
270 return false;
271 } else if (dst_device.is_cuda() && src_device.is_cuda()) {
272 // Copies between GPUs can use the copy kernel if P2P is supported
273 return !p2p_enabled;
274 } else {
275 // The remaining cases require temporaries. For example, this includes
276 // non-contiguous copies between CPU and GPU.
277 return true;
278 }
279 }
280
maybe_enable_p2p_access(Device dst_device,Device src_device)281 static bool maybe_enable_p2p_access(Device dst_device, Device src_device) {
282 if (dst_device.is_cpu() || src_device.is_cpu()) {
283 return false;
284 }
285 return at::cuda::get_p2p_access(src_device.index(), dst_device.index());
286 }
287
copy_kernel_cuda(TensorIterator & iter,bool non_blocking)288 static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
289 TORCH_CHECK(iter.ntensors() == 2);
290
291 Device dst_device = iter.device(0);
292 Device src_device = iter.device(1);
293
294 // Enable p2p access between devices. (No-op if it involves the CPU)
295 bool p2p_enabled = maybe_enable_p2p_access(dst_device, src_device);
296
297 if (copy_requires_temporaries(iter, p2p_enabled)) {
298 // NB: this involves recursive calls to copy. Be careful that those copies
299 // don't require temporaries or you will cause an infinite recursion!
300 auto& dst = iter.tensor(0);
301 Tensor dst_contig;
302 Tensor src_contig;
303
304 // If non_blocking is true - type conversions are performed on the GPU
305 // For blocking transfers conversions are performed on CPU to avoid allocating
306 // extra GPU memory
307 // for GPU-GPU transfers conversions are performed on the source device
308 auto conversion_device = non_blocking ? kCUDA : kCPU;
309 if (iter.device_type(1) == conversion_device) {
310 dst_contig = dst.is_contiguous() ? dst : at::empty_like(dst, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
311 src_contig = iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous();
312 } else {
313 bool same_type = iter.dtype(0) == iter.dtype(1);
314 dst_contig = (dst.is_contiguous() && same_type) ? dst : at::empty_like(dst, iter.dtype(1), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
315 src_contig = iter.tensor(1).expand_as(dst).contiguous();
316 }
317
318 // propagate the correct conjugate bit
319 dst_contig._set_conj(dst.is_conj());
320 src_contig._set_conj(iter.tensor(1).is_conj());
321
322 dst_contig._set_neg(dst.is_neg());
323 src_contig._set_neg(iter.tensor(1).is_neg());
324
325 // perform a same-dtype copy on contiguous tensors
326 TORCH_INTERNAL_ASSERT(dst_contig.sizes().equals(src_contig.sizes()));
327 TORCH_INTERNAL_ASSERT(dst_contig.scalar_type() == src_contig.scalar_type());
328 dst_contig.copy_(src_contig, non_blocking);
329
330 // if necessary, copy back into dst
331 if (!dst_contig.is_same(dst)) {
332 TORCH_INTERNAL_ASSERT(dst_contig.device() == dst.device());
333 dst.copy_(dst_contig, non_blocking);
334 }
335 return;
336 }
337
338 // Copy on GPU (or between GPUs)
339 if (dst_device.is_cuda() && src_device.is_cuda()) {
340 copy_device_to_device(iter, non_blocking, p2p_enabled);
341 return;
342 }
343
344 // Copy between CPU and GPU
345 cuda::OptionalCUDAGuard device_guard;
346 cudaMemcpyKind kind;
347 if (dst_device.is_cuda() && src_device.is_cpu()) {
348 device_guard.set_device(dst_device);
349 kind = cudaMemcpyHostToDevice;
350 } else if (dst_device.is_cpu() && src_device.is_cuda()) {
351 device_guard.set_device(src_device);
352 kind = cudaMemcpyDeviceToHost;
353 } else {
354 TORCH_INTERNAL_ASSERT(false, "unsupported devices in GPU copy_()");
355 }
356
357 void* dst = iter.data_ptr(0);
358 void* src = iter.data_ptr(1);
359 int64_t nbytes = iter.numel() * iter.element_size(0);
360 CUDAStream stream = getCurrentCUDAStream();
361
362 if (non_blocking) {
363 AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
364 // we use both the storage context and the tensor data pointer as the key
365 // for the caching host allocator. This allows us to better attribute the
366 // events to the original tensor allocation correctly. The cases we seek to
367 // handle are:
368
369 // 1: a user can pass a pinned memory tensor with an alternative
370 // context, for example if allocating memory directly from the pinned memory
371 // allocator and constructing a tensor with torch::from_blob.
372
373 // 2: a user can pass a tensor with a different base pointer to the original
374 // allocation (via slicing).
375 const auto& dst_tensor = iter.tensor(0);
376 const auto& src_tensor = iter.tensor(1);
377 const auto& host_tensor = (dst_device == kCPU ? dst_tensor : src_tensor);
378 auto* ptr = (dst_device == kCPU ? dst : src);
379 auto* ctx = host_tensor.storage().data_ptr().get_context();
380 // TODO: warn on the return value.
381 CachingHostAllocator_recordEvent(ptr, ctx, stream);
382
383 } else {
384 at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
385 }
386
387 if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
388 iter.tensor(0).conj_physical_();
389 }
390 if (iter.tensor(0).is_neg() != iter.tensor(1).is_neg()) {
391 iter.tensor(0).neg_();
392 }
393 }
394
395 REGISTER_DISPATCH(copy_stub, ©_kernel_cuda);
396
397 } // namespace at::native
398