xref: /aosp_15_r20/external/pytorch/aten/src/ATen/Utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <ATen/Context.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/Dispatch.h>
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Utils.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/accumulate.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker namespace at {
11*da0073e9SAndroid Build Coastguard Worker 
_crash_if_asan(int arg)12*da0073e9SAndroid Build Coastguard Worker int _crash_if_asan(int arg) {
13*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
14*da0073e9SAndroid Build Coastguard Worker   volatile char x[3];
15*da0073e9SAndroid Build Coastguard Worker   x[arg] = 0;
16*da0073e9SAndroid Build Coastguard Worker   return x[0];
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker 
19*da0073e9SAndroid Build Coastguard Worker namespace detail {
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker template <typename T>
tensor_cpu(ArrayRef<T> values,const TensorOptions & options)22*da0073e9SAndroid Build Coastguard Worker Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options) {
23*da0073e9SAndroid Build Coastguard Worker   auto result = at::empty(values.size(), options);
24*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(result.is_contiguous());
25*da0073e9SAndroid Build Coastguard Worker   AT_DISPATCH_ALL_TYPES_AND_COMPLEX(result.scalar_type(), "tensor_cpu", [&] {
26*da0073e9SAndroid Build Coastguard Worker     std::copy(
27*da0073e9SAndroid Build Coastguard Worker         values.begin(), values.end(), result.template data_ptr<scalar_t>());
28*da0073e9SAndroid Build Coastguard Worker   });
29*da0073e9SAndroid Build Coastguard Worker   return result;
30*da0073e9SAndroid Build Coastguard Worker }
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker template <typename T>
tensor_backend(ArrayRef<T> values,const TensorOptions & options)33*da0073e9SAndroid Build Coastguard Worker Tensor tensor_backend(ArrayRef<T> values, const TensorOptions& options) {
34*da0073e9SAndroid Build Coastguard Worker   auto cpu_tensor = tensor_cpu(values, options.device(DeviceType::CPU));
35*da0073e9SAndroid Build Coastguard Worker   return cpu_tensor.to(options.device());
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker template <typename T>
tensor_complex_cpu(ArrayRef<T> values,const TensorOptions & options)39*da0073e9SAndroid Build Coastguard Worker Tensor tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options) {
40*da0073e9SAndroid Build Coastguard Worker   auto result = at::empty(values.size(), options);
41*da0073e9SAndroid Build Coastguard Worker   AT_ASSERT(result.is_contiguous());
42*da0073e9SAndroid Build Coastguard Worker   AT_DISPATCH_COMPLEX_TYPES(result.scalar_type(), "tensor_cpu", [&] {
43*da0073e9SAndroid Build Coastguard Worker     std::copy(
44*da0073e9SAndroid Build Coastguard Worker         values.begin(), values.end(), result.template data_ptr<scalar_t>());
45*da0073e9SAndroid Build Coastguard Worker   });
46*da0073e9SAndroid Build Coastguard Worker   return result;
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker template <typename T>
tensor_complex_backend(ArrayRef<T> values,const TensorOptions & options)50*da0073e9SAndroid Build Coastguard Worker Tensor tensor_complex_backend(
51*da0073e9SAndroid Build Coastguard Worker     ArrayRef<T> values,
52*da0073e9SAndroid Build Coastguard Worker     const TensorOptions& options) {
53*da0073e9SAndroid Build Coastguard Worker   auto cpu_tensor = tensor_complex_cpu(values, options.device(DeviceType::CPU));
54*da0073e9SAndroid Build Coastguard Worker   return cpu_tensor.to(options.device());
55*da0073e9SAndroid Build Coastguard Worker }
56*da0073e9SAndroid Build Coastguard Worker } // namespace detail
57*da0073e9SAndroid Build Coastguard Worker 
58*da0073e9SAndroid Build Coastguard Worker #define TENSOR(T, _1)                                               \
59*da0073e9SAndroid Build Coastguard Worker   Tensor tensor(ArrayRef<T> values, const TensorOptions& options) { \
60*da0073e9SAndroid Build Coastguard Worker     if (options.device().type() != c10::DeviceType::CPU) {          \
61*da0073e9SAndroid Build Coastguard Worker       return at::detail::tensor_backend(values, options);           \
62*da0073e9SAndroid Build Coastguard Worker     } else {                                                        \
63*da0073e9SAndroid Build Coastguard Worker       return at::detail::tensor_cpu(values, options);               \
64*da0073e9SAndroid Build Coastguard Worker     }                                                               \
65*da0073e9SAndroid Build Coastguard Worker   }
66*da0073e9SAndroid Build Coastguard Worker AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR)
67*da0073e9SAndroid Build Coastguard Worker #undef TENSOR
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker #define TENSOR(T, _1)                                               \
70*da0073e9SAndroid Build Coastguard Worker   Tensor tensor(ArrayRef<T> values, const TensorOptions& options) { \
71*da0073e9SAndroid Build Coastguard Worker     if (options.device().type() != c10::DeviceType::CPU) {          \
72*da0073e9SAndroid Build Coastguard Worker       return at::detail::tensor_complex_backend(values, options);   \
73*da0073e9SAndroid Build Coastguard Worker     } else {                                                        \
74*da0073e9SAndroid Build Coastguard Worker       return at::detail::tensor_complex_cpu(values, options);       \
75*da0073e9SAndroid Build Coastguard Worker     }                                                               \
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker AT_FORALL_COMPLEX_TYPES(TENSOR)
78*da0073e9SAndroid Build Coastguard Worker #undef TENSOR
79*da0073e9SAndroid Build Coastguard Worker } // namespace at
80