xref: /aosp_15_r20/external/executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 
2 //
3 //  Copyright (c) 2023 Apple Inc. All rights reserved.
4 //  Provided subject to the LICENSE file in the top level directory.
5 //
6 
7 #pragma once
8 
9 #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
10 
11 @interface MPSGraph (VenturaOps)
12 
13 #if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
14 
15 typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
16   MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
17   MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
18   MPSGraphResizeNearestRoundingModeCeil = 2L,
19   MPSGraphResizeNearestRoundingModeFloor = 3L,
20   MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
21   MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
22 };
23 
24 // Define complex enums for MacOS 12
25 #define MPSDataTypeComplexBit 0x01000000
26 #define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
27 #define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
28 #endif
29 
30 - (MPSGraphTensor *_Nonnull)cumulativeSumWithTensor:(MPSGraphTensor *_Nonnull)tensor
31                                                axis:(NSInteger)axis
32                                                name:(NSString *_Nullable)name;
33 
34 - (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
35                                       axis:(NSInteger)axis
36                                       name:(NSString *_Nullable)name;
37 
38 - (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
39                                       axis:(NSInteger)axis
40                                 descending:(BOOL)descending
41                                       name:(NSString *_Nullable)name;
42 
43 - (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
44                                 axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
45                                 descending:(BOOL)descending
46                                       name:(NSString *_Nullable)name;
47 
48 - (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
49                                 axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
50                                       name:(NSString *_Nullable)name;
51 
52 - (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
53                                          axis:(NSInteger)axis
54                                          name:(NSString *_Nullable)name;
55 
56 - (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
57                                          axis:(NSInteger)axis
58                                    descending:(BOOL)descending
59                                          name:(NSString *_Nullable)name;
60 
61 - (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
62                                    axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
63                                    descending:(BOOL)descending
64                                          name:(NSString *_Nullable)name;
65 
66 - (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
67                                    axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
68                                          name:(NSString *_Nullable)name;
69 
70 - (MPSGraphTensor *_Nonnull)inverseOfTensor:(MPSGraphTensor *_Nonnull)inputTensor name:(NSString *_Nullable)name;
71 
72 - (MPSGraphTensor *_Nonnull)resizeNearestWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
73                                          sizeTensor:(MPSGraphTensor *_Nonnull)size
74                                 nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
75                                        centerResult:(BOOL)centerResult
76                                        alignCorners:(BOOL)alignCorners
77                                              layout:(MPSGraphTensorNamedDataLayout)layout
78                                                name:(NSString *_Nullable)name;
79 
80 - (MPSGraphTensor *_Nonnull)resizeNearestWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
81                                          sizeTensor:(MPSGraphTensor *_Nonnull)size
82                                   scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
83                                 nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
84                                              layout:(MPSGraphTensorNamedDataLayout)layout
85                                                name:(NSString *_Nullable)name;
86 
87 - (MPSGraphTensor *_Nonnull)resizeBilinearWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
88                                           sizeTensor:(MPSGraphTensor *_Nonnull)size
89                                         centerResult:(BOOL)centerResult
90                                         alignCorners:(BOOL)alignCorners
91                                               layout:(MPSGraphTensorNamedDataLayout)layout
92                                                 name:(NSString *_Nullable)name;
93 
94 - (MPSGraphTensor *_Nonnull)resizeBilinearWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
95                                           sizeTensor:(MPSGraphTensor *_Nonnull)size
96                                    scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
97                                               layout:(MPSGraphTensorNamedDataLayout)layout
98                                                 name:(NSString *_Nullable)name;
99 
100 - (MPSGraphTensor *_Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
101                                                       input:(MPSGraphTensor *_Nonnull)input
102                                         nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
103                                                centerResult:(BOOL)centerResult
104                                                alignCorners:(BOOL)alignCorners
105                                                      layout:(MPSGraphTensorNamedDataLayout)layout
106                                                        name:(NSString *_Nullable)name;
107 
108 - (MPSGraphTensor *_Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
109                                                       input:(MPSGraphTensor *_Nonnull)input
110                                           scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
111                                         nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
112                                                      layout:(MPSGraphTensorNamedDataLayout)layout
113                                                        name:(NSString *_Nullable)name;
114 
115 - (MPSGraphTensor *_Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
116                                                        input:(MPSGraphTensor *_Nonnull)input
117                                                 centerResult:(BOOL)centerResult
118                                                 alignCorners:(BOOL)alignCorners
119                                                       layout:(MPSGraphTensorNamedDataLayout)layout
120                                                         name:(NSString *_Nullable)name;
121 
122 - (MPSGraphTensor *_Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
123                                                        input:(MPSGraphTensor *_Nonnull)input
124                                            scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
125                                                       layout:(MPSGraphTensorNamedDataLayout)layout
126                                                         name:(NSString *_Nullable)name;
127 
128 - (MPSGraphTensor *_Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor *_Nonnull)source
129                                       coordinateTensor:(MPSGraphTensor *_Nonnull)coordinates
130                                                 layout:(MPSGraphTensorNamedDataLayout)layout
131                                   normalizeCoordinates:(BOOL)normalizeCoordinates
132                                    relativeCoordinates:(BOOL)relativeCoordinates
133                                           alignCorners:(BOOL)alignCorners
134                                            paddingMode:(MPSGraphPaddingMode)paddingMode
135                                           samplingMode:(MPSGraphResizeMode)samplingMode
136                                          constantValue:(double)constantValue
137                                                   name:(NSString *_Nullable)name;
138 
139 - (MPSGraphTensor *_Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor *_Nonnull)source
140                                       coordinateTensor:(MPSGraphTensor *_Nonnull)coordinates
141                                                 layout:(MPSGraphTensorNamedDataLayout)layout
142                                   normalizeCoordinates:(BOOL)normalizeCoordinates
143                                    relativeCoordinates:(BOOL)relativeCoordinates
144                                           alignCorners:(BOOL)alignCorners
145                                            paddingMode:(MPSGraphPaddingMode)paddingMode
146                                    nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
147                                          constantValue:(double)constantValue
148                                                   name:(NSString *_Nullable)name;
149 
150 - (MPSGraphTensor *_Nonnull)truncateWithTensor:(MPSGraphTensor *_Nonnull)tensor name:(NSString *_Nullable)name;
151 
152 - (MPSGraphTensor *_Nonnull)transposeTensor:(MPSGraphTensor *_Nonnull)tensor
153                                 permutation:(NSArray<NSNumber *> *_Nonnull)permutation
154                                        name:(NSString *_Nullable)name;
155 
156 - (MPSGraphTensor *_Nonnull)bitwiseANDWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
157                                         secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
158                                                    name:(NSString *_Nullable)name;
159 
160 - (MPSGraphTensor *_Nonnull)bitwiseORWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
161                                        secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
162                                                   name:(NSString *_Nullable)name;
163 
164 - (MPSGraphTensor *_Nonnull)bitwiseXORWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
165                                         secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
166                                                    name:(NSString *_Nullable)name;
167 
168 - (MPSGraphTensor *_Nonnull)bitwiseNOTWithTensor:(MPSGraphTensor *_Nonnull)tensor name:(NSString *_Nullable)name;
169 
170 #if !defined(MAC_OS_X_VERSION_12_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_12_2)
171 - (MPSGraphTensor *_Nullable)expandDimsOfTensor:(MPSGraphTensor *_Nullable)tensor
172                                            axis:(NSInteger)axis
173                                            name:(NSString *_Nullable)name;
174 
175 - (MPSGraphTensor *_Nullable)expandDimsOfTensor:(MPSGraphTensor *_Nullable)tensor
176                                            axes:(NSArray<NSNumber *> *_Nullable)axes
177                                            name:(NSString *_Nullable)name;
178 
179 - (MPSGraphTensor *_Nullable)squeezeTensor:(MPSGraphTensor *_Nullable)tensor
180                                       axes:(NSArray<NSNumber *> *_Nullable)axes
181                                       name:(NSString *_Nullable)name;
182 
183 - (MPSGraphTensor *_Nullable)squeezeTensor:(MPSGraphTensor *_Nullable)tensor
184                                       axis:(NSInteger)axis
185                                       name:(NSString *_Nullable)name;
186 
187 - (NSArray<MPSGraphTensor *> *_Nullable)
188     maxPooling2DReturnIndicesWithSourceTensor:(MPSGraphTensor *_Nullable)source
189                                    descriptor:(MPSGraphPooling2DOpDescriptor *_Nullable)descriptor
190                                          name:(NSString *_Nullable)name;
191 
192 - (MPSGraphTensor *_Nullable)coordinateAlongAxis:(NSInteger)axis
193                                  withShapeTensor:(MPSGraphTensor *_Nullable)shapeTensor
194                                             name:(NSString *_Nullable)name;
195 
196 - (NSArray<MPSGraphTensor *> *_Nullable)splitTensor:(MPSGraphTensor *_Nullable)tensor
197                                    splitSizesTensor:(MPSGraphTensor *_Nullable)splitSizesTensor
198                                                axis:(NSInteger)axis
199                                                name:(NSString *_Nullable)name;
200 
201 - (NSArray<MPSGraphTensor *> *_Nullable)splitTensor:(MPSGraphTensor *_Nullable)tensor
202                                          splitSizes:(NSArray<NSNumber *> *_Nullable)splitSizes
203                                                axis:(NSInteger)axis
204                                                name:(NSString *_Nullable)name;
205 #endif
206 
207 @end
208