xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/coreml_executor.mm (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#import "tensorflow/lite/delegates/coreml/coreml_executor.h"
16
17#import <CoreML/CoreML.h>
18#import <Foundation/Foundation.h>
19
20#include <fstream>
21#include <iostream>
22
23namespace {
24// Returns NSURL for a temporary file.
25NSURL* createTemporaryFile() {
26  // Get temporary directory.
27  NSURL* temporaryDirectoryURL = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES];
28  // Generate a Unique file name to use.
29  NSString* temporaryFilename = [[NSProcessInfo processInfo] globallyUniqueString];
30  // Create URL to that file.
31  NSURL* temporaryFileURL = [temporaryDirectoryURL URLByAppendingPathComponent:temporaryFilename];
32
33  return temporaryFileURL;
34}
35}  // namespace
36
37@interface MultiArrayFeatureProvider : NSObject <MLFeatureProvider> {
38  const std::vector<TensorData>* _inputs;
39  NSSet* _featureNames;
40}
41
42- (instancetype)initWithInputs:(const std::vector<TensorData>*)inputs
43                 coreMlVersion:(int)coreMlVersion;
44- (MLFeatureValue*)featureValueForName:(NSString*)featureName API_AVAILABLE(ios(11));
45- (NSSet<NSString*>*)featureNames;
46
47@property(nonatomic, readonly) int coreMlVersion;
48
49@end
50
51@implementation MultiArrayFeatureProvider
52
53- (instancetype)initWithInputs:(const std::vector<TensorData>*)inputs
54                 coreMlVersion:(int)coreMlVersion {
55  self = [super init];
56  _inputs = inputs;
57  _coreMlVersion = coreMlVersion;
58  for (auto& input : *_inputs) {
59    if (input.name.empty()) {
60      return nil;
61    }
62  }
63  return self;
64}
65
66- (NSSet<NSString*>*)featureNames {
67  if (_featureNames == nil) {
68    NSMutableArray* names = [[NSMutableArray alloc] init];
69    for (auto& input : *_inputs) {
70      [names addObject:[NSString stringWithCString:input.name.c_str()
71                                          encoding:[NSString defaultCStringEncoding]]];
72    }
73    _featureNames = [NSSet setWithArray:names];
74  }
75  return _featureNames;
76}
77
78- (MLFeatureValue*)featureValueForName:(NSString*)featureName {
79  for (auto& input : *_inputs) {
80    if ([featureName cStringUsingEncoding:NSUTF8StringEncoding] == input.name) {
81      // TODO(b/141492326): Update shape handling for higher ranks
82      NSArray* shape = @[
83        @(input.shape[0]),
84        @(input.shape[1]),
85        @(input.shape[2]),
86      ];
87      NSArray* strides = @[
88        @(input.shape[1] * input.shape[2]),
89        @(input.shape[2]),
90        @1,
91      ];
92
93      if ([self coreMlVersion] >= 3) {
94        shape = @[
95          @(input.shape[0]),
96          @(input.shape[1]),
97          @(input.shape[2]),
98          @(input.shape[3]),
99        ];
100        strides = @[
101          @(input.shape[1] * input.shape[2] * input.shape[3]),
102          @(input.shape[2] * input.shape[3]),
103          @(input.shape[3]),
104          @1,
105        ];
106      };
107      NSError* error = nil;
108      MLMultiArray* mlArray = [[MLMultiArray alloc] initWithDataPointer:(float*)input.data.data()
109                                                                  shape:shape
110                                                               dataType:MLMultiArrayDataTypeFloat32
111                                                                strides:strides
112                                                            deallocator:(^(void* bytes){
113                                                                        })error:&error];
114      if (error != nil) {
115        NSLog(@"Failed to create MLMultiArray for feature %@ error: %@", featureName,
116              [error localizedDescription]);
117        return nil;
118      }
119      auto* mlFeatureValue = [MLFeatureValue featureValueWithMultiArray:mlArray];
120      return mlFeatureValue;
121    }
122  }
123
124  NSLog(@"Feature %@ not found", featureName);
125  return nil;
126}
127@end
128
129@implementation CoreMlExecutor
130- (bool)invokeWithInputs:(const std::vector<TensorData>&)inputs
131                 outputs:(const std::vector<TensorData>&)outputs {
132  if (_model == nil) {
133    return NO;
134  }
135  NSError* error = nil;
136  MultiArrayFeatureProvider* inputFeature =
137      [[MultiArrayFeatureProvider alloc] initWithInputs:&inputs coreMlVersion:[self coreMlVersion]];
138  if (inputFeature == nil) {
139    NSLog(@"inputFeature is not initialized.");
140    return NO;
141  }
142  MLPredictionOptions* options = [[MLPredictionOptions alloc] init];
143  id<MLFeatureProvider> outputFeature = [_model predictionFromFeatures:inputFeature
144                                                               options:options
145                                                                 error:&error];
146  if (error != nil) {
147    NSLog(@"Error executing model: %@", [error localizedDescription]);
148    return NO;
149  }
150  NSSet<NSString*>* outputFeatureNames = [outputFeature featureNames];
151  for (auto& output : outputs) {
152    NSString* outputName = [NSString stringWithCString:output.name.c_str()
153                                              encoding:[NSString defaultCStringEncoding]];
154    MLFeatureValue* outputValue =
155        [outputFeature featureValueForName:[outputFeatureNames member:outputName]];
156    auto* data = [outputValue multiArrayValue];
157    float* outputData = (float*)data.dataPointer;
158    if (outputData == nullptr) {
159      return NO;
160    }
161    memcpy((float*)output.data.data(), outputData, output.data.size() * sizeof(output.data[0]));
162  }
163  return YES;
164}
165
166- (bool)cleanup {
167  NSError* error = nil;
168  [[NSFileManager defaultManager] removeItemAtPath:_mlModelFilePath error:&error];
169  if (error != nil) {
170    NSLog(@"Failed cleaning up model: %@", [error localizedDescription]);
171    return NO;
172  }
173  [[NSFileManager defaultManager] removeItemAtPath:_compiledModelFilePath error:&error];
174  if (error != nil) {
175    NSLog(@"Failed cleaning up compiled model: %@", [error localizedDescription]);
176    return NO;
177  }
178  return YES;
179}
180
181- (NSURL*)saveModel:(CoreML::Specification::Model*)model {
182  NSURL* modelUrl = createTemporaryFile();
183  NSString* modelPath = [modelUrl path];
184  if (model->specificationversion() == 3) {
185    _coreMlVersion = 2;
186  } else if (model->specificationversion() == 4) {
187    _coreMlVersion = 3;
188  } else {
189    NSLog(@"Only Core ML models with specification version 3 or 4 are supported");
190    return nil;
191  }
192  // Flush data to file.
193  // TODO(karimnosseir): Can we mmap this instead of actual writing it to phone ?
194  std::ofstream file_stream([modelPath UTF8String], std::ios::out | std::ios::binary);
195  model->SerializeToOstream(&file_stream);
196  return modelUrl;
197}
198
199- (bool)build:(NSURL*)modelUrl {
200  NSError* error = nil;
201  NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error];
202  if (error != nil) {
203    NSLog(@"Error compiling model %@", [error localizedDescription]);
204    return NO;
205  }
206  _mlModelFilePath = [modelUrl path];
207  _compiledModelFilePath = [compileUrl path];
208
209  if (@available(iOS 12.0, *)) {
210    MLModelConfiguration* config = [MLModelConfiguration alloc];
211    config.computeUnits = MLComputeUnitsAll;
212    _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error];
213  } else {
214    _model = [MLModel modelWithContentsOfURL:compileUrl error:&error];
215  }
216  if (error != NULL) {
217    NSLog(@"Error Creating MLModel %@", [error localizedDescription]);
218    return NO;
219  }
220  return YES;
221}
222@end
223