xref: /aosp_15_r20/external/pytorch/c10/core/CopyBytes.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/CopyBytes.h>
2 #include <c10/util/Logging.h>
3 
4 namespace c10 {
5 
6 // First dimension of the array is `bool async`: 0 is sync,
7 // 1 is async (non-blocking)
8 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
9 static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES]
10                                      [COMPILE_TIME_MAX_DEVICE_TYPES];
11 
_CopyBytesFunctionRegisterer(DeviceType fromType,DeviceType toType,CopyBytesFunction func_sync,CopyBytesFunction func_async)12 _CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer(
13     DeviceType fromType,
14     DeviceType toType,
15     CopyBytesFunction func_sync,
16     CopyBytesFunction func_async) {
17   auto from = static_cast<int>(fromType);
18   auto to = static_cast<int>(toType);
19   if (!func_async) {
20     // default to the sync function
21     func_async = func_sync;
22   }
23   CHECK(
24       g_copy_bytes[0][from][to] == nullptr &&
25       g_copy_bytes[1][from][to] == nullptr)
26       << "Duplicate registration for device type pair "
27       << c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType);
28   g_copy_bytes[0][from][to] = func_sync;
29   g_copy_bytes[1][from][to] = func_async;
30 }
31 
CopyBytes(size_t nbytes,const void * src,Device src_device,void * dst,Device dst_device,bool async)32 void CopyBytes(
33     size_t nbytes,
34     const void* src,
35     Device src_device,
36     void* dst,
37     Device dst_device,
38     bool async) {
39   auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())]
40                          [static_cast<int>(dst_device.type())];
41   CAFFE_ENFORCE(
42       ptr,
43       "No function found for copying from ",
44       c10::DeviceTypeName(src_device.type()),
45       " to ",
46       c10::DeviceTypeName(dst_device.type()));
47   ptr(nbytes, src, src_device, dst, dst_device);
48 }
49 
50 } // namespace c10
51