1
2 #include <c10/xpu/XPUFunctions.h>
3
4 #include <ATen/ATen.h>
5 #include <ATen/record_function.h>
6
7 #include <Attr.h>
8 #include <Utils.h>
9
10 #include <oneapi/dnnl/dnnl.hpp>
11
12 namespace at::native::onednn {
13
matmul(at::Tensor & result,const at::Tensor & mat1,const at::Tensor & mat2,const at::Tensor & b_raw,bool m2_trans,Attr attr,const std::vector<sycl::event> & deps)14 sycl::event matmul(
15 at::Tensor& result,
16 const at::Tensor& mat1,
17 const at::Tensor& mat2,
18 const at::Tensor& b_raw,
19 bool m2_trans,
20 Attr attr,
21 const std::vector<sycl::event>& deps) {
22 int64_t dims = result.dim();
23 TORCH_CHECK(
24 dims == 2 || dims == 3,
25 "oneDNN matmul only works with 2D or 3D, got ",
26 dims);
27 TORCH_CHECK(
28 dims == mat1.dim() && dims == mat2.dim(),
29 "oneDNN input matrixes must have the same ranks");
30 TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined");
31
32 at::Device cur_device = at::Device(at::kXPU, c10::xpu::current_device());
33 auto engine = GpuEngineManager::Instance().get_engine(cur_device);
34 auto stream = GpuStreamManager::Instance().get_stream();
35
36 at::Tensor m1 = is_onednn_matmul_strides(mat1) ? mat1 : mat1.contiguous();
37 at::Tensor m2 = is_onednn_matmul_strides(mat2) ? mat2 : mat2.contiguous();
38 at::Tensor dst = is_onednn_matmul_strides(result, true) ? result : result.contiguous();
39
40 int64_t m = dst.size(-2);
41 int64_t n = dst.size(-1);
42 int64_t k = m1.size(-1);
43 int64_t mb = 1;
44
45 if (dims == 3) {
46 mb = dst.size(0);
47 TORCH_CHECK(
48 mb == m1.size(0) && mb == m2.size(0),
49 "batch size mismatch, dst mb: ",
50 mb,
51 "m1 mb",
52 m1.size(0),
53 " m2 mb: ",
54 m2.size(0));
55 }
56
57 // validate bias and make it compatible with oneDNN implementation
58 bool with_bias = false;
59 at::Tensor b = b_raw;
60 if (b.defined()) {
61 with_bias = true;
62 if (b.dim() == 1) {
63 TORCH_CHECK(
64 b.size(0) == n || b.size(0) == 1,
65 "matmul supports [n] or [1] when bias dim is 1 ...");
66 if (b.size(0) == 0) {
67 with_bias = false;
68 } else if (m1.dim() == 3) {
69 b = b.expand({mb, m, n}).contiguous();
70 } else if (m1.dim() == 2) {
71 b = b.expand({1, n}).contiguous();
72 }
73 } else if (b.dim() == 2) {
74 TORCH_CHECK(
75 (b.size(0) == m && b.size(1) == n) ||
76 (b.size(0) == 1 && b.size(1) == n) ||
77 (b.size(0) == m && b.size(1) == 1) ||
78 (b.size(0) == 1 && b.size(1) == 1),
79 "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
80 if (b.size(0) == 1 && b.size(1) == 1)
81 b = b.expand({1, n}).contiguous();
82 } else if (b.dim() == 3) {
83 TORCH_CHECK(
84 at::are_expandable({mb, m, n}, b.sizes()),
85 "matmul bias must be expandable to:",
86 dst.sizes(),
87 " but got:",
88 b.sizes());
89 b = b.expand({mb, m, n}).contiguous();
90 } else if (b.dim() == 0) {
91 TORCH_CHECK(
92 b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
93 if (m1.dim() == 3) {
94 b = b.expand({mb, m, n}).contiguous();
95 } else {
96 b = b.expand({1, n}).contiguous();
97 }
98 } else {
99 TORCH_CHECK(0, "unsupported bias dim in matmul ...");
100 }
101 }
102
103 b = b.contiguous(); // avoid reorder 2 times
104
105 // xpu matmul support both ab/ba shape for m2 tensor, we don't check any more
106 auto m1_usr_dt = get_onednn_dtype(m1);
107 auto m2_usr_dt = get_onednn_dtype(m2);
108 auto dst_usr_dt = get_onednn_dtype(dst);
109
110 auto m1_dt = m1_usr_dt;
111 auto m2_dt = m2_usr_dt;
112 auto dst_dt = dst_usr_dt;
113 dnnl::memory::data_type bias_dt;
114
115 dnnl::memory::desc m1_md, m1_usr_md, m1_any_md;
116 dnnl::memory::desc m2_md, m2_usr_md, m2_any_md;
117 dnnl::memory::desc dst_md, dst_usr_md, dst_any_md;
118 dnnl::memory::desc bias_md;
119
120 // Naive Master weight
121 if (m1_dt == dnnl::memory::data_type::bf16 && m2_dt == dnnl::memory::data_type::f32) {
122 m2_dt = dnnl::memory::data_type::bf16;
123 dst_dt = dnnl::memory::data_type::bf16;
124 } else if (
125 m1_dt == dnnl::memory::data_type::f32 && m2_dt == dnnl::memory::data_type::bf16) {
126 m1_dt = dnnl::memory::data_type::bf16;
127 dst_dt = dnnl::memory::data_type::bf16;
128 }
129
130 dnnl::memory::dims m1_dims, m2_dims, dst_dims, bias_dims;
131 dnnl::memory::dims m1_strides, m2_strides, dst_strides, bias_strides;
132 if (dims == 2) {
133 m1_dims = {m, k};
134 m2_dims = {k, n};
135 dst_dims = {m, n};
136
137 m1_strides = {m1.stride(0), m1.stride(1)};
138 if (m2_trans) {
139 m2_strides = {m2.stride(0), m2.stride(1)};
140 } else {
141 m2_strides = {m2.stride(1), m2.stride(0)};
142 }
143 dst_strides = {dst.stride(0), dst.stride(1)};
144 } else {
145 m1_dims = {mb, m, k};
146 m2_dims = {mb, k, n};
147 dst_dims = {mb, m, n};
148
149 m1_strides = {m1.stride(0), m1.stride(1), m1.stride(2)};
150 if (m2_trans) {
151 m2_strides = {m2.stride(0), m2.stride(1), m2.stride(2)};
152 } else {
153 m2_strides = {m2.stride(0), m2.stride(2), m2.stride(1)};
154 }
155 dst_strides = {dst.stride(0), dst.stride(1), dst.stride(2)};
156 }
157
158 if (with_bias) {
159 bias_dims = get_onednn_dims(b);
160 bias_dt = get_onednn_dtype(b);
161 bias_strides = get_onednn_strides(b);
162 }
163
164 dnnl::post_ops po = attr.extract_post_ops(dst);
165
166 std::unordered_map<int, dnnl::memory> args;
167 dnnl::matmul matmul_p;
168 dnnl::matmul::primitive_desc matmul_pd;
169
170 // STEP1: create memory desc
171 m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides);
172 m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides);
173 dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides);
174
175 // STEP2: creat attribute
176 dnnl::primitive_attr pattr;
177 pattr.set_post_ops(po);
178
179 #if ONEDNN_SUPPORT_DETERMINISTIC
180 if(at::globalContext().deterministicAlgorithms() || at::globalContext().deterministicMkldnn())
181 pattr.set_deterministic(true);
182 #endif
183
184 // scratchpad
185 pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
186
187 if (m1_dt == dnnl::memory::data_type::f32) {
188 pattr.set_fpmath_mode(dnnl::fpmath_mode::strict);
189 }
190
191 // STEP3: create primitive
192 if (with_bias) {
193 bias_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
194 matmul_pd =
195 dnnl::matmul::primitive_desc(engine, m1_md, m2_md, bias_md, dst_md, pattr);
196 } else {
197 matmul_pd = dnnl::matmul::primitive_desc(engine, m1_md, m2_md, dst_md, pattr);
198 }
199
200 matmul_p = dnnl::matmul(matmul_pd);
201
202 m1_usr_md = dnnl::memory::desc(m1_dims, m1_usr_dt, m1_strides);
203 m2_usr_md = dnnl::memory::desc(m2_dims, m2_usr_dt, m2_strides);
204 dst_usr_md = dnnl::memory::desc(dst_dims, dst_usr_dt, dst_strides);
205
206 // STEP4: create memory
207 auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr());
208 auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr());
209 auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr());
210
211 auto expected_m1_md = matmul_pd.src_desc();
212 auto expected_m2_md = matmul_pd.weights_desc();
213 auto expected_dst_md = matmul_pd.dst_desc();
214
215 dnnl::memory m1_m = m1_usr_m, m2_m = m2_usr_m, dst_m = dst_usr_m;
216 at::Tensor m1_, m2_, dst_;
217
218 if (attr.with_binary())
219 attr.construct_post_binary(matmul_pd, args);
220
221 size_t scratchpad_size = matmul_pd.scratchpad_desc().get_size();
222 at::Tensor scratchpad_tensor = at::empty(
223 {static_cast<int64_t>(scratchpad_size)}, m1.options().dtype(at::kByte), std::nullopt);
224 auto scratchpad_memory = make_onednn_memory(
225 matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr());
226 args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory});
227
228 args.insert({DNNL_ARG_SRC, m1_m});
229 args.insert({DNNL_ARG_WEIGHTS, m2_m});
230 args.insert({DNNL_ARG_DST, dst_m});
231 if (with_bias) {
232 auto bias_m = make_onednn_memory(bias_md, engine, b.data_ptr());
233 args.insert({DNNL_ARG_BIAS, bias_m});
234 }
235
236 sycl::event matmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args, deps);
237
238 if (!dst.is_same(result))
239 result.copy_(dst);
240
241 return matmul_event;
242 }
243
244 } // namespace at::native::onednn
245