xref: /aosp_15_r20/external/executorch/kernels/test/op_tril_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::TensorFactory;
22 
23 class OpTrilTest : public OperatorTest {
24  protected:
op_tril_out(const Tensor & self,int64_t diagonal,Tensor & out)25   Tensor& op_tril_out(const Tensor& self, int64_t diagonal, Tensor& out) {
26     return torch::executor::aten::tril_outf(context_, self, diagonal, out);
27   }
28 
29   // Assert `self` and `out` as zero tensors is a no-op.
30   template <ScalarType DTYPE>
test_tril_out_zeros()31   void test_tril_out_zeros() {
32     TensorFactory<DTYPE> tf;
33 
34     // clang-format off
35   Tensor self = tf.make(
36     /*sizes=*/{3, 3},
37     /*data=*/
38     {
39         0,  0,  0, // tensor([[ 0,  0,  0],
40         0,  0,  0, //         [ 0,  0,  0],
41         0,  0,  0, //         [ 0,  0,  0]])
42     }
43   );
44     // clang-format on
45 
46     Tensor out = tf.zeros({3, 3});
47 
48     op_tril_out(self, 0, out);
49 
50     // clang-format off
51   Tensor result = tf.make(
52     /*sizes=*/{3, 3},
53     /*data=*/
54     {
55         0,  0,  0, // tensor([[ 0,  0,  0],
56         0,  0,  0, //         [ 0,  0,  0],
57         0,  0,  0, //         [ 0,  0,  0]])
58     }
59   );
60     // clang-format on
61 
62     EXPECT_TENSOR_EQ(out, result);
63   }
64 
65   // Assert `out` as a non-zero tensor yields correct results.
66   template <ScalarType DTYPE>
test_tril_out_ones()67   void test_tril_out_ones() {
68     TensorFactory<DTYPE> tf;
69 
70     // clang-format off
71   Tensor self = tf.make(
72     /*sizes=*/{3, 3},
73     /*data=*/
74     {
75         0,  0,  0, // tensor([[ 0,  0,  0],
76         0,  0,  0, //         [ 0,  0,  0],
77         0,  0,  0, //         [ 0,  0,  0]])
78     }
79   );
80     // clang-format on
81 
82     Tensor out = tf.ones({3, 3});
83 
84     op_tril_out(self, 0, out);
85 
86     // clang-format off
87   Tensor result = tf.make(
88     /*sizes=*/{3, 3},
89     /*data=*/
90     {
91         0,  0,  0, // tensor([[ 0,  0,  0],
92         0,  0,  0, //         [ 0,  0,  0],
93         0,  0,  0, //         [ 0,  0,  0]])
94     }
95   );
96     // clang-format on
97 
98     EXPECT_TENSOR_EQ(out, result);
99   }
100 
101   // Assert `tril` works with multiple empty dims.
102   template <ScalarType DTYPE>
test_tril_out_empty_dims()103   void test_tril_out_empty_dims() {
104     TensorFactory<DTYPE> tf;
105     Tensor out = tf.zeros({1, 1, 1, 1});
106 
107     // tensor([[[[1]]]])
108     Tensor self = tf.ones({1, 1, 1, 1});
109 
110     op_tril_out(self, 0, out);
111 
112     // tensor([[[[1]]]])
113     Tensor result = tf.ones({1, 1, 1, 1});
114 
115     EXPECT_TENSOR_EQ(out, result);
116   }
117 
118   // Assert `tril` works with a square tensor.
119   template <ScalarType DTYPE>
test_tril_out_square()120   void test_tril_out_square() {
121     TensorFactory<DTYPE> tf;
122 
123     // clang-format off
124   Tensor self = tf.make(
125     /*sizes=*/{3, 3},
126     /*data=*/
127     {
128         1,  1,  1, // tensor([[ 1,  1,  1],
129         1,  1,  1, //         [ 1,  1,  1],
130         1,  1,  1, //         [ 1,  1,  1]])
131     }
132   );
133     // clang-format on
134 
135     Tensor out = tf.zeros({3, 3});
136 
137     op_tril_out(self, 0, out);
138 
139     // clang-format off
140   Tensor result = tf.make(
141     /*sizes=*/{3, 3},
142     /*data=*/
143     {
144         1,  0,  0, // tensor([[ 1,  0,  0],
145         1,  1,  0, //         [ 1,  1,  0],
146         1,  1,  1, //         [ 1,  1,  1]])
147     }
148   );
149     // clang-format on
150 
151     EXPECT_TENSOR_EQ(out, result);
152   }
153 
154   // Assert `tril` works with a rectangular tensor.
155   template <ScalarType DTYPE>
test_tril_out_rectangle()156   void test_tril_out_rectangle() {
157     TensorFactory<DTYPE> tf;
158 
159     // clang-format off
160   Tensor self = tf.make(
161     /*sizes=*/{3, 5},
162     /*data=*/
163     {
164         1,  1,  1,  1,  1, // tensor([[ 1,  1,  1,  1,  1],
165         1,  1,  1,  1,  1, //         [ 1,  1,  1,  1,  1],
166         1,  1,  1,  1,  1, //         [ 1,  1,  1,  1,  1]])
167     }
168   );
169     // clang-format on
170 
171     Tensor out = tf.zeros({3, 5});
172 
173     op_tril_out(self, 0, out);
174 
175     // clang-format off
176   Tensor result = tf.make(
177     /*sizes=*/{3, 5},
178     /*data=*/
179     {
180         1,  0,  0,  0,  0, // tensor([[ 1,  0,  0,  0,  0],
181         1,  1,  0,  0,  0, //         [ 1,  1,  0,  0,  0],
182         1,  1,  1,  0,  0, //         [ 1,  1,  1,  0,  0]])
183     }
184   );
185     // clang-format on
186 
187     EXPECT_TENSOR_EQ(out, result);
188   }
189 
190   // Assert `tril` works with a positive diagonal value.
191   template <ScalarType DTYPE>
test_tril_out_pos_diag()192   void test_tril_out_pos_diag() {
193     TensorFactory<DTYPE> tf;
194 
195     // clang-format off
196   Tensor self = tf.make(
197     /*sizes=*/{3, 3},
198     /*data=*/
199     {
200         1,  1,  1, // tensor([[ 1,  1,  1],
201         1,  1,  1, //         [ 1,  1,  1],
202         1,  1,  1, //         [ 1,  1,  1]])
203     }
204   );
205     // clang-format on
206 
207     Tensor out = tf.zeros({3, 3});
208 
209     op_tril_out(self, 1, out);
210 
211     // clang-format off
212   Tensor result = tf.make(
213     /*sizes=*/{3, 3},
214     /*data=*/
215     {
216         1,  1,  0, // tensor([[ 1,  1,  0],
217         1,  1,  1, //         [ 1,  1,  1],
218         1,  1,  1, //         [ 1,  1,  1]])
219     }
220   );
221     // clang-format on
222 
223     EXPECT_TENSOR_EQ(out, result);
224   }
225 
226   // Assert `tril` works with a negative diagonal value.
227   template <ScalarType DTYPE>
test_tril_out_neg_diag()228   void test_tril_out_neg_diag() {
229     TensorFactory<DTYPE> tf;
230 
231     // clang-format off
232   Tensor self = tf.make(
233     /*sizes=*/{3, 3},
234     /*data=*/
235     {
236         1,  1,  1, // tensor([[ 1,  1,  1],
237         1,  1,  1, //         [ 1,  1,  1],
238         1,  1,  1, //         [ 1,  1,  1]])
239     }
240   );
241     // clang-format on
242 
243     Tensor out = tf.zeros({3, 3});
244 
245     op_tril_out(self, -1, out);
246 
247     // clang-format off
248   Tensor result = tf.make(
249     /*sizes=*/{3, 3},
250     /*data=*/
251     {
252         0,  0,  0, // tensor([[ 0,  0,  0],
253         1,  0,  0, //         [ 1,  0,  0],
254         1,  1,  0, //         [ 1,  1,  0]])
255     }
256   );
257     // clang-format on
258 
259     EXPECT_TENSOR_EQ(out, result);
260   }
261 
262   // Assert `tril` works with a batch of tensors, where dims are equal.
263   template <ScalarType DTYPE>
test_tril_out_multi_equal_dim()264   void test_tril_out_multi_equal_dim() {
265     TensorFactory<DTYPE> tf;
266 
267     // clang-format off
268   Tensor self = tf.make(
269     /*sizes=*/{3, 3, 3},
270     /*data=*/
271     {
272         1,  1,  1, // tensor([[[ 1,  1,  1],
273         1,  1,  1, //          [ 1,  1,  1],
274         1,  1,  1, //          [ 1,  1,  1]],
275 
276         1,  1,  1, //         [[ 1,  1,  1],
277         1,  1,  1, //          [ 1,  1,  1],
278         1,  1,  1, //          [ 1,  1,  1]],
279 
280         1,  1,  1, //         [[ 1,  1,  1],
281         1,  1,  1, //          [ 1,  1,  1],
282         1,  1,  1, //          [ 1,  1,  1]]])
283     }
284   );
285     // clang-format on
286 
287     Tensor out = tf.zeros({3, 3, 3});
288 
289     op_tril_out(self, 0, out);
290 
291     // clang-format off
292   Tensor result = tf.make(
293     /*sizes=*/{3, 3, 3},
294     /*data=*/
295     {
296         1,  0,  0, // tensor([[[ 1,  0,  0],
297         1,  1,  0, //          [ 1,  1,  0],
298         1,  1,  1, //          [ 1,  1,  1]],
299 
300         1,  0,  0, //         [[ 1,  0,  0],
301         1,  1,  0, //          [ 1,  1,  0],
302         1,  1,  1, //          [ 1,  1,  1]],
303 
304         1,  0,  0, //         [[ 1,  0,  0],
305         1,  1,  0, //          [ 1,  1,  0],
306         1,  1,  1, //          [ 1,  1,  1]]])
307     }
308   );
309     // clang-format on
310 
311     EXPECT_TENSOR_EQ(out, result);
312   }
313 
314   // Assert `tril` works with a batch of tensors, where dims are unequal.
315   template <ScalarType DTYPE>
test_tril_out_multi_unequal_dim()316   void test_tril_out_multi_unequal_dim() {
317     TensorFactory<DTYPE> tf;
318 
319     // clang-format offF
320     Tensor self = tf.make(
321         /*sizes=*/{3, 2, 3},
322         /*data=*/
323         {
324             1,
325             1,
326             1, // tensor([[[ 1,  1,  1],
327             1,
328             1,
329             1, //          [ 1,  1,  1]],
330 
331             1,
332             1,
333             1, //         [[ 1,  1,  1],
334             1,
335             1,
336             1, //          [ 1,  1,  1]],
337 
338             1,
339             1,
340             1, //         [[ 1,  1,  1],
341             1,
342             1,
343             1, //          [ 1,  1,  1]]])
344         });
345     // clang-format on
346 
347     Tensor out = tf.zeros({3, 2, 3});
348 
349     op_tril_out(self, 0, out);
350 
351     // clang-format off
352   Tensor result = tf.make(
353     /*sizes=*/{3, 2, 3},
354     /*data=*/
355     {
356         1,  0,  0, // tensor([[[ 1,  0,  0],
357         1,  1,  0, //          [ 1,  1,  0]],
358 
359         1,  0,  0, //         [[ 1,  0,  0],
360         1,  1,  0, //          [ 1,  1,  0]],
361 
362         1,  0,  0, //         [[ 1,  0,  0],
363         1,  1,  0, //          [ 1,  1,  0]]])
364     }
365   );
366     // clang-format on
367 
368     EXPECT_TENSOR_EQ(out, result);
369   }
370 
371   // Assert `tril` works with non-0/1 values on regular diagonal.
372   template <ScalarType DTYPE>
test_tril_out_arange_reg_diag()373   void test_tril_out_arange_reg_diag() {
374     TensorFactory<DTYPE> tf;
375 
376     // clang-format off
377   Tensor self = tf.make(
378     /*sizes=*/{3, 3},
379     /*data=*/
380     {
381         1,  2,  3, // tensor([[ 1,  2,  3],
382         4,  5,  6, //         [ 4,  5,  6],
383         7,  8,  9, //         [ 7,  8,  9]])
384     }
385   );
386     // clang-format on
387 
388     Tensor out = tf.zeros({3, 3});
389 
390     op_tril_out(self, 0, out);
391 
392     // clang-format off
393   Tensor result = tf.make(
394     /*sizes=*/{3, 3},
395     /*data=*/
396     {
397         1,  0,  0, // tensor([[ 1,  0,  0],
398         4,  5,  0, //         [ 4,  5,  0],
399         7,  8,  9, //         [ 7,  8,  9]])
400     }
401   );
402     // clang-format on
403 
404     EXPECT_TENSOR_EQ(out, result);
405   }
406 
407   // Assert `tril` works with non-0/1 values on positive diagonal values.
408   // An edge case with a far-out positive diagonal is also included.
409   template <ScalarType DTYPE>
test_tril_out_arange_pos_diag()410   void test_tril_out_arange_pos_diag() {
411     TensorFactory<DTYPE> tf;
412 
413     // Case: diag = 1
414 
415     // clang-format off
416   Tensor self = tf.make(
417     /*sizes=*/{3, 3},
418     /*data=*/
419     {
420         1,  2,  3, // tensor([[ 1,  2,  3],
421         4,  5,  6, //         [ 4,  5,  6],
422         7,  8,  9, //         [ 7,  8,  9]])
423     }
424   );
425     // clang-format on
426 
427     Tensor out1 = tf.zeros({3, 3});
428 
429     op_tril_out(self, 1, out1);
430 
431     // clang-format off
432   Tensor result1 = tf.make(
433     /*sizes=*/{3, 3},
434     /*data=*/
435     {
436         1,  2,  0, // tensor([[ 1,  2,  0],
437         4,  5,  6, //         [ 4,  5,  6],
438         7,  8,  9, //         [ 7,  8,  9]])
439     }
440   );
441     // clang-format on
442 
443     EXPECT_TENSOR_EQ(out1, result1);
444 
445     // Case: diag = 2
446 
447     Tensor out2 = tf.zeros({3, 3});
448     op_tril_out(self, 2, out2);
449     EXPECT_TENSOR_EQ(out2, self);
450 
451     // Case: diag = 10
452 
453     Tensor out3 = tf.zeros({3, 3});
454     op_tril_out(self, 10, out3);
455     EXPECT_TENSOR_EQ(out3, self);
456   }
457 
458   // Assert `tril` works with non-0/1 values on negative diagonal values.
459   // An edge case with a far-out negative diagonal is also included.
460   template <ScalarType DTYPE>
test_tril_out_arange_neg_diag()461   void test_tril_out_arange_neg_diag() {
462     TensorFactory<DTYPE> tf;
463 
464     // Case: diag = -1
465 
466     // clang-format off
467   Tensor self = tf.make(
468     /*sizes=*/{3, 3},
469     /*data=*/
470     {
471         1,  2,  3, // tensor([[ 1,  2,  3],
472         4,  5,  6, //         [ 4,  5,  6],
473         7,  8,  9, //         [ 7,  8,  9]])
474     }
475   );
476     // clang-format on
477 
478     Tensor out1 = tf.zeros({3, 3});
479 
480     op_tril_out(self, -1, out1);
481 
482     // clang-format off
483   Tensor result1 = tf.make(
484     /*sizes=*/{3, 3},
485     /*data=*/
486     {
487         0,  0,  0, // tensor([[ 0,  0,  0],
488         4,  0,  0, //         [ 4,  0,  0],
489         7,  8,  0, //         [ 7,  8,  0]])
490     }
491   );
492     // clang-format on
493 
494     EXPECT_TENSOR_EQ(out1, result1);
495 
496     // Case: diag = 2
497 
498     Tensor out2 = tf.zeros({3, 3});
499 
500     op_tril_out(self, -2, out2);
501 
502     // clang-format off
503   Tensor result2 = tf.make(
504     /*sizes=*/{3, 3},
505     /*data=*/
506     {
507         0,  0,  0, // tensor([[ 0,  0,  0],
508         0,  0,  0, //         [ 0,  0,  0],
509         7,  0,  0, //         [ 7,  0,  0]])
510     }
511   );
512     // clang-format on
513 
514     EXPECT_TENSOR_EQ(out2, result2);
515 
516     // Case: diag = 10
517 
518     Tensor out3 = tf.zeros({3, 3});
519 
520     op_tril_out(self, -10, out3);
521 
522     // clang-format off
523   Tensor result3 = tf.make(
524     /*sizes=*/{3, 3},
525     /*data=*/
526     {
527         0,  0,  0, // tensor([[ 0,  0,  0],
528         0,  0,  0, //         [ 0,  0,  0],
529         0,  0,  0, //         [ 0,  0,  0]])
530     }
531   );
532     // clang-format on
533 
534     EXPECT_TENSOR_EQ(out3, result3);
535   }
536 
537   // Assert `tril` works on a batch of tensors with random integers, where dims
538   // are equal.
539   template <ScalarType DTYPE>
test_tril_out_randint_multi_equal()540   void test_tril_out_randint_multi_equal() {
541     TensorFactory<DTYPE> tf;
542 
543     // clang-format off
544   Tensor self = tf.make(
545     /*sizes=*/{3, 3, 3, 3},
546     /*data=*/
547     {
548         9,  5,  4, // tensor([[[[ 9,  5,  4],
549         3,  9,  6, //           [ 3,  9,  6],
550         9,  9,  5, //           [ 9,  9,  5]],
551 
552         7,  2,  6, //          [[ 7,  2,  6],
553         8,  5,  5, //           [ 8,  5,  5],
554         9,  3,  9, //           [ 9,  3,  9]],
555 
556         1,  2,  1, //          [[ 1,  2,  1],
557         6,  2,  6, //           [ 6,  2,  6],
558         1,  1,  8, //           [ 1,  1,  8]]],
559 
560         3,  2,  5, //         [[[ 3,  2,  5],
561         4,  4,  1, //           [ 4,  4,  1],
562         7,  1,  1, //           [ 7,  1,  1]],
563 
564         5,  7,  8, //          [[ 5,  7,  8],
565         1,  5,  7, //           [ 1,  5,  7],
566         7,  6,  3, //           [ 7,  6,  3]]],
567 
568         3,  5,  9, //          [[ 3,  5,  9],
569         4,  2,  2, //           [ 4,  2,  2],
570         9,  5,  2, //           [ 9,  5,  2]]],
571 
572         8,  4,  7, //         [[[ 8,  4,  7],
573         8,  7,  5, //           [ 8,  7,  5],
574         7,  3,  8, //           [ 7,  3,  8]],
575 
576         9,  5,  5, //          [[ 9,  5,  5],
577         6,  1,  8, //           [ 6,  1,  8],
578         8,  9,  7, //           [ 8,  9,  7]]],
579 
580         1,  2,  3, //          [[ 1,  2,  3],
581         7,  9,  1, //           [ 7,  9,  1],
582         5,  2,  2, //           [ 5,  2,  2]]]])
583     }
584   );
585     // clang-format on
586 
587     Tensor out = tf.zeros({3, 3, 3, 3});
588 
589     op_tril_out(self, 0, out);
590 
591     // clang-format off
592   Tensor result = tf.make(
593     /*sizes=*/{3, 3, 3, 3},
594     /*data=*/
595     {
596         9,  0,  0, // tensor([[[[ 9,  0,  0],
597         3,  9,  0, //           [ 3,  9,  0],
598         9,  9,  5, //           [ 9,  9,  5]],
599 
600         7,  0,  0, //          [[ 7,  0,  0],
601         8,  5,  0, //           [ 8,  5,  0],
602         9,  3,  9, //           [ 9,  3,  9]],
603 
604         1,  0,  0, //          [[ 1,  0,  0],
605         6,  2,  0, //           [ 6,  2,  0],
606         1,  1,  8, //           [ 1,  1,  8]]],
607 
608         3,  0,  0, //         [[[ 3,  0,  0],
609         4,  4,  0, //           [ 4,  4,  0],
610         7,  1,  1, //           [ 7,  1,  1]],
611 
612         5,  0,  0, //          [[ 5,  0,  0],
613         1,  5,  0, //           [ 1,  5,  0],
614         7,  6,  3, //           [ 7,  6,  3]]],
615 
616         3,  0,  0, //          [[ 3,  0,  0],
617         4,  2,  0, //           [ 4,  2,  0],
618         9,  5,  2, //           [ 9,  5,  2]]],
619 
620         8,  0,  0, //         [[[ 8,  0,  0],
621         8,  7,  0, //           [ 8,  7,  0],
622         7,  3,  8, //           [ 7,  3,  8]],
623 
624         9,  0,  0, //          [[ 9,  0,  0],
625         6,  1,  0, //           [ 6,  1,  0],
626         8,  9,  7, //           [ 8,  9,  7]]],
627 
628         1,  0,  0, //          [[ 1,  0,  0],
629         7,  9,  0, //           [ 7,  9,  0],
630         5,  2,  2, //           [ 5,  2,  2]]]])
631     }
632   );
633     // clang-format on
634 
635     EXPECT_TENSOR_EQ(out, result);
636   }
637 
638   // Assert `tril` works on a batch of tensors with random integers, where dims
639   // are unequal.
640   template <ScalarType DTYPE>
test_tril_out_randint_multi_unequal()641   void test_tril_out_randint_multi_unequal() {
642     TensorFactory<DTYPE> tf;
643 
644     // clang-format off
645   Tensor self = tf.make(
646     /*sizes=*/{3, 2, 3, 2},
647     /*data=*/
648     {
649         1,  1, // tensor([[[[ 1,  1],
650         1,  1, //           [ 1,  1],
651         9,  1, //           [ 9,  1]],
652 
653         1,  6, //          [[ 1,  6],
654         6,  2, //           [ 6,  2],
655         7,  2, //           [ 7,  2]],
656 
657         2,  4, //         [[[ 2,  4],
658         8,  3, //           [ 8,  3],
659         4,  2, //           [ 4,  2]]],
660 
661         7,  6, //          [[ 7,  6],
662         1,  8, //           [ 1,  8],
663         4,  3, //           [ 4,  3]],
664 
665         2,  2, //         [[[ 2,  2],
666         7,  4, //           [ 7,  4],
667         3,  7, //           [ 3,  7]]],
668 
669         7,  8, //          [[ 7,  8],
670         4,  9, //           [ 4,  9],
671         1,  6, //           [ 1,  6]]]])
672     }
673   );
674     // clang-format on
675 
676     Tensor out = tf.zeros({3, 2, 3, 2});
677 
678     op_tril_out(self, 0, out);
679 
680     // clang-format off
681   Tensor result = tf.make(
682     /*sizes=*/{3, 2, 3, 2},
683     /*data=*/
684     {
685         1,  0, // tensor([[[[ 1,  0],
686         1,  1, //           [ 1,  1],
687         9,  1, //           [ 9,  1]],
688 
689         1,  0, //          [[ 1,  0],
690         6,  2, //           [ 6,  2],
691         7,  2, //           [ 7,  2]],
692 
693         2,  0, //         [[[ 2,  0],
694         8,  3, //           [ 8,  3],
695         4,  2, //           [ 4,  2]]],
696 
697         7,  0, //          [[ 7,  0],
698         1,  8, //           [ 1,  8],
699         4,  3, //           [ 4,  3]],
700 
701         2,  0, //         [[[ 2,  0],
702         7,  4, //           [ 7,  4],
703         3,  7, //           [ 3,  7]]],
704 
705         7,  0, //          [[ 7,  0],
706         4,  9, //           [ 4,  9],
707         1,  6, //           [ 1,  6]]]])
708     }
709   );
710     // clang-format on
711 
712     EXPECT_TENSOR_EQ(out, result);
713   }
714 };
715 
716 // Create generic tests for all dtypes. Tensors contain 0s or 1s.
717 #define GENERATE_GENERIC_TEST(_, DTYPE)                   \
718   TEST_F(OpTrilTest, DTYPE##GenericTest) {                \
719     test_tril_out_zeros<ScalarType::DTYPE>();             \
720     test_tril_out_ones<ScalarType::DTYPE>();              \
721     test_tril_out_empty_dims<ScalarType::DTYPE>();        \
722     test_tril_out_square<ScalarType::DTYPE>();            \
723     test_tril_out_rectangle<ScalarType::DTYPE>();         \
724     test_tril_out_pos_diag<ScalarType::DTYPE>();          \
725     test_tril_out_neg_diag<ScalarType::DTYPE>();          \
726     test_tril_out_multi_equal_dim<ScalarType::DTYPE>();   \
727     test_tril_out_multi_unequal_dim<ScalarType::DTYPE>(); \
728   }
729 
ET_FORALL_REAL_TYPES_AND(Bool,GENERATE_GENERIC_TEST)730 ET_FORALL_REAL_TYPES_AND(Bool, GENERATE_GENERIC_TEST)
731 
732 // Create generic tests for real dtypes. Tensors have diverse values.
733 #define GENERATE_REAL_TEST(_, DTYPE)                          \
734   TEST_F(OpTrilTest, DTYPE##RealTest) {                       \
735     test_tril_out_arange_pos_diag<ScalarType::DTYPE>();       \
736     test_tril_out_arange_neg_diag<ScalarType::DTYPE>();       \
737     test_tril_out_randint_multi_equal<ScalarType::DTYPE>();   \
738     test_tril_out_randint_multi_unequal<ScalarType::DTYPE>(); \
739   }
740 
741 ET_FORALL_REAL_TYPES(GENERATE_REAL_TEST)
742 
743 TEST_F(OpTrilTest, InvalidInputShapesDies) {
744   TensorFactory<ScalarType::Int> tf;
745 
746   // `self` and `out` invalid shapes: ndims = 0 is <2.
747   Tensor self1 = tf.zeros({});
748   Tensor out1 = tf.zeros({});
749 
750   // Assert `out` can't be filled due to incompatible shapes.
751   ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self1, 0, out1));
752 
753   // `self` and `out` invalid shapes: ndims = 1 is <2.
754   Tensor self2 = tf.zeros({1});
755   Tensor out2 = tf.zeros({1});
756 
757   // Assert `out` can't be filled due to incompatible shapes.
758   ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self2, 0, out2));
759 }
760 
TEST_F(OpTrilTest,MismatchedOutputShapesDies)761 TEST_F(OpTrilTest, MismatchedOutputShapesDies) {
762   // Skip ATen test since it supports `self` and `out` having different shapes.
763   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
764     GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
765   }
766 
767   TensorFactory<ScalarType::Int> tf;
768 
769   // `self` and `out` have different shapes but same dtype.
770   Tensor self = tf.zeros({2, 1});
771   Tensor out = tf.zeros({2, 2});
772 
773   // Assert `out` can't be filled due to incompatible shapes.
774   ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
775 }
776 
TEST_F(OpTrilTest,MismatchedOutputDtypeDies)777 TEST_F(OpTrilTest, MismatchedOutputDtypeDies) {
778   TensorFactory<ScalarType::Byte> tf_byte;
779   TensorFactory<ScalarType::Float> tf_float;
780 
781   // `self` and `out` have different dtypes but same shape.
782   Tensor self = tf_byte.zeros({2, 2});
783   Tensor out = tf_float.zeros({2, 2});
784 
785   // Assert `out` can't be filled due to incompatible dtype.
786   ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
787 }
788 
TEST_F(OpTrilTest,InvalidTensorDims)789 TEST_F(OpTrilTest, InvalidTensorDims) {
790   // Skip ATen test since it supports `self` and `out` having different shapes.
791   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
792     GTEST_SKIP() << "ATen kernel can handle mismatched output shape";
793   }
794 
795   TensorFactory<ScalarType::Int> tf;
796 
797   // Create `self` and `out` with 25 dims.
798   std::vector<int32_t> sizes(25, 1);
799   Tensor self = tf.zeros(sizes);
800   Tensor out = tf.zeros(sizes);
801 
802   // Assert `out` can't be filled due to too many tensor dims.
803   ET_EXPECT_KERNEL_FAILURE(context_, op_tril_out(self, 0, out));
804 }
805