xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/compiler/xla/pjrt/c/pjrt_c_api.h"
21 
22 namespace pjrt {
23 
MakeClientDeleter(const PJRT_Api * api)24 PJRT_ClientDeleter MakeClientDeleter(const PJRT_Api* api) {
25   return [api](PJRT_Client* client) -> void {
26     PJRT_Client_Destroy_Args destroy_args;
27     destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE;
28     destroy_args.priv = nullptr;
29     destroy_args.client = client;
30 
31     PJRT_Error* error = api->PJRT_Client_Destroy(&destroy_args);
32     // TODO(b/236710439): handle the error and remove this CHECK() call
33     CHECK(error == nullptr);
34   };
35 }
36 
MakeErrorDeleter(const PJRT_Api * api)37 PJRT_ErrorDeleter MakeErrorDeleter(const PJRT_Api* api) {
38   return [api](PJRT_Error* error) -> void {
39     PJRT_Error_Destroy_Args destroy_args;
40     destroy_args.struct_size = PJRT_Error_Destroy_Args_STRUCT_SIZE;
41     destroy_args.priv = nullptr;
42     destroy_args.error = error;
43 
44     api->PJRT_Error_Destroy(&destroy_args);
45   };
46 }
47 
MakeBufferDeleter(const PJRT_Api * api)48 PJRT_BufferDeleter MakeBufferDeleter(const PJRT_Api* api) {
49   return [api](PJRT_Buffer* buffer) -> void {
50     PJRT_Buffer_Destroy_Args destroy_args;
51     destroy_args.struct_size = PJRT_Buffer_Destroy_Args_STRUCT_SIZE;
52     destroy_args.priv = nullptr;
53     destroy_args.buffer = buffer;
54 
55     pjrt::LogFatalIfPjrtError(api->PJRT_Buffer_Destroy(&destroy_args), api);
56   };
57 }
58 
PjrtErrorToStatus(const PJRT_Error * error,const PJRT_Api * api)59 xla::Status PjrtErrorToStatus(const PJRT_Error* error, const PJRT_Api* api) {
60   xla::Status status;
61   if (error != nullptr) {
62     status = xla::Status(PjrtErrorToStatusCode(error, api),
63                          GetPjrtErrorMessage(error, api));
64   }
65   return status;
66 }
67 
PjrtErrorToStatusCode(const PJRT_Error * error,const PJRT_Api * api)68 tensorflow::error::Code PjrtErrorToStatusCode(const PJRT_Error* error,
69                                               const PJRT_Api* api) {
70   PJRT_Error_GetCode_Args args;
71   args.struct_size = PJRT_Error_GetCode_Args_STRUCT_SIZE;
72   args.priv = nullptr;
73   args.error = error;
74   api->PJRT_Error_GetCode(&args);
75   PJRT_Error_Code code = args.code;
76   switch (code) {
77     case PJRT_Error_Code_CANCELLED:
78       return tensorflow::error::CANCELLED;
79     case PJRT_Error_Code_UNKNOWN:
80       return tensorflow::error::UNKNOWN;
81     case PJRT_Error_Code_INVALID_ARGUMENT:
82       return tensorflow::error::INVALID_ARGUMENT;
83     case PJRT_Error_Code_DEADLINE_EXCEEDED:
84       return tensorflow::error::DEADLINE_EXCEEDED;
85     case PJRT_Error_Code_NOT_FOUND:
86       return tensorflow::error::NOT_FOUND;
87     case PJRT_Error_Code_ALREADY_EXISTS:
88       return tensorflow::error::ALREADY_EXISTS;
89     case PJRT_Error_Code_PERMISSION_DENIED:
90       return tensorflow::error::PERMISSION_DENIED;
91     case PJRT_Error_Code_RESOURCE_EXHAUSTED:
92       return tensorflow::error::RESOURCE_EXHAUSTED;
93     case PJRT_Error_Code_FAILED_PRECONDITION:
94       return tensorflow::error::FAILED_PRECONDITION;
95     case PJRT_Error_Code_ABORTED:
96       return tensorflow::error::ABORTED;
97     case PJRT_Error_Code_OUT_OF_RANGE:
98       return tensorflow::error::OUT_OF_RANGE;
99     case PJRT_Error_Code_UNIMPLEMENTED:
100       return tensorflow::error::UNIMPLEMENTED;
101     case PJRT_Error_Code_INTERNAL:
102       return tensorflow::error::INTERNAL;
103     case PJRT_Error_Code_UNAVAILABLE:
104       return tensorflow::error::UNAVAILABLE;
105     case PJRT_Error_Code_DATA_LOSS:
106       return tensorflow::error::DATA_LOSS;
107     case PJRT_Error_Code_UNAUTHENTICATED:
108       return tensorflow::error::UNAUTHENTICATED;
109   }
110 }
111 
StatusCodeToPjrtErrorCode(tensorflow::error::Code code)112 PJRT_Error_Code StatusCodeToPjrtErrorCode(tensorflow::error::Code code) {
113   switch (code) {
114     case tensorflow::error::CANCELLED:
115       return PJRT_Error_Code::PJRT_Error_Code_CANCELLED;
116     case tensorflow::error::UNKNOWN:
117       return PJRT_Error_Code::PJRT_Error_Code_UNKNOWN;
118     case tensorflow::error::INVALID_ARGUMENT:
119       return PJRT_Error_Code::PJRT_Error_Code_INVALID_ARGUMENT;
120     case tensorflow::error::DEADLINE_EXCEEDED:
121       return PJRT_Error_Code::PJRT_Error_Code_DEADLINE_EXCEEDED;
122     case tensorflow::error::NOT_FOUND:
123       return PJRT_Error_Code::PJRT_Error_Code_NOT_FOUND;
124     case tensorflow::error::ALREADY_EXISTS:
125       return PJRT_Error_Code::PJRT_Error_Code_ALREADY_EXISTS;
126     case tensorflow::error::PERMISSION_DENIED:
127       return PJRT_Error_Code::PJRT_Error_Code_PERMISSION_DENIED;
128     case tensorflow::error::UNAUTHENTICATED:
129       return PJRT_Error_Code::PJRT_Error_Code_UNAUTHENTICATED;
130     case tensorflow::error::RESOURCE_EXHAUSTED:
131       return PJRT_Error_Code::PJRT_Error_Code_RESOURCE_EXHAUSTED;
132     case tensorflow::error::FAILED_PRECONDITION:
133       return PJRT_Error_Code::PJRT_Error_Code_FAILED_PRECONDITION;
134     case tensorflow::error::ABORTED:
135       return PJRT_Error_Code::PJRT_Error_Code_ABORTED;
136     case tensorflow::error::OUT_OF_RANGE:
137       return PJRT_Error_Code::PJRT_Error_Code_OUT_OF_RANGE;
138     case tensorflow::error::UNIMPLEMENTED:
139       return PJRT_Error_Code::PJRT_Error_Code_UNIMPLEMENTED;
140     case tensorflow::error::INTERNAL:
141       return PJRT_Error_Code::PJRT_Error_Code_INTERNAL;
142     case tensorflow::error::UNAVAILABLE:
143       return PJRT_Error_Code::PJRT_Error_Code_UNAVAILABLE;
144     case tensorflow::error::DATA_LOSS:
145       return PJRT_Error_Code::PJRT_Error_Code_DATA_LOSS;
146     case tensorflow::error::OK:
147       CHECK(false) << "Status::OK() cannot be converted to PJRT_Error code, "
148                       "use nullptr instead";
149     case tensorflow::error::
150         DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_:
151       CHECK(false) << "got DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_"
152                       "USE_DEFAULT_IN_SWITCH_INSTEAD_";
153     case tensorflow::error::Code_INT_MIN_SENTINEL_DO_NOT_USE_:
154       CHECK(false) << "got Code_INT_MIN_SENTINEL_DO_NOT_USE_";
155     case tensorflow::error::Code_INT_MAX_SENTINEL_DO_NOT_USE_:
156       CHECK(false) << "got Code_INT_MAX_SENTINEL_DO_NOT_USE_";
157   }
158 }
159 
GetPjrtErrorMessage(const PJRT_Error * error,const PJRT_Api * api)160 absl::string_view GetPjrtErrorMessage(const PJRT_Error* error,
161                                       const PJRT_Api* api) {
162   PJRT_Error_Message_Args message_args;
163   message_args.struct_size = PJRT_Error_Message_Args_STRUCT_SIZE;
164   message_args.priv = nullptr;
165   message_args.error = error;
166   api->PJRT_Error_Message(&message_args);
167   return absl::string_view(message_args.message, message_args.message_size);
168 }
169 
LogFatalIfPjrtError(PJRT_Error * error,const PJRT_Api * api)170 void LogFatalIfPjrtError(PJRT_Error* error, const PJRT_Api* api) {
171   std::unique_ptr<PJRT_Error, pjrt::PJRT_ErrorDeleter> _error(
172       error, MakeErrorDeleter(api));
173   xla::Status _status = PjrtErrorToStatus(_error.get(), api);
174   if (!_status.ok()) {
175     LOG(FATAL) << "Unexpected error status " << _status.error_message();
176   }
177 }
178 
MakeEventDeleter(const PJRT_Api * api)179 PJRT_EventDeleter MakeEventDeleter(const PJRT_Api* api) {
180   CHECK(api != nullptr);
181   return [api](PJRT_Event* managed) {
182     PJRT_Event_Destroy_Args args;
183     args.struct_size = PJRT_Event_Destroy_Args_STRUCT_SIZE;
184     args.priv = nullptr;
185     args.event = managed;
186 
187     LogFatalIfPjrtError(api->PJRT_Event_Destroy(&args), api);
188   };
189 }
190 
191 }  // namespace pjrt
192