xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/c_api_conversions.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_
18 
19 #include "absl/container/inlined_vector.h"
20 #include "tensorflow/compiler/xla/executable_run_options.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h"
26 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
30 #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h"
31 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
32 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h"
33 
34 // APIs for converting between internal and external versions of
35 // XLA/StreamExecutor data structures.
36 namespace ApiConverter {
37 
38 absl::Span<const float> MakeSpan(const FloatList& src_list);
39 void CreateVector(const absl::Span<const float> src, FloatList* dst);
40 void Destroy(FloatList* float_list);
41 
42 absl::Span<const int64_t> MakeSpan(const Int64List& src_list);
43 void CreateVector(const absl::Span<const int64_t> src, Int64List* dst);
44 
45 absl::Span<const bool> MakeSpan(const BoolList& src_list);
46 void CreateVector(const absl::Span<const bool> src, BoolList* dst);
47 
48 // se::DeviceMemoryBase
49 SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base);
50 void ToC(const stream_executor::DeviceMemoryBase& base,
51          SE_DeviceMemoryBase* se_base);
52 stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
53 void Destroy(SE_DeviceMemoryBase*);
54 
55 // xla::Shape
56 xla::Shape FromC(const XLA_Shape* c_shape);
57 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
58 void Destroy(XLA_Shape* c_shape);
59 
60 // xla::Layout
61 xla::Layout FromC(const XLA_Layout* c_layout);
62 void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout);
63 void Destroy(XLA_Layout* c_layout);
64 
65 // xla::Tile
66 xla::Tile FromC(const XLA_Tile* c_tile);
67 void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile);
68 void Destroy(XLA_Tile* c_tile);
69 
70 // xla::ShapeIndex
71 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape);
72 xla::ShapeIndex FromC(XLA_ShapeIndex* c_shape);
73 void Destroy(XLA_ShapeIndex*);
74 
75 // Literal
76 void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
77 xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
78 void Destroy(XLA_Literal* c_literal);
79 
80 // ShapedBuffer
81 void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer);
82 xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer);
83 void Destroy(XLA_ShapedBuffer* c_buffer);
84 
85 // se::DeviceMemoryBase
86 SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base);
87 stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
88 void Destroy(SE_DeviceMemoryBase*);
89 
90 // Literal
91 void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
92 xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
93 void Destroy(XLA_Literal* c_literal);
94 
95 // ShapedBuffer
96 void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer);
97 xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer);
98 void Destroy(XLA_ShapedBuffer* c_buffer);
99 
100 // TpuEmbeddingEngineParametersData
101 struct TpuEmbeddingEngineParametersData {
102   // Backing vector for struct
103   std::array<std::vector<FloatListRef*>, 8> vectors;
104   TpuEmbeddingEngineParameters c_params;
105 };
106 
107 std::unique_ptr<TpuEmbeddingEngineParametersData> Create(int num_tables);
108 
109 xla::MaybeOwningDeviceMemory FromC(
110     SE_MaybeOwningDeviceMemory* se_mem,
111     stream_executor::DeviceMemoryAllocator* allocator);
112 
113 // DeviceMemoryAllocator
114 SE_DeviceMemoryAllocator ToC(stream_executor::DeviceMemoryAllocator* allocator);
115 
116 // OwningDeviceMemory
117 SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem);
118 // mem.HasOwnership() may be true if the buffer is aliased and shouldn't be
119 // released. 'aliased' should be true in this case. 'aliased' has no effect if
120 // 'mem' is unowned.
121 SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased);
122 
123 // HloModule
124 XLA_HloModule ToC(const xla::HloModule& module);
125 xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
126     const XLA_HloModule& c_module);
127 void Destroy(XLA_HloModule* c_module);
128 
129 // HloModuleConfig
130 XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config);
131 xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config);
132 void Destroy(XLA_HloModuleConfig* c_config);
133 
134 // Helper for managing stack based C -> C++ conversions.
135 template <class CType>
136 struct StackHelper {
StackHelperStackHelper137   explicit StackHelper() {}
138 
139   template <class CppType>
StackHelperStackHelper140   explicit StackHelper(const CppType& t) {
141     ::ApiConverter::ToC(t, &value);
142   }
~StackHelperStackHelper143   ~StackHelper() { ::ApiConverter::Destroy(&value); }
144 
145   template <class CppType>
AsCppStackHelper146   CppType AsCpp() const {
147     return ::ApiConverter::FromC(&value);
148   }
149 
150   mutable CType value;
151 };
152 
153 }  // namespace ApiConverter
154 
155 #endif
156