1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_CUSTOM_CALL_REGISTRY_H_ 17 #define XLA_RUNTIME_CUSTOM_CALL_REGISTRY_H_ 18 19 #include <memory> 20 21 #include "llvm/ADT/StringRef.h" 22 #include "tensorflow/compiler/xla/runtime/custom_call.h" 23 24 namespace xla { 25 namespace runtime { 26 27 // Custom call registry is a container for the custom calls that looks up the 28 // handler implementing the custom call by name at run time. It is used to 29 // implement a generic `rt.custom_call` runtime intrinsic. 30 // 31 // For low overhead custom calls prefer direct custom calls that linked with the 32 // compiled executable and bypass by-name look up (see DirectCustomCallLibrary). 33 // 34 // TODO(ezhulenev): Consider removing this registry, because we'll likely not 35 // need it for any of the practical purposes, and it's currently used only in 36 // tests. We also likely don't need the generic custom call API. 37 class CustomCallRegistry { 38 public: 39 // The type for custom call registration functions. 40 using RegistrationFunction = void (*)(CustomCallRegistry*); 41 42 CustomCallRegistry(); 43 ~CustomCallRegistry() = default; 44 45 CustomCallRegistry(const CustomCallRegistry&) = delete; 46 CustomCallRegistry& operator=(const CustomCallRegistry&) = delete; 47 48 void Register(std::unique_ptr<class CustomCall> custom_call); 49 50 class CustomCall* Find(llvm::StringRef callee) const; 51 52 private: 53 class Impl; 54 std::unique_ptr<Impl> impl_; 55 }; 56 57 // TODO(ezhulenev): Remove static registration mechanism, and pass custom call 58 // registry explicitly to the JitExecutable. It should be passed around as a 59 // part of XLA runtime execution context. 60 61 // Use this macro to add a function that will register custom calls that are 62 // statically linked in the binary. FUNC should be a function pointer with the 63 // prototype given by the CustomCallRegistry::RegistrationFunction alias. 64 #define XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION(FUNC) \ 65 XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION_IMPL(FUNC, __COUNTER__) 66 #define XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION_IMPL(FUNC, N) \ 67 XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION_IMPL_EXPAND(FUNC, N) 68 #define XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION_IMPL_EXPAND(FUNC, N) \ 69 static bool XLA_RUNTIME_static_custom_call_##N##_registered_ = []() { \ 70 ::xla::runtime::AddStaticCustomCallRegistration(FUNC); \ 71 return true; \ 72 }() 73 74 // Registers all statically linked custom calls in the given registry. 75 void RegisterStaticCustomCalls(CustomCallRegistry* custom_call_registry); 76 77 // Adds a custom call registration function to the registry. This should not be 78 // used directly; use XLA_RUNTIME_STATIC_CUSTOM_CALL_REGISTRATION instead. 79 void AddStaticCustomCallRegistration( 80 CustomCallRegistry::RegistrationFunction registration); 81 82 } // namespace runtime 83 } // namespace xla 84 85 #endif // XLA_RUNTIME_CUSTOM_CALL_REGISTRY_H_ 86