xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/callback.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_
18 
19 #include <optional>
20 #include <utility>
21 
22 #include "pybind11/pybind11.h"
23 #include "tensorflow/compiler/xla/pjrt/transpose.h"
24 #include "tensorflow/compiler/xla/python/py_values.h"
25 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
26 #include "tensorflow/compiler/xla/service/custom_call_status.h"
27 #include "tensorflow/compiler/xla/types.h"
28 
29 namespace xla {
30 
31 class CpuCallback {
32  public:
33   struct Arg {
34     xla::PrimitiveType type;               // XLA type
35     pybind11::dtype dtype;                 // NumPy type, for array types.
36     absl::InlinedVector<int64_t, 4> dims;  // Dimensions, for array types.
37     std::vector<ssize_t> strides;          // Byte strides, for array types.
38     size_t size_in_bytes;                  // Size of the array in bytes.
39   };
40   struct Result {
41     xla::PrimitiveType type;  // XLA type
42     // Expected output shape, for array types
43     absl::InlinedVector<int64_t, 4> expected_dims;
44     // Expected output byte strides, for array types. If the strides do not
45     // match the output will be transposed into the expected layout.
46     std::vector<int64_t> expected_strides;
47     // The desired order of output dimensions in major-to-minor order.
48     absl::InlinedVector<int64_t, 4> reversed_layout;
49     // Size of the array in bytes.
50     size_t size_in_bytes;
51   };
52 
CpuCallback(pybind11::function callable,std::vector<Arg> args,std::vector<Result> results)53   explicit CpuCallback(pybind11::function callable, std::vector<Arg> args,
54                        std::vector<Result> results)
55       : callable_(std::move(callable)),
56         args_(std::move(args)),
57         results_(std::move(results)),
58         transpose_cache_(/*capacity=*/16) {}
59 
~CpuCallback()60   ~CpuCallback() {
61     // The destructor may be called without GIL held. In that case, we defer it
62     // to GlobalPyRefManager.
63     pybind11::object object = std::move(callable_);
64     GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&object, 1));
65   }
66 
args()67   const std::vector<Arg>& args() const { return args_; }
num_args()68   size_t num_args() const { return args_.size(); }
69 
results()70   const std::vector<Result>& results() const { return results_; }
num_results()71   size_t num_results() const { return results_.size(); }
72 
transpose_cache()73   xla::TransposePlanCache& transpose_cache() { return transpose_cache_; }
74 
75   void PrepareAndCall(void* result, void** arg_ptrs,
76                       XlaCustomCallStatus* status);
77   Status PrepareAndCall(void* result, void** arg_ptrs);
78 
79   std::optional<pybind11::tuple> Call(pybind11::tuple args,
80                                       XlaCustomCallStatus* status);
81   StatusOr<pybind11::tuple> Call(pybind11::tuple args);
82 
83  private:
84   Status PrepareAndCallInternal(void* result, void** arg_ptrs);
85   StatusOr<pybind11::tuple> CallInternal(pybind11::tuple args);
86 
87   pybind11::function callable_;
88   std::vector<Arg> const args_;
89   std::vector<Result> const results_;
90   xla::TransposePlanCache transpose_cache_;
91 };
92 
93 void XlaPythonCpuCallback(void* output, void** inputs,
94                           XlaCustomCallStatus* status);
95 
96 }  // namespace xla
97 
98 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_
99