xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_sdpa_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <limits>
10 
11 #include <executorch/extension/llm/custom_ops/op_sdpa.h>
12 
13 #include <executorch/kernels/test/TestUtil.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16 
17 #include <gtest/gtest.h>
18 
19 using namespace ::testing;
20 using executorch::runtime::testing::TensorFactory;
21 
op_scaled_dot_product_attention(const exec_aten::Tensor & query,const exec_aten::Tensor & key,const exec_aten::Tensor & value,const exec_aten::optional<exec_aten::Tensor> & attn_mask,double dropout_p,bool is_causal,exec_aten::optional<double> scale,exec_aten::Tensor & out)22 exec_aten::Tensor op_scaled_dot_product_attention(
23     const exec_aten::Tensor& query,
24     const exec_aten::Tensor& key,
25     const exec_aten::Tensor& value,
26     const exec_aten::optional<exec_aten::Tensor>& attn_mask,
27     double dropout_p,
28     bool is_causal,
29     exec_aten::optional<double> scale,
30     exec_aten::Tensor& out) {
31   executorch::runtime::KernelRuntimeContext context{};
32   return torch::executor::native::flash_attention_kernel_out(
33       context, query, key, value, attn_mask, dropout_p, is_causal, scale, out);
34 }
35 
36 /*
37 Most tests are generated by FACTO
38 */
39 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_105)40 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_105) {
41   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
42 
43   exec_aten::Tensor query = tfFloat.make(
44       {1, 1, 4, 4},
45       {0.4320,
46        0.1461,
47        0.6817,
48        0.8756,
49        0.8619,
50        0.9165,
51        0.1050,
52        0.0488,
53        0.9832,
54        0.8024,
55        0.3185,
56        0.7671,
57        0.5988,
58        0.2772,
59        0.3965,
60        0.1101});
61   exec_aten::Tensor key = tfFloat.make(
62       {1, 1, 4, 4},
63       {0.4951,
64        0.1630,
65        0.7805,
66        0.7971,
67        0.7538,
68        0.5109,
69        0.0012,
70        0.0018,
71        0.3541,
72        0.6563,
73        0.5831,
74        0.0022,
75        0.7363,
76        0.2270,
77        0.1862,
78        0.2762});
79   exec_aten::Tensor value = tfFloat.make(
80       {1, 1, 4, 4},
81       {0.2914,
82        0.4977,
83        0.0895,
84        0.3630,
85        0.6552,
86        0.1495,
87        0.1673,
88        0.5845,
89        0.8988,
90        0.6690,
91        0.5082,
92        0.9999,
93        0.0609,
94        0.7338,
95        0.2203,
96        0.6971});
97   exec_aten::optional<exec_aten::Tensor> attn_mask;
98   double dropout_p = 0;
99   bool is_causal = false;
100   exec_aten::optional<double> scale;
101   exec_aten::Tensor ret_expected = tfFloat.make(
102       {1, 1, 4, 4},
103       {0.4473,
104        0.5221,
105        0.2302,
106        0.6293,
107        0.4910,
108        0.5032,
109        0.2501,
110        0.6689,
111        0.4630,
112        0.5109,
113        0.2368,
114        0.6449,
115        0.4741,
116        0.5132,
117        0.2444,
118        0.6570});
119   std::vector<int32_t> out_size = {1, 1, 4, 4};
120   exec_aten::Tensor out = tfFloat.zeros(out_size);
121   exec_aten::Tensor ret = op_scaled_dot_product_attention(
122       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
123   EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected, 1e-4, 1e-4);
124 }
125 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_11)126 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_11) {
127   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
128 
129   exec_aten::Tensor query = tfFloat.make(
130       {1, 1, 1, 8},
131       {75.25, -32.875, -96.375, 75.0, -5.25, -30.0, 71.5, -70.875});
132   exec_aten::Tensor key = tfFloat.make(
133       {1, 1, 1, 8},
134       {50.125, 18.0, 72.625, -95.0, 47.25, -74.875, -46.375, -47.0});
135   exec_aten::Tensor value = tfFloat.make(
136       {1, 1, 1, 8},
137       {99.375, 80.125, -81.0, 8.5, -70.375, -54.25, -80.25, 34.125});
138   exec_aten::optional<exec_aten::Tensor> attn_mask =
139       exec_aten::optional<exec_aten::Tensor>(
140           tfFloat.full({1, 1}, std::numeric_limits<float>::infinity()));
141   double dropout_p = 0.0;
142   bool is_causal = false;
143   exec_aten::optional<double> scale;
144   std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
145   exec_aten::Tensor out = tfFloat.zeros(out_size);
146   // Pytorch says these should be NAN
147   // but output is -NAN. Both are really the same though
148   exec_aten::Tensor ret_expected = tfFloat.make(
149       {1, 1, 1, 8}, {-NAN, -NAN, -NAN, -NAN, -NAN, -NAN, -NAN, -NAN});
150   exec_aten::Tensor ret = op_scaled_dot_product_attention(
151       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
152   EXPECT_TENSOR_CLOSE(ret, ret_expected);
153 }
154 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_13)155 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_13) {
156   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
157 
158   exec_aten::Tensor query = tfFloat.make(
159       {1, 8, 1, 1}, {-47.0, 21.25, 74.75, 46.375, 21.0, -29.0, 2.625, 83.125});
160   exec_aten::Tensor key = tfFloat.make(
161       {1, 8, 3, 1},
162       {-43.0,  12.5,    -68.125, -3.25,  -10.0,  65.0,    49.75,   -83.125,
163        97.125, -40.375, -5.5,    93.125, 70.875, -67.375, -44.875, 98.25,
164        -76.25, -74.5,   -23.25,  -66.75, 42.625, -88.0,   -37.75,  -61.625});
165   exec_aten::Tensor value = tfFloat.make(
166       {1, 8, 3, 1},
167       {65.0,   81.125,  8.125,  68.375, -54.25, -1.125, -73.25, -54.0,
168        -28.75, -23.875, 49.0,   63.5,   96.375, 16.625, 79.5,   33.125,
169        32.875, -73.75,  69.125, 7.25,   -35.0,  94.0,   6.75,   65.75});
170   exec_aten::optional<exec_aten::Tensor> attn_mask;
171   double dropout_p = 0.0;
172   bool is_causal = true;
173   exec_aten::optional<double> scale;
174   std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
175   exec_aten::Tensor out = tfFloat.zeros(out_size);
176   exec_aten::Tensor ret_expected = tfFloat.make(
177       {1, 8, 1, 1},
178       {65.0, 68.375, -73.25, -23.875, 96.375, 33.125, 69.125, 94.0});
179   exec_aten::Tensor ret = op_scaled_dot_product_attention(
180       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
181   EXPECT_TENSOR_CLOSE(ret, ret_expected);
182 }
183 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_17)184 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_17) {
185   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
186 
187   exec_aten::Tensor query = tfFloat.make(
188       {3, 2, 2, 6},
189       {69.625,  -98.125, -22.0,   -17.25, -75.625, -43.875, -74.75,  14.5,
190        82.0,    -82.625, 25.125,  -98.0,  -91.5,   65.875,  23.0,    50.25,
191        30.125,  58.25,   -1.375,  23.0,   72.625,  47.875,  -76.125, -62.25,
192        82.0,    -89.25,  75.25,   99.0,   -4.375,  -46.75,  94.875,  -16.375,
193        -90.875, 81.875,  63.75,   -67.25, -13.625, 17.625,  -12.875, 86.0,
194        10.875,  -57.625, 62.75,   -69.5,  -96.625, 80.0,    94.875,  17.5,
195        -17.125, -69.5,   26.375,  25.5,   -51.625, 32.5,    15.0,    65.5,
196        -49.0,   -71.25,  -18.625, -82.0,  94.25,   -56.25,  2.0,     21.25,
197        37.125,  -9.0,    65.0,    -86.75, -77.0,   -26.75,  -99.875, -8.5});
198   exec_aten::Tensor key = tfFloat.make(
199       {3, 2, 4, 6},
200       {98.125,  -86.25,  25.25,   -33.125, -98.0,   -42.5,   44.75,   42.375,
201        -68.625, -97.375, 70.625,  0.75,    51.375,  89.75,   -62.5,   0.5,
202        6.75,    92.875,  10.375,  -20.5,   20.75,   13.625,  -11.0,   99.0,
203        52.75,   31.625,  -97.375, -51.0,   -31.25,  -78.5,   92.125,  -99.75,
204        -10.5,   -39.125, 46.375,  98.5,    -81.5,   -61.375, 29.5,    -39.75,
205        -54.875, 12.0,    80.25,   40.875,  58.25,   96.0,    -97.625, 31.625,
206        63.625,  -3.875,  86.5,    -27.25,  8.875,   57.625,  88.375,  57.125,
207        -17.5,   83.875,  84.75,   -27.375, 90.625,  -24.5,   76.5,    28.625,
208        -71.625, 6.75,    -91.5,   -19.125, 24.5,    -76.0,   -6.5,    -77.625,
209        46.625,  21.125,  -53.25,  -80.375, 59.0,    -21.125, -39.125, 90.75,
210        -68.5,   -18.75,  44.625,  -44.75,  -24.0,   37.0,    -58.125, 13.25,
211        -71.125, 16.875,  -4.625,  10.25,   12.375,  92.875,  76.0,    12.875,
212        32.125,  94.5,    -58.25,  83.25,   -28.375, -27.875, 32.5,    -51.875,
213        -94.75,  -65.5,   -48.875, 18.375,  -54.125, 52.625,  -51.0,   -66.125,
214        64.5,    -31.0,   82.25,   42.0,    37.5,    -72.5,   66.625,  -96.5,
215        59.375,  -69.625, -47.25,  -11.5,   -8.5,    -90.875, -64.75,  -61.75,
216        97.0,    1.75,    -17.375, 99.875,  -85.375, 6.25,    41.625,  5.75,
217        78.375,  -50.75,  9.75,    36.875,  84.5,    19.625,  -83.75,  17.0});
218   exec_aten::Tensor value = tfFloat.make(
219       {3, 2, 4, 6},
220       {-26.375, -65.0,   55.5,    37.0,    90.0,    54.25,   83.75,   -33.75,
221        2.375,   99.5,    71.5,    70.5,    -3.625,  -30.875, 46.125,  -60.5,
222        -7.375,  -82.25,  42.5,    -3.125,  -9.25,   54.0,    -36.875, -67.875,
223        -5.75,   -51.625, -8.875,  -36.25,  86.625,  84.5,    -28.75,  23.375,
224        -39.625, 79.375,  95.0,    -51.125, -28.625, -82.375, 14.5,    -85.875,
225        -92.125, 97.875,  -78.125, -34.0,   16.375,  -1.625,  70.375,  -58.625,
226        96.75,   -95.125, -36.375, -72.875, 16.375,  -38.75,  -58.875, -97.0,
227        -94.25,  -76.125, -30.0,   -60.0,   77.375,  34.75,   -16.5,   5.5,
228        -16.25,  -40.75,  -7.625,  18.875,  -59.125, -56.0,   -7.25,   -14.375,
229        -44.375, 87.625,  38.75,   79.5,    61.5,    29.375,  7.25,    -4.5,
230        -46.25,  -88.875, -0.625,  -6.0,    -23.375, -18.25,  86.0,    33.375,
231        60.25,   -23.125, 37.75,   5.5,     83.875,  -14.625, -89.75,  -84.625,
232        -33.5,   90.5,    -53.125, 11.625,  90.875,  49.0,    -89.625, -6.75,
233        -31.25,  -29.0,   -5.5,    72.5,    44.25,   66.0,    -76.75,  -7.375,
234        52.375,  76.375,  -30.125, -72.875, 37.125,  -83.625, 60.875,  -98.125,
235        -23.625, 85.875,  -25.875, 57.625,  50.75,   76.625,  -72.5,   26.0,
236        65.875,  13.125,  -19.625, 7.5,     -25.5,   40.25,   75.25,   -48.0,
237        8.25,    5.125,   42.375,  23.75,   65.25,   -77.0,   35.625,  -12.0});
238   exec_aten::optional<exec_aten::Tensor> attn_mask;
239   double dropout_p = 0.0;
240   bool is_causal = false;
241   exec_aten::optional<double> scale;
242   exec_aten::Tensor ret_expected = tfFloat.make(
243       {3, 2, 2, 6},
244       {-26.375, -65.0,   55.5,    37.0,    90.0,    54.25,   83.75,   -33.75,
245        2.375,   99.5,    71.5,    70.5,    -28.625, -82.375, 14.5,    -85.875,
246        -92.125, 97.875,  -78.125, -34.0,   16.375,  -1.625,  70.375,  -58.625,
247        77.375,  34.75,   -16.5,   5.5,     -16.25,  -40.75,  -58.875, -97.0,
248        -94.25,  -76.125, -30.0,   -60.0,   37.75,   5.5,     83.875,  -14.625,
249        -89.75,  -84.625, 37.75,   5.5,     83.875,  -14.625, -89.75,  -84.625,
250        -89.625, -6.75,   -31.25,  -29.0,   -5.5,    72.5,    -30.125, -72.875,
251        37.125,  -83.625, 60.875,  -98.125, -23.625, 85.875,  -25.875, 57.625,
252        50.75,   76.625,  -23.625, 85.875,  -25.875, 57.625,  50.75,   76.625});
253   std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
254   exec_aten::Tensor out = tfFloat.zeros(out_size);
255   exec_aten::Tensor ret = op_scaled_dot_product_attention(
256       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
257   EXPECT_TENSOR_CLOSE(ret, ret_expected);
258 }
259 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_18)260 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_18) {
261   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
262 
263   exec_aten::Tensor query = tfFloat.make(
264       {3, 2, 2, 6},
265       {44.0,    -13.875, -10.125, 36.625,  72.875,  -45.0,   87.5,    -5.375,
266        25.25,   -28.625, 8.75,    -95.125, -75.5,   -59.25,  2.25,    -5.75,
267        50.25,   83.375,  -19.0,   43.875,  -98.5,   43.375,  -27.875, 7.875,
268        -15.875, 77.625,  92.5,    -16.375, -2.375,  20.25,   -75.875, -33.875,
269        13.75,   9.875,   0.625,   78.5,    6.625,   -71.625, -38.25,  -33.5,
270        -0.375,  -47.25,  55.875,  -49.0,   66.25,   88.625,  -28.75,  -49.75,
271        -6.5,    23.5,    -84.875, -13.25,  4.875,   -2.125,  -56.25,  85.75,
272        44.5,    -78.75,  -39.875, 31.0,    -73.125, 68.875,  -42.625, 29.75,
273        35.125,  83.0,    29.625,  89.75,   64.875,  91.875,  40.375,  -92.75});
274   exec_aten::Tensor key = tfFloat.make(
275       {3, 2, 4, 6},
276       {-11.375, -70.5,   10.125,  -76.125, -26.5,   -11.375, -1.125,  7.5,
277        94.375,  -50.125, 43.125,  61.75,   39.375,  -79.25,  41.375,  88.75,
278        -72.625, -17.125, 48.0,    80.75,   -66.125, -8.625,  -41.0,   6.75,
279        -37.75,  91.375,  4.0,     27.625,  51.625,  80.5,    -64.5,   21.875,
280        89.0,    -71.625, 32.75,   29.25,   -70.625, 6.875,   -1.75,   55.875,
281        -19.125, -99.125, -73.0,   -62.75,  -17.25,  37.625,  -86.75,  58.75,
282        -40.75,  45.125,  -38.5,   -60.125, 90.625,  99.875,  71.25,   -88.625,
283        74.625,  42.0,    -75.875, 57.375,  -29.0,   -25.75,  72.5,    76.875,
284        -27.0,   -2.625,  -26.375, -94.0,   -71.625, -18.125, -25.875, -62.0,
285        7.625,   73.125,  -87.625, 98.875,  -61.25,  -96.75,  -25.625, -57.875,
286        53.75,   68.25,   84.125,  36.125,  38.125,  -82.375, 92.5,    -82.75,
287        -91.25,  -60.25,  -46.375, 79.625,  20.25,   13.125,  -54.125, -32.625,
288        -35.25,  -51.75,  13.625,  -62.375, 91.0,    -45.5,   85.125,  17.625,
289        99.5,    8.875,   -92.75,  81.375,  18.625,  37.625,  0.75,    23.125,
290        -81.5,   76.75,   10.875,  40.125,  -22.875, -24.875, 52.5,    0.875,
291        59.25,   48.125,  40.875,  -43.25,  -65.625, -27.25,  58.0,    91.125,
292        78.625,  45.875,  76.0,    79.375,  17.0,    9.75,    -26.75,  15.0,
293        -1.0,    -84.75,  -38.5,   -50.625, -68.375, 0.375,   -47.0,   -91.75});
294   exec_aten::Tensor value = tfFloat.make(
295       {3, 2, 4, 6},
296       {-39.25,  69.875, -28.125, 18.375,  89.375,  -39.5,   -55.25,  -42.0,
297        -7.875,  26.625, -6.125,  -98.25,  48.625,  33.625,  48.0,    96.75,
298        -59.125, -85.25, -22.25,  -91.0,   -1.75,   14.25,   -7.75,   -94.375,
299        -97.625, 71.0,   90.875,  -11.5,   -14.625, 52.875,  90.875,  32.875,
300        -84.25,  -57.75, -78.875, -81.75,  86.0,    54.125,  -75.625, -28.375,
301        24.375,  45.125, 80.375,  -42.25,  3.5,     -68.5,   2.875,   58.75,
302        9.625,   -52.75, -31.25,  74.25,   -98.0,   38.0,    59.25,   45.5,
303        67.75,   52.5,   -59.75,  20.0,    83.875,  -46.75,  5.25,    74.375,
304        14.125,  -67.0,  -60.625, 28.5,    20.5,    -96.625, -89.125, 33.875,
305        -89.25,  9.875,  -99.25,  -20.5,   78.625,  37.875,  -72.375, -49.625,
306        22.0,    -54.25, 18.125,  57.75,   72.375,  -11.5,   -52.5,   -28.125,
307        -86.875, -45.0,  60.25,   34.625,  -88.875, 91.0,    -48.25,  98.75,
308        100.0,   33.0,   -69.625, -88.25,  -46.625, -24.75,  -77.5,   93.5,
309        -45.125, 42.75,  -50.0,   -86.0,   -17.375, 85.25,   -28.125, -28.375,
310        46.375,  26.625, 23.0,    -55.875, 39.125,  87.25,   -9.625,  95.375,
311        -27.875, 59.5,   15.5,    -90.0,   39.5,    -15.75,  -16.375, -96.875,
312        -96.125, -47.0,  0.75,    -45.875, 74.625,  46.0,    20.5,    -42.875,
313        -55.0,   30.375, -27.375, 99.375,  18.375,  0.375,   54.25,   -57.75});
314   exec_aten::optional<exec_aten::Tensor> attn_mask;
315   double dropout_p = 0.0;
316   bool is_causal = false;
317   exec_aten::optional<double> scale = exec_aten::optional<double>(-INFINITY);
318   exec_aten::Tensor ret_expected = tfFloat.make(
319       {3, 2, 2, 6},
320       {NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
321        NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
322        NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
323        NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
324        NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN,
325        NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN, NAN});
326   std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
327   exec_aten::Tensor out = tfFloat.zeros(out_size);
328   exec_aten::Tensor ret = op_scaled_dot_product_attention(
329       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
330   EXPECT_TENSOR_CLOSE(ret, ret_expected);
331 }
332 
333 /*
334 // Disabling this test because right now we are enforcing that
335 // attention mask must be 2D
336 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_19) {
337   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
338 
339   exec_aten::Tensor query = tfFloat.make(
340       {3, 2, 2, 6},
341       {-50.875, 17.375,  -42.875, 8.125,   -59.625, -59.125, 0.0,     -76.375,
342        39.625,  -27.75,  -43.375, 71.0,    -96.5,   -48.75,  23.125,  11.125,
343        30.125,  36.75,   -22.25,  35.625,  37.875,  -43.375, -22.875, 74.75,
344        79.375,  -75.25,  66.75,   48.875,  88.875,  -73.5,   79.375,  55.5,
345        -84.0,   93.0,    -19.625, -49.875, 88.625,  -5.0,    -94.625, -13.375,
346        88.375,  -30.625, 39.75,   -15.625, -80.5,   -40.25,  -90.375, -0.5,
347        -47.625, 86.875,  -27.125, 26.75,   41.0,    48.0,    4.375,   10.125,
348        -26.375, 4.25,    56.5,    -45.625, -78.75,  99.625,  -5.5,    -85.0,
349        18.125,  -71.5,   6.0,     -44.125, 59.125,  49.25,   21.125,  -6.5});
350   exec_aten::Tensor key = tfFloat.make(
351       {3, 2, 4, 6},
352       {-36.25,  -6.125,  49.0,    -14.375, 22.25,   17.75,   69.125,  22.625,
353        -0.125,  -85.875, -71.125, -1.375,  -43.75,  -55.25,  71.125,  58.375,
354        19.875,  -98.0,   -16.875, -29.375, 83.875,  19.125,  -18.5,   -34.75,
355        -59.75,  -92.625, -19.375, 55.625,  -1.75,   25.0,    82.25,   8.0,
356        6.75,    28.5,    8.125,   -24.375, 52.875,  -39.75,  66.625,  -31.375,
357        -42.25,  -30.25,  -20.875, 24.75,   -34.5,   -69.75,  -9.0,    65.625,
358        42.125,  89.5,    -1.875,  -88.375, 82.375,  80.25,   7.875,   71.0,
359        84.125,  -9.625,  -62.0,   7.625,   83.0,    55.0,    -65.125, -55.125,
360        -10.0,   17.75,   67.0,    83.25,   51.125,  -13.75,  40.875,  -77.625,
361        19.125,  -48.125, -86.125, -20.5,   -93.125, 64.5,    -5.5,    72.375,
362        86.625,  -21.0,   77.0,    -85.625, 14.5,    69.75,   99.875,  -14.125,
363        36.875,  -50.375, -65.5,   94.5,    64.0,    61.0,    -73.0,   -24.375,
364        -11.5,   -16.75,  92.0,    62.5,    62.375,  -81.625, -25.125, -53.25,
365        -61.375, 58.5,    -67.625, 26.5,    64.0,    27.25,   84.5,    4.125,
366        -82.375, 2.0,     21.5,    0.75,    80.0,    -87.375, 38.75,   -25.25,
367        68.75,   -18.875, 74.75,   -45.625, -15.875, 13.5,    51.25,   37.25,
368        -12.0,   -15.5,   -45.75,  7.375,   1.25,    -54.375, 80.25,   18.875,
369        89.0,    -30.625, -39.5,   -39.0,   46.625,  -46.0,   -87.125, -18.0});
370   exec_aten::Tensor value = tfFloat.make(
371       {3, 2, 4, 6},
372       {-74.5,   -0.25,   -77.125, -74.375, -53.0,   33.625,  -45.0,   66.0,
373        -66.875, -71.875, -9.75,   -41.125, 37.0,    -65.25,  -50.25,  84.75,
374        -67.875, 54.0,    16.875,  -96.5,   91.75,   14.625,  80.875,  -25.875,
375        -62.75,  -92.5,   -77.75,  40.75,   -53.125, -71.875, 10.0,    -4.75,
376        -54.875, -24.25,  48.625,  9.375,   -9.625,  32.875,  -62.75,  99.5,
377        25.125,  85.625,  -29.0,   -33.75,  44.0,    -83.75,  44.125,  -88.625,
378        -17.75,  22.625,  -79.5,   1.0,     -10.625, 10.0,    70.25,   -91.625,
379        -86.0,   83.875,  68.25,   -35.125, -6.25,   -81.25,  -38.375, 56.0,
380        26.875,  -51.75,  -79.625, 83.375,  -31.625, 83.375,  -4.75,   81.875,
381        53.0,    -31.625, -48.625, 76.75,   71.625,  -63.0,   17.25,   -22.0,
382        -7.75,   -77.25,  -92.25,  -2.0,    -88.0,   88.5,    -54.125, -7.875,
383        98.0,    -56.75,  96.125,  -90.0,   -70.0,   -50.125, -53.5,   -65.125,
384        48.375,  98.125,  -89.0,   -97.125, 20.625,  85.5,    -77.625, 76.0,
385        73.625,  58.625,  -90.375, 11.75,   -16.5,   78.125,  95.375,  86.375,
386        -69.125, -92.375, -65.25,  27.875,  77.125,  -59.875, 79.5,    -78.625,
387        15.25,   53.75,   44.625,  -22.0,   -84.0,   -7.25,   22.0,    25.875,
388        17.625,  -86.875, 22.75,   -74.0,   -79.875, -68.0,   -71.125, -81.625,
389        -4.125,  65.875,  1.875,   76.125,  -43.75,  -15.25,  -4.625,  -66.125});
390   exec_aten::optional<exec_aten::Tensor> attn_mask =
391       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
392           {3, 1, 2, 2, 4},
393           {39.0,  49.375,  -87.125, -99.125, 49.375,  -41.125, 26.25,   79.75,
394            91.0,  -3.125,  65.75,   63.5,    -48.375, 43.375,  22.5,    -53.625,
395            -70.0, 2.125,   21.875,  6.375,   -6.375,  75.25,   -35.875, 86.375,
396            71.5,  -35.875, 19.75,   11.625,  -87.25,  49.0,    -6.0,    62.875,
397            7.125, 87.375,  -14.75,  55.5,    59.125,  24.75,   -66.5,   72.375,
398            2.25,  81.375,  -87.125, 35.125,  -39.125, 43.5,    52.875,  39.5}));
399   double dropout_p = 0.0;
400   bool is_causal = false;
401   exec_aten::optional<double> scale;
402   exec_aten::Tensor ret_expected = tfFloat.make(
403       {3, 1, 2, 2, 6},
404       {37.0,
405        -65.25,
406        -50.25,
407        84.75,
408        -67.875,
409        54.0,
410        16.874713897705078,
411        -96.4992446899414,
412        91.749267578125,
413        14.624600410461426,
414        80.87458038330078,
415        -25.87506866455078,
416        -62.75,
417        -92.5,
418        -77.75,
419        40.75,
420        -53.125,
421        -71.875,
422        -29.0,
423        -33.75,
424        44.0,
425        -83.75,
426        44.125,
427        -88.625,
428        -79.625,
429        83.375,
430        -31.625,
431        83.375,
432        -4.75,
433        81.875,
434        -6.25,
435        -81.25,
436        -38.375,
437        56.0,
438        26.875,
439        -51.75,
440        17.25,
441        -22.0,
442        -7.75,
443        -77.25,
444        -92.25,
445        -2.0,
446        53.0,
447        -31.625,
448        -48.625,
449        76.75,
450        71.625,
451        -63.0,
452        -77.625,
453        76.0,
454        73.625,
455        58.625,
456        -90.375,
457        11.75,
458        48.375,
459        98.125,
460        -89.0,
461        -97.125,
462        20.625,
463        85.5,
464        1.875,
465        76.125,
466        -43.75,
467        -15.25,
468        -4.625,
469        -66.125,
470        -79.875,
471        -68.0,
472        -71.125,
473        -81.625,
474        -4.125,
475        65.875});
476   Tensor ret = op_scaled_dot_product_attention(
477       query, key, value, attn_mask, dropout_p, is_causal, scale);
478   EXPECT_TENSOR_CLOSE(ret, ret_expected);
479 }
480 */
481 
TEST(OpScaledDotProductAttentionTest,CorrectnessTest_51)482 TEST(OpScaledDotProductAttentionTest, CorrectnessTest_51) {
483   TensorFactory<exec_aten::ScalarType::Float> tfFloat;
484 
485   exec_aten::Tensor query = tfFloat.make(
486       {1, 1, 8, 3},
487       {-14.0,  46.125, -78.125, -61.375, 52.375, -9.125, 57.875, 88.25,
488        -95.75, 8.875,  -64.625, 41.75,   -62.25, 41.25,  -67.25, 51.25,
489        48.0,   67.625, 30.0,    -59.0,   42.25,  -33.0,  -10.25, -77.5});
490   exec_aten::Tensor key = tfFloat.make(
491       {1, 1, 3, 3},
492       {6.0, 58.5, -37.875, -11.125, -18.5, 35.0, 59.25, 73.0, 34.125});
493   exec_aten::Tensor value = tfFloat.make(
494       {1, 1, 3, 3},
495       {70.375, 30.875, 72.125, 53.0, 39.125, -4.625, 26.5, 79.5, 88.625});
496   exec_aten::optional<exec_aten::Tensor> attn_mask =
497       exec_aten::optional<exec_aten::Tensor>(tfFloat.make(
498           {8, 3},
499           {-59.25, -26.25, -3.0,  -24.125, 47.75,  92.375,  87.5,    21.5,
500            64.5,   45.0,   -54.0, 17.375,  -67.75, 14.625,  88.75,   36.0,
501            88.375, 25.75,  42.5,  -13.375, -82.75, -59.625, -21.125, 6.5}));
502   double dropout_p = 0.0;
503   bool is_causal = false;
504   exec_aten::optional<double> scale;
505   exec_aten::Tensor ret_expected = tfFloat.make(
506       {1, 1, 8, 3},
507       {70.375, 30.875, 72.125, 70.375, 30.875, 72.125, 70.375, 30.875,
508        72.125, 53.0,   39.125, -4.625, 70.375, 30.875, 72.125, 26.5,
509        79.5,   88.625, 53.0,   39.125, -4.625, 70.375, 30.875, 72.125});
510   std::vector<int32_t> out_size(query.sizes().begin(), query.sizes().end());
511   exec_aten::Tensor out = tfFloat.zeros(out_size);
512   exec_aten::Tensor ret = op_scaled_dot_product_attention(
513       query, key, value, attn_mask, dropout_p, is_causal, scale, out);
514   EXPECT_TENSOR_CLOSE(ret, ret_expected);
515 }
516