1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * Copyright (c) 2021 Linux Test Project
4 */
5
6 #include <stdlib.h>
7 #include <limits.h>
8 #include <asm/types.h>
9 #include <linux/netlink.h>
10 #include <linux/rtnetlink.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <sys/poll.h>
14 #define TST_NO_DEFAULT_MAIN
15 #include "tst_test.h"
16 #include "tst_netlink.h"
17
18 struct tst_netlink_context {
19 int socket;
20 pid_t pid;
21 uint32_t seq;
22 size_t bufsize, datalen;
23 char *buffer;
24 struct nlmsghdr *curmsg;
25 };
26
27 int tst_netlink_errno;
28
netlink_grow_buffer(const char * file,const int lineno,struct tst_netlink_context * ctx,size_t size)29 static int netlink_grow_buffer(const char *file, const int lineno,
30 struct tst_netlink_context *ctx, size_t size)
31 {
32 size_t needed, offset, curlen = NLMSG_ALIGN(ctx->datalen);
33 char *buf;
34
35 if (ctx->bufsize - curlen >= size)
36 return 1;
37
38 needed = size - (ctx->bufsize - curlen);
39 size = ctx->bufsize + (ctx->bufsize > needed ? ctx->bufsize : needed);
40 size = NLMSG_ALIGN(size);
41 buf = safe_realloc(file, lineno, ctx->buffer, size);
42
43 if (!buf)
44 return 0;
45
46 memset(buf + ctx->bufsize, 0, size - ctx->bufsize);
47 offset = ((char *)ctx->curmsg) - ctx->buffer;
48 ctx->buffer = buf;
49 ctx->curmsg = (struct nlmsghdr *)(buf + offset);
50 ctx->bufsize = size;
51
52 return 1;
53 }
54
tst_netlink_destroy_context(const char * file,const int lineno,struct tst_netlink_context * ctx)55 void tst_netlink_destroy_context(const char *file, const int lineno,
56 struct tst_netlink_context *ctx)
57 {
58 if (!ctx)
59 return;
60
61 safe_close(file, lineno, NULL, ctx->socket);
62 free(ctx->buffer);
63 free(ctx);
64 }
65
tst_netlink_create_context(const char * file,const int lineno,int protocol)66 struct tst_netlink_context *tst_netlink_create_context(const char *file,
67 const int lineno, int protocol)
68 {
69 struct tst_netlink_context *ctx;
70 struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
71
72 ctx = safe_malloc(file, lineno, NULL,
73 sizeof(struct tst_netlink_context));
74
75 if (!ctx)
76 return NULL;
77
78 ctx->pid = 0;
79 ctx->seq = 0;
80 ctx->buffer = NULL;
81 ctx->bufsize = 1024;
82 ctx->datalen = 0;
83 ctx->curmsg = NULL;
84 ctx->socket = safe_socket(file, lineno, NULL, AF_NETLINK,
85 SOCK_DGRAM | SOCK_CLOEXEC, protocol);
86
87 if (ctx->socket < 0) {
88 free(ctx);
89 return NULL;
90 }
91
92 if (safe_bind(file, lineno, NULL, ctx->socket, (struct sockaddr *)&addr,
93 sizeof(addr))) {
94 tst_netlink_destroy_context(file, lineno, ctx);
95 return NULL;
96 }
97
98 ctx->buffer = safe_malloc(file, lineno, NULL, ctx->bufsize);
99
100 if (!ctx->buffer) {
101 tst_netlink_destroy_context(file, lineno, ctx);
102 return NULL;
103 }
104
105 memset(ctx->buffer, 0, ctx->bufsize);
106
107 return ctx;
108 }
109
tst_netlink_free_message(struct tst_netlink_message * msg)110 void tst_netlink_free_message(struct tst_netlink_message *msg)
111 {
112 if (!msg)
113 return;
114
115 // all ptr->header and ptr->info pointers point to the same buffer
116 // msg->header is the start of the buffer
117 free(msg->header);
118 free(msg);
119 }
120
tst_netlink_send(const char * file,const int lineno,struct tst_netlink_context * ctx)121 int tst_netlink_send(const char *file, const int lineno,
122 struct tst_netlink_context *ctx)
123 {
124 int ret;
125 struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
126 struct iovec iov;
127 struct msghdr msg = {
128 .msg_name = &addr,
129 .msg_namelen = sizeof(addr),
130 .msg_iov = &iov,
131 .msg_iovlen = 1
132 };
133
134 if (!ctx->curmsg) {
135 tst_brk_(file, lineno, TBROK, "%s(): No message to send",
136 __func__);
137 return 0;
138 }
139
140 if (ctx->curmsg->nlmsg_flags & NLM_F_MULTI) {
141 struct nlmsghdr eom = { .nlmsg_type = NLMSG_DONE };
142
143 if (!tst_netlink_add_message(file, lineno, ctx, &eom, NULL, 0))
144 return 0;
145
146 /* NLMSG_DONE message must not have NLM_F_MULTI flag */
147 ctx->curmsg->nlmsg_flags = 0;
148 }
149
150 iov.iov_base = ctx->buffer;
151 iov.iov_len = ctx->datalen;
152 ret = safe_sendmsg(file, lineno, ctx->datalen, ctx->socket, &msg, 0);
153
154 if (ret > 0)
155 ctx->curmsg = NULL;
156
157 return ret;
158 }
159
tst_netlink_wait(struct tst_netlink_context * ctx)160 int tst_netlink_wait(struct tst_netlink_context *ctx)
161 {
162 struct pollfd fdinfo = {
163 .fd = ctx->socket,
164 .events = POLLIN
165 };
166
167 return poll(&fdinfo, 1, 1000);
168 }
169
tst_netlink_recv(const char * file,const int lineno,struct tst_netlink_context * ctx)170 struct tst_netlink_message *tst_netlink_recv(const char *file,
171 const int lineno, struct tst_netlink_context *ctx)
172 {
173 char tmp, *tmpbuf, *buffer = NULL;
174 struct tst_netlink_message *ret;
175 struct nlmsghdr *ptr;
176 size_t retsize, bufsize = 0;
177 ssize_t size;
178 int i, size_left, msgcount;
179
180 /* Each recv() call returns one message, read all pending messages */
181 while (1) {
182 errno = 0;
183 size = recv(ctx->socket, &tmp, 1,
184 MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC);
185
186 if (size < 0) {
187 if (errno != EAGAIN) {
188 tst_brk_(file, lineno, TBROK | TERRNO,
189 "recv() failed");
190 }
191
192 break;
193 }
194
195 tmpbuf = safe_realloc(file, lineno, buffer, bufsize + size);
196
197 if (!tmpbuf)
198 break;
199
200 buffer = tmpbuf;
201 size = safe_recv(file, lineno, size, ctx->socket,
202 buffer + bufsize, size, 0);
203
204 if (size < 0)
205 break;
206
207 bufsize += size;
208 }
209
210 if (!bufsize) {
211 free(buffer);
212 return NULL;
213 }
214
215 ptr = (struct nlmsghdr *)buffer;
216 size_left = bufsize;
217 msgcount = 0;
218
219 for (; size_left > 0 && NLMSG_OK(ptr, size_left); msgcount++)
220 ptr = NLMSG_NEXT(ptr, size_left);
221
222 retsize = (msgcount + 1) * sizeof(struct tst_netlink_message);
223 ret = safe_malloc(file, lineno, NULL, retsize);
224
225 if (!ret) {
226 free(buffer);
227 return NULL;
228 }
229
230 memset(ret, 0, retsize);
231 ptr = (struct nlmsghdr *)buffer;
232 size_left = bufsize;
233
234 for (i = 0; i < msgcount; i++, ptr = NLMSG_NEXT(ptr, size_left)) {
235 ret[i].header = ptr;
236 ret[i].payload = NLMSG_DATA(ptr);
237 ret[i].payload_size = NLMSG_PAYLOAD(ptr, 0);
238
239 if (ptr->nlmsg_type == NLMSG_ERROR)
240 ret[i].err = NLMSG_DATA(ptr);
241 }
242
243 return ret;
244 }
245
tst_netlink_add_message(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct nlmsghdr * header,const void * payload,size_t payload_size)246 int tst_netlink_add_message(const char *file, const int lineno,
247 struct tst_netlink_context *ctx, const struct nlmsghdr *header,
248 const void *payload, size_t payload_size)
249 {
250 size_t size;
251 unsigned int extra_flags = 0;
252
253 if (!netlink_grow_buffer(file, lineno, ctx, NLMSG_SPACE(payload_size)))
254 return 0;
255
256 if (!ctx->curmsg) {
257 /*
258 * datalen may hold the size of last sent message for ACK
259 * checking, reset it back to 0 here
260 */
261 ctx->datalen = 0;
262 ctx->curmsg = (struct nlmsghdr *)ctx->buffer;
263 } else {
264 size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
265
266 extra_flags = NLM_F_MULTI;
267 ctx->curmsg->nlmsg_flags |= extra_flags;
268 ctx->curmsg = NLMSG_NEXT(ctx->curmsg, size);
269 ctx->datalen = NLMSG_ALIGN(ctx->datalen);
270 }
271
272 *ctx->curmsg = *header;
273 ctx->curmsg->nlmsg_len = NLMSG_LENGTH(payload_size);
274 ctx->curmsg->nlmsg_flags |= extra_flags;
275 ctx->curmsg->nlmsg_seq = ctx->seq++;
276 ctx->curmsg->nlmsg_pid = ctx->pid;
277
278 if (payload_size)
279 memcpy(NLMSG_DATA(ctx->curmsg), payload, payload_size);
280
281 ctx->datalen += ctx->curmsg->nlmsg_len;
282
283 return 1;
284 }
285
tst_netlink_add_attr(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const void * data,unsigned short len)286 int tst_netlink_add_attr(const char *file, const int lineno,
287 struct tst_netlink_context *ctx, unsigned short type,
288 const void *data, unsigned short len)
289 {
290 size_t size = NLA_HDRLEN + NLA_ALIGN(len);
291 struct nlattr *attr;
292
293 if (!ctx->curmsg) {
294 tst_brk_(file, lineno, TBROK,
295 "%s(): No message to add attributes to", __func__);
296 return 0;
297 }
298
299 if (!netlink_grow_buffer(file, lineno, ctx, size))
300 return 0;
301
302 size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
303 attr = (struct nlattr *)(((char *)ctx->curmsg) + size);
304 attr->nla_type = type;
305 attr->nla_len = NLA_HDRLEN + len;
306 memcpy(((char *)attr) + NLA_HDRLEN, data, len);
307 ctx->curmsg->nlmsg_len = size + attr->nla_len;
308 ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->nla_len;
309
310 return 1;
311 }
312
tst_netlink_add_attr_string(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const char * data)313 int tst_netlink_add_attr_string(const char *file, const int lineno,
314 struct tst_netlink_context *ctx, unsigned short type,
315 const char *data)
316 {
317 return tst_netlink_add_attr(file, lineno, ctx, type, data,
318 strlen(data) + 1);
319 }
320
tst_netlink_add_attr_list(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct tst_netlink_attr_list * list)321 int tst_netlink_add_attr_list(const char *file, const int lineno,
322 struct tst_netlink_context *ctx,
323 const struct tst_netlink_attr_list *list)
324 {
325 int i, ret;
326 size_t offset;
327
328 for (i = 0; list[i].len >= 0; i++) {
329 if (list[i].len > USHRT_MAX) {
330 tst_brk_(file, lineno, TBROK,
331 "%s(): Attribute value too long", __func__);
332 return -1;
333 }
334
335 offset = NLMSG_ALIGN(ctx->datalen);
336 ret = tst_netlink_add_attr(file, lineno, ctx, list[i].type,
337 list[i].data, list[i].len);
338
339 if (!ret)
340 return -1;
341
342 if (list[i].sublist) {
343 struct rtattr *attr;
344
345 ret = tst_netlink_add_attr_list(file, lineno, ctx,
346 list[i].sublist);
347
348 if (ret < 0)
349 return ret;
350
351 attr = (struct rtattr *)(ctx->buffer + offset);
352
353 if (ctx->datalen - offset > USHRT_MAX) {
354 tst_brk_(file, lineno, TBROK,
355 "%s(): Sublist too long", __func__);
356 return -1;
357 }
358
359 attr->rta_len = ctx->datalen - offset;
360 }
361 }
362
363 return i;
364 }
365
tst_rtnl_add_attr(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const void * data,unsigned short len)366 int tst_rtnl_add_attr(const char *file, const int lineno,
367 struct tst_netlink_context *ctx, unsigned short type,
368 const void *data, unsigned short len)
369 {
370 size_t size;
371 struct rtattr *attr;
372
373 if (!ctx->curmsg) {
374 tst_brk_(file, lineno, TBROK,
375 "%s(): No message to add attributes to", __func__);
376 return 0;
377 }
378
379 if (!netlink_grow_buffer(file, lineno, ctx, RTA_SPACE(len)))
380 return 0;
381
382 size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
383 attr = (struct rtattr *)(((char *)ctx->curmsg) + size);
384 attr->rta_type = type;
385 attr->rta_len = RTA_LENGTH(len);
386 memcpy(RTA_DATA(attr), data, len);
387 ctx->curmsg->nlmsg_len = size + attr->rta_len;
388 ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->rta_len;
389
390 return 1;
391 }
392
tst_rtnl_add_attr_string(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const char * data)393 int tst_rtnl_add_attr_string(const char *file, const int lineno,
394 struct tst_netlink_context *ctx, unsigned short type,
395 const char *data)
396 {
397 return tst_rtnl_add_attr(file, lineno, ctx, type, data,
398 strlen(data) + 1);
399 }
400
tst_rtnl_add_attr_list(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct tst_netlink_attr_list * list)401 int tst_rtnl_add_attr_list(const char *file, const int lineno,
402 struct tst_netlink_context *ctx,
403 const struct tst_netlink_attr_list *list)
404 {
405 int i, ret;
406 size_t offset;
407
408 for (i = 0; list[i].len >= 0; i++) {
409 if (list[i].len > USHRT_MAX) {
410 tst_brk_(file, lineno, TBROK,
411 "%s(): Attribute value too long", __func__);
412 return -1;
413 }
414
415 offset = NLMSG_ALIGN(ctx->datalen);
416 ret = tst_rtnl_add_attr(file, lineno, ctx, list[i].type,
417 list[i].data, list[i].len);
418
419 if (!ret)
420 return -1;
421
422 if (list[i].sublist) {
423 struct rtattr *attr;
424
425 ret = tst_rtnl_add_attr_list(file, lineno, ctx,
426 list[i].sublist);
427
428 if (ret < 0)
429 return ret;
430
431 attr = (struct rtattr *)(ctx->buffer + offset);
432
433 if (ctx->datalen - offset > USHRT_MAX) {
434 tst_brk_(file, lineno, TBROK,
435 "%s(): Sublist too long", __func__);
436 return -1;
437 }
438
439 attr->rta_len = ctx->datalen - offset;
440 }
441 }
442
443 return i;
444 }
445
tst_netlink_check_acks(const char * file,const int lineno,struct tst_netlink_context * ctx,struct tst_netlink_message * res)446 int tst_netlink_check_acks(const char *file, const int lineno,
447 struct tst_netlink_context *ctx, struct tst_netlink_message *res)
448 {
449 struct nlmsghdr *msg = (struct nlmsghdr *)ctx->buffer;
450 int size_left = ctx->datalen;
451
452 for (; size_left > 0 && NLMSG_OK(msg, size_left);
453 msg = NLMSG_NEXT(msg, size_left)) {
454
455 if (!(msg->nlmsg_flags & NLM_F_ACK))
456 continue;
457
458 while (res->header && !(res->err && res->err->error) &&
459 res->header->nlmsg_seq != msg->nlmsg_seq)
460 res++;
461
462 if (res->err && res->err->error) {
463 tst_netlink_errno = -res->err->error;
464 return 0;
465 }
466
467 if (!res->header || res->header->nlmsg_seq != msg->nlmsg_seq) {
468 tst_brk_(file, lineno, TBROK,
469 "No ACK found for Netlink message %u",
470 msg->nlmsg_seq);
471 return 0;
472 }
473 }
474
475 return 1;
476 }
477
tst_netlink_send_validate(const char * file,const int lineno,struct tst_netlink_context * ctx)478 int tst_netlink_send_validate(const char *file, const int lineno,
479 struct tst_netlink_context *ctx)
480 {
481 struct tst_netlink_message *response;
482 int ret;
483
484 tst_netlink_errno = 0;
485
486 if (tst_netlink_send(file, lineno, ctx) <= 0)
487 return 0;
488
489 tst_netlink_wait(ctx);
490 response = tst_netlink_recv(file, lineno, ctx);
491
492 if (!response)
493 return 0;
494
495 ret = tst_netlink_check_acks(file, lineno, ctx, response);
496 tst_netlink_free_message(response);
497
498 return ret;
499 }
500