xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#import <torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h>
2
3@implementation PTMCoreMLFeatureProvider {
4  NSMutableDictionary *_featureValuesForName;
5}
6
7@synthesize featureNames = _featureNames;
8
9- (instancetype)initWithFeatureNames:(NSSet<NSString *> *)featureNames {
10  if (self = [super init]) {
11    _featureNames = featureNames;
12    _featureValuesForName = [NSMutableDictionary dictionary];
13  }
14  return self;
15}
16
17- (void)clearInputTensors {
18  [_featureValuesForName removeAllObjects];
19}
20
21- (void)setInputTensor:(const at::Tensor&)tensor forFeatureName:(NSString *)name {
22  NSMutableArray *shape = [NSMutableArray new];
23  for (auto& dim : tensor.sizes().vec()) {
24    [shape addObject:@(dim)];
25  }
26
27  NSMutableArray *strides = [NSMutableArray new];
28  for (auto& step : tensor.strides().vec()) {
29    [strides addObject:@(step)];
30  }
31
32  NSError* error = nil;
33  MLMultiArray *mlArray =
34    [[MLMultiArray alloc]
35     initWithDataPointer:tensor.mutable_data_ptr<float>()
36     shape:shape
37     dataType:MLMultiArrayDataTypeFloat32
38     strides:strides
39     deallocator:(^(void* bytes){})
40     error:&error];
41  MLFeatureValue *value = [MLFeatureValue featureValueWithMultiArray:mlArray];
42  if (value) {
43    _featureValuesForName[name] = value;
44  }
45}
46
47- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
48  return _featureValuesForName[featureName];
49}
50
51@end
52