xref: /aosp_15_r20/external/pytorch/aten/src/ATen/AccumulateType.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/AccumulateType.h>
2 
3 namespace at {
4 
toAccumulateType(c10::ScalarType type,c10::DeviceType device)5 c10::ScalarType toAccumulateType(c10::ScalarType type, c10::DeviceType device) {
6   switch (type) {
7 #define DEFINE_CASE(scalar_t, TypeNum)                                                             \
8     case ScalarType::TypeNum:                                                                      \
9       switch (device) {                                                                            \
10         case DeviceType::CUDA:                                                                     \
11           return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::CUDA>>::value; \
12         case DeviceType::XPU:                                                                      \
13           return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::XPU>>::value;  \
14         case DeviceType::MPS:                                                                      \
15           return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::MPS>>::value;  \
16         default:                                                                                   \
17           return CppTypeToScalarType<at::acc_type_device<scalar_t, c10::DeviceType::CPU>>::value;  \
18       }
19 
20     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF_F8NZ(DEFINE_CASE)
21 #undef DEFINE_CASE
22 
23     default: TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type);
24   }
25 }
26 
toAccumulateType(c10::ScalarType type,bool is_cuda)27 c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda) {
28   return is_cuda ? toAccumulateType(type, c10::DeviceType::CUDA) : toAccumulateType(type, c10::DeviceType::CPU);
29 }
30 
31 }
32