1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h" 23 #include "tensorflow/compiler/xla/status.h" 24 25 namespace pjrt { 26 27 using PJRT_ClientDeleter = std::function<void(PJRT_Client*)>; 28 29 // Pass in an API pointer; receive a custom deleter for smart pointers. 30 // The lifetime of the Api pointed to must be longer than the client. 31 PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api); 32 33 using PJRT_ErrorDeleter = std::function<void(PJRT_Error*)>; 34 35 // Pass in an API pointer; receive a custom deleter for smart pointers. 36 // The lifetime of the Api pointed to must be longer than the error. 37 PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api); 38 39 using PJRT_BufferDeleter = std::function<void(PJRT_Buffer*)>; 40 41 // Pass in an API pointer; receive a custom deleter for smart pointers. 42 // The lifetime of the Api pointed to must be longer than the buffer. 43 PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api); 44 45 // Fatal error logging if status is not success. This terminates the process 46 // and frees the PJRT_Error passed in. 47 void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api); 48 49 absl::string_view GetPjrtErrorMessage(const PJRT_Error* error, 50 const PJRT_Api* api); 51 52 xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api); 53 54 tensorflow::error::Code PjrtErrorToStatusCode(const PJRT_Error* error, 55 const PJRT_Api* api); 56 57 PJRT_Error_Code StatusCodeToPjrtErrorCode(tensorflow::error::Code code); 58 59 using PJRT_EventDeleter = std::function<void(PJRT_Event*)>; 60 61 // Pass in an API pointer; receive a custom deleter for smart pointers. 62 // The lifetime of the Api pointed to must be longer than the event. 63 PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api); 64 65 } // namespace pjrt 66 67 #endif // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_ 68