1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/FunctionOfAMatrixUtils.h>
3
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorIterator.h>
6 #include <c10/util/irange.h>
7
8 #if (defined(_WIN32) || defined(_WIN64))
9 #define RESTRICT __restrict
10 #else
11 #define RESTRICT __restrict__
12 #endif
13
14 namespace at::native {
15
16 namespace {
17
_compute_linear_combination_cpu_kernel(TensorIterator & iter,int64_t in_stride,int64_t coeff_stride,int64_t num_summations)18 void _compute_linear_combination_cpu_kernel(
19 TensorIterator& iter,
20 int64_t in_stride,
21 int64_t coeff_stride,
22 int64_t num_summations
23 ) {
24 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
25 at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
26 iter.dtype(),
27 "_compute_linear_combination_cpu", [&] {
28 auto loop = [&](char** data, const int64_t* strides, int64_t n) {
29 auto* RESTRICT out_ptr = data[0];
30 auto* RESTRICT in_ptr = data[1];
31 auto* RESTRICT coeff_ptr = data[2];
32
33 for (const auto elem C10_UNUSED : c10::irange(n)) {
34 auto* RESTRICT out_data = reinterpret_cast<scalar_t*>(out_ptr);
35 auto* RESTRICT in_data = reinterpret_cast<scalar_t*>(in_ptr);
36 using primitive_t = typename scalar_value_type<scalar_t>::type;
37 auto* RESTRICT coeff_data = reinterpret_cast<primitive_t*>(coeff_ptr);
38
39 // perform summation
40 for (const auto i : c10::irange(num_summations)) {
41 *out_data += in_data[i * in_stride] * coeff_data[i * coeff_stride];
42 }
43
44 out_ptr += strides[0];
45 in_ptr += strides[1];
46 coeff_ptr += strides[2];
47 }
48 };
49 iter.for_each(loop);
50 });
51 }
52
53 }
54
55 REGISTER_DISPATCH(_compute_linear_combination_stub, &_compute_linear_combination_cpu_kernel);
56
57 } // namespace at::native
58