1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/callback.h"
17
18 #include <cstring>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/python/exceptions.h"
26 #include "tensorflow/compiler/xla/service/custom_call_status.h"
27 #include "tensorflow/core/profiler/lib/traceme.h"
28
29 namespace py = pybind11;
30
31 namespace xla {
32
PrepareAndCallInternal(void * result,void ** arg_ptrs)33 Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) {
34 absl::Span<void* const> inputs(arg_ptrs, args_.size());
35 absl::Span<void* const> outputs(reinterpret_cast<void**>(result),
36 results_.size());
37
38 py::gil_scoped_acquire gil;
39 py::tuple args(inputs.size());
40 for (size_t i = 0; i < inputs.size(); ++i) {
41 if (args_[i].type == xla::TOKEN) {
42 args[i] = py::none();
43 } else {
44 args[i] = py::array(args_[i].dtype, args_[i].dims, args_[i].strides,
45 const_cast<void*>(inputs[i]));
46 args[i].attr("flags").attr("writeable") = Py_False;
47 }
48 }
49
50 TF_ASSIGN_OR_RETURN(auto result_tuple, CallInternal(std::move(args)));
51
52 for (size_t i = 0; i < results_.size(); ++i) {
53 py::object output = py::reinterpret_borrow<py::object>(
54 PyTuple_GetItem(result_tuple.ptr(), i));
55 py::array array = py::cast<py::array>(std::move(output));
56 absl::Span<int64_t const> dims(
57 reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
58 absl::Span<int64_t const> strides(
59 reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
60 if (strides == results_[i].expected_strides) {
61 std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes);
62 } else {
63 xla::StatusOr<std::shared_ptr<xla::TransposePlan>> plan =
64 transpose_cache_.GetOrCreate(
65 xla::primitive_util::ByteWidth(results_[i].type), dims,
66 results_[i].reversed_layout,
67 /*input_layout=*/xla::TransposePlan::Striding{strides});
68 if (!plan.ok()) {
69 return std::move(plan).status();
70 }
71 plan.ValueOrDie()->Execute(array.data(), outputs[i]);
72 }
73 }
74
75 return Status::OK();
76 }
77
PrepareAndCall(void * result,void ** arg_ptrs,XlaCustomCallStatus * status)78 void CpuCallback::PrepareAndCall(void* result, void** arg_ptrs,
79 XlaCustomCallStatus* status) {
80 auto s = PrepareAndCallInternal(result, arg_ptrs);
81 if (!s.ok()) {
82 XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
83 s.error_message().length());
84 return;
85 }
86 }
87
PrepareAndCall(void * result,void ** arg_ptrs)88 Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) {
89 return PrepareAndCallInternal(result, arg_ptrs);
90 }
91
CallInternal(py::tuple args)92 StatusOr<py::tuple> CpuCallback::CallInternal(py::tuple args) {
93 py::object result_object;
94 try {
95 result_object = callable_(*py::reinterpret_borrow<py::args>(args));
96 } catch (py::error_already_set& e) {
97 PyErr_Clear();
98 std::string error_message = e.what();
99 return InternalError("CpuCallback error: %s", error_message);
100 }
101 if (!PyTuple_Check(result_object.ptr())) {
102 return InternalError("CPU callback expected a tuple result, got %s",
103 static_cast<std::string>(py::repr(result_object)));
104 }
105 if (PyTuple_Size(result_object.ptr()) != results_.size()) {
106 return InternalError(
107 "CPU callback expected a tuple with %d results, got %d",
108 results_.size(), PyTuple_Size(result_object.ptr()));
109 }
110 py::tuple result_tuple = py::cast<py::tuple>(result_object);
111 for (size_t i = 0; i < results_.size(); ++i) {
112 py::object output = py::reinterpret_borrow<py::object>(
113 PyTuple_GetItem(result_tuple.ptr(), i));
114 if (results_[i].type == xla::TOKEN) {
115 if (!output.is_none()) {
116 return InternalError(
117 "Token output from Python callback should be None, got %s",
118 static_cast<std::string>(py::repr(output)));
119 }
120 continue;
121 }
122 py::array array = py::cast<py::array>(std::move(output));
123 static_assert(sizeof(ssize_t) == sizeof(int64_t),
124 "Expected ssize_t to be of equal size to int64_t");
125 absl::Span<int64_t const> dims(
126 reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
127 if (dims != results_[i].expected_dims) {
128 return InternalError(
129 "Mismatched result shape for %d-th return value from CPU callback; "
130 "expected array with dimensions %s, got %s",
131 i, absl::StrJoin(results_[i].expected_dims, ","),
132 absl::StrJoin(dims, ","));
133 }
134 }
135 return result_tuple;
136 }
137
Call(py::tuple args)138 StatusOr<py::tuple> CpuCallback::Call(py::tuple args) {
139 return CallInternal(std::move(args));
140 }
141
Call(py::tuple args,XlaCustomCallStatus * status)142 std::optional<py::tuple> CpuCallback::Call(py::tuple args,
143 XlaCustomCallStatus* status) {
144 auto statusor = CallInternal(std::move(args));
145 if (!statusor.ok()) {
146 XlaCustomCallStatusSetFailure(status,
147 statusor.status().error_message().c_str(),
148 statusor.status().error_message().length());
149 return std::nullopt;
150 }
151 return std::move(statusor).value();
152 }
153
XlaPythonCpuCallback(void * output,void ** inputs,XlaCustomCallStatus * status)154 void XlaPythonCpuCallback(void* output, void** inputs,
155 XlaCustomCallStatus* status) {
156 CpuCallback* callback =
157 absl::bit_cast<CpuCallback*>(*static_cast<uintptr_t*>(inputs[0]));
158 callback->PrepareAndCall(output, inputs + 1, status);
159 }
160
161 } // namespace xla
162