1 // SPDX-License-Identifier: GPL-2.0
2 #include "comm.h"
3 #include <errno.h>
4 #include <string.h>
5 #include <internal/rc_check.h>
6 #include <linux/refcount.h>
7 #include <linux/zalloc.h>
8 #include <tools/libc_compat.h> // reallocarray
9
10 #include "rwsem.h"
11
DECLARE_RC_STRUCT(comm_str)12 DECLARE_RC_STRUCT(comm_str) {
13 refcount_t refcnt;
14 char str[];
15 };
16
17 static struct comm_strs {
18 struct rw_semaphore lock;
19 struct comm_str **strs;
20 int num_strs;
21 int capacity;
22 } _comm_strs;
23
24 static void comm_strs__remove_if_last(struct comm_str *cs);
25
comm_strs__init(void)26 static void comm_strs__init(void)
27 {
28 init_rwsem(&_comm_strs.lock);
29 _comm_strs.capacity = 16;
30 _comm_strs.num_strs = 0;
31 _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs));
32 }
33
comm_strs__get(void)34 static struct comm_strs *comm_strs__get(void)
35 {
36 static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT;
37
38 pthread_once(&comm_strs_type_once, comm_strs__init);
39
40 return &_comm_strs;
41 }
42
comm_str__refcnt(struct comm_str * cs)43 static refcount_t *comm_str__refcnt(struct comm_str *cs)
44 {
45 return &RC_CHK_ACCESS(cs)->refcnt;
46 }
47
comm_str__str(const struct comm_str * cs)48 static const char *comm_str__str(const struct comm_str *cs)
49 {
50 return &RC_CHK_ACCESS(cs)->str[0];
51 }
52
comm_str__get(struct comm_str * cs)53 static struct comm_str *comm_str__get(struct comm_str *cs)
54 {
55 struct comm_str *result;
56
57 if (RC_CHK_GET(result, cs))
58 refcount_inc_not_zero(comm_str__refcnt(cs));
59
60 return result;
61 }
62
comm_str__put(struct comm_str * cs)63 static void comm_str__put(struct comm_str *cs)
64 {
65 if (!cs)
66 return;
67
68 if (refcount_dec_and_test(comm_str__refcnt(cs))) {
69 RC_CHK_FREE(cs);
70 } else {
71 if (refcount_read(comm_str__refcnt(cs)) == 1)
72 comm_strs__remove_if_last(cs);
73
74 RC_CHK_PUT(cs);
75 }
76 }
77
comm_str__new(const char * str)78 static struct comm_str *comm_str__new(const char *str)
79 {
80 struct comm_str *result = NULL;
81 RC_STRUCT(comm_str) *cs;
82
83 cs = malloc(sizeof(*cs) + strlen(str) + 1);
84 if (ADD_RC_CHK(result, cs)) {
85 refcount_set(comm_str__refcnt(result), 1);
86 strcpy(&cs->str[0], str);
87 }
88 return result;
89 }
90
comm_str__search(const void * _key,const void * _member)91 static int comm_str__search(const void *_key, const void *_member)
92 {
93 const char *key = _key;
94 const struct comm_str *member = *(const struct comm_str * const *)_member;
95
96 return strcmp(key, comm_str__str(member));
97 }
98
comm_strs__remove_if_last(struct comm_str * cs)99 static void comm_strs__remove_if_last(struct comm_str *cs)
100 {
101 struct comm_strs *comm_strs = comm_strs__get();
102
103 down_write(&comm_strs->lock);
104 /*
105 * Are there only references from the array, if so remove the array
106 * reference under the write lock so that we don't race with findnew.
107 */
108 if (refcount_read(comm_str__refcnt(cs)) == 1) {
109 struct comm_str **entry;
110
111 entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs,
112 sizeof(struct comm_str *), comm_str__search);
113 comm_str__put(*entry);
114 for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++)
115 comm_strs->strs[i] = comm_strs->strs[i + 1];
116 comm_strs->num_strs--;
117 }
118 up_write(&comm_strs->lock);
119 }
120
__comm_strs__find(struct comm_strs * comm_strs,const char * str)121 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str)
122 {
123 struct comm_str **result;
124
125 result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *),
126 comm_str__search);
127
128 if (!result)
129 return NULL;
130
131 return comm_str__get(*result);
132 }
133
comm_strs__findnew(const char * str)134 static struct comm_str *comm_strs__findnew(const char *str)
135 {
136 struct comm_strs *comm_strs = comm_strs__get();
137 struct comm_str *result;
138
139 if (!comm_strs)
140 return NULL;
141
142 down_read(&comm_strs->lock);
143 result = __comm_strs__find(comm_strs, str);
144 up_read(&comm_strs->lock);
145 if (result)
146 return result;
147
148 down_write(&comm_strs->lock);
149 result = __comm_strs__find(comm_strs, str);
150 if (!result) {
151 if (comm_strs->num_strs == comm_strs->capacity) {
152 struct comm_str **tmp;
153
154 tmp = reallocarray(comm_strs->strs,
155 comm_strs->capacity + 16,
156 sizeof(*comm_strs->strs));
157 if (!tmp) {
158 up_write(&comm_strs->lock);
159 return NULL;
160 }
161 comm_strs->strs = tmp;
162 comm_strs->capacity += 16;
163 }
164 result = comm_str__new(str);
165 if (result) {
166 int low = 0, high = comm_strs->num_strs - 1;
167 int insert = comm_strs->num_strs; /* Default to inserting at the end. */
168
169 while (low <= high) {
170 int mid = low + (high - low) / 2;
171 int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str);
172
173 if (cmp < 0) {
174 low = mid + 1;
175 } else {
176 high = mid - 1;
177 insert = mid;
178 }
179 }
180 memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert],
181 (comm_strs->num_strs - insert) * sizeof(struct comm_str *));
182 comm_strs->num_strs++;
183 comm_strs->strs[insert] = result;
184 }
185 }
186 up_write(&comm_strs->lock);
187 return comm_str__get(result);
188 }
189
comm__new(const char * str,u64 timestamp,bool exec)190 struct comm *comm__new(const char *str, u64 timestamp, bool exec)
191 {
192 struct comm *comm = zalloc(sizeof(*comm));
193
194 if (!comm)
195 return NULL;
196
197 comm->start = timestamp;
198 comm->exec = exec;
199
200 comm->comm_str = comm_strs__findnew(str);
201 if (!comm->comm_str) {
202 free(comm);
203 return NULL;
204 }
205
206 return comm;
207 }
208
comm__override(struct comm * comm,const char * str,u64 timestamp,bool exec)209 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
210 {
211 struct comm_str *new, *old = comm->comm_str;
212
213 new = comm_strs__findnew(str);
214 if (!new)
215 return -ENOMEM;
216
217 comm_str__put(old);
218 comm->comm_str = new;
219 comm->start = timestamp;
220 if (exec)
221 comm->exec = true;
222
223 return 0;
224 }
225
comm__free(struct comm * comm)226 void comm__free(struct comm *comm)
227 {
228 comm_str__put(comm->comm_str);
229 free(comm);
230 }
231
comm__str(const struct comm * comm)232 const char *comm__str(const struct comm *comm)
233 {
234 return comm_str__str(comm->comm_str);
235 }
236