1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h"
17
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/c/tf_status_helper.h"
20 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
21 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h"
22 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_id.h"
23 #include "tensorflow/core/tpu/tpu_api.h"
24
25 namespace tensorflow {
26 namespace tpu {
27
28 const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
29 TpuPlatform* tpu_registered_platform = nullptr;
30
31 using Status = ::stream_executor::port::Status;
32 template <typename T>
33 using StatusOr = ::stream_executor::port::StatusOr<T>;
34
TpuPlatform()35 TpuPlatform::TpuPlatform() : name_("TPU") {
36 platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
37 CHECK(platform_ != nullptr);
38 }
39
GetRegisteredPlatform()40 TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
41 return tpu_registered_platform;
42 }
43
Initialize(const std::map<std::string,std::string> & platform_options)44 Status TpuPlatform::Initialize(
45 const std::map<std::string, std::string>& platform_options) {
46 StatusHelper status;
47
48 size_t options_size = platform_options.size();
49 const char** options_key =
50 static_cast<const char**>(malloc(sizeof(const char*) * options_size));
51 const char** options_value =
52 static_cast<const char**>(malloc(sizeof(const char*) * options_size));
53
54 size_t i = 0;
55 for (const auto& option : platform_options) {
56 options_key[i] = option.first.c_str();
57 options_value[i] = option.second.c_str();
58 i++;
59 }
60
61 tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
62 platform_, options_size, options_key, options_value, status.c_status);
63
64 free(options_key);
65 free(options_value);
66
67 return status.status();
68 }
69
Initialized() const70 bool TpuPlatform::Initialized() const {
71 return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
72 }
73
~TpuPlatform()74 TpuPlatform::~TpuPlatform() {
75 tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
76 }
77
VisibleDeviceCount() const78 int TpuPlatform::VisibleDeviceCount() const {
79 return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
80 }
81
GetExecutor(const::stream_executor::StreamExecutorConfig & config)82 StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
83 const ::stream_executor::StreamExecutorConfig& config) {
84 return executor_cache_.GetOrCreate(
85 config, [&]() { return GetUncachedExecutor(config); });
86 }
87
88 StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(const::stream_executor::StreamExecutorConfig & config)89 TpuPlatform::GetUncachedExecutor(
90 const ::stream_executor::StreamExecutorConfig& config) {
91 SE_StreamExecutorConfig* c_config =
92 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
93
94 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
95 config.ordinal);
96
97 StatusHelper status;
98 SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
99 platform_, c_config, status.c_status);
100 tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
101 if (!status.ok()) {
102 return status.status();
103 }
104 return std::make_unique<stream_executor::StreamExecutor>(
105 this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
106 }
107
id() const108 ::stream_executor::Platform::Id TpuPlatform::id() const {
109 return TpuPlatform::kId;
110 }
111
Name() const112 const std::string& TpuPlatform::Name() const { return name_; }
113
TpuMemoryLimit()114 int64_t TpuPlatform::TpuMemoryLimit() {
115 return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
116 }
117
ShouldRegisterTpuDeviceToDeviceCopy()118 bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
119 return tpu::ExecutorApiFn()
120 ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
121 }
122
GetTopologyPtr()123 const tensorflow::tpu::TpuTopologyPtr TpuPlatform::GetTopologyPtr() {
124 return tpu::ExecutorApiFn()->TpuPlatform_GetTopologyPtrFn(platform_);
125 }
126
GetTpuHostLocation() const127 const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
128 const {
129 return tpu::TpuHostLocationExternal(
130 tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
131 }
132
version() const133 TpuRuntimeVersion TpuPlatform::version() const {
134 return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
135 }
136
InsertEvent(stream_executor::internal::EventInterface * key,SE_Event * val)137 void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
138 SE_Event* val) {
139 absl::MutexLock lock(&event_map_mu_);
140 event_map_[key] = val;
141 }
142
LookupEvent(stream_executor::internal::EventInterface * key)143 SE_Event* TpuPlatform::LookupEvent(
144 stream_executor::internal::EventInterface* key) {
145 absl::ReaderMutexLock lock(&event_map_mu_);
146 return event_map_.at(key);
147 }
148
EraseEvent(stream_executor::internal::EventInterface * key)149 void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
150 absl::MutexLock lock(&event_map_mu_);
151 event_map_.erase(key);
152 }
153
TpusPerHost(int * tpus)154 Status TpuPlatform::TpusPerHost(int* tpus) {
155 TF_Status* status = TF_NewStatus();
156
157 if (tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn == nullptr) {
158 *tpus = 0;
159 return OkStatus();
160 }
161
162 tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
163 auto ret_status = StatusFromTF_Status(status);
164 TF_DeleteStatus(status);
165 return ret_status;
166 }
167
TpuMemoryLimit(int64_t * memory_limit)168 Status TpuPlatform::TpuMemoryLimit(int64_t* memory_limit) {
169 TF_Status* status = TF_NewStatus();
170
171 if (tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn == nullptr) {
172 *memory_limit = 0;
173 return OkStatus();
174 }
175
176 tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
177 reinterpret_cast<int64_t*>(memory_limit), status);
178 auto ret_status = StatusFromTF_Status(status);
179 TF_DeleteStatus(status);
180 return ret_status;
181 }
182
RegisterTpuPlatform()183 bool RegisterTpuPlatform() {
184 // Silently bail if the underlying TPU C API isn't initialized. This is useful
185 // for code that unconditionally calls RegisterTpuPlatform() but doesn't link
186 // in the underlying TPU library when not running on TPU.
187 if (!tpu::IsStreamExecutorEnabled(tpu::ExecutorApiFn())) {
188 return true;
189 }
190 static bool tpu_platform_registered = false;
191 if (!tpu_platform_registered) {
192 tpu_registered_platform = new TpuPlatform();
193 std::unique_ptr<stream_executor::Platform> platform(
194 tpu_registered_platform);
195 SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
196 std::move(platform)));
197 tpu_platform_registered = true;
198 }
199 return true;
200 }
201
202 } // namespace tpu
203 } // namespace tensorflow
204