xref: /aosp_15_r20/external/ltp/lib/tst_netlink.c (revision 49cdfc7efb34551c7342be41a7384b9c40d7cab7)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright (c) 2021 Linux Test Project
4  */
5 
6 #include <stdlib.h>
7 #include <limits.h>
8 #include <asm/types.h>
9 #include <linux/netlink.h>
10 #include <linux/rtnetlink.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <sys/poll.h>
14 #define TST_NO_DEFAULT_MAIN
15 #include "tst_test.h"
16 #include "tst_netlink.h"
17 
18 struct tst_netlink_context {
19 	int socket;
20 	pid_t pid;
21 	uint32_t seq;
22 	size_t bufsize, datalen;
23 	char *buffer;
24 	struct nlmsghdr *curmsg;
25 };
26 
27 int tst_netlink_errno;
28 
netlink_grow_buffer(const char * file,const int lineno,struct tst_netlink_context * ctx,size_t size)29 static int netlink_grow_buffer(const char *file, const int lineno,
30 	struct tst_netlink_context *ctx, size_t size)
31 {
32 	size_t needed, offset, curlen = NLMSG_ALIGN(ctx->datalen);
33 	char *buf;
34 
35 	if (ctx->bufsize - curlen >= size)
36 		return 1;
37 
38 	needed = size - (ctx->bufsize - curlen);
39 	size = ctx->bufsize + (ctx->bufsize > needed ? ctx->bufsize : needed);
40 	size = NLMSG_ALIGN(size);
41 	buf = safe_realloc(file, lineno, ctx->buffer, size);
42 
43 	if (!buf)
44 		return 0;
45 
46 	memset(buf + ctx->bufsize, 0, size - ctx->bufsize);
47 	offset = ((char *)ctx->curmsg) - ctx->buffer;
48 	ctx->buffer = buf;
49 	ctx->curmsg = (struct nlmsghdr *)(buf + offset);
50 	ctx->bufsize = size;
51 
52 	return 1;
53 }
54 
tst_netlink_destroy_context(const char * file,const int lineno,struct tst_netlink_context * ctx)55 void tst_netlink_destroy_context(const char *file, const int lineno,
56 	struct tst_netlink_context *ctx)
57 {
58 	if (!ctx)
59 		return;
60 
61 	safe_close(file, lineno, NULL, ctx->socket);
62 	free(ctx->buffer);
63 	free(ctx);
64 }
65 
tst_netlink_create_context(const char * file,const int lineno,int protocol)66 struct tst_netlink_context *tst_netlink_create_context(const char *file,
67 	const int lineno, int protocol)
68 {
69 	struct tst_netlink_context *ctx;
70 	struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
71 
72 	ctx = safe_malloc(file, lineno, NULL,
73 		sizeof(struct tst_netlink_context));
74 
75 	if (!ctx)
76 		return NULL;
77 
78 	ctx->pid = 0;
79 	ctx->seq = 0;
80 	ctx->buffer = NULL;
81 	ctx->bufsize = 1024;
82 	ctx->datalen = 0;
83 	ctx->curmsg = NULL;
84 	ctx->socket = safe_socket(file, lineno, NULL, AF_NETLINK,
85 		SOCK_DGRAM | SOCK_CLOEXEC, protocol);
86 
87 	if (ctx->socket < 0) {
88 		free(ctx);
89 		return NULL;
90 	}
91 
92 	if (safe_bind(file, lineno, NULL, ctx->socket, (struct sockaddr *)&addr,
93 		sizeof(addr))) {
94 		tst_netlink_destroy_context(file, lineno, ctx);
95 		return NULL;
96 	}
97 
98 	ctx->buffer = safe_malloc(file, lineno, NULL, ctx->bufsize);
99 
100 	if (!ctx->buffer) {
101 		tst_netlink_destroy_context(file, lineno, ctx);
102 		return NULL;
103 	}
104 
105 	memset(ctx->buffer, 0, ctx->bufsize);
106 
107 	return ctx;
108 }
109 
tst_netlink_free_message(struct tst_netlink_message * msg)110 void tst_netlink_free_message(struct tst_netlink_message *msg)
111 {
112 	if (!msg)
113 		return;
114 
115 	// all ptr->header and ptr->info pointers point to the same buffer
116 	// msg->header is the start of the buffer
117 	free(msg->header);
118 	free(msg);
119 }
120 
tst_netlink_send(const char * file,const int lineno,struct tst_netlink_context * ctx)121 int tst_netlink_send(const char *file, const int lineno,
122 	struct tst_netlink_context *ctx)
123 {
124 	int ret;
125 	struct sockaddr_nl addr = { .nl_family = AF_NETLINK };
126 	struct iovec iov;
127 	struct msghdr msg = {
128 		.msg_name = &addr,
129 		.msg_namelen = sizeof(addr),
130 		.msg_iov = &iov,
131 		.msg_iovlen = 1
132 	};
133 
134 	if (!ctx->curmsg) {
135 		tst_brk_(file, lineno, TBROK, "%s(): No message to send",
136 			__func__);
137 		return 0;
138 	}
139 
140 	if (ctx->curmsg->nlmsg_flags & NLM_F_MULTI) {
141 		struct nlmsghdr eom = { .nlmsg_type = NLMSG_DONE };
142 
143 		if (!tst_netlink_add_message(file, lineno, ctx, &eom, NULL, 0))
144 			return 0;
145 
146 		/* NLMSG_DONE message must not have NLM_F_MULTI flag */
147 		ctx->curmsg->nlmsg_flags = 0;
148 	}
149 
150 	iov.iov_base = ctx->buffer;
151 	iov.iov_len = ctx->datalen;
152 	ret = safe_sendmsg(file, lineno, ctx->datalen, ctx->socket, &msg, 0);
153 
154 	if (ret > 0)
155 		ctx->curmsg = NULL;
156 
157 	return ret;
158 }
159 
tst_netlink_wait(struct tst_netlink_context * ctx)160 int tst_netlink_wait(struct tst_netlink_context *ctx)
161 {
162 	struct pollfd fdinfo = {
163 		.fd = ctx->socket,
164 		.events = POLLIN
165 	};
166 
167 	return poll(&fdinfo, 1, 1000);
168 }
169 
tst_netlink_recv(const char * file,const int lineno,struct tst_netlink_context * ctx)170 struct tst_netlink_message *tst_netlink_recv(const char *file,
171 	const int lineno, struct tst_netlink_context *ctx)
172 {
173 	char tmp, *tmpbuf, *buffer = NULL;
174 	struct tst_netlink_message *ret;
175 	struct nlmsghdr *ptr;
176 	size_t retsize, bufsize = 0;
177 	ssize_t size;
178 	int i, size_left, msgcount;
179 
180 	/* Each recv() call returns one message, read all pending messages */
181 	while (1) {
182 		errno = 0;
183 		size = recv(ctx->socket, &tmp, 1,
184 			MSG_DONTWAIT | MSG_PEEK | MSG_TRUNC);
185 
186 		if (size < 0) {
187 			if (errno != EAGAIN) {
188 				tst_brk_(file, lineno, TBROK | TERRNO,
189 					"recv() failed");
190 			}
191 
192 			break;
193 		}
194 
195 		tmpbuf = safe_realloc(file, lineno, buffer, bufsize + size);
196 
197 		if (!tmpbuf)
198 			break;
199 
200 		buffer = tmpbuf;
201 		size = safe_recv(file, lineno, size, ctx->socket,
202 			buffer + bufsize, size, 0);
203 
204 		if (size < 0)
205 			break;
206 
207 		bufsize += size;
208 	}
209 
210 	if (!bufsize) {
211 		free(buffer);
212 		return NULL;
213 	}
214 
215 	ptr = (struct nlmsghdr *)buffer;
216 	size_left = bufsize;
217 	msgcount = 0;
218 
219 	for (; size_left > 0 && NLMSG_OK(ptr, size_left); msgcount++)
220 		ptr = NLMSG_NEXT(ptr, size_left);
221 
222 	retsize = (msgcount + 1) * sizeof(struct tst_netlink_message);
223 	ret = safe_malloc(file, lineno, NULL, retsize);
224 
225 	if (!ret) {
226 		free(buffer);
227 		return NULL;
228 	}
229 
230 	memset(ret, 0, retsize);
231 	ptr = (struct nlmsghdr *)buffer;
232 	size_left = bufsize;
233 
234 	for (i = 0; i < msgcount; i++, ptr = NLMSG_NEXT(ptr, size_left)) {
235 		ret[i].header = ptr;
236 		ret[i].payload = NLMSG_DATA(ptr);
237 		ret[i].payload_size = NLMSG_PAYLOAD(ptr, 0);
238 
239 		if (ptr->nlmsg_type == NLMSG_ERROR)
240 			ret[i].err = NLMSG_DATA(ptr);
241 	}
242 
243 	return ret;
244 }
245 
tst_netlink_add_message(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct nlmsghdr * header,const void * payload,size_t payload_size)246 int tst_netlink_add_message(const char *file, const int lineno,
247 	struct tst_netlink_context *ctx, const struct nlmsghdr *header,
248 	const void *payload, size_t payload_size)
249 {
250 	size_t size;
251 	unsigned int extra_flags = 0;
252 
253 	if (!netlink_grow_buffer(file, lineno, ctx, NLMSG_SPACE(payload_size)))
254 		return 0;
255 
256 	if (!ctx->curmsg) {
257 		/*
258 		 * datalen may hold the size of last sent message for ACK
259 		 * checking, reset it back to 0 here
260 		 */
261 		ctx->datalen = 0;
262 		ctx->curmsg = (struct nlmsghdr *)ctx->buffer;
263 	} else {
264 		size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
265 
266 		extra_flags = NLM_F_MULTI;
267 		ctx->curmsg->nlmsg_flags |= extra_flags;
268 		ctx->curmsg = NLMSG_NEXT(ctx->curmsg, size);
269 		ctx->datalen = NLMSG_ALIGN(ctx->datalen);
270 	}
271 
272 	*ctx->curmsg = *header;
273 	ctx->curmsg->nlmsg_len = NLMSG_LENGTH(payload_size);
274 	ctx->curmsg->nlmsg_flags |= extra_flags;
275 	ctx->curmsg->nlmsg_seq = ctx->seq++;
276 	ctx->curmsg->nlmsg_pid = ctx->pid;
277 
278 	if (payload_size)
279 		memcpy(NLMSG_DATA(ctx->curmsg), payload, payload_size);
280 
281 	ctx->datalen += ctx->curmsg->nlmsg_len;
282 
283 	return 1;
284 }
285 
tst_netlink_add_attr(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const void * data,unsigned short len)286 int tst_netlink_add_attr(const char *file, const int lineno,
287 	struct tst_netlink_context *ctx, unsigned short type,
288 	const void *data, unsigned short len)
289 {
290 	size_t size = NLA_HDRLEN + NLA_ALIGN(len);
291 	struct nlattr *attr;
292 
293 	if (!ctx->curmsg) {
294 		tst_brk_(file, lineno, TBROK,
295 			"%s(): No message to add attributes to", __func__);
296 		return 0;
297 	}
298 
299 	if (!netlink_grow_buffer(file, lineno, ctx, size))
300 		return 0;
301 
302 	size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
303 	attr = (struct nlattr *)(((char *)ctx->curmsg) + size);
304 	attr->nla_type = type;
305 	attr->nla_len = NLA_HDRLEN + len;
306 	memcpy(((char *)attr) + NLA_HDRLEN, data, len);
307 	ctx->curmsg->nlmsg_len = size + attr->nla_len;
308 	ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->nla_len;
309 
310 	return 1;
311 }
312 
tst_netlink_add_attr_string(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const char * data)313 int tst_netlink_add_attr_string(const char *file, const int lineno,
314 	struct tst_netlink_context *ctx, unsigned short type,
315 	const char *data)
316 {
317 	return tst_netlink_add_attr(file, lineno, ctx, type, data,
318 		strlen(data) + 1);
319 }
320 
tst_netlink_add_attr_list(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct tst_netlink_attr_list * list)321 int tst_netlink_add_attr_list(const char *file, const int lineno,
322 	struct tst_netlink_context *ctx,
323 	const struct tst_netlink_attr_list *list)
324 {
325 	int i, ret;
326 	size_t offset;
327 
328 	for (i = 0; list[i].len >= 0; i++) {
329 		if (list[i].len > USHRT_MAX) {
330 			tst_brk_(file, lineno, TBROK,
331 				"%s(): Attribute value too long", __func__);
332 			return -1;
333 		}
334 
335 		offset = NLMSG_ALIGN(ctx->datalen);
336 		ret = tst_netlink_add_attr(file, lineno, ctx, list[i].type,
337 			list[i].data, list[i].len);
338 
339 		if (!ret)
340 			return -1;
341 
342 		if (list[i].sublist) {
343 			struct rtattr *attr;
344 
345 			ret = tst_netlink_add_attr_list(file, lineno, ctx,
346 				list[i].sublist);
347 
348 			if (ret < 0)
349 				return ret;
350 
351 			attr = (struct rtattr *)(ctx->buffer + offset);
352 
353 			if (ctx->datalen - offset > USHRT_MAX) {
354 				tst_brk_(file, lineno, TBROK,
355 					"%s(): Sublist too long", __func__);
356 				return -1;
357 			}
358 
359 			attr->rta_len = ctx->datalen - offset;
360 		}
361 	}
362 
363 	return i;
364 }
365 
tst_rtnl_add_attr(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const void * data,unsigned short len)366 int tst_rtnl_add_attr(const char *file, const int lineno,
367 	struct tst_netlink_context *ctx, unsigned short type,
368 	const void *data, unsigned short len)
369 {
370 	size_t size;
371 	struct rtattr *attr;
372 
373 	if (!ctx->curmsg) {
374 		tst_brk_(file, lineno, TBROK,
375 			"%s(): No message to add attributes to", __func__);
376 		return 0;
377 	}
378 
379 	if (!netlink_grow_buffer(file, lineno, ctx, RTA_SPACE(len)))
380 		return 0;
381 
382 	size = NLMSG_ALIGN(ctx->curmsg->nlmsg_len);
383 	attr = (struct rtattr *)(((char *)ctx->curmsg) + size);
384 	attr->rta_type = type;
385 	attr->rta_len = RTA_LENGTH(len);
386 	memcpy(RTA_DATA(attr), data, len);
387 	ctx->curmsg->nlmsg_len = size + attr->rta_len;
388 	ctx->datalen = NLMSG_ALIGN(ctx->datalen) + attr->rta_len;
389 
390 	return 1;
391 }
392 
tst_rtnl_add_attr_string(const char * file,const int lineno,struct tst_netlink_context * ctx,unsigned short type,const char * data)393 int tst_rtnl_add_attr_string(const char *file, const int lineno,
394 	struct tst_netlink_context *ctx, unsigned short type,
395 	const char *data)
396 {
397 	return tst_rtnl_add_attr(file, lineno, ctx, type, data,
398 		strlen(data) + 1);
399 }
400 
tst_rtnl_add_attr_list(const char * file,const int lineno,struct tst_netlink_context * ctx,const struct tst_netlink_attr_list * list)401 int tst_rtnl_add_attr_list(const char *file, const int lineno,
402 	struct tst_netlink_context *ctx,
403 	const struct tst_netlink_attr_list *list)
404 {
405 	int i, ret;
406 	size_t offset;
407 
408 	for (i = 0; list[i].len >= 0; i++) {
409 		if (list[i].len > USHRT_MAX) {
410 			tst_brk_(file, lineno, TBROK,
411 				"%s(): Attribute value too long", __func__);
412 			return -1;
413 		}
414 
415 		offset = NLMSG_ALIGN(ctx->datalen);
416 		ret = tst_rtnl_add_attr(file, lineno, ctx, list[i].type,
417 			list[i].data, list[i].len);
418 
419 		if (!ret)
420 			return -1;
421 
422 		if (list[i].sublist) {
423 			struct rtattr *attr;
424 
425 			ret = tst_rtnl_add_attr_list(file, lineno, ctx,
426 				list[i].sublist);
427 
428 			if (ret < 0)
429 				return ret;
430 
431 			attr = (struct rtattr *)(ctx->buffer + offset);
432 
433 			if (ctx->datalen - offset > USHRT_MAX) {
434 				tst_brk_(file, lineno, TBROK,
435 					"%s(): Sublist too long", __func__);
436 				return -1;
437 			}
438 
439 			attr->rta_len = ctx->datalen - offset;
440 		}
441 	}
442 
443 	return i;
444 }
445 
tst_netlink_check_acks(const char * file,const int lineno,struct tst_netlink_context * ctx,struct tst_netlink_message * res)446 int tst_netlink_check_acks(const char *file, const int lineno,
447 	struct tst_netlink_context *ctx, struct tst_netlink_message *res)
448 {
449 	struct nlmsghdr *msg = (struct nlmsghdr *)ctx->buffer;
450 	int size_left = ctx->datalen;
451 
452 	for (; size_left > 0 && NLMSG_OK(msg, size_left);
453 		msg = NLMSG_NEXT(msg, size_left)) {
454 
455 		if (!(msg->nlmsg_flags & NLM_F_ACK))
456 			continue;
457 
458 		while (res->header && !(res->err && res->err->error) &&
459 			res->header->nlmsg_seq != msg->nlmsg_seq)
460 			res++;
461 
462 		if (res->err && res->err->error) {
463 			tst_netlink_errno = -res->err->error;
464 			return 0;
465 		}
466 
467 		if (!res->header || res->header->nlmsg_seq != msg->nlmsg_seq) {
468 			tst_brk_(file, lineno, TBROK,
469 				"No ACK found for Netlink message %u",
470 				msg->nlmsg_seq);
471 			return 0;
472 		}
473 	}
474 
475 	return 1;
476 }
477 
tst_netlink_send_validate(const char * file,const int lineno,struct tst_netlink_context * ctx)478 int tst_netlink_send_validate(const char *file, const int lineno,
479 	struct tst_netlink_context *ctx)
480 {
481 	struct tst_netlink_message *response;
482 	int ret;
483 
484 	tst_netlink_errno = 0;
485 
486 	if (tst_netlink_send(file, lineno, ctx) <= 0)
487 		return 0;
488 
489 	tst_netlink_wait(ctx);
490 	response = tst_netlink_recv(file, lineno, ctx);
491 
492 	if (!response)
493 		return 0;
494 
495 	ret = tst_netlink_check_acks(file, lineno, ctx, response);
496 	tst_netlink_free_message(response);
497 
498 	return ret;
499 }
500