xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/LinearAlgebra.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/native/cpu/Reduce.h>
8 #include <ATen/native/cpu/Loops.h>
9 #include <c10/util/irange.h>
10 
11 namespace at::native { namespace {
12 
addr_kernel(TensorIterator & iter,const Scalar & beta,const Scalar & alpha)13 void addr_kernel(TensorIterator &iter,
14                  const Scalar& beta, const Scalar& alpha) {
15   if (iter.dtype() == ScalarType::Bool) {
16     using scalar_t = bool;
17     auto beta_val = beta.to<scalar_t>();
18     auto alpha_val = alpha.to<scalar_t>();
19 
20     // when beta is false, values in self should be ignored,
21     // nans and infs in self should not propagate.
22     if (beta_val == false) {
23       cpu_kernel(iter,
24         [=](scalar_t /*self_val*/,
25             scalar_t vec1_val,
26             scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
27           return alpha_val && vec1_val && vec2_val;
28         }
29       );
30     } else {
31       cpu_kernel(iter,
32         [=](scalar_t self_val,
33             scalar_t vec1_val,
34             scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
35           return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
36         }
37       );
38     }
39     return;
40   }
41 
42   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
43     iter.dtype(), "addr_cpu", [&]() {
44       using Vec = Vectorized<scalar_t>;
45 
46       auto beta_val = beta.to<scalar_t>();
47       auto alpha_val = alpha.to<scalar_t>();
48 
49       auto beta_vec = Vec(beta_val);
50       auto alpha_vec = Vec(alpha_val);
51 
52       const scalar_t zero_val(0);
53       // when beta == 0, values in self should be ignored,
54       // nans and infs in self should not propagate.
55       if (beta_val == zero_val) {
56         cpu_kernel_vec(iter,
57           [=](scalar_t /*self_val*/,
58               scalar_t vec1_val,
59               scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
60             return alpha_val * vec1_val * vec2_val;
61           },
62           [=](Vec /*self_vec*/,
63               Vec vec1_vec,
64               Vec vec2_vec) __ubsan_ignore_undefined__ {
65             return alpha_vec * vec1_vec * vec2_vec;
66           }
67         );
68       } else {
69         cpu_kernel_vec(iter,
70           [=](scalar_t self_val,
71               scalar_t vec1_val,
72               scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
73             return beta_val * self_val + alpha_val * vec1_val * vec2_val;
74           },
75           [=](Vec self_vec,
76               Vec vec1_vec,
77               Vec vec2_vec) __ubsan_ignore_undefined__ {
78             return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec;
79           }
80         );
81       }
82     }
83   );
84 }
85 
86 } // anonymous namespace
87 
88 REGISTER_DISPATCH(addr_stub, &addr_kernel);
89 } // namespace at::native
90