xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
23 #include "tensorflow/compiler/xla/status.h"
24 
25 namespace pjrt {
26 
27 using PJRT_ClientDeleter = std::function<void(PJRT_Client*)>;
28 
29 // Pass in an API pointer; receive a custom deleter for smart pointers.
30 // The lifetime of the Api pointed to must be longer than the client.
31 PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api);
32 
33 using PJRT_ErrorDeleter = std::function<void(PJRT_Error*)>;
34 
35 // Pass in an API pointer; receive a custom deleter for smart pointers.
36 // The lifetime of the Api pointed to must be longer than the error.
37 PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api);
38 
39 using PJRT_BufferDeleter = std::function<void(PJRT_Buffer*)>;
40 
41 // Pass in an API pointer; receive a custom deleter for smart pointers.
42 // The lifetime of the Api pointed to must be longer than the buffer.
43 PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api);
44 
45 // Fatal error logging if status is not success. This terminates the process
46 // and frees the PJRT_Error passed in.
47 void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api);
48 
49 absl::string_view GetPjrtErrorMessage(const PJRT_Error* error,
50                                       const PJRT_Api* api);
51 
52 xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api);
53 
54 tensorflow::error::Code PjrtErrorToStatusCode(const PJRT_Error* error,
55                                               const PJRT_Api* api);
56 
57 PJRT_Error_Code StatusCodeToPjrtErrorCode(tensorflow::error::Code code);
58 
59 using PJRT_EventDeleter = std::function<void(PJRT_Event*)>;
60 
61 // Pass in an API pointer; receive a custom deleter for smart pointers.
62 // The lifetime of the Api pointed to must be longer than the event.
63 PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api);
64 
65 }  // namespace pjrt
66 
67 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_C_PJRT_C_API_HELPERS_H_
68