1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/LinearAlgebra.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/SharedReduceOps.h>
7 #include <ATen/native/cpu/Reduce.h>
8 #include <ATen/native/cpu/Loops.h>
9 #include <c10/util/irange.h>
10
11 namespace at::native { namespace {
12
addr_kernel(TensorIterator & iter,const Scalar & beta,const Scalar & alpha)13 void addr_kernel(TensorIterator &iter,
14 const Scalar& beta, const Scalar& alpha) {
15 if (iter.dtype() == ScalarType::Bool) {
16 using scalar_t = bool;
17 auto beta_val = beta.to<scalar_t>();
18 auto alpha_val = alpha.to<scalar_t>();
19
20 // when beta is false, values in self should be ignored,
21 // nans and infs in self should not propagate.
22 if (beta_val == false) {
23 cpu_kernel(iter,
24 [=](scalar_t /*self_val*/,
25 scalar_t vec1_val,
26 scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
27 return alpha_val && vec1_val && vec2_val;
28 }
29 );
30 } else {
31 cpu_kernel(iter,
32 [=](scalar_t self_val,
33 scalar_t vec1_val,
34 scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
35 return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
36 }
37 );
38 }
39 return;
40 }
41
42 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
43 iter.dtype(), "addr_cpu", [&]() {
44 using Vec = Vectorized<scalar_t>;
45
46 auto beta_val = beta.to<scalar_t>();
47 auto alpha_val = alpha.to<scalar_t>();
48
49 auto beta_vec = Vec(beta_val);
50 auto alpha_vec = Vec(alpha_val);
51
52 const scalar_t zero_val(0);
53 // when beta == 0, values in self should be ignored,
54 // nans and infs in self should not propagate.
55 if (beta_val == zero_val) {
56 cpu_kernel_vec(iter,
57 [=](scalar_t /*self_val*/,
58 scalar_t vec1_val,
59 scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
60 return alpha_val * vec1_val * vec2_val;
61 },
62 [=](Vec /*self_vec*/,
63 Vec vec1_vec,
64 Vec vec2_vec) __ubsan_ignore_undefined__ {
65 return alpha_vec * vec1_vec * vec2_vec;
66 }
67 );
68 } else {
69 cpu_kernel_vec(iter,
70 [=](scalar_t self_val,
71 scalar_t vec1_val,
72 scalar_t vec2_val) __ubsan_ignore_undefined__ -> scalar_t {
73 return beta_val * self_val + alpha_val * vec1_val * vec2_val;
74 },
75 [=](Vec self_vec,
76 Vec vec1_vec,
77 Vec vec2_vec) __ubsan_ignore_undefined__ {
78 return beta_vec * self_vec + alpha_vec * vec1_vec * vec2_vec;
79 }
80 );
81 }
82 }
83 );
84 }
85
86 } // anonymous namespace
87
88 REGISTER_DISPATCH(addr_stub, &addr_kernel);
89 } // namespace at::native
90