xref: /aosp_15_r20/hardware/interfaces/health/utils/libhealthloop/filterPowerSupplyEventsTest.cpp (revision 4d7e907c777eeecc4c5bd7cf640a754fac206ff7)
1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/unique_fd.h>
18 #include <bpf/libbpf.h>
19 #include <gtest/gtest.h>
20 #include <linux/bpf.h>  // SO_ATTACH_BPF
21 #include <linux/netlink.h>
22 #include <netinet/in.h>
23 #include <stdio.h>
24 #include <sys/socket.h>
25 #include <string>
26 #include <string_view>
27 
28 #define ASSERT_UNIX_OK(e) ASSERT_GE(e, 0) << strerror(errno)
29 
30 // TODO(bvanassche): remove the code below. See also b/357099095.
31 #ifndef SO_ATTACH_BPF
32 #define SO_ATTACH_BPF 50  // From <asm-generic/socket.h>.
33 #endif
34 
35 using ::android::base::unique_fd;
36 using ::testing::ScopedTrace;
37 
38 struct test_data {
39     bool discarded;
40     std::string_view str;
41 };
42 
43 static const uint8_t binary_bpf_prog[] = {
44 #include "filterPowerSupplyEvents.h"
45 };
46 
47 static std::vector<std::unique_ptr<ScopedTrace>>* msg_vec;
48 
operator <<(std::ostream & os,const test_data & td)49 std::ostream& operator<<(std::ostream& os, const test_data& td) {
50     os << "{.discarded=" << td.discarded << ", .str=";
51     for (auto c : td.str) {
52         if (isprint(c)) {
53             os << c;
54         } else {
55             os << ".";
56         }
57     }
58     return os << '}';
59 }
60 
61 #define RECORD_ERR_MSG(fmt, ...)                                          \
62     do {                                                                  \
63         char* str;                                                        \
64         if (asprintf(&str, fmt, ##__VA_ARGS__) < 0) break;                \
65         auto st = std::make_unique<ScopedTrace>(__FILE__, __LINE__, str); \
66         msg_vec->emplace_back(std::move(st));                             \
67         free(str);                                                        \
68     } while (0)
69 
libbpf_print_fn(enum libbpf_print_level,const char * fmt,va_list args)70 int libbpf_print_fn(enum libbpf_print_level, const char* fmt, va_list args) {
71     char* str;
72     if (vasprintf(&str, fmt, args) < 0) {
73         return 0;
74     }
75     msg_vec->emplace_back(std::make_unique<ScopedTrace>(__FILE__, -1, str));
76     free(str);
77     return 0;
78 }
79 
record_libbpf_output()80 static void record_libbpf_output() {
81     libbpf_set_print(libbpf_print_fn);
82 }
83 
84 class filterPseTest : public testing::TestWithParam<test_data> {};
85 
86 struct ConnectedSockets {
87     unique_fd write_fd;
88     unique_fd read_fd;
89 };
90 
91 // socketpair() only supports AF_UNIX sockets. AF_UNIX sockets do not
92 // support BPF filters. Hence connect two TCP sockets with each other.
ConnectSockets(int domain,int type,int protocol)93 static ConnectedSockets ConnectSockets(int domain, int type, int protocol) {
94     int _server_fd = socket(domain, type, protocol);
95     if (_server_fd < 0) {
96         return {};
97     }
98     unique_fd server_fd(_server_fd);
99 
100     int _write_fd = socket(domain, type, protocol);
101     if (_write_fd < 0) {
102         RECORD_ERR_MSG("socket: %s", strerror(errno));
103         return {};
104     }
105     unique_fd write_fd(_write_fd);
106 
107     struct sockaddr_in sa = {.sin_family = AF_INET, .sin_addr.s_addr = INADDR_ANY};
108     if (bind(_server_fd, (const struct sockaddr*)&sa, sizeof(sa)) < 0) {
109         RECORD_ERR_MSG("bind: %s", strerror(errno));
110         return {};
111     }
112     if (listen(_server_fd, 1) < 0) {
113         RECORD_ERR_MSG("listen: %s", strerror(errno));
114         return {};
115     }
116     socklen_t addr_len = sizeof(sa);
117     if (getsockname(_server_fd, (struct sockaddr*)&sa, &addr_len) < 0) {
118         RECORD_ERR_MSG("getsockname: %s", strerror(errno));
119         return {};
120     }
121     errno = 0;
122     if (connect(_write_fd, (const struct sockaddr*)&sa, sizeof(sa)) < 0 && errno != EINPROGRESS) {
123         RECORD_ERR_MSG("connect: %s", strerror(errno));
124         return {};
125     }
126     int _read_fd = accept(_server_fd, NULL, NULL);
127     if (_read_fd < 0) {
128         RECORD_ERR_MSG("accept: %s", strerror(errno));
129         return {};
130     }
131     unique_fd read_fd(_read_fd);
132 
133     return {.write_fd = std::move(write_fd), .read_fd = std::move(read_fd)};
134 }
135 
TEST_P(filterPseTest,filterPse)136 TEST_P(filterPseTest, filterPse) {
137     if (getuid() != 0) {
138         GTEST_SKIP() << "Must be run as root.";
139         return;
140     }
141     if (!msg_vec) {
142         msg_vec = new typeof(*msg_vec);
143     }
144     std::unique_ptr<int, void (*)(int*)> clear_msg_vec_at_end_of_scope(new int, [](int* p) {
145         msg_vec->clear();
146         delete p;
147     });
148     record_libbpf_output();
149 
150     auto connected_sockets = ConnectSockets(AF_INET, SOCK_STREAM, 0);
151     unique_fd write_fd = std::move(connected_sockets.write_fd);
152     unique_fd read_fd = std::move(connected_sockets.read_fd);
153 
154     ASSERT_UNIX_OK(fcntl(read_fd, F_SETFL, O_NONBLOCK));
155 
156     bpf_object* obj = bpf_object__open_mem(binary_bpf_prog, sizeof(binary_bpf_prog), NULL);
157     ASSERT_TRUE(obj) << "bpf_object__open() failed" << strerror(errno);
158 
159     // Find the BPF program within the object.
160     bpf_program* prog = bpf_object__find_program_by_name(obj, "filterPowerSupplyEvents");
161     ASSERT_TRUE(prog);
162 
163     ASSERT_UNIX_OK(bpf_program__set_type(prog, BPF_PROG_TYPE_SOCKET_FILTER));
164 
165     ASSERT_UNIX_OK(bpf_object__load(obj));
166 
167     int filter_fd = bpf_program__fd(prog);
168     ASSERT_UNIX_OK(filter_fd);
169 
170     int setsockopt_result =
171             setsockopt(read_fd, SOL_SOCKET, SO_ATTACH_BPF, &filter_fd, sizeof(filter_fd));
172     ASSERT_UNIX_OK(setsockopt_result);
173 
174     const test_data param = GetParam();
175     const std::string header(sizeof(struct nlmsghdr), '\0');
176     ASSERT_EQ(header.length(), sizeof(struct nlmsghdr));
177     const std::string data = header + std::string(param.str);
178     const size_t len = data.length();
179     std::cerr.write(data.data(), data.length());
180     std::cerr << ")\n";
181     ASSERT_EQ(write(write_fd, data.data(), len), len);
182     std::array<uint8_t, 512> read_buf;
183     int bytes_read = read(read_fd, read_buf.data(), read_buf.size());
184     if (bytes_read < 0) {
185         ASSERT_EQ(errno, EAGAIN);
186         bytes_read = 0;
187     } else {
188         ASSERT_LT(bytes_read, read_buf.size());
189     }
190     EXPECT_EQ(bytes_read, param.discarded ? 0 : len);
191 
192     bpf_object__close(obj);
193 }
194 
195 static constexpr char input0[] = "a";
196 static constexpr char input1[] = "abc\0SUBSYSTEM=block\0";
197 static constexpr char input2[] = "\0SUBSYSTEM=block";
198 static constexpr char input3[] = "\0SUBSYSTEM=power_supply";
199 static constexpr char input4[] = "\0SUBSYSTEM=power_supply\0";
200 static constexpr char input5[] =
201         "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
202         "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
203         "012345678901234567890123456789012345678901234567890123456789\0SUBSYSTEM=block\0";
204 
205 INSTANTIATE_TEST_SUITE_P(
206         filterPse, filterPseTest,
207         testing::Values(test_data{false, std::string_view(input0, sizeof(input0) - 1)},
208                         test_data{true, std::string_view(input1, sizeof(input1) - 1)},
209                         test_data{true, std::string_view(input2, sizeof(input2) - 1)},
210                         test_data{true, std::string_view(input3, sizeof(input3) - 1)},
211                         test_data{false, std::string_view(input4, sizeof(input4) - 1)},
212                         test_data{false, std::string_view(input5, sizeof(input5) - 1)}));
213