xref: /aosp_15_r20/bionic/tests/libs/stack_tagging_helper.cpp (revision 8d67ca893c1523eb926b9080dbe4e2ffd2a27ba1)
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <errno.h>
18 #include <setjmp.h>
19 #include <signal.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <sys/mman.h>
24 #include <sys/types.h>
25 #include <sys/wait.h>
26 #include <unistd.h>
27 #include <thread>
28 
29 #include <bionic/malloc.h>
30 
31 #include "CHECK.h"
32 
33 #if defined(__aarch64__)
34 
35 template <typename T>
mte_set_tag(T * p)36 static inline void mte_set_tag(T* p) {
37   __asm__ __volatile__(
38       ".arch_extension memtag\n"
39       "stg %[Ptr], [%[Ptr]]\n"
40       :
41       : [Ptr] "r"(p)
42       : "memory");
43 }
44 
45 template <typename T>
mte_get_tag(T * p)46 static inline T* mte_get_tag(T* p) {
47   __asm__ __volatile__(
48       ".arch_extension memtag\n"
49       "ldg %[Ptr], [%[Ptr]]\n"
50       : [Ptr] "+r"(p)
51       :
52       : "memory");
53   return p;
54 }
55 
56 template <typename T>
mte_increment_tag(T * p)57 static inline T* mte_increment_tag(T* p) {
58   T* res;
59   __asm__ __volatile__(
60       ".arch_extension memtag\n"
61       "addg %[Res], %[Ptr], #0, #1\n"
62       : [Res] "=r"(res)
63       : [Ptr] "r"(p)
64       : "memory");
65   return res;
66 }
67 
68 constexpr size_t kStackAllocationSize = 128 * 1024;
69 
70 // Prevent optimizations.
71 volatile void* sink;
72 
73 enum struct ChildAction { Exit, Execve, Execl };
74 
75 // Either execve or _exit, transferring control back to parent.
vfork_child2(ChildAction action,void * fp_parent)76 __attribute__((no_sanitize("memtag"), optnone, noinline)) void vfork_child2(ChildAction action,
77                                                                             void* fp_parent) {
78   // Make sure that the buffer in the caller has not been optimized out.
79   void* fp = __builtin_frame_address(0);
80   CHECK(reinterpret_cast<uintptr_t>(fp_parent) - reinterpret_cast<uintptr_t>(fp) >=
81         kStackAllocationSize);
82   if (action == ChildAction::Execve) {
83     const char* argv[] = {"/system/bin/true", nullptr};
84     const char* envp[] = {nullptr};
85     execve("/system/bin/true", const_cast<char**>(argv), const_cast<char**>(envp));
86     fprintf(stderr, "execve failed: %m\n");
87     _exit(1);
88   } else if (action == ChildAction::Execl) {
89     execl("/system/bin/true", "/system/bin/true", "unusedA", "unusedB", nullptr);
90     fprintf(stderr, "execl failed: %m\n");
91     _exit(1);
92   } else if (action == ChildAction::Exit) {
93     _exit(0);
94   }
95   CHECK(0);
96 }
97 
98 // Place a tagged buffer on the stack. Do not tag the top half so that the parent does not crash too
99 // early even if things go wrong.
vfork_child(ChildAction action)100 __attribute__((no_sanitize("memtag"), optnone, noinline)) void vfork_child(ChildAction action) {
101   alignas(16) char buf[kStackAllocationSize] __attribute__((uninitialized));
102   sink = &buf;
103 
104   for (char* p = buf; p < buf + sizeof(buf) / 2; p += 16) {
105     char* q = mte_increment_tag(p);
106     mte_set_tag(q);
107     CHECK(mte_get_tag(p) == q);
108   }
109   vfork_child2(action, __builtin_frame_address(0));
110 }
111 
112 // Parent. Check that the stack has correct allocation tags.
vfork_parent(pid_t pid)113 __attribute__((no_sanitize("memtag"), optnone, noinline)) void vfork_parent(pid_t pid) {
114   alignas(16) char buf[kStackAllocationSize] __attribute__((uninitialized));
115   fprintf(stderr, "vfork_parent %p\n", &buf);
116   bool success = true;
117   for (char* p = buf; p < buf + sizeof(buf); p += 16) {
118     char* q = mte_get_tag(p);
119     if (p != q) {
120       fprintf(stderr, "tag mismatch at offset %zx: %p != %p\n", p - buf, p, q);
121       success = false;
122       break;
123     }
124   }
125 
126   int wstatus;
127   do {
128     int res = waitpid(pid, &wstatus, 0);
129     CHECK(res == pid);
130   } while (!WIFEXITED(wstatus) && !WIFSIGNALED(wstatus));
131 
132   CHECK(WIFEXITED(wstatus));
133   CHECK(WEXITSTATUS(wstatus) == 0);
134 
135   if (!success) exit(1);
136 }
137 
test_vfork(ChildAction action)138 void test_vfork(ChildAction action) {
139   pid_t pid = vfork();
140   if (pid == 0) {
141     vfork_child(action);
142   } else {
143     vfork_parent(pid);
144   }
145 }
146 
settag_and_longjmp(jmp_buf cont)147 __attribute__((no_sanitize("memtag"), optnone, noinline)) static void settag_and_longjmp(
148     jmp_buf cont) {
149   alignas(16) char buf[kStackAllocationSize] __attribute__((uninitialized));
150   sink = &buf;
151 
152   for (char* p = buf; p < buf + sizeof(buf) / 2; p += 16) {
153     char* q = mte_increment_tag(p);
154     mte_set_tag(q);
155     if (mte_get_tag(p) != q) {
156       fprintf(stderr, "failed to set allocation tags on stack: %p != %p\n", mte_get_tag(p), q);
157       exit(1);
158     }
159   }
160   longjmp(cont, 42);
161 }
162 
163 // Check that the stack has correct allocation tags.
check_stack_tags()164 __attribute__((no_sanitize("memtag"), optnone, noinline)) static void check_stack_tags() {
165   alignas(16) char buf[kStackAllocationSize] __attribute__((uninitialized));
166   for (char* p = buf; p < buf + sizeof(buf); p += 16) {
167     void* q = mte_get_tag(p);
168     if (p != q) {
169       fprintf(stderr, "stack tags mismatch: expected %p, got %p", p, q);
170       exit(1);
171     }
172   }
173 }
174 
check_longjmp_restores_tags()175 void check_longjmp_restores_tags() {
176   int value;
177   jmp_buf jb;
178   if ((value = setjmp(jb)) == 0) {
179     settag_and_longjmp(jb);
180     exit(2);  // Unreachable.
181   } else {
182     CHECK(value == 42);
183     check_stack_tags();
184   }
185 }
186 
187 class SigAltStackScoped {
188   stack_t old_ss;
189   void* altstack_start;
190   size_t altstack_size;
191 
192  public:
SigAltStackScoped(size_t sz)193   SigAltStackScoped(size_t sz) : altstack_size(sz) {
194     altstack_start = mmap(nullptr, altstack_size, PROT_READ | PROT_WRITE | PROT_MTE,
195                           MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);
196     if (altstack_start == MAP_FAILED) {
197       fprintf(stderr, "sigaltstack mmap failed: %m\n");
198       exit(1);
199     }
200     stack_t ss = {};
201     ss.ss_sp = altstack_start;
202     ss.ss_size = altstack_size;
203     int res = sigaltstack(&ss, &old_ss);
204     CHECK(res == 0);
205   }
206 
~SigAltStackScoped()207   ~SigAltStackScoped() {
208     int res = sigaltstack(&old_ss, nullptr);
209     CHECK(res == 0);
210     munmap(altstack_start, altstack_size);
211   }
212 };
213 
214 class SigActionScoped {
215   int signo;
216   struct sigaction oldsa;
217 
218  public:
219   using handler_t = void (*)(int, siginfo_t* siginfo, void*);
220 
SigActionScoped(int signo,handler_t handler)221   SigActionScoped(int signo, handler_t handler) : signo(signo) {
222     struct sigaction sa = {};
223     sa.sa_sigaction = handler;
224     sa.sa_flags = SA_SIGINFO | SA_ONSTACK;
225     int res = sigaction(signo, &sa, &oldsa);
226     CHECK(res == 0);
227   }
228 
~SigActionScoped()229   ~SigActionScoped() {
230     int res = sigaction(signo, &oldsa, nullptr);
231     CHECK(res == 0);
232   }
233 };
234 
test_longjmp()235 void test_longjmp() {
236   check_longjmp_restores_tags();
237 
238   std::thread t([]() { check_longjmp_restores_tags(); });
239   t.join();
240 }
241 
test_longjmp_sigaltstack()242 void test_longjmp_sigaltstack() {
243   const size_t kAltStackSize = kStackAllocationSize + getpagesize() * 16;
244   SigAltStackScoped sigAltStackScoped(kAltStackSize);
245   SigActionScoped sigActionScoped(
246       SIGUSR1, [](int, siginfo_t*, void*) { check_longjmp_restores_tags(); });
247   raise(SIGUSR1);
248 
249   // same for a secondary thread
250   std::thread t([&]() {
251     SigAltStackScoped sigAltStackScoped(kAltStackSize);
252     raise(SIGUSR1);
253   });
254   t.join();
255 }
256 
test_android_mallopt()257 void test_android_mallopt() {
258   bool memtag_stack;
259   CHECK(android_mallopt(M_MEMTAG_STACK_IS_ON, &memtag_stack, sizeof(memtag_stack)));
260   CHECK(memtag_stack);
261 }
262 
GetTag(void * addr)263 static uintptr_t GetTag(void* addr) {
264   return reinterpret_cast<uintptr_t>(addr) & (0xFULL << 56);
265 }
266 
GetTag(volatile void * addr)267 static uintptr_t GetTag(volatile void* addr) {
268   return GetTag(const_cast<void*>(addr));
269 }
270 
271 static volatile char* throw_frame;
272 static volatile char* skip_frame3_frame;
273 volatile char *x;
274 
throws()275 __attribute__((noinline)) void throws() {
276   // Prevent optimization.
277   if (getpid() == 0) return;
278   throw_frame = reinterpret_cast<char*>(__builtin_frame_address(0));
279   throw "error";
280 }
281 
maybe_throws()282 __attribute__((noinline)) void maybe_throws() {
283   // These are all unique sizes so in case of a failure, we can see which ones
284   // are not untagged from the tag dump.
285   volatile char y[5 * 16]= {};
286   x = y;
287   // Make sure y is tagged.
288   CHECK(GetTag(&y) != GetTag(__builtin_frame_address(0)));
289   throws();
290 }
291 
skip_frame()292 __attribute__((noinline, no_sanitize("memtag"))) void skip_frame() {
293   volatile char y[6*16] = {};
294   x = y;
295   // Make sure y is not tagged.
296   CHECK(GetTag(&y) == GetTag(__builtin_frame_address(0)));
297   maybe_throws();
298 }
299 
skip_frame2()300 __attribute__((noinline)) void skip_frame2() {
301   volatile char y[7*16] = {};
302   x = y;
303   // Make sure y is tagged.
304   CHECK(GetTag(&y) != GetTag(__builtin_frame_address(0)));
305   skip_frame();
306 }
307 
skip_frame3()308 __attribute__((noinline, no_sanitize("memtag"))) void skip_frame3() {
309   volatile char y[8*16] = {};
310   x = y;
311   skip_frame3_frame = reinterpret_cast<char*>(__builtin_frame_address(0));
312   // Make sure y is not tagged.
313   CHECK(GetTag(&y) == GetTag(__builtin_frame_address(0)));
314   skip_frame2();
315 }
316 
test_exception_cleanup()317 void test_exception_cleanup() {
318   // This is here for debugging purposes, if something goes wrong we can
319   // verify that this placeholder did not get untagged.
320   volatile char placeholder[16*16] = {};
321   x = placeholder;
322   try {
323     skip_frame3();
324   } catch (const char* e) {
325   }
326   if (throw_frame >= skip_frame3_frame) {
327     fprintf(stderr, "invalid throw frame");
328     exit(1);
329   }
330   for (char* b = const_cast<char*>(throw_frame); b < skip_frame3_frame; ++b) {
331     if (mte_get_tag(b) != b) {
332       fprintf(stderr, "invalid tag at %p", b);
333       exit(1);
334     }
335   }
336 }
337 
main(int argc,char ** argv)338 int main(int argc, char** argv) {
339   if (argc < 2) {
340     printf("nothing to do\n");
341     return 1;
342   }
343 
344   if (strcmp(argv[1], "vfork_execve") == 0) {
345     test_vfork(ChildAction::Execve);
346     return 0;
347   }
348 
349   if (strcmp(argv[1], "vfork_execl") == 0) {
350     test_vfork(ChildAction::Execl);
351     return 0;
352   }
353 
354   if (strcmp(argv[1], "vfork_exit") == 0) {
355     test_vfork(ChildAction::Exit);
356     return 0;
357   }
358 
359   if (strcmp(argv[1], "longjmp") == 0) {
360     test_longjmp();
361     return 0;
362   }
363 
364   if (strcmp(argv[1], "longjmp_sigaltstack") == 0) {
365     test_longjmp_sigaltstack();
366     return 0;
367   }
368 
369   if (strcmp(argv[1], "android_mallopt") == 0) {
370     test_android_mallopt();
371     return 0;
372   }
373 
374   if (strcmp(argv[1], "exception_cleanup") == 0) {
375     test_exception_cleanup();
376     return 0;
377   }
378 
379   printf("unrecognized command: %s\n", argv[1]);
380   return 1;
381 }
382 #else
main(int,char **)383 int main(int, char**) {
384   printf("aarch64 only\n");
385   return 1;
386 }
387 #endif  // defined(__aarch64__)
388