xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/inference_context.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include <list>
22 #include <map>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/lite/delegates/gpu/common/gpu_model.h"
27 #include "tensorflow/lite/delegates/gpu/common/gpu_model_generated.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_hints.h"
30 #include "tensorflow/lite/delegates/gpu/common/precision.h"
31 #include "tensorflow/lite/delegates/gpu/common/shape.h"
32 #include "tensorflow/lite/delegates/gpu/common/status.h"
33 #include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h"
34 #include "tensorflow/lite/delegates/gpu/common/task/tuning_type.h"
35 #include "tensorflow/lite/delegates/gpu/metal/compute_task.h"
36 #include "tensorflow/lite/delegates/gpu/metal/inference_context_generated.h"
37 #include "tensorflow/lite/delegates/gpu/metal/metal_device.h"
38 #include "tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h"
39 
40 namespace tflite {
41 namespace gpu {
42 namespace metal {
43 
44 struct MetalNode {
45   ComputeTask task;
46   std::vector<ValueId> inputs;
47   std::vector<ValueId> outputs;
48 
49   // Mostly for debug purposes.
50   std::string name;
51 
52   MetalNode() = default;
53 
54   MetalNode(MetalNode&& node) = default;
55   MetalNode& operator=(MetalNode&& node) = default;
56   MetalNode(const MetalNode&) = delete;
57   MetalNode& operator=(const MetalNode&) = delete;
58 };
59 
60 class InferenceContext {
61  public:
62   InferenceContext() = default;
63 
64   // IMPORTANT: If InitFromGraph used, RunGraphTransforms must be applied for
65   // this graph upfront, otherwise not guaranteed correct behavior
66   absl::Status InitFromGraph(const CreateGpuModelInfo& create_info,
67                              const GraphFloat32& graph, id<MTLDevice> device_id,
68                              std::vector<uint8_t>* serialized_model = nullptr);
69 
70   // Applies specific transformations to the graph before the
71   // initialization. These transformations are either impossible or useless in
72   // other backends.
73   absl::Status InitFromGraphWithTransforms(
74       const CreateGpuModelInfo& create_info, GraphFloat32* graph,
75       id<MTLDevice> device_id,
76       std::vector<uint8_t>* serialized_model = nullptr);
77 
78   absl::Status RestoreDeserialized(
79       const absl::Span<const uint8_t> serialized_model, id<MTLDevice> device_id,
80       CreateGpuModelInfo* create_info = nullptr);
81 
82   /// Inserts all GPU compute tasks into the command encoder.
83   /// @param inputOutputBuffers Must be created and passed into the method
84   /// with pairs ID:buffer
85   /// @discussion No GPU synchronization functions are used inside. All GPU
86   /// resources must be created
87   ///             with the same device which has been used in
88   ///             compileModelWithDevice() method.
89   void EncodeWithEncoder(id<MTLComputeCommandEncoder> command_encoder);
90 
91   /// Inserts all GPU compute tasks into the command buffer. For every task will
92   /// be used separate
93   ///   encoder.
94   /// @param inputOutputBuffers Must be created and passed into the method with
95   /// pairs ID:buffer
96   /// @discussion No GPU synchronization functions are used inside. All GPU
97   /// resources must be created
98   ///             with the same device which has been used in
99   ///             compileModelWithDevice() method.
100   void EncodeWithCommandBuffer(id<MTLCommandBuffer> command_buffer);
101 
102   /// Adds all GPU compute tasks to the command queue. For every task will be
103   /// used separate
104   ///   encoder. Few encoders(flushPeriod) batched into compute buffer that sent
105   ///   for execution.
106   /// @param inputOutputBuffers Must be created and passed into the method with
107   /// pairs ID:buffer
108   /// @discussion No GPU synchronization functions are used inside. All GPU
109   /// resources must be created
110   ///             with the same device which has been used in
111   ///             compileModelWithDevice() method.
112   void EncodeWithCommandQueue(id<MTLCommandQueue> command_queue,
113                               int flush_period);
114 
115   API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
116   void AddResources(id<MTLComputeCommandEncoder> command_encoder);
117   API_AVAILABLE(ios(13.0), macos(11.00), tvos(13.0))
118   void EncodeWithICB(id<MTLComputeCommandEncoder> command_encoder);
119 
120   void Profile(id<MTLDevice> device, ProfilingInfo* result);
121   // Returns size in bytes for all intermediate(runtime) tensors that owned by
122   // this inference context. Do not include constant tensors.
123   uint64_t GetIntermediateTensorsSize() const;
124   uint64_t GetConstantTensorsSize() const;
125 
126   // Can be used only with ids from external_mutable_tensors in create_info
127   // Must be called after initialization and before execution
128   absl::Status SetTensor(const ValueId& tensor_id,
129                          MetalSpatialTensor* tensor_ptr);
130 
131   MetalSpatialTensor* GetTensor(ValueId tensor_id);
132   absl::Status SetInputTensor(ValueId id, const TensorFloat32& tensor);
133   absl::Status GetOutputTensor(ValueId id, TensorFloat32* result);
134 
135  private:
136   enum class TensorMemoryType {
137     kStrongShape,
138     kBuffer,
139     kVariable,
140     kConst,
141     kExternal
142   };
143 
144   flatbuffers::Offset<data::InferenceContext> Encode(
145       MetalDevice* device,
146       flatbuffers::Offset<tflite::gpu::data::GpuModel> gpu_model_fb,
147       flatbuffers::FlatBufferBuilder* builder);
148 
149   absl::Status Decode(MetalDevice* device,
150                       const data::InferenceContext* fb_inference);
151 
152   void CopyFromGpuModel(GpuModel* gpu_model);
153   absl::Status CompileOperations(MetalDevice* device);
154   void PrepareExternal();
155 
156   absl::Status AllocateTensors(MetalDevice* device);
157   absl::Status AllocateMemoryForConstTensors(MetalDevice* device);
158   absl::Status AllocateMemoryForBuffers(MetalDevice* device);
159   absl::Status AllocateMemoryForStrongShapes(MetalDevice* device);
160   void BindTensorsToOperations();
161   absl::Status UpdateParams(const GpuInfo& gpu_info);
162   void GetUsages(const std::function<bool(ValueId)>& functor,
163                  std::map<ValueId, int2>* usages);
164   TensorMemoryType GetTensorMemoryType(ValueId id);
165   absl::Status Tune(TuningType tuning_type, MetalDevice* device);
166 
167   absl::flat_hash_map<ValueId, TensorDescriptor> tensors_descs_;
168 
169   std::vector<MetalNode> nodes_;
170   std::vector<ValueId> input_ids_;
171   std::vector<ValueId> output_ids_;
172 
173   absl::flat_hash_map<ValueId, MetalSpatialTensor*> external_immutable_tensors_;
174   absl::flat_hash_map<ValueId, MetalSpatialTensor*> external_mutable_tensors_;
175   absl::flat_hash_map<ValueId, std::vector<int>> external_tensor_to_nodes_;
176   absl::flat_hash_map<ValueId, TensorDescriptor> const_tensors_descs_;
177   std::map<ValueId, MetalSpatialTensor> const_tensors_;
178 
179   std::map<ValueId, int> graph_ids_to_shared_buffer_tensors_;
180   std::vector<id<MTLBuffer>> shared_buffers_;
181   std::vector<MetalSpatialTensor>
182       shared_buffer_tensors_;  // use references to memory
183                                // from _sharedBuffers
184 
185   std::map<ValueId, MetalSpatialTensor> strong_shape_tensors_;
186   std::map<ValueId, ValueId> graph_ids_to_strong_shape_tensors_;
187 
188   id<MTLIndirectCommandBuffer> icb_ = nullptr;
189   id<MTLDevice> device_ = nullptr;
190 };
191 
192 }  // namespace metal
193 }  // namespace gpu
194 }  // namespace tflite
195 
196 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_INFERENCE_CONTEXT_H_
197