1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
17
18 #include <memory>
19
20 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
21
22 namespace pjrt {
23
MakeClientDeleter(const PJRT_Api * api)24 PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) {
25 return [api](PJRT_Client* client) -> void {
26 PJRT_Client_Destroy_Args destroy_args;
27 destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE;
28 destroy_args.priv = nullptr;
29 destroy_args.client = client;
30
31 PJRT_Error* error = api->PJRT_Client_Destroy(&destroy_args);
32 // TODO(b/236710439): handle the error and remove this CHECK() call
33 CHECK(error == nullptr);
34 };
35 }
36
MakeErrorDeleter(const PJRT_Api * api)37 PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) {
38 return [api](PJRT_Error* error) -> void {
39 PJRT_Error_Destroy_Args destroy_args;
40 destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
41 destroy_args.priv = nullptr;
42 destroy_args.error = error;
43
44 api->PJRT_Error_Destroy(&destroy_args);
45 };
46 }
47
MakeBufferDeleter(const PJRT_Api * api)48 PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api) {
49 return [api](PJRT_Buffer* buffer) -> void {
50 PJRT_Buffer_Destroy_Args destroy_args;
51 destroy_args.struct_size = PJRT_Buffer_Destroy_Args_STRUCT_SIZE;
52 destroy_args.priv = nullptr;
53 destroy_args.buffer = buffer;
54
55 pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Destroy(&destroy_args), api);
56 };
57 }
58
PjrtErrorToStatus(const PJRT_Error * error,const PJRT_Api * api)59 xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api) {
60 xla::Status status;
61 if (error != nullptr) {
62 status = xla::Status(PjrtErrorToStatusCode(error, api),
63 GetPjrtErrorMessage(error, api));
64 }
65 return status;
66 }
67
PjrtErrorToStatusCode(const PJRT_Error * error,const PJRT_Api * api)68 tensorflow::error::Code PjrtErrorToStatusCode(const PJRT_Error* error,
69 const PJRT_Api* api) {
70 PJRT_Error_GetCode_Args args;
71 args.struct_size = PJRT_Error_GetCode_Args_STRUCT_SIZE;
72 args.priv = nullptr;
73 args.error = error;
74 api->PJRT_Error_GetCode(&args);
75 PJRT_Error_Code code = args.code;
76 switch (code) {
77 case PJRT_Error_Code_CANCELLED:
78 return tensorflow::error::CANCELLED;
79 case PJRT_Error_Code_UNKNOWN:
80 return tensorflow::error::UNKNOWN;
81 case PJRT_Error_Code_INVALID_ARGUMENT:
82 return tensorflow::error::INVALID_ARGUMENT;
83 case PJRT_Error_Code_DEADLINE_EXCEEDED:
84 return tensorflow::error::DEADLINE_EXCEEDED;
85 case PJRT_Error_Code_NOT_FOUND:
86 return tensorflow::error::NOT_FOUND;
87 case PJRT_Error_Code_ALREADY_EXISTS:
88 return tensorflow::error::ALREADY_EXISTS;
89 case PJRT_Error_Code_PERMISSION_DENIED:
90 return tensorflow::error::PERMISSION_DENIED;
91 case PJRT_Error_Code_RESOURCE_EXHAUSTED:
92 return tensorflow::error::RESOURCE_EXHAUSTED;
93 case PJRT_Error_Code_FAILED_PRECONDITION:
94 return tensorflow::error::FAILED_PRECONDITION;
95 case PJRT_Error_Code_ABORTED:
96 return tensorflow::error::ABORTED;
97 case PJRT_Error_Code_OUT_OF_RANGE:
98 return tensorflow::error::OUT_OF_RANGE;
99 case PJRT_Error_Code_UNIMPLEMENTED:
100 return tensorflow::error::UNIMPLEMENTED;
101 case PJRT_Error_Code_INTERNAL:
102 return tensorflow::error::INTERNAL;
103 case PJRT_Error_Code_UNAVAILABLE:
104 return tensorflow::error::UNAVAILABLE;
105 case PJRT_Error_Code_DATA_LOSS:
106 return tensorflow::error::DATA_LOSS;
107 case PJRT_Error_Code_UNAUTHENTICATED:
108 return tensorflow::error::UNAUTHENTICATED;
109 }
110 }
111
StatusCodeToPjrtErrorCode(tensorflow::error::Code code)112 PJRT_Error_Code StatusCodeToPjrtErrorCode(tensorflow::error::Code code) {
113 switch (code) {
114 case tensorflow::error::CANCELLED:
115 return PJRT_Error_Code::PJRT_Error_Code_CANCELLED;
116 case tensorflow::error::UNKNOWN:
117 return PJRT_Error_Code::PJRT_Error_Code_UNKNOWN;
118 case tensorflow::error::INVALID_ARGUMENT:
119 return PJRT_Error_Code::PJRT_Error_Code_INVALID_ARGUMENT;
120 case tensorflow::error::DEADLINE_EXCEEDED:
121 return PJRT_Error_Code::PJRT_Error_Code_DEADLINE_EXCEEDED;
122 case tensorflow::error::NOT_FOUND:
123 return PJRT_Error_Code::PJRT_Error_Code_NOT_FOUND;
124 case tensorflow::error::ALREADY_EXISTS:
125 return PJRT_Error_Code::PJRT_Error_Code_ALREADY_EXISTS;
126 case tensorflow::error::PERMISSION_DENIED:
127 return PJRT_Error_Code::PJRT_Error_Code_PERMISSION_DENIED;
128 case tensorflow::error::UNAUTHENTICATED:
129 return PJRT_Error_Code::PJRT_Error_Code_UNAUTHENTICATED;
130 case tensorflow::error::RESOURCE_EXHAUSTED:
131 return PJRT_Error_Code::PJRT_Error_Code_RESOURCE_EXHAUSTED;
132 case tensorflow::error::FAILED_PRECONDITION:
133 return PJRT_Error_Code::PJRT_Error_Code_FAILED_PRECONDITION;
134 case tensorflow::error::ABORTED:
135 return PJRT_Error_Code::PJRT_Error_Code_ABORTED;
136 case tensorflow::error::OUT_OF_RANGE:
137 return PJRT_Error_Code::PJRT_Error_Code_OUT_OF_RANGE;
138 case tensorflow::error::UNIMPLEMENTED:
139 return PJRT_Error_Code::PJRT_Error_Code_UNIMPLEMENTED;
140 case tensorflow::error::INTERNAL:
141 return PJRT_Error_Code::PJRT_Error_Code_INTERNAL;
142 case tensorflow::error::UNAVAILABLE:
143 return PJRT_Error_Code::PJRT_Error_Code_UNAVAILABLE;
144 case tensorflow::error::DATA_LOSS:
145 return PJRT_Error_Code::PJRT_Error_Code_DATA_LOSS;
146 case tensorflow::error::OK:
147 CHECK(false) << "Status::OK() cannot be converted to PJRT_Error code, "
148 "use nullptr instead";
149 case tensorflow::error::
150 DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_:
151 CHECK(false) << "got DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_"
152 "USE_DEFAULT_IN_SWITCH_INSTEAD_";
153 case tensorflow::error::Code_INT_MIN_SENTINEL_DO_NOT_USE_:
154 CHECK(false) << "got Code_INT_MIN_SENTINEL_DO_NOT_USE_";
155 case tensorflow::error::Code_INT_MAX_SENTINEL_DO_NOT_USE_:
156 CHECK(false) << "got Code_INT_MAX_SENTINEL_DO_NOT_USE_";
157 }
158 }
159
GetPjrtErrorMessage(const PJRT_Error * error,const PJRT_Api * api)160 absl::string_view GetPjrtErrorMessage(const PJRT_Error* error,
161 const PJRT_Api* api) {
162 PJRT_Error_Message_Args message_args;
163 message_args.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
164 message_args.priv = nullptr;
165 message_args.error = error;
166 api->PJRT_Error_Message(&message_args);
167 return absl::string_view(message_args.message, message_args.message_size);
168 }
169
LogFatalIfPjrtError(PJRT_Error * error,const PJRT_Api * api)170 void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api) {
171 std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> _error(
172 error, MakeErrorDeleter(api));
173 xla::Status _status = PjrtErrorToStatus(_error.get(), api);
174 if (!_status.ok()) {
175 LOG(FATAL) << "Unexpected error status " << _status.error_message();
176 }
177 }
178
MakeEventDeleter(const PJRT_Api * api)179 PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api) {
180 CHECK(api != nullptr);
181 return [api](PJRT_Event* managed) {
182 PJRT_Event_Destroy_Args args;
183 args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
184 args.priv = nullptr;
185 args.event = managed;
186
187 LogFatalIfPjrtError(api->PJRT_Event_Destroy(&args), api);
188 };
189 }
190
191 } // namespace pjrt
192