xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/FunctionOfAMatrixUtilsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3 
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorIterator.h>
6 #include <c10/util/irange.h>
7 
8 #if (defined(_WIN32) || defined(_WIN64))
9 #define RESTRICT __restrict
10 #else
11 #define RESTRICT __restrict__
12 #endif
13 
14 namespace at::native {
15 
16 namespace {
17 
_compute_linear_combination_cpu_kernel(TensorIterator & iter,int64_t in_stride,int64_t coeff_stride,int64_t num_summations)18 void _compute_linear_combination_cpu_kernel(
19   TensorIterator& iter,
20   int64_t in_stride,
21   int64_t coeff_stride,
22   int64_t num_summations
23 ) {
24   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
25     at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
26     iter.dtype(),
27     "_compute_linear_combination_cpu", [&] {
28       auto loop = [&](char** data, const int64_t* strides, int64_t n) {
29         auto* RESTRICT out_ptr = data[0];
30         auto* RESTRICT in_ptr = data[1];
31         auto* RESTRICT coeff_ptr = data[2];
32 
33         for (const auto elem C10_UNUSED : c10::irange(n)) {
34           auto* RESTRICT out_data = reinterpret_cast<scalar_t*>(out_ptr);
35           auto* RESTRICT in_data = reinterpret_cast<scalar_t*>(in_ptr);
36           using primitive_t = typename scalar_value_type<scalar_t>::type;
37           auto* RESTRICT coeff_data = reinterpret_cast<primitive_t*>(coeff_ptr);
38 
39           // perform summation
40           for (const auto i : c10::irange(num_summations)) {
41             *out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
42           }
43 
44           out_ptr += strides[0];
45           in_ptr += strides[1];
46           coeff_ptr += strides[2];
47         }
48       };
49       iter.for_each(loop);
50   });
51 }
52 
53 }
54 
55 REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cpu_kernel);
56 
57 } // namespace at::native
58