xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/backend_delegate.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 //
2 // backend_delegate.hpp
3 //
4 // Copyright © 2024 Apple Inc. All rights reserved.
5 //
6 // Please refer to the license found in the LICENSE file in the root directory of the source tree.
7 
8 #pragma once
9 
10 #include <model_logging_options.h>
11 #include <system_error>
12 #include <unordered_map>
13 #include <vector>
14 
15 namespace executorchcoreml {
16 
17 class ModelEventLogger;
18 class MultiArray;
19 class Buffer;
20 
21 /// An abstract class for a CoreML delegate to implement.
22 class BackendDelegate {
23 public:
24     /// The model handle.
25     using Handle = void;
26 
27     struct Config {
28         // Max models cache size in bytes.
29         size_t max_models_cache_size = 10 * size_t(1024) * size_t(1024) * size_t(1024);
30         // If set to `true`, delegate pre-warms the most recently used asset.
31         bool should_prewarm_asset = false;
32         // If set to `true`, delegate pre-warms the model in `init`.
33         bool should_prewarm_model = true;
34     };
35 
36     /// The error codes for the `BackendDelegate`.
37     enum class ErrorCode : int8_t {
38         CorruptedData = 1, // AOT blob can't be parsed.
39         CorruptedMetadata, // AOT blob has incorrect or missing metadata.
40         CorruptedModel, // AOT blob has incorrect or missing CoreML model.
41         BrokenModel, // Model doesn't match the input and output specifications.
42         CompilationFailed, // Model failed to compile.
43         ModelSaveFailed, // Failed to save Model to disk.
44         ModelCacheCreationFailed // Failed to create models cache.
45     };
46 
47     /// The error category for `BackendDelegate` errors.
48     struct ErrorCategory final : public std::error_category {
49     public:
50         /// Returns the name of the category.
namefinal51         inline const char* name() const noexcept override { return "CoreMLBackend"; }
52 
53         /// Returns a message from the error code.
54         std::string message(int code) const override;
55     };
56 
BackendDelegate()57     inline BackendDelegate() noexcept { }
58     virtual ~BackendDelegate() noexcept = default;
59 
60     BackendDelegate(BackendDelegate const&) = delete;
61     BackendDelegate& operator=(BackendDelegate const&) = delete;
62 
63     BackendDelegate(BackendDelegate&&) = default;
64     BackendDelegate& operator=(BackendDelegate&&) = default;
65 
66     /// Must initialize a CoreML model.
67     ///
68     /// The method receives the AOT blob that's embedded in the executorch
69     /// Program. The implementation must initialize the model and prepare it for
70     /// execution.
71     ///
72     /// @param processed The AOT blob.
73     /// @param specs The specs at the time of compilation.
74     /// @retval An opaque handle to the initialized blob or `nullptr` if the
75     /// initialization failed.
76     virtual Handle* init(Buffer processed, const std::unordered_map<std::string, Buffer>& specs) const noexcept = 0;
77 
78     /// Must execute the CoreML model with the specified handle.
79     ///
80     /// The `args` are inputs and outputs combined. It's the responsibility of the
81     /// implementation to find the inputs and the outputs from `args`. The
82     /// implementation must execute the model with the inputs and must populate
83     /// the outputs from the model prediction outputs.
84     ///
85     /// @param handle The model handle.
86     /// @param args The inputs and outputs to the model.
87     /// @param logging_options The model logging options.
88     /// @param event_logger The model event logger.
89     /// @param error   On failure, error is filled with the failure information.
90     /// @retval `true` if the execution succeeded otherwise `false`.
91     virtual bool execute(Handle* handle,
92                          const std::vector<MultiArray>& args,
93                          const ModelLoggingOptions& logging_options,
94                          ModelEventLogger* event_logger,
95                          std::error_code& error) const noexcept = 0;
96 
97     /// Must return `true` if the delegate is available for execution otherwise
98     /// `false`.
99     virtual bool is_available() const noexcept = 0;
100 
101     /// Must returns the number of inputs and the number of outputs for the
102     /// specified handle.
103     ///
104     /// The returned pair's first value is the number of inputs and the second
105     /// value is the number of outputs.
106     ///
107     /// @param handle The model handle.
108     /// @retval A pair with the number of inputs and the number of outputs.
109     virtual std::pair<size_t, size_t> get_num_arguments(Handle* handle) const noexcept = 0;
110 
111     /// Checks if the model handle is valid.
112     ///
113     /// @param handle The model handle.
114     /// @retval `true` if the model handle is valid otherwise `false`.
115     virtual bool is_valid_handle(Handle* handle) const noexcept = 0;
116 
117     /// Must unload the CoreML model with the specified handle.
118     ///
119     /// The returned pair's first value is the number of inputs and the second
120     /// value is the number of outputs.
121     ///
122     /// @param handle The model handle.
123     virtual void destroy(Handle* handle) const noexcept = 0;
124 
125     /// Purges the models cache.
126     ///
127     /// Compiled models are stored on-disk to improve the model load time. The
128     /// method tries to remove all the models that are not currently in-use.
129     virtual bool purge_models_cache() const noexcept = 0;
130 
131     /// Returns a delegate implementation with the specified config.
132     ///
133     /// @param config The delegate config.
134     /// @retval A delegate implementation.
135     static std::shared_ptr<BackendDelegate> make(const Config& config);
136 };
137 
138 /// Constructs a `std::error_code` from  a`BackendDelegate::ErrorCode`.
139 ///
140 /// @param code The backend error code.
141 /// @retval A `std::error_code` constructed from
142 /// the`BackendDelegate::ErrorCode`.
make_error_code(BackendDelegate::ErrorCode code)143 inline std::error_code make_error_code(BackendDelegate::ErrorCode code) {
144     static BackendDelegate::ErrorCategory errorCategory;
145     return { static_cast<int>(code), errorCategory };
146 }
147 } // namespace executorchcoreml
148 
149 namespace std {
150 template <> struct is_error_code_enum<executorchcoreml::BackendDelegate::ErrorCode> : true_type { };
151 } // namespace std
152