1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
17
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21
22 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
23 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
24 #include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h"
25 #include "tensorflow/lite/delegates/gpu/cl/util.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include <farmhash.h>
28
29 namespace tflite {
30 namespace gpu {
31 namespace cl {
32 namespace {
33
34 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)35 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
36 // Murmur-inspired hashing.
37 const uint64_t kMul = 0x9ddfea08eb382d69ULL;
38 uint64_t a = (l ^ h) * kMul;
39 a ^= (a >> 47);
40 uint64_t b = (h ^ a) * kMul;
41 b ^= (b >> 44);
42 b *= kMul;
43 b ^= (b >> 41);
44 b *= kMul;
45 return b;
46 }
47
GetProgramFingerprint(const std::string & code,const std::string & compiler_options)48 uint64_t GetProgramFingerprint(const std::string& code,
49 const std::string& compiler_options) {
50 const uint64_t code_fingerprint = ::util::Fingerprint64(code);
51 const uint64_t options_fingerprint =
52 ::util::Fingerprint64(compiler_options);
53 return CombineFingerprints(code_fingerprint, options_fingerprint);
54 }
55
GetDriverVersion(const CLDevice & device)56 std::string GetDriverVersion(const CLDevice& device) {
57 return device.GetPlatformVersion() + "_jet_version_0";
58 }
59
60 } // namespace
61
ProgramDescriptor(const std::string & code,const std::string & compiler_options)62 ProgramCache::ProgramDescriptor::ProgramDescriptor(
63 const std::string& code, const std::string& compiler_options)
64 : fingerprint(GetProgramFingerprint(code, compiler_options)) {}
65
ProgramDescriptor(uint64_t fingerprints)66 ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
67 : fingerprint(fingerprints) {}
68
ProgramCache(ProgramCache && program_cache)69 ProgramCache::ProgramCache(ProgramCache&& program_cache)
70 : programs_(std::move(program_cache.programs_)) {}
71
operator =(ProgramCache && program_cache)72 ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) {
73 if (this != &program_cache) {
74 programs_ = std::move(program_cache.programs_);
75 }
76 return *this;
77 }
78
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const std::vector<CompilerOptions> & compiler_options,const CLContext & context,const CLDevice & device,CLKernel * result,uint64_t * kernel_fingerprint)79 absl::Status ProgramCache::GetOrCreateCLKernel(
80 const std::string& code, const std::string& function_name,
81 const std::vector<CompilerOptions>& compiler_options,
82 const CLContext& context, const CLDevice& device, CLKernel* result,
83 uint64_t* kernel_fingerprint) {
84 const std::string options =
85 CompilerOptionsToString(device.GetInfo(), compiler_options);
86 ProgramDescriptor desc(code, options);
87 if (kernel_fingerprint) {
88 *kernel_fingerprint = desc.fingerprint;
89 }
90 auto it = programs_.find(desc);
91 if (it != programs_.end()) {
92 return result->CreateFromProgram(it->second, function_name);
93 }
94
95 CLProgram program;
96 RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
97 RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
98 programs_.insert(std::make_pair(std::move(desc), std::move(program)));
99 return absl::OkStatus();
100 }
101
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const CLContext & context,const CLDevice & device,CLKernel * result,uint64_t * kernel_fingerprint)102 absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code,
103 const std::string& function_name,
104 const CLContext& context,
105 const CLDevice& device,
106 CLKernel* result,
107 uint64_t* kernel_fingerprint) {
108 return GetOrCreateCLKernel(code, function_name, {}, context, device, result,
109 kernel_fingerprint);
110 }
111
GetKernel(uint64_t fingerprint,const std::string & function_name,CLKernel * result) const112 absl::Status ProgramCache::GetKernel(uint64_t fingerprint,
113 const std::string& function_name,
114 CLKernel* result) const {
115 ProgramDescriptor desc(fingerprint);
116 auto it = programs_.find(desc);
117 if (it == programs_.end()) {
118 return absl::NotFoundError("No program with this fingerprint.");
119 }
120 return result->CreateFromProgram(it->second, function_name);
121 }
122
AddProgramBinary(const CLContext & context,const CLDevice & device,uint64_t fingerprint,absl::Span<const uint8_t> binary)123 absl::Status ProgramCache::AddProgramBinary(const CLContext& context,
124 const CLDevice& device,
125 uint64_t fingerprint,
126 absl::Span<const uint8_t> binary) {
127 ProgramDescriptor desc(fingerprint);
128 auto it = programs_.find(desc);
129 if (it == programs_.end()) {
130 CLProgram program;
131 RETURN_IF_ERROR(
132 CreateCLProgramFromBinary(context, device, binary, &program));
133 programs_.insert(std::make_pair(std::move(desc), std::move(program)));
134 }
135 return absl::OkStatus();
136 }
137
GetProgramBinary(uint64_t fingerprint,std::vector<uint8_t> * program_binary) const138 absl::Status ProgramCache::GetProgramBinary(
139 uint64_t fingerprint, std::vector<uint8_t>* program_binary) const {
140 ProgramDescriptor desc(fingerprint);
141 auto it = programs_.find(desc);
142 if (it == programs_.end()) {
143 return absl::NotFoundError("No program with this fingerprint.");
144 }
145 return it->second.GetBinary(program_binary);
146 }
147
AddSerializedCache(const CLContext & context,const CLDevice & device,absl::Span<const uint8_t> serialized_cache)148 absl::Status ProgramCache::AddSerializedCache(
149 const CLContext& context, const CLDevice& device,
150 absl::Span<const uint8_t> serialized_cache) {
151 flatbuffers::Verifier verifier(serialized_cache.data(),
152 serialized_cache.size());
153 if (!data::VerifyCompiledCacheBuffer(verifier)) {
154 return absl::InvalidArgumentError("Serialized model is corrupted.");
155 }
156
157 auto model = data::GetCompiledCache(serialized_cache.data());
158 std::string platform_version(model->driver_version()->c_str(),
159 model->driver_version()->size());
160
161 if (GetDriverVersion(device) != platform_version) {
162 return absl::InvalidArgumentError(
163 "OpenCL driver changed, cache invalid, should be regenerated");
164 }
165
166 for (auto serialized_program : *model->programs()) {
167 auto binary_span = absl::MakeSpan(serialized_program->binary()->data(),
168 serialized_program->binary()->size());
169 RETURN_IF_ERROR(AddProgramBinary(
170 context, device, serialized_program->fingerprint(), binary_span));
171 }
172 return absl::OkStatus();
173 }
174
GetSerializedCache(const CLDevice & device,std::vector<uint8_t> * serialized_cache) const175 absl::Status ProgramCache::GetSerializedCache(
176 const CLDevice& device, std::vector<uint8_t>* serialized_cache) const {
177 ::flatbuffers::FlatBufferBuilder builder;
178 std::vector<flatbuffers::Offset<data::Program>> serialized_programs;
179 for (auto& program : programs_) {
180 std::vector<uint8_t> binary;
181 RETURN_IF_ERROR(program.second.GetBinary(&binary));
182 auto binary_offset = builder.CreateVector(binary);
183 data::ProgramBuilder program_builder(builder);
184 program_builder.add_fingerprint(program.first.fingerprint);
185 program_builder.add_binary(binary_offset);
186 serialized_programs.push_back(program_builder.Finish());
187 }
188 auto driver_version = builder.CreateString(GetDriverVersion(device));
189 auto programs_s = builder.CreateVector(serialized_programs);
190 data::CompiledCacheBuilder cache_builder(builder);
191 cache_builder.add_driver_version(driver_version);
192 cache_builder.add_programs(programs_s);
193 data::FinishCompiledCacheBuffer(builder, cache_builder.Finish());
194 size_t next_element = serialized_cache->size();
195 serialized_cache->resize(serialized_cache->size() + builder.GetSize());
196 std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(),
197 builder.GetSize());
198 return absl::OkStatus();
199 }
200
201 } // namespace cl
202 } // namespace gpu
203 } // namespace tflite
204