xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/backends/gpu/device_tracer_rocm.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #if TENSORFLOW_USE_ROCM
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/container/fixed_array.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/str_format.h"
26 #include "absl/strings/str_join.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/platform/abi.h"
29 #include "tensorflow/core/platform/env_time.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/profiler/backends/cpu/annotation_stack.h"
34 #include "tensorflow/core/profiler/backends/gpu/rocm_tracer.h"
35 #include "tensorflow/core/profiler/lib/profiler_factory.h"
36 #include "tensorflow/core/profiler/lib/profiler_interface.h"
37 #include "tensorflow/core/profiler/utils/parse_annotation.h"
38 #include "tensorflow/core/profiler/utils/xplane_builder.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/util/env_var.h"
42 
43 namespace tensorflow {
44 namespace profiler {
45 
46 namespace {
47 // Set the all XLines of specified XPlane to starting walltime.
48 // Events time in both host and device planes are CUTPI timestamps.
49 // We set initial RocmTracer timestamp as start time for all lines to reflect
50 // this fact. Eventually we change line start time to corresponding
51 // start_walltime_ns to normalize with CPU wall time.
NormalizeTimeStamps(XPlaneBuilder * plane,uint64_t start_walltime_ns)52 static void NormalizeTimeStamps(XPlaneBuilder* plane,
53                                 uint64_t start_walltime_ns) {
54   plane->ForEachLine([&](tensorflow::profiler::XLineBuilder line) {
55     line.SetTimestampNs(start_walltime_ns);
56   });
57 }
58 
GetDeviceXLineName(int64_t stream_id,absl::flat_hash_set<RocmTracerEventType> & event_types)59 std::string GetDeviceXLineName(
60     int64_t stream_id, absl::flat_hash_set<RocmTracerEventType>& event_types) {
61   std::string line_name = absl::StrCat("Stream #", stream_id);
62   event_types.erase(RocmTracerEventType::Unsupported);
63   if (event_types.empty()) return line_name;
64   std::vector<const char*> type_names;
65   for (const auto event_type : event_types) {
66     type_names.emplace_back(GetRocmTracerEventTypeName(event_type));
67   }
68   return absl::StrCat(line_name, "(", absl::StrJoin(type_names, ","), ")");
69 }
70 
71 }  // namespace
72 
73 class RocmTraceCollectorImpl : public profiler::RocmTraceCollector {
74  public:
RocmTraceCollectorImpl(const RocmTraceCollectorOptions & options,uint64_t start_walltime_ns,uint64_t start_gputime_ns)75   RocmTraceCollectorImpl(const RocmTraceCollectorOptions& options,
76                          uint64_t start_walltime_ns, uint64_t start_gputime_ns)
77       : RocmTraceCollector(options),
78         num_callback_events_(0),
79         num_activity_events_(0),
80         start_walltime_ns_(start_walltime_ns),
81         start_gputime_ns_(start_gputime_ns),
82         per_device_collector_(options.num_gpus) {}
83 
AddEvent(RocmTracerEvent && event,bool is_auxiliary)84   void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) override {
85     mutex_lock lock(event_maps_mutex_);
86 
87     if (event.source == RocmTracerEventSource::ApiCallback && !is_auxiliary) {
88       if (num_callback_events_ > options_.max_callback_api_events) {
89         OnEventsDropped("max callback event capacity reached",
90                         event.correlation_id);
91         DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
92         return;
93       }
94       num_callback_events_++;
95     } else if (event.source == RocmTracerEventSource::Activity &&
96                event.domain == RocmTracerEventDomain::HIP_API) {
97       // we do not count HIP_OPS activities.
98       if (num_activity_events_ > options_.max_activity_api_events) {
99         OnEventsDropped("max activity event capacity reached",
100                         event.correlation_id);
101         DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
102         return;
103       }
104       num_activity_events_++;
105     }
106 
107     bool emplace_result = false;
108     if (event.source == RocmTracerEventSource::ApiCallback) {
109       auto& target_api_event_map =
110           (is_auxiliary) ? auxiliary_api_events_map_ : api_events_map_;
111       std::tie(std::ignore, emplace_result) =
112           target_api_event_map.emplace(event.correlation_id, std::move(event));
113     } else if (event.source == RocmTracerEventSource::Activity) {
114       if (event.domain == RocmTracerEventDomain::HIP_API) {
115         std::tie(std::ignore, emplace_result) =
116             activity_api_events_map_.emplace(event.correlation_id,
117                                              std::move(event));
118       } else if (event.domain == RocmTracerEventDomain::HCC_OPS) {
119         auto result = activity_ops_events_map_.emplace(
120             event.correlation_id, std::vector<RocmTracerEvent>{});
121         result.first->second.push_back(std::move(event));
122         emplace_result = true;  // we always accept Hip-Ops events
123       }
124     }
125     if (!emplace_result) {
126       OnEventsDropped("event with duplicate correlation_id was received.",
127                       event.correlation_id);
128       DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
129     }
130   }
131 
OnEventsDropped(const std::string & reason,uint32_t correlation_id)132   void OnEventsDropped(const std::string& reason,
133                        uint32_t correlation_id) override {
134     LOG(INFO) << "RocmTracerEvent dropped (correlation_id=" << correlation_id
135               << ",) : " << reason << ".";
136   }
137 
Flush()138   void Flush() override {
139     mutex_lock lock(event_maps_mutex_);
140     auto& aggregated_events_ = ApiActivityInfoExchange();
141 
142     VLOG(3) << "RocmTraceCollector collected " << num_callback_events_
143             << " callback events, " << num_activity_events_
144             << " activity events, and aggregated them into "
145             << aggregated_events_.size() << " events.";
146 
147     for (auto& event : aggregated_events_) {
148       if (event.device_id >= options_.num_gpus) {
149         OnEventsDropped("device id >= num gpus", event.correlation_id);
150         DumpRocmTracerEvent(event, 0, 0, ". Dropped!");
151         LOG(WARNING) << "A ROCm profiler event record with wrong device ID "
152                         "dropped! Type="
153                      << GetRocmTracerEventTypeName(event.type);
154         continue;
155       }
156 
157       activity_api_events_map_.clear();
158       activity_ops_events_map_.clear();
159       api_events_map_.clear();
160       auxiliary_api_events_map_.clear();
161 
162       per_device_collector_[event.device_id].AddEvent(event);
163     }
164 
165     for (int i = 0; i < options_.num_gpus; ++i) {
166       per_device_collector_[i].SortByStartTime();
167     }
168   }
169 
Export(XSpace * space)170   void Export(XSpace* space) {
171     uint64_t end_gputime_ns = RocmTracer::GetTimestamp();
172     XPlaneBuilder host_plane(
173         FindOrAddMutablePlaneWithName(space, kRoctracerApiPlaneName));
174     for (int i = 0; i < options_.num_gpus; ++i) {
175       std::string name = GpuPlaneName(i);
176       XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name));
177       device_plane.SetId(i);
178       // Calculate device capabilities before flushing, so that device
179       // properties are available to the occupancy calculator in export().
180       per_device_collector_[i].GetDeviceCapabilities(i, &device_plane);
181       per_device_collector_[i].Export(start_walltime_ns_, start_gputime_ns_,
182                                       end_gputime_ns, &device_plane,
183                                       &host_plane);
184 
185       NormalizeTimeStamps(&device_plane, start_walltime_ns_);
186     }
187     NormalizeTimeStamps(&host_plane, start_walltime_ns_);
188   }
189 
190  private:
191   std::atomic<int> num_callback_events_;
192   std::atomic<int> num_activity_events_;
193   uint64_t start_walltime_ns_;
194   uint64_t start_gputime_ns_;
195 
196   mutex event_maps_mutex_;
197   absl::flat_hash_map<uint32, RocmTracerEvent> api_events_map_
198       TF_GUARDED_BY(event_maps_mutex_);
199   absl::flat_hash_map<uint32, RocmTracerEvent> activity_api_events_map_
200       TF_GUARDED_BY(event_maps_mutex_);
201 
202   /* Some apis such as MEMSETD32 (based on an observation with ResNet50),
203     trigger multiple HIP ops domain activities. We keep them in a vector and
204     merge them with api activities at flush time.
205   */
206   absl::flat_hash_map<uint32, std::vector<RocmTracerEvent>>
207       activity_ops_events_map_ TF_GUARDED_BY(event_maps_mutex_);
208   // This is for the APIs that we track because we need some information from
209   // them to populate the corresponding activity that we actually track.
210   absl::flat_hash_map<uint32, RocmTracerEvent> auxiliary_api_events_map_
211       TF_GUARDED_BY(event_maps_mutex_);
212 
ApiActivityInfoExchange()213   const std::vector<RocmTracerEvent> ApiActivityInfoExchange() {
214     /* Different from CUDA, roctracer activity records are not enough to fill a
215       TF event. For most of the activities, we need to enable the corresponding
216       API callsbacks (we call them auxiliary API callbacks) to capture the
217       necessary fields from them using the correlation id. The purpose of this
218       function is to let APIs and activities exchange information to reach a
219       state very similar to TF CUDA and getting ready to dump the event.
220     */
221 
222     // Copying info from HIP-OPS activities to HIP-API activities
223     /*HIP-API activities <<==== HIP-OPS activities*/
224     auto activity_api_events_map_iter = activity_api_events_map_.begin();
225     while (activity_api_events_map_iter != activity_api_events_map_.end()) {
226       uint32_t activity_corr_id = activity_api_events_map_iter->first;
227       RocmTracerEvent& activity_api_event =
228           activity_api_events_map_iter->second;
229 
230       bool result = false;
231       switch (activity_api_event.type) {
232         case RocmTracerEventType::Kernel:
233         case RocmTracerEventType::Memset: {
234           // KERNEL & MEMSET
235           auto iter =
236               activity_ops_events_map_.find(activity_api_event.correlation_id);
237           result = (iter != activity_ops_events_map_.end());
238           if (result) {
239             // since the key exist in the map, there should be at least one item
240             // in the vector
241             activity_api_event.device_id = iter->second.front().device_id;
242             activity_api_event.stream_id = iter->second.front().stream_id;
243             // we initialize the start time and end time based on the first
244             // element
245             activity_api_event.start_time_ns =
246                 iter->second.front().start_time_ns;
247             activity_api_event.end_time_ns = iter->second.front().end_time_ns;
248             for (auto& kernel_activity_op : iter->second) {
249               activity_api_event.start_time_ns =
250                   std::min(activity_api_event.start_time_ns,
251                            kernel_activity_op.start_time_ns);
252               activity_api_event.end_time_ns =
253                   std::max(activity_api_event.end_time_ns,
254                            kernel_activity_op.end_time_ns);
255             }
256           }
257           break;
258         }
259         case RocmTracerEventType::MemcpyD2D:
260         case RocmTracerEventType::MemcpyH2D:
261         case RocmTracerEventType::MemcpyD2H:
262         case RocmTracerEventType::MemcpyOther: {
263           // MEMCPY
264           auto iter =
265               activity_ops_events_map_.find(activity_api_event.correlation_id);
266           result = (iter != activity_ops_events_map_.end());
267           if (result) {
268             // since the key exist in the map, there should be at least one item
269             // in the vector
270             activity_api_event.device_id = iter->second.front().device_id;
271             activity_api_event.memcpy_info.destination =
272                 iter->second.front()
273                     .memcpy_info.destination;  // similar to CUDA, it is the
274                                                // same as device_id
275             activity_api_event.stream_id = iter->second.front().stream_id;
276             /* IMPORTANT: it seems that the HCC timing is only valid for
277              * Synchronous memcpy activities*/
278             if (!activity_api_event.memcpy_info.async) {
279               activity_api_event.start_time_ns =
280                   iter->second.front().start_time_ns;
281               activity_api_event.end_time_ns = iter->second.front().end_time_ns;
282               for (auto& kernel_activity_op : iter->second) {
283                 activity_api_event.start_time_ns =
284                     std::min(activity_api_event.start_time_ns,
285                              kernel_activity_op.start_time_ns);
286                 activity_api_event.end_time_ns =
287                     std::max(activity_api_event.end_time_ns,
288                              kernel_activity_op.end_time_ns);
289               }
290             }
291           }
292           break;
293         }
294         default:
295           // nothing to do for the rest
296           result = true;
297           break;
298       }
299       if (!result) {
300         OnEventsDropped(
301             "A HIP-API activity with missing HIP-OPS activity was found",
302             activity_api_event.correlation_id);
303         DumpRocmTracerEvent(activity_api_event, 0, 0, ". Dropped!");
304         activity_api_events_map_.erase(activity_api_events_map_iter++);
305       } else {
306         ++activity_api_events_map_iter;
307       }
308     }
309 
310     // the event vector to be returned
311     std::vector<RocmTracerEvent> aggregated_events;
312 
313     // Copying info from HIP activities to HIP API callbacks
314     /*HIP-API call backs <<==== HIP-API activities*/
315     for (auto& api_iter : api_events_map_) {
316       RocmTracerEvent& api_event = api_iter.second;
317       auto iter = activity_api_events_map_.find(api_event.correlation_id);
318       switch (api_event.type) {
319         /*KERNEL API*/
320         case RocmTracerEventType::Kernel: {
321           aggregated_events.push_back(api_event);
322           break;
323         }
324         /*MEMCPY API*/
325         case RocmTracerEventType::MemcpyD2H:
326         case RocmTracerEventType::MemcpyH2D:
327         case RocmTracerEventType::MemcpyD2D:
328         case RocmTracerEventType::MemcpyOther: {
329           if (iter != activity_api_events_map_.end()) {
330             api_event.device_id = iter->second.device_id;
331             api_event.memcpy_info.destination =
332                 api_event.device_id;  // Similar to CUDA
333             aggregated_events.push_back(api_event);
334           } else {
335             OnEventsDropped(
336                 "A Memcpy event from HIP API discarded."
337                 " Could not find the counterpart activity.",
338                 api_event.correlation_id);
339             DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
340           }
341           break;
342         }
343         /*MEMSET API*/
344         case RocmTracerEventType::Memset: {
345           if (iter != activity_api_events_map_.end()) {
346             api_event.device_id = iter->second.device_id;
347 
348             aggregated_events.push_back(api_event);
349           } else {
350             OnEventsDropped(
351                 "A Memset event from HIP API discarded."
352                 " Could not find the counterpart activity.",
353                 api_event.correlation_id);
354             DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
355           }
356           break;
357         }
358         /*MALLOC API, FREE API*/
359         case RocmTracerEventType::MemoryAlloc:
360         case RocmTracerEventType::MemoryFree: {
361           // no missing info
362           aggregated_events.push_back(api_event);
363           break;
364         }
365         /*SYNCHRONIZATION API*/
366         case RocmTracerEventType::Synchronization: {
367           // no missing info
368           aggregated_events.push_back(api_event);
369           break;
370         }
371         default:
372           OnEventsDropped("Missing API-Activity information exchange. Dropped!",
373                           api_event.correlation_id);
374           DumpRocmTracerEvent(api_event, 0, 0, ". Dropped!");
375           LOG(WARNING) << "A ROCm API event type with unimplemented activity "
376                           "merge dropped! "
377                           "Type="
378                        << GetRocmTracerEventTypeName(api_event.type);
379           break;
380       }  // end switch(api_event.type)
381     }
382 
383     // Copying info from HIP API callbacks to HIP API activities
384     //  API ACTIVITIES<<====API-CB
385     for (auto& activity_iter : activity_api_events_map_) {
386       RocmTracerEvent& activity_event = activity_iter.second;
387       // finding the corresponding activity either in the api_call backs or the
388       // axuilarities
389       auto iter = api_events_map_.find(activity_event.correlation_id);
390 
391       iter = (iter == api_events_map_.end())
392                  ? auxiliary_api_events_map_.find(activity_event.correlation_id)
393                  : iter;
394       switch (activity_event.type) {
395         /*KERNEL ACTIVITY*/
396         case RocmTracerEventType::Kernel: {
397           if (iter != api_events_map_.end() ||
398               iter != auxiliary_api_events_map_.end()) {
399             activity_event.name = iter->second.name;
400             activity_event.kernel_info = iter->second.kernel_info;
401             aggregated_events.push_back(activity_event);
402           } else {
403             OnEventsDropped(
404                 "A Kernel event activity was discarded."
405                 " Could not find the counterpart API callback.",
406                 activity_event.correlation_id);
407             DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
408           }
409           break;
410         }
411         /*MEMCPY ACTIVITY*/
412         case RocmTracerEventType::MemcpyD2H:
413         case RocmTracerEventType::MemcpyH2D:
414         case RocmTracerEventType::MemcpyD2D:
415         case RocmTracerEventType::MemcpyOther: {
416           if (iter != api_events_map_.end() ||
417               iter != auxiliary_api_events_map_.end()) {
418             activity_event.memcpy_info = iter->second.memcpy_info;
419             aggregated_events.push_back(activity_event);
420           } else {
421             OnEventsDropped(
422                 "A Memcpy event activity was discarded."
423                 " Could not find the counterpart API callback.",
424                 activity_event.correlation_id);
425             DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
426           }
427           break;
428         }
429         /*MEMSET ACTIVITY*/
430         case RocmTracerEventType::Memset: {
431           if (iter != api_events_map_.end() ||
432               iter != auxiliary_api_events_map_.end()) {
433             activity_event.memset_info = iter->second.memset_info;
434             aggregated_events.push_back(activity_event);
435 
436           } else {
437             OnEventsDropped(
438                 "A Memset event activity was discarded."
439                 " Could not find the counterpart API callback.",
440                 activity_event.correlation_id);
441             DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
442           }
443           break;
444         }
445         /*MALLOC ACTIVITY, FREE ACTIVITY*/
446         case RocmTracerEventType::MemoryAlloc:
447         case RocmTracerEventType::MemoryFree: {
448           if (iter != api_events_map_.end() ||
449               iter != auxiliary_api_events_map_.end()) {
450             activity_event.device_id = iter->second.device_id;
451             aggregated_events.push_back(activity_event);
452           } else {
453             OnEventsDropped(
454                 "A Malloc/Free activity was discarded."
455                 " Could not find the counterpart API callback.",
456                 activity_event.correlation_id);
457             DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
458           }
459           break;
460         }
461         /*SYNCHRONIZATION ACTIVITY*/
462         case RocmTracerEventType::Synchronization: {
463           if (iter != api_events_map_.end() ||
464               iter != auxiliary_api_events_map_.end()) {
465             // CUDA does not provide device ID for these activities.
466             // Interestingly, TF-profiler by default set the device id to 0 for
467             // CuptiTracerEvent.
468             // RocmTracerEvent type, set device by default to an unvalid
469             // device-id value. To be consistent with CUDA (in terms of having a
470             // logically valid value for device id) we update the device-id to
471             // its correct value
472             activity_event.device_id = iter->second.device_id;
473             aggregated_events.push_back(activity_event);
474           } else {
475             OnEventsDropped(
476                 "A sync event activity was discarded."
477                 " Could not find the counterpart API callback.",
478                 activity_event.correlation_id);
479             DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
480           }
481           break;
482         }
483         default:
484           OnEventsDropped("Missing API-Activity information exchange. Dropped!",
485                           activity_event.correlation_id);
486           DumpRocmTracerEvent(activity_event, 0, 0, ". Dropped!");
487           LOG(WARNING) << "A ROCm activity event with unimplemented API "
488                           "callback merge dropped! "
489                           "Type="
490                        << GetRocmTracerEventTypeName(activity_event.type);
491           break;
492       }  // end switch(activity_event.type)
493     }
494 
495     return aggregated_events;
496   }
497   struct RocmDeviceOccupancyParams {
498     hipFuncAttributes attributes = {};
499     int block_size = 0;
500     size_t dynamic_smem_size = 0;
501     void* func_ptr;
502 
operator ==(const RocmDeviceOccupancyParams & lhs,const RocmDeviceOccupancyParams & rhs)503     friend bool operator==(const RocmDeviceOccupancyParams& lhs,
504                            const RocmDeviceOccupancyParams& rhs) {
505       return 0 == memcmp(&lhs, &rhs, sizeof(lhs));
506     }
507 
508     template <typename H>
AbslHashValue(H hash_state,const RocmDeviceOccupancyParams & params)509     friend H AbslHashValue(H hash_state,
510                            const RocmDeviceOccupancyParams& params) {
511       return H::combine(
512           std::move(hash_state), params.attributes.maxThreadsPerBlock,
513           params.attributes.numRegs, params.attributes.sharedSizeBytes,
514           params.attributes.maxDynamicSharedSizeBytes, params.block_size,
515           params.dynamic_smem_size, params.func_ptr);
516     }
517   };
518 
519   struct OccupancyStats {
520     double occupancy_pct = 0.0;
521     int min_grid_size = 0;
522     int suggested_block_size = 0;
523   };
524   struct CorrelationInfo {
CorrelationInfotensorflow::profiler::RocmTraceCollectorImpl::CorrelationInfo525     CorrelationInfo(uint32_t t, uint32_t e)
526         : thread_id(t), enqueue_time_ns(e) {}
527     uint32_t thread_id;
528     uint64_t enqueue_time_ns;
529   };
530 
531   struct PerDeviceCollector {
GetDeviceCapabilitiestensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector532     void GetDeviceCapabilities(int32_t device_ordinal,
533                                XPlaneBuilder* device_plane) {
534       device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata(
535                                      GetStatTypeStr(StatType::kDevVendor)),
536                                  kDeviceVendorAMD);
537 
538       if (hipGetDeviceProperties(&device_properties_, device_ordinal) !=
539           hipSuccess)
540         return;
541 
542       auto clock_rate_in_khz =
543           device_properties_.clockRate;  // this is also in Khz
544       if (clock_rate_in_khz) {
545         device_plane->AddStatValue(
546             *device_plane->GetOrCreateStatMetadata(
547                 GetStatTypeStr(StatType::kDevCapClockRateKHz)),
548             clock_rate_in_khz);
549       }
550 
551       auto core_count = device_properties_.multiProcessorCount;
552       if (core_count) {
553         device_plane->AddStatValue(
554             *device_plane->GetOrCreateStatMetadata(
555                 GetStatTypeStr(StatType::kDevCapCoreCount)),
556             core_count);
557       }
558 
559       auto mem_clock_khz = device_properties_.memoryClockRate;
560       auto mem_bus_width_bits = device_properties_.memoryBusWidth;
561 
562       if (mem_clock_khz && mem_bus_width_bits) {
563         // Times 2 because HBM is DDR memory; it gets two data bits per each
564         // data lane.
565         auto memory_bandwidth =
566             uint64{2} * (mem_clock_khz)*1000 * (mem_bus_width_bits) / 8;
567         device_plane->AddStatValue(
568             *device_plane->GetOrCreateStatMetadata(
569                 GetStatTypeStr(StatType::kDevCapMemoryBandwidth)),
570             memory_bandwidth);
571       }
572 
573       size_t total_memory = device_properties_.totalGlobalMem;
574       if (total_memory) {
575         device_plane->AddStatValue(
576             *device_plane->GetOrCreateStatMetadata(
577                 GetStatTypeStr(StatType::kDevCapMemorySize)),
578             static_cast<uint64>(total_memory));
579       }
580 
581       auto compute_capability_major = device_properties_.major;
582       if (compute_capability_major) {
583         device_plane->AddStatValue(
584             *device_plane->GetOrCreateStatMetadata(
585                 GetStatTypeStr(StatType::kDevCapComputeCapMajor)),
586             compute_capability_major);
587       }
588       auto compute_capability_minor = device_properties_.minor;
589       if (compute_capability_minor) {
590         device_plane->AddStatValue(
591             *device_plane->GetOrCreateStatMetadata(
592                 GetStatTypeStr(StatType::kDevCapComputeCapMinor)),
593             compute_capability_minor);
594       }
595     }
596 
ToXStattensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector597     inline std::string ToXStat(const KernelDetails& kernel_info,
598                                double occupancy_pct) {
599       return absl::StrCat(
600           "regs:", kernel_info.registers_per_thread,
601           " static_shared:", kernel_info.static_shared_memory_usage,
602           " dynamic_shared:", kernel_info.dynamic_shared_memory_usage,
603           " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",",
604           kernel_info.grid_z, " block:", kernel_info.block_x, ",",
605           kernel_info.block_y, ",", kernel_info.block_z,
606           " occ_pct:", occupancy_pct);
607     }
GetOccupancytensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector608     OccupancyStats GetOccupancy(const RocmDeviceOccupancyParams& params) const {
609       // TODO(rocm-profiler): hipOccupancyMaxActiveBlocksPerMultiprocessor only
610       // return hipSuccess for HIP_API_ID_hipLaunchKernel
611 
612       OccupancyStats stats;
613       int number_of_active_blocks;
614       hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
615           &number_of_active_blocks, params.func_ptr, params.block_size,
616           params.dynamic_smem_size);
617 
618       if (err != hipError_t::hipSuccess) {
619         return {};
620       }
621 
622       stats.occupancy_pct = number_of_active_blocks * params.block_size * 100;
623       stats.occupancy_pct /= device_properties_.maxThreadsPerMultiProcessor;
624 
625       err = hipOccupancyMaxPotentialBlockSize(
626           &stats.min_grid_size, &stats.suggested_block_size, params.func_ptr,
627           params.dynamic_smem_size, 0);
628 
629       if (err != hipError_t::hipSuccess) {
630         return {};
631       }
632 
633       return stats;
634     }
AddEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector635     void AddEvent(const RocmTracerEvent& event) {
636       mutex_lock l(events_mutex);
637       if (event.source == RocmTracerEventSource::ApiCallback) {
638         // Cupti api callback events were used to populate launch times etc.
639         if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) {
640           correlation_info_.insert(
641               {event.correlation_id,
642                CorrelationInfo(event.thread_id, event.start_time_ns)});
643         }
644         events.emplace_back(std::move(event));
645       } else {
646         // Cupti activity events measure device times etc.
647         events.emplace_back(std::move(event));
648       }
649     }
650 
SortByStartTimetensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector651     void SortByStartTime() {
652       mutex_lock lock(events_mutex);
653       std::sort(
654           events.begin(), events.end(),
655           [](const RocmTracerEvent& event1, const RocmTracerEvent& event2) {
656             return event1.start_time_ns < event2.start_time_ns;
657           });
658     }
659 
CreateXEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector660     void CreateXEvent(const RocmTracerEvent& event, XPlaneBuilder* plane,
661                       uint64_t start_gpu_ns, uint64_t end_gpu_ns,
662                       XLineBuilder* line) {
663       if (event.start_time_ns < start_gpu_ns ||
664           event.end_time_ns > end_gpu_ns ||
665           event.start_time_ns > event.end_time_ns) {
666         VLOG(2) << "events have abnormal timestamps:" << event.name
667                 << " start time(ns): " << event.start_time_ns
668                 << " end time(ns): " << event.end_time_ns
669                 << " start gpu(ns):" << start_gpu_ns
670                 << " end gpu(ns):" << end_gpu_ns
671                 << " corr. id:" << event.correlation_id;
672         return;
673       }
674       std::string kernel_name = port::MaybeAbiDemangle(event.name.c_str());
675       if (kernel_name.empty()) {
676         kernel_name = GetRocmTracerEventTypeName(event.type);
677       }
678       XEventMetadata* event_metadata =
679           plane->GetOrCreateEventMetadata(std::move(kernel_name));
680       XEventBuilder xevent = line->AddEvent(*event_metadata);
681       VLOG(7) << "Adding event to line=" << line->Id();
682       xevent.SetTimestampNs(event.start_time_ns);
683       xevent.SetEndTimestampNs(event.end_time_ns);
684       if (event.source == RocmTracerEventSource::ApiCallback) {
685         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
686                                 GetStatTypeStr(StatType::kDeviceId)),
687                             event.device_id);
688       }
689       if (event.correlation_id != RocmTracerEvent::kInvalidCorrelationId) {
690         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
691                                 GetStatTypeStr(StatType::kCorrelationId)),
692                             event.correlation_id);
693       }
694       if (!event.roctx_range.empty()) {
695         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
696                                 GetStatTypeStr(StatType::kNVTXRange)),
697                             *plane->GetOrCreateStatMetadata(event.roctx_range));
698       }
699       // if (event.context_id != CuptiTracerEvent::kInvalidContextId) {
700       //   xevent.AddStatValue(
701       //       *plane->GetOrCreateStatMetadata(
702       //           GetStatTypeStr(StatType::kContextId)),
703       //       absl::StrCat("$$", static_cast<uint64>(event.context_id)));
704       // }
705 
706       if (event.type == RocmTracerEventType::Kernel &&
707           event.source == RocmTracerEventSource::Activity) {
708         RocmDeviceOccupancyParams params{};
709         params.attributes.maxThreadsPerBlock = INT_MAX;
710         params.attributes.numRegs =
711             static_cast<int>(event.kernel_info.registers_per_thread);
712         params.attributes.sharedSizeBytes =
713             event.kernel_info.static_shared_memory_usage;
714         // params.attributes.partitionedGCConfig = PARTITIONED_GC_OFF;
715         // params.attributes.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT;
716         params.attributes.maxDynamicSharedSizeBytes = 0;
717         params.block_size = static_cast<int>(event.kernel_info.block_x *
718                                              event.kernel_info.block_y *
719                                              event.kernel_info.block_z);
720 
721         params.dynamic_smem_size =
722             event.kernel_info.dynamic_shared_memory_usage;
723         params.func_ptr = event.kernel_info.func_ptr;
724 
725         OccupancyStats& occ_stats = occupancy_cache_[params];
726         if (occ_stats.occupancy_pct == 0.0) {
727           occ_stats = GetOccupancy(params);
728         }
729         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
730                                 StatType::kTheoreticalOccupancyPct)),
731                             occ_stats.occupancy_pct);
732         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
733                                 StatType::kOccupancyMinGridSize)),
734                             static_cast<int32>(occ_stats.min_grid_size));
735         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
736                                 StatType::kOccupancySuggestedBlockSize)),
737                             static_cast<int32>(occ_stats.suggested_block_size));
738         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
739                                 GetStatTypeStr(StatType::kKernelDetails)),
740                             *plane->GetOrCreateStatMetadata(ToXStat(
741                                 event.kernel_info, occ_stats.occupancy_pct)));
742       } else if (event.type == RocmTracerEventType::MemcpyH2D ||
743                  event.type == RocmTracerEventType::MemcpyD2H ||
744                  event.type == RocmTracerEventType::MemcpyD2D ||
745                  event.type == RocmTracerEventType::MemcpyP2P ||
746                  event.type == RocmTracerEventType::MemcpyOther) {
747         VLOG(7) << "Add Memcpy stat";
748         const auto& memcpy_info = event.memcpy_info;
749         std::string memcpy_details = absl::StrCat(
750             // TODO(rocm-profiler): we need to discover the memory kind similar
751             // to CUDA
752             "kind:", "Unknown", " size:", memcpy_info.num_bytes,
753             " dest:", memcpy_info.destination, " async:", memcpy_info.async);
754         xevent.AddStatValue(
755             *plane->GetOrCreateStatMetadata(
756                 GetStatTypeStr(StatType::kMemcpyDetails)),
757             *plane->GetOrCreateStatMetadata(std::move(memcpy_details)));
758       } else if (event.type == RocmTracerEventType::MemoryAlloc) {
759         VLOG(7) << "Add MemAlloc stat";
760         std::string value =
761             // TODO(rocm-profiler): we need to discover the memory kind similar
762             // to CUDA
763             absl::StrCat("kind:", "Unknown",
764                          " num_bytes:", event.memalloc_info.num_bytes);
765         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
766                                 GetStatTypeStr(StatType::kMemallocDetails)),
767                             *plane->GetOrCreateStatMetadata(std::move(value)));
768       } else if (event.type == RocmTracerEventType::MemoryFree) {
769         VLOG(7) << "Add MemFree stat";
770         std::string value =
771             // TODO(rocm-profiler): we need to discover the memory kind similar
772             // to CUDA
773             absl::StrCat("kind:", "Unknown",
774                          " num_bytes:", event.memalloc_info.num_bytes);
775         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
776                                 GetStatTypeStr(StatType::kMemFreeDetails)),
777                             *plane->GetOrCreateStatMetadata(std::move(value)));
778       } else if (event.type == RocmTracerEventType::Memset) {
779         VLOG(7) << "Add Memset stat";
780         auto value =
781             // TODO(rocm-profiler): we need to discover the memory kind similar
782             // to CUDA
783             absl::StrCat("kind:", "Unknown",
784                          " num_bytes:", event.memset_info.num_bytes,
785                          " async:", event.memset_info.async);
786         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(
787                                 GetStatTypeStr(StatType::kMemsetDetails)),
788                             *plane->GetOrCreateStatMetadata(std::move(value)));
789       }
790       // TODO(rocm-profiler): we need to support the following event type
791       /* else if (event.type == CuptiTracerEventType::MemoryResidency) {
792         VLOG(7) << "Add MemoryResidency stat";
793         std::string value = absl::StrCat(
794             "kind:", GetMemoryKindName(event.memory_residency_info.kind),
795             " num_bytes:", event.memory_residency_info.num_bytes,
796             " addr:", event.memory_residency_info.address);
797         xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr(
798                                 StatType::kMemoryResidencyDetails)),
799                             *plane->GetOrCreateStatMetadata(std::move(value)));
800       } */
801 
802       std::vector<Annotation> annotation_stack =
803           ParseAnnotationStack(event.annotation);
804       if (!annotation_stack.empty()) {
805         xevent.AddStatValue(
806             *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kTfOp)),
807             *plane->GetOrCreateStatMetadata(annotation_stack.begin()->name));
808       }
809       // If multiple metadata have the same key name, show the values from the
810       // top of the stack (innermost annotation). Concatenate the values from
811       // "hlo_op".
812       absl::flat_hash_set<absl::string_view> key_set;
813 
814       for (auto annotation = annotation_stack.rbegin();
815            annotation != annotation_stack.rend(); ++annotation) {
816         for (const Annotation::Metadata& metadata : annotation->metadata) {
817           if (key_set.insert(metadata.key).second) {
818             xevent.ParseAndAddStatValue(
819                 *plane->GetOrCreateStatMetadata(metadata.key), metadata.value);
820           }
821         }
822       }
823     }
IsHostEventtensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector824     bool IsHostEvent(const RocmTracerEvent& event, int64* line_id) {
825       // DriverCallback(i.e. kernel launching) events are host events.
826       if (event.source == RocmTracerEventSource::ApiCallback) {
827         *line_id = event.thread_id;
828         return true;
829       } else {  // activities
830         *line_id = event.stream_id;
831         return false;
832       }
833 
834       // TODO(rocm-profiler): do we have such a report in rocm?
835       // Non-overhead activity events are device events.
836       /* if (event.type != CuptiTracerEventType::Overhead) {
837         *line_id = event.stream_id;
838         return false;
839       } */
840       // Overhead events can be associated with a thread or a stream, etc.
841       // If a valid thread id is specified, we consider it as a host event.
842       //
843 
844       if (event.stream_id != RocmTracerEvent::kInvalidStreamId) {
845         *line_id = event.stream_id;
846         return false;
847       } else if (event.thread_id != RocmTracerEvent::kInvalidThreadId &&
848                  event.thread_id != 0) {
849         *line_id = event.thread_id;
850         return true;
851       } else {
852         *line_id = kThreadIdOverhead;
853         return false;
854       }
855     }
Exporttensorflow::profiler::RocmTraceCollectorImpl::PerDeviceCollector856     void Export(uint64_t start_walltime_ns, uint64_t start_gputime_ns,
857                 uint64_t end_gputime_ns, XPlaneBuilder* device_plane,
858                 XPlaneBuilder* host_plane) {
859       int host_ev_cnt = 0, dev_ev_cnt = 0;
860       mutex_lock l(events_mutex);
861       // Tracking event types per line.
862       absl::flat_hash_map<int64, absl::flat_hash_set<RocmTracerEventType>>
863           events_types_per_line;
864       for (const RocmTracerEvent& event : events) {
865         int64_t line_id = RocmTracerEvent::kInvalidThreadId;
866         bool is_host_event = IsHostEvent(event, &line_id);
867 
868         if (is_host_event) {
869           host_ev_cnt++;
870         } else {
871           dev_ev_cnt++;
872         }
873 
874         if (line_id == RocmTracerEvent::kInvalidThreadId ||
875             line_id == RocmTracerEvent::kInvalidStreamId) {
876           VLOG(3) << "Ignoring event, type=" << static_cast<int>(event.type);
877           continue;
878         }
879         auto* plane = is_host_event ? host_plane : device_plane;
880         VLOG(9) << "Event"
881                 << " type=" << static_cast<int>(event.type)
882                 << " line_id=" << line_id
883                 << (is_host_event ? " host plane=" : " device plane=")
884                 << plane->Name();
885         XLineBuilder line = plane->GetOrCreateLine(line_id);
886         line.SetTimestampNs(start_gputime_ns);
887         CreateXEvent(event, plane, start_gputime_ns, end_gputime_ns, &line);
888         events_types_per_line[line_id].emplace(event.type);
889       }
890       device_plane->ForEachLine([&](XLineBuilder line) {
891         line.SetName(
892             GetDeviceXLineName(line.Id(), events_types_per_line[line.Id()]));
893       });
894       host_plane->ForEachLine([&](XLineBuilder line) {
895         line.SetName(absl::StrCat("Host Threads/", line.Id()));
896       });
897       size_t num_events = events.size();
898       events.clear();
899     }
900 
901     mutex events_mutex;
902     std::vector<RocmTracerEvent> events TF_GUARDED_BY(events_mutex);
903     absl::flat_hash_map<uint32, CorrelationInfo> correlation_info_
904         TF_GUARDED_BY(events_mutex);
905     absl::flat_hash_map<RocmDeviceOccupancyParams, OccupancyStats>
906         occupancy_cache_;
907     hipDeviceProp_t device_properties_;
908   };
909 
910   absl::FixedArray<PerDeviceCollector> per_device_collector_;
911 };
912 
913 // GpuTracer for ROCm GPU.
914 class GpuTracer : public profiler::ProfilerInterface {
915  public:
GpuTracer(RocmTracer * rocm_tracer)916   GpuTracer(RocmTracer* rocm_tracer) : rocm_tracer_(rocm_tracer) {
917     LOG(INFO) << "GpuTracer created.";
918   }
~GpuTracer()919   ~GpuTracer() override {}
920 
921   // GpuTracer interface:
922   Status Start() override;
923   Status Stop() override;
924   Status CollectData(XSpace* space) override;
925 
926  private:
927   Status DoStart();
928   Status DoStop();
929   Status DoCollectData(XSpace* space);
930 
931   RocmTracerOptions GetRocmTracerOptions();
932 
933   RocmTraceCollectorOptions GetRocmTraceCollectorOptions(uint32_t num_gpus);
934 
935   enum State {
936     kNotStarted,
937     kStartedOk,
938     kStartedError,
939     kStoppedOk,
940     kStoppedError
941   };
942   State profiling_state_ = State::kNotStarted;
943 
944   RocmTracer* rocm_tracer_;
945   std::unique_ptr<RocmTraceCollectorImpl> rocm_trace_collector_;
946 };
947 
GetRocmTracerOptions()948 RocmTracerOptions GpuTracer::GetRocmTracerOptions() {
949   // TODO(rocm-profiler): We need support for context similar to CUDA
950   RocmTracerOptions options;
951   std::vector<uint32_t> empty_vec;
952 
953   // clang formatting does not preserve one entry per line
954   // clang-format off
955   std::vector<uint32_t> hip_api_domain_ops{
956       // KERNEL
957       HIP_API_ID_hipExtModuleLaunchKernel,
958       HIP_API_ID_hipModuleLaunchKernel,
959       HIP_API_ID_hipHccModuleLaunchKernel,
960       HIP_API_ID_hipLaunchKernel,
961       // MEMCPY
962       HIP_API_ID_hipMemcpy,
963       HIP_API_ID_hipMemcpyAsync,
964       HIP_API_ID_hipMemcpyDtoD,
965       HIP_API_ID_hipMemcpyDtoDAsync,
966       HIP_API_ID_hipMemcpyDtoH,
967       HIP_API_ID_hipMemcpyDtoHAsync,
968       HIP_API_ID_hipMemcpyHtoD,
969       HIP_API_ID_hipMemcpyHtoDAsync,
970       HIP_API_ID_hipMemcpyPeer,
971       HIP_API_ID_hipMemcpyPeerAsync,
972 
973       // MEMSet
974       HIP_API_ID_hipMemsetD32,
975       HIP_API_ID_hipMemsetD32Async,
976       HIP_API_ID_hipMemsetD16,
977       HIP_API_ID_hipMemsetD16Async,
978       HIP_API_ID_hipMemsetD8,
979       HIP_API_ID_hipMemsetD8Async,
980       HIP_API_ID_hipMemset,
981       HIP_API_ID_hipMemsetAsync,
982 
983       // MEMAlloc
984       HIP_API_ID_hipMalloc,
985       HIP_API_ID_hipMallocPitch,
986       // MEMFree
987       HIP_API_ID_hipFree,
988       // GENERIC
989       HIP_API_ID_hipStreamSynchronize,
990   };
991   // clang-format on
992 
993   options.api_tracking_set =
994       std::set<uint32_t>(hip_api_domain_ops.begin(), hip_api_domain_ops.end());
995 
996   // These are the list of APIs we track since roctracer activity
997   // does not provide all the information necessary to fully populate the
998   // TF events. We need to track the APIs for those activities in API domain but
999   // we only use them for filling the missing items in their corresponding
1000   // activity (using correlation id).
1001   // clang-format off
1002   std::vector<uint32_t> hip_api_aux_ops{
1003     HIP_API_ID_hipStreamWaitEvent,
1004     // TODO(rocm-profiler): finding device ID from hipEventSynchronize need some
1005     // extra work, we ignore it for now.
1006     // HIP_API_ID_hipEventSynchronize,
1007     HIP_API_ID_hipHostFree,
1008     HIP_API_ID_hipHostMalloc,
1009     HIP_API_ID_hipSetDevice  //  added to track default device
1010   };
1011   // clang-format on
1012 
1013   hip_api_domain_ops.insert(hip_api_domain_ops.end(), hip_api_aux_ops.begin(),
1014                             hip_api_aux_ops.end());
1015 
1016   options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, hip_api_domain_ops);
1017   // options.api_callbacks.emplace(ACTIVITY_DOMAIN_ROCTX, empty_vec);
1018   // options.api_callbacks.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec);
1019 
1020   // options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API,
1021   // hip_api_domain_ops);
1022   options.activity_tracing.emplace(ACTIVITY_DOMAIN_HIP_API, empty_vec);
1023   options.activity_tracing.emplace(ACTIVITY_DOMAIN_HCC_OPS, empty_vec);
1024 
1025   return options;
1026 }
1027 
GetRocmTraceCollectorOptions(uint32_t num_gpus)1028 RocmTraceCollectorOptions GpuTracer::GetRocmTraceCollectorOptions(
1029     uint32_t num_gpus) {
1030   RocmTraceCollectorOptions options;
1031   options.max_callback_api_events = 2 * 1024 * 1024;
1032   options.max_activity_api_events = 2 * 1024 * 1024;
1033   options.max_annotation_strings = 1024 * 1024;
1034   options.num_gpus = num_gpus;
1035   return options;
1036 }
1037 
DoStart()1038 Status GpuTracer::DoStart() {
1039   if (!rocm_tracer_->IsAvailable()) {
1040     return errors::Unavailable("Another profile session running.");
1041   }
1042 
1043   AnnotationStack::Enable(true);
1044 
1045   RocmTraceCollectorOptions trace_collector_options =
1046       GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus());
1047   uint64_t start_gputime_ns = RocmTracer::GetTimestamp();
1048   uint64_t start_walltime_ns = tensorflow::EnvTime::NowNanos();
1049   rocm_trace_collector_ = std::make_unique<RocmTraceCollectorImpl>(
1050       trace_collector_options, start_walltime_ns, start_gputime_ns);
1051 
1052   RocmTracerOptions tracer_options = GetRocmTracerOptions();
1053   rocm_tracer_->Enable(tracer_options, rocm_trace_collector_.get());
1054 
1055   return Status::OK();
1056 }
1057 
Start()1058 Status GpuTracer::Start() {
1059   Status status = DoStart();
1060   if (status.ok()) {
1061     profiling_state_ = State::kStartedOk;
1062     return Status::OK();
1063   } else {
1064     profiling_state_ = State::kStartedError;
1065     return status;
1066   }
1067 }
1068 
DoStop()1069 Status GpuTracer::DoStop() {
1070   rocm_tracer_->Disable();
1071   AnnotationStack::Enable(false);
1072   return Status::OK();
1073 }
1074 
Stop()1075 Status GpuTracer::Stop() {
1076   if (profiling_state_ == State::kStartedOk) {
1077     Status status = DoStop();
1078     profiling_state_ = status.ok() ? State::kStoppedOk : State::kStoppedError;
1079   }
1080   return Status::OK();
1081 }
1082 
DoCollectData(XSpace * space)1083 Status GpuTracer::DoCollectData(XSpace* space) {
1084   if (rocm_trace_collector_) rocm_trace_collector_->Export(space);
1085   return Status::OK();
1086 }
1087 
CollectData(XSpace * space)1088 Status GpuTracer::CollectData(XSpace* space) {
1089   switch (profiling_state_) {
1090     case State::kNotStarted:
1091       VLOG(3) << "No trace data collected, session wasn't started";
1092       return Status::OK();
1093     case State::kStartedOk:
1094       return errors::FailedPrecondition("Cannot collect trace before stopping");
1095     case State::kStartedError:
1096       LOG(ERROR) << "Cannot collect, roctracer failed to start";
1097       return Status::OK();
1098     case State::kStoppedError:
1099       VLOG(3) << "No trace data collected";
1100       return Status::OK();
1101     case State::kStoppedOk: {
1102       DoCollectData(space);
1103       return Status::OK();
1104     }
1105   }
1106   return errors::Internal("Invalid profiling state: ", profiling_state_);
1107 }
1108 
1109 // Not in anonymous namespace for testing purposes.
CreateGpuTracer(const ProfileOptions & options)1110 std::unique_ptr<profiler::ProfilerInterface> CreateGpuTracer(
1111     const ProfileOptions& options) {
1112   if (options.device_type() != ProfileOptions::GPU &&
1113       options.device_type() != ProfileOptions::UNSPECIFIED)
1114     return nullptr;
1115 
1116   profiler::RocmTracer* rocm_tracer =
1117       profiler::RocmTracer::GetRocmTracerSingleton();
1118   if (!rocm_tracer->IsAvailable()) return nullptr;
1119 
1120   return std::make_unique<profiler::GpuTracer>(rocm_tracer);
1121 }
1122 
__anon9b30bc4e0602null1123 auto register_rocm_gpu_tracer_factory = [] {
1124   RegisterProfilerFactory(&CreateGpuTracer);
1125   return 0;
1126 }();
1127 
1128 }  // namespace profiler
1129 }  // namespace tensorflow
1130 
1131 #endif  // TENSORFLOW_USE_ROCM
1132