1 // Copyright 2017 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/win/com_init_check_hook.h"
6
7 #include <objbase.h>
8
9 #include <shlobj.h>
10 #include <wrl/client.h>
11
12 #include "base/test/gtest_util.h"
13 #include "base/win/com_init_util.h"
14 #include "base/win/patch_util.h"
15 #include "base/win/scoped_com_initializer.h"
16 #include "testing/gtest/include/gtest/gtest.h"
17
18 namespace base {
19 namespace win {
20
21 using Microsoft::WRL::ComPtr;
22
TEST(ComInitCheckHook,AssertNotInitialized)23 TEST(ComInitCheckHook, AssertNotInitialized) {
24 ComInitCheckHook com_check_hook;
25 AssertComApartmentType(ComApartmentType::NONE);
26 ComPtr<IUnknown> shell_link;
27 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
28 EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
29 IID_PPV_ARGS(&shell_link)));
30 #else
31 EXPECT_EQ(CO_E_NOTINITIALIZED,
32 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
33 IID_PPV_ARGS(&shell_link)));
34 #endif
35 }
36
TEST(ComInitCheckHook,HookRemoval)37 TEST(ComInitCheckHook, HookRemoval) {
38 AssertComApartmentType(ComApartmentType::NONE);
39 { ComInitCheckHook com_check_hook; }
40 ComPtr<IUnknown> shell_link;
41 EXPECT_EQ(CO_E_NOTINITIALIZED,
42 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
43 IID_PPV_ARGS(&shell_link)));
44 }
45
TEST(ComInitCheckHook,NoAssertComInitialized)46 TEST(ComInitCheckHook, NoAssertComInitialized) {
47 ComInitCheckHook com_check_hook;
48 ScopedCOMInitializer com_initializer;
49 ComPtr<IUnknown> shell_link;
50 EXPECT_TRUE(SUCCEEDED(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
51 IID_PPV_ARGS(&shell_link))));
52 }
53
TEST(ComInitCheckHook,MultipleHooks)54 TEST(ComInitCheckHook, MultipleHooks) {
55 ComInitCheckHook com_check_hook_1;
56 ComInitCheckHook com_check_hook_2;
57 AssertComApartmentType(ComApartmentType::NONE);
58 ComPtr<IUnknown> shell_link;
59 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
60 EXPECT_DCHECK_DEATH(::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
61 IID_PPV_ARGS(&shell_link)));
62 #else
63 EXPECT_EQ(CO_E_NOTINITIALIZED,
64 ::CoCreateInstance(CLSID_ShellLink, nullptr, CLSCTX_ALL,
65 IID_PPV_ARGS(&shell_link)));
66 #endif
67 }
68
TEST(ComInitCheckHook,UnexpectedHook)69 TEST(ComInitCheckHook, UnexpectedHook) {
70 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
71 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
72 ASSERT_TRUE(ole32_library);
73
74 uint32_t co_create_instance_padded_address =
75 reinterpret_cast<uint32_t>(
76 GetProcAddress(ole32_library, "CoCreateInstance")) -
77 5;
78 const unsigned char* co_create_instance_bytes =
79 reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
80 const unsigned char original_byte = co_create_instance_bytes[0];
81 const unsigned char unexpected_byte = 0xdb;
82 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
83 internal::ModifyCode(
84 reinterpret_cast<void*>(co_create_instance_padded_address),
85 reinterpret_cast<const void*>(&unexpected_byte),
86 sizeof(unexpected_byte)));
87
88 EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });
89
90 // If this call fails, really bad things are going to happen to other tests
91 // so CHECK here.
92 CHECK_EQ(static_cast<DWORD>(NO_ERROR),
93 internal::ModifyCode(
94 reinterpret_cast<void*>(co_create_instance_padded_address),
95 reinterpret_cast<const void*>(&original_byte),
96 sizeof(original_byte)));
97
98 ::FreeLibrary(ole32_library);
99 ole32_library = nullptr;
100 #endif
101 }
102
TEST(ComInitCheckHook,ExternallyHooked)103 TEST(ComInitCheckHook, ExternallyHooked) {
104 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
105 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
106 ASSERT_TRUE(ole32_library);
107
108 uint32_t co_create_instance_address = reinterpret_cast<uint32_t>(
109 GetProcAddress(ole32_library, "CoCreateInstance"));
110 const unsigned char* co_create_instance_bytes =
111 reinterpret_cast<const unsigned char*>(co_create_instance_address);
112 const unsigned char original_byte = co_create_instance_bytes[0];
113 const unsigned char jmp_byte = 0xe9;
114 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
115 internal::ModifyCode(
116 reinterpret_cast<void*>(co_create_instance_address),
117 reinterpret_cast<const void*>(&jmp_byte), sizeof(jmp_byte)));
118
119 // Externally patched instances should crash so we catch these cases on bots.
120 EXPECT_DCHECK_DEATH({ ComInitCheckHook com_check_hook; });
121
122 // If this call fails, really bad things are going to happen to other tests
123 // so CHECK here.
124 CHECK_EQ(
125 static_cast<DWORD>(NO_ERROR),
126 internal::ModifyCode(reinterpret_cast<void*>(co_create_instance_address),
127 reinterpret_cast<const void*>(&original_byte),
128 sizeof(original_byte)));
129
130 ::FreeLibrary(ole32_library);
131 ole32_library = nullptr;
132 #endif
133 }
134
TEST(ComInitCheckHook,UnexpectedChangeDuringHook)135 TEST(ComInitCheckHook, UnexpectedChangeDuringHook) {
136 #if defined(COM_INIT_CHECK_HOOK_ENABLED)
137 HMODULE ole32_library = ::LoadLibrary(L"ole32.dll");
138 ASSERT_TRUE(ole32_library);
139
140 uint32_t co_create_instance_padded_address =
141 reinterpret_cast<uint32_t>(
142 GetProcAddress(ole32_library, "CoCreateInstance")) -
143 5;
144 const unsigned char* co_create_instance_bytes =
145 reinterpret_cast<const unsigned char*>(co_create_instance_padded_address);
146 const unsigned char original_byte = co_create_instance_bytes[0];
147 const unsigned char unexpected_byte = 0xdb;
148 ASSERT_EQ(static_cast<DWORD>(NO_ERROR),
149 internal::ModifyCode(
150 reinterpret_cast<void*>(co_create_instance_padded_address),
151 reinterpret_cast<const void*>(&unexpected_byte),
152 sizeof(unexpected_byte)));
153
154 EXPECT_DCHECK_DEATH({
155 ComInitCheckHook com_check_hook;
156
157 internal::ModifyCode(
158 reinterpret_cast<void*>(co_create_instance_padded_address),
159 reinterpret_cast<const void*>(&unexpected_byte),
160 sizeof(unexpected_byte));
161 });
162
163 // If this call fails, really bad things are going to happen to other tests
164 // so CHECK here.
165 CHECK_EQ(static_cast<DWORD>(NO_ERROR),
166 internal::ModifyCode(
167 reinterpret_cast<void*>(co_create_instance_padded_address),
168 reinterpret_cast<const void*>(&original_byte),
169 sizeof(original_byte)));
170
171 ::FreeLibrary(ole32_library);
172 ole32_library = nullptr;
173 #endif
174 }
175
176 } // namespace win
177 } // namespace base
178