1#import <ATen/native/metal/MetalContext.h> 2#import <ATen/native/metal/MetalTensorUtils.h> 3#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h> 4#import <ATen/native/metal/mpscnn/MPSCNNUtils.h> 5#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h> 6 7@implementation MPSCNNClampOp { 8 MPSImage* _X; 9 MPSImage* _Y; 10 NSNumber* _min; 11 NSNumber* _max; 12} 13 14+ (id<MPSCNNShaderOp>)newWithTextures:(NSArray<MPSImage*>*)textures 15 Args:(NSArray<NSNumber*>*)args { 16 MPSCNNClampOp* op = [MPSCNNClampOp new]; 17 op->_X = textures[0]; 18 op->_Y = textures[1]; 19 op->_min = args[0]; 20 op->_max = args[1]; 21 22 return op; 23} 24 25- (void)encode:(id<MTLCommandBuffer>)cb { 26 id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder]; 27 id<MTLComputePipelineState> state = 28 [[MetalContext sharedInstance] specializedPipelineState:"clamp" 29 Constants:@[ 30 @(_min.floatValue), 31 @(_max.floatValue), 32 @(_X.featureChannels), 33 @(_X.numberOfImages) 34 ]]; 35 [encoder setComputePipelineState:state]; 36 [encoder setTexture:[_X texture] atIndex:0]; 37 [encoder setTexture:[_Y texture] atIndex:1]; 38 const auto& launchParams = 39 at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y); 40 [encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid 41 threadsPerThreadgroup:launchParams.threadsPerThreadgroup]; 42 [encoder endEncoding]; 43} 44 45@end 46