xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/metal_spatial_tensor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the Licensgoe is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
17 #define TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
18 
19 #import <Metal/Metal.h>
20 
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/gpu_tensor.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h"
24 #include "tensorflow/lite/delegates/gpu/common/util.h"
25 #include "tensorflow/lite/delegates/gpu/metal/common.h"
26 #include "tensorflow/lite/delegates/gpu/metal/gpu_object.h"
27 
28 namespace tflite {
29 namespace gpu {
30 namespace metal {
31 
32 class MetalSpatialTensor : public GPUObject, public GpuSpatialTensor {
33  public:
MetalSpatialTensor()34   MetalSpatialTensor()
35       : memory_(nullptr),
36         texture_mem_(nullptr),
37         memory_owner_(true),
38         texture_mem_owner_(true) {}
39   MetalSpatialTensor(id<MTLBuffer> buffer, id<MTLTexture> texture,
40                      bool memory_owner, bool texture_mem_owner,
41                      const TensorDescriptor& descriptor);
42 
43   // Move only
44   MetalSpatialTensor(MetalSpatialTensor&& tensor);
45   MetalSpatialTensor& operator=(MetalSpatialTensor&& tensor);
46   MetalSpatialTensor(const MetalSpatialTensor&) = delete;
47   MetalSpatialTensor& operator=(const MetalSpatialTensor&) = delete;
48 
~MetalSpatialTensor()49   ~MetalSpatialTensor() override { Release(); }
50 
51   absl::Status GetGPUResources(const GPUObjectDescriptor* obj_ptr,
52                                GPUResourcesWithValue* resources) const override;
53 
Width()54   int Width() const override { return descriptor_.GetBHWDCShape().w; }
Height()55   int Height() const override { return descriptor_.GetBHWDCShape().h; }
Depth()56   int Depth() const override { return descriptor_.GetBHWDCShape().d; }
Channels()57   int Channels() const override { return descriptor_.GetBHWDCShape().c; }
Slices()58   int Slices() const override {
59     return DivideRoundUp(descriptor_.GetBHWDCShape().c, 4);
60   }
Batch()61   int Batch() const override { return descriptor_.GetBHWDCShape().b; }
62 
GetDescriptor()63   TensorDescriptor GetDescriptor() const override { return descriptor_; }
GetDataType()64   DataType GetDataType() const { return descriptor_.GetDataType(); }
GetStorageType()65   TensorStorageType GetStorageType() const {
66     return descriptor_.GetStorageType();
67   }
GetMemorySizeInBytes()68   uint64_t GetMemorySizeInBytes() const {
69     return descriptor_.GetMemorySizeInBytes();
70   }
71 
72   absl::Status CreateFromDescriptor(const TensorDescriptor& desc,
73                                     id<MTLDevice> device);
74   absl::Status UploadDescriptorData(const TensorDescriptor& desc,
75                                     id<MTLDevice> device);
76   absl::Status ToDescriptor(TensorDescriptor* desc, id<MTLDevice> device) const;
77 
78   absl::Status SetBufferHandle(id<MTLBuffer> buffer);
79   id<MTLBuffer> GetBufferHandle() const;
80 
81  private:
82   friend absl::Status CreateTensorSharedBuffer(
83       id<MTLBuffer> buffer, const TensorDescriptor& descriptor,
84       MetalSpatialTensor* result, uint64_t buffer_offset);
85 
86   friend absl::Status CreateTensorSharedImage2DBuffer(
87       id<MTLBuffer> buffer, const TensorDescriptor& descriptor,
88       int row_bytes_alignment, MetalSpatialTensor* result,
89       uint64_t buffer_offset);
90 
91   absl::Status WriteData(id<MTLDevice> device, const void* ptr);
92   absl::Status ReadData(id<MTLDevice> device, void* ptr) const;
93 
94   void Release();
95 
96   id<MTLBuffer> memory_;
97   id<MTLTexture> texture_mem_;
98   bool memory_owner_;
99   bool texture_mem_owner_;
100   TensorDescriptor descriptor_;
101   // for use with TEXTURE_2D and when texture created from buffer.
102   int aligned_texture_width_;
103   // used when created from shared buffer
104   uint64_t buffer_offset_ = 0;
105 };
106 
107 absl::Status CreateTensor(id<MTLDevice> device,
108                           const TensorDescriptor& descriptor,
109                           MetalSpatialTensor* result);
110 
111 absl::Status CreateTensorSharedBuffer(id<MTLBuffer> buffer,
112                                       const TensorDescriptor& descriptor,
113                                       MetalSpatialTensor* result,
114                                       uint64_t buffer_offset = 0);
115 
116 absl::Status CreateTensorSharedImage2DBuffer(id<MTLBuffer> buffer,
117                                              const TensorDescriptor& descriptor,
118                                              int row_bytes_alignment,
119                                              MetalSpatialTensor* result,
120                                              uint64_t buffer_offset = 0);
121 
122 TensorStorageType GetFastestStorageType(const GpuInfo& gpu_info);
123 
124 }  // namespace metal
125 }  // namespace gpu
126 }  // namespace tflite
127 
128 #endif  // TENSORFLOW_LITE_DELEGATES_GPU_METAL_METAL_SPATIAL_TENSOR_H_
129