1 #pragma once
2
3 // TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
4 #ifdef TEMPLATE_SELECTIVE_BUILD
5 #include <ATen/selected_mobile_ops.h>
6 #endif
7
8 /**
9 * This header implements functionality to build PyTorch with only a certain
10 * set of operators (+ dependencies) included.
11 *
12 * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
13 * two ops will be included in your build. The allowlist records operators
14 * only, no overloads; if you include aten::add, all overloads of aten::add
15 * will be included.
16 *
17 * Internally, this is done by removing the operator registration calls
18 * using compile time programming, and the linker will then prune all
19 * operator functions that weren't registered.
20 * See Note [Selective build] for more details
21 *
22 * WARNING: The allowlist mechanism doesn't work for all ways you could go about
23 * registering an operator. If the dispatch key / operator name is not
24 * sufficiently obvious at compile time, then the allowlisting mechanism
25 * will fail (and the operator will be included in the binary anyway).
26 */
27
28 #include <c10/util/string_view.h>
29 #include <c10/core/DispatchKey.h>
30 #include <c10/macros/Macros.h>
31
32
33 #if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
34 #include <ATen/record_function.h>
35 #endif
36
37 namespace c10 {
38
39 namespace impl {
40
41 constexpr bool allowlist_contains(string_view allowlist, string_view item); // Forward Declare
42
43 /**
44 * In selective build mode returns true/false depending on whether a build
45 * feature is available or not.
46 *
47 * In instrumenting mode (tracing mode), always returns true, and doesn't
48 * trigger any side effects.
49 */
is_build_feature_available(const char * name)50 constexpr bool is_build_feature_available(const char* name) {
51 #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
52 // Selective Build mode.
53 #if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
54 (void)name;
55 return true;
56 #else
57 return allowlist_contains(
58 C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
59 name);
60 #endif
61
62 #else
63 // Instrumenting mode.
64 (void)name;
65 return true;
66 #endif
67 }
68
69 [[noreturn]] void build_feature_required_feature_not_available(const char* feature);
70
71 /**
72 * Use BUILD_FEATURE_REQUIRED macro in user-code.
73 *
74 * In selective build mode becomes a no-op if the build feature passed
75 * in is available. If not available, throws an exception (c10::Error).
76 * The compiler is able to perform dead code elimination for code
77 * following this method if the build feature is not available.
78 *
79 * In instrumenting mode (tracing mode), registers (as a side effect)
80 * the presence of this specific build feature being triggered.
81 */
82 #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode
83
84 #if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
85 #define BUILD_FEATURE_REQUIRED(NAME) \
86 if (!c10::impl::is_build_feature_available(NAME)) { \
87 ::c10::impl::build_feature_required_feature_not_available(NAME); \
88 }
89 #else // Everything trivially selected
90 #define BUILD_FEATURE_REQUIRED(NAME)
91
92 #endif
93
94 #else // trace mode
95 #define BUILD_FEATURE_REQUIRED(NAME) \
96 RECORD_FUNCTION_WITH_SCOPE( \
97 at::RecordScope::BUILD_FEATURE, \
98 std::string(NAME), \
99 {});
100 #endif
101
102 // Use this macro, and not is_build_feature_available
103 #define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
104
105 // returns true iff allowlist contains item
106 // allowlist_contains("a;bc;d", "bc") == true
allowlist_contains(string_view allowlist,string_view item)107 constexpr bool allowlist_contains(string_view allowlist, string_view item) {
108 //Choose a really big value for next so that if something goes wrong
109 //this code will blow up in a hopefully detectable way.
110 size_t next = std::numeric_limits<size_t>::max();
111 for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
112 next = allowlist.find(';', cur);
113 if (next != string_view::npos) {
114 if (allowlist.substr(cur, next - cur).compare(item) == 0) {
115 return true;
116 }
117 next++;
118 } else {
119 if (allowlist.substr(cur).compare(item) == 0) {
120 return true;
121 }
122 break;
123 }
124 }
125 return false;
126 }
127
128 // Returns true iff the given op name is on the allowlist
129 // and should be registered
op_allowlist_check(string_view op_name)130 constexpr bool op_allowlist_check(string_view op_name [[maybe_unused]]) {
131 assert(op_name.find("::") != string_view::npos);
132 // Use assert() instead of throw() due to a gcc bug. See:
133 // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
134 // https://github.com/fmtlib/fmt/issues/682
135 assert(op_name.find("(") == string_view::npos);
136 #if !defined(TORCH_OPERATOR_WHITELIST)
137 // If the TORCH_OPERATOR_WHITELIST parameter is not defined,
138 // all ops are to be registered
139 return true;
140 #else
141 return allowlist_contains(
142 C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
143 // This function is majorly used for mobile selective build with
144 // root operators, where the overload is included in the allowlist.
145 op_name);
146 // // Strip overload name (as allowlist doesn't contain overloads)
147 // // Another function based on this may be added when there's usage
148 // // on op names without overload.
149 // OperatorNameView::parse(op_name).name);
150 #endif
151 }
152
153 // Returns true iff the given schema string is on the allowlist
154 // and should be registered
schema_allowlist_check(string_view schema)155 constexpr bool schema_allowlist_check(string_view schema) {
156 #if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
157 return true;
158 #else
159 return op_allowlist_check(schema.substr(0, schema.find("(")));
160 #endif
161 }
162
163 // Returns true iff the given custom class name is on the allowlist
164 // and should be registered
custom_class_allowlist_check(string_view custom_class_name)165 constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
166 #if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
167 // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
168 // all custom classes are to be registered
169 (void)custom_class_name;
170 return true;
171 #else
172 return allowlist_contains(
173 C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
174 custom_class_name);
175 #endif
176 }
177
178 // schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
179 // Add this API to pass arbitrary allowlist.
op_allowlist_contains_name_in_schema(string_view allowlist,string_view schema)180 constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {
181 return allowlist_contains(allowlist, schema.substr(0, schema.find("(")));
182 }
183
184 // Returns true iff the given dispatch key is on the allowlist
185 // and should be registered. When we turn this on, the list of valid
186 // mobile dispatch keys is hard coded (but you need to make sure
187 // that you have the correct set of dispatch keys for this).
dispatch_key_allowlist_check(DispatchKey)188 constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
189 #ifdef C10_MOBILE
190 return true;
191 // Disabled for now: to be enabled later!
192 // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
193 #else
194 return true;
195 #endif
196 }
197
198 } // namespace impl
199 } // namespace c10
200