1 #include <ATen/AccumulateType.h>
2
3 namespace at {
4
toAccumulateType(c10::ScalarType type,c10::DeviceType device)5 c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) {
6 switch (type) {
7 #define DEFINE_CASE(scalar_t, TypeNum) \
8 case ScalarType::TypeNum: \
9 switch (device) { \
10 case DeviceType::CUDA: \
11 return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::CUDA>>::value; \
12 case DeviceType::XPU: \
13 return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::XPU>>::value; \
14 case DeviceType::MPS: \
15 return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::MPS>>::value; \
16 default: \
17 return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::CPU>>::value; \
18 }
19
20 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(DEFINE_CASE)
21 #undef DEFINE_CASE
22
23 default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
24 }
25 }
26
toAccumulateType(c10::ScalarType type,bool is_cuda)27 c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) {
28 return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU);
29 }
30
31 }
32