1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_ 18 19 #include <optional> 20 #include <utility> 21 22 #include "pybind11/pybind11.h" 23 #include "tensorflow/compiler/xla/pjrt/transpose.h" 24 #include "tensorflow/compiler/xla/python/py_values.h" 25 #include "tensorflow/compiler/xla/python/python_ref_manager.h" 26 #include "tensorflow/compiler/xla/service/custom_call_status.h" 27 #include "tensorflow/compiler/xla/types.h" 28 29 namespace xla { 30 31 class CpuCallback { 32 public: 33 struct Arg { 34 xla::PrimitiveType type; // XLA type 35 pybind11::dtype dtype; // NumPy type, for array types. 36 absl::InlinedVector<int64_t, 4> dims; // Dimensions, for array types. 37 std::vector<ssize_t> strides; // Byte strides, for array types. 38 size_t size_in_bytes; // Size of the array in bytes. 39 }; 40 struct Result { 41 xla::PrimitiveType type; // XLA type 42 // Expected output shape, for array types 43 absl::InlinedVector<int64_t, 4> expected_dims; 44 // Expected output byte strides, for array types. If the strides do not 45 // match the output will be transposed into the expected layout. 46 std::vector<int64_t> expected_strides; 47 // The desired order of output dimensions in major-to-minor order. 48 absl::InlinedVector<int64_t, 4> reversed_layout; 49 // Size of the array in bytes. 50 size_t size_in_bytes; 51 }; 52 CpuCallback(pybind11::function callable,std::vector<Arg> args,std::vector<Result> results)53 explicit CpuCallback(pybind11::function callable, std::vector<Arg> args, 54 std::vector<Result> results) 55 : callable_(std::move(callable)), 56 args_(std::move(args)), 57 results_(std::move(results)), 58 transpose_cache_(/*capacity=*/16) {} 59 ~CpuCallback()60 ~CpuCallback() { 61 // The destructor may be called without GIL held. In that case, we defer it 62 // to GlobalPyRefManager. 63 pybind11::object object = std::move(callable_); 64 GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&object, 1)); 65 } 66 args()67 const std::vector<Arg>& args() const { return args_; } num_args()68 size_t num_args() const { return args_.size(); } 69 results()70 const std::vector<Result>& results() const { return results_; } num_results()71 size_t num_results() const { return results_.size(); } 72 transpose_cache()73 xla::TransposePlanCache& transpose_cache() { return transpose_cache_; } 74 75 void PrepareAndCall(void* result, void** arg_ptrs, 76 XlaCustomCallStatus* status); 77 Status PrepareAndCall(void* result, void** arg_ptrs); 78 79 std::optional<pybind11::tuple> Call(pybind11::tuple args, 80 XlaCustomCallStatus* status); 81 StatusOr<pybind11::tuple> Call(pybind11::tuple args); 82 83 private: 84 Status PrepareAndCallInternal(void* result, void** arg_ptrs); 85 StatusOr<pybind11::tuple> CallInternal(pybind11::tuple args); 86 87 pybind11::function callable_; 88 std::vector<Arg> const args_; 89 std::vector<Result> const results_; 90 xla::TransposePlanCache transpose_cache_; 91 }; 92 93 void XlaPythonCpuCallback(void* output, void** inputs, 94 XlaCustomCallStatus* status); 95 96 } // namespace xla 97 98 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_CALLBACK_H_ 99