1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <limits>
10
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
12
13 #include <executorch/kernels/test/TestUtil.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20 using executorch::runtime::testing::TensorFactory;
21
op_scaled_dot_product_attention(const exec_aten::Tensor & query,const exec_aten::Tensor & key,const exec_aten::Tensor & value,const exec_aten::optional<exec_aten::Tensor> & attn_mask,double dropout_p,bool is_causal,exec_aten::optional<double> scale,exec_aten::Tensor & out)22 exec_aten::Tensor op_scaled_dot_product_attention(
23 const exec_aten::Tensor& query,
24 const exec_aten::Tensor& key,
25 const exec_aten::Tensor& value,
26 const exec_aten::optional<exec_aten::Tensor>& attn_mask,
27 double dropout_p,
28 bool is_causal,
29 exec_aten::optional<double> scale,
30 exec_aten::Tensor& out) {
31 executorch::runtime::KernelRuntimeContext context{};
32 return torch::executor::native::flash_attention_kernel_out(
33 context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
34 }
35
36 /*
37 Most tests are generated by FACTO
38 */
39
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_105)40 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_105) {
41 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
42
43 exec_aten::Tensor query = tfFloat.make(
44 {1, 1, 4, 4},
45 {0.4320,
46 0.1461,
47 0.6817,
48 0.8756,
49 0.8619,
50 0.9165,
51 0.1050,
52 0.0488,
53 0.9832,
54 0.8024,
55 0.3185,
56 0.7671,
57 0.5988,
58 0.2772,
59 0.3965,
60 0.1101});
61 exec_aten::Tensor key = tfFloat.make(
62 {1, 1, 4, 4},
63 {0.4951,
64 0.1630,
65 0.7805,
66 0.7971,
67 0.7538,
68 0.5109,
69 0.0012,
70 0.0018,
71 0.3541,
72 0.6563,
73 0.5831,
74 0.0022,
75 0.7363,
76 0.2270,
77 0.1862,
78 0.2762});
79 exec_aten::Tensor value = tfFloat.make(
80 {1, 1, 4, 4},
81 {0.2914,
82 0.4977,
83 0.0895,
84 0.3630,
85 0.6552,
86 0.1495,
87 0.1673,
88 0.5845,
89 0.8988,
90 0.6690,
91 0.5082,
92 0.9999,
93 0.0609,
94 0.7338,
95 0.2203,
96 0.6971});
97 exec_aten::optional<exec_aten::Tensor> attn_mask;
98 double dropout_p = 0;
99 bool is_causal = false;
100 exec_aten::optional<double> scale;
101 exec_aten::Tensor ret_expected = tfFloat.make(
102 {1, 1, 4, 4},
103 {0.4473,
104 0.5221,
105 0.2302,
106 0.6293,
107 0.4910,
108 0.5032,
109 0.2501,
110 0.6689,
111 0.4630,
112 0.5109,
113 0.2368,
114 0.6449,
115 0.4741,
116 0.5132,
117 0.2444,
118 0.6570});
119 std::vector<int32_t> out_size = {1, 1, 4, 4};
120 exec_aten::Tensor out = tfFloat.zeros(out_size);
121 exec_aten::Tensor ret = op_scaled_dot_product_attention(
122 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
123 EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected, 1e-4, 1e-4);
124 }
125
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_11)126 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_11) {
127 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
128
129 exec_aten::Tensor query = tfFloat.make(
130 {1, 1, 1, 8},
131 {75.25, -32.875, -96.375, 75.0, -5.25, -30.0, 71.5, -70.875});
132 exec_aten::Tensor key = tfFloat.make(
133 {1, 1, 1, 8},
134 {50.125, 18.0, 72.625, -95.0, 47.25, -74.875, -46.375, -47.0});
135 exec_aten::Tensor value = tfFloat.make(
136 {1, 1, 1, 8},
137 {99.375, 80.125, -81.0, 8.5, -70.375, -54.25, -80.25, 34.125});
138 exec_aten::optional<exec_aten::Tensor> attn_mask =
139 exec_aten::optional<exec_aten::Tensor>(
140 tfFloat.full({1, 1}, std::numeric_limits<float>::infinity()));
141 double dropout_p = 0.0;
142 bool is_causal = false;
143 exec_aten::optional<double> scale;
144 std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
145 exec_aten::Tensor out = tfFloat.zeros(out_size);
146 // Pytorch says these should be NAN
147 // but output is -NAN. Both are really the same though
148 exec_aten::Tensor ret_expected = tfFloat.make(
149 {1, 1, 1, 8}, {-NAN, -NAN, -NAN, -NAN, -NAN, -NAN, -NAN, -NAN});
150 exec_aten::Tensor ret = op_scaled_dot_product_attention(
151 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
152 EXPECT_TENSOR_CLOSE(ret, ret_expected);
153 }
154
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_13)155 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_13) {
156 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
157
158 exec_aten::Tensor query = tfFloat.make(
159 {1, 8, 1, 1}, {-47.0, 21.25, 74.75, 46.375, 21.0, -29.0, 2.625, 83.125});
160 exec_aten::Tensor key = tfFloat.make(
161 {1, 8, 3, 1},
162 {-43.0, 12.5, -68.125, -3.25, -10.0, 65.0, 49.75, -83.125,
163 97.125, -40.375, -5.5, 93.125, 70.875, -67.375, -44.875, 98.25,
164 -76.25, -74.5, -23.25, -66.75, 42.625, -88.0, -37.75, -61.625});
165 exec_aten::Tensor value = tfFloat.make(
166 {1, 8, 3, 1},
167 {65.0, 81.125, 8.125, 68.375, -54.25, -1.125, -73.25, -54.0,
168 -28.75, -23.875, 49.0, 63.5, 96.375, 16.625, 79.5, 33.125,
169 32.875, -73.75, 69.125, 7.25, -35.0, 94.0, 6.75, 65.75});
170 exec_aten::optional<exec_aten::Tensor> attn_mask;
171 double dropout_p = 0.0;
172 bool is_causal = true;
173 exec_aten::optional<double> scale;
174 std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
175 exec_aten::Tensor out = tfFloat.zeros(out_size);
176 exec_aten::Tensor ret_expected = tfFloat.make(
177 {1, 8, 1, 1},
178 {65.0, 68.375, -73.25, -23.875, 96.375, 33.125, 69.125, 94.0});
179 exec_aten::Tensor ret = op_scaled_dot_product_attention(
180 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
181 EXPECT_TENSOR_CLOSE(ret, ret_expected);
182 }
183
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_17)184 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_17) {
185 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
186
187 exec_aten::Tensor query = tfFloat.make(
188 {3, 2, 2, 6},
189 {69.625, -98.125, -22.0, -17.25, -75.625, -43.875, -74.75, 14.5,
190 82.0, -82.625, 25.125, -98.0, -91.5, 65.875, 23.0, 50.25,
191 30.125, 58.25, -1.375, 23.0, 72.625, 47.875, -76.125, -62.25,
192 82.0, -89.25, 75.25, 99.0, -4.375, -46.75, 94.875, -16.375,
193 -90.875, 81.875, 63.75, -67.25, -13.625, 17.625, -12.875, 86.0,
194 10.875, -57.625, 62.75, -69.5, -96.625, 80.0, 94.875, 17.5,
195 -17.125, -69.5, 26.375, 25.5, -51.625, 32.5, 15.0, 65.5,
196 -49.0, -71.25, -18.625, -82.0, 94.25, -56.25, 2.0, 21.25,
197 37.125, -9.0, 65.0, -86.75, -77.0, -26.75, -99.875, -8.5});
198 exec_aten::Tensor key = tfFloat.make(
199 {3, 2, 4, 6},
200 {98.125, -86.25, 25.25, -33.125, -98.0, -42.5, 44.75, 42.375,
201 -68.625, -97.375, 70.625, 0.75, 51.375, 89.75, -62.5, 0.5,
202 6.75, 92.875, 10.375, -20.5, 20.75, 13.625, -11.0, 99.0,
203 52.75, 31.625, -97.375, -51.0, -31.25, -78.5, 92.125, -99.75,
204 -10.5, -39.125, 46.375, 98.5, -81.5, -61.375, 29.5, -39.75,
205 -54.875, 12.0, 80.25, 40.875, 58.25, 96.0, -97.625, 31.625,
206 63.625, -3.875, 86.5, -27.25, 8.875, 57.625, 88.375, 57.125,
207 -17.5, 83.875, 84.75, -27.375, 90.625, -24.5, 76.5, 28.625,
208 -71.625, 6.75, -91.5, -19.125, 24.5, -76.0, -6.5, -77.625,
209 46.625, 21.125, -53.25, -80.375, 59.0, -21.125, -39.125, 90.75,
210 -68.5, -18.75, 44.625, -44.75, -24.0, 37.0, -58.125, 13.25,
211 -71.125, 16.875, -4.625, 10.25, 12.375, 92.875, 76.0, 12.875,
212 32.125, 94.5, -58.25, 83.25, -28.375, -27.875, 32.5, -51.875,
213 -94.75, -65.5, -48.875, 18.375, -54.125, 52.625, -51.0, -66.125,
214 64.5, -31.0, 82.25, 42.0, 37.5, -72.5, 66.625, -96.5,
215 59.375, -69.625, -47.25, -11.5, -8.5, -90.875, -64.75, -61.75,
216 97.0, 1.75, -17.375, 99.875, -85.375, 6.25, 41.625, 5.75,
217 78.375, -50.75, 9.75, 36.875, 84.5, 19.625, -83.75, 17.0});
218 exec_aten::Tensor value = tfFloat.make(
219 {3, 2, 4, 6},
220 {-26.375, -65.0, 55.5, 37.0, 90.0, 54.25, 83.75, -33.75,
221 2.375, 99.5, 71.5, 70.5, -3.625, -30.875, 46.125, -60.5,
222 -7.375, -82.25, 42.5, -3.125, -9.25, 54.0, -36.875, -67.875,
223 -5.75, -51.625, -8.875, -36.25, 86.625, 84.5, -28.75, 23.375,
224 -39.625, 79.375, 95.0, -51.125, -28.625, -82.375, 14.5, -85.875,
225 -92.125, 97.875, -78.125, -34.0, 16.375, -1.625, 70.375, -58.625,
226 96.75, -95.125, -36.375, -72.875, 16.375, -38.75, -58.875, -97.0,
227 -94.25, -76.125, -30.0, -60.0, 77.375, 34.75, -16.5, 5.5,
228 -16.25, -40.75, -7.625, 18.875, -59.125, -56.0, -7.25, -14.375,
229 -44.375, 87.625, 38.75, 79.5, 61.5, 29.375, 7.25, -4.5,
230 -46.25, -88.875, -0.625, -6.0, -23.375, -18.25, 86.0, 33.375,
231 60.25, -23.125, 37.75, 5.5, 83.875, -14.625, -89.75, -84.625,
232 -33.5, 90.5, -53.125, 11.625, 90.875, 49.0, -89.625, -6.75,
233 -31.25, -29.0, -5.5, 72.5, 44.25, 66.0, -76.75, -7.375,
234 52.375, 76.375, -30.125, -72.875, 37.125, -83.625, 60.875, -98.125,
235 -23.625, 85.875, -25.875, 57.625, 50.75, 76.625, -72.5, 26.0,
236 65.875, 13.125, -19.625, 7.5, -25.5, 40.25, 75.25, -48.0,
237 8.25, 5.125, 42.375, 23.75, 65.25, -77.0, 35.625, -12.0});
238 exec_aten::optional<exec_aten::Tensor> attn_mask;
239 double dropout_p = 0.0;
240 bool is_causal = false;
241 exec_aten::optional<double> scale;
242 exec_aten::Tensor ret_expected = tfFloat.make(
243 {3, 2, 2, 6},
244 {-26.375, -65.0, 55.5, 37.0, 90.0, 54.25, 83.75, -33.75,
245 2.375, 99.5, 71.5, 70.5, -28.625, -82.375, 14.5, -85.875,
246 -92.125, 97.875, -78.125, -34.0, 16.375, -1.625, 70.375, -58.625,
247 77.375, 34.75, -16.5, 5.5, -16.25, -40.75, -58.875, -97.0,
248 -94.25, -76.125, -30.0, -60.0, 37.75, 5.5, 83.875, -14.625,
249 -89.75, -84.625, 37.75, 5.5, 83.875, -14.625, -89.75, -84.625,
250 -89.625, -6.75, -31.25, -29.0, -5.5, 72.5, -30.125, -72.875,
251 37.125, -83.625, 60.875, -98.125, -23.625, 85.875, -25.875, 57.625,
252 50.75, 76.625, -23.625, 85.875, -25.875, 57.625, 50.75, 76.625});
253 std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
254 exec_aten::Tensor out = tfFloat.zeros(out_size);
255 exec_aten::Tensor ret = op_scaled_dot_product_attention(
256 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
257 EXPECT_TENSOR_CLOSE(ret, ret_expected);
258 }
259
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_18)260 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_18) {
261 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
262
263 exec_aten::Tensor query = tfFloat.make(
264 {3, 2, 2, 6},
265 {44.0, -13.875, -10.125, 36.625, 72.875, -45.0, 87.5, -5.375,
266 25.25, -28.625, 8.75, -95.125, -75.5, -59.25, 2.25, -5.75,
267 50.25, 83.375, -19.0, 43.875, -98.5, 43.375, -27.875, 7.875,
268 -15.875, 77.625, 92.5, -16.375, -2.375, 20.25, -75.875, -33.875,
269 13.75, 9.875, 0.625, 78.5, 6.625, -71.625, -38.25, -33.5,
270 -0.375, -47.25, 55.875, -49.0, 66.25, 88.625, -28.75, -49.75,
271 -6.5, 23.5, -84.875, -13.25, 4.875, -2.125, -56.25, 85.75,
272 44.5, -78.75, -39.875, 31.0, -73.125, 68.875, -42.625, 29.75,
273 35.125, 83.0, 29.625, 89.75, 64.875, 91.875, 40.375, -92.75});
274 exec_aten::Tensor key = tfFloat.make(
275 {3, 2, 4, 6},
276 {-11.375, -70.5, 10.125, -76.125, -26.5, -11.375, -1.125, 7.5,
277 94.375, -50.125, 43.125, 61.75, 39.375, -79.25, 41.375, 88.75,
278 -72.625, -17.125, 48.0, 80.75, -66.125, -8.625, -41.0, 6.75,
279 -37.75, 91.375, 4.0, 27.625, 51.625, 80.5, -64.5, 21.875,
280 89.0, -71.625, 32.75, 29.25, -70.625, 6.875, -1.75, 55.875,
281 -19.125, -99.125, -73.0, -62.75, -17.25, 37.625, -86.75, 58.75,
282 -40.75, 45.125, -38.5, -60.125, 90.625, 99.875, 71.25, -88.625,
283 74.625, 42.0, -75.875, 57.375, -29.0, -25.75, 72.5, 76.875,
284 -27.0, -2.625, -26.375, -94.0, -71.625, -18.125, -25.875, -62.0,
285 7.625, 73.125, -87.625, 98.875, -61.25, -96.75, -25.625, -57.875,
286 53.75, 68.25, 84.125, 36.125, 38.125, -82.375, 92.5, -82.75,
287 -91.25, -60.25, -46.375, 79.625, 20.25, 13.125, -54.125, -32.625,
288 -35.25, -51.75, 13.625, -62.375, 91.0, -45.5, 85.125, 17.625,
289 99.5, 8.875, -92.75, 81.375, 18.625, 37.625, 0.75, 23.125,
290 -81.5, 76.75, 10.875, 40.125, -22.875, -24.875, 52.5, 0.875,
291 59.25, 48.125, 40.875, -43.25, -65.625, -27.25, 58.0, 91.125,
292 78.625, 45.875, 76.0, 79.375, 17.0, 9.75, -26.75, 15.0,
293 -1.0, -84.75, -38.5, -50.625, -68.375, 0.375, -47.0, -91.75});
294 exec_aten::Tensor value = tfFloat.make(
295 {3, 2, 4, 6},
296 {-39.25, 69.875, -28.125, 18.375, 89.375, -39.5, -55.25, -42.0,
297 -7.875, 26.625, -6.125, -98.25, 48.625, 33.625, 48.0, 96.75,
298 -59.125, -85.25, -22.25, -91.0, -1.75, 14.25, -7.75, -94.375,
299 -97.625, 71.0, 90.875, -11.5, -14.625, 52.875, 90.875, 32.875,
300 -84.25, -57.75, -78.875, -81.75, 86.0, 54.125, -75.625, -28.375,
301 24.375, 45.125, 80.375, -42.25, 3.5, -68.5, 2.875, 58.75,
302 9.625, -52.75, -31.25, 74.25, -98.0, 38.0, 59.25, 45.5,
303 67.75, 52.5, -59.75, 20.0, 83.875, -46.75, 5.25, 74.375,
304 14.125, -67.0, -60.625, 28.5, 20.5, -96.625, -89.125, 33.875,
305 -89.25, 9.875, -99.25, -20.5, 78.625, 37.875, -72.375, -49.625,
306 22.0, -54.25, 18.125, 57.75, 72.375, -11.5, -52.5, -28.125,
307 -86.875, -45.0, 60.25, 34.625, -88.875, 91.0, -48.25, 98.75,
308 100.0, 33.0, -69.625, -88.25, -46.625, -24.75, -77.5, 93.5,
309 -45.125, 42.75, -50.0, -86.0, -17.375, 85.25, -28.125, -28.375,
310 46.375, 26.625, 23.0, -55.875, 39.125, 87.25, -9.625, 95.375,
311 -27.875, 59.5, 15.5, -90.0, 39.5, -15.75, -16.375, -96.875,
312 -96.125, -47.0, 0.75, -45.875, 74.625, 46.0, 20.5, -42.875,
313 -55.0, 30.375, -27.375, 99.375, 18.375, 0.375, 54.25, -57.75});
314 exec_aten::optional<exec_aten::Tensor> attn_mask;
315 double dropout_p = 0.0;
316 bool is_causal = false;
317 exec_aten::optional<double> scale = exec_aten::optional<double>(-INFINITY);
318 exec_aten::Tensor ret_expected = tfFloat.make(
319 {3, 2, 2, 6},
320 {NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
321 NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
322 NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
323 NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
324 NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
325 NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
326 std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
327 exec_aten::Tensor out = tfFloat.zeros(out_size);
328 exec_aten::Tensor ret = op_scaled_dot_product_attention(
329 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
330 EXPECT_TENSOR_CLOSE(ret, ret_expected);
331 }
332
333 /*
334 // Disabling this test because right now we are enforcing that
335 // attention mask must be 2D
336 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_19) {
337 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
338
339 exec_aten::Tensor query = tfFloat.make(
340 {3, 2, 2, 6},
341 {-50.875, 17.375, -42.875, 8.125, -59.625, -59.125, 0.0, -76.375,
342 39.625, -27.75, -43.375, 71.0, -96.5, -48.75, 23.125, 11.125,
343 30.125, 36.75, -22.25, 35.625, 37.875, -43.375, -22.875, 74.75,
344 79.375, -75.25, 66.75, 48.875, 88.875, -73.5, 79.375, 55.5,
345 -84.0, 93.0, -19.625, -49.875, 88.625, -5.0, -94.625, -13.375,
346 88.375, -30.625, 39.75, -15.625, -80.5, -40.25, -90.375, -0.5,
347 -47.625, 86.875, -27.125, 26.75, 41.0, 48.0, 4.375, 10.125,
348 -26.375, 4.25, 56.5, -45.625, -78.75, 99.625, -5.5, -85.0,
349 18.125, -71.5, 6.0, -44.125, 59.125, 49.25, 21.125, -6.5});
350 exec_aten::Tensor key = tfFloat.make(
351 {3, 2, 4, 6},
352 {-36.25, -6.125, 49.0, -14.375, 22.25, 17.75, 69.125, 22.625,
353 -0.125, -85.875, -71.125, -1.375, -43.75, -55.25, 71.125, 58.375,
354 19.875, -98.0, -16.875, -29.375, 83.875, 19.125, -18.5, -34.75,
355 -59.75, -92.625, -19.375, 55.625, -1.75, 25.0, 82.25, 8.0,
356 6.75, 28.5, 8.125, -24.375, 52.875, -39.75, 66.625, -31.375,
357 -42.25, -30.25, -20.875, 24.75, -34.5, -69.75, -9.0, 65.625,
358 42.125, 89.5, -1.875, -88.375, 82.375, 80.25, 7.875, 71.0,
359 84.125, -9.625, -62.0, 7.625, 83.0, 55.0, -65.125, -55.125,
360 -10.0, 17.75, 67.0, 83.25, 51.125, -13.75, 40.875, -77.625,
361 19.125, -48.125, -86.125, -20.5, -93.125, 64.5, -5.5, 72.375,
362 86.625, -21.0, 77.0, -85.625, 14.5, 69.75, 99.875, -14.125,
363 36.875, -50.375, -65.5, 94.5, 64.0, 61.0, -73.0, -24.375,
364 -11.5, -16.75, 92.0, 62.5, 62.375, -81.625, -25.125, -53.25,
365 -61.375, 58.5, -67.625, 26.5, 64.0, 27.25, 84.5, 4.125,
366 -82.375, 2.0, 21.5, 0.75, 80.0, -87.375, 38.75, -25.25,
367 68.75, -18.875, 74.75, -45.625, -15.875, 13.5, 51.25, 37.25,
368 -12.0, -15.5, -45.75, 7.375, 1.25, -54.375, 80.25, 18.875,
369 89.0, -30.625, -39.5, -39.0, 46.625, -46.0, -87.125, -18.0});
370 exec_aten::Tensor value = tfFloat.make(
371 {3, 2, 4, 6},
372 {-74.5, -0.25, -77.125, -74.375, -53.0, 33.625, -45.0, 66.0,
373 -66.875, -71.875, -9.75, -41.125, 37.0, -65.25, -50.25, 84.75,
374 -67.875, 54.0, 16.875, -96.5, 91.75, 14.625, 80.875, -25.875,
375 -62.75, -92.5, -77.75, 40.75, -53.125, -71.875, 10.0, -4.75,
376 -54.875, -24.25, 48.625, 9.375, -9.625, 32.875, -62.75, 99.5,
377 25.125, 85.625, -29.0, -33.75, 44.0, -83.75, 44.125, -88.625,
378 -17.75, 22.625, -79.5, 1.0, -10.625, 10.0, 70.25, -91.625,
379 -86.0, 83.875, 68.25, -35.125, -6.25, -81.25, -38.375, 56.0,
380 26.875, -51.75, -79.625, 83.375, -31.625, 83.375, -4.75, 81.875,
381 53.0, -31.625, -48.625, 76.75, 71.625, -63.0, 17.25, -22.0,
382 -7.75, -77.25, -92.25, -2.0, -88.0, 88.5, -54.125, -7.875,
383 98.0, -56.75, 96.125, -90.0, -70.0, -50.125, -53.5, -65.125,
384 48.375, 98.125, -89.0, -97.125, 20.625, 85.5, -77.625, 76.0,
385 73.625, 58.625, -90.375, 11.75, -16.5, 78.125, 95.375, 86.375,
386 -69.125, -92.375, -65.25, 27.875, 77.125, -59.875, 79.5, -78.625,
387 15.25, 53.75, 44.625, -22.0, -84.0, -7.25, 22.0, 25.875,
388 17.625, -86.875, 22.75, -74.0, -79.875, -68.0, -71.125, -81.625,
389 -4.125, 65.875, 1.875, 76.125, -43.75, -15.25, -4.625, -66.125});
390 exec_aten::optional<exec_aten::Tensor> attn_mask =
391 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
392 {3, 1, 2, 2, 4},
393 {39.0, 49.375, -87.125, -99.125, 49.375, -41.125, 26.25, 79.75,
394 91.0, -3.125, 65.75, 63.5, -48.375, 43.375, 22.5, -53.625,
395 -70.0, 2.125, 21.875, 6.375, -6.375, 75.25, -35.875, 86.375,
396 71.5, -35.875, 19.75, 11.625, -87.25, 49.0, -6.0, 62.875,
397 7.125, 87.375, -14.75, 55.5, 59.125, 24.75, -66.5, 72.375,
398 2.25, 81.375, -87.125, 35.125, -39.125, 43.5, 52.875, 39.5}));
399 double dropout_p = 0.0;
400 bool is_causal = false;
401 exec_aten::optional<double> scale;
402 exec_aten::Tensor ret_expected = tfFloat.make(
403 {3, 1, 2, 2, 6},
404 {37.0,
405 -65.25,
406 -50.25,
407 84.75,
408 -67.875,
409 54.0,
410 16.874713897705078,
411 -96.4992446899414,
412 91.749267578125,
413 14.624600410461426,
414 80.87458038330078,
415 -25.87506866455078,
416 -62.75,
417 -92.5,
418 -77.75,
419 40.75,
420 -53.125,
421 -71.875,
422 -29.0,
423 -33.75,
424 44.0,
425 -83.75,
426 44.125,
427 -88.625,
428 -79.625,
429 83.375,
430 -31.625,
431 83.375,
432 -4.75,
433 81.875,
434 -6.25,
435 -81.25,
436 -38.375,
437 56.0,
438 26.875,
439 -51.75,
440 17.25,
441 -22.0,
442 -7.75,
443 -77.25,
444 -92.25,
445 -2.0,
446 53.0,
447 -31.625,
448 -48.625,
449 76.75,
450 71.625,
451 -63.0,
452 -77.625,
453 76.0,
454 73.625,
455 58.625,
456 -90.375,
457 11.75,
458 48.375,
459 98.125,
460 -89.0,
461 -97.125,
462 20.625,
463 85.5,
464 1.875,
465 76.125,
466 -43.75,
467 -15.25,
468 -4.625,
469 -66.125,
470 -79.875,
471 -68.0,
472 -71.125,
473 -81.625,
474 -4.125,
475 65.875});
476 Tensor ret = op_scaled_dot_product_attention(
477 query, key, value, attn_mask, dropout_p, is_causal, scale);
478 EXPECT_TENSOR_CLOSE(ret, ret_expected);
479 }
480 */
481
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_51)482 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_51) {
483 TensorFactory<exec_aten::ScalarType::Float> tfFloat;
484
485 exec_aten::Tensor query = tfFloat.make(
486 {1, 1, 8, 3},
487 {-14.0, 46.125, -78.125, -61.375, 52.375, -9.125, 57.875, 88.25,
488 -95.75, 8.875, -64.625, 41.75, -62.25, 41.25, -67.25, 51.25,
489 48.0, 67.625, 30.0, -59.0, 42.25, -33.0, -10.25, -77.5});
490 exec_aten::Tensor key = tfFloat.make(
491 {1, 1, 3, 3},
492 {6.0, 58.5, -37.875, -11.125, -18.5, 35.0, 59.25, 73.0, 34.125});
493 exec_aten::Tensor value = tfFloat.make(
494 {1, 1, 3, 3},
495 {70.375, 30.875, 72.125, 53.0, 39.125, -4.625, 26.5, 79.5, 88.625});
496 exec_aten::optional<exec_aten::Tensor> attn_mask =
497 exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
498 {8, 3},
499 {-59.25, -26.25, -3.0, -24.125, 47.75, 92.375, 87.5, 21.5,
500 64.5, 45.0, -54.0, 17.375, -67.75, 14.625, 88.75, 36.0,
501 88.375, 25.75, 42.5, -13.375, -82.75, -59.625, -21.125, 6.5}));
502 double dropout_p = 0.0;
503 bool is_causal = false;
504 exec_aten::optional<double> scale;
505 exec_aten::Tensor ret_expected = tfFloat.make(
506 {1, 1, 8, 3},
507 {70.375, 30.875, 72.125, 70.375, 30.875, 72.125, 70.375, 30.875,
508 72.125, 53.0, 39.125, -4.625, 70.375, 30.875, 72.125, 26.5,
509 79.5, 88.625, 53.0, 39.125, -4.625, 70.375, 30.875, 72.125});
510 std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
511 exec_aten::Tensor out = tfFloat.zeros(out_size);
512 exec_aten::Tensor ret = op_scaled_dot_product_attention(
513 query, key, value, attn_mask, dropout_p, is_causal, scale, out);
514 EXPECT_TENSOR_CLOSE(ret, ret_expected);
515 }
516