1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_H_
18
19 #include "pybind11/pybind11.h"
20 #include "tensorflow/compiler/xla/python/exceptions.h"
21 #include "tensorflow/compiler/xla/status.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23
24 namespace xla {
25
26 // Helper that converts a failing StatusOr to an exception.
27 // For use only inside pybind11 code.
28 template <typename T>
ValueOrThrow(StatusOr<T> v)29 T ValueOrThrow(StatusOr<T> v) {
30 if (!v.ok()) {
31 throw xla::XlaRuntimeError(v.status());
32 }
33 return std::move(v).value();
34 }
35
36 } // namespace xla
37
38 // This namespace is a documented pybind11 extension point.
39 // Caution: Unusually for Google code, this code uses C++ exceptions because
40 // they are the only mechanism for reporting cast failures to pybind11. However,
41 // the exceptions are local to the binding code.
42 namespace pybind11 {
43 namespace detail {
44
45 // Status, StatusOr. Failing statuses become Python exceptions; Status::OK()
46 // becomes None.
47 template <>
48 struct type_caster<xla::Status> {
49 public:
50 PYBIND11_TYPE_CASTER(xla::Status, _("Status"));
51
52 static handle cast(xla::Status src, return_value_policy /* policy */,
53 handle /* parent */) {
54 if (!src.ok()) {
55 throw xla::XlaRuntimeError(src);
56 }
57 return none().inc_ref();
58 }
59 };
60
61 template <typename T>
62 struct type_caster<xla::StatusOr<T>> {
63 public:
64 using value_conv = make_caster<T>;
65
66 PYBIND11_TYPE_CASTER(xla::StatusOr<T>,
67 _("StatusOr[") + value_conv::name + _("]"));
68
69 static handle cast(xla::StatusOr<T> src, return_value_policy policy,
70 handle parent) {
71 if (!src.ok()) {
72 throw xla::XlaRuntimeError(src.status());
73 }
74 return value_conv::cast(std::forward<xla::StatusOr<T>>(src).ValueOrDie(),
75 policy, parent);
76 }
77 };
78
79 } // namespace detail
80 } // namespace pybind11
81
82 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_STATUS_CASTERS_H_
83