1 // SPDX-License-Identifier: GPL-2.0
2 #include "comm.h"
3 #include <errno.h>
4 #include <string.h>
5 #include <internal/rc_check.h>
6 #include <linux/refcount.h>
7 #include <linux/zalloc.h>
8 #include <tools/libc_compat.h> // reallocarray
9 
10 #include "rwsem.h"
11 
DECLARE_RC_STRUCT(comm_str)12 DECLARE_RC_STRUCT(comm_str) {
13 	refcount_t refcnt;
14 	char str[];
15 };
16 
17 static struct comm_strs {
18 	struct rw_semaphore lock;
19 	struct comm_str **strs;
20 	int num_strs;
21 	int capacity;
22 } _comm_strs;
23 
24 static void comm_strs__remove_if_last(struct comm_str *cs);
25 
comm_strs__init(void)26 static void comm_strs__init(void)
27 {
28 	init_rwsem(&_comm_strs.lock);
29 	_comm_strs.capacity = 16;
30 	_comm_strs.num_strs = 0;
31 	_comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs));
32 }
33 
comm_strs__get(void)34 static struct comm_strs *comm_strs__get(void)
35 {
36 	static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT;
37 
38 	pthread_once(&comm_strs_type_once, comm_strs__init);
39 
40 	return &_comm_strs;
41 }
42 
comm_str__refcnt(struct comm_str * cs)43 static refcount_t *comm_str__refcnt(struct comm_str *cs)
44 {
45 	return &RC_CHK_ACCESS(cs)->refcnt;
46 }
47 
comm_str__str(const struct comm_str * cs)48 static const char *comm_str__str(const struct comm_str *cs)
49 {
50 	return &RC_CHK_ACCESS(cs)->str[0];
51 }
52 
comm_str__get(struct comm_str * cs)53 static struct comm_str *comm_str__get(struct comm_str *cs)
54 {
55 	struct comm_str *result;
56 
57 	if (RC_CHK_GET(result, cs))
58 		refcount_inc_not_zero(comm_str__refcnt(cs));
59 
60 	return result;
61 }
62 
comm_str__put(struct comm_str * cs)63 static void comm_str__put(struct comm_str *cs)
64 {
65 	if (!cs)
66 		return;
67 
68 	if (refcount_dec_and_test(comm_str__refcnt(cs))) {
69 		RC_CHK_FREE(cs);
70 	} else {
71 		if (refcount_read(comm_str__refcnt(cs)) == 1)
72 			comm_strs__remove_if_last(cs);
73 
74 		RC_CHK_PUT(cs);
75 	}
76 }
77 
comm_str__new(const char * str)78 static struct comm_str *comm_str__new(const char *str)
79 {
80 	struct comm_str *result = NULL;
81 	RC_STRUCT(comm_str) *cs;
82 
83 	cs = malloc(sizeof(*cs) + strlen(str) + 1);
84 	if (ADD_RC_CHK(result, cs)) {
85 		refcount_set(comm_str__refcnt(result), 1);
86 		strcpy(&cs->str[0], str);
87 	}
88 	return result;
89 }
90 
comm_str__search(const void * _key,const void * _member)91 static int comm_str__search(const void *_key, const void *_member)
92 {
93 	const char *key = _key;
94 	const struct comm_str *member = *(const struct comm_str * const *)_member;
95 
96 	return strcmp(key, comm_str__str(member));
97 }
98 
comm_strs__remove_if_last(struct comm_str * cs)99 static void comm_strs__remove_if_last(struct comm_str *cs)
100 {
101 	struct comm_strs *comm_strs = comm_strs__get();
102 
103 	down_write(&comm_strs->lock);
104 	/*
105 	 * Are there only references from the array, if so remove the array
106 	 * reference under the write lock so that we don't race with findnew.
107 	 */
108 	if (refcount_read(comm_str__refcnt(cs)) == 1) {
109 		struct comm_str **entry;
110 
111 		entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs,
112 				sizeof(struct comm_str *), comm_str__search);
113 		comm_str__put(*entry);
114 		for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++)
115 			comm_strs->strs[i] = comm_strs->strs[i + 1];
116 		comm_strs->num_strs--;
117 	}
118 	up_write(&comm_strs->lock);
119 }
120 
__comm_strs__find(struct comm_strs * comm_strs,const char * str)121 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str)
122 {
123 	struct comm_str **result;
124 
125 	result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *),
126 			 comm_str__search);
127 
128 	if (!result)
129 		return NULL;
130 
131 	return comm_str__get(*result);
132 }
133 
comm_strs__findnew(const char * str)134 static struct comm_str *comm_strs__findnew(const char *str)
135 {
136 	struct comm_strs *comm_strs = comm_strs__get();
137 	struct comm_str *result;
138 
139 	if (!comm_strs)
140 		return NULL;
141 
142 	down_read(&comm_strs->lock);
143 	result = __comm_strs__find(comm_strs, str);
144 	up_read(&comm_strs->lock);
145 	if (result)
146 		return result;
147 
148 	down_write(&comm_strs->lock);
149 	result = __comm_strs__find(comm_strs, str);
150 	if (!result) {
151 		if (comm_strs->num_strs == comm_strs->capacity) {
152 			struct comm_str **tmp;
153 
154 			tmp = reallocarray(comm_strs->strs,
155 					   comm_strs->capacity + 16,
156 					   sizeof(*comm_strs->strs));
157 			if (!tmp) {
158 				up_write(&comm_strs->lock);
159 				return NULL;
160 			}
161 			comm_strs->strs = tmp;
162 			comm_strs->capacity += 16;
163 		}
164 		result = comm_str__new(str);
165 		if (result) {
166 			int low = 0, high = comm_strs->num_strs - 1;
167 			int insert = comm_strs->num_strs; /* Default to inserting at the end. */
168 
169 			while (low <= high) {
170 				int mid = low + (high - low) / 2;
171 				int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str);
172 
173 				if (cmp < 0) {
174 					low = mid + 1;
175 				} else {
176 					high = mid - 1;
177 					insert = mid;
178 				}
179 			}
180 			memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert],
181 				(comm_strs->num_strs - insert) * sizeof(struct comm_str *));
182 			comm_strs->num_strs++;
183 			comm_strs->strs[insert] = result;
184 		}
185 	}
186 	up_write(&comm_strs->lock);
187 	return comm_str__get(result);
188 }
189 
comm__new(const char * str,u64 timestamp,bool exec)190 struct comm *comm__new(const char *str, u64 timestamp, bool exec)
191 {
192 	struct comm *comm = zalloc(sizeof(*comm));
193 
194 	if (!comm)
195 		return NULL;
196 
197 	comm->start = timestamp;
198 	comm->exec = exec;
199 
200 	comm->comm_str = comm_strs__findnew(str);
201 	if (!comm->comm_str) {
202 		free(comm);
203 		return NULL;
204 	}
205 
206 	return comm;
207 }
208 
comm__override(struct comm * comm,const char * str,u64 timestamp,bool exec)209 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
210 {
211 	struct comm_str *new, *old = comm->comm_str;
212 
213 	new = comm_strs__findnew(str);
214 	if (!new)
215 		return -ENOMEM;
216 
217 	comm_str__put(old);
218 	comm->comm_str = new;
219 	comm->start = timestamp;
220 	if (exec)
221 		comm->exec = true;
222 
223 	return 0;
224 }
225 
comm__free(struct comm * comm)226 void comm__free(struct comm *comm)
227 {
228 	comm_str__put(comm->comm_str);
229 	free(comm);
230 }
231 
comm__str(const struct comm * comm)232 const char *comm__str(const struct comm *comm)
233 {
234 	return comm_str__str(comm->comm_str);
235 }
236