1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ 17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_TPU_C_API_CONVERSIONS_H_ 18 19 #include "absl/container/inlined_vector.h" 20 #include "tensorflow/compiler/xla/executable_run_options.h" 21 #include "tensorflow/compiler/xla/literal.h" 22 #include "tensorflow/compiler/xla/service/hlo_module.h" 23 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 24 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" 25 #include "tensorflow/compiler/xla/service/service_executable_run_options.h" 26 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 27 #include "tensorflow/compiler/xla/shape.h" 28 #include "tensorflow/compiler/xla/shape_util.h" 29 #include "tensorflow/compiler/xla/stream_executor/device_memory.h" 30 #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h" 31 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" 32 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor_c_api.h" 33 34 // APIs for converting between internal and external versions of 35 // XLA/StreamExecutor data structures. 36 namespace ApiConverter { 37 38 absl::Span<const float> MakeSpan(const FloatList& src_list); 39 void CreateVector(const absl::Span<const float> src, FloatList* dst); 40 void Destroy(FloatList* float_list); 41 42 absl::Span<const int64_t> MakeSpan(const Int64List& src_list); 43 void CreateVector(const absl::Span<const int64_t> src, Int64List* dst); 44 45 absl::Span<const bool> MakeSpan(const BoolList& src_list); 46 void CreateVector(const absl::Span<const bool> src, BoolList* dst); 47 48 // se::DeviceMemoryBase 49 SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base); 50 void ToC(const stream_executor::DeviceMemoryBase& base, 51 SE_DeviceMemoryBase* se_base); 52 stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); 53 void Destroy(SE_DeviceMemoryBase*); 54 55 // xla::Shape 56 xla::Shape FromC(const XLA_Shape* c_shape); 57 void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape); 58 void Destroy(XLA_Shape* c_shape); 59 60 // xla::Layout 61 xla::Layout FromC(const XLA_Layout* c_layout); 62 void ToC(const xla::Layout& xla_layout, XLA_Layout* c_layout); 63 void Destroy(XLA_Layout* c_layout); 64 65 // xla::Tile 66 xla::Tile FromC(const XLA_Tile* c_tile); 67 void ToC(const xla::Tile& xla_tile, XLA_Tile* c_tile); 68 void Destroy(XLA_Tile* c_tile); 69 70 // xla::ShapeIndex 71 XLA_ShapeIndex ToC(const xla::ShapeIndex& xla_shape); 72 xla::ShapeIndex FromC(XLA_ShapeIndex* c_shape); 73 void Destroy(XLA_ShapeIndex*); 74 75 // Literal 76 void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal); 77 xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal); 78 void Destroy(XLA_Literal* c_literal); 79 80 // ShapedBuffer 81 void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer); 82 xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer); 83 void Destroy(XLA_ShapedBuffer* c_buffer); 84 85 // se::DeviceMemoryBase 86 SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base); 87 stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base); 88 void Destroy(SE_DeviceMemoryBase*); 89 90 // Literal 91 void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal); 92 xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal); 93 void Destroy(XLA_Literal* c_literal); 94 95 // ShapedBuffer 96 void ToC(const xla::ShapedBuffer& buffer, XLA_ShapedBuffer* c_device_buffer); 97 xla::ShapedBuffer FromC(XLA_ShapedBuffer* c_buffer); 98 void Destroy(XLA_ShapedBuffer* c_buffer); 99 100 // TpuEmbeddingEngineParametersData 101 struct TpuEmbeddingEngineParametersData { 102 // Backing vector for struct 103 std::array<std::vector<FloatListRef*>, 8> vectors; 104 TpuEmbeddingEngineParameters c_params; 105 }; 106 107 std::unique_ptr<TpuEmbeddingEngineParametersData> Create(int num_tables); 108 109 xla::MaybeOwningDeviceMemory FromC( 110 SE_MaybeOwningDeviceMemory* se_mem, 111 stream_executor::DeviceMemoryAllocator* allocator); 112 113 // DeviceMemoryAllocator 114 SE_DeviceMemoryAllocator ToC(stream_executor::DeviceMemoryAllocator* allocator); 115 116 // OwningDeviceMemory 117 SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem); 118 // mem.HasOwnership() may be true if the buffer is aliased and shouldn't be 119 // released. 'aliased' should be true in this case. 'aliased' has no effect if 120 // 'mem' is unowned. 121 SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased); 122 123 // HloModule 124 XLA_HloModule ToC(const xla::HloModule& module); 125 xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC( 126 const XLA_HloModule& c_module); 127 void Destroy(XLA_HloModule* c_module); 128 129 // HloModuleConfig 130 XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config); 131 xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config); 132 void Destroy(XLA_HloModuleConfig* c_config); 133 134 // Helper for managing stack based C -> C++ conversions. 135 template <class CType> 136 struct StackHelper { StackHelperStackHelper137 explicit StackHelper() {} 138 139 template <class CppType> StackHelperStackHelper140 explicit StackHelper(const CppType& t) { 141 ::ApiConverter::ToC(t, &value); 142 } ~StackHelperStackHelper143 ~StackHelper() { ::ApiConverter::Destroy(&value); } 144 145 template <class CppType> AsCppStackHelper146 CppType AsCpp() const { 147 return ::ApiConverter::FromC(&value); 148 } 149 150 mutable CType value; 151 }; 152 153 } // namespace ApiConverter 154 155 #endif 156