xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <ATen/native/metal/MetalContext.h>
2#import <ATen/native/metal/MetalTensorUtils.h>
3#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
5#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
6
7@implementation MPSCNNClampOp {
8  MPSImage* _X;
9  MPSImage* _Y;
10  NSNumber* _min;
11  NSNumber* _max;
12}
13
14+ (id<MPSCNNShaderOp>)newWithTextures:(NSArray<MPSImage*>*)textures
15                                 Args:(NSArray<NSNumber*>*)args {
16  MPSCNNClampOp* op = [MPSCNNClampOp new];
17  op->_X = textures[0];
18  op->_Y = textures[1];
19  op->_min = args[0];
20  op->_max = args[1];
21
22  return op;
23}
24
25- (void)encode:(id<MTLCommandBuffer>)cb {
26  id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
27  id<MTLComputePipelineState> state =
28      [[MetalContext sharedInstance] specializedPipelineState:"clamp"
29                                                    Constants:@[
30                                                      @(_min.floatValue),
31                                                      @(_max.floatValue),
32                                                      @(_X.featureChannels),
33                                                      @(_X.numberOfImages)
34                                                    ]];
35  [encoder setComputePipelineState:state];
36  [encoder setTexture:[_X texture] atIndex:0];
37  [encoder setTexture:[_Y texture] atIndex:1];
38  const auto& launchParams =
39      at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y);
40  [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
41          threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
42  [encoder endEncoding];
43}
44
45@end
46