xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/results.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_RESULTS_H_
17 #define XLA_RUNTIME_RESULTS_H_
18 
19 #include "llvm/Support/Error.h"
20 #include "tensorflow/compiler/xla/runtime/logical_result.h"
21 #include "tensorflow/compiler/xla/runtime/types.h"
22 
23 namespace xla {
24 namespace runtime {
25 
26 //===----------------------------------------------------------------------===//
27 // Conversions from XLA executable results to C++ types.
28 //===----------------------------------------------------------------------===//
29 
30 // The result type defines its own ABI as a required number of bytes and
31 // alignment, and executable returns results by writing into the requested
32 // memory allocated in the call frame. The user is responsible for providing
33 // a conversion function that converts this opaque memory back to the C++
34 // data type. For example memrefs returned as a `StridedMemrefType` structure,
35 // and it is the user responsibiity to define a conversion function that can
36 // convert a memref to the run time Tensor/Buffer type.
37 //
38 // It is important that the type that is written into the call frame memory has
39 // a standard memory layout, because we rely on `reinterpret_cast` to reinterpet
40 // the opaque bytes to a C struct.
41 //
42 // See https://en.cppreference.com/w/cpp/types/is_standard_layout
43 
44 // Result converter is responsible for taking a pointer to the memory location
45 // where the executable wrote the result, and converting it to the corresponding
46 // run time value expected by the caller (e.g. memref descriptor to Tensor).
47 class ResultConverter {
48  public:
49   virtual ~ResultConverter() = default;
50 
51   // Converts value `ret` of type `runtime_type` (runtime type derived from the
52   // original `type`) returned from the executable at `result_index` result
53   // position using registered conversion functions. Returns a logical result
54   // telling if the conversion was successful.
55   virtual LogicalResult ReturnValue(unsigned result_index, const Type* type,
56                                     const Type* runtime_type,
57                                     void* ret) const = 0;
58 
59   // Returns error for all results.
60   virtual void ReturnError(const llvm::Error& error) const = 0;
61 };
62 
63 //===----------------------------------------------------------------------===//
64 // Result converter for functions without results (returning void).
65 //===----------------------------------------------------------------------===//
66 
67 struct NoResultConverter : public ResultConverter {
68   LLVM_ATTRIBUTE_ALWAYS_INLINE
ReturnValueNoResultConverter69   LogicalResult ReturnValue(unsigned, const Type*, const Type*,
70                             void*) const final {
71     assert(false && "no result converter must never be called");
72     return failure();
73   }
74 
ReturnErrorNoResultConverter75   void ReturnError(const llvm::Error&) const final {}
76 };
77 
78 }  // namespace runtime
79 }  // namespace xla
80 
81 #endif  // XLA_RUNTIME_RESULTS_H_
82