1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <executorch/runtime/platform/runtime.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::IntArrayRef;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using executorch::runtime::MemoryAllocator;
23 using torch::executor::testing::TensorFactory;
24
25 class TempMemoryAllocator final : public MemoryAllocator {
26 private:
27 // We allocate a little more than requested and use that memory as a node in
28 // a linked list, pushing the allocated buffers onto a list that's iterated
29 // and freed when the KernelRuntimeContext is destroyed.
30 struct AllocationNode {
31 void* data;
32 AllocationNode* next;
33 };
34
35 AllocationNode* head_ = nullptr;
36
37 public:
TempMemoryAllocator()38 TempMemoryAllocator() : MemoryAllocator(0, nullptr) {}
39
allocate(size_t size,size_t alignment=kDefaultAlignment)40 void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
41 if (!isPowerOf2(alignment)) {
42 ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
43 return nullptr;
44 }
45
46 // Allocate enough memory for the node, the data and the alignment bump.
47 size_t alloc_size = sizeof(AllocationNode) + size + alignment;
48 void* node_memory = malloc(alloc_size);
49
50 // If allocation failed, log message and return nullptr.
51 if (node_memory == nullptr) {
52 ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size);
53 return nullptr;
54 }
55
56 // Compute data pointer.
57 uint8_t* data_ptr =
58 reinterpret_cast<uint8_t*>(node_memory) + sizeof(AllocationNode);
59
60 // Align the data pointer.
61 void* aligned_data_ptr = alignPointer(data_ptr, alignment);
62
63 // Assert that the alignment didn't overflow the allocated memory.
64 ET_DCHECK_MSG(
65 reinterpret_cast<uintptr_t>(aligned_data_ptr) + size <=
66 reinterpret_cast<uintptr_t>(node_memory) + alloc_size,
67 "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu",
68 aligned_data_ptr,
69 size,
70 node_memory,
71 alloc_size);
72
73 // Construct the node.
74 AllocationNode* new_node = reinterpret_cast<AllocationNode*>(node_memory);
75 new_node->data = aligned_data_ptr;
76 new_node->next = head_;
77 head_ = new_node;
78
79 // Return the aligned data pointer.
80 return head_->data;
81 }
82
reset()83 void reset() override {
84 AllocationNode* current = head_;
85 while (current != nullptr) {
86 AllocationNode* next = current->next;
87 free(current);
88 current = next;
89 }
90 head_ = nullptr;
91 }
92
~TempMemoryAllocator()93 ~TempMemoryAllocator() override {
94 reset();
95 }
96 };
97
op_topk_values(const Tensor & input,int64_t k,int64_t dim,bool largest,bool sorted,Tensor & values,Tensor & indices)98 std::tuple<Tensor&, Tensor&> op_topk_values(
99 const Tensor& input,
100 int64_t k,
101 int64_t dim,
102 bool largest,
103 bool sorted,
104 Tensor& values,
105 Tensor& indices) {
106 TempMemoryAllocator allocator = TempMemoryAllocator();
107 executorch::runtime::KernelRuntimeContext context(nullptr, &allocator);
108 return torch::executor::aten::topk_outf(
109 context, input, k, dim, largest, sorted, values, indices);
110 }
111
112 class OpTopkValuesTest : public ::testing::Test {
113 protected:
SetUp()114 void SetUp() override {
115 // Since these tests cause ET_LOG to be called, the PAL must be initialized
116 // first.
117 torch::executor::runtime_init();
118 }
119 };
120
TEST_F(OpTopkValuesTest,SmokeTest)121 TEST_F(OpTopkValuesTest, SmokeTest) {
122 TensorFactory<ScalarType::Float> tfFloat;
123 TensorFactory<ScalarType::Long> tfLong;
124
125 Tensor input =
126 tfFloat.make({3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
127 int64_t k = 2;
128 int64_t dim = 0;
129 bool largest = true;
130 bool sorted = true;
131 Tensor values = tfFloat.zeros({2, 2, 2});
132 Tensor indices = tfLong.zeros({2, 2, 2});
133 Tensor values_expected = tfFloat.make({2, 2, 2}, {9, 10, 11, 12, 5, 6, 7, 8});
134 Tensor indices_expected = tfLong.make({2, 2, 2}, {2, 2, 2, 2, 1, 1, 1, 1});
135 op_topk_values(input, k, dim, largest, sorted, values, indices);
136 EXPECT_TENSOR_CLOSE(values, values_expected);
137 EXPECT_TENSOR_EQ(indices, indices_expected);
138 }
139