1// 2// ETCoreModelStructurePath.mm 3// 4// Copyright © 2024 Apple Inc. All rights reserved. 5// 6// Please refer to the license found in the LICENSE file in the root directory of the source tree. 7 8#import "ETCoreMLModelStructurePath.h" 9 10#import "objc_safe_cast.h" 11 12namespace { 13using namespace executorchcoreml::modelstructure; 14 15enum ComponentType: uint8_t { 16 program, 17 function, 18 operation, 19 block 20}; 21 22template<typename T> void append_component(NSDictionary<NSString *, id> *component, Path& path); 23 24template<> void append_component<Path::Program>(NSDictionary<NSString *, id> * __unused component, Path& path) { 25 path.append_component(Path::Program()); 26} 27 28template<> void append_component<Path::Program::Function>(NSDictionary<NSString *, id> *component, Path& path) { 29 NSString *name = SAFE_CAST(component[@(Path::Program::Function::kNameKeyName)], NSString); 30 path.append_component(Path::Program::Function(name.UTF8String)); 31} 32 33template<> void append_component<Path::Program::Block>(NSDictionary<NSString *, id> *component, Path& path) { 34 NSNumber *index = SAFE_CAST(component[@(Path::Program::Block::kIndexKeyName)], NSNumber); 35 NSInteger indexValue = (index != nil) ? index.integerValue : -1; 36 path.append_component(Path::Program::Block(indexValue)); 37} 38 39template<> void append_component<Path::Program::Operation>(NSDictionary<NSString *, id> *component, Path& path) { 40 NSString *output_name = SAFE_CAST(component[@(Path::Program::Operation::kOutputKeyName)], NSString) ?: @""; 41 NSCAssert(output_name.length > 0, @"Component=%@ is missing %s key.", component, Path::Program::Operation::kOutputKeyName); 42 path.append_component(Path::Program::Operation(output_name.UTF8String)); 43} 44 45NSDictionary<NSString *, NSNumber *> *component_types() { 46 static NSDictionary<NSString *, NSNumber *> *result = nil; 47 static dispatch_once_t onceToken; 48 dispatch_once(&onceToken, ^{ 49 result = @{ 50 @(Path::Program::kTypeName): @(ComponentType::program), 51 @(Path::Program::Function::kTypeName): @(ComponentType::function), 52 @(Path::Program::Block::kTypeName): @(ComponentType::block), 53 @(Path::Program::Operation::kTypeName): @(ComponentType::operation) 54 }; 55 }); 56 57 return result; 58} 59 60Path to_path(NSArray<NSDictionary<NSString *, id> *> *components) { 61 Path path; 62 NSDictionary<NSString *, NSNumber *> *types = component_types(); 63 for (NSDictionary<NSString *, id> *component in components) { 64 NSString *type = SAFE_CAST(component[@(Path::kTypeKeyName)], NSString); 65 NSCAssert(type.length > 0, @"Component=%@ is missing %s key.", component, Path::kTypeKeyName); 66 switch (types[type].intValue) { 67 case ComponentType::program: { 68 append_component<Path::Program>(component, path); 69 break; 70 } 71 case ComponentType::function: { 72 append_component<Path::Program::Function>(component, path); 73 break; 74 } 75 case ComponentType::block: { 76 append_component<Path::Program::Block>(component, path); 77 break; 78 } 79 case ComponentType::operation: { 80 append_component<Path::Program::Operation>(component, path); 81 break; 82 } 83 default: { 84 NSCAssert(type.length == 0, @"Component=%@ has invalid type=%@.", component, type); 85 } 86 } 87 } 88 89 return path; 90} 91 92NSDictionary<NSString *, id> *to_dictionary(const Path::Program& __unused program) { 93 return @{@(Path::kTypeKeyName) : @(Path::Program::kTypeName)}; 94} 95 96NSDictionary<NSString *, id> *to_dictionary(const Path::Program::Function& function) { 97 return @{ 98 @(Path::kTypeKeyName) : @(Path::Program::Function::kTypeName), 99 @(Path::Program::Function::kNameKeyName) : @(function.name.c_str()) 100 }; 101} 102 103NSDictionary<NSString *, id> *to_dictionary(const Path::Program::Block& block) { 104 return @{ 105 @(Path::kTypeKeyName) : @(Path::Program::Block::kTypeName), 106 @(Path::Program::Block::kIndexKeyName) : @(block.index) 107 }; 108} 109 110NSDictionary<NSString *, id> *to_dictionary(const Path::Program::Operation& operation) { 111 return @{ 112 @(Path::kTypeKeyName) : @(Path::Program::Operation::kTypeName), 113 @(Path::Program::Operation::kOutputKeyName) : @(operation.output_name.c_str()) 114 }; 115} 116 117NSArray<NSDictionary<NSString *, id> *> *to_array(const Path& path) { 118 NSMutableArray<NSDictionary<NSString *, id> *> *result = [NSMutableArray arrayWithCapacity:path.size()]; 119 for (const auto& component : path.components()) { 120 NSDictionary<NSString *, id> *value = std::visit([](auto&& arg){ 121 return to_dictionary(arg); 122 }, component); 123 [result addObject:value]; 124 } 125 126 return result; 127} 128} 129 130@implementation ETCoreMLModelStructurePath 131 132- (instancetype)initWithUnderlyingValue:(executorchcoreml::modelstructure::Path)underlyingValue { 133 self = [super init]; 134 if (self) { 135 _underlyingValue = std::move(underlyingValue); 136 } 137 138 return self; 139} 140 141- (instancetype)initWithComponents:(NSArray<NSDictionary<NSString *, id> *> *)components { 142 auto underlyingValue = to_path(components); 143 return [self initWithUnderlyingValue:std::move(underlyingValue)]; 144} 145 146- (BOOL)isEqual:(id)object { 147 if (object == self) { 148 return YES; 149 } 150 151 if (![object isKindOfClass:self.class]) { 152 return NO; 153 } 154 155 return _underlyingValue == ((ETCoreMLModelStructurePath *)object)->_underlyingValue; 156} 157 158- (NSUInteger)hash { 159 return std::hash<executorchcoreml::modelstructure::Path>()(_underlyingValue); 160} 161 162- (instancetype)copyWithZone:(NSZone *)zone { 163 return [[ETCoreMLModelStructurePath allocWithZone:zone] initWithUnderlyingValue:_underlyingValue]; 164} 165 166- (nullable NSString *)operationOutputName { 167 using namespace executorchcoreml::modelstructure; 168 auto operation = std::get_if<Path::Program::Operation>(&(_underlyingValue.components().back())); 169 if (operation == nullptr) { 170 return nil; 171 } 172 173 return @(operation->output_name.c_str()); 174} 175 176- (NSArray<NSDictionary<NSString *, id> *> *)components { 177 return to_array(_underlyingValue); 178} 179 180- (NSString *)description { 181 return [NSString stringWithFormat:@"<MLModelStructurePath: %p> %@", (void *)self, self.components]; 182} 183 184 185@end 186