1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/NamedTensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/ExpandUtils.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/Config.h>
8
9 #include <ATen/native/mkldnn/Matmul.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/CPUFunctions.h>
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_efficientzerotensor.h>
17 #include <ATen/ops/addmv.h>
18 #include <ATen/ops/addmv_native.h>
19 #include <ATen/ops/copy_native.h>
20 #include <ATen/ops/dot.h>
21 #include <ATen/ops/dot_native.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/mul_cpu_dispatch.h>
24 #include <ATen/ops/mv_native.h>
25 #include <ATen/ops/scalar_tensor_native.h>
26 #include <ATen/ops/vdot_native.h>
27 #endif
28
29 namespace at::meta {
TORCH_META_FUNC(addmv)30 TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) {
31 TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
32 "vector + matrix @ vector expected, got ", self.dim(), ", ", mat.dim(), ", ", vec.dim());
33
34 TORCH_CHECK(mat.size(1) == vec.size(0) && (mat.size(0) == self.numel() || self.numel() == 1),
35 "size mismatch, got input (", self.size(0), "), mat (", mat.size(0), "x", mat.size(1), "), vec (", vec.size(0), ")");
36 auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
37 set_output_raw_strided(0, IntArrayRef(mat.sizes().data(), 1), {}, vec.options(), names);
38 }
39 } // namespace at::meta
40
41 namespace at::native {
42
43 template<typename scalar_t>
44 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
45
46 template<typename scalar_t>
47 scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
48
49 template<typename scalar_t>
50 scalar_t vdot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
51
lda_cond(int64_t m,int64_t n,int64_t lda)52 constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) {
53 return n == 1 || lda >= std::max<int64_t>(1L, m);
54 }
55
56
57
58
TORCH_IMPL_FUNC(addmv_out_cpu)59 TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
60 c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
61 auto betaval = beta_.toComplexDouble();
62 if (mat.numel() == 0) {
63 // shortcut for an empty matrix
64 // By definition, when beta==0, values in self should be ignored. nans and infs
65 // should not propagate
66 if (betaval == 0.0) {
67 result.zero_();
68 } else {
69 at::cpu::mul_out(
70 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
71 const_cast<Tensor&>(result),
72 self,
73 at::native::scalar_tensor(
74 beta_, self.scalar_type(), std::nullopt /* layout */, at::kCPU, std::nullopt /* pin_memory */));
75 }
76 } else {
77 if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents is ignored
78 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
79 at::native::copy_(const_cast<Tensor&>(result), *self_);
80 }
81 if (result.numel() != 0) {
82
83 NoNamesGuard guard;
84 if (use_mkldnn_matmul(mat, vec, /*result=*/Tensor())){
85 mkldnn_matmul(mat, vec, result, beta_.to<float>(), alpha_.to<float>());
86 return;
87 }
88
89 auto r_stride = result.stride(0);
90 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, mat.scalar_type(), "addmv_impl_cpu", [&] {
91 auto beta = beta_.to<scalar_t>();
92 auto alpha = alpha_.to<scalar_t>();
93 if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) {
94 gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(1),
95 vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
96 }
97 else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) {
98 gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(0),
99 vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
100 }
101 else {
102 Tensor cmat = mat.contiguous();
103 gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.const_data_ptr<scalar_t>(), cmat.stride(0),
104 vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
105 }
106 });
107 }
108 }
109 }
110
mv_out(const Tensor & self,const Tensor & vec,Tensor & result)111 Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) {
112 //self arg sent to addmv_out cannot be resized
113 //here we use result as self argument for addmv, and result is user supplied and can be wrong size
114 //it's not a hard error, because we allow resizing result, but it becomes a hard error
115 //in addmv, because addmv expects self to satisfy proper conditions
116 //to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0
117 if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) {
118 Tensor self_addmv = at::empty({self.size(0)}, vec.options());
119 return at::addmv_out(result, self_addmv, self, vec, 0, 1);
120 }
121 return at::addmv_out(result, result, self, vec, 0, 1);
122 }
123
mv(const Tensor & self,const Tensor & vec)124 Tensor mv(const Tensor &self, const Tensor &vec) {
125 Tensor result = at::empty({self.size(0)}, vec.options());
126 //inplace version is more efficient if we can use it
127 return at::addmv_(result, self, vec, 0, 1);
128 }
129
dot_check(const Tensor & self,const Tensor & other)130 inline void dot_check(const Tensor& self, const Tensor& other) {
131 TORCH_CHECK(
132 self.dim() == 1 && other.dim() == 1,
133 "1D tensors expected, but got ",
134 self.dim(),
135 "D and ",
136 other.dim(),
137 "D tensors");
138
139 TORCH_CHECK(
140 self.scalar_type() == other.scalar_type(),
141 "dot : expected both vectors to have same dtype, but found ",
142 self.scalar_type(),
143 " and ",
144 other.scalar_type());
145
146 TORCH_CHECK(
147 self.numel() == other.numel(),
148 "inconsistent tensor size, expected tensor [",
149 self.numel(),
150 "] and src [",
151 other.numel(),
152 "] to have the same number of elements, but got ",
153 self.numel(),
154 " and ",
155 other.numel(),
156 " elements respectively");
157 }
158
dot(const Tensor & self,const Tensor & other)159 Tensor dot(const Tensor &self, const Tensor &other){
160 if (self.is_complex()) {
161 if (self.is_conj()) {
162 if (other.is_conj()) {
163 return (at::native::dot(self.conj(), other.conj())).conj();
164 } else {
165 return at::native::vdot(self.conj(), other);
166 }
167 } else if (other.is_conj()) {
168 return at::native::vdot(other.conj(), self);
169 }
170 }
171
172 at::NoNamesGuard guard;
173 dot_check(self, other);
174
175 if (self._is_zerotensor() || other._is_zerotensor()) {
176 return at::_efficientzerotensor({}, self.options());
177 }
178
179 if (use_mkldnn_matmul(self, other, /*result=*/Tensor())){
180 // mkldnn matmul expect result have sizes info to create ideep tensor
181 auto r = at::empty({1, 1}, self.options());
182 mkldnn_matmul(self, other, r, /*beta=*/0);
183 return r;
184 }
185
186 return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(), "dot", [&] {
187 Tensor result = at::empty({}, self.options());
188 result.fill_(dot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t*>(other.const_data_ptr<scalar_t>()), other.stride(0)));
189 return result;
190 });
191 }
192
vdot(const Tensor & self,const Tensor & other)193 Tensor vdot(const Tensor &self, const Tensor &other){
194 // Dispatch to `dot` for real dtypes.
195 if (!self.is_complex()){
196 return at::dot(self, other);
197 }
198
199 if (self.is_conj()) {
200 if (other.is_conj()) {
201 return at::native::vdot(other.conj(), self.conj());
202 } else {
203 return at::native::dot(self.conj(), other);
204 }
205 } else if (other.is_conj()) {
206 return (at::native::dot(self, other.conj())).conj();
207 }
208
209 at::NoNamesGuard guard;
210 // For complex dtypes.
211 dot_check(self, other);
212
213 if (self._is_zerotensor() || other._is_zerotensor()) {
214 return at::_efficientzerotensor({}, self.options());
215 }
216
217 return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
218 Tensor result = at::empty({}, self.options());
219 result.fill_(vdot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t *>(other.const_data_ptr<scalar_t>()), other.stride(0)));
220 return result;
221 });
222
223 }
224
225 } // namespace at::native
226