xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_tile_crop_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/custom_ops/op_tile_crop.h>
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
13 #include <gtest/gtest.h>
14 
15 using namespace ::testing;
16 using exec_aten::ScalarType;
17 using exec_aten::Tensor;
18 using executorch::runtime::testing::TensorFactory;
19 
20 class OpTileCropOutTest : public OperatorTest {
21  protected:
op_tile_crop_out(const Tensor & self,int64_t tile_size,Tensor & out)22   Tensor& op_tile_crop_out(const Tensor& self, int64_t tile_size, Tensor& out) {
23     return torch::executor::native::tile_crop_out_impl(
24         context_, self, tile_size, out);
25   }
26 
27   template <ScalarType DTYPE_IN>
test_tile_crop()28   void test_tile_crop() {
29     TensorFactory<DTYPE_IN> tf_in;
30 
31     const std::vector<int32_t> sizes = {1, 4, 4};
32     const std::vector<int32_t> out_sizes = {4, 1, 2, 2};
33 
34     Tensor out = tf_in.zeros(out_sizes);
35 
36     // clang-format off
37     op_tile_crop_out(
38         tf_in.make(
39             sizes, { 0,  1,  2,  3,
40                      4,  5,  6,  7,
41                      8,  9, 10, 11,
42                     12, 13, 14, 15}),
43         2,
44         out);
45     EXPECT_TENSOR_EQ(
46         out,
47         tf_in.make(
48             out_sizes, {0,  1,  4,  5,
49                         2,  3,  6,  7,
50                         8,  9, 12, 13,
51                        10, 11, 14, 15}));
52     // clang-format on
53   }
54 };
55 
56 //
57 // Correctness Tests
58 //
59 
60 /**
61  * Uses the function templates above to test all input dtypes.
62  */
TEST_F(OpTileCropOutTest,AllRealDtypesSupported)63 TEST_F(OpTileCropOutTest, AllRealDtypesSupported){
64 #define ENUMERATE_TEST_ENTRY(ctype, dtype) test_tile_crop<ScalarType::dtype>();
65     ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
66 #undef ENUMERATE_TEST_ENTRY
67 }
68 
69 // Mismatched shape tests.
TEST_F(OpTileCropOutTest,InvalidInputShapeDies)70 TEST_F(OpTileCropOutTest, InvalidInputShapeDies) {
71   TensorFactory<ScalarType::Int> tf;
72 
73   // Input tensors with invalid shapes. 7 is not divisible by tile_size
74   Tensor in = tf.ones(/*sizes=*/{1, 7, 8});
75   Tensor out = tf.zeros(/*sizes=*/{16, 1, 2, 2});
76 
77   ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 2, out));
78 }
79 
TEST_F(OpTileCropOutTest,WrongInputRankDies)80 TEST_F(OpTileCropOutTest, WrongInputRankDies) {
81   TensorFactory<ScalarType::Int> tf;
82 
83   // Tile crop requires a 3D input tensor.
84   Tensor in = tf.ones(/*sizes=*/{1, 2});
85   Tensor out = tf.zeros(/*sizes=*/{1, 2});
86 
87   ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 2, out));
88 }
89 
TEST_F(OpTileCropOutTest,DifferentDtypeDies)90 TEST_F(OpTileCropOutTest, DifferentDtypeDies) {
91   TensorFactory<ScalarType::Int> tf;
92   TensorFactory<ScalarType::Float> tf_float;
93 
94   Tensor in = tf.ones(/*sizes=*/{2, 12, 12});
95 
96   // Tile crop requires two tensors with the same dtype.
97   Tensor out = tf_float.zeros(/*sizes=*/{9, 2, 4, 4});
98 
99   ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, 3, out));
100 }
101 
TEST_F(OpTileCropOutTest,NegativeTileSizeDies)102 TEST_F(OpTileCropOutTest, NegativeTileSizeDies) {
103   TensorFactory<ScalarType::Int> tf;
104   Tensor in = tf.ones(/*sizes=*/{2, 12, 12});
105   Tensor out = tf.zeros(/*sizes=*/{9, 2, 4, 4});
106   ET_EXPECT_KERNEL_FAILURE(context_, op_tile_crop_out(in, -3, out));
107 }
108