xref: /aosp_15_r20/external/executorch/kernels/test/op_topk_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/platform/runtime.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::IntArrayRef;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using executorch::runtime::MemoryAllocator;
23 using torch::executor::testing::TensorFactory;
24 
25 class TempMemoryAllocator final : public MemoryAllocator {
26  private:
27   // We allocate a little more than requested and use that memory as a node in
28   // a linked list, pushing the allocated buffers onto a list that's iterated
29   // and freed when the KernelRuntimeContext is destroyed.
30   struct AllocationNode {
31     void* data;
32     AllocationNode* next;
33   };
34 
35   AllocationNode* head_ = nullptr;
36 
37  public:
TempMemoryAllocator()38   TempMemoryAllocator() : MemoryAllocator(0, nullptr) {}
39 
allocate(size_t size,size_t alignment=kDefaultAlignment)40   void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
41     if (!isPowerOf2(alignment)) {
42       ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
43       return nullptr;
44     }
45 
46     // Allocate enough memory for the node, the data and the alignment bump.
47     size_t alloc_size = sizeof(AllocationNode) + size + alignment;
48     void* node_memory = malloc(alloc_size);
49 
50     // If allocation failed, log message and return nullptr.
51     if (node_memory == nullptr) {
52       ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size);
53       return nullptr;
54     }
55 
56     // Compute data pointer.
57     uint8_t* data_ptr =
58         reinterpret_cast<uint8_t*>(node_memory) + sizeof(AllocationNode);
59 
60     // Align the data pointer.
61     void* aligned_data_ptr = alignPointer(data_ptr, alignment);
62 
63     // Assert that the alignment didn't overflow the allocated memory.
64     ET_DCHECK_MSG(
65         reinterpret_cast<uintptr_t>(aligned_data_ptr) + size <=
66             reinterpret_cast<uintptr_t>(node_memory) + alloc_size,
67         "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu",
68         aligned_data_ptr,
69         size,
70         node_memory,
71         alloc_size);
72 
73     // Construct the node.
74     AllocationNode* new_node = reinterpret_cast<AllocationNode*>(node_memory);
75     new_node->data = aligned_data_ptr;
76     new_node->next = head_;
77     head_ = new_node;
78 
79     // Return the aligned data pointer.
80     return head_->data;
81   }
82 
reset()83   void reset() override {
84     AllocationNode* current = head_;
85     while (current != nullptr) {
86       AllocationNode* next = current->next;
87       free(current);
88       current = next;
89     }
90     head_ = nullptr;
91   }
92 
~TempMemoryAllocator()93   ~TempMemoryAllocator() override {
94     reset();
95   }
96 };
97 
op_topk_values(const Tensor & input,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices)98 std::tuple<Tensor&, Tensor&> op_topk_values(
99     const Tensor& input,
100     int64_t k,
101     int64_t dim,
102     bool largest,
103     bool sorted,
104     Tensor& values,
105     Tensor& indices) {
106   TempMemoryAllocator allocator = TempMemoryAllocator();
107   executorch::runtime::KernelRuntimeContext context(nullptr, &allocator);
108   return torch::executor::aten::topk_outf(
109       context, input, k, dim, largest, sorted, values, indices);
110 }
111 
112 class OpTopkValuesTest : public ::testing::Test {
113  protected:
SetUp()114   void SetUp() override {
115     // Since these tests cause ET_LOG to be called, the PAL must be initialized
116     // first.
117     torch::executor::runtime_init();
118   }
119 };
120 
TEST_F(OpTopkValuesTest,SmokeTest)121 TEST_F(OpTopkValuesTest, SmokeTest) {
122   TensorFactory<ScalarType::Float> tfFloat;
123   TensorFactory<ScalarType::Long> tfLong;
124 
125   Tensor input =
126       tfFloat.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
127   int64_t k = 2;
128   int64_t dim = 0;
129   bool largest = true;
130   bool sorted = true;
131   Tensor values = tfFloat.zeros({2, 2, 2});
132   Tensor indices = tfLong.zeros({2, 2, 2});
133   Tensor values_expected = tfFloat.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
134   Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
135   op_topk_values(input, k, dim, largest, sorted, values, indices);
136   EXPECT_TENSOR_CLOSE(values, values_expected);
137   EXPECT_TENSOR_EQ(indices, indices_expected);
138 }
139