xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Blas.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/NamedTensor.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/ExpandUtils.h>
6 #include <ATen/NamedTensorUtils.h>
7 #include <ATen/Config.h>
8 
9 #include <ATen/native/mkldnn/Matmul.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/CPUFunctions.h>
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/_efficientzerotensor.h>
17 #include <ATen/ops/addmv.h>
18 #include <ATen/ops/addmv_native.h>
19 #include <ATen/ops/copy_native.h>
20 #include <ATen/ops/dot.h>
21 #include <ATen/ops/dot_native.h>
22 #include <ATen/ops/empty.h>
23 #include <ATen/ops/mul_cpu_dispatch.h>
24 #include <ATen/ops/mv_native.h>
25 #include <ATen/ops/scalar_tensor_native.h>
26 #include <ATen/ops/vdot_native.h>
27 #endif
28 
29 namespace at::meta {
TORCH_META_FUNC(addmv)30 TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) {
31   TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1),
32     "vector + matrix @ vector expected, got ", self.dim(), ", ", mat.dim(), ", ", vec.dim());
33 
34   TORCH_CHECK(mat.size(1) == vec.size(0) && (mat.size(0) == self.numel() || self.numel() == 1),
35     "size mismatch, got input (", self.size(0), "), mat (", mat.size(0), "x", mat.size(1), "), vec (", vec.size(0), ")");
36   auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self);
37   set_output_raw_strided(0, IntArrayRef(mat.sizes().data(), 1), {}, vec.options(), names);
38 }
39 } // namespace at::meta
40 
41 namespace at::native {
42 
43 template<typename scalar_t>
44 void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, int64_t lda, const scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
45 
46 template<typename scalar_t>
47 scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
48 
49 template<typename scalar_t>
50 scalar_t vdot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy);
51 
lda_cond(int64_t m,int64_t n,int64_t lda)52 constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) {
53   return n == 1 || lda >= std::max<int64_t>(1L, m);
54 }
55 
56 
57 
58 
TORCH_IMPL_FUNC(addmv_out_cpu)59 TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) {
60   c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)});
61   auto betaval = beta_.toComplexDouble();
62   if (mat.numel() == 0) {
63     // shortcut for an empty matrix
64     // By definition, when beta==0, values in self should be ignored. nans and infs
65     // should not propagate
66     if (betaval == 0.0) {
67       result.zero_();
68     } else {
69       at::cpu::mul_out(
70           // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
71           const_cast<Tensor&>(result),
72           self,
73           at::native::scalar_tensor(
74               beta_, self.scalar_type(), std::nullopt /* layout */, at::kCPU, std::nullopt /* pin_memory */));
75     }
76   } else {
77     if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents is ignored
78       // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
79       at::native::copy_(const_cast<Tensor&>(result), *self_);
80     }
81     if (result.numel() != 0) {
82 
83       NoNamesGuard guard;
84       if (use_mkldnn_matmul(mat, vec, /*result=*/Tensor())){
85         mkldnn_matmul(mat, vec, result, beta_.to<float>(), alpha_.to<float>());
86         return;
87       }
88 
89       auto r_stride = result.stride(0);
90       AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, mat.scalar_type(), "addmv_impl_cpu", [&] {
91         auto beta = beta_.to<scalar_t>();
92         auto alpha = alpha_.to<scalar_t>();
93         if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) {
94           gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(1),
95               vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
96         }
97         else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) {
98           gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.const_data_ptr<scalar_t>(), mat.stride(0),
99               vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
100         }
101         else {
102           Tensor cmat = mat.contiguous();
103           gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.const_data_ptr<scalar_t>(), cmat.stride(0),
104               vec.const_data_ptr<scalar_t>(), vec.stride(0), beta, result.mutable_data_ptr<scalar_t>(), r_stride);
105         }
106       });
107     }
108   }
109 }
110 
mv_out(const Tensor & self,const Tensor & vec,Tensor & result)111 Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) {
112   //self arg sent to addmv_out cannot be resized
113   //here we use result as self argument for addmv, and result is user supplied and can be wrong size
114   //it's not a hard error, because we allow resizing result, but it becomes a hard error
115   //in addmv, because addmv expects self to satisfy proper conditions
116   //to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0
117   if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) {
118     Tensor self_addmv = at::empty({self.size(0)}, vec.options());
119     return at::addmv_out(result, self_addmv, self, vec, 0, 1);
120   }
121   return at::addmv_out(result, result, self, vec, 0, 1);
122 }
123 
mv(const Tensor & self,const Tensor & vec)124 Tensor mv(const Tensor &self, const Tensor &vec) {
125   Tensor result = at::empty({self.size(0)}, vec.options());
126   //inplace version is more efficient if we can use it
127   return at::addmv_(result, self, vec, 0, 1);
128 }
129 
dot_check(const Tensor & self,const Tensor & other)130 inline void dot_check(const Tensor& self, const Tensor& other) {
131   TORCH_CHECK(
132       self.dim() == 1 && other.dim() == 1,
133       "1D tensors expected, but got ",
134       self.dim(),
135       "D and ",
136       other.dim(),
137       "D tensors");
138 
139   TORCH_CHECK(
140       self.scalar_type() == other.scalar_type(),
141       "dot : expected both vectors to have same dtype, but found ",
142       self.scalar_type(),
143       " and ",
144       other.scalar_type());
145 
146   TORCH_CHECK(
147       self.numel() == other.numel(),
148       "inconsistent tensor size, expected tensor [",
149       self.numel(),
150       "] and src [",
151       other.numel(),
152       "] to have the same number of elements, but got ",
153       self.numel(),
154       " and ",
155       other.numel(),
156       " elements respectively");
157 }
158 
dot(const Tensor & self,const Tensor & other)159 Tensor dot(const Tensor &self, const Tensor &other){
160   if (self.is_complex()) {
161     if (self.is_conj()) {
162       if (other.is_conj()) {
163         return (at::native::dot(self.conj(), other.conj())).conj();
164        } else {
165          return at::native::vdot(self.conj(), other);
166        }
167     } else if (other.is_conj()) {
168       return at::native::vdot(other.conj(), self);
169     }
170   }
171 
172   at::NoNamesGuard guard;
173   dot_check(self, other);
174 
175   if (self._is_zerotensor() || other._is_zerotensor()) {
176     return at::_efficientzerotensor({}, self.options());
177   }
178 
179   if (use_mkldnn_matmul(self, other, /*result=*/Tensor())){
180     // mkldnn matmul expect result have sizes info to create ideep tensor
181     auto r =  at::empty({1, 1}, self.options());
182     mkldnn_matmul(self, other, r, /*beta=*/0);
183     return r;
184   }
185 
186   return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(), "dot", [&] {
187     Tensor result = at::empty({}, self.options());
188     result.fill_(dot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t*>(other.const_data_ptr<scalar_t>()), other.stride(0)));
189     return result;
190   });
191 }
192 
vdot(const Tensor & self,const Tensor & other)193 Tensor vdot(const Tensor &self, const Tensor &other){
194   // Dispatch to `dot` for real dtypes.
195   if (!self.is_complex()){
196     return at::dot(self, other);
197   }
198 
199   if (self.is_conj()) {
200     if (other.is_conj()) {
201       return at::native::vdot(other.conj(), self.conj());
202     } else {
203       return at::native::dot(self.conj(), other);
204     }
205   } else if (other.is_conj()) {
206     return (at::native::dot(self, other.conj())).conj();
207   }
208 
209   at::NoNamesGuard guard;
210   // For complex dtypes.
211   dot_check(self, other);
212 
213   if (self._is_zerotensor() || other._is_zerotensor()) {
214     return at::_efficientzerotensor({}, self.options());
215   }
216 
217   return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
218     Tensor result = at::empty({}, self.options());
219     result.fill_(vdot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t *>(other.const_data_ptr<scalar_t>()), other.stride(0)));
220     return result;
221   });
222 
223 }
224 
225 }  // namespace at::native
226