1 // Copyright 2023 Code Intelligence GmbH
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "dex_file_manager.h"
16
17 #include <algorithm>
18 #include <iostream>
19 #include <sstream>
20 #include <string>
21 #include <vector>
22
23 #include "jazzer_jvmti_allocator.h"
24 #include "jvmti.h"
25 #include "slicer/dex_ir.h"
26 #include "slicer/reader.h"
27 #include "slicer/writer.h"
28
GetName(const char * name)29 std::string GetName(const char* name) {
30 std::stringstream ss;
31 // Class name needs to be in the format "L<class_name>;" as it is stored in
32 // the types table in the DEX file for slicer to find it
33 ss << "L" << name << ";";
34 return ss.str();
35 }
36
IsValidIndex(dex::u4 index)37 bool IsValidIndex(dex::u4 index) { return index != (unsigned)-1; }
38
addDexFile(const unsigned char * bytes,int length)39 void DexFileManager::addDexFile(const unsigned char* bytes, int length) {
40 unsigned char* newArr = new unsigned char[length];
41 std::copy(bytes, bytes + length, newArr);
42
43 dexFiles.push_back(newArr);
44 dexFilesSize.push_back(length);
45 }
46
getClassBytes(const char * className,int dexFileIndex,jvmtiEnv * jvmti,size_t * newSize)47 unsigned char* DexFileManager::getClassBytes(const char* className,
48 int dexFileIndex, jvmtiEnv* jvmti,
49 size_t* newSize) {
50 dex::Reader dexReader(dexFiles[dexFileIndex], dexFilesSize[dexFileIndex]);
51 auto descName = GetName(className);
52
53 auto classIndex = dexReader.FindClassIndex(descName.c_str());
54 if (!IsValidIndex(classIndex)) {
55 *newSize = *newSize;
56 return nullptr;
57 }
58
59 dexReader.CreateClassIr(classIndex);
60 auto oldIr = dexReader.GetIr();
61
62 dex::Writer writer(oldIr);
63 JazzerJvmtiAllocator allocator(jvmti);
64 return writer.CreateImage(&allocator, newSize);
65 }
66
findDexFileForClass(const char * className)67 uint32_t DexFileManager::findDexFileForClass(const char* className) {
68 for (int i = 0; i < dexFiles.size(); i++) {
69 dex::Reader dexReader(dexFiles[i], dexFilesSize[i]);
70
71 std::string descName = GetName(className);
72 dex::u4 classIndex = dexReader.FindClassIndex(descName.c_str());
73
74 if (IsValidIndex(classIndex)) {
75 return i;
76 }
77 }
78
79 return -1;
80 }
81
getMethodDescriptions(std::vector<ir::EncodedMethod * > * encMethodList)82 std::vector<std::string> getMethodDescriptions(
83 std::vector<ir::EncodedMethod*>* encMethodList) {
84 std::vector<std::string> methodDescs;
85
86 for (int i = 0; i < encMethodList->size(); i++) {
87 std::stringstream ss;
88 ss << (*encMethodList)[i]->access_flags;
89 ss << (*encMethodList)[i]->decl->name->c_str();
90 ss << (*encMethodList)[i]->decl->prototype->Signature().c_str();
91
92 methodDescs.push_back(ss.str());
93 }
94
95 sort(methodDescs.begin(), methodDescs.end());
96 return methodDescs;
97 }
98
getFieldDescriptions(std::vector<ir::EncodedField * > * encFieldList)99 std::vector<std::string> getFieldDescriptions(
100 std::vector<ir::EncodedField*>* encFieldList) {
101 std::vector<std::string> fieldDescs;
102
103 for (int i = 0; i < encFieldList->size(); i++) {
104 std::stringstream ss;
105 ss << (*encFieldList)[i]->access_flags;
106 ss << (*encFieldList)[i]->decl->type->descriptor->c_str();
107 ss << (*encFieldList)[i]->decl->name->c_str();
108 fieldDescs.push_back(ss.str());
109 }
110
111 sort(fieldDescs.begin(), fieldDescs.end());
112 return fieldDescs;
113 }
114
matchFields(std::vector<ir::EncodedField * > * encodedFieldListOne,std::vector<ir::EncodedField * > * encodedFieldListTwo)115 bool matchFields(std::vector<ir::EncodedField*>* encodedFieldListOne,
116 std::vector<ir::EncodedField*>* encodedFieldListTwo) {
117 std::vector<std::string> fDescListOne =
118 getFieldDescriptions(encodedFieldListOne);
119 std::vector<std::string> fDescListTwo =
120 getFieldDescriptions(encodedFieldListTwo);
121
122 if (fDescListOne.size() != fDescListTwo.size()) {
123 return false;
124 }
125
126 for (int i = 0; i < fDescListOne.size(); i++) {
127 if (fDescListOne[i] != fDescListTwo[i]) {
128 return false;
129 }
130 }
131
132 return true;
133 }
134
matchMethods(std::vector<ir::EncodedMethod * > * encodedMethodListOne,std::vector<ir::EncodedMethod * > * encodedMethodListTwo)135 bool matchMethods(std::vector<ir::EncodedMethod*>* encodedMethodListOne,
136 std::vector<ir::EncodedMethod*>* encodedMethodListTwo) {
137 std::vector<std::string> mDescListOne =
138 getMethodDescriptions(encodedMethodListOne);
139 std::vector<std::string> mDescListTwo =
140 getMethodDescriptions(encodedMethodListTwo);
141
142 if (mDescListOne.size() != mDescListTwo.size()) {
143 return false;
144 }
145
146 for (int i = 0; i < mDescListOne.size(); i++) {
147 if (mDescListOne[i] != mDescListTwo[i]) {
148 return false;
149 }
150 }
151
152 return true;
153 }
154
classStructureMatches(ir::Class * classOne,ir::Class * classTwo)155 bool classStructureMatches(ir::Class* classOne, ir::Class* classTwo) {
156 return matchMethods(&(classOne->direct_methods),
157 &(classTwo->direct_methods)) &&
158 matchMethods(&(classOne->virtual_methods),
159 &(classTwo->virtual_methods)) &&
160 matchFields(&(classOne->static_fields), &(classTwo->static_fields)) &&
161 matchFields(&(classOne->instance_fields),
162 &(classTwo->instance_fields)) &&
163 classOne->access_flags == classTwo->access_flags;
164 }
165
structureMatches(dex::Reader * oldReader,dex::Reader * newReader,const char * className)166 bool DexFileManager::structureMatches(dex::Reader* oldReader,
167 dex::Reader* newReader,
168 const char* className) {
169 std::string descName = GetName(className);
170
171 dex::u4 oldReaderIndex = oldReader->FindClassIndex(descName.c_str());
172 dex::u4 newReaderIndex = newReader->FindClassIndex(descName.c_str());
173
174 if (!IsValidIndex(oldReaderIndex) || !IsValidIndex(newReaderIndex)) {
175 return false;
176 }
177
178 oldReader->CreateClassIr(oldReaderIndex);
179 newReader->CreateClassIr(newReaderIndex);
180
181 std::shared_ptr<ir::DexFile> oldDexFile = oldReader->GetIr();
182 std::shared_ptr<ir::DexFile> newDexFile = newReader->GetIr();
183
184 for (int i = 0; i < oldDexFile->classes.size(); i++) {
185 const char* oldClassDescriptor =
186 oldDexFile->classes[i]->type->descriptor->c_str();
187 if (strcmp(oldClassDescriptor, descName.c_str()) != 0) {
188 continue;
189 }
190
191 bool match = false;
192 for (int j = 0; j < newDexFile->classes.size(); j++) {
193 const char* newClassDescriptor =
194 newDexFile->classes[j]->type->descriptor->c_str();
195 if (strcmp(oldClassDescriptor, newClassDescriptor) == 0) {
196 match = classStructureMatches(oldDexFile->classes[i].get(),
197 newDexFile->classes[j].get());
198 break;
199 }
200 }
201
202 if (!match) {
203 return false;
204 }
205 }
206
207 return true;
208 }
209