xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/xpu/detail/Matmul.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <c10/xpu/XPUFunctions.h>
3 
4 #include <ATen/ATen.h>
5 #include <ATen/record_function.h>
6 
7 #include <Attr.h>
8 #include <Utils.h>
9 
10 #include <oneapi/dnnl/dnnl.hpp>
11 
12 namespace at::native::onednn {
13 
matmul(at::Tensor & result,const at::Tensor & mat1,const at::Tensor & mat2,const at::Tensor & b_raw,bool m2_trans,Attr attr,const std::vector<sycl::event> & deps)14 sycl::event matmul(
15     at::Tensor& result,
16     const at::Tensor& mat1,
17     const at::Tensor& mat2,
18     const at::Tensor& b_raw,
19     bool m2_trans,
20     Attr attr,
21     const std::vector<sycl::event>& deps) {
22   int64_t dims = result.dim();
23   TORCH_CHECK(
24       dims == 2 || dims == 3,
25       "oneDNN matmul only works with 2D or 3D, got ",
26       dims);
27   TORCH_CHECK(
28       dims == mat1.dim() && dims == mat2.dim(),
29       "oneDNN input matrixes must have the same ranks");
30   TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined");
31 
32   at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device());
33   auto engine = GpuEngineManager::Instance().get_engine(cur_device);
34   auto stream = GpuStreamManager::Instance().get_stream();
35 
36   at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
37   at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
38   at::Tensor dst = is_onednn_matmul_strides(result, true) ? result : result.contiguous();
39 
40   int64_t m = dst.size(-2);
41   int64_t n = dst.size(-1);
42   int64_t k = m1.size(-1);
43   int64_t mb = 1;
44 
45   if (dims == 3) {
46     mb = dst.size(0);
47     TORCH_CHECK(
48         mb == m1.size(0) && mb == m2.size(0),
49         "batch size mismatch, dst mb: ",
50         mb,
51         "m1 mb",
52         m1.size(0),
53         " m2 mb: ",
54         m2.size(0));
55   }
56 
57   // validate bias and make it compatible with oneDNN implementation
58   bool with_bias = false;
59   at::Tensor b = b_raw;
60   if (b.defined()) {
61     with_bias = true;
62     if (b.dim() == 1) {
63       TORCH_CHECK(
64           b.size(0) == n || b.size(0) == 1,
65           "matmul supports [n] or [1] when bias dim is 1 ...");
66       if (b.size(0) == 0) {
67         with_bias = false;
68       } else if (m1.dim() == 3) {
69         b = b.expand({mb, m, n}).contiguous();
70       } else if (m1.dim() == 2) {
71         b = b.expand({1, n}).contiguous();
72       }
73     } else if (b.dim() == 2) {
74       TORCH_CHECK(
75           (b.size(0) == m && b.size(1) == n) ||
76               (b.size(0) == 1 && b.size(1) == n) ||
77               (b.size(0) == m && b.size(1) == 1) ||
78               (b.size(0) == 1 && b.size(1) == 1),
79           "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
80       if (b.size(0) == 1 && b.size(1) == 1)
81         b = b.expand({1, n}).contiguous();
82     } else if (b.dim() == 3) {
83       TORCH_CHECK(
84           at::are_expandable({mb, m, n}, b.sizes()),
85           "matmul bias must be expandable to:",
86           dst.sizes(),
87           " but got:",
88           b.sizes());
89       b = b.expand({mb, m, n}).contiguous();
90     } else if (b.dim() == 0) {
91       TORCH_CHECK(
92           b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
93       if (m1.dim() == 3) {
94         b = b.expand({mb, m, n}).contiguous();
95       } else {
96         b = b.expand({1, n}).contiguous();
97       }
98     } else {
99       TORCH_CHECK(0, "unsupported bias dim in matmul ...");
100     }
101   }
102 
103   b = b.contiguous(); // avoid reorder 2 times
104 
105   // xpu matmul support both ab/ba shape for m2 tensor, we don't check any more
106   auto m1_usr_dt = get_onednn_dtype(m1);
107   auto m2_usr_dt = get_onednn_dtype(m2);
108   auto dst_usr_dt = get_onednn_dtype(dst);
109 
110   auto m1_dt = m1_usr_dt;
111   auto m2_dt = m2_usr_dt;
112   auto dst_dt = dst_usr_dt;
113   dnnl::memory::data_type bias_dt;
114 
115   dnnl::memory::desc m1_md, m1_usr_md, m1_any_md;
116   dnnl::memory::desc m2_md, m2_usr_md, m2_any_md;
117   dnnl::memory::desc dst_md, dst_usr_md, dst_any_md;
118   dnnl::memory::desc bias_md;
119 
120   // Naive Master weight
121   if (m1_dt == dnnl::memory::data_type::bf16 && m2_dt == dnnl::memory::data_type::f32) {
122     m2_dt = dnnl::memory::data_type::bf16;
123     dst_dt = dnnl::memory::data_type::bf16;
124   } else if (
125       m1_dt == dnnl::memory::data_type::f32 && m2_dt == dnnl::memory::data_type::bf16) {
126     m1_dt = dnnl::memory::data_type::bf16;
127     dst_dt = dnnl::memory::data_type::bf16;
128   }
129 
130   dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
131   dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
132   if (dims == 2) {
133     m1_dims = {m, k};
134     m2_dims = {k, n};
135     dst_dims = {m, n};
136 
137     m1_strides = {m1.stride(0), m1.stride(1)};
138     if (m2_trans) {
139       m2_strides = {m2.stride(0), m2.stride(1)};
140     } else {
141       m2_strides = {m2.stride(1), m2.stride(0)};
142     }
143     dst_strides = {dst.stride(0), dst.stride(1)};
144   } else {
145     m1_dims = {mb, m, k};
146     m2_dims = {mb, k, n};
147     dst_dims = {mb, m, n};
148 
149     m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)};
150     if (m2_trans) {
151       m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)};
152     } else {
153       m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)};
154     }
155     dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)};
156   }
157 
158   if (with_bias) {
159     bias_dims = get_onednn_dims(b);
160     bias_dt = get_onednn_dtype(b);
161     bias_strides = get_onednn_strides(b);
162   }
163 
164   dnnl::post_ops po = attr.extract_post_ops(dst);
165 
166   std::unordered_map<int, dnnl::memory> args;
167   dnnl::matmul matmul_p;
168   dnnl::matmul::primitive_desc matmul_pd;
169 
170   // STEP1: create memory desc
171   m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides);
172   m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides);
173   dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides);
174 
175   // STEP2: creat attribute
176   dnnl::primitive_attr pattr;
177   pattr.set_post_ops(po);
178 
179   #if ONEDNN_SUPPORT_DETERMINISTIC
180     if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
181         pattr.set_deterministic(true);
182   #endif
183 
184   // scratchpad
185   pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
186 
187   if (m1_dt == dnnl::memory::data_type::f32) {
188     pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
189   }
190 
191   // STEP3: create primitive
192   if (with_bias) {
193     bias_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
194     matmul_pd =
195         dnnl::matmul::primitive_desc(engine, m1_md, m2_md, bias_md, dst_md, pattr);
196   } else {
197     matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
198   }
199 
200   matmul_p = dnnl::matmul(matmul_pd);
201 
202   m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides);
203   m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides);
204   dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides);
205 
206   // STEP4: create memory
207   auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr());
208   auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr());
209   auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr());
210 
211   auto expected_m1_md = matmul_pd.src_desc();
212   auto expected_m2_md = matmul_pd.weights_desc();
213   auto expected_dst_md = matmul_pd.dst_desc();
214 
215   dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
216   at::Tensor m1_, m2_, dst_;
217 
218   if (attr.with_binary())
219     attr.construct_post_binary(matmul_pd, args);
220 
221   size_t scratchpad_size = matmul_pd.scratchpad_desc().get_size();
222   at::Tensor scratchpad_tensor = at::empty(
223       {static_cast<int64_t>(scratchpad_size)}, m1.options().dtype(at::kByte), std::nullopt);
224   auto scratchpad_memory = make_onednn_memory(
225       matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
226   args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
227 
228   args.insert({DNNL_ARG_SRC, m1_m});
229   args.insert({DNNL_ARG_WEIGHTS, m2_m});
230   args.insert({DNNL_ARG_DST, dst_m});
231   if (with_bias) {
232     auto bias_m = make_onednn_memory(bias_md, engine, b.data_ptr());
233     args.insert({DNNL_ARG_BIAS, bias_m});
234   }
235 
236   sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps);
237 
238   if (!dst.is_same(result))
239     result.copy_(dst);
240 
241   return matmul_event;
242 }
243 
244 } // namespace at::native::onednn
245