xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/op_allowlist.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
4 #ifdef TEMPLATE_SELECTIVE_BUILD
5 #include <ATen/selected_mobile_ops.h>
6 #endif
7 
8 /**
9  * This header implements functionality to build PyTorch with only a certain
10  * set of operators (+ dependencies) included.
11  *
12  * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
13  *   two ops will be included in your build.  The allowlist records operators
14  *   only, no overloads; if you include aten::add, all overloads of aten::add
15  *   will be included.
16  *
17  * Internally, this is done by removing the operator registration calls
18  * using compile time programming, and the linker will then prune all
19  * operator functions that weren't registered.
20  * See Note [Selective build] for more details
21  *
22  * WARNING: The allowlist mechanism doesn't work for all ways you could go about
23  * registering an operator.  If the dispatch key / operator name is not
24  * sufficiently obvious at compile time, then the allowlisting mechanism
25  * will fail (and the operator will be included in the binary anyway).
26  */
27 
28 #include <c10/util/string_view.h>
29 #include <c10/core/DispatchKey.h>
30 #include <c10/macros/Macros.h>
31 
32 
33 #if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
34 #include <ATen/record_function.h>
35 #endif
36 
37 namespace c10 {
38 
39 namespace impl {
40 
41 constexpr bool allowlist_contains(string_view allowlist, string_view item);  // Forward Declare
42 
43 /**
44  * In selective build mode returns true/false depending on whether a build
45  * feature is available or not.
46  *
47  * In instrumenting mode (tracing mode), always returns true, and doesn't
48  * trigger any side effects.
49  */
is_build_feature_available(const char * name)50 constexpr bool is_build_feature_available(const char* name) {
51 #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
52   // Selective Build mode.
53 #if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
54   (void)name;
55   return true;
56 #else
57   return allowlist_contains(
58     C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
59     name);
60 #endif
61 
62 #else
63   // Instrumenting mode.
64   (void)name;
65   return true;
66 #endif
67 }
68 
69 [[noreturn]] void build_feature_required_feature_not_available(const char* feature);
70 
71 /**
72  * Use BUILD_FEATURE_REQUIRED macro in user-code.
73  *
74  * In selective build mode becomes a no-op if the build feature passed
75  * in is available. If not available, throws an exception (c10::Error).
76  * The compiler is able to perform dead code elimination for code
77  * following this method if the build feature is not available.
78  *
79  * In instrumenting mode (tracing mode), registers (as a side effect)
80  * the presence of this specific build feature being triggered.
81  */
82 #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)  // selective build mode
83 
84 #if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
85 #define BUILD_FEATURE_REQUIRED(NAME)                                 \
86   if (!c10::impl::is_build_feature_available(NAME)) {                \
87     ::c10::impl::build_feature_required_feature_not_available(NAME); \
88   }
89 #else  // Everything trivially selected
90 #define BUILD_FEATURE_REQUIRED(NAME)
91 
92 #endif
93 
94 #else  // trace mode
95 #define BUILD_FEATURE_REQUIRED(NAME)  \
96   RECORD_FUNCTION_WITH_SCOPE(         \
97       at::RecordScope::BUILD_FEATURE, \
98       std::string(NAME),              \
99       {});
100 #endif
101 
102 // Use this macro, and not is_build_feature_available
103 #define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
104 
105 // returns true iff allowlist contains item
106 // allowlist_contains("a;bc;d", "bc") == true
allowlist_contains(string_view allowlist,string_view item)107 constexpr bool allowlist_contains(string_view allowlist, string_view item) {
108     //Choose a really big value for next so that if something goes wrong
109     //this code will blow up in a hopefully detectable way.
110     size_t next = std::numeric_limits<size_t>::max();
111     for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
112       next = allowlist.find(';', cur);
113       if (next != string_view::npos) {
114         if (allowlist.substr(cur, next - cur).compare(item) == 0) {
115           return true;
116         }
117         next++;
118       } else {
119         if (allowlist.substr(cur).compare(item) == 0) {
120           return true;
121         }
122         break;
123       }
124     }
125     return false;
126 }
127 
128 // Returns true iff the given op name is on the allowlist
129 // and should be registered
op_allowlist_check(string_view op_name)130 constexpr bool op_allowlist_check(string_view op_name [[maybe_unused]]) {
131   assert(op_name.find("::") != string_view::npos);
132   // Use assert() instead of throw() due to a gcc bug. See:
133   // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
134   // https://github.com/fmtlib/fmt/issues/682
135   assert(op_name.find("(") == string_view::npos);
136 #if !defined(TORCH_OPERATOR_WHITELIST)
137   // If the TORCH_OPERATOR_WHITELIST parameter is not defined,
138   // all ops are to be registered
139   return true;
140 #else
141   return allowlist_contains(
142     C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
143     // This function is majorly used for mobile selective build with
144     // root operators, where the overload is included in the allowlist.
145     op_name);
146     // // Strip overload name (as allowlist doesn't contain overloads)
147     // // Another function based on this may be added when there's usage
148     // // on op names without overload.
149     // OperatorNameView::parse(op_name).name);
150 #endif
151 }
152 
153 // Returns true iff the given schema string is on the allowlist
154 // and should be registered
schema_allowlist_check(string_view schema)155 constexpr bool schema_allowlist_check(string_view schema) {
156 #if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
157   return true;
158 #else
159   return op_allowlist_check(schema.substr(0, schema.find("(")));
160 #endif
161 }
162 
163 // Returns true iff the given custom class name is on the allowlist
164 // and should be registered
custom_class_allowlist_check(string_view custom_class_name)165 constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
166 #if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
167   // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
168   // all custom classes are to be registered
169   (void)custom_class_name;
170   return true;
171 #else
172   return allowlist_contains(
173     C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
174     custom_class_name);
175 #endif
176 }
177 
178 // schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
179 // Add this API to pass arbitrary allowlist.
op_allowlist_contains_name_in_schema(string_view allowlist,string_view schema)180 constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {
181   return allowlist_contains(allowlist, schema.substr(0, schema.find("(")));
182 }
183 
184 // Returns true iff the given dispatch key is on the allowlist
185 // and should be registered.  When we turn this on, the list of valid
186 // mobile dispatch keys is hard coded (but you need to make sure
187 // that you have the correct set of dispatch keys for this).
dispatch_key_allowlist_check(DispatchKey)188 constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
189 #ifdef C10_MOBILE
190   return true;
191   // Disabled for now: to be enabled later!
192   // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
193 #else
194   return true;
195 #endif
196 }
197 
198 } // namespace impl
199 } // namespace c10
200