xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/callback.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/callback.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/python/exceptions.h"
26 #include "tensorflow/compiler/xla/service/custom_call_status.h"
27 #include "tensorflow/core/profiler/lib/traceme.h"
28 
29 namespace py = pybind11;
30 
31 namespace xla {
32 
PrepareAndCallInternal(void * result,void ** arg_ptrs)33 Status CpuCallback::PrepareAndCallInternal(void* result, void** arg_ptrs) {
34   absl::Span<void* const> inputs(arg_ptrs, args_.size());
35   absl::Span<void* const> outputs(reinterpret_cast<void**>(result),
36                                   results_.size());
37 
38   py::gil_scoped_acquire gil;
39   py::tuple args(inputs.size());
40   for (size_t i = 0; i < inputs.size(); ++i) {
41     if (args_[i].type == xla::TOKEN) {
42       args[i] = py::none();
43     } else {
44       args[i] = py::array(args_[i].dtype, args_[i].dims, args_[i].strides,
45                           const_cast<void*>(inputs[i]));
46       args[i].attr("flags").attr("writeable") = Py_False;
47     }
48   }
49 
50   TF_ASSIGN_OR_RETURN(auto result_tuple, CallInternal(std::move(args)));
51 
52   for (size_t i = 0; i < results_.size(); ++i) {
53     py::object output = py::reinterpret_borrow<py::object>(
54         PyTuple_GetItem(result_tuple.ptr(), i));
55     py::array array = py::cast<py::array>(std::move(output));
56     absl::Span<int64_t const> dims(
57         reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
58     absl::Span<int64_t const> strides(
59         reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
60     if (strides == results_[i].expected_strides) {
61       std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes);
62     } else {
63       xla::StatusOr<std::shared_ptr<xla::TransposePlan>> plan =
64           transpose_cache_.GetOrCreate(
65               xla::primitive_util::ByteWidth(results_[i].type), dims,
66               results_[i].reversed_layout,
67               /*input_layout=*/xla::TransposePlan::Striding{strides});
68       if (!plan.ok()) {
69         return std::move(plan).status();
70       }
71       plan.ValueOrDie()->Execute(array.data(), outputs[i]);
72     }
73   }
74 
75   return Status::OK();
76 }
77 
PrepareAndCall(void * result,void ** arg_ptrs,XlaCustomCallStatus * status)78 void CpuCallback::PrepareAndCall(void* result, void** arg_ptrs,
79                                  XlaCustomCallStatus* status) {
80   auto s = PrepareAndCallInternal(result, arg_ptrs);
81   if (!s.ok()) {
82     XlaCustomCallStatusSetFailure(status, s.error_message().c_str(),
83                                   s.error_message().length());
84     return;
85   }
86 }
87 
PrepareAndCall(void * result,void ** arg_ptrs)88 Status CpuCallback::PrepareAndCall(void* result, void** arg_ptrs) {
89   return PrepareAndCallInternal(result, arg_ptrs);
90 }
91 
CallInternal(py::tuple args)92 StatusOr<py::tuple> CpuCallback::CallInternal(py::tuple args) {
93   py::object result_object;
94   try {
95     result_object = callable_(*py::reinterpret_borrow<py::args>(args));
96   } catch (py::error_already_set& e) {
97     PyErr_Clear();
98     std::string error_message = e.what();
99     return InternalError("CpuCallback error: %s", error_message);
100   }
101   if (!PyTuple_Check(result_object.ptr())) {
102     return InternalError("CPU callback expected a tuple result, got %s",
103                          static_cast<std::string>(py::repr(result_object)));
104   }
105   if (PyTuple_Size(result_object.ptr()) != results_.size()) {
106     return InternalError(
107         "CPU callback expected a tuple with %d results, got %d",
108         results_.size(), PyTuple_Size(result_object.ptr()));
109   }
110   py::tuple result_tuple = py::cast<py::tuple>(result_object);
111   for (size_t i = 0; i < results_.size(); ++i) {
112     py::object output = py::reinterpret_borrow<py::object>(
113         PyTuple_GetItem(result_tuple.ptr(), i));
114     if (results_[i].type == xla::TOKEN) {
115       if (!output.is_none()) {
116         return InternalError(
117             "Token output from Python callback should be None, got %s",
118             static_cast<std::string>(py::repr(output)));
119       }
120       continue;
121     }
122     py::array array = py::cast<py::array>(std::move(output));
123     static_assert(sizeof(ssize_t) == sizeof(int64_t),
124                   "Expected ssize_t to be of equal size to int64_t");
125     absl::Span<int64_t const> dims(
126         reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
127     if (dims != results_[i].expected_dims) {
128       return InternalError(
129           "Mismatched result shape for %d-th return value from CPU callback; "
130           "expected array with dimensions %s, got %s",
131           i, absl::StrJoin(results_[i].expected_dims, ","),
132           absl::StrJoin(dims, ","));
133     }
134   }
135   return result_tuple;
136 }
137 
Call(py::tuple args)138 StatusOr<py::tuple> CpuCallback::Call(py::tuple args) {
139   return CallInternal(std::move(args));
140 }
141 
Call(py::tuple args,XlaCustomCallStatus * status)142 std::optional<py::tuple> CpuCallback::Call(py::tuple args,
143                                            XlaCustomCallStatus* status) {
144   auto statusor = CallInternal(std::move(args));
145   if (!statusor.ok()) {
146     XlaCustomCallStatusSetFailure(status,
147                                   statusor.status().error_message().c_str(),
148                                   statusor.status().error_message().length());
149     return std::nullopt;
150   }
151   return std::move(statusor).value();
152 }
153 
XlaPythonCpuCallback(void * output,void ** inputs,XlaCustomCallStatus * status)154 void XlaPythonCpuCallback(void* output, void** inputs,
155                           XlaCustomCallStatus* status) {
156   CpuCallback* callback =
157       absl::bit_cast<CpuCallback*>(*static_cast<uintptr_t*>(inputs[0]));
158   callback->PrepareAndCall(output, inputs + 1, status);
159 }
160 
161 }  // namespace xla
162