1 /*
2 * Copyright (C) 2024 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/unique_fd.h>
18 #include <bpf/libbpf.h>
19 #include <gtest/gtest.h>
20 #include <linux/bpf.h> // SO_ATTACH_BPF
21 #include <linux/netlink.h>
22 #include <netinet/in.h>
23 #include <stdio.h>
24 #include <sys/socket.h>
25 #include <string>
26 #include <string_view>
27
28 #define ASSERT_UNIX_OK(e) ASSERT_GE(e, 0) << strerror(errno)
29
30 // TODO(bvanassche): remove the code below. See also b/357099095.
31 #ifndef SO_ATTACH_BPF
32 #define SO_ATTACH_BPF 50 // From <asm-generic/socket.h>.
33 #endif
34
35 using ::android::base::unique_fd;
36 using ::testing::ScopedTrace;
37
38 struct test_data {
39 bool discarded;
40 std::string_view str;
41 };
42
43 static const uint8_t binary_bpf_prog[] = {
44 #include "filterPowerSupplyEvents.h"
45 };
46
47 static std::vector<std::unique_ptr<ScopedTrace>>* msg_vec;
48
operator <<(std::ostream & os,const test_data & td)49 std::ostream& operator<<(std::ostream& os, const test_data& td) {
50 os << "{.discarded=" << td.discarded << ", .str=";
51 for (auto c : td.str) {
52 if (isprint(c)) {
53 os << c;
54 } else {
55 os << ".";
56 }
57 }
58 return os << '}';
59 }
60
61 #define RECORD_ERR_MSG(fmt, ...) \
62 do { \
63 char* str; \
64 if (asprintf(&str, fmt, ##__VA_ARGS__) < 0) break; \
65 auto st = std::make_unique<ScopedTrace>(__FILE__, __LINE__, str); \
66 msg_vec->emplace_back(std::move(st)); \
67 free(str); \
68 } while (0)
69
libbpf_print_fn(enum libbpf_print_level,const char * fmt,va_list args)70 int libbpf_print_fn(enum libbpf_print_level, const char* fmt, va_list args) {
71 char* str;
72 if (vasprintf(&str, fmt, args) < 0) {
73 return 0;
74 }
75 msg_vec->emplace_back(std::make_unique<ScopedTrace>(__FILE__, -1, str));
76 free(str);
77 return 0;
78 }
79
record_libbpf_output()80 static void record_libbpf_output() {
81 libbpf_set_print(libbpf_print_fn);
82 }
83
84 class filterPseTest : public testing::TestWithParam<test_data> {};
85
86 struct ConnectedSockets {
87 unique_fd write_fd;
88 unique_fd read_fd;
89 };
90
91 // socketpair() only supports AF_UNIX sockets. AF_UNIX sockets do not
92 // support BPF filters. Hence connect two TCP sockets with each other.
ConnectSockets(int domain,int type,int protocol)93 static ConnectedSockets ConnectSockets(int domain, int type, int protocol) {
94 int _server_fd = socket(domain, type, protocol);
95 if (_server_fd < 0) {
96 return {};
97 }
98 unique_fd server_fd(_server_fd);
99
100 int _write_fd = socket(domain, type, protocol);
101 if (_write_fd < 0) {
102 RECORD_ERR_MSG("socket: %s", strerror(errno));
103 return {};
104 }
105 unique_fd write_fd(_write_fd);
106
107 struct sockaddr_in sa = {.sin_family = AF_INET, .sin_addr.s_addr = INADDR_ANY};
108 if (bind(_server_fd, (const struct sockaddr*)&sa, sizeof(sa)) < 0) {
109 RECORD_ERR_MSG("bind: %s", strerror(errno));
110 return {};
111 }
112 if (listen(_server_fd, 1) < 0) {
113 RECORD_ERR_MSG("listen: %s", strerror(errno));
114 return {};
115 }
116 socklen_t addr_len = sizeof(sa);
117 if (getsockname(_server_fd, (struct sockaddr*)&sa, &addr_len) < 0) {
118 RECORD_ERR_MSG("getsockname: %s", strerror(errno));
119 return {};
120 }
121 errno = 0;
122 if (connect(_write_fd, (const struct sockaddr*)&sa, sizeof(sa)) < 0 && errno != EINPROGRESS) {
123 RECORD_ERR_MSG("connect: %s", strerror(errno));
124 return {};
125 }
126 int _read_fd = accept(_server_fd, NULL, NULL);
127 if (_read_fd < 0) {
128 RECORD_ERR_MSG("accept: %s", strerror(errno));
129 return {};
130 }
131 unique_fd read_fd(_read_fd);
132
133 return {.write_fd = std::move(write_fd), .read_fd = std::move(read_fd)};
134 }
135
TEST_P(filterPseTest,filterPse)136 TEST_P(filterPseTest, filterPse) {
137 if (getuid() != 0) {
138 GTEST_SKIP() << "Must be run as root.";
139 return;
140 }
141 if (!msg_vec) {
142 msg_vec = new typeof(*msg_vec);
143 }
144 std::unique_ptr<int, void (*)(int*)> clear_msg_vec_at_end_of_scope(new int, [](int* p) {
145 msg_vec->clear();
146 delete p;
147 });
148 record_libbpf_output();
149
150 auto connected_sockets = ConnectSockets(AF_INET, SOCK_STREAM, 0);
151 unique_fd write_fd = std::move(connected_sockets.write_fd);
152 unique_fd read_fd = std::move(connected_sockets.read_fd);
153
154 ASSERT_UNIX_OK(fcntl(read_fd, F_SETFL, O_NONBLOCK));
155
156 bpf_object* obj = bpf_object__open_mem(binary_bpf_prog, sizeof(binary_bpf_prog), NULL);
157 ASSERT_TRUE(obj) << "bpf_object__open() failed" << strerror(errno);
158
159 // Find the BPF program within the object.
160 bpf_program* prog = bpf_object__find_program_by_name(obj, "filterPowerSupplyEvents");
161 ASSERT_TRUE(prog);
162
163 ASSERT_UNIX_OK(bpf_program__set_type(prog, BPF_PROG_TYPE_SOCKET_FILTER));
164
165 ASSERT_UNIX_OK(bpf_object__load(obj));
166
167 int filter_fd = bpf_program__fd(prog);
168 ASSERT_UNIX_OK(filter_fd);
169
170 int setsockopt_result =
171 setsockopt(read_fd, SOL_SOCKET, SO_ATTACH_BPF, &filter_fd, sizeof(filter_fd));
172 ASSERT_UNIX_OK(setsockopt_result);
173
174 const test_data param = GetParam();
175 const std::string header(sizeof(struct nlmsghdr), '\0');
176 ASSERT_EQ(header.length(), sizeof(struct nlmsghdr));
177 const std::string data = header + std::string(param.str);
178 const size_t len = data.length();
179 std::cerr.write(data.data(), data.length());
180 std::cerr << ")\n";
181 ASSERT_EQ(write(write_fd, data.data(), len), len);
182 std::array<uint8_t, 512> read_buf;
183 int bytes_read = read(read_fd, read_buf.data(), read_buf.size());
184 if (bytes_read < 0) {
185 ASSERT_EQ(errno, EAGAIN);
186 bytes_read = 0;
187 } else {
188 ASSERT_LT(bytes_read, read_buf.size());
189 }
190 EXPECT_EQ(bytes_read, param.discarded ? 0 : len);
191
192 bpf_object__close(obj);
193 }
194
195 static constexpr char input0[] = "a";
196 static constexpr char input1[] = "abc\0SUBSYSTEM=block\0";
197 static constexpr char input2[] = "\0SUBSYSTEM=block";
198 static constexpr char input3[] = "\0SUBSYSTEM=power_supply";
199 static constexpr char input4[] = "\0SUBSYSTEM=power_supply\0";
200 static constexpr char input5[] =
201 "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
202 "012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
203 "012345678901234567890123456789012345678901234567890123456789\0SUBSYSTEM=block\0";
204
205 INSTANTIATE_TEST_SUITE_P(
206 filterPse, filterPseTest,
207 testing::Values(test_data{false, std::string_view(input0, sizeof(input0) - 1)},
208 test_data{true, std::string_view(input1, sizeof(input1) - 1)},
209 test_data{true, std::string_view(input2, sizeof(input2) - 1)},
210 test_data{true, std::string_view(input3, sizeof(input3) - 1)},
211 test_data{false, std::string_view(input4, sizeof(input4) - 1)},
212 test_data{false, std::string_view(input5, sizeof(input5) - 1)}));
213