1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #if TENSORFLOW_USE_ROCM
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/container/fixed_array.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/abi.h"
29 #include "tensorflow/core/platform/env_time.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/profiler/backends/cpu/annotation_stack.h"
34 #include "tensorflow/core/profiler/backends/gpu/rocm_tracer.h"
35 #include "tensorflow/core/profiler/lib/profiler_factory.h"
36 #include "tensorflow/core/profiler/lib/profiler_interface.h"
37 #include "tensorflow/core/profiler/utils/parse_annotation.h"
38 #include "tensorflow/core/profiler/utils/xplane_builder.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/util/env_var.h"
42
43 namespace tensorflow {
44 namespace profiler {
45
46 namespace {
47 // Set the all XLines of specified XPlane to starting walltime.
48 // Events time in both host and device planes are CUTPI timestamps.
49 // We set initial RocmTracer timestamp as start time for all lines to reflect
50 // this fact. Eventually we change line start time to corresponding
51 // start_walltime_ns to normalize with CPU wall time.
NormalizeTimeStamps(XPlaneBuilder * plane,uint64_t start_walltime_ns)52 static void NormalizeTimeStamps(XPlaneBuilder* plane,
53 uint64_t start_walltime_ns) {
54 plane->ForEachLine([&](tensorflow::profiler::XLineBuilder line) {
55 line.SetTimestampNs(start_walltime_ns);
56 });
57 }
58
GetDeviceXLineName(int64_t stream_id,absl::flat_hash_set<RocmTracerEventType> & event_types)59 std::string GetDeviceXLineName(
60 int64_t stream_id, absl::flat_hash_set<RocmTracerEventType>& event_types) {
61 std::string line_name = absl::StrCat("Stream #", stream_id);
62 event_types.erase(RocmTracerEventType::Unsupported);
63 if (event_types.empty()) return line_name;
64 std::vector<const char*> type_names;
65 for (const auto event_type : event_types) {
66 type_names.emplace_back(GetRocmTracerEventTypeName(event_type));
67 }
68 return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")");
69 }
70
71 } // namespace
72
73 class RocmTraceCollectorImpl : public profiler::RocmTraceCollector {
74 public:
RocmTraceCollectorImpl(const RocmTraceCollectorOptions & options,uint64_t start_walltime_ns,uint64_t start_gputime_ns)75 RocmTraceCollectorImpl(const RocmTraceCollectorOptions& options,
76 uint64_t start_walltime_ns, uint64_t start_gputime_ns)
77 : RocmTraceCollector(options),
78 num_callback_events_(0),
79 num_activity_events_(0),
80 start_walltime_ns_(start_walltime_ns),
81 start_gputime_ns_(start_gputime_ns),
82 per_device_collector_(options.num_gpus) {}
83
AddEvent(RocmTracerEvent && event,bool is_auxiliary)84 void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) override {
85 mutex_lock lock(event_maps_mutex_);
86
87 if (event.source == RocmTracerEventSource::ApiCallback && !is_auxiliary) {
88 if (num_callback_events_ > options_.max_callback_api_events) {
89 OnEventsDropped("max callback event capacity reached",
90 event.correlation_id);
91 DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
92 return;
93 }
94 num_callback_events_++;
95 } else if (event.source == RocmTracerEventSource::Activity &&
96 event.domain == RocmTracerEventDomain::HIP_API) {
97 // we do not count HIP_OPS activities.
98 if (num_activity_events_ > options_.max_activity_api_events) {
99 OnEventsDropped("max activity event capacity reached",
100 event.correlation_id);
101 DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
102 return;
103 }
104 num_activity_events_++;
105 }
106
107 bool emplace_result = false;
108 if (event.source == RocmTracerEventSource::ApiCallback) {
109 auto& target_api_event_map =
110 (is_auxiliary) ? auxiliary_api_events_map_ : api_events_map_;
111 std::tie(std::ignore, emplace_result) =
112 target_api_event_map.emplace(event.correlation_id, std::move(event));
113 } else if (event.source == RocmTracerEventSource::Activity) {
114 if (event.domain == RocmTracerEventDomain::HIP_API) {
115 std::tie(std::ignore, emplace_result) =
116 activity_api_events_map_.emplace(event.correlation_id,
117 std::move(event));
118 } else if (event.domain == RocmTracerEventDomain::HCC_OPS) {
119 auto result = activity_ops_events_map_.emplace(
120 event.correlation_id, std::vector<RocmTracerEvent>{});
121 result.first->second.push_back(std::move(event));
122 emplace_result = true; // we always accept Hip-Ops events
123 }
124 }
125 if (!emplace_result) {
126 OnEventsDropped("event with duplicate correlation_id was received.",
127 event.correlation_id);
128 DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
129 }
130 }
131
OnEventsDropped(const std::string & reason,uint32_t correlation_id)132 void OnEventsDropped(const std::string& reason,
133 uint32_t correlation_id) override {
134 LOG(INFO) << "RocmTracerEvent dropped (correlation_id=" << correlation_id
135 << ",) : " << reason << ".";
136 }
137
Flush()138 void Flush() override {
139 mutex_lock lock(event_maps_mutex_);
140 auto& aggregated_events_ = ApiActivityInfoExchange();
141
142 VLOG(3) << "RocmTraceCollector collected " << num_callback_events_
143 << " callback events, " << num_activity_events_
144 << " activity events, and aggregated them into "
145 << aggregated_events_.size() << " events.";
146
147 for (auto& event : aggregated_events_) {
148 if (event.device_id >= options_.num_gpus) {
149 OnEventsDropped("device id >= num gpus", event.correlation_id);
150 DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
151 LOG(WARNING) << "A ROCm profiler event record with wrong device ID "
152 "dropped! Type="
153 << GetRocmTracerEventTypeName(event.type);
154 continue;
155 }
156
157 activity_api_events_map_.clear();
158 activity_ops_events_map_.clear();
159 api_events_map_.clear();
160 auxiliary_api_events_map_.clear();
161
162 per_device_collector_[event.device_id].AddEvent(event);
163 }
164
165 for (int i = 0; i < options_.num_gpus; ++i) {
166 per_device_collector_[i].SortByStartTime();
167 }
168 }
169
Export(XSpace * space)170 void Export(XSpace* space) {
171 uint64_t end_gputime_ns = RocmTracer::GetTimestamp();
172 XPlaneBuilder host_plane(
173 FindOrAddMutablePlaneWithName(space, kRoctracerApiPlaneName));
174 for (int i = 0; i < options_.num_gpus; ++i) {
175 std::string name = GpuPlaneName(i);
176 XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name));
177 device_plane.SetId(i);
178 // Calculate device capabilities before flushing, so that device
179 // properties are available to the occupancy calculator in export().
180 per_device_collector_[i].GetDeviceCapabilities(i, &device_plane);
181 per_device_collector_[i].Export(start_walltime_ns_, start_gputime_ns_,
182 end_gputime_ns, &device_plane,
183 &host_plane);
184
185 NormalizeTimeStamps(&device_plane, start_walltime_ns_);
186 }
187 NormalizeTimeStamps(&host_plane, start_walltime_ns_);
188 }
189
190 private:
191 std::atomic<int> num_callback_events_;
192 std::atomic<int> num_activity_events_;
193 uint64_t start_walltime_ns_;
194 uint64_t start_gputime_ns_;
195
196 mutex event_maps_mutex_;
197 absl::flat_hash_map<uint32, RocmTracerEvent> api_events_map_
198 TF_GUARDED_BY(event_maps_mutex_);
199 absl::flat_hash_map<uint32, RocmTracerEvent> activity_api_events_map_
200 TF_GUARDED_BY(event_maps_mutex_);
201
202 /* Some apis such as MEMSETD32 (based on an observation with ResNet50),
203 trigger multiple HIP ops domain activities. We keep them in a vector and
204 merge them with api activities at flush time.
205 */
206 absl::flat_hash_map<uint32, std::vector<RocmTracerEvent>>
207 activity_ops_events_map_ TF_GUARDED_BY(event_maps_mutex_);
208 // This is for the APIs that we track because we need some information from
209 // them to populate the corresponding activity that we actually track.
210 absl::flat_hash_map<uint32, RocmTracerEvent> auxiliary_api_events_map_
211 TF_GUARDED_BY(event_maps_mutex_);
212
ApiActivityInfoExchange()213 const std::vector<RocmTracerEvent> ApiActivityInfoExchange() {
214 /* Different from CUDA, roctracer activity records are not enough to fill a
215 TF event. For most of the activities, we need to enable the corresponding
216 API callsbacks (we call them auxiliary API callbacks) to capture the
217 necessary fields from them using the correlation id. The purpose of this
218 function is to let APIs and activities exchange information to reach a
219 state very similar to TF CUDA and getting ready to dump the event.
220 */
221
222 // Copying info from HIP-OPS activities to HIP-API activities
223 /*HIP-API activities <<==== HIP-OPS activities*/
224 auto activity_api_events_map_iter = activity_api_events_map_.begin();
225 while (activity_api_events_map_iter != activity_api_events_map_.end()) {
226 uint32_t activity_corr_id = activity_api_events_map_iter->first;
227 RocmTracerEvent& activity_api_event =
228 activity_api_events_map_iter->second;
229
230 bool result = false;
231 switch (activity_api_event.type) {
232 case RocmTracerEventType::Kernel:
233 case RocmTracerEventType::Memset: {
234 // KERNEL & MEMSET
235 auto iter =
236 activity_ops_events_map_.find(activity_api_event.correlation_id);
237 result = (iter != activity_ops_events_map_.end());
238 if (result) {
239 // since the key exist in the map, there should be at least one item
240 // in the vector
241 activity_api_event.device_id = iter->second.front().device_id;
242 activity_api_event.stream_id = iter->second.front().stream_id;
243 // we initialize the start time and end time based on the first
244 // element
245 activity_api_event.start_time_ns =
246 iter->second.front().start_time_ns;
247 activity_api_event.end_time_ns = iter->second.front().end_time_ns;
248 for (auto& kernel_activity_op : iter->second) {
249 activity_api_event.start_time_ns =
250 std::min(activity_api_event.start_time_ns,
251 kernel_activity_op.start_time_ns);
252 activity_api_event.end_time_ns =
253 std::max(activity_api_event.end_time_ns,
254 kernel_activity_op.end_time_ns);
255 }
256 }
257 break;
258 }
259 case RocmTracerEventType::MemcpyD2D:
260 case RocmTracerEventType::MemcpyH2D:
261 case RocmTracerEventType::MemcpyD2H:
262 case RocmTracerEventType::MemcpyOther: {
263 // MEMCPY
264 auto iter =
265 activity_ops_events_map_.find(activity_api_event.correlation_id);
266 result = (iter != activity_ops_events_map_.end());
267 if (result) {
268 // since the key exist in the map, there should be at least one item
269 // in the vector
270 activity_api_event.device_id = iter->second.front().device_id;
271 activity_api_event.memcpy_info.destination =
272 iter->second.front()
273 .memcpy_info.destination; // similar to CUDA, it is the
274 // same as device_id
275 activity_api_event.stream_id = iter->second.front().stream_id;
276 /* IMPORTANT: it seems that the HCC timing is only valid for
277 * Synchronous memcpy activities*/
278 if (!activity_api_event.memcpy_info.async) {
279 activity_api_event.start_time_ns =
280 iter->second.front().start_time_ns;
281 activity_api_event.end_time_ns = iter->second.front().end_time_ns;
282 for (auto& kernel_activity_op : iter->second) {
283 activity_api_event.start_time_ns =
284 std::min(activity_api_event.start_time_ns,
285 kernel_activity_op.start_time_ns);
286 activity_api_event.end_time_ns =
287 std::max(activity_api_event.end_time_ns,
288 kernel_activity_op.end_time_ns);
289 }
290 }
291 }
292 break;
293 }
294 default:
295 // nothing to do for the rest
296 result = true;
297 break;
298 }
299 if (!result) {
300 OnEventsDropped(
301 "A HIP-API activity with missing HIP-OPS activity was found",
302 activity_api_event.correlation_id);
303 DumpRocmTracerEvent(activity_api_event, 0, 0, ". Dropped!");
304 activity_api_events_map_.erase(activity_api_events_map_iter++);
305 } else {
306 ++activity_api_events_map_iter;
307 }
308 }
309
310 // the event vector to be returned
311 std::vector<RocmTracerEvent> aggregated_events;
312
313 // Copying info from HIP activities to HIP API callbacks
314 /*HIP-API call backs <<==== HIP-API activities*/
315 for (auto& api_iter : api_events_map_) {
316 RocmTracerEvent& api_event = api_iter.second;
317 auto iter = activity_api_events_map_.find(api_event.correlation_id);
318 switch (api_event.type) {
319 /*KERNEL API*/
320 case RocmTracerEventType::Kernel: {
321 aggregated_events.push_back(api_event);
322 break;
323 }
324 /*MEMCPY API*/
325 case RocmTracerEventType::MemcpyD2H:
326 case RocmTracerEventType::MemcpyH2D:
327 case RocmTracerEventType::MemcpyD2D:
328 case RocmTracerEventType::MemcpyOther: {
329 if (iter != activity_api_events_map_.end()) {
330 api_event.device_id = iter->second.device_id;
331 api_event.memcpy_info.destination =
332 api_event.device_id; // Similar to CUDA
333 aggregated_events.push_back(api_event);
334 } else {
335 OnEventsDropped(
336 "A Memcpy event from HIP API discarded."
337 " Could not find the counterpart activity.",
338 api_event.correlation_id);
339 DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
340 }
341 break;
342 }
343 /*MEMSET API*/
344 case RocmTracerEventType::Memset: {
345 if (iter != activity_api_events_map_.end()) {
346 api_event.device_id = iter->second.device_id;
347
348 aggregated_events.push_back(api_event);
349 } else {
350 OnEventsDropped(
351 "A Memset event from HIP API discarded."
352 " Could not find the counterpart activity.",
353 api_event.correlation_id);
354 DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
355 }
356 break;
357 }
358 /*MALLOC API, FREE API*/
359 case RocmTracerEventType::MemoryAlloc:
360 case RocmTracerEventType::MemoryFree: {
361 // no missing info
362 aggregated_events.push_back(api_event);
363 break;
364 }
365 /*SYNCHRONIZATION API*/
366 case RocmTracerEventType::Synchronization: {
367 // no missing info
368 aggregated_events.push_back(api_event);
369 break;
370 }
371 default:
372 OnEventsDropped("Missing API-Activity information exchange. Dropped!",
373 api_event.correlation_id);
374 DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
375 LOG(WARNING) << "A ROCm API event type with unimplemented activity "
376 "merge dropped! "
377 "Type="
378 << GetRocmTracerEventTypeName(api_event.type);
379 break;
380 } // end switch(api_event.type)
381 }
382
383 // Copying info from HIP API callbacks to HIP API activities
384 // API ACTIVITIES<<====API-CB
385 for (auto& activity_iter : activity_api_events_map_) {
386 RocmTracerEvent& activity_event = activity_iter.second;
387 // finding the corresponding activity either in the api_call backs or the
388 // axuilarities
389 auto iter = api_events_map_.find(activity_event.correlation_id);
390
391 iter = (iter == api_events_map_.end())
392 ? auxiliary_api_events_map_.find(activity_event.correlation_id)
393 : iter;
394 switch (activity_event.type) {
395 /*KERNEL ACTIVITY*/
396 case RocmTracerEventType::Kernel: {
397 if (iter != api_events_map_.end() ||
398 iter != auxiliary_api_events_map_.end()) {
399 activity_event.name = iter->second.name;
400 activity_event.kernel_info = iter->second.kernel_info;
401 aggregated_events.push_back(activity_event);
402 } else {
403 OnEventsDropped(
404 "A Kernel event activity was discarded."
405 " Could not find the counterpart API callback.",
406 activity_event.correlation_id);
407 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
408 }
409 break;
410 }
411 /*MEMCPY ACTIVITY*/
412 case RocmTracerEventType::MemcpyD2H:
413 case RocmTracerEventType::MemcpyH2D:
414 case RocmTracerEventType::MemcpyD2D:
415 case RocmTracerEventType::MemcpyOther: {
416 if (iter != api_events_map_.end() ||
417 iter != auxiliary_api_events_map_.end()) {
418 activity_event.memcpy_info = iter->second.memcpy_info;
419 aggregated_events.push_back(activity_event);
420 } else {
421 OnEventsDropped(
422 "A Memcpy event activity was discarded."
423 " Could not find the counterpart API callback.",
424 activity_event.correlation_id);
425 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
426 }
427 break;
428 }
429 /*MEMSET ACTIVITY*/
430 case RocmTracerEventType::Memset: {
431 if (iter != api_events_map_.end() ||
432 iter != auxiliary_api_events_map_.end()) {
433 activity_event.memset_info = iter->second.memset_info;
434 aggregated_events.push_back(activity_event);
435
436 } else {
437 OnEventsDropped(
438 "A Memset event activity was discarded."
439 " Could not find the counterpart API callback.",
440 activity_event.correlation_id);
441 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
442 }
443 break;
444 }
445 /*MALLOC ACTIVITY, FREE ACTIVITY*/
446 case RocmTracerEventType::MemoryAlloc:
447 case RocmTracerEventType::MemoryFree: {
448 if (iter != api_events_map_.end() ||
449 iter != auxiliary_api_events_map_.end()) {
450 activity_event.device_id = iter->second.device_id;
451 aggregated_events.push_back(activity_event);
452 } else {
453 OnEventsDropped(
454 "A Malloc/Free activity was discarded."
455 " Could not find the counterpart API callback.",
456 activity_event.correlation_id);
457 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
458 }
459 break;
460 }
461 /*SYNCHRONIZATION ACTIVITY*/
462 case RocmTracerEventType::Synchronization: {
463 if (iter != api_events_map_.end() ||
464 iter != auxiliary_api_events_map_.end()) {
465 // CUDA does not provide device ID for these activities.
466 // Interestingly, TF-profiler by default set the device id to 0 for
467 // CuptiTracerEvent.
468 // RocmTracerEvent type, set device by default to an unvalid
469 // device-id value. To be consistent with CUDA (in terms of having a
470 // logically valid value for device id) we update the device-id to
471 // its correct value
472 activity_event.device_id = iter->second.device_id;
473 aggregated_events.push_back(activity_event);
474 } else {
475 OnEventsDropped(
476 "A sync event activity was discarded."
477 " Could not find the counterpart API callback.",
478 activity_event.correlation_id);
479 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
480 }
481 break;
482 }
483 default:
484 OnEventsDropped("Missing API-Activity information exchange. Dropped!",
485 activity_event.correlation_id);
486 DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
487 LOG(WARNING) << "A ROCm activity event with unimplemented API "
488 "callback merge dropped! "
489 "Type="
490 << GetRocmTracerEventTypeName(activity_event.type);
491 break;
492 } // end switch(activity_event.type)
493 }
494
495 return aggregated_events;
496 }
497 struct RocmDeviceOccupancyParams {
498 hipFuncAttributes attributes = {};
499 int block_size = 0;
500 size_t dynamic_smem_size = 0;
501 void* func_ptr;
502
operator ==(const RocmDeviceOccupancyParams & lhs,const RocmDeviceOccupancyParams & rhs)503 friend bool operator==(const RocmDeviceOccupancyParams& lhs,
504 const RocmDeviceOccupancyParams& rhs) {
505 return 0 == memcmp(&lhs, &rhs, sizeof(lhs));
506 }
507
508 template <typename H>
AbslHashValue(H hash_state,const RocmDeviceOccupancyParams & params)509 friend H AbslHashValue(H hash_state,
510 const RocmDeviceOccupancyParams& params) {
511 return H::combine(
512 std::move(hash_state), params.attributes.maxThreadsPerBlock,
513 params.attributes.numRegs, params.attributes.sharedSizeBytes,
514 params.attributes.maxDynamicSharedSizeBytes, params.block_size,
515 params.dynamic_smem_size, params.func_ptr);
516 }
517 };
518
519 struct OccupancyStats {
520 double occupancy_pct = 0.0;
521 int min_grid_size = 0;
522 int suggested_block_size = 0;
523 };
524 struct CorrelationInfo {
CorrelationInfotensorflow::profiler::RocmTraceCollectorImpl::CorrelationInfo525 CorrelationInfo(uint32_t t, uint32_t e)
526 : thread_id(t), enqueue_time_ns(e) {}
527 uint32_t thread_id;
528 uint64_t enqueue_time_ns;
529 };
530
531 struct PerDeviceCollector {
GetDeviceCapabilitiestensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector532 void GetDeviceCapabilities(int32_t device_ordinal,
533 XPlaneBuilder* device_plane) {
534 device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata(
535 GetStatTypeStr(StatType::kDevVendor)),
536 kDeviceVendorAMD);
537
538 if (hipGetDeviceProperties(&device_properties_, device_ordinal) !=
539 hipSuccess)
540 return;
541
542 auto clock_rate_in_khz =
543 device_properties_.clockRate; // this is also in Khz
544 if (clock_rate_in_khz) {
545 device_plane->AddStatValue(
546 *device_plane->GetOrCreateStatMetadata(
547 GetStatTypeStr(StatType::kDevCapClockRateKHz)),
548 clock_rate_in_khz);
549 }
550
551 auto core_count = device_properties_.multiProcessorCount;
552 if (core_count) {
553 device_plane->AddStatValue(
554 *device_plane->GetOrCreateStatMetadata(
555 GetStatTypeStr(StatType::kDevCapCoreCount)),
556 core_count);
557 }
558
559 auto mem_clock_khz = device_properties_.memoryClockRate;
560 auto mem_bus_width_bits = device_properties_.memoryBusWidth;
561
562 if (mem_clock_khz && mem_bus_width_bits) {
563 // Times 2 because HBM is DDR memory; it gets two data bits per each
564 // data lane.
565 auto memory_bandwidth =
566 uint64{2} * (mem_clock_khz)*1000 * (mem_bus_width_bits) / 8;
567 device_plane->AddStatValue(
568 *device_plane->GetOrCreateStatMetadata(
569 GetStatTypeStr(StatType::kDevCapMemoryBandwidth)),
570 memory_bandwidth);
571 }
572
573 size_t total_memory = device_properties_.totalGlobalMem;
574 if (total_memory) {
575 device_plane->AddStatValue(
576 *device_plane->GetOrCreateStatMetadata(
577 GetStatTypeStr(StatType::kDevCapMemorySize)),
578 static_cast<uint64>(total_memory));
579 }
580
581 auto compute_capability_major = device_properties_.major;
582 if (compute_capability_major) {
583 device_plane->AddStatValue(
584 *device_plane->GetOrCreateStatMetadata(
585 GetStatTypeStr(StatType::kDevCapComputeCapMajor)),
586 compute_capability_major);
587 }
588 auto compute_capability_minor = device_properties_.minor;
589 if (compute_capability_minor) {
590 device_plane->AddStatValue(
591 *device_plane->GetOrCreateStatMetadata(
592 GetStatTypeStr(StatType::kDevCapComputeCapMinor)),
593 compute_capability_minor);
594 }
595 }
596
ToXStattensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector597 inline std::string ToXStat(const KernelDetails& kernel_info,
598 double occupancy_pct) {
599 return absl::StrCat(
600 "regs:", kernel_info.registers_per_thread,
601 " static_shared:", kernel_info.static_shared_memory_usage,
602 " dynamic_shared:", kernel_info.dynamic_shared_memory_usage,
603 " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",",
604 kernel_info.grid_z, " block:", kernel_info.block_x, ",",
605 kernel_info.block_y, ",", kernel_info.block_z,
606 " occ_pct:", occupancy_pct);
607 }
GetOccupancytensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector608 OccupancyStats GetOccupancy(const RocmDeviceOccupancyParams& params) const {
609 // TODO(rocm-profiler): hipOccupancyMaxActiveBlocksPerMultiprocessor only
610 // return hipSuccess for HIP_API_ID_hipLaunchKernel
611
612 OccupancyStats stats;
613 int number_of_active_blocks;
614 hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
615 &number_of_active_blocks, params.func_ptr, params.block_size,
616 params.dynamic_smem_size);
617
618 if (err != hipError_t::hipSuccess) {
619 return {};
620 }
621
622 stats.occupancy_pct = number_of_active_blocks * params.block_size * 100;
623 stats.occupancy_pct /= device_properties_.maxThreadsPerMultiProcessor;
624
625 err = hipOccupancyMaxPotentialBlockSize(
626 &stats.min_grid_size, &stats.suggested_block_size, params.func_ptr,
627 params.dynamic_smem_size, 0);
628
629 if (err != hipError_t::hipSuccess) {
630 return {};
631 }
632
633 return stats;
634 }
AddEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector635 void AddEvent(const RocmTracerEvent& event) {
636 mutex_lock l(events_mutex);
637 if (event.source == RocmTracerEventSource::ApiCallback) {
638 // Cupti api callback events were used to populate launch times etc.
639 if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) {
640 correlation_info_.insert(
641 {event.correlation_id,
642 CorrelationInfo(event.thread_id, event.start_time_ns)});
643 }
644 events.emplace_back(std::move(event));
645 } else {
646 // Cupti activity events measure device times etc.
647 events.emplace_back(std::move(event));
648 }
649 }
650
SortByStartTimetensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector651 void SortByStartTime() {
652 mutex_lock lock(events_mutex);
653 std::sort(
654 events.begin(), events.end(),
655 [](const RocmTracerEvent& event1, const RocmTracerEvent& event2) {
656 return event1.start_time_ns < event2.start_time_ns;
657 });
658 }
659
CreateXEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector660 void CreateXEvent(const RocmTracerEvent& event, XPlaneBuilder* plane,
661 uint64_t start_gpu_ns, uint64_t end_gpu_ns,
662 XLineBuilder* line) {
663 if (event.start_time_ns < start_gpu_ns ||
664 event.end_time_ns > end_gpu_ns ||
665 event.start_time_ns > event.end_time_ns) {
666 VLOG(2) << "events have abnormal timestamps:" << event.name
667 << " start time(ns): " << event.start_time_ns
668 << " end time(ns): " << event.end_time_ns
669 << " start gpu(ns):" << start_gpu_ns
670 << " end gpu(ns):" << end_gpu_ns
671 << " corr. id:" << event.correlation_id;
672 return;
673 }
674 std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str());
675 if (kernel_name.empty()) {
676 kernel_name = GetRocmTracerEventTypeName(event.type);
677 }
678 XEventMetadata* event_metadata =
679 plane->GetOrCreateEventMetadata(std::move(kernel_name));
680 XEventBuilder xevent = line->AddEvent(*event_metadata);
681 VLOG(7) << "Adding event to line=" << line->Id();
682 xevent.SetTimestampNs(event.start_time_ns);
683 xevent.SetEndTimestampNs(event.end_time_ns);
684 if (event.source == RocmTracerEventSource::ApiCallback) {
685 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
686 GetStatTypeStr(StatType::kDeviceId)),
687 event.device_id);
688 }
689 if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) {
690 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
691 GetStatTypeStr(StatType::kCorrelationId)),
692 event.correlation_id);
693 }
694 if (!event.roctx_range.empty()) {
695 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
696 GetStatTypeStr(StatType::kNVTXRange)),
697 *plane->GetOrCreateStatMetadata(event.roctx_range));
698 }
699 // if (event.context_id != CuptiTracerEvent::kInvalidContextId) {
700 // xevent.AddStatValue(
701 // *plane->GetOrCreateStatMetadata(
702 // GetStatTypeStr(StatType::kContextId)),
703 // absl::StrCat("$$", static_cast<uint64>(event.context_id)));
704 // }
705
706 if (event.type == RocmTracerEventType::Kernel &&
707 event.source == RocmTracerEventSource::Activity) {
708 RocmDeviceOccupancyParams params{};
709 params.attributes.maxThreadsPerBlock = INT_MAX;
710 params.attributes.numRegs =
711 static_cast<int>(event.kernel_info.registers_per_thread);
712 params.attributes.sharedSizeBytes =
713 event.kernel_info.static_shared_memory_usage;
714 // params.attributes.partitionedGCConfig = PARTITIONED_GC_OFF;
715 // params.attributes.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
716 params.attributes.maxDynamicSharedSizeBytes = 0;
717 params.block_size = static_cast<int>(event.kernel_info.block_x *
718 event.kernel_info.block_y *
719 event.kernel_info.block_z);
720
721 params.dynamic_smem_size =
722 event.kernel_info.dynamic_shared_memory_usage;
723 params.func_ptr = event.kernel_info.func_ptr;
724
725 OccupancyStats& occ_stats = occupancy_cache_[params];
726 if (occ_stats.occupancy_pct == 0.0) {
727 occ_stats = GetOccupancy(params);
728 }
729 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
730 StatType::kTheoreticalOccupancyPct)),
731 occ_stats.occupancy_pct);
732 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
733 StatType::kOccupancyMinGridSize)),
734 static_cast<int32>(occ_stats.min_grid_size));
735 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
736 StatType::kOccupancySuggestedBlockSize)),
737 static_cast<int32>(occ_stats.suggested_block_size));
738 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
739 GetStatTypeStr(StatType::kKernelDetails)),
740 *plane->GetOrCreateStatMetadata(ToXStat(
741 event.kernel_info, occ_stats.occupancy_pct)));
742 } else if (event.type == RocmTracerEventType::MemcpyH2D ||
743 event.type == RocmTracerEventType::MemcpyD2H ||
744 event.type == RocmTracerEventType::MemcpyD2D ||
745 event.type == RocmTracerEventType::MemcpyP2P ||
746 event.type == RocmTracerEventType::MemcpyOther) {
747 VLOG(7) << "Add Memcpy stat";
748 const auto& memcpy_info = event.memcpy_info;
749 std::string memcpy_details = absl::StrCat(
750 // TODO(rocm-profiler): we need to discover the memory kind similar
751 // to CUDA
752 "kind:", "Unknown", " size:", memcpy_info.num_bytes,
753 " dest:", memcpy_info.destination, " async:", memcpy_info.async);
754 xevent.AddStatValue(
755 *plane->GetOrCreateStatMetadata(
756 GetStatTypeStr(StatType::kMemcpyDetails)),
757 *plane->GetOrCreateStatMetadata(std::move(memcpy_details)));
758 } else if (event.type == RocmTracerEventType::MemoryAlloc) {
759 VLOG(7) << "Add MemAlloc stat";
760 std::string value =
761 // TODO(rocm-profiler): we need to discover the memory kind similar
762 // to CUDA
763 absl::StrCat("kind:", "Unknown",
764 " num_bytes:", event.memalloc_info.num_bytes);
765 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
766 GetStatTypeStr(StatType::kMemallocDetails)),
767 *plane->GetOrCreateStatMetadata(std::move(value)));
768 } else if (event.type == RocmTracerEventType::MemoryFree) {
769 VLOG(7) << "Add MemFree stat";
770 std::string value =
771 // TODO(rocm-profiler): we need to discover the memory kind similar
772 // to CUDA
773 absl::StrCat("kind:", "Unknown",
774 " num_bytes:", event.memalloc_info.num_bytes);
775 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
776 GetStatTypeStr(StatType::kMemFreeDetails)),
777 *plane->GetOrCreateStatMetadata(std::move(value)));
778 } else if (event.type == RocmTracerEventType::Memset) {
779 VLOG(7) << "Add Memset stat";
780 auto value =
781 // TODO(rocm-profiler): we need to discover the memory kind similar
782 // to CUDA
783 absl::StrCat("kind:", "Unknown",
784 " num_bytes:", event.memset_info.num_bytes,
785 " async:", event.memset_info.async);
786 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
787 GetStatTypeStr(StatType::kMemsetDetails)),
788 *plane->GetOrCreateStatMetadata(std::move(value)));
789 }
790 // TODO(rocm-profiler): we need to support the following event type
791 /* else if (event.type == CuptiTracerEventType::MemoryResidency) {
792 VLOG(7) << "Add MemoryResidency stat";
793 std::string value = absl::StrCat(
794 "kind:", GetMemoryKindName(event.memory_residency_info.kind),
795 " num_bytes:", event.memory_residency_info.num_bytes,
796 " addr:", event.memory_residency_info.address);
797 xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
798 StatType::kMemoryResidencyDetails)),
799 *plane->GetOrCreateStatMetadata(std::move(value)));
800 } */
801
802 std::vector<Annotation> annotation_stack =
803 ParseAnnotationStack(event.annotation);
804 if (!annotation_stack.empty()) {
805 xevent.AddStatValue(
806 *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)),
807 *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name));
808 }
809 // If multiple metadata have the same key name, show the values from the
810 // top of the stack (innermost annotation). Concatenate the values from
811 // "hlo_op".
812 absl::flat_hash_set<absl::string_view> key_set;
813
814 for (auto annotation = annotation_stack.rbegin();
815 annotation != annotation_stack.rend(); ++annotation) {
816 for (const Annotation::Metadata& metadata : annotation->metadata) {
817 if (key_set.insert(metadata.key).second) {
818 xevent.ParseAndAddStatValue(
819 *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
820 }
821 }
822 }
823 }
IsHostEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector824 bool IsHostEvent(const RocmTracerEvent& event, int64* line_id) {
825 // DriverCallback(i.e. kernel launching) events are host events.
826 if (event.source == RocmTracerEventSource::ApiCallback) {
827 *line_id = event.thread_id;
828 return true;
829 } else { // activities
830 *line_id = event.stream_id;
831 return false;
832 }
833
834 // TODO(rocm-profiler): do we have such a report in rocm?
835 // Non-overhead activity events are device events.
836 /* if (event.type != CuptiTracerEventType::Overhead) {
837 *line_id = event.stream_id;
838 return false;
839 } */
840 // Overhead events can be associated with a thread or a stream, etc.
841 // If a valid thread id is specified, we consider it as a host event.
842 //
843
844 if (event.stream_id != RocmTracerEvent::kInvalidStreamId) {
845 *line_id = event.stream_id;
846 return false;
847 } else if (event.thread_id != RocmTracerEvent::kInvalidThreadId &&
848 event.thread_id != 0) {
849 *line_id = event.thread_id;
850 return true;
851 } else {
852 *line_id = kThreadIdOverhead;
853 return false;
854 }
855 }
Exporttensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector856 void Export(uint64_t start_walltime_ns, uint64_t start_gputime_ns,
857 uint64_t end_gputime_ns, XPlaneBuilder* device_plane,
858 XPlaneBuilder* host_plane) {
859 int host_ev_cnt = 0, dev_ev_cnt = 0;
860 mutex_lock l(events_mutex);
861 // Tracking event types per line.
862 absl::flat_hash_map<int64, absl::flat_hash_set<RocmTracerEventType>>
863 events_types_per_line;
864 for (const RocmTracerEvent& event : events) {
865 int64_t line_id = RocmTracerEvent::kInvalidThreadId;
866 bool is_host_event = IsHostEvent(event, &line_id);
867
868 if (is_host_event) {
869 host_ev_cnt++;
870 } else {
871 dev_ev_cnt++;
872 }
873
874 if (line_id == RocmTracerEvent::kInvalidThreadId ||
875 line_id == RocmTracerEvent::kInvalidStreamId) {
876 VLOG(3) << "Ignoring event, type=" << static_cast<int>(event.type);
877 continue;
878 }
879 auto* plane = is_host_event ? host_plane : device_plane;
880 VLOG(9) << "Event"
881 << " type=" << static_cast<int>(event.type)
882 << " line_id=" << line_id
883 << (is_host_event ? " host plane=" : " device plane=")
884 << plane->Name();
885 XLineBuilder line = plane->GetOrCreateLine(line_id);
886 line.SetTimestampNs(start_gputime_ns);
887 CreateXEvent(event, plane, start_gputime_ns, end_gputime_ns, &line);
888 events_types_per_line[line_id].emplace(event.type);
889 }
890 device_plane->ForEachLine([&](XLineBuilder line) {
891 line.SetName(
892 GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()]));
893 });
894 host_plane->ForEachLine([&](XLineBuilder line) {
895 line.SetName(absl::StrCat("Host Threads/", line.Id()));
896 });
897 size_t num_events = events.size();
898 events.clear();
899 }
900
901 mutex events_mutex;
902 std::vector<RocmTracerEvent> events TF_GUARDED_BY(events_mutex);
903 absl::flat_hash_map<uint32, CorrelationInfo> correlation_info_
904 TF_GUARDED_BY(events_mutex);
905 absl::flat_hash_map<RocmDeviceOccupancyParams, OccupancyStats>
906 occupancy_cache_;
907 hipDeviceProp_t device_properties_;
908 };
909
910 absl::FixedArray<PerDeviceCollector> per_device_collector_;
911 };
912
913 // GpuTracer for ROCm GPU.
914 class GpuTracer : public profiler::ProfilerInterface {
915 public:
GpuTracer(RocmTracer * rocm_tracer)916 GpuTracer(RocmTracer* rocm_tracer) : rocm_tracer_(rocm_tracer) {
917 LOG(INFO) << "GpuTracer created.";
918 }
~GpuTracer()919 ~GpuTracer() override {}
920
921 // GpuTracer interface:
922 Status Start() override;
923 Status Stop() override;
924 Status CollectData(XSpace* space) override;
925
926 private:
927 Status DoStart();
928 Status DoStop();
929 Status DoCollectData(XSpace* space);
930
931 RocmTracerOptions GetRocmTracerOptions();
932
933 RocmTraceCollectorOptions GetRocmTraceCollectorOptions(uint32_t num_gpus);
934
935 enum State {
936 kNotStarted,
937 kStartedOk,
938 kStartedError,
939 kStoppedOk,
940 kStoppedError
941 };
942 State profiling_state_ = State::kNotStarted;
943
944 RocmTracer* rocm_tracer_;
945 std::unique_ptr<RocmTraceCollectorImpl> rocm_trace_collector_;
946 };
947
GetRocmTracerOptions()948 RocmTracerOptions GpuTracer::GetRocmTracerOptions() {
949 // TODO(rocm-profiler): We need support for context similar to CUDA
950 RocmTracerOptions options;
951 std::vector<uint32_t> empty_vec;
952
953 // clang formatting does not preserve one entry per line
954 // clang-format off
955 std::vector<uint32_t> hip_api_domain_ops{
956 // KERNEL
957 HIP_API_ID_hipExtModuleLaunchKernel,
958 HIP_API_ID_hipModuleLaunchKernel,
959 HIP_API_ID_hipHccModuleLaunchKernel,
960 HIP_API_ID_hipLaunchKernel,
961 // MEMCPY
962 HIP_API_ID_hipMemcpy,
963 HIP_API_ID_hipMemcpyAsync,
964 HIP_API_ID_hipMemcpyDtoD,
965 HIP_API_ID_hipMemcpyDtoDAsync,
966 HIP_API_ID_hipMemcpyDtoH,
967 HIP_API_ID_hipMemcpyDtoHAsync,
968 HIP_API_ID_hipMemcpyHtoD,
969 HIP_API_ID_hipMemcpyHtoDAsync,
970 HIP_API_ID_hipMemcpyPeer,
971 HIP_API_ID_hipMemcpyPeerAsync,
972
973 // MEMSet
974 HIP_API_ID_hipMemsetD32,
975 HIP_API_ID_hipMemsetD32Async,
976 HIP_API_ID_hipMemsetD16,
977 HIP_API_ID_hipMemsetD16Async,
978 HIP_API_ID_hipMemsetD8,
979 HIP_API_ID_hipMemsetD8Async,
980 HIP_API_ID_hipMemset,
981 HIP_API_ID_hipMemsetAsync,
982
983 // MEMAlloc
984 HIP_API_ID_hipMalloc,
985 HIP_API_ID_hipMallocPitch,
986 // MEMFree
987 HIP_API_ID_hipFree,
988 // GENERIC
989 HIP_API_ID_hipStreamSynchronize,
990 };
991 // clang-format on
992
993 options.api_tracking_set =
994 std::set<uint32_t>(hip_api_domain_ops.begin(), hip_api_domain_ops.end());
995
996 // These are the list of APIs we track since roctracer activity
997 // does not provide all the information necessary to fully populate the
998 // TF events. We need to track the APIs for those activities in API domain but
999 // we only use them for filling the missing items in their corresponding
1000 // activity (using correlation id).
1001 // clang-format off
1002 std::vector<uint32_t> hip_api_aux_ops{
1003 HIP_API_ID_hipStreamWaitEvent,
1004 // TODO(rocm-profiler): finding device ID from hipEventSynchronize need some
1005 // extra work, we ignore it for now.
1006 // HIP_API_ID_hipEventSynchronize,
1007 HIP_API_ID_hipHostFree,
1008 HIP_API_ID_hipHostMalloc,
1009 HIP_API_ID_hipSetDevice // added to track default device
1010 };
1011 // clang-format on
1012
1013 hip_api_domain_ops.insert(hip_api_domain_ops.end(), hip_api_aux_ops.begin(),
1014 hip_api_aux_ops.end());
1015
1016 options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, hip_api_domain_ops);
1017 // options.api_callbacks.emplace(ACTIVITY_DOMAIN_ROCTX, empty_vec);
1018 // options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec);
1019
1020 // options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API,
1021 // hip_api_domain_ops);
1022 options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec);
1023 options.activity_tracing.emplace(ACTIVITY_DOMAIN_HCC_OPS, empty_vec);
1024
1025 return options;
1026 }
1027
GetRocmTraceCollectorOptions(uint32_t num_gpus)1028 RocmTraceCollectorOptions GpuTracer::GetRocmTraceCollectorOptions(
1029 uint32_t num_gpus) {
1030 RocmTraceCollectorOptions options;
1031 options.max_callback_api_events = 2 * 1024 * 1024;
1032 options.max_activity_api_events = 2 * 1024 * 1024;
1033 options.max_annotation_strings = 1024 * 1024;
1034 options.num_gpus = num_gpus;
1035 return options;
1036 }
1037
DoStart()1038 Status GpuTracer::DoStart() {
1039 if (!rocm_tracer_->IsAvailable()) {
1040 return errors::Unavailable("Another profile session running.");
1041 }
1042
1043 AnnotationStack::Enable(true);
1044
1045 RocmTraceCollectorOptions trace_collector_options =
1046 GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus());
1047 uint64_t start_gputime_ns = RocmTracer::GetTimestamp();
1048 uint64_t start_walltime_ns = tensorflow::EnvTime::NowNanos();
1049 rocm_trace_collector_ = std::make_unique<RocmTraceCollectorImpl>(
1050 trace_collector_options, start_walltime_ns, start_gputime_ns);
1051
1052 RocmTracerOptions tracer_options = GetRocmTracerOptions();
1053 rocm_tracer_->Enable(tracer_options, rocm_trace_collector_.get());
1054
1055 return Status::OK();
1056 }
1057
Start()1058 Status GpuTracer::Start() {
1059 Status status = DoStart();
1060 if (status.ok()) {
1061 profiling_state_ = State::kStartedOk;
1062 return Status::OK();
1063 } else {
1064 profiling_state_ = State::kStartedError;
1065 return status;
1066 }
1067 }
1068
DoStop()1069 Status GpuTracer::DoStop() {
1070 rocm_tracer_->Disable();
1071 AnnotationStack::Enable(false);
1072 return Status::OK();
1073 }
1074
Stop()1075 Status GpuTracer::Stop() {
1076 if (profiling_state_ == State::kStartedOk) {
1077 Status status = DoStop();
1078 profiling_state_ = status.ok() ? State::kStoppedOk : State::kStoppedError;
1079 }
1080 return Status::OK();
1081 }
1082
DoCollectData(XSpace * space)1083 Status GpuTracer::DoCollectData(XSpace* space) {
1084 if (rocm_trace_collector_) rocm_trace_collector_->Export(space);
1085 return Status::OK();
1086 }
1087
CollectData(XSpace * space)1088 Status GpuTracer::CollectData(XSpace* space) {
1089 switch (profiling_state_) {
1090 case State::kNotStarted:
1091 VLOG(3) << "No trace data collected, session wasn't started";
1092 return Status::OK();
1093 case State::kStartedOk:
1094 return errors::FailedPrecondition("Cannot collect trace before stopping");
1095 case State::kStartedError:
1096 LOG(ERROR) << "Cannot collect, roctracer failed to start";
1097 return Status::OK();
1098 case State::kStoppedError:
1099 VLOG(3) << "No trace data collected";
1100 return Status::OK();
1101 case State::kStoppedOk: {
1102 DoCollectData(space);
1103 return Status::OK();
1104 }
1105 }
1106 return errors::Internal("Invalid profiling state: ", profiling_state_);
1107 }
1108
1109 // Not in anonymous namespace for testing purposes.
CreateGpuTracer(const ProfileOptions & options)1110 std::unique_ptr<profiler::ProfilerInterface> CreateGpuTracer(
1111 const ProfileOptions& options) {
1112 if (options.device_type() != ProfileOptions::GPU &&
1113 options.device_type() != ProfileOptions::UNSPECIFIED)
1114 return nullptr;
1115
1116 profiler::RocmTracer* rocm_tracer =
1117 profiler::RocmTracer::GetRocmTracerSingleton();
1118 if (!rocm_tracer->IsAvailable()) return nullptr;
1119
1120 return std::make_unique<profiler::GpuTracer>(rocm_tracer);
1121 }
1122
__anon9b30bc4e0602null1123 auto register_rocm_gpu_tracer_factory = [] {
1124 RegisterProfilerFactory(&CreateGpuTracer);
1125 return 0;
1126 }();
1127
1128 } // namespace profiler
1129 } // namespace tensorflow
1130
1131 #endif // TENSORFLOW_USE_ROCM
1132