xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform.h"
17 
18 #include "tensorflow/c/tf_status.h"
19 #include "tensorflow/c/tf_status_helper.h"
20 #include "tensorflow/compiler/xla/stream_executor/tpu/status_helper.h"
21 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_executor.h"
22 #include "tensorflow/compiler/xla/stream_executor/tpu/tpu_platform_id.h"
23 #include "tensorflow/core/tpu/tpu_api.h"
24 
25 namespace tensorflow {
26 namespace tpu {
27 
28 const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
29 TpuPlatform* tpu_registered_platform = nullptr;
30 
31 using Status = ::stream_executor::port::Status;
32 template <typename T>
33 using StatusOr = ::stream_executor::port::StatusOr<T>;
34 
TpuPlatform()35 TpuPlatform::TpuPlatform() : name_("TPU") {
36   platform_ = tpu::ExecutorApiFn()->TpuPlatform_NewFn();
37   CHECK(platform_ != nullptr);
38 }
39 
GetRegisteredPlatform()40 TpuPlatform* TpuPlatform::GetRegisteredPlatform() {
41   return tpu_registered_platform;
42 }
43 
Initialize(const std::map<std::string,std::string> & platform_options)44 Status TpuPlatform::Initialize(
45     const std::map<std::string, std::string>& platform_options) {
46   StatusHelper status;
47 
48   size_t options_size = platform_options.size();
49   const char** options_key =
50       static_cast<const char**>(malloc(sizeof(const char*) * options_size));
51   const char** options_value =
52       static_cast<const char**>(malloc(sizeof(const char*) * options_size));
53 
54   size_t i = 0;
55   for (const auto& option : platform_options) {
56     options_key[i] = option.first.c_str();
57     options_value[i] = option.second.c_str();
58     i++;
59   }
60 
61   tpu::ExecutorApiFn()->TpuPlatform_InitializeFn(
62       platform_, options_size, options_key, options_value, status.c_status);
63 
64   free(options_key);
65   free(options_value);
66 
67   return status.status();
68 }
69 
Initialized() const70 bool TpuPlatform::Initialized() const {
71   return tpu::ExecutorApiFn()->TpuPlatform_InitializedFn(platform_);
72 }
73 
~TpuPlatform()74 TpuPlatform::~TpuPlatform() {
75   tpu::ExecutorApiFn()->TpuPlatform_FreeFn(platform_);
76 }
77 
VisibleDeviceCount() const78 int TpuPlatform::VisibleDeviceCount() const {
79   return tpu::ExecutorApiFn()->TpuPlatform_VisibleDeviceCountFn(platform_);
80 }
81 
GetExecutor(const::stream_executor::StreamExecutorConfig & config)82 StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor(
83     const ::stream_executor::StreamExecutorConfig& config) {
84   return executor_cache_.GetOrCreate(
85       config, [&]() { return GetUncachedExecutor(config); });
86 }
87 
88 StatusOr<std::unique_ptr<::stream_executor::StreamExecutor>>
GetUncachedExecutor(const::stream_executor::StreamExecutorConfig & config)89 TpuPlatform::GetUncachedExecutor(
90     const ::stream_executor::StreamExecutorConfig& config) {
91   SE_StreamExecutorConfig* c_config =
92       tpu::ExecutorApiFn()->TpuStreamExecutorConfig_DefaultFn();
93 
94   tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn(c_config,
95                                                              config.ordinal);
96 
97   StatusHelper status;
98   SE_StreamExecutor* executor = tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn(
99       platform_, c_config, status.c_status);
100   tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn(c_config);
101   if (!status.ok()) {
102     return status.status();
103   }
104   return std::make_unique<stream_executor::StreamExecutor>(
105       this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
106 }
107 
id() const108 ::stream_executor::Platform::Id TpuPlatform::id() const {
109   return TpuPlatform::kId;
110 }
111 
Name() const112 const std::string& TpuPlatform::Name() const { return name_; }
113 
TpuMemoryLimit()114 int64_t TpuPlatform::TpuMemoryLimit() {
115   return tpu::ExecutorApiFn()->TpuPlatform_TpuMemoryLimitFn(platform_);
116 }
117 
ShouldRegisterTpuDeviceToDeviceCopy()118 bool TpuPlatform::ShouldRegisterTpuDeviceToDeviceCopy() {
119   return tpu::ExecutorApiFn()
120       ->TpuPlatform_ShouldRegisterTpuDeviceToDeviceCopyFn(platform_);
121 }
122 
GetTopologyPtr()123 const tensorflow::tpu::TpuTopologyPtr TpuPlatform::GetTopologyPtr() {
124   return tpu::ExecutorApiFn()->TpuPlatform_GetTopologyPtrFn(platform_);
125 }
126 
GetTpuHostLocation() const127 const tensorflow::tpu::TpuHostLocationExternal TpuPlatform::GetTpuHostLocation()
128     const {
129   return tpu::TpuHostLocationExternal(
130       tpu::ExecutorApiFn()->TpuPlatform_GetHostLocationFn(platform_));
131 }
132 
version() const133 TpuRuntimeVersion TpuPlatform::version() const {
134   return tpu::ExecutorApiFn()->TpuPlatform_GetRuntimeVersionFn(platform_);
135 }
136 
InsertEvent(stream_executor::internal::EventInterface * key,SE_Event * val)137 void TpuPlatform::InsertEvent(stream_executor::internal::EventInterface* key,
138                               SE_Event* val) {
139   absl::MutexLock lock(&event_map_mu_);
140   event_map_[key] = val;
141 }
142 
LookupEvent(stream_executor::internal::EventInterface * key)143 SE_Event* TpuPlatform::LookupEvent(
144     stream_executor::internal::EventInterface* key) {
145   absl::ReaderMutexLock lock(&event_map_mu_);
146   return event_map_.at(key);
147 }
148 
EraseEvent(stream_executor::internal::EventInterface * key)149 void TpuPlatform::EraseEvent(stream_executor::internal::EventInterface* key) {
150   absl::MutexLock lock(&event_map_mu_);
151   event_map_.erase(key);
152 }
153 
TpusPerHost(int * tpus)154 Status TpuPlatform::TpusPerHost(int* tpus) {
155   TF_Status* status = TF_NewStatus();
156 
157   if (tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn == nullptr) {
158     *tpus = 0;
159     return OkStatus();
160   }
161 
162   tpu::OpsApiFn()->TpuConfigurationApi_TpusPerHostFn(tpus, status);
163   auto ret_status = StatusFromTF_Status(status);
164   TF_DeleteStatus(status);
165   return ret_status;
166 }
167 
TpuMemoryLimit(int64_t * memory_limit)168 Status TpuPlatform::TpuMemoryLimit(int64_t* memory_limit) {
169   TF_Status* status = TF_NewStatus();
170 
171   if (tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn == nullptr) {
172     *memory_limit = 0;
173     return OkStatus();
174   }
175 
176   tpu::OpsApiFn()->TpuConfigurationApi_TpuMemoryLimitFn(
177       reinterpret_cast<int64_t*>(memory_limit), status);
178   auto ret_status = StatusFromTF_Status(status);
179   TF_DeleteStatus(status);
180   return ret_status;
181 }
182 
RegisterTpuPlatform()183 bool RegisterTpuPlatform() {
184   // Silently bail if the underlying TPU C API isn't initialized. This is useful
185   // for code that unconditionally calls RegisterTpuPlatform() but doesn't link
186   // in the underlying TPU library when not running on TPU.
187   if (!tpu::IsStreamExecutorEnabled(tpu::ExecutorApiFn())) {
188     return true;
189   }
190   static bool tpu_platform_registered = false;
191   if (!tpu_platform_registered) {
192     tpu_registered_platform = new TpuPlatform();
193     std::unique_ptr<stream_executor::Platform> platform(
194         tpu_registered_platform);
195     SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
196         std::move(platform)));
197     tpu_platform_registered = true;
198   }
199   return true;
200 }
201 
202 }  // namespace tpu
203 }  // namespace tensorflow
204