xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/cl/program_cache.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/cl/program_cache.h"
17 
18 #include <cstdint>
19 #include <string>
20 #include <utility>
21 
22 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
23 #include "tensorflow/lite/delegates/gpu/cl/cl_program.h"
24 #include "tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h"
25 #include "tensorflow/lite/delegates/gpu/cl/util.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include <farmhash.h>
28 
29 namespace tflite {
30 namespace gpu {
31 namespace cl {
32 namespace {
33 
34 // Farmhash Fingerprint
CombineFingerprints(uint64_t l,uint64_t h)35 inline uint64_t CombineFingerprints(uint64_t l, uint64_t h) {
36   // Murmur-inspired hashing.
37   const uint64_t kMul = 0x9ddfea08eb382d69ULL;
38   uint64_t a = (l ^ h) * kMul;
39   a ^= (a >> 47);
40   uint64_t b = (h ^ a) * kMul;
41   b ^= (b >> 44);
42   b *= kMul;
43   b ^= (b >> 41);
44   b *= kMul;
45   return b;
46 }
47 
GetProgramFingerprint(const std::string & code,const std::string & compiler_options)48 uint64_t GetProgramFingerprint(const std::string& code,
49                                const std::string& compiler_options) {
50   const uint64_t code_fingerprint = ::util::Fingerprint64(code);
51   const uint64_t options_fingerprint =
52       ::util::Fingerprint64(compiler_options);
53   return CombineFingerprints(code_fingerprint, options_fingerprint);
54 }
55 
GetDriverVersion(const CLDevice & device)56 std::string GetDriverVersion(const CLDevice& device) {
57   return device.GetPlatformVersion() + "_jet_version_0";
58 }
59 
60 }  // namespace
61 
ProgramDescriptor(const std::string & code,const std::string & compiler_options)62 ProgramCache::ProgramDescriptor::ProgramDescriptor(
63     const std::string& code, const std::string& compiler_options)
64     : fingerprint(GetProgramFingerprint(code, compiler_options)) {}
65 
ProgramDescriptor(uint64_t fingerprints)66 ProgramCache::ProgramDescriptor::ProgramDescriptor(uint64_t fingerprints)
67     : fingerprint(fingerprints) {}
68 
ProgramCache(ProgramCache && program_cache)69 ProgramCache::ProgramCache(ProgramCache&& program_cache)
70     : programs_(std::move(program_cache.programs_)) {}
71 
operator =(ProgramCache && program_cache)72 ProgramCache& ProgramCache::operator=(ProgramCache&& program_cache) {
73   if (this != &program_cache) {
74     programs_ = std::move(program_cache.programs_);
75   }
76   return *this;
77 }
78 
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const std::vector<CompilerOptions> & compiler_options,const CLContext & context,const CLDevice & device,CLKernel * result,uint64_t * kernel_fingerprint)79 absl::Status ProgramCache::GetOrCreateCLKernel(
80     const std::string& code, const std::string& function_name,
81     const std::vector<CompilerOptions>& compiler_options,
82     const CLContext& context, const CLDevice& device, CLKernel* result,
83     uint64_t* kernel_fingerprint) {
84   const std::string options =
85       CompilerOptionsToString(device.GetInfo(), compiler_options);
86   ProgramDescriptor desc(code, options);
87   if (kernel_fingerprint) {
88     *kernel_fingerprint = desc.fingerprint;
89   }
90   auto it = programs_.find(desc);
91   if (it != programs_.end()) {
92     return result->CreateFromProgram(it->second, function_name);
93   }
94 
95   CLProgram program;
96   RETURN_IF_ERROR(CreateCLProgram(code, options, context, device, &program));
97   RETURN_IF_ERROR(result->CreateFromProgram(program, function_name));
98   programs_.insert(std::make_pair(std::move(desc), std::move(program)));
99   return absl::OkStatus();
100 }
101 
GetOrCreateCLKernel(const std::string & code,const std::string & function_name,const CLContext & context,const CLDevice & device,CLKernel * result,uint64_t * kernel_fingerprint)102 absl::Status ProgramCache::GetOrCreateCLKernel(const std::string& code,
103                                                const std::string& function_name,
104                                                const CLContext& context,
105                                                const CLDevice& device,
106                                                CLKernel* result,
107                                                uint64_t* kernel_fingerprint) {
108   return GetOrCreateCLKernel(code, function_name, {}, context, device, result,
109                              kernel_fingerprint);
110 }
111 
GetKernel(uint64_t fingerprint,const std::string & function_name,CLKernel * result) const112 absl::Status ProgramCache::GetKernel(uint64_t fingerprint,
113                                      const std::string& function_name,
114                                      CLKernel* result) const {
115   ProgramDescriptor desc(fingerprint);
116   auto it = programs_.find(desc);
117   if (it == programs_.end()) {
118     return absl::NotFoundError("No program with this fingerprint.");
119   }
120   return result->CreateFromProgram(it->second, function_name);
121 }
122 
AddProgramBinary(const CLContext & context,const CLDevice & device,uint64_t fingerprint,absl::Span<const uint8_t> binary)123 absl::Status ProgramCache::AddProgramBinary(const CLContext& context,
124                                             const CLDevice& device,
125                                             uint64_t fingerprint,
126                                             absl::Span<const uint8_t> binary) {
127   ProgramDescriptor desc(fingerprint);
128   auto it = programs_.find(desc);
129   if (it == programs_.end()) {
130     CLProgram program;
131     RETURN_IF_ERROR(
132         CreateCLProgramFromBinary(context, device, binary, &program));
133     programs_.insert(std::make_pair(std::move(desc), std::move(program)));
134   }
135   return absl::OkStatus();
136 }
137 
GetProgramBinary(uint64_t fingerprint,std::vector<uint8_t> * program_binary) const138 absl::Status ProgramCache::GetProgramBinary(
139     uint64_t fingerprint, std::vector<uint8_t>* program_binary) const {
140   ProgramDescriptor desc(fingerprint);
141   auto it = programs_.find(desc);
142   if (it == programs_.end()) {
143     return absl::NotFoundError("No program with this fingerprint.");
144   }
145   return it->second.GetBinary(program_binary);
146 }
147 
AddSerializedCache(const CLContext & context,const CLDevice & device,absl::Span<const uint8_t> serialized_cache)148 absl::Status ProgramCache::AddSerializedCache(
149     const CLContext& context, const CLDevice& device,
150     absl::Span<const uint8_t> serialized_cache) {
151   flatbuffers::Verifier verifier(serialized_cache.data(),
152                                  serialized_cache.size());
153   if (!data::VerifyCompiledCacheBuffer(verifier)) {
154     return absl::InvalidArgumentError("Serialized model is corrupted.");
155   }
156 
157   auto model = data::GetCompiledCache(serialized_cache.data());
158   std::string platform_version(model->driver_version()->c_str(),
159                                model->driver_version()->size());
160 
161   if (GetDriverVersion(device) != platform_version) {
162     return absl::InvalidArgumentError(
163         "OpenCL driver changed, cache invalid, should be regenerated");
164   }
165 
166   for (auto serialized_program : *model->programs()) {
167     auto binary_span = absl::MakeSpan(serialized_program->binary()->data(),
168                                       serialized_program->binary()->size());
169     RETURN_IF_ERROR(AddProgramBinary(
170         context, device, serialized_program->fingerprint(), binary_span));
171   }
172   return absl::OkStatus();
173 }
174 
GetSerializedCache(const CLDevice & device,std::vector<uint8_t> * serialized_cache) const175 absl::Status ProgramCache::GetSerializedCache(
176     const CLDevice& device, std::vector<uint8_t>* serialized_cache) const {
177   ::flatbuffers::FlatBufferBuilder builder;
178   std::vector<flatbuffers::Offset<data::Program>> serialized_programs;
179   for (auto& program : programs_) {
180     std::vector<uint8_t> binary;
181     RETURN_IF_ERROR(program.second.GetBinary(&binary));
182     auto binary_offset = builder.CreateVector(binary);
183     data::ProgramBuilder program_builder(builder);
184     program_builder.add_fingerprint(program.first.fingerprint);
185     program_builder.add_binary(binary_offset);
186     serialized_programs.push_back(program_builder.Finish());
187   }
188   auto driver_version = builder.CreateString(GetDriverVersion(device));
189   auto programs_s = builder.CreateVector(serialized_programs);
190   data::CompiledCacheBuilder cache_builder(builder);
191   cache_builder.add_driver_version(driver_version);
192   cache_builder.add_programs(programs_s);
193   data::FinishCompiledCacheBuffer(builder, cache_builder.Finish());
194   size_t next_element = serialized_cache->size();
195   serialized_cache->resize(serialized_cache->size() + builder.GetSize());
196   std::memcpy(&(*serialized_cache)[next_element], builder.GetBufferPointer(),
197               builder.GetSize());
198   return absl::OkStatus();
199 }
200 
201 }  // namespace cl
202 }  // namespace gpu
203 }  // namespace tflite
204