1 /**************************************************************************
2 *
3 * Copyright 2009-2013 VMware, Inc.
4 * All Rights Reserved.
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 *
26 **************************************************************************/
27
28 #include <windows.h>
29 #include <tlhelp32.h>
30
31 #include "util/compiler.h"
32 #include "util/u_debug.h"
33 #include "stw_tls.h"
34
35 static DWORD tlsIndex = TLS_OUT_OF_INDEXES;
36
37
38 /**
39 * Static mutex to protect the access to g_pendingTlsData global and
40 * stw_tls_data::next member.
41 */
42 static CRITICAL_SECTION g_mutex = {
43 (PCRITICAL_SECTION_DEBUG)-1, -1, 0, 0, 0, 0
44 };
45
46 /**
47 * There is no way to invoke TlsSetValue for a different thread, so we
48 * temporarily put the thread data for non-current threads here.
49 */
50 static struct stw_tls_data *g_pendingTlsData = NULL;
51
52
53 static struct stw_tls_data *
54 stw_tls_data_create(DWORD dwThreadId);
55
56 static struct stw_tls_data *
57 stw_tls_lookup_pending_data(DWORD dwThreadId);
58
59
60 bool
stw_tls_init(void)61 stw_tls_init(void)
62 {
63 tlsIndex = TlsAlloc();
64 if (tlsIndex == TLS_OUT_OF_INDEXES) {
65 return false;
66 }
67
68 /*
69 * DllMain is called with DLL_THREAD_ATTACH only for threads created after
70 * the DLL is loaded by the process. So enumerate and add our hook to all
71 * previously existing threads.
72 *
73 * XXX: Except for the current thread since it there is an explicit
74 * stw_tls_init_thread() call for it later on.
75 */
76 #ifndef _GAMING_XBOX
77 DWORD dwCurrentProcessId = GetCurrentProcessId();
78 DWORD dwCurrentThreadId = GetCurrentThreadId();
79 HANDLE hSnapshot = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, dwCurrentProcessId);
80 if (hSnapshot != INVALID_HANDLE_VALUE) {
81 THREADENTRY32 te;
82 te.dwSize = sizeof te;
83 if (Thread32First(hSnapshot, &te)) {
84 do {
85 if (te.dwSize >= FIELD_OFFSET(THREADENTRY32, th32OwnerProcessID) +
86 sizeof te.th32OwnerProcessID) {
87 if (te.th32OwnerProcessID == dwCurrentProcessId) {
88 if (te.th32ThreadID != dwCurrentThreadId) {
89 struct stw_tls_data *data;
90 data = stw_tls_data_create(te.th32ThreadID);
91 if (data) {
92 EnterCriticalSection(&g_mutex);
93 data->next = g_pendingTlsData;
94 g_pendingTlsData = data;
95 LeaveCriticalSection(&g_mutex);
96 }
97 }
98 }
99 }
100 te.dwSize = sizeof te;
101 } while (Thread32Next(hSnapshot, &te));
102 }
103 CloseHandle(hSnapshot);
104 }
105 #endif /* _GAMING_XBOX */
106
107 return true;
108 }
109
110
111 /**
112 * Install windows hook for a given thread (not necessarily the current one).
113 */
114 static struct stw_tls_data *
stw_tls_data_create(DWORD dwThreadId)115 stw_tls_data_create(DWORD dwThreadId)
116 {
117 struct stw_tls_data *data;
118
119 if (0) {
120 debug_printf("%s(0x%04lx)\n", __func__, dwThreadId);
121 }
122
123 data = calloc(1, sizeof *data);
124 if (!data) {
125 goto no_data;
126 }
127
128 data->dwThreadId = dwThreadId;
129
130 #ifndef _GAMING_XBOX
131 data->hCallWndProcHook = SetWindowsHookEx(WH_CALLWNDPROC,
132 stw_call_window_proc,
133 NULL,
134 dwThreadId);
135 #else
136 data->hCallWndProcHook = NULL;
137 #endif
138 if (data->hCallWndProcHook == NULL) {
139 goto no_hook;
140 }
141
142 return data;
143
144 no_hook:
145 free(data);
146 no_data:
147 return NULL;
148 }
149
150 /**
151 * Destroy the per-thread data/hook.
152 *
153 * It is important to remove all hooks when unloading our DLL, otherwise our
154 * hook function might be called after it is no longer there.
155 */
156 static void
stw_tls_data_destroy(struct stw_tls_data * data)157 stw_tls_data_destroy(struct stw_tls_data *data)
158 {
159 assert(data);
160 if (!data) {
161 return;
162 }
163
164 if (0) {
165 debug_printf("%s(0x%04lx)\n", __func__, data->dwThreadId);
166 }
167
168 #ifndef _GAMING_XBOX
169 if (data->hCallWndProcHook) {
170 UnhookWindowsHookEx(data->hCallWndProcHook);
171 data->hCallWndProcHook = NULL;
172 }
173 #endif
174
175 free(data);
176 }
177
178 bool
stw_tls_init_thread(void)179 stw_tls_init_thread(void)
180 {
181 struct stw_tls_data *data;
182
183 if (tlsIndex == TLS_OUT_OF_INDEXES) {
184 return false;
185 }
186
187 data = stw_tls_data_create(GetCurrentThreadId());
188 if (!data) {
189 return false;
190 }
191
192 TlsSetValue(tlsIndex, data);
193
194 return true;
195 }
196
197 void
stw_tls_cleanup_thread(void)198 stw_tls_cleanup_thread(void)
199 {
200 struct stw_tls_data *data;
201
202 if (tlsIndex == TLS_OUT_OF_INDEXES) {
203 return;
204 }
205
206 data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
207 if (data) {
208 TlsSetValue(tlsIndex, NULL);
209 } else {
210 /* See if there this thread's data in on the pending list */
211 data = stw_tls_lookup_pending_data(GetCurrentThreadId());
212 }
213
214 if (data) {
215 stw_tls_data_destroy(data);
216 }
217 }
218
219 void
stw_tls_cleanup(void)220 stw_tls_cleanup(void)
221 {
222 if (tlsIndex != TLS_OUT_OF_INDEXES) {
223 /*
224 * Destroy all items in g_pendingTlsData linked list.
225 */
226 EnterCriticalSection(&g_mutex);
227 while (g_pendingTlsData) {
228 struct stw_tls_data * data = g_pendingTlsData;
229 g_pendingTlsData = data->next;
230 stw_tls_data_destroy(data);
231 }
232 LeaveCriticalSection(&g_mutex);
233
234 TlsFree(tlsIndex);
235 tlsIndex = TLS_OUT_OF_INDEXES;
236 }
237 }
238
239 /*
240 * Search for the current thread in the g_pendingTlsData linked list.
241 *
242 * It will remove and return the node on success, or return NULL on failure.
243 */
244 static struct stw_tls_data *
stw_tls_lookup_pending_data(DWORD dwThreadId)245 stw_tls_lookup_pending_data(DWORD dwThreadId)
246 {
247 struct stw_tls_data ** p_data;
248 struct stw_tls_data *data = NULL;
249
250 EnterCriticalSection(&g_mutex);
251 for (p_data = &g_pendingTlsData; *p_data; p_data = &(*p_data)->next) {
252 if ((*p_data)->dwThreadId == dwThreadId) {
253 data = *p_data;
254
255 /*
256 * Unlink the node.
257 */
258 *p_data = data->next;
259 data->next = NULL;
260
261 break;
262 }
263 }
264 LeaveCriticalSection(&g_mutex);
265
266 return data;
267 }
268
269 struct stw_tls_data *
stw_tls_get_data(void)270 stw_tls_get_data(void)
271 {
272 struct stw_tls_data *data;
273
274 if (tlsIndex == TLS_OUT_OF_INDEXES) {
275 return NULL;
276 }
277
278 data = (struct stw_tls_data *) TlsGetValue(tlsIndex);
279 if (!data) {
280 DWORD dwCurrentThreadId = GetCurrentThreadId();
281
282 /*
283 * Search for the current thread in the g_pendingTlsData linked list.
284 */
285 data = stw_tls_lookup_pending_data(dwCurrentThreadId);
286
287 if (!data) {
288 /*
289 * This should be impossible now.
290 */
291 assert(!"Failed to find thread data for thread id");
292
293 /*
294 * DllMain is called with DLL_THREAD_ATTACH only by threads created
295 * after the DLL is loaded by the process
296 */
297 data = stw_tls_data_create(dwCurrentThreadId);
298 if (!data) {
299 return NULL;
300 }
301 }
302
303 TlsSetValue(tlsIndex, data);
304 }
305
306 assert(data);
307 assert(data->dwThreadId == GetCurrentThreadId());
308 assert(data->next == NULL);
309
310 return data;
311 }
312