xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/inference_context.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/metal/inference_context.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <map>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/strings/substitute.h"
28 #include "absl/time/clock.h"
29 #include "tensorflow/lite/delegates/gpu/common/memory_management.h"
30 #include "tensorflow/lite/delegates/gpu/common/memory_management/types.h"
31 #include "tensorflow/lite/delegates/gpu/common/model.h"
32 #include "tensorflow/lite/delegates/gpu/common/operations.h"
33 #include "tensorflow/lite/delegates/gpu/common/precision.h"
34 #include "tensorflow/lite/delegates/gpu/common/selectors/operation_selector.h"
35 #include "tensorflow/lite/delegates/gpu/common/selectors/special_selector.h"
36 #include "tensorflow/lite/delegates/gpu/common/selectors/subgraph.h"
37 #include "tensorflow/lite/delegates/gpu/common/shape.h"
38 #include "tensorflow/lite/delegates/gpu/common/status.h"
39 #include "tensorflow/lite/delegates/gpu/common/task/serialization_base.h"
40 #include "tensorflow/lite/delegates/gpu/common/util.h"
41 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
42 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
43 
44 namespace tflite {
45 namespace gpu {
46 namespace metal {
47 namespace {
48 
49 // returns true if actual memory for this storage type is buffer
IsBufferBased(const TensorStorageType & type)50 bool IsBufferBased(const TensorStorageType& type) {
51   return type == TensorStorageType::BUFFER ||
52          type == TensorStorageType::IMAGE_BUFFER ||
53          type == TensorStorageType::TEXTURE_2D ||
54          type == TensorStorageType::SINGLE_TEXTURE_2D;
55 }
56 
AddUsage(ValueId id,int task_index,std::map<ValueId,int2> * usage_records)57 void AddUsage(ValueId id, int task_index,
58               std::map<ValueId, int2>* usage_records) {
59   auto it = usage_records->find(id);
60   if (it == usage_records->end()) {
61     // initializing start index(.x) and end index(.y)
62     (*usage_records)[id].x = task_index;
63     (*usage_records)[id].y = task_index;
64   } else {
65     // updating end index(.y)
66     (*usage_records)[id].y = task_index;
67   }
68 }
69 
70 // Calculates the total size of the assignment.
TotalSize(const ObjectsAssignment<size_t> & assignment,size_t alignment=1)71 size_t TotalSize(const ObjectsAssignment<size_t>& assignment,
72                  size_t alignment = 1) {
73   size_t total_size = 0;
74   for (auto object_size : assignment.object_sizes) {
75     total_size += AlignByN(object_size, alignment);
76   }
77   return total_size;
78 }
79 
EncodeProgram(const std::string & code,const std::map<std::string,std::string> & defines,flatbuffers::FlatBufferBuilder * builder)80 flatbuffers::Offset<data::MetalProgram> EncodeProgram(
81     const std::string& code, const std::map<std::string, std::string>& defines,
82     flatbuffers::FlatBufferBuilder* builder) {
83   std::vector<flatbuffers::Offset<flatbuffers::String>> names_fb;
84   std::vector<flatbuffers::Offset<flatbuffers::String>> expressions_fb;
85   for (auto& define : defines) {
86     names_fb.push_back(builder->CreateString(define.first));
87     expressions_fb.push_back(builder->CreateString(define.second));
88   }
89   auto names_fb_vec = builder->CreateVector(names_fb);
90   auto expressions_fb_vec = builder->CreateVector(expressions_fb);
91   auto code_fb = builder->CreateString(code);
92   data::MetalProgramBuilder program_builder(*builder);
93   program_builder.add_define_names(names_fb_vec);
94   program_builder.add_define_expressions(expressions_fb_vec);
95   program_builder.add_code(code_fb);
96   return program_builder.Finish();
97 }
98 
DecodeProgram(const data::MetalProgram * metal_program,std::string * code,std::map<std::string,std::string> * defines)99 void DecodeProgram(const data::MetalProgram* metal_program, std::string* code,
100                    std::map<std::string, std::string>* defines) {
101   *code = std::string(metal_program->code()->c_str(),
102                       metal_program->code()->size());
103   for (int i = 0; i < metal_program->define_names()->size(); ++i) {
104     std::string key((*metal_program->define_names())[i]->c_str(),
105                     (*metal_program->define_names())[i]->size());
106     std::string value((*metal_program->define_expressions())[i]->c_str(),
107                       (*metal_program->define_expressions())[i]->size());
108     (*defines)[key] = value;
109   }
110 }
111 }  // namespace
112 
InitFromGraphWithTransforms(const CreateGpuModelInfo & create_info,GraphFloat32 * graph,id<MTLDevice> device_id,std::vector<uint8_t> * serialized_model)113 absl::Status InferenceContext::InitFromGraphWithTransforms(
114     const CreateGpuModelInfo& create_info, GraphFloat32* graph,
115     id<MTLDevice> device_id, std::vector<uint8_t>* serialized_model) {
116   RETURN_IF_ERROR(RunGraphTransformsForGpuModel(graph));
117   RETURN_IF_ERROR(
118       InitFromGraph(create_info, *graph, device_id, serialized_model));
119   return absl::OkStatus();
120 }
121 
CopyFromGpuModel(GpuModel * gpu_model)122 void InferenceContext::CopyFromGpuModel(GpuModel* gpu_model) {
123   for (const auto& input : gpu_model->input_ids_and_refs) {
124     input_ids_.push_back(input.first);
125   }
126   for (const auto& output : gpu_model->output_ids_and_refs) {
127     output_ids_.push_back(output.first);
128   }
129   nodes_.resize(gpu_model->nodes.size());
130   for (int i = 0; i < gpu_model->nodes.size(); ++i) {
131     nodes_[i].task.Init(std::move(gpu_model->nodes[i].gpu_operation));
132     nodes_[i].inputs = gpu_model->nodes[i].inputs;
133     nodes_[i].outputs = gpu_model->nodes[i].outputs;
134     nodes_[i].name = gpu_model->nodes[i].name;
135   }
136   const_tensors_descs_ = std::move(gpu_model->const_tensors);
137   tensors_descs_ = std::move(gpu_model->tensors);
138 }
139 
InitFromGraph(const CreateGpuModelInfo & create_info,const GraphFloat32 & graph,id<MTLDevice> device_id,std::vector<uint8_t> * serialized_model)140 absl::Status InferenceContext::InitFromGraph(
141     const CreateGpuModelInfo& create_info, const GraphFloat32& graph,
142     id<MTLDevice> device_id, std::vector<uint8_t>* serialized_model) {
143   device_ = device_id;
144   MetalDevice metal_device(device_id);
145   GpuModel gpu_model;
146   RETURN_IF_ERROR(
147       GraphToGpuModel(graph, create_info, metal_device.GetInfo(), &gpu_model));
148   flatbuffers::FlatBufferBuilder builder;
149   flatbuffers::Offset<tflite::gpu::data::GpuModel> gpu_model_fb;
150   if (serialized_model) {
151     gpu_model_fb = tflite::gpu::Encode(gpu_model, &builder);
152   }
153   CopyFromGpuModel(&gpu_model);
154 
155   for (const auto& external_tensor : create_info.external_immutable_tensors) {
156     auto* metal_spatial_tensor =
157         dynamic_cast<MetalSpatialTensor*>(external_tensor.second);
158     if (!metal_spatial_tensor) {
159       return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
160     }
161     external_immutable_tensors_[external_tensor.first] = metal_spatial_tensor;
162   }
163   std::map<ValueId, MetalSpatialTensor> temp_external_tensors;
164   for (const auto& external_tensor : create_info.external_mutable_tensors) {
165     RETURN_IF_ERROR(
166         CreateTensor(device_id, tensors_descs_[external_tensor.first],
167                      &temp_external_tensors[external_tensor.first]));
168     external_mutable_tensors_[external_tensor.first] =
169         &temp_external_tensors[external_tensor.first];
170   }
171   PrepareExternal();
172   RETURN_IF_ERROR(CompileOperations(&metal_device));
173   RETURN_IF_ERROR(AllocateTensors(&metal_device));
174   BindTensorsToOperations();
175   RETURN_IF_ERROR(UpdateParams(metal_device.GetInfo()));
176   RETURN_IF_ERROR(Tune(TuningType::kFast, &metal_device));
177 
178   for (auto& external_tensor : external_mutable_tensors_) {
179     external_tensor.second = nullptr;
180   }
181 
182   if (serialized_model) {
183     auto encoded_fb = Encode(&metal_device, gpu_model_fb, &builder);
184     data::FinishInferenceContextBuffer(builder, encoded_fb);
185     serialized_model->resize(builder.GetSize());
186     std::memcpy(serialized_model->data(), builder.GetBufferPointer(),
187                 builder.GetSize());
188   }
189 
190   bool add_icb_support = false && external_mutable_tensors_.empty();
191   if (add_icb_support) {
192     if (@available(macOS 11.00, iOS 13.0, tvOS 13.0, *)) {
193       MTLIndirectCommandBufferDescriptor* icb_desc =
194           [[MTLIndirectCommandBufferDescriptor alloc] init];
195       icb_desc.commandTypes = MTLIndirectCommandTypeConcurrentDispatch;
196       icb_desc.inheritBuffers = NO;
197       icb_desc.inheritPipelineState = NO;
198       icb_desc.maxKernelBufferBindCount = 1;
199 
200       icb_ = [device_id newIndirectCommandBufferWithDescriptor:icb_desc
201                                                maxCommandCount:nodes_.size()
202                                                        options:0];
203 
204       for (int i = 0; i < nodes_.size(); ++i) {
205         id<MTLIndirectComputeCommand> icb_command =
206             [icb_ indirectComputeCommandAtIndex:i];
207         auto& node = nodes_[i];
208         node.task.EncodeToICB(icb_command);
209       }
210     }
211   }
212   return absl::OkStatus();
213 }
214 
RestoreDeserialized(const absl::Span<const uint8_t> serialized_model,id<MTLDevice> device_id,CreateGpuModelInfo * create_info)215 absl::Status InferenceContext::RestoreDeserialized(
216     const absl::Span<const uint8_t> serialized_model, id<MTLDevice> device_id,
217     CreateGpuModelInfo* create_info) {
218   flatbuffers::Verifier verifier(serialized_model.data(),
219                                  serialized_model.size());
220   if (!data::VerifyInferenceContextBuffer(verifier)) {
221     return absl::DataLossError("Deserialization failed.");
222   }
223   auto decoded_fb = data::GetInferenceContext(serialized_model.data());
224   device_ = device_id;
225   MetalDevice metal_device(device_id);
226   RETURN_IF_ERROR(Decode(&metal_device, decoded_fb));
227 
228   std::map<ValueId, MetalSpatialTensor> temp_external_tensors;
229   if (create_info) {
230     for (const auto& external_tensor :
231          create_info->external_immutable_tensors) {
232       auto* cl_spatial_tensor =
233           dynamic_cast<MetalSpatialTensor*>(external_tensor.second);
234       if (!cl_spatial_tensor) {
235         return absl::InvalidArgumentError("Expected MetalSpatialTensor.");
236       }
237       external_immutable_tensors_[external_tensor.first] = cl_spatial_tensor;
238     }
239     for (const auto& external_tensor : create_info->external_mutable_tensors) {
240       RETURN_IF_ERROR(
241           CreateTensor(device_id, tensors_descs_[external_tensor.first],
242                        &temp_external_tensors[external_tensor.first]));
243       external_mutable_tensors_[external_tensor.first] =
244           &temp_external_tensors[external_tensor.first];
245     }
246   }
247   PrepareExternal();
248 
249   RETURN_IF_ERROR(AllocateTensors(&metal_device));
250   BindTensorsToOperations();
251 
252   for (auto& node : nodes_) {
253     RETURN_IF_ERROR(node.task.RestoreDeserialized(&metal_device));
254   }
255   RETURN_IF_ERROR(UpdateParams(metal_device.GetInfo()));
256   for (auto& external_tensor : external_mutable_tensors_) {
257     external_tensor.second = nullptr;
258   }
259   return absl::OkStatus();
260 }
261 
Encode(MetalDevice * device,flatbuffers::Offset<tflite::gpu::data::GpuModel> gpu_model_fb,flatbuffers::FlatBufferBuilder * builder)262 flatbuffers::Offset<data::InferenceContext> InferenceContext::Encode(
263     MetalDevice* device,
264     flatbuffers::Offset<tflite::gpu::data::GpuModel> gpu_model_fb,
265     flatbuffers::FlatBufferBuilder* builder) {
266   std::vector<flatbuffers::Offset<tflite::gpu::data::Int3>> work_groups_fb;
267   for (int i = 0; i < nodes_.size(); ++i) {
268     auto work_group_fb =
269         tflite::gpu::Encode(nodes_[i].task.GetWorkGroupSize(), builder);
270     work_groups_fb.push_back(work_group_fb);
271   }
272   auto work_groups_fb_vec = builder->CreateVector(work_groups_fb);
273 
274   std::vector<flatbuffers::Offset<data::MetalProgram>> programs_fb;
275   for (int i = 0; i < nodes_.size(); ++i) {
276     auto program_fb = EncodeProgram(nodes_[i].task.GetCode(),
277                                     nodes_[i].task.GetDefines(), builder);
278     programs_fb.push_back(program_fb);
279   }
280   auto programs_fb_vec = builder->CreateVector(programs_fb);
281 
282   data::InferenceContextBuilder inf_builder(*builder);
283   inf_builder.add_gpu_model(gpu_model_fb);
284   inf_builder.add_tuned_work_group_sizes_per_node(work_groups_fb_vec);
285   inf_builder.add_metal_programs(programs_fb_vec);
286   return inf_builder.Finish();
287 }
288 
Decode(MetalDevice * device,const data::InferenceContext * fb_inference)289 absl::Status InferenceContext::Decode(
290     MetalDevice* device, const data::InferenceContext* fb_inference) {
291   GpuModel gpu_model;
292   RETURN_IF_ERROR(tflite::gpu::Decode(fb_inference->gpu_model(), &gpu_model));
293   CopyFromGpuModel(&gpu_model);
294 
295   for (int i = 0; i < nodes_.size(); ++i) {
296     std::string code;
297     std::map<std::string, std::string> defines;
298     DecodeProgram((*fb_inference->metal_programs())[i], &code, &defines);
299     RETURN_IF_ERROR(nodes_[i].task.Init(device, code, defines));
300 
301     int3 wg_size;
302     wg_size.x = (*fb_inference->tuned_work_group_sizes_per_node())[i]->x();
303     wg_size.y = (*fb_inference->tuned_work_group_sizes_per_node())[i]->y();
304     wg_size.z = (*fb_inference->tuned_work_group_sizes_per_node())[i]->z();
305     nodes_[i].task.SetWorkGroupSize(wg_size);
306   }
307   return absl::OkStatus();
308 }
309 
CompileOperations(MetalDevice * device)310 absl::Status InferenceContext::CompileOperations(MetalDevice* device) {
311   for (auto& node : nodes_) {
312     RETURN_IF_ERROR(node.task.Compile(device));
313   }
314   return absl::OkStatus();
315 }
316 
AllocateTensors(MetalDevice * device)317 absl::Status InferenceContext::AllocateTensors(MetalDevice* device) {
318   RETURN_IF_ERROR(AllocateMemoryForConstTensors(device));
319   RETURN_IF_ERROR(AllocateMemoryForBuffers(device));
320   RETURN_IF_ERROR(AllocateMemoryForStrongShapes(device));
321   return absl::OkStatus();
322 }
323 
GetTensor(ValueId tensor_id)324 MetalSpatialTensor* InferenceContext::GetTensor(ValueId tensor_id) {
325   if (external_immutable_tensors_.find(tensor_id) !=
326       external_immutable_tensors_.end()) {
327     return external_immutable_tensors_[tensor_id];
328   } else if (external_mutable_tensors_.find(tensor_id) !=
329              external_mutable_tensors_.end()) {
330     return external_mutable_tensors_[tensor_id];
331   } else if (const_tensors_.find(tensor_id) != const_tensors_.end()) {
332     return &const_tensors_[tensor_id];
333   } else if (graph_ids_to_shared_buffer_tensors_.find(tensor_id) !=
334              graph_ids_to_shared_buffer_tensors_.end()) {
335     return &shared_buffer_tensors_
336         [graph_ids_to_shared_buffer_tensors_[tensor_id]];
337   } else if (graph_ids_to_strong_shape_tensors_.find(tensor_id) !=
338              graph_ids_to_strong_shape_tensors_.end()) {
339     return &strong_shape_tensors_
340         [graph_ids_to_strong_shape_tensors_[tensor_id]];
341   }
342   return nullptr;
343 }
344 
SetInputTensor(ValueId id,const TensorFloat32 & tensor)345 absl::Status InferenceContext::SetInputTensor(ValueId id,
346                                               const TensorFloat32& tensor) {
347   MetalSpatialTensor* gpu_tensor = GetTensor(id);
348   TensorDescriptor descriptor_with_data = gpu_tensor->GetDescriptor();
349   descriptor_with_data.UploadData(tensor);
350   return gpu_tensor->UploadDescriptorData(descriptor_with_data, device_);
351 }
352 
GetOutputTensor(ValueId id,TensorFloat32 * result)353 absl::Status InferenceContext::GetOutputTensor(ValueId id,
354                                                TensorFloat32* result) {
355   const MetalSpatialTensor* gpu_tensor = GetTensor(id);
356   const auto dst_shape = BHWC(gpu_tensor->Batch(), gpu_tensor->Height(),
357                               gpu_tensor->Width(), gpu_tensor->Channels());
358   result->id = id;
359   result->shape = dst_shape;
360   result->data.resize(dst_shape.DimensionsProduct());
361 
362   TensorDescriptor desc;
363   RETURN_IF_ERROR(gpu_tensor->ToDescriptor(&desc, device_));
364   desc.DownloadData(result);
365   return absl::OkStatus();
366 }
367 
BindTensorsToOperations()368 void InferenceContext::BindTensorsToOperations() {
369   for (auto& node : nodes_) {
370     const auto& src_ids = node.inputs;
371     for (int i = 0; i < src_ids.size(); ++i) {
372       node.task.SetSrcTensor(GetTensor(src_ids[i]), i);
373     }
374     const auto& dst_ids = node.outputs;
375     for (int i = 0; i < dst_ids.size(); ++i) {
376       node.task.SetDstTensor(GetTensor(dst_ids[i]), i);
377     }
378   }
379 }
380 
UpdateParams(const GpuInfo & gpu_info)381 absl::Status InferenceContext::UpdateParams(const GpuInfo& gpu_info) {
382   for (auto& node : nodes_) {
383     std::vector<BHWC> src_shapes;
384     std::vector<BHWC> dst_shapes;
385     for (const auto& in_id : node.inputs) {
386       const auto& shape = tensors_descs_[in_id].GetBHWDCShape();
387       src_shapes.push_back(BHWC(shape.b, shape.h, shape.w, shape.c));
388     }
389     for (const auto& out_id : node.outputs) {
390       const auto& shape = tensors_descs_[out_id].GetBHWDCShape();
391       dst_shapes.push_back(BHWC(shape.b, shape.h, shape.w, shape.c));
392     }
393     RETURN_IF_ERROR(node.task.UpdateParams());
394   }
395   return absl::OkStatus();
396 }
397 
GetTensorMemoryType(ValueId id)398 InferenceContext::TensorMemoryType InferenceContext::GetTensorMemoryType(
399     ValueId id) {
400   if (external_immutable_tensors_.find(id) !=
401       external_immutable_tensors_.end()) {
402     return TensorMemoryType::kExternal;
403   } else if (external_mutable_tensors_.find(id) !=
404              external_mutable_tensors_.end()) {
405     return TensorMemoryType::kExternal;
406   } else if (const_tensors_.find(id) != const_tensors_.end()) {
407     return TensorMemoryType::kConst;
408   } else if (IsBufferBased(tensors_descs_[id].GetStorageType())) {
409     return TensorMemoryType::kBuffer;
410   } else {
411     return TensorMemoryType::kStrongShape;
412   }
413 }
414 
GetUsages(const std::function<bool (ValueId)> & functor,std::map<ValueId,int2> * usages)415 void InferenceContext::GetUsages(const std::function<bool(ValueId)>& functor,
416                                  std::map<ValueId, int2>* usages) {
417   for (ValueId in_id : input_ids_) {
418     if (functor(in_id)) {
419       AddUsage(in_id, 0, usages);
420     }
421   }
422   for (int op_index = 0; op_index < nodes_.size(); ++op_index) {
423     for (auto& tensor_id : nodes_[op_index].inputs) {
424       if (functor(tensor_id)) {
425         AddUsage(tensor_id, op_index, usages);
426       }
427     }
428     for (auto& tensor_id : nodes_[op_index].outputs) {
429       if (functor(tensor_id)) {
430         AddUsage(tensor_id, op_index, usages);
431       }
432     }
433   }
434   for (ValueId out_id : output_ids_) {
435     if (functor(out_id)) {
436       AddUsage(out_id, nodes_.size(), usages);
437     }
438   }
439 }
440 
AllocateMemoryForConstTensors(MetalDevice * device)441 absl::Status InferenceContext::AllocateMemoryForConstTensors(
442     MetalDevice* device) {
443   for (auto& description : const_tensors_descs_) {
444     RETURN_IF_ERROR(const_tensors_[description.first].CreateFromDescriptor(
445         description.second, device->device()));
446   }
447   const_tensors_descs_.clear();
448   return absl::OkStatus();
449 }
450 
AllocateMemoryForBuffers(MetalDevice * device)451 absl::Status InferenceContext::AllocateMemoryForBuffers(MetalDevice* device) {
452   std::map<ValueId, int2> buffer_usages;
453   GetUsages(
454       [this](ValueId id) {
455         return GetTensorMemoryType(id) == TensorMemoryType::kBuffer;
456       },
457       &buffer_usages);
458 
459   if (buffer_usages.empty()) {
460     return absl::OkStatus();
461   }
462 
463   // From Apple documentation:
464   // For buffers in the device address space, align the offset to the data type
465   // consumed by the compute function (which is always less than or equal to 16
466   // bytes).
467   // For buffers in the constant address space, align the offset to 256
468   // bytes in macOS. In iOS, align the offset to the maximum of either the data
469   // type consumed by the compute function, or 4 bytes. A 16-byte alignment is
470   // safe in iOS if you don't need to consider the data type.
471 #if defined(TARGET_IOS) || defined(TARGET_TVOS)
472   const size_t kConstAlignment = 16;
473 #elif defined(TARGET_MACOS)
474   const size_t kConstAlignment = 256;
475 #else
476   const size_t kConstAlignment = 256;
477 #endif
478   size_t min_common_alignment = kConstAlignment;
479   std::vector<TensorUsageRecord<size_t>> buffer_usage_records;
480   for (auto& usage : buffer_usages) {
481     const auto& t = tensors_descs_[usage.first];
482     const auto& shape = t.GetBHWDCShape();
483     const auto& descriptor = t;
484     const size_t element_size = SizeOf(descriptor.GetDataType());
485     size_t buffer_size;
486     size_t row_bytes_alignment = [device->device()
487         minimumLinearTextureAlignmentForPixelFormat:DataTypeToRGBAPixelFormat(
488                                                         descriptor
489                                                             .GetDataType(),
490                                                         false)];
491     if (descriptor.GetStorageType() == TensorStorageType::TEXTURE_2D) {
492       min_common_alignment =
493           std::lcm(min_common_alignment, row_bytes_alignment);
494       const size_t bytes_per_row = element_size * shape.b * shape.w * 4;
495       const size_t height = shape.h * DivideRoundUp(shape.c, 4);
496       buffer_size = AlignByN(bytes_per_row, row_bytes_alignment) * height;
497     } else if (descriptor.GetStorageType() ==
498                TensorStorageType::SINGLE_TEXTURE_2D) {
499       min_common_alignment =
500           std::lcm(min_common_alignment, row_bytes_alignment);
501       const size_t bytes_per_row = element_size * shape.b * shape.w * shape.c;
502       const size_t height = shape.h;
503       buffer_size = AlignByN(bytes_per_row, row_bytes_alignment) * height;
504     } else {
505       buffer_size =
506           shape.b * shape.w * shape.h * AlignByN(shape.c, 4) * element_size;
507     }
508     graph_ids_to_shared_buffer_tensors_[usage.first] =
509         buffer_usage_records.size();
510     buffer_usage_records.push_back({buffer_size,
511                                     static_cast<TaskId>(usage.second.x),
512                                     static_cast<TaskId>(usage.second.y)});
513   }
514 
515   ObjectsAssignment<size_t> buffer_assignment;
516   RETURN_IF_ERROR(AssignObjectsToTensors(
517       buffer_usage_records, MemoryStrategy::GREEDY_BEST, &buffer_assignment));
518 
519   OffsetsAssignment offset_assignment;
520   RETURN_IF_ERROR(AssignOffsetsToTensors(
521       buffer_usage_records, MemoryStrategy::GREEDY_BY_SIZE, &offset_assignment,
522       min_common_alignment));
523 
524   bool use_offset_assignment = false;
525   if (offset_assignment.total_size <= TotalSize(buffer_assignment) &&
526       offset_assignment.total_size <= device->GetInfo().GetMaxBufferSize()) {
527     use_offset_assignment = true;
528   }
529 
530   if (use_offset_assignment) {
531     shared_buffers_.resize(1);
532     shared_buffers_[0] =
533         [device->device() newBufferWithLength:offset_assignment.total_size
534                                       options:MTLResourceStorageModeShared];
535   } else {
536     shared_buffers_.resize(buffer_assignment.object_sizes.size());
537     for (int i = 0; i < buffer_assignment.object_sizes.size(); ++i) {
538       // Initialize metal buffer
539       NSUInteger bufferSize = buffer_assignment.object_sizes[i];
540 
541       if (bufferSize > device->GetInfo().GetMaxBufferSize()) {
542         std::string error("Tensor id: ");
543         error += std::to_string(buffer_assignment.object_ids[i]) +
544                  " with size: " + std::to_string(bufferSize) +
545                  " exceeds MTLDevice maxBufferLength: " +
546                  std::to_string(device->GetInfo().GetMaxBufferSize());
547         return absl::ResourceExhaustedError(error);
548       }
549 
550       shared_buffers_[i] =
551           [device->device() newBufferWithLength:bufferSize
552                                         options:MTLResourceStorageModeShared];
553     }
554   }
555 
556   std::vector<bool> created_tensors(buffer_usage_records.size(), false);
557   shared_buffer_tensors_.resize(buffer_usage_records.size());
558   for (auto& node : nodes_) {
559     std::vector<ValueId> all_ids = node.inputs;
560     all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
561     for (auto& tensor_id : all_ids) {
562       if (GetTensorMemoryType(tensor_id) != TensorMemoryType::kBuffer) {
563         continue;
564       }
565       const int tensor_index = graph_ids_to_shared_buffer_tensors_[tensor_id];
566       if (created_tensors[tensor_index]) continue;
567       const auto& tensor_dummy = tensors_descs_[tensor_id];
568       const int buffer_index = buffer_assignment.object_ids[tensor_index];
569       uint64_t base_buffer_offset = 0;
570       id<MTLBuffer> base_buffer;
571       if (use_offset_assignment) {
572         base_buffer = shared_buffers_[0];
573         base_buffer_offset = offset_assignment.offsets[tensor_index];
574       } else {
575         base_buffer = shared_buffers_[buffer_index];
576         base_buffer_offset = 0;
577       }
578       if (tensor_dummy.GetStorageType() == TensorStorageType::TEXTURE_2D ||
579           tensor_dummy.GetStorageType() ==
580               TensorStorageType::SINGLE_TEXTURE_2D) {
581         size_t row_bytes_alignment = [device->device()
582             minimumLinearTextureAlignmentForPixelFormat:
583                 DataTypeToRGBAPixelFormat(tensor_dummy.GetDataType(), false)];
584         RETURN_IF_ERROR(CreateTensorSharedImage2DBuffer(
585             base_buffer, tensor_dummy, row_bytes_alignment,
586             &shared_buffer_tensors_[tensor_index], base_buffer_offset));
587       } else {
588         RETURN_IF_ERROR(CreateTensorSharedBuffer(
589             base_buffer, tensor_dummy, &shared_buffer_tensors_[tensor_index],
590             base_buffer_offset));
591       }
592       created_tensors[tensor_index] = true;
593     }
594   }
595   return absl::OkStatus();
596 }
597 
AllocateMemoryForStrongShapes(MetalDevice * device)598 absl::Status InferenceContext::AllocateMemoryForStrongShapes(
599     MetalDevice* device) {
600   std::map<ValueId, int2> usages;
601   GetUsages(
602       [this](ValueId id) {
603         return GetTensorMemoryType(id) == TensorMemoryType::kStrongShape;
604       },
605       &usages);
606 
607   struct TensorDescComparator {
608     TensorDescriptor tensor_desc;
609 
610     bool operator==(const TensorDescComparator& t) const {
611       return tensor_desc == t.tensor_desc &&
612              tensor_desc.GetBHWDCShape() == t.tensor_desc.GetBHWDCShape();
613     }
614   };
615 
616   std::vector<TensorUsageRecord<TensorDescComparator>> usage_records;
617   std::map<ValueId, ValueId> remap_from_graph_ids;
618   for (auto& usage : usages) {
619     remap_from_graph_ids[usage.first] = usage_records.size();
620     usage_records.push_back({{tensors_descs_[usage.first]},
621                              static_cast<TaskId>(usage.second.x),
622                              static_cast<TaskId>(usage.second.y)});
623   }
624 
625   ObjectsAssignment<TensorDescComparator> assignment;
626   RETURN_IF_ERROR(AssignObjectsToTensors(
627       usage_records, MemoryStrategy::EQUALITY, &assignment));
628 
629   for (auto& node : nodes_) {
630     std::vector<ValueId> all_ids = node.inputs;
631     all_ids.insert(all_ids.end(), node.outputs.begin(), node.outputs.end());
632     for (auto& tensor_id : all_ids) {
633       const auto& tensor_dummy = tensors_descs_[tensor_id];
634       if (GetTensorMemoryType(tensor_id) != TensorMemoryType::kStrongShape) {
635         continue;
636       }
637       const auto id = assignment.object_ids[remap_from_graph_ids[tensor_id]];
638       graph_ids_to_strong_shape_tensors_[tensor_id] = id;
639       const auto& it = strong_shape_tensors_.find(id);
640       if (it == strong_shape_tensors_.end()) {
641         RETURN_IF_ERROR(CreateTensor(device->device(), tensor_dummy,
642                                      &strong_shape_tensors_[id]));
643       }
644     }
645   }
646   return absl::OkStatus();
647 }
648 
Tune(TuningType tuning_type,MetalDevice * device)649 absl::Status InferenceContext::Tune(TuningType tuning_type,
650                                     MetalDevice* device) {
651   for (auto& node : nodes_) {
652     RETURN_IF_ERROR(node.task.Tune(tuning_type, device));
653   }
654   return absl::OkStatus();
655 }
656 
EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder)657 void InferenceContext::EncodeWithEncoder(
658     id<MTLComputeCommandEncoder> command_encoder) {
659   for (int i = 0; i < nodes_.size(); ++i) {
660     auto& task = nodes_[i].task;
661     task.Encode(command_encoder);
662   }
663 }
664 
665 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
AddResources(id<MTLComputeCommandEncoder> command_encoder)666 void InferenceContext::AddResources(
667     id<MTLComputeCommandEncoder> command_encoder) {
668   for (int i = 0; i < nodes_.size(); ++i) {
669     auto& task = nodes_[i].task;
670     task.AddResourcesToEncoder(command_encoder);
671   }
672 }
673 
674 API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
EncodeWithICB(id<MTLComputeCommandEncoder> command_encoder)675 void InferenceContext::EncodeWithICB(
676     id<MTLComputeCommandEncoder> command_encoder) {
677   [command_encoder executeCommandsInBuffer:icb_
678                                  withRange:NSMakeRange(0, nodes_.size())];
679 }
680 
Profile(id<MTLDevice> device,ProfilingInfo * result)681 void InferenceContext::Profile(id<MTLDevice> device, ProfilingInfo* result) {
682   result->dispatches.resize(nodes_.size());
683   id<MTLCommandQueue> command_queue = [device newCommandQueue];
684   for (int k = 0; k < nodes_.size(); ++k) {
685     @autoreleasepool {
686       id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
687       id<MTLComputeCommandEncoder> encoder =
688           [command_buffer computeCommandEncoder];
689       auto& task = nodes_[k].task;
690       const int kRuns = 500;
691       for (int i = 0; i < kRuns; ++i) {
692         task.Encode(encoder);
693       }
694       [encoder endEncoding];
695       auto start = absl::Now();
696       [command_buffer commit];
697       [command_buffer waitUntilCompleted];
698       auto end = absl::Now();
699       auto& dispatch_info = result->dispatches[k];
700       dispatch_info.label = nodes_[k].name;
701       dispatch_info.duration = (end - start) / static_cast<float>(kRuns);
702 
703       uint64_t read_size = 0;
704       for (auto& src_id : nodes_[k].inputs) {
705         read_size += GetTensor(src_id)->GetMemorySizeInBytes();
706       }
707       const auto& gpu_op = nodes_[k].task.GetGpuOperation();
708       read_size += gpu_op.const_args_size_;
709       uint64_t write_size = 0;
710       for (auto& dst_id : nodes_[k].outputs) {
711         write_size += GetTensor(dst_id)->GetMemorySizeInBytes();
712       }
713       dispatch_info.flops = gpu_op.flops_;
714       dispatch_info.read_mem_size = read_size;
715       dispatch_info.write_mem_size = write_size;
716     }
717   }
718 }
719 
GetIntermediateTensorsSize() const720 uint64_t InferenceContext::GetIntermediateTensorsSize() const {
721   uint64_t total_memory = 0;
722   for (const auto& t : strong_shape_tensors_) {
723     total_memory += t.second.GetMemorySizeInBytes();
724   }
725   for (const auto& b : shared_buffers_) {
726     total_memory += [b length];
727   }
728 
729   return total_memory;
730 }
731 
GetConstantTensorsSize() const732 uint64_t InferenceContext::GetConstantTensorsSize() const {
733   uint64_t total_size = 0;
734   for (const auto& node : nodes_) {
735     total_size += node.task.GetGpuOperation().const_args_size_;
736   }
737   for (const auto& t : const_tensors_) {
738     total_size += t.second.GetMemorySizeInBytes();
739   }
740   return total_size;
741 }
742 
EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer)743 void InferenceContext::EncodeWithCommandBuffer(
744     id<MTLCommandBuffer> command_buffer) {
745   for (int i = 0; i < nodes_.size(); ++i) {
746     id<MTLComputeCommandEncoder> encoder =
747         [command_buffer computeCommandEncoder];
748     auto& task = nodes_[i].task;
749     task.Encode(encoder);
750     [encoder endEncoding];
751   }
752 }
753 
EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,int flush_period)754 void InferenceContext::EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,
755                                               int flush_period) {
756   id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
757   for (int i = 0; i < nodes_.size(); ++i) {
758     id<MTLComputeCommandEncoder> encoder =
759         [command_buffer computeCommandEncoder];
760     auto& task = nodes_[i].task;
761     task.Encode(encoder);
762     [encoder endEncoding];
763     if (i % flush_period == (flush_period - 1)) {
764       [command_buffer commit];
765       command_buffer = [command_queue commandBuffer];
766     }
767   }
768   [command_buffer commit];
769 }
770 
SetTensor(const ValueId & tensor_id,MetalSpatialTensor * tensor_ptr)771 absl::Status InferenceContext::SetTensor(const ValueId& tensor_id,
772                                          MetalSpatialTensor* tensor_ptr) {
773   auto it = external_mutable_tensors_.find(tensor_id);
774   if (it == external_mutable_tensors_.end()) {
775     return absl::InvalidArgumentError("No external tensor with this id.");
776   }
777   external_mutable_tensors_[tensor_id] = tensor_ptr;
778   for (int node_index : external_tensor_to_nodes_[tensor_id]) {
779     auto& node = nodes_[node_index];
780     for (int i = 0; i < node.inputs.size(); ++i) {
781       if (node.inputs[i] == tensor_id) {
782         node.task.SetSrcTensor(tensor_ptr, i);
783       }
784     }
785     for (int i = 0; i < node.outputs.size(); ++i) {
786       if (node.outputs[i] == tensor_id) {
787         node.task.SetDstTensor(tensor_ptr, i);
788       }
789     }
790   }
791   return absl::OkStatus();
792 }
793 
PrepareExternal()794 void InferenceContext::PrepareExternal() {
795   for (auto& external : external_mutable_tensors_) {
796     for (int i = 0; i < nodes_.size(); ++i) {
797       bool has_tensor = false;
798       const auto& src_ids = nodes_[i].inputs;
799       for (int i = 0; i < src_ids.size(); ++i) {
800         if (src_ids[i] == external.first) {
801           has_tensor = true;
802         }
803       }
804       const auto& dst_ids = nodes_[i].outputs;
805       for (int i = 0; i < dst_ids.size(); ++i) {
806         if (dst_ids[i] == external.first) {
807           has_tensor = true;
808         }
809       }
810       if (has_tensor) {
811         external_tensor_to_nodes_[external.first].push_back(i);
812       }
813     }
814   }
815 }
816 
817 }  // namespace metal
818 }  // namespace gpu
819 }  // namespace tflite
820