xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/task/serialization_base.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/lite/delegates/gpu/common/precision.h"
26 #include "tensorflow/lite/delegates/gpu/common/task/arguments.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/buffer_desc.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/gpu_object_desc.h"
29 #include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
30 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h"
31 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
32 
33 namespace tflite {
34 namespace gpu {
35 
36 namespace {
ToFB(AccessType type)37 data::AccessType ToFB(AccessType type) {
38   switch (type) {
39     case AccessType::READ:
40       return data::AccessType::READ;
41     case AccessType::WRITE:
42       return data::AccessType::WRITE;
43     case AccessType::READ_WRITE:
44       return data::AccessType::READ_WRITE;
45     default:
46       return data::AccessType::READ_WRITE;
47   }
48 }
49 
ToFB(DataType type)50 data::DataType ToFB(DataType type) {
51   switch (type) {
52     case DataType::BOOL:
53       return data::DataType::BOOL;
54     case DataType::FLOAT16:
55       return data::DataType::FLOAT16;
56     case DataType::FLOAT32:
57       return data::DataType::FLOAT32;
58     case DataType::FLOAT64:
59       return data::DataType::FLOAT64;
60     case DataType::UINT8:
61       return data::DataType::UINT8;
62     case DataType::INT8:
63       return data::DataType::INT8;
64     case DataType::UINT16:
65       return data::DataType::UINT16;
66     case DataType::INT16:
67       return data::DataType::INT16;
68     case DataType::UINT32:
69       return data::DataType::UINT32;
70     case DataType::INT32:
71       return data::DataType::INT32;
72     case DataType::UINT64:
73       return data::DataType::UINT64;
74     case DataType::INT64:
75       return data::DataType::INT64;
76     case DataType::UNKNOWN:
77       return data::DataType::UNKNOWN;
78   }
79 }
80 
ToFB(MemoryType type)81 data::MemoryType ToFB(MemoryType type) {
82   switch (type) {
83     case MemoryType::CONSTANT:
84       return data::MemoryType::CONSTANT;
85     case MemoryType::GLOBAL:
86       return data::MemoryType::GLOBAL;
87     case MemoryType::LOCAL:
88       return data::MemoryType::LOCAL;
89   }
90 }
91 
ToFB(TensorStorageType type)92 data::TensorStorageType ToFB(TensorStorageType type) {
93   switch (type) {
94     case TensorStorageType::BUFFER:
95       return data::TensorStorageType::BUFFER;
96     case TensorStorageType::IMAGE_BUFFER:
97       return data::TensorStorageType::IMAGE_BUFFER;
98     case TensorStorageType::TEXTURE_2D:
99       return data::TensorStorageType::TEXTURE_2D;
100     case TensorStorageType::TEXTURE_ARRAY:
101       return data::TensorStorageType::TEXTURE_ARRAY;
102     case TensorStorageType::TEXTURE_3D:
103       return data::TensorStorageType::TEXTURE_3D;
104     case TensorStorageType::SINGLE_TEXTURE_2D:
105       return data::TensorStorageType::SINGLE_TEXTURE_2D;
106     case TensorStorageType::UNKNOWN:
107       return data::TensorStorageType::UNKNOWN;
108   }
109 }
110 
ToFB(Layout type)111 data::Layout ToFB(Layout type) {
112   switch (type) {
113     case Layout::HWC:
114       return data::Layout::HWC;
115     case Layout::BHWC:
116       return data::Layout::BHWC;
117     case Layout::HWDC:
118       return data::Layout::HWDC;
119     case Layout::BHWDC:
120       return data::Layout::BHWDC;
121     case Layout::LINEAR:
122       return data::Layout::LINEAR;
123     case Layout::HW:
124       return data::Layout::HW;
125     default:
126       return data::Layout::UNKNOWN;
127   }
128 }
129 
ToEnum(data::DataType type)130 DataType ToEnum(data::DataType type) {
131   switch (type) {
132     case data::DataType::BOOL:
133       return DataType::BOOL;
134     case data::DataType::FLOAT16:
135       return DataType::FLOAT16;
136     case data::DataType::FLOAT32:
137       return DataType::FLOAT32;
138     case data::DataType::FLOAT64:
139       return DataType::FLOAT64;
140     case data::DataType::UINT8:
141       return DataType::UINT8;
142     case data::DataType::INT8:
143       return DataType::INT8;
144     case data::DataType::UINT16:
145       return DataType::UINT16;
146     case data::DataType::INT16:
147       return DataType::INT16;
148     case data::DataType::UINT32:
149       return DataType::UINT32;
150     case data::DataType::INT32:
151       return DataType::INT32;
152     case data::DataType::UINT64:
153       return DataType::UINT64;
154     case data::DataType::INT64:
155       return DataType::INT64;
156     case data::DataType::UNKNOWN:
157       return DataType::UNKNOWN;
158   }
159 }
160 
ToEnum(data::AccessType type)161 AccessType ToEnum(data::AccessType type) {
162   switch (type) {
163     case data::AccessType::READ:
164       return AccessType::READ;
165     case data::AccessType::WRITE:
166       return AccessType::WRITE;
167     case data::AccessType::READ_WRITE:
168       return AccessType::READ_WRITE;
169   }
170 }
171 
ToEnum(data::MemoryType type)172 MemoryType ToEnum(data::MemoryType type) {
173   switch (type) {
174     case data::MemoryType::CONSTANT:
175       return MemoryType::CONSTANT;
176     case data::MemoryType::GLOBAL:
177       return MemoryType::GLOBAL;
178     case data::MemoryType::LOCAL:
179       return MemoryType::LOCAL;
180   }
181 }
182 
ToEnum(data::TensorStorageType type)183 TensorStorageType ToEnum(data::TensorStorageType type) {
184   switch (type) {
185     case data::TensorStorageType::BUFFER:
186       return TensorStorageType::BUFFER;
187     case data::TensorStorageType::IMAGE_BUFFER:
188       return TensorStorageType::IMAGE_BUFFER;
189     case data::TensorStorageType::TEXTURE_2D:
190       return TensorStorageType::TEXTURE_2D;
191     case data::TensorStorageType::TEXTURE_ARRAY:
192       return TensorStorageType::TEXTURE_ARRAY;
193     case data::TensorStorageType::TEXTURE_3D:
194       return TensorStorageType::TEXTURE_3D;
195     case data::TensorStorageType::SINGLE_TEXTURE_2D:
196       return TensorStorageType::SINGLE_TEXTURE_2D;
197     case data::TensorStorageType::UNKNOWN:
198       return TensorStorageType::UNKNOWN;
199   }
200 }
201 
ToEnum(data::Layout type)202 Layout ToEnum(data::Layout type) {
203   switch (type) {
204     case data::Layout::HWC:
205       return Layout::HWC;
206     case data::Layout::BHWC:
207       return Layout::BHWC;
208     case data::Layout::HWDC:
209       return Layout::HWDC;
210     case data::Layout::BHWDC:
211       return Layout::BHWDC;
212     case data::Layout::LINEAR:
213       return Layout::LINEAR;
214     case data::Layout::HW:
215       return Layout::HW;
216     default:
217       return Layout::UNKNOWN;
218   }
219 }
220 
ToFB(CalculationsPrecision type)221 data::CalculationsPrecision ToFB(CalculationsPrecision type) {
222   switch (type) {
223     case CalculationsPrecision::F32:
224       return data::CalculationsPrecision::F32;
225     case CalculationsPrecision::F32_F16:
226       return data::CalculationsPrecision::F32_F16;
227     case CalculationsPrecision::F16:
228       return data::CalculationsPrecision::F16;
229   }
230 }
231 
ToFB(TensorToGrid type)232 data::TensorToGrid ToFB(TensorToGrid type) {
233   switch (type) {
234     case TensorToGrid::kCustom:
235       return data::TensorToGrid::CUSTOM;
236     case TensorToGrid::kWBToX_HDToY_SToZ:
237       return data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z;
238     case TensorToGrid::kWBToX_HDToY_ZIs1:
239       return data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1;
240     case TensorToGrid::kWBToX_HToY_DToZ:
241       return data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z;
242     case TensorToGrid::kBToX_YIs1_ZIs1:
243       return data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1;
244   }
245 }
246 
ToFB(CompilerOptions type)247 data::CompilerOptions ToFB(CompilerOptions type) {
248   switch (type) {
249     case CompilerOptions::kAdrenoFullSimd:
250       return data::CompilerOptions::ADRENO_FULL_SIMD_LINE;
251     case CompilerOptions::kAdrenoMoreWaves:
252       return data::CompilerOptions::ADRENO_MORE_WAVES;
253     case CompilerOptions::kClFastRelaxedMath:
254       return data::CompilerOptions::CL_FAST_RELAXED_MATH;
255     case CompilerOptions::kClDisableOptimizations:
256       return data::CompilerOptions::CL_OPT_DISABLE;
257     case CompilerOptions::kCl20:
258       return data::CompilerOptions::CL_2_0;
259     case CompilerOptions::kCl30:
260       return data::CompilerOptions::CL_3_0;
261   }
262 }
263 
ToEnum(data::CalculationsPrecision type)264 CalculationsPrecision ToEnum(data::CalculationsPrecision type) {
265   switch (type) {
266     case data::CalculationsPrecision::F32:
267       return CalculationsPrecision::F32;
268     case data::CalculationsPrecision::F32_F16:
269       return CalculationsPrecision::F32_F16;
270     case data::CalculationsPrecision::F16:
271       return CalculationsPrecision::F16;
272   }
273 }
274 
ToEnum(data::TensorToGrid type)275 TensorToGrid ToEnum(data::TensorToGrid type) {
276   switch (type) {
277     case data::TensorToGrid::CUSTOM:
278       return TensorToGrid::kCustom;
279     case data::TensorToGrid::WB_TO_X_HD_TO_Y_S_TO_Z:
280       return TensorToGrid::kWBToX_HDToY_SToZ;
281     case data::TensorToGrid::WB_TO_X_HD_TO_Y_Z_IS_1:
282       return TensorToGrid::kWBToX_HDToY_ZIs1;
283     case data::TensorToGrid::WB_TO_X_H_TO_Y_D_TO_Z:
284       return TensorToGrid::kWBToX_HToY_DToZ;
285     case data::TensorToGrid::B_TO_X_Y_IS_1_Z_IS_1:
286       return TensorToGrid::kBToX_YIs1_ZIs1;
287   }
288 }
289 
ToEnum(data::CompilerOptions type)290 CompilerOptions ToEnum(data::CompilerOptions type) {
291   switch (type) {
292     case data::CompilerOptions::ADRENO_FULL_SIMD_LINE:
293       return CompilerOptions::kAdrenoFullSimd;
294     case data::CompilerOptions::ADRENO_MORE_WAVES:
295       return CompilerOptions::kAdrenoMoreWaves;
296     case data::CompilerOptions::CL_FAST_RELAXED_MATH:
297       return CompilerOptions::kClFastRelaxedMath;
298     case data::CompilerOptions::CL_OPT_DISABLE:
299       return CompilerOptions::kClDisableOptimizations;
300     case data::CompilerOptions::CL_2_0:
301       return CompilerOptions::kCl20;
302     case data::CompilerOptions::CL_3_0:
303       return CompilerOptions::kCl30;
304   }
305 }
306 
307 }  // namespace
308 
Encode(const int2 & v,flatbuffers::FlatBufferBuilder * builder)309 flatbuffers::Offset<data::Int2> Encode(
310     const int2& v, flatbuffers::FlatBufferBuilder* builder) {
311   data::Int2Builder int2_builder(*builder);
312   int2_builder.add_x(v.x);
313   int2_builder.add_y(v.y);
314   return int2_builder.Finish();
315 }
316 
Encode(const int3 & v,flatbuffers::FlatBufferBuilder * builder)317 flatbuffers::Offset<data::Int3> Encode(
318     const int3& v, flatbuffers::FlatBufferBuilder* builder) {
319   data::Int3Builder int3_builder(*builder);
320   int3_builder.add_x(v.x);
321   int3_builder.add_y(v.y);
322   int3_builder.add_z(v.z);
323   return int3_builder.Finish();
324 }
325 
Encode(const GPUObjectDescriptor & desc,flatbuffers::FlatBufferBuilder * builder)326 flatbuffers::Offset<data::GPUObjectDescriptor> Encode(
327     const GPUObjectDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) {
328   std::vector<flatbuffers::Offset<data::StateVariable>> state_vars_fb;
329   for (auto& v0 : desc.state_vars_) {
330     auto key_fb = builder->CreateString(v0.first);
331     auto value_fb = builder->CreateString(v0.second);
332     data::StateVariableBuilder state_builder(*builder);
333     state_builder.add_key(key_fb);
334     state_builder.add_value(value_fb);
335     state_vars_fb.push_back(state_builder.Finish());
336   }
337   auto state_vars_fb_vec = builder->CreateVector(state_vars_fb);
338   data::GPUObjectDescriptorBuilder obj_builder(*builder);
339   obj_builder.add_state_vars(state_vars_fb_vec);
340   obj_builder.add_access_type(ToFB(desc.access_type_));
341   return obj_builder.Finish();
342 }
343 
Decode(const data::GPUObjectDescriptor * fb_obj,GPUObjectDescriptor * obj)344 void Decode(const data::GPUObjectDescriptor* fb_obj, GPUObjectDescriptor* obj) {
345   obj->access_type_ = ToEnum(fb_obj->access_type());
346   for (auto state_fb : *fb_obj->state_vars()) {
347     std::string key(state_fb->key()->c_str(), state_fb->key()->size());
348     std::string value(state_fb->value()->c_str(), state_fb->value()->size());
349     obj->state_vars_[key] = value;
350   }
351 }
352 
Encode(const BufferDescriptor & desc,flatbuffers::FlatBufferBuilder * builder)353 flatbuffers::Offset<data::BufferDescriptor> Encode(
354     const BufferDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) {
355   auto obj_fb =
356       Encode(*static_cast<const GPUObjectDescriptor*>(&desc), builder);
357 
358   std::vector<flatbuffers::Offset<flatbuffers::String>> attributes_fb;
359   attributes_fb.reserve(desc.attributes.size());
360   for (auto& attr : desc.attributes) {
361     attributes_fb.push_back(builder->CreateString(attr));
362   }
363   auto attributes_fb_vec = builder->CreateVector(attributes_fb);
364   auto data_fb = builder->CreateVector(desc.data);
365   data::BufferDescriptorBuilder buf_builder(*builder);
366   buf_builder.add_base_obj(obj_fb);
367   buf_builder.add_element_type(ToFB(desc.element_type));
368   buf_builder.add_element_size(desc.element_size);
369   buf_builder.add_memory_type(ToFB(desc.memory_type));
370   buf_builder.add_attributes(attributes_fb_vec);
371   buf_builder.add_size(desc.size);
372   buf_builder.add_data(data_fb);
373   return buf_builder.Finish();
374 }
375 
Decode(const data::BufferDescriptor * fb_desc,BufferDescriptor * desc)376 void Decode(const data::BufferDescriptor* fb_desc, BufferDescriptor* desc) {
377   Decode(fb_desc->base_obj(), desc);
378   desc->element_type = ToEnum(fb_desc->element_type());
379   desc->element_size = fb_desc->element_size();
380   desc->memory_type = ToEnum(fb_desc->memory_type());
381   for (auto attr_fb : *fb_desc->attributes()) {
382     std::string attr(attr_fb->c_str(), attr_fb->size());
383     desc->attributes.push_back(attr);
384   }
385   desc->size = fb_desc->size();
386   desc->data =
387       std::vector<uint8_t>(fb_desc->data()->data(),
388                            fb_desc->data()->data() + fb_desc->data()->size());
389 }
390 
Encode(const TensorDescriptor & desc,flatbuffers::FlatBufferBuilder * builder)391 flatbuffers::Offset<data::TensorDescriptor> Encode(
392     const TensorDescriptor& desc, flatbuffers::FlatBufferBuilder* builder) {
393   auto obj_fb =
394       Encode(*static_cast<const GPUObjectDescriptor*>(&desc), builder);
395 
396   data::BHWDCBuilder shape_builder(*builder);
397   shape_builder.add_b(desc.GetBHWDCShape().b);
398   shape_builder.add_h(desc.GetBHWDCShape().h);
399   shape_builder.add_w(desc.GetBHWDCShape().w);
400   shape_builder.add_d(desc.GetBHWDCShape().d);
401   shape_builder.add_c(desc.GetBHWDCShape().c);
402   auto shape_fb = shape_builder.Finish();
403 
404   auto data_fb = builder->CreateVector(desc.GetData());
405   data::TensorDescriptorBuilder tensor_builder(*builder);
406   tensor_builder.add_base_obj(obj_fb);
407   tensor_builder.add_data_type(ToFB(desc.data_type_));
408   tensor_builder.add_storage_type(ToFB(desc.storage_type_));
409   tensor_builder.add_layout(ToFB(desc.layout_));
410   tensor_builder.add_shape(shape_fb);
411   tensor_builder.add_data(data_fb);
412   tensor_builder.add_use_buffer_for_write_only_2d_texture(
413       desc.use_buffer_for_write_only_2d_texture_);
414   tensor_builder.add_use_buffer_for_write_only_image_buffer(
415       desc.use_buffer_for_write_only_image_buffer_);
416   return tensor_builder.Finish();
417 }
418 
Decode(const data::TensorDescriptor * fb_desc,TensorDescriptor * desc)419 void Decode(const data::TensorDescriptor* fb_desc, TensorDescriptor* desc) {
420   Decode(fb_desc->base_obj(), desc);
421   desc->data_type_ = ToEnum(fb_desc->data_type());
422   desc->storage_type_ = ToEnum(fb_desc->storage_type());
423   desc->layout_ = ToEnum(fb_desc->layout());
424   desc->SetBHWDCShape(BHWDC(fb_desc->shape()->b(), fb_desc->shape()->h(),
425                             fb_desc->shape()->w(), fb_desc->shape()->d(),
426                             fb_desc->shape()->c()));
427   desc->SetData(
428       std::vector<uint8_t>(fb_desc->data()->data(),
429                            fb_desc->data()->data() + fb_desc->data()->size()));
430   desc->use_buffer_for_write_only_2d_texture_ =
431       fb_desc->use_buffer_for_write_only_2d_texture();
432   desc->use_buffer_for_write_only_image_buffer_ =
433       fb_desc->use_buffer_for_write_only_image_buffer();
434 }
435 
Decode(const data::Arguments * fb_args,Arguments * args)436 absl::Status Decode(const data::Arguments* fb_args, Arguments* args) {
437   args->int_values_.clear();
438   for (auto int_values_fb : *fb_args->int_values()) {
439     Arguments::IntValue value;
440     value.value = int_values_fb->value();
441     value.active = int_values_fb->active();
442     std::string name(int_values_fb->name()->c_str(),
443                      int_values_fb->name()->size());
444     args->int_values_[name] = value;
445   }
446 
447   args->float_values_.clear();
448   for (auto float_values_fb : *fb_args->float_values()) {
449     Arguments::FloatValue value;
450     value.value = float_values_fb->value();
451     value.active = float_values_fb->active();
452     std::string name(float_values_fb->name()->c_str(),
453                      float_values_fb->name()->size());
454     args->float_values_[name] = value;
455   }
456 
457   args->half_values_.clear();
458   for (auto half_values_fb : *fb_args->half_values()) {
459     Arguments::HalfValue value;
460     value.value = half_values_fb->value();
461     value.active = half_values_fb->active();
462     std::string name(half_values_fb->name()->c_str(),
463                      half_values_fb->name()->size());
464     args->half_values_[name] = value;
465   }
466 
467   for (auto buffer_pair_fb : *fb_args->buffer_objects()) {
468     std::string key(buffer_pair_fb->key()->c_str(),
469                     buffer_pair_fb->key()->size());
470     BufferDescriptor desc;
471     Decode(buffer_pair_fb->value(), &desc);
472     args->AddObject(key, std::make_unique<BufferDescriptor>(std::move(desc)));
473   }
474 
475   for (auto tensor_pair_fb : *fb_args->tensor_objects()) {
476     std::string key(tensor_pair_fb->key()->c_str(),
477                     tensor_pair_fb->key()->size());
478     TensorDescriptor desc;
479     Decode(tensor_pair_fb->value(), &desc);
480     args->AddObject(key, std::make_unique<TensorDescriptor>(std::move(desc)));
481   }
482 
483   for (auto buffer_pair_fb : *fb_args->buffer_refs()) {
484     std::string key(buffer_pair_fb->key()->c_str(),
485                     buffer_pair_fb->key()->size());
486     BufferDescriptor desc;
487     Decode(buffer_pair_fb->value(), &desc);
488     auto access_type = desc.GetAccess();
489     args->AddObjectRef(key, access_type,
490                        std::make_unique<BufferDescriptor>(std::move(desc)));
491   }
492 
493   for (auto tensor_pair_fb : *fb_args->tensor_refs()) {
494     std::string key(tensor_pair_fb->key()->c_str(),
495                     tensor_pair_fb->key()->size());
496     TensorDescriptor desc;
497     Decode(tensor_pair_fb->value(), &desc);
498     auto access_type = desc.GetAccess();
499     args->AddObjectRef(key, access_type,
500                        std::make_unique<TensorDescriptor>(std::move(desc)));
501   }
502   return absl::OkStatus();
503 }
504 
Encode(const Arguments & args,flatbuffers::FlatBufferBuilder * builder)505 flatbuffers::Offset<data::Arguments> Encode(
506     const Arguments& args, flatbuffers::FlatBufferBuilder* builder) {
507   std::vector<flatbuffers::Offset<data::IntValue>> int_values_fb;
508   for (auto& value : args.int_values_) {
509     auto name_fb = builder->CreateString(value.first);
510     data::IntValueBuilder value_builder(*builder);
511     value_builder.add_name(name_fb);
512     value_builder.add_value(value.second.value);
513     value_builder.add_active(value.second.active);
514     int_values_fb.push_back(value_builder.Finish());
515   }
516 
517   std::vector<flatbuffers::Offset<data::FloatValue>> float_values_fb;
518   for (auto& value : args.float_values_) {
519     auto name_fb = builder->CreateString(value.first);
520     data::FloatValueBuilder value_builder(*builder);
521     value_builder.add_name(name_fb);
522     value_builder.add_value(value.second.value);
523     value_builder.add_active(value.second.active);
524     float_values_fb.push_back(value_builder.Finish());
525   }
526 
527   std::vector<flatbuffers::Offset<data::HalfValue>> half_values_fb;
528   for (auto& value : args.half_values_) {
529     auto name_fb = builder->CreateString(value.first);
530     data::HalfValueBuilder value_builder(*builder);
531     value_builder.add_name(name_fb);
532     value_builder.add_value(value.second.value);
533     value_builder.add_active(value.second.active);
534     half_values_fb.push_back(value_builder.Finish());
535   }
536 
537   std::vector<flatbuffers::Offset<data::BufferDescriptorMapValue>>
538       buffer_objs_fb;
539   for (auto& value : args.objects_) {
540     const auto* buffer_desc =
541         dynamic_cast<const BufferDescriptor*>(value.second.get());
542     if (!buffer_desc) continue;
543     auto desc_fb = Encode(*buffer_desc, builder);
544     auto key_fb = builder->CreateString(value.first);
545     data::BufferDescriptorMapValueBuilder buf_map_builder(*builder);
546     buf_map_builder.add_key(key_fb);
547     buf_map_builder.add_value(desc_fb);
548     buffer_objs_fb.push_back(buf_map_builder.Finish());
549   }
550   std::vector<flatbuffers::Offset<data::TensorDescriptorMapValue>>
551       tensor_objs_fb;
552   for (auto& value : args.objects_) {
553     const auto* tensor_desc =
554         dynamic_cast<const TensorDescriptor*>(value.second.get());
555     if (!tensor_desc) continue;
556     auto desc_fb = Encode(*tensor_desc, builder);
557     auto key_fb = builder->CreateString(value.first);
558     data::TensorDescriptorMapValueBuilder ten_map_builder(*builder);
559     ten_map_builder.add_key(key_fb);
560     ten_map_builder.add_value(desc_fb);
561     tensor_objs_fb.push_back(ten_map_builder.Finish());
562   }
563 
564   std::vector<flatbuffers::Offset<data::BufferDescriptorMapValue>>
565       buffer_refs_fb;
566   for (auto& value : args.object_refs_) {
567     const auto* buffer_desc =
568         dynamic_cast<const BufferDescriptor*>(value.second.get());
569     if (!buffer_desc) continue;
570     auto desc_fb = Encode(*buffer_desc, builder);
571     auto key_fb = builder->CreateString(value.first);
572     data::BufferDescriptorMapValueBuilder buf_map_builder(*builder);
573     buf_map_builder.add_key(key_fb);
574     buf_map_builder.add_value(desc_fb);
575     buffer_refs_fb.push_back(buf_map_builder.Finish());
576   }
577   std::vector<flatbuffers::Offset<data::TensorDescriptorMapValue>>
578       tensor_refs_fb;
579   for (auto& value : args.object_refs_) {
580     const auto* tensor_desc =
581         dynamic_cast<const TensorDescriptor*>(value.second.get());
582     if (!tensor_desc) continue;
583     auto desc_fb = Encode(*tensor_desc, builder);
584     auto key_fb = builder->CreateString(value.first);
585     data::TensorDescriptorMapValueBuilder ten_map_builder(*builder);
586     ten_map_builder.add_key(key_fb);
587     ten_map_builder.add_value(desc_fb);
588     tensor_refs_fb.push_back(ten_map_builder.Finish());
589   }
590 
591   auto int_values_fb_vec = builder->CreateVector(int_values_fb);
592   auto float_values_fb_vec = builder->CreateVector(float_values_fb);
593   auto half_values_fb_vec = builder->CreateVector(half_values_fb);
594   auto buffer_objs_fb_vec = builder->CreateVector(buffer_objs_fb);
595   auto tensor_objs_fb_vec = builder->CreateVector(tensor_objs_fb);
596   auto buffer_refs_fb_vec = builder->CreateVector(buffer_refs_fb);
597   auto tensor_refs_fb_vec = builder->CreateVector(tensor_refs_fb);
598   data::ArgumentsBuilder arguments_builder(*builder);
599   arguments_builder.add_int_values(int_values_fb_vec);
600   arguments_builder.add_float_values(float_values_fb_vec);
601   arguments_builder.add_half_values(half_values_fb_vec);
602   arguments_builder.add_buffer_objects(buffer_objs_fb_vec);
603   arguments_builder.add_tensor_objects(tensor_objs_fb_vec);
604   arguments_builder.add_buffer_refs(buffer_refs_fb_vec);
605   arguments_builder.add_tensor_refs(tensor_refs_fb_vec);
606   return arguments_builder.Finish();
607 }
608 
Encode(const OperationDef & def,flatbuffers::FlatBufferBuilder * builder)609 flatbuffers::Offset<data::OperationDef> Encode(
610     const OperationDef& def, flatbuffers::FlatBufferBuilder* builder) {
611   std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
612       src_tensors_fb;
613   for (auto& desc : def.src_tensors) {
614     auto desc_fb = Encode(desc, builder);
615     src_tensors_fb.push_back(desc_fb);
616   }
617 
618   std::vector<flatbuffers::Offset<tflite::gpu::data::TensorDescriptor>>
619       dst_tensors_fb;
620   for (auto& desc : def.dst_tensors) {
621     auto desc_fb = Encode(desc, builder);
622     dst_tensors_fb.push_back(desc_fb);
623   }
624 
625   auto src_tensors_fb_vec = builder->CreateVector(src_tensors_fb);
626   auto dst_tensors_fb_vec = builder->CreateVector(dst_tensors_fb);
627 
628   data::OperationDefBuilder def_builder(*builder);
629   def_builder.add_precision(ToFB(def.precision));
630   def_builder.add_src_tensors(src_tensors_fb_vec);
631   def_builder.add_dst_tensors(dst_tensors_fb_vec);
632   return def_builder.Finish();
633 }
634 
Decode(const data::OperationDef * fb_def,OperationDef * def)635 void Decode(const data::OperationDef* fb_def, OperationDef* def) {
636   for (auto src_fb : *fb_def->src_tensors()) {
637     TensorDescriptor desc;
638     Decode(src_fb, &desc);
639     def->src_tensors.push_back(std::move(desc));
640   }
641   for (auto dst_fb : *fb_def->dst_tensors()) {
642     TensorDescriptor desc;
643     Decode(dst_fb, &desc);
644     def->dst_tensors.push_back(std::move(desc));
645   }
646   def->precision = ToEnum(fb_def->precision());
647 }
648 
Decode(const data::GPUOperation * fb_op,GPUOperation * op)649 absl::Status Decode(const data::GPUOperation* fb_op, GPUOperation* op) {
650   RETURN_IF_ERROR(Decode(fb_op->arguments(), &op->args_));
651   op->work_group_size_.x = fb_op->work_group_size()->x();
652   op->work_group_size_.y = fb_op->work_group_size()->y();
653   op->work_group_size_.z = fb_op->work_group_size()->z();
654   op->tensor_to_grid_ = ToEnum(fb_op->tensor_to_grid());
655   op->flops_ = fb_op->flops();
656   Decode(fb_op->definition(), &op->definition_);
657   op->grid_dimension_ = fb_op->grid_dimension();
658   op->work_group_launch_order_.x = fb_op->work_group_launch_order()->x();
659   op->work_group_launch_order_.y = fb_op->work_group_launch_order()->y();
660   op->work_group_launch_order_.z = fb_op->work_group_launch_order()->z();
661   op->grid_size_.x = fb_op->grid_size()->x();
662   op->grid_size_.y = fb_op->grid_size()->y();
663   op->grid_size_.z = fb_op->grid_size()->z();
664   for (auto name_fb : *fb_op->src_tensors_names()) {
665     std::string name(name_fb->c_str(), name_fb->size());
666     op->src_tensors_names_.push_back(std::move(name));
667   }
668   for (auto name_fb : *fb_op->dst_tensors_names()) {
669     std::string name(name_fb->c_str(), name_fb->size());
670     op->dst_tensors_names_.push_back(std::move(name));
671   }
672   op->work_groups_count_.x = fb_op->work_groups_count()->x();
673   op->work_groups_count_.y = fb_op->work_groups_count()->y();
674   op->work_groups_count_.z = fb_op->work_groups_count()->z();
675   op->CalculateConstArgsSize();
676   return absl::OkStatus();
677 }
678 
Encode(const GPUOperation & op,flatbuffers::FlatBufferBuilder * builder)679 flatbuffers::Offset<data::GPUOperation> Encode(
680     const GPUOperation& op, flatbuffers::FlatBufferBuilder* builder) {
681   auto args_fb = Encode(op.args_, builder);
682   auto work_group_size_fb = Encode(op.work_group_size_, builder);
683 
684   auto def_fb = Encode(op.definition_, builder);
685   auto work_group_launch_order_fb =
686       Encode(op.work_group_launch_order_, builder);
687   auto grid_size_fb = Encode(op.grid_size_, builder);
688   auto work_groups_count_fb = Encode(op.work_groups_count_, builder);
689 
690   std::vector<flatbuffers::Offset<flatbuffers::String>> src_names_fb;
691   src_names_fb.reserve(op.src_tensors_names_.size());
692   for (auto& name : op.src_tensors_names_) {
693     src_names_fb.push_back(builder->CreateString(name));
694   }
695   auto src_names_fb_vec = builder->CreateVector(src_names_fb);
696 
697   std::vector<flatbuffers::Offset<flatbuffers::String>> dst_names_fb;
698   dst_names_fb.reserve(op.dst_tensors_names_.size());
699   for (auto& name : op.dst_tensors_names_) {
700     dst_names_fb.push_back(builder->CreateString(name));
701   }
702   auto dst_names_fb_vec = builder->CreateVector(dst_names_fb);
703 
704   data::GPUOperationBuilder op_builder(*builder);
705   op_builder.add_arguments(args_fb);
706   op_builder.add_work_group_size(work_group_size_fb);
707   op_builder.add_tensor_to_grid(ToFB(op.tensor_to_grid_));
708   op_builder.add_flops(op.flops_);
709   op_builder.add_definition(def_fb);
710   op_builder.add_grid_dimension(op.grid_dimension_);
711   op_builder.add_work_group_launch_order(work_group_launch_order_fb);
712   op_builder.add_grid_size(grid_size_fb);
713   op_builder.add_src_tensors_names(src_names_fb_vec);
714   op_builder.add_dst_tensors_names(dst_names_fb_vec);
715   op_builder.add_work_groups_count(work_groups_count_fb);
716   return op_builder.Finish();
717 }
718 
719 }  // namespace gpu
720 }  // namespace tflite
721