1#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h> 2 3@implementation PTMCoreMLFeatureProvider { 4 NSMutableDictionary *_featureValuesForName; 5} 6 7@synthesize featureNames = _featureNames; 8 9- (instancetype)initWithFeatureNames:(NSSet<NSString *> *)featureNames { 10 if (self = [super init]) { 11 _featureNames = featureNames; 12 _featureValuesForName = [NSMutableDictionary dictionary]; 13 } 14 return self; 15} 16 17- (void)clearInputTensors { 18 [_featureValuesForName removeAllObjects]; 19} 20 21- (void)setInputTensor:(const at::Tensor&)tensor forFeatureName:(NSString *)name { 22 NSMutableArray *shape = [NSMutableArray new]; 23 for (auto& dim : tensor.sizes().vec()) { 24 [shape addObject:@(dim)]; 25 } 26 27 NSMutableArray *strides = [NSMutableArray new]; 28 for (auto& step : tensor.strides().vec()) { 29 [strides addObject:@(step)]; 30 } 31 32 NSError* error = nil; 33 MLMultiArray *mlArray = 34 [[MLMultiArray alloc] 35 initWithDataPointer:tensor.mutable_data_ptr<float>() 36 shape:shape 37 dataType:MLMultiArrayDataTypeFloat32 38 strides:strides 39 deallocator:(^(void* bytes){}) 40 error:&error]; 41 MLFeatureValue *value = [MLFeatureValue featureValueWithMultiArray:mlArray]; 42 if (value) { 43 _featureValuesForName[name] = value; 44 } 45} 46 47- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName { 48 return _featureValuesForName[featureName]; 49} 50 51@end 52