1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::TensorFactory;
22
23 class OpTrilTest : public OperatorTest {
24 protected:
op_tril_out(const Tensor & self,int64_t diagonal,Tensor & out)25 Tensor& op_tril_out(const Tensor& self, int64_t diagonal, Tensor& out) {
26 return torch::executor::aten::tril_outf(context_, self, diagonal, out);
27 }
28
29 // Assert `self` and `out` as zero tensors is a no-op.
30 template <ScalarType DTYPE>
test_tril_out_zeros()31 void test_tril_out_zeros() {
32 TensorFactory<DTYPE> tf;
33
34 // clang-format off
35 Tensor self = tf.make(
36 /*sizes=*/{3, 3},
37 /*data=*/
38 {
39 0, 0, 0, // tensor([[ 0, 0, 0],
40 0, 0, 0, // [ 0, 0, 0],
41 0, 0, 0, // [ 0, 0, 0]])
42 }
43 );
44 // clang-format on
45
46 Tensor out = tf.zeros({3, 3});
47
48 op_tril_out(self, 0, out);
49
50 // clang-format off
51 Tensor result = tf.make(
52 /*sizes=*/{3, 3},
53 /*data=*/
54 {
55 0, 0, 0, // tensor([[ 0, 0, 0],
56 0, 0, 0, // [ 0, 0, 0],
57 0, 0, 0, // [ 0, 0, 0]])
58 }
59 );
60 // clang-format on
61
62 EXPECT_TENSOR_EQ(out, result);
63 }
64
65 // Assert `out` as a non-zero tensor yields correct results.
66 template <ScalarType DTYPE>
test_tril_out_ones()67 void test_tril_out_ones() {
68 TensorFactory<DTYPE> tf;
69
70 // clang-format off
71 Tensor self = tf.make(
72 /*sizes=*/{3, 3},
73 /*data=*/
74 {
75 0, 0, 0, // tensor([[ 0, 0, 0],
76 0, 0, 0, // [ 0, 0, 0],
77 0, 0, 0, // [ 0, 0, 0]])
78 }
79 );
80 // clang-format on
81
82 Tensor out = tf.ones({3, 3});
83
84 op_tril_out(self, 0, out);
85
86 // clang-format off
87 Tensor result = tf.make(
88 /*sizes=*/{3, 3},
89 /*data=*/
90 {
91 0, 0, 0, // tensor([[ 0, 0, 0],
92 0, 0, 0, // [ 0, 0, 0],
93 0, 0, 0, // [ 0, 0, 0]])
94 }
95 );
96 // clang-format on
97
98 EXPECT_TENSOR_EQ(out, result);
99 }
100
101 // Assert `tril` works with multiple empty dims.
102 template <ScalarType DTYPE>
test_tril_out_empty_dims()103 void test_tril_out_empty_dims() {
104 TensorFactory<DTYPE> tf;
105 Tensor out = tf.zeros({1, 1, 1, 1});
106
107 // tensor([[[[1]]]])
108 Tensor self = tf.ones({1, 1, 1, 1});
109
110 op_tril_out(self, 0, out);
111
112 // tensor([[[[1]]]])
113 Tensor result = tf.ones({1, 1, 1, 1});
114
115 EXPECT_TENSOR_EQ(out, result);
116 }
117
118 // Assert `tril` works with a square tensor.
119 template <ScalarType DTYPE>
test_tril_out_square()120 void test_tril_out_square() {
121 TensorFactory<DTYPE> tf;
122
123 // clang-format off
124 Tensor self = tf.make(
125 /*sizes=*/{3, 3},
126 /*data=*/
127 {
128 1, 1, 1, // tensor([[ 1, 1, 1],
129 1, 1, 1, // [ 1, 1, 1],
130 1, 1, 1, // [ 1, 1, 1]])
131 }
132 );
133 // clang-format on
134
135 Tensor out = tf.zeros({3, 3});
136
137 op_tril_out(self, 0, out);
138
139 // clang-format off
140 Tensor result = tf.make(
141 /*sizes=*/{3, 3},
142 /*data=*/
143 {
144 1, 0, 0, // tensor([[ 1, 0, 0],
145 1, 1, 0, // [ 1, 1, 0],
146 1, 1, 1, // [ 1, 1, 1]])
147 }
148 );
149 // clang-format on
150
151 EXPECT_TENSOR_EQ(out, result);
152 }
153
154 // Assert `tril` works with a rectangular tensor.
155 template <ScalarType DTYPE>
test_tril_out_rectangle()156 void test_tril_out_rectangle() {
157 TensorFactory<DTYPE> tf;
158
159 // clang-format off
160 Tensor self = tf.make(
161 /*sizes=*/{3, 5},
162 /*data=*/
163 {
164 1, 1, 1, 1, 1, // tensor([[ 1, 1, 1, 1, 1],
165 1, 1, 1, 1, 1, // [ 1, 1, 1, 1, 1],
166 1, 1, 1, 1, 1, // [ 1, 1, 1, 1, 1]])
167 }
168 );
169 // clang-format on
170
171 Tensor out = tf.zeros({3, 5});
172
173 op_tril_out(self, 0, out);
174
175 // clang-format off
176 Tensor result = tf.make(
177 /*sizes=*/{3, 5},
178 /*data=*/
179 {
180 1, 0, 0, 0, 0, // tensor([[ 1, 0, 0, 0, 0],
181 1, 1, 0, 0, 0, // [ 1, 1, 0, 0, 0],
182 1, 1, 1, 0, 0, // [ 1, 1, 1, 0, 0]])
183 }
184 );
185 // clang-format on
186
187 EXPECT_TENSOR_EQ(out, result);
188 }
189
190 // Assert `tril` works with a positive diagonal value.
191 template <ScalarType DTYPE>
test_tril_out_pos_diag()192 void test_tril_out_pos_diag() {
193 TensorFactory<DTYPE> tf;
194
195 // clang-format off
196 Tensor self = tf.make(
197 /*sizes=*/{3, 3},
198 /*data=*/
199 {
200 1, 1, 1, // tensor([[ 1, 1, 1],
201 1, 1, 1, // [ 1, 1, 1],
202 1, 1, 1, // [ 1, 1, 1]])
203 }
204 );
205 // clang-format on
206
207 Tensor out = tf.zeros({3, 3});
208
209 op_tril_out(self, 1, out);
210
211 // clang-format off
212 Tensor result = tf.make(
213 /*sizes=*/{3, 3},
214 /*data=*/
215 {
216 1, 1, 0, // tensor([[ 1, 1, 0],
217 1, 1, 1, // [ 1, 1, 1],
218 1, 1, 1, // [ 1, 1, 1]])
219 }
220 );
221 // clang-format on
222
223 EXPECT_TENSOR_EQ(out, result);
224 }
225
226 // Assert `tril` works with a negative diagonal value.
227 template <ScalarType DTYPE>
test_tril_out_neg_diag()228 void test_tril_out_neg_diag() {
229 TensorFactory<DTYPE> tf;
230
231 // clang-format off
232 Tensor self = tf.make(
233 /*sizes=*/{3, 3},
234 /*data=*/
235 {
236 1, 1, 1, // tensor([[ 1, 1, 1],
237 1, 1, 1, // [ 1, 1, 1],
238 1, 1, 1, // [ 1, 1, 1]])
239 }
240 );
241 // clang-format on
242
243 Tensor out = tf.zeros({3, 3});
244
245 op_tril_out(self, -1, out);
246
247 // clang-format off
248 Tensor result = tf.make(
249 /*sizes=*/{3, 3},
250 /*data=*/
251 {
252 0, 0, 0, // tensor([[ 0, 0, 0],
253 1, 0, 0, // [ 1, 0, 0],
254 1, 1, 0, // [ 1, 1, 0]])
255 }
256 );
257 // clang-format on
258
259 EXPECT_TENSOR_EQ(out, result);
260 }
261
262 // Assert `tril` works with a batch of tensors, where dims are equal.
263 template <ScalarType DTYPE>
test_tril_out_multi_equal_dim()264 void test_tril_out_multi_equal_dim() {
265 TensorFactory<DTYPE> tf;
266
267 // clang-format off
268 Tensor self = tf.make(
269 /*sizes=*/{3, 3, 3},
270 /*data=*/
271 {
272 1, 1, 1, // tensor([[[ 1, 1, 1],
273 1, 1, 1, // [ 1, 1, 1],
274 1, 1, 1, // [ 1, 1, 1]],
275
276 1, 1, 1, // [[ 1, 1, 1],
277 1, 1, 1, // [ 1, 1, 1],
278 1, 1, 1, // [ 1, 1, 1]],
279
280 1, 1, 1, // [[ 1, 1, 1],
281 1, 1, 1, // [ 1, 1, 1],
282 1, 1, 1, // [ 1, 1, 1]]])
283 }
284 );
285 // clang-format on
286
287 Tensor out = tf.zeros({3, 3, 3});
288
289 op_tril_out(self, 0, out);
290
291 // clang-format off
292 Tensor result = tf.make(
293 /*sizes=*/{3, 3, 3},
294 /*data=*/
295 {
296 1, 0, 0, // tensor([[[ 1, 0, 0],
297 1, 1, 0, // [ 1, 1, 0],
298 1, 1, 1, // [ 1, 1, 1]],
299
300 1, 0, 0, // [[ 1, 0, 0],
301 1, 1, 0, // [ 1, 1, 0],
302 1, 1, 1, // [ 1, 1, 1]],
303
304 1, 0, 0, // [[ 1, 0, 0],
305 1, 1, 0, // [ 1, 1, 0],
306 1, 1, 1, // [ 1, 1, 1]]])
307 }
308 );
309 // clang-format on
310
311 EXPECT_TENSOR_EQ(out, result);
312 }
313
314 // Assert `tril` works with a batch of tensors, where dims are unequal.
315 template <ScalarType DTYPE>
test_tril_out_multi_unequal_dim()316 void test_tril_out_multi_unequal_dim() {
317 TensorFactory<DTYPE> tf;
318
319 // clang-format offF
320 Tensor self = tf.make(
321 /*sizes=*/{3, 2, 3},
322 /*data=*/
323 {
324 1,
325 1,
326 1, // tensor([[[ 1, 1, 1],
327 1,
328 1,
329 1, // [ 1, 1, 1]],
330
331 1,
332 1,
333 1, // [[ 1, 1, 1],
334 1,
335 1,
336 1, // [ 1, 1, 1]],
337
338 1,
339 1,
340 1, // [[ 1, 1, 1],
341 1,
342 1,
343 1, // [ 1, 1, 1]]])
344 });
345 // clang-format on
346
347 Tensor out = tf.zeros({3, 2, 3});
348
349 op_tril_out(self, 0, out);
350
351 // clang-format off
352 Tensor result = tf.make(
353 /*sizes=*/{3, 2, 3},
354 /*data=*/
355 {
356 1, 0, 0, // tensor([[[ 1, 0, 0],
357 1, 1, 0, // [ 1, 1, 0]],
358
359 1, 0, 0, // [[ 1, 0, 0],
360 1, 1, 0, // [ 1, 1, 0]],
361
362 1, 0, 0, // [[ 1, 0, 0],
363 1, 1, 0, // [ 1, 1, 0]]])
364 }
365 );
366 // clang-format on
367
368 EXPECT_TENSOR_EQ(out, result);
369 }
370
371 // Assert `tril` works with non-0/1 values on regular diagonal.
372 template <ScalarType DTYPE>
test_tril_out_arange_reg_diag()373 void test_tril_out_arange_reg_diag() {
374 TensorFactory<DTYPE> tf;
375
376 // clang-format off
377 Tensor self = tf.make(
378 /*sizes=*/{3, 3},
379 /*data=*/
380 {
381 1, 2, 3, // tensor([[ 1, 2, 3],
382 4, 5, 6, // [ 4, 5, 6],
383 7, 8, 9, // [ 7, 8, 9]])
384 }
385 );
386 // clang-format on
387
388 Tensor out = tf.zeros({3, 3});
389
390 op_tril_out(self, 0, out);
391
392 // clang-format off
393 Tensor result = tf.make(
394 /*sizes=*/{3, 3},
395 /*data=*/
396 {
397 1, 0, 0, // tensor([[ 1, 0, 0],
398 4, 5, 0, // [ 4, 5, 0],
399 7, 8, 9, // [ 7, 8, 9]])
400 }
401 );
402 // clang-format on
403
404 EXPECT_TENSOR_EQ(out, result);
405 }
406
407 // Assert `tril` works with non-0/1 values on positive diagonal values.
408 // An edge case with a far-out positive diagonal is also included.
409 template <ScalarType DTYPE>
test_tril_out_arange_pos_diag()410 void test_tril_out_arange_pos_diag() {
411 TensorFactory<DTYPE> tf;
412
413 // Case: diag = 1
414
415 // clang-format off
416 Tensor self = tf.make(
417 /*sizes=*/{3, 3},
418 /*data=*/
419 {
420 1, 2, 3, // tensor([[ 1, 2, 3],
421 4, 5, 6, // [ 4, 5, 6],
422 7, 8, 9, // [ 7, 8, 9]])
423 }
424 );
425 // clang-format on
426
427 Tensor out1 = tf.zeros({3, 3});
428
429 op_tril_out(self, 1, out1);
430
431 // clang-format off
432 Tensor result1 = tf.make(
433 /*sizes=*/{3, 3},
434 /*data=*/
435 {
436 1, 2, 0, // tensor([[ 1, 2, 0],
437 4, 5, 6, // [ 4, 5, 6],
438 7, 8, 9, // [ 7, 8, 9]])
439 }
440 );
441 // clang-format on
442
443 EXPECT_TENSOR_EQ(out1, result1);
444
445 // Case: diag = 2
446
447 Tensor out2 = tf.zeros({3, 3});
448 op_tril_out(self, 2, out2);
449 EXPECT_TENSOR_EQ(out2, self);
450
451 // Case: diag = 10
452
453 Tensor out3 = tf.zeros({3, 3});
454 op_tril_out(self, 10, out3);
455 EXPECT_TENSOR_EQ(out3, self);
456 }
457
458 // Assert `tril` works with non-0/1 values on negative diagonal values.
459 // An edge case with a far-out negative diagonal is also included.
460 template <ScalarType DTYPE>
test_tril_out_arange_neg_diag()461 void test_tril_out_arange_neg_diag() {
462 TensorFactory<DTYPE> tf;
463
464 // Case: diag = -1
465
466 // clang-format off
467 Tensor self = tf.make(
468 /*sizes=*/{3, 3},
469 /*data=*/
470 {
471 1, 2, 3, // tensor([[ 1, 2, 3],
472 4, 5, 6, // [ 4, 5, 6],
473 7, 8, 9, // [ 7, 8, 9]])
474 }
475 );
476 // clang-format on
477
478 Tensor out1 = tf.zeros({3, 3});
479
480 op_tril_out(self, -1, out1);
481
482 // clang-format off
483 Tensor result1 = tf.make(
484 /*sizes=*/{3, 3},
485 /*data=*/
486 {
487 0, 0, 0, // tensor([[ 0, 0, 0],
488 4, 0, 0, // [ 4, 0, 0],
489 7, 8, 0, // [ 7, 8, 0]])
490 }
491 );
492 // clang-format on
493
494 EXPECT_TENSOR_EQ(out1, result1);
495
496 // Case: diag = 2
497
498 Tensor out2 = tf.zeros({3, 3});
499
500 op_tril_out(self, -2, out2);
501
502 // clang-format off
503 Tensor result2 = tf.make(
504 /*sizes=*/{3, 3},
505 /*data=*/
506 {
507 0, 0, 0, // tensor([[ 0, 0, 0],
508 0, 0, 0, // [ 0, 0, 0],
509 7, 0, 0, // [ 7, 0, 0]])
510 }
511 );
512 // clang-format on
513
514 EXPECT_TENSOR_EQ(out2, result2);
515
516 // Case: diag = 10
517
518 Tensor out3 = tf.zeros({3, 3});
519
520 op_tril_out(self, -10, out3);
521
522 // clang-format off
523 Tensor result3 = tf.make(
524 /*sizes=*/{3, 3},
525 /*data=*/
526 {
527 0, 0, 0, // tensor([[ 0, 0, 0],
528 0, 0, 0, // [ 0, 0, 0],
529 0, 0, 0, // [ 0, 0, 0]])
530 }
531 );
532 // clang-format on
533
534 EXPECT_TENSOR_EQ(out3, result3);
535 }
536
537 // Assert `tril` works on a batch of tensors with random integers, where dims
538 // are equal.
539 template <ScalarType DTYPE>
test_tril_out_randint_multi_equal()540 void test_tril_out_randint_multi_equal() {
541 TensorFactory<DTYPE> tf;
542
543 // clang-format off
544 Tensor self = tf.make(
545 /*sizes=*/{3, 3, 3, 3},
546 /*data=*/
547 {
548 9, 5, 4, // tensor([[[[ 9, 5, 4],
549 3, 9, 6, // [ 3, 9, 6],
550 9, 9, 5, // [ 9, 9, 5]],
551
552 7, 2, 6, // [[ 7, 2, 6],
553 8, 5, 5, // [ 8, 5, 5],
554 9, 3, 9, // [ 9, 3, 9]],
555
556 1, 2, 1, // [[ 1, 2, 1],
557 6, 2, 6, // [ 6, 2, 6],
558 1, 1, 8, // [ 1, 1, 8]]],
559
560 3, 2, 5, // [[[ 3, 2, 5],
561 4, 4, 1, // [ 4, 4, 1],
562 7, 1, 1, // [ 7, 1, 1]],
563
564 5, 7, 8, // [[ 5, 7, 8],
565 1, 5, 7, // [ 1, 5, 7],
566 7, 6, 3, // [ 7, 6, 3]]],
567
568 3, 5, 9, // [[ 3, 5, 9],
569 4, 2, 2, // [ 4, 2, 2],
570 9, 5, 2, // [ 9, 5, 2]]],
571
572 8, 4, 7, // [[[ 8, 4, 7],
573 8, 7, 5, // [ 8, 7, 5],
574 7, 3, 8, // [ 7, 3, 8]],
575
576 9, 5, 5, // [[ 9, 5, 5],
577 6, 1, 8, // [ 6, 1, 8],
578 8, 9, 7, // [ 8, 9, 7]]],
579
580 1, 2, 3, // [[ 1, 2, 3],
581 7, 9, 1, // [ 7, 9, 1],
582 5, 2, 2, // [ 5, 2, 2]]]])
583 }
584 );
585 // clang-format on
586
587 Tensor out = tf.zeros({3, 3, 3, 3});
588
589 op_tril_out(self, 0, out);
590
591 // clang-format off
592 Tensor result = tf.make(
593 /*sizes=*/{3, 3, 3, 3},
594 /*data=*/
595 {
596 9, 0, 0, // tensor([[[[ 9, 0, 0],
597 3, 9, 0, // [ 3, 9, 0],
598 9, 9, 5, // [ 9, 9, 5]],
599
600 7, 0, 0, // [[ 7, 0, 0],
601 8, 5, 0, // [ 8, 5, 0],
602 9, 3, 9, // [ 9, 3, 9]],
603
604 1, 0, 0, // [[ 1, 0, 0],
605 6, 2, 0, // [ 6, 2, 0],
606 1, 1, 8, // [ 1, 1, 8]]],
607
608 3, 0, 0, // [[[ 3, 0, 0],
609 4, 4, 0, // [ 4, 4, 0],
610 7, 1, 1, // [ 7, 1, 1]],
611
612 5, 0, 0, // [[ 5, 0, 0],
613 1, 5, 0, // [ 1, 5, 0],
614 7, 6, 3, // [ 7, 6, 3]]],
615
616 3, 0, 0, // [[ 3, 0, 0],
617 4, 2, 0, // [ 4, 2, 0],
618 9, 5, 2, // [ 9, 5, 2]]],
619
620 8, 0, 0, // [[[ 8, 0, 0],
621 8, 7, 0, // [ 8, 7, 0],
622 7, 3, 8, // [ 7, 3, 8]],
623
624 9, 0, 0, // [[ 9, 0, 0],
625 6, 1, 0, // [ 6, 1, 0],
626 8, 9, 7, // [ 8, 9, 7]]],
627
628 1, 0, 0, // [[ 1, 0, 0],
629 7, 9, 0, // [ 7, 9, 0],
630 5, 2, 2, // [ 5, 2, 2]]]])
631 }
632 );
633 // clang-format on
634
635 EXPECT_TENSOR_EQ(out, result);
636 }
637
638 // Assert `tril` works on a batch of tensors with random integers, where dims
639 // are unequal.
640 template <ScalarType DTYPE>
test_tril_out_randint_multi_unequal()641 void test_tril_out_randint_multi_unequal() {
642 TensorFactory<DTYPE> tf;
643
644 // clang-format off
645 Tensor self = tf.make(
646 /*sizes=*/{3, 2, 3, 2},
647 /*data=*/
648 {
649 1, 1, // tensor([[[[ 1, 1],
650 1, 1, // [ 1, 1],
651 9, 1, // [ 9, 1]],
652
653 1, 6, // [[ 1, 6],
654 6, 2, // [ 6, 2],
655 7, 2, // [ 7, 2]],
656
657 2, 4, // [[[ 2, 4],
658 8, 3, // [ 8, 3],
659 4, 2, // [ 4, 2]]],
660
661 7, 6, // [[ 7, 6],
662 1, 8, // [ 1, 8],
663 4, 3, // [ 4, 3]],
664
665 2, 2, // [[[ 2, 2],
666 7, 4, // [ 7, 4],
667 3, 7, // [ 3, 7]]],
668
669 7, 8, // [[ 7, 8],
670 4, 9, // [ 4, 9],
671 1, 6, // [ 1, 6]]]])
672 }
673 );
674 // clang-format on
675
676 Tensor out = tf.zeros({3, 2, 3, 2});
677
678 op_tril_out(self, 0, out);
679
680 // clang-format off
681 Tensor result = tf.make(
682 /*sizes=*/{3, 2, 3, 2},
683 /*data=*/
684 {
685 1, 0, // tensor([[[[ 1, 0],
686 1, 1, // [ 1, 1],
687 9, 1, // [ 9, 1]],
688
689 1, 0, // [[ 1, 0],
690 6, 2, // [ 6, 2],
691 7, 2, // [ 7, 2]],
692
693 2, 0, // [[[ 2, 0],
694 8, 3, // [ 8, 3],
695 4, 2, // [ 4, 2]]],
696
697 7, 0, // [[ 7, 0],
698 1, 8, // [ 1, 8],
699 4, 3, // [ 4, 3]],
700
701 2, 0, // [[[ 2, 0],
702 7, 4, // [ 7, 4],
703 3, 7, // [ 3, 7]]],
704
705 7, 0, // [[ 7, 0],
706 4, 9, // [ 4, 9],
707 1, 6, // [ 1, 6]]]])
708 }
709 );
710 // clang-format on
711
712 EXPECT_TENSOR_EQ(out, result);
713 }
714 };
715
716 // Create generic tests for all dtypes. Tensors contain 0s or 1s.
717 #define GENERATE_GENERIC_TEST(_, DTYPE) \
718 TEST_F(OpTrilTest, DTYPE##GenericTest) { \
719 test_tril_out_zeros<ScalarType::DTYPE>(); \
720 test_tril_out_ones<ScalarType::DTYPE>(); \
721 test_tril_out_empty_dims<ScalarType::DTYPE>(); \
722 test_tril_out_square<ScalarType::DTYPE>(); \
723 test_tril_out_rectangle<ScalarType::DTYPE>(); \
724 test_tril_out_pos_diag<ScalarType::DTYPE>(); \
725 test_tril_out_neg_diag<ScalarType::DTYPE>(); \
726 test_tril_out_multi_equal_dim<ScalarType::DTYPE>(); \
727 test_tril_out_multi_unequal_dim<ScalarType::DTYPE>(); \
728 }
729
ET_FORALL_REAL_TYPES_AND(Bool,GENERATE_GENERIC_TEST)730 ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_GENERIC_TEST)
731
732 // Create generic tests for real dtypes. Tensors have diverse values.
733 #define GENERATE_REAL_TEST(_, DTYPE) \
734 TEST_F(OpTrilTest, DTYPE##RealTest) { \
735 test_tril_out_arange_pos_diag<ScalarType::DTYPE>(); \
736 test_tril_out_arange_neg_diag<ScalarType::DTYPE>(); \
737 test_tril_out_randint_multi_equal<ScalarType::DTYPE>(); \
738 test_tril_out_randint_multi_unequal<ScalarType::DTYPE>(); \
739 }
740
741 ET_FORALL_REAL_TYPES(GENERATE_REAL_TEST)
742
743 TEST_F(OpTrilTest, InvalidInputShapesDies) {
744 TensorFactory<ScalarType::Int> tf;
745
746 // `self` and `out` invalid shapes: ndims = 0 is <2.
747 Tensor self1 = tf.zeros({});
748 Tensor out1 = tf.zeros({});
749
750 // Assert `out` can't be filled due to incompatible shapes.
751 ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self1, 0, out1));
752
753 // `self` and `out` invalid shapes: ndims = 1 is <2.
754 Tensor self2 = tf.zeros({1});
755 Tensor out2 = tf.zeros({1});
756
757 // Assert `out` can't be filled due to incompatible shapes.
758 ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self2, 0, out2));
759 }
760
TEST_F(OpTrilTest,MismatchedOutputShapesDies)761 TEST_F(OpTrilTest, MismatchedOutputShapesDies) {
762 // Skip ATen test since it supports `self` and `out` having different shapes.
763 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
764 GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
765 }
766
767 TensorFactory<ScalarType::Int> tf;
768
769 // `self` and `out` have different shapes but same dtype.
770 Tensor self = tf.zeros({2, 1});
771 Tensor out = tf.zeros({2, 2});
772
773 // Assert `out` can't be filled due to incompatible shapes.
774 ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
775 }
776
TEST_F(OpTrilTest,MismatchedOutputDtypeDies)777 TEST_F(OpTrilTest, MismatchedOutputDtypeDies) {
778 TensorFactory<ScalarType::Byte> tf_byte;
779 TensorFactory<ScalarType::Float> tf_float;
780
781 // `self` and `out` have different dtypes but same shape.
782 Tensor self = tf_byte.zeros({2, 2});
783 Tensor out = tf_float.zeros({2, 2});
784
785 // Assert `out` can't be filled due to incompatible dtype.
786 ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
787 }
788
TEST_F(OpTrilTest,InvalidTensorDims)789 TEST_F(OpTrilTest, InvalidTensorDims) {
790 // Skip ATen test since it supports `self` and `out` having different shapes.
791 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
792 GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
793 }
794
795 TensorFactory<ScalarType::Int> tf;
796
797 // Create `self` and `out` with 25 dims.
798 std::vector<int32_t> sizes(25, 1);
799 Tensor self = tf.zeros(sizes);
800 Tensor out = tf.zeros(sizes);
801
802 // Assert `out` can't be filled due to too many tensor dims.
803 ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
804 }
805