1 //
2 // backend_delegate.hpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8 #pragma once
9
10 #include <model_logging_options.h>
11 #include <system_error>
12 #include <unordered_map>
13 #include <vector>
14
15 namespace executorchcoreml {
16
17 class ModelEventLogger;
18 class MultiArray;
19 class Buffer;
20
21 /// An abstract class for a CoreML delegate to implement.
22 class BackendDelegate {
23 public:
24 /// The model handle.
25 using Handle = void;
26
27 struct Config {
28 // Max models cache size in bytes.
29 size_t max_models_cache_size = 10 * size_t(1024) * size_t(1024) * size_t(1024);
30 // If set to `true`, delegate pre-warms the most recently used asset.
31 bool should_prewarm_asset = false;
32 // If set to `true`, delegate pre-warms the model in `init`.
33 bool should_prewarm_model = true;
34 };
35
36 /// The error codes for the `BackendDelegate`.
37 enum class ErrorCode : int8_t {
38 CorruptedData = 1, // AOT blob can't be parsed.
39 CorruptedMetadata, // AOT blob has incorrect or missing metadata.
40 CorruptedModel, // AOT blob has incorrect or missing CoreML model.
41 BrokenModel, // Model doesn't match the input and output specifications.
42 CompilationFailed, // Model failed to compile.
43 ModelSaveFailed, // Failed to save Model to disk.
44 ModelCacheCreationFailed // Failed to create models cache.
45 };
46
47 /// The error category for `BackendDelegate` errors.
48 struct ErrorCategory final : public std::error_category {
49 public:
50 /// Returns the name of the category.
namefinal51 inline const char* name() const noexcept override { return "CoreMLBackend"; }
52
53 /// Returns a message from the error code.
54 std::string message(int code) const override;
55 };
56
BackendDelegate()57 inline BackendDelegate() noexcept { }
58 virtual ~BackendDelegate() noexcept = default;
59
60 BackendDelegate(BackendDelegate const&) = delete;
61 BackendDelegate& operator=(BackendDelegate const&) = delete;
62
63 BackendDelegate(BackendDelegate&&) = default;
64 BackendDelegate& operator=(BackendDelegate&&) = default;
65
66 /// Must initialize a CoreML model.
67 ///
68 /// The method receives the AOT blob that's embedded in the executorch
69 /// Program. The implementation must initialize the model and prepare it for
70 /// execution.
71 ///
72 /// @param processed The AOT blob.
73 /// @param specs The specs at the time of compilation.
74 /// @retval An opaque handle to the initialized blob or `nullptr` if the
75 /// initialization failed.
76 virtual Handle* init(Buffer processed, const std::unordered_map<std::string, Buffer>& specs) const noexcept = 0;
77
78 /// Must execute the CoreML model with the specified handle.
79 ///
80 /// The `args` are inputs and outputs combined. It's the responsibility of the
81 /// implementation to find the inputs and the outputs from `args`. The
82 /// implementation must execute the model with the inputs and must populate
83 /// the outputs from the model prediction outputs.
84 ///
85 /// @param handle The model handle.
86 /// @param args The inputs and outputs to the model.
87 /// @param logging_options The model logging options.
88 /// @param event_logger The model event logger.
89 /// @param error On failure, error is filled with the failure information.
90 /// @retval `true` if the execution succeeded otherwise `false`.
91 virtual bool execute(Handle* handle,
92 const std::vector<MultiArray>& args,
93 const ModelLoggingOptions& logging_options,
94 ModelEventLogger* event_logger,
95 std::error_code& error) const noexcept = 0;
96
97 /// Must return `true` if the delegate is available for execution otherwise
98 /// `false`.
99 virtual bool is_available() const noexcept = 0;
100
101 /// Must returns the number of inputs and the number of outputs for the
102 /// specified handle.
103 ///
104 /// The returned pair's first value is the number of inputs and the second
105 /// value is the number of outputs.
106 ///
107 /// @param handle The model handle.
108 /// @retval A pair with the number of inputs and the number of outputs.
109 virtual std::pair<size_t, size_t> get_num_arguments(Handle* handle) const noexcept = 0;
110
111 /// Checks if the model handle is valid.
112 ///
113 /// @param handle The model handle.
114 /// @retval `true` if the model handle is valid otherwise `false`.
115 virtual bool is_valid_handle(Handle* handle) const noexcept = 0;
116
117 /// Must unload the CoreML model with the specified handle.
118 ///
119 /// The returned pair's first value is the number of inputs and the second
120 /// value is the number of outputs.
121 ///
122 /// @param handle The model handle.
123 virtual void destroy(Handle* handle) const noexcept = 0;
124
125 /// Purges the models cache.
126 ///
127 /// Compiled models are stored on-disk to improve the model load time. The
128 /// method tries to remove all the models that are not currently in-use.
129 virtual bool purge_models_cache() const noexcept = 0;
130
131 /// Returns a delegate implementation with the specified config.
132 ///
133 /// @param config The delegate config.
134 /// @retval A delegate implementation.
135 static std::shared_ptr<BackendDelegate> make(const Config& config);
136 };
137
138 /// Constructs a `std::error_code` from a`BackendDelegate::ErrorCode`.
139 ///
140 /// @param code The backend error code.
141 /// @retval A `std::error_code` constructed from
142 /// the`BackendDelegate::ErrorCode`.
make_error_code(BackendDelegate::ErrorCode code)143 inline std::error_code make_error_code(BackendDelegate::ErrorCode code) {
144 static BackendDelegate::ErrorCategory errorCategory;
145 return { static_cast<int>(code), errorCategory };
146 }
147 } // namespace executorchcoreml
148
149 namespace std {
150 template <> struct is_error_code_enum<executorchcoreml::BackendDelegate::ErrorCode> : true_type { };
151 } // namespace std
152