xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/backends/gpu/rocm_tracer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/backends/gpu/rocm_tracer.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/node_hash_map.h"
20 #include "rocm/rocm_config.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/gtl/cleanup.h"
23 #include "tensorflow/core/lib/hash/hash.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/mem.h"
28 #include "tensorflow/core/profiler/backends/cpu/annotation_stack.h"
29 #include "tensorflow/core/profiler/utils/time_utils.h"
30 
31 namespace tensorflow {
32 namespace profiler {
33 
34 constexpr uint32_t RocmTracerEvent::kInvalidDeviceId;
35 
36 #define RETURN_IF_ROCTRACER_ERROR(expr)                                      \
37   do {                                                                       \
38     roctracer_status_t status = expr;                                        \
39     if (status != ROCTRACER_STATUS_SUCCESS) {                                \
40       const char* errstr = wrap::roctracer_error_string();                   \
41       LOG(ERROR) << "function " << #expr << "failed with error " << errstr;  \
42       return errors::Internal(absl::StrCat("roctracer call error", errstr)); \
43     }                                                                        \
44   } while (false)
45 
46 namespace {
47 
48 // GetCachedTID() caches the thread ID in thread-local storage (which is a
49 // userspace construct) to avoid unnecessary system calls. Without this caching,
50 // it can take roughly 98ns, while it takes roughly 1ns with this caching.
GetCachedTID()51 int32_t GetCachedTID() {
52   static thread_local int32_t current_thread_id =
53       Env::Default()->GetCurrentThreadId();
54   return current_thread_id;
55 }
56 
GetActivityDomainName(uint32_t domain)57 const char* GetActivityDomainName(uint32_t domain) {
58   switch (domain) {
59     case ACTIVITY_DOMAIN_HSA_API:
60       return "HSA API";
61     case ACTIVITY_DOMAIN_HSA_OPS:
62       return "HSA OPS";
63     case ACTIVITY_DOMAIN_HIP_OPS:
64       return "HIP OPS/HCC/VDI";
65     case ACTIVITY_DOMAIN_HIP_API:
66       return "HIP API";
67     case ACTIVITY_DOMAIN_KFD_API:
68       return "KFD API";
69     case ACTIVITY_DOMAIN_EXT_API:
70       return "EXT API";
71     case ACTIVITY_DOMAIN_ROCTX:
72       return "ROCTX";
73     default:
74       DCHECK(false);
75       return "";
76   }
77   return "";
78 }
79 
GetActivityDomainOpName(uint32_t domain,uint32_t op)80 string GetActivityDomainOpName(uint32_t domain, uint32_t op) {
81   std::ostringstream oss;
82   oss << GetActivityDomainName(domain) << " - ";
83   switch (domain) {
84     case ACTIVITY_DOMAIN_HIP_API:
85       oss << hip_api_name(op);
86       break;
87     default:
88       oss << op;
89       break;
90   }
91   return oss.str();
92 }
93 
GetActivityPhaseName(uint32_t phase)94 const char* GetActivityPhaseName(uint32_t phase) {
95   switch (phase) {
96     case ACTIVITY_API_PHASE_ENTER:
97       return "ENTER";
98     case ACTIVITY_API_PHASE_EXIT:
99       return "EXIT";
100     default:
101       DCHECK(false);
102       return "";
103   }
104   return "";
105 }
106 
DumpApiCallbackData(uint32_t domain,uint32_t cbid,const void * cbdata)107 inline void DumpApiCallbackData(uint32_t domain, uint32_t cbid,
108                                 const void* cbdata) {
109   std::ostringstream oss;
110   oss << "API callback for " << GetActivityDomainName(domain);
111   if (domain == ACTIVITY_DOMAIN_HIP_API) {
112     const hip_api_data_t* data =
113         reinterpret_cast<const hip_api_data_t*>(cbdata);
114     oss << " - " << hip_api_name(cbid);
115     oss << ", correlation_id=" << data->correlation_id;
116     oss << ", phase=" << GetActivityPhaseName(data->phase);
117     switch (cbid) {
118       case HIP_API_ID_hipModuleLaunchKernel:
119       case HIP_API_ID_hipExtModuleLaunchKernel:
120       case HIP_API_ID_hipHccModuleLaunchKernel:
121       case HIP_API_ID_hipLaunchKernel:
122         break;
123       case HIP_API_ID_hipMemcpyDtoH:
124         oss << ", sizeBytes=" << data->args.hipMemcpyDtoH.sizeBytes;
125         break;
126       case HIP_API_ID_hipMemcpyDtoHAsync:
127         oss << ", sizeBytes=" << data->args.hipMemcpyDtoHAsync.sizeBytes;
128         break;
129       case HIP_API_ID_hipMemcpyHtoD:
130         oss << ", sizeBytes=" << data->args.hipMemcpyHtoD.sizeBytes;
131         break;
132       case HIP_API_ID_hipMemcpyHtoDAsync:
133         oss << ", sizeBytes=" << data->args.hipMemcpyHtoDAsync.sizeBytes;
134         break;
135       case HIP_API_ID_hipMemcpyDtoD:
136         oss << ", sizeBytes=" << data->args.hipMemcpyDtoD.sizeBytes;
137         break;
138       case HIP_API_ID_hipMemcpyDtoDAsync:
139         oss << ", sizeBytes=" << data->args.hipMemcpyDtoDAsync.sizeBytes;
140         break;
141       case HIP_API_ID_hipMemcpyAsync:
142         oss << ", sizeBytes=" << data->args.hipMemcpyAsync.sizeBytes;
143         break;
144       case HIP_API_ID_hipMemsetD32:
145         oss << ", value=" << data->args.hipMemsetD32.value;
146         oss << ", count=" << data->args.hipMemsetD32.count;
147         break;
148       case HIP_API_ID_hipMemsetD32Async:
149         oss << ", value=" << data->args.hipMemsetD32Async.value;
150         oss << ", count=" << data->args.hipMemsetD32Async.count;
151         break;
152       case HIP_API_ID_hipMemsetD8:
153         oss << ", value=" << data->args.hipMemsetD8.value;
154         oss << ", count=" << data->args.hipMemsetD8.count;
155         break;
156       case HIP_API_ID_hipMemsetD8Async:
157         oss << ", value=" << data->args.hipMemsetD8Async.value;
158         oss << ", count=" << data->args.hipMemsetD8Async.count;
159         break;
160       case HIP_API_ID_hipMalloc:
161         oss << ", size=" << data->args.hipMalloc.size;
162         break;
163       case HIP_API_ID_hipFree:
164         oss << ", ptr=" << data->args.hipFree.ptr;
165         break;
166       case HIP_API_ID_hipStreamSynchronize:
167         break;
168       default:
169         DCHECK(false);
170         break;
171     }
172   } else {
173     oss << ": " << cbid;
174   }
175   VLOG(3) << oss.str();
176 }
177 
DumpActivityRecord(const roctracer_record_t * record,std::string extra_info)178 void DumpActivityRecord(const roctracer_record_t* record,
179                         std::string extra_info) {
180   std::ostringstream oss;
181   oss << "Activity callback for " << GetActivityDomainName(record->domain);
182   oss << ", op name= "
183       << wrap::roctracer_op_string(record->domain, record->op, record->kind);
184   oss << ", correlation_id=" << record->correlation_id;
185   oss << ", begin_ns=" << record->begin_ns;
186   oss << ", end_ns=" << record->end_ns;
187   oss << ", duration=" << record->end_ns - record->begin_ns;
188   oss << ", device_id=" << record->device_id;
189   oss << ", queue_id=" << record->queue_id;
190   oss << ", process_id=" << record->process_id;
191   oss << ", thread_id=" << record->thread_id;
192   oss << ", external_id=" << record->external_id;
193   oss << ", bytes=" << record->bytes;
194   oss << ", domain=" << record->domain;
195   oss << ", op=" << record->op;
196   oss << ", kind=" << record->kind;
197   oss << ", extra_info=" << extra_info;
198   VLOG(3) << oss.str();
199 }
200 
201 }  // namespace
202 
GetRocmTracerEventTypeName(const RocmTracerEventType & type)203 const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type) {
204   switch (type) {
205     case RocmTracerEventType::MemcpyH2D:
206       return "MemcpyH2D";
207     case RocmTracerEventType::MemcpyD2H:
208       return "MemcpyD2H";
209     case RocmTracerEventType::MemcpyD2D:
210       return "MemcpyD2D";
211     case RocmTracerEventType::MemcpyP2P:
212       return "MemcpyP2P";
213     case RocmTracerEventType::MemcpyOther:
214       return "MemcpyOther";
215     case RocmTracerEventType::Kernel:
216       return "Kernel";
217     case RocmTracerEventType::MemoryAlloc:
218       return "MemoryAlloc";
219     case RocmTracerEventType::Generic:
220       return "Generic";
221     case RocmTracerEventType::Synchronization:
222       return "Synchronization";
223     case RocmTracerEventType::Memset:
224       return "Memset";
225     default:
226       DCHECK(false);
227       return "";
228   }
229   return "";
230 }
231 
GetRocmTracerEventSourceName(const RocmTracerEventSource & source)232 const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source) {
233   switch (source) {
234     case RocmTracerEventSource::ApiCallback:
235       return "ApiCallback";
236       break;
237     case RocmTracerEventSource::Activity:
238       return "Activity";
239       break;
240     case RocmTracerEventSource::Invalid:
241       return "Invalid";
242       break;
243     default:
244       DCHECK(false);
245       return "";
246   }
247   return "";
248 }
249 
250 // FIXME(rocm-profiler): These domain names are not consistent with the
251 // GetActivityDomainName function
GetRocmTracerEventDomainName(const RocmTracerEventDomain & domain)252 const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain) {
253   switch (domain) {
254     case RocmTracerEventDomain::HIP_API:
255       return "HIP_API";
256       break;
257     case RocmTracerEventDomain::HCC_OPS:
258       return "HCC_OPS";
259       break;
260     default:
261       DCHECK(false);
262       return "";
263   }
264   return "";
265 }
266 
DumpRocmTracerEvent(const RocmTracerEvent & event,uint64_t start_walltime_ns,uint64_t start_gputime_ns,const string & message)267 void DumpRocmTracerEvent(const RocmTracerEvent& event,
268                          uint64_t start_walltime_ns, uint64_t start_gputime_ns,
269                          const string& message) {
270   std::ostringstream oss;
271   oss << "correlation_id=" << event.correlation_id;
272   oss << ",type=" << GetRocmTracerEventTypeName(event.type);
273   oss << ",source=" << GetRocmTracerEventSourceName(event.source);
274   oss << ",domain=" << GetRocmTracerEventDomainName(event.domain);
275   oss << ",name=" << event.name;
276   oss << ",annotation=" << event.annotation;
277   oss << ",start_time_us="
278       << (start_walltime_ns + (start_gputime_ns - event.start_time_ns)) / 1000;
279   oss << ",duration=" << (event.end_time_ns - event.start_time_ns) / 1000;
280   oss << ",device_id=" << event.device_id;
281   oss << ",thread_id=" << event.thread_id;
282   oss << ",stream_id=" << event.stream_id;
283 
284   switch (event.type) {
285     case RocmTracerEventType::Kernel:
286       break;
287     case RocmTracerEventType::MemcpyD2H:
288     case RocmTracerEventType::MemcpyH2D:
289     case RocmTracerEventType::MemcpyD2D:
290     case RocmTracerEventType::MemcpyP2P:
291       oss << ",num_bytes=" << event.memcpy_info.num_bytes;
292       oss << ",destination=" << event.memcpy_info.destination;
293       oss << ",async=" << event.memcpy_info.async;
294       break;
295     case RocmTracerEventType::MemoryAlloc:
296       oss << ",num_bytes=" << event.memalloc_info.num_bytes;
297       break;
298     case RocmTracerEventType::Synchronization:
299       break;
300     case RocmTracerEventType::Generic:
301       break;
302     default:
303       DCHECK(false);
304       break;
305   }
306   oss << message;
307   VLOG(3) << oss.str();
308 }
309 
operator ()(uint32_t domain,uint32_t cbid,const void * cbdata)310 Status RocmApiCallbackImpl::operator()(uint32_t domain, uint32_t cbid,
311                                        const void* cbdata) {
312   /* Some APIs such as hipMalloc, implicitly work on th devices set by the
313     user using APIs such as hipSetDevice. API callbacks and activity records
314     for functions like hipMalloc does not return the device id (CUDA does). To
315     solve this we need to track the APIs that select the device (such as
316     hipSetDevice) for each thread.
317     */
318 
319   thread_local uint32_t default_device = 0;
320 
321   // DumpApiCallbackData(domain, cbid, cbdata);
322 
323   if (domain != ACTIVITY_DOMAIN_HIP_API) return Status::OK();
324 
325   const hip_api_data_t* data = reinterpret_cast<const hip_api_data_t*>(cbdata);
326 
327   if (data->phase == ACTIVITY_API_PHASE_ENTER) {
328     if (options_.api_tracking_set.find(cbid) !=
329         options_.api_tracking_set.end()) {
330       mutex_lock lock(api_call_start_mutex_);
331       api_call_start_time_.emplace(data->correlation_id,
332                                    RocmTracer::GetTimestamp());
333     }
334 
335     if (cbid == HIP_API_ID_hipSetDevice) {
336       default_device = data->args.hipSetDevice.deviceId;
337     }
338   } else if (data->phase == ACTIVITY_API_PHASE_EXIT) {
339     uint64_t enter_time = 0, exit_time = 0;
340 
341     if (options_.api_tracking_set.find(cbid) !=
342         options_.api_tracking_set.end()) {
343       mutex_lock lock(api_call_start_mutex_);
344       if (api_call_start_time_.find(data->correlation_id) !=
345           api_call_start_time_.end()) {
346         enter_time = api_call_start_time_.at(data->correlation_id);
347         api_call_start_time_.erase(data->correlation_id);
348       } else {
349         LOG(WARNING) << "An API exit callback received without API enter "
350                         "with same correlation id. Event droped!";
351         return Status::OK();  // This API does not belong to us.
352       }
353       exit_time = RocmTracer::GetTimestamp();
354     }
355     // Set up the map from correlation id to annotation string.
356     const std::string& annotation = AnnotationStack::Get();
357     if (!annotation.empty()) {
358       collector_->annotation_map()->Add(data->correlation_id, annotation);
359     }
360 
361     if (options_.api_tracking_set.find(cbid) ==
362         options_.api_tracking_set.end()) {
363       VLOG(3) << "API callback is from the auxilarity list. Corr. id="
364               << data->correlation_id;
365     }
366     DumpApiCallbackData(domain, cbid, cbdata);
367 
368     switch (cbid) {
369       // star in comments means it does not exist in the driver wrapper
370       case HIP_API_ID_hipModuleLaunchKernel:
371       case HIP_API_ID_hipExtModuleLaunchKernel:  // *
372       case HIP_API_ID_hipHccModuleLaunchKernel:  // *
373       case HIP_API_ID_hipLaunchKernel:           // *
374 
375         this->AddKernelEventUponApiExit(cbid, data, enter_time, exit_time);
376 
377         // Add the correlation_ids for these events to the pending set
378         // so that we can explicitly wait for their corresponding
379         // HIP runtime activity records, before exporting the trace data
380         tracer_->AddToPendingActivityRecords(data->correlation_id);
381         break;
382       case HIP_API_ID_hipMemcpy:
383       case HIP_API_ID_hipMemcpyDtoH:
384       case HIP_API_ID_hipMemcpyDtoHAsync:
385       case HIP_API_ID_hipMemcpyHtoD:
386       case HIP_API_ID_hipMemcpyHtoDAsync:
387       case HIP_API_ID_hipMemcpyDtoD:
388       case HIP_API_ID_hipMemcpyDtoDAsync:
389       case HIP_API_ID_hipMemcpyAsync:
390         this->AddNormalMemcpyEventUponApiExit(cbid, data, enter_time,
391                                               exit_time);
392         tracer_->AddToPendingActivityRecords(data->correlation_id);
393         break;
394       case HIP_API_ID_hipMemset:
395       case HIP_API_ID_hipMemsetAsync:
396       case HIP_API_ID_hipMemsetD32:
397       case HIP_API_ID_hipMemsetD32Async:
398       case HIP_API_ID_hipMemsetD16:
399       case HIP_API_ID_hipMemsetD16Async:
400       case HIP_API_ID_hipMemsetD8:
401       case HIP_API_ID_hipMemsetD8Async:
402         this->AddMemsetEventUponApiExit(cbid, data, enter_time, exit_time);
403         break;
404       case HIP_API_ID_hipMalloc:
405       case HIP_API_ID_hipMallocPitch:
406       case HIP_API_ID_hipHostMalloc:
407       case HIP_API_ID_hipFree:
408       case HIP_API_ID_hipHostFree:
409         this->AddMallocFreeEventUponApiExit(cbid, data, default_device,
410                                             enter_time, exit_time);
411         break;
412       case HIP_API_ID_hipStreamSynchronize:
413       case HIP_API_ID_hipStreamWaitEvent:
414         // case HIP_API_ID_hipEventSynchronize:
415         this->AddSynchronizeEventUponApiExit(cbid, data, enter_time, exit_time);
416         break;
417       case HIP_API_ID_hipSetDevice:
418         // we track this ID only to find the device ID
419         //  for the current thread.
420         break;
421       default:
422         //
423         LOG(WARNING) << "API call "
424                      << wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid,
425                                                   0)
426                      << ", corr. id=" << data->correlation_id
427                      << " dropped. No capturing function was found!";
428         // AddGenericEventUponApiExit(cbid, data);
429         break;
430     }
431   }
432   return Status::OK();
433 }
434 
AddKernelEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,const uint64_t enter_time,const uint64_t exit_time)435 void RocmApiCallbackImpl::AddKernelEventUponApiExit(uint32_t cbid,
436                                                     const hip_api_data_t* data,
437                                                     const uint64_t enter_time,
438                                                     const uint64_t exit_time) {
439   /*
440   extra fields:
441     kernel_info, domain
442 
443   missing fields:
444     context_id
445   */
446   RocmTracerEvent event;
447 
448   event.domain = RocmTracerEventDomain::HIP_API;
449   event.type = RocmTracerEventType::Kernel;
450   event.source = RocmTracerEventSource::ApiCallback;
451   event.thread_id = GetCachedTID();
452   event.correlation_id = data->correlation_id;
453   event.start_time_ns = enter_time;
454   event.end_time_ns = exit_time;
455 
456   switch (cbid) {
457     case HIP_API_ID_hipModuleLaunchKernel: {
458       const hipFunction_t kernelFunc = data->args.hipModuleLaunchKernel.f;
459       if (kernelFunc != nullptr) event.name = hipKernelNameRef(kernelFunc);
460 
461       event.kernel_info.dynamic_shared_memory_usage =
462           data->args.hipModuleLaunchKernel.sharedMemBytes;
463       event.kernel_info.block_x = data->args.hipModuleLaunchKernel.blockDimX;
464       event.kernel_info.block_y = data->args.hipModuleLaunchKernel.blockDimY;
465       event.kernel_info.block_z = data->args.hipModuleLaunchKernel.blockDimZ;
466       event.kernel_info.grid_x = data->args.hipModuleLaunchKernel.gridDimX;
467       event.kernel_info.grid_y = data->args.hipModuleLaunchKernel.gridDimY;
468       event.kernel_info.grid_z = data->args.hipModuleLaunchKernel.gridDimZ;
469       event.kernel_info.func_ptr = kernelFunc;
470       const hipStream_t& stream = data->args.hipModuleLaunchKernel.stream;
471       // TODO(rocm-profiler): wrap this API if possible.
472       event.device_id = hipGetStreamDeviceId(stream);
473     } break;
474     case HIP_API_ID_hipExtModuleLaunchKernel: {
475       const hipFunction_t kernelFunc = data->args.hipExtModuleLaunchKernel.f;
476       if (kernelFunc != nullptr) event.name = hipKernelNameRef(kernelFunc);
477 
478       event.kernel_info.dynamic_shared_memory_usage =
479           data->args.hipExtModuleLaunchKernel.sharedMemBytes;
480       unsigned int blockDimX =
481           data->args.hipExtModuleLaunchKernel.localWorkSizeX;
482       unsigned int blockDimY =
483           data->args.hipExtModuleLaunchKernel.localWorkSizeY;
484       unsigned int blockDimZ =
485           data->args.hipExtModuleLaunchKernel.localWorkSizeZ;
486 
487       event.kernel_info.block_x = blockDimX;
488       event.kernel_info.block_y = blockDimY;
489       event.kernel_info.block_z = blockDimZ;
490       event.kernel_info.grid_x =
491           data->args.hipExtModuleLaunchKernel.globalWorkSizeX / blockDimX;
492       event.kernel_info.grid_y =
493           data->args.hipExtModuleLaunchKernel.globalWorkSizeY / blockDimY;
494       event.kernel_info.grid_z =
495           data->args.hipExtModuleLaunchKernel.globalWorkSizeZ / blockDimZ;
496       event.kernel_info.func_ptr = kernelFunc;
497       const hipStream_t& stream = data->args.hipExtModuleLaunchKernel.hStream;
498       event.device_id = hipGetStreamDeviceId(stream);
499     } break;
500     case HIP_API_ID_hipHccModuleLaunchKernel: {
501       const hipFunction_t kernelFunc = data->args.hipHccModuleLaunchKernel.f;
502       if (kernelFunc != nullptr) event.name = hipKernelNameRef(kernelFunc);
503 
504       event.kernel_info.dynamic_shared_memory_usage =
505           data->args.hipHccModuleLaunchKernel.sharedMemBytes;
506       event.kernel_info.block_x = data->args.hipHccModuleLaunchKernel.blockDimX;
507       event.kernel_info.block_y = data->args.hipHccModuleLaunchKernel.blockDimY;
508       event.kernel_info.block_z = data->args.hipHccModuleLaunchKernel.blockDimZ;
509       event.kernel_info.grid_x =
510           data->args.hipHccModuleLaunchKernel.globalWorkSizeX /
511           event.kernel_info.block_x;
512       event.kernel_info.grid_y =
513           data->args.hipHccModuleLaunchKernel.globalWorkSizeY /
514           event.kernel_info.block_y;
515       event.kernel_info.grid_z =
516           data->args.hipHccModuleLaunchKernel.globalWorkSizeZ /
517           event.kernel_info.block_z;
518       event.kernel_info.func_ptr = kernelFunc;
519       const hipStream_t& stream = data->args.hipHccModuleLaunchKernel.hStream;
520       event.device_id = hipGetStreamDeviceId(stream);
521       event.kernel_info.dynamic_shared_memory_usage =
522           data->args.hipHccModuleLaunchKernel.sharedMemBytes;
523     } break;
524     case HIP_API_ID_hipLaunchKernel: {
525       const void* func_addr = data->args.hipLaunchKernel.function_address;
526       hipStream_t stream = data->args.hipLaunchKernel.stream;
527       if (func_addr != nullptr)
528         event.name = hipKernelNameRefByPtr(func_addr, stream);
529 
530       event.kernel_info.dynamic_shared_memory_usage =
531           data->args.hipLaunchKernel.sharedMemBytes;
532       event.kernel_info.block_x = data->args.hipLaunchKernel.dimBlocks.x;
533       event.kernel_info.block_y = data->args.hipLaunchKernel.dimBlocks.y;
534       event.kernel_info.block_z = data->args.hipLaunchKernel.dimBlocks.z;
535       event.kernel_info.grid_x = data->args.hipLaunchKernel.numBlocks.x;
536       event.kernel_info.grid_y = data->args.hipLaunchKernel.numBlocks.y;
537       event.kernel_info.grid_z = data->args.hipLaunchKernel.numBlocks.z;
538       event.kernel_info.func_ptr = (void*)func_addr;
539       event.device_id = hipGetStreamDeviceId(stream);
540     } break;
541   }
542   bool is_auxiliary =
543       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
544   collector_->AddEvent(std::move(event), is_auxiliary);
545 }
546 
AddNormalMemcpyEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,uint64_t enter_time,uint64_t exit_time)547 void RocmApiCallbackImpl::AddNormalMemcpyEventUponApiExit(
548     uint32_t cbid, const hip_api_data_t* data, uint64_t enter_time,
549     uint64_t exit_time) {
550   /*
551     missing:
552       device_id(partially, have only for async), context_id,
553     memcpy_info.kind(CUPTI puts CUPTI_ACTIVITY_MEMCPY_KIND_UNKNOWN),
554       memcpy_info.destenation(partially, only for async)( CUPTI puts device_id),
555 
556     extra:
557       domain, name,
558   */
559   // for CUDA, it does NOT capture stream id for these types
560 
561   RocmTracerEvent event;
562   event.domain = RocmTracerEventDomain::HIP_API;
563   event.name = wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid, 0);
564   event.source = RocmTracerEventSource::ApiCallback;
565   event.thread_id = GetCachedTID();
566   event.correlation_id = data->correlation_id;
567   event.start_time_ns = enter_time;
568   event.end_time_ns = exit_time;
569 
570   /* The general hipMemcpy or hipMemcpyAsync can support any kind of memory
571   copy operation, such as H2D, D2D, P2P, and D2H. Here we use MemcpyOther for
572   all api calls with HipMemcpy(+Async) to carry-on this generality.
573   We also assume that if we want to copy data BETWEEN devices, we do not use
574   hipMemcpy(+Async) or hipMemcpyDtoD(+Async) as we explicitly always set the
575   destenation as the source device id). Ultimately, to figure out the actual
576   device we can use hipPointerGetAttributes but we do not do that now .In the
577   other words, we assume we use hipMemcpyPeer to achieve the copy between
578   devices.
579   */
580 
581   switch (cbid) {
582     case HIP_API_ID_hipMemcpyDtoH: {
583       event.type = RocmTracerEventType::MemcpyD2H;
584       event.memcpy_info.num_bytes = data->args.hipMemcpyDtoH.sizeBytes;
585       event.memcpy_info.async = false;
586     } break;
587     case HIP_API_ID_hipMemcpyDtoHAsync: {
588       event.type = RocmTracerEventType::MemcpyD2H;
589       const hipStream_t& stream = data->args.hipMemcpyDtoHAsync.stream;
590       event.device_id = hipGetStreamDeviceId(stream);
591       event.memcpy_info.num_bytes = data->args.hipMemcpyDtoHAsync.sizeBytes;
592       event.memcpy_info.async = true;
593       event.memcpy_info.destination = event.device_id;
594     } break;
595     case HIP_API_ID_hipMemcpyHtoD: {
596       event.type = RocmTracerEventType::MemcpyH2D;
597       event.memcpy_info.num_bytes = data->args.hipMemcpyHtoD.sizeBytes;
598       event.memcpy_info.async = false;
599       // we set the destenattion device id for it using the device id we get
600       // from activities when they exchange information before flushing
601     } break;
602     case HIP_API_ID_hipMemcpyHtoDAsync: {
603       event.type = RocmTracerEventType::MemcpyH2D;
604       const hipStream_t& stream = data->args.hipMemcpyHtoDAsync.stream;
605       event.device_id = hipGetStreamDeviceId(stream);
606       event.memcpy_info.num_bytes = data->args.hipMemcpyHtoDAsync.sizeBytes;
607       event.memcpy_info.async = true;
608       event.memcpy_info.destination = event.device_id;
609     } break;
610     case HIP_API_ID_hipMemcpyDtoD: {
611       event.type = RocmTracerEventType::MemcpyD2D;
612       event.memcpy_info.num_bytes = data->args.hipMemcpyDtoD.sizeBytes;
613       event.memcpy_info.async = false;
614     } break;
615     case HIP_API_ID_hipMemcpyDtoDAsync: {
616       event.type = RocmTracerEventType::MemcpyD2D;
617       const hipStream_t& stream = data->args.hipMemcpyDtoDAsync.stream;
618       event.device_id = hipGetStreamDeviceId(stream);
619       event.memcpy_info.num_bytes = data->args.hipMemcpyDtoDAsync.sizeBytes;
620       event.memcpy_info.async = true;
621       event.memcpy_info.destination = event.device_id;
622     } break;
623     case HIP_API_ID_hipMemcpy: {
624       event.type = RocmTracerEventType::MemcpyOther;
625       event.memcpy_info.num_bytes = data->args.hipMemcpy.sizeBytes;
626       event.memcpy_info.async = false;
627     } break;
628     case HIP_API_ID_hipMemcpyAsync: {
629       event.type = RocmTracerEventType::MemcpyOther;
630       const hipStream_t& stream = data->args.hipMemcpyAsync.stream;
631       event.device_id = hipGetStreamDeviceId(stream);
632       event.memcpy_info.num_bytes = data->args.hipMemcpyAsync.sizeBytes;
633       event.memcpy_info.async = true;
634       event.memcpy_info.destination = event.device_id;
635     } break;
636     default:
637       LOG(WARNING) << "Unsupported Memcpy API for profiling observed for cbid="
638                    << cbid << ". Event dropped!";
639       return;
640       break;
641   }
642 
643   bool is_auxiliary =
644       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
645   collector_->AddEvent(std::move(event), is_auxiliary);
646 }
AddMemcpyPeerEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,uint64_t enter_time,uint64_t exit_time)647 void RocmApiCallbackImpl::AddMemcpyPeerEventUponApiExit(
648     uint32_t cbid, const hip_api_data_t* data, uint64_t enter_time,
649     uint64_t exit_time) {
650   /*
651     missing: context_id, memcpy_info.kind
652 
653     extra: domain, name,
654   */
655 
656   RocmTracerEvent event;
657   event.type = RocmTracerEventType::MemcpyP2P;
658   event.domain = RocmTracerEventDomain::HIP_API;
659   event.name = wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid, 0);
660   event.source = RocmTracerEventSource::ApiCallback;
661   event.thread_id = GetCachedTID();
662   event.correlation_id = data->correlation_id;
663   event.start_time_ns = enter_time;
664   event.end_time_ns = exit_time;
665 
666   switch (cbid) {
667     case HIP_API_ID_hipMemcpyPeer:
668       event.device_id = data->args.hipMemcpyPeer.srcDeviceId;
669       event.memcpy_info.destination = data->args.hipMemcpyPeer.dstDeviceId;
670       event.memcpy_info.num_bytes = data->args.hipMemcpyPeer.sizeBytes;
671       event.memcpy_info.async = false;
672       break;
673     case HIP_API_ID_hipMemcpyPeerAsync:
674       event.device_id = data->args.hipMemcpyPeerAsync.srcDevice;
675       event.memcpy_info.destination = data->args.hipMemcpyPeerAsync.dstDeviceId;
676       event.memcpy_info.num_bytes = data->args.hipMemcpyPeerAsync.sizeBytes;
677       event.memcpy_info.async = true;
678       break;
679     default:
680       LOG(WARNING)
681           << "Unsupported MemcpyPeer API for profiling observed for cbid="
682           << cbid << ". Event dropped!";
683       return;
684       break;
685   }
686 
687   bool is_auxiliary =
688       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
689   collector_->AddEvent(std::move(event), is_auxiliary);
690 }
AddMemsetEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,uint64_t enter_time,uint64_t exit_time)691 void RocmApiCallbackImpl::AddMemsetEventUponApiExit(uint32_t cbid,
692                                                     const hip_api_data_t* data,
693                                                     uint64_t enter_time,
694                                                     uint64_t exit_time) {
695   /*
696     misses:
697       device_id(only avail. for async), context_id
698 
699     extras:
700       domain, name
701   */
702 
703   RocmTracerEvent event;
704   event.domain = RocmTracerEventDomain::HIP_API;
705   event.name = wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid, 0);
706   event.source = RocmTracerEventSource::ApiCallback;
707   event.thread_id = GetCachedTID();
708   event.correlation_id = data->correlation_id;
709   event.start_time_ns = enter_time;
710   event.end_time_ns = exit_time;
711 
712   switch (cbid) {
713     case HIP_API_ID_hipMemsetD8:
714       event.type = RocmTracerEventType::Memset;
715       event.memset_info.num_bytes = data->args.hipMemsetD8.count;
716       event.memset_info.async = false;
717       break;
718     case HIP_API_ID_hipMemsetD8Async: {
719       event.type = RocmTracerEventType::Memset;
720       event.memset_info.num_bytes = data->args.hipMemsetD8Async.count;
721       event.memset_info.async = true;
722       const hipStream_t& stream = data->args.hipMemsetD8Async.stream;
723       event.device_id = hipGetStreamDeviceId(stream);
724     } break;
725     case HIP_API_ID_hipMemsetD16:
726       event.type = RocmTracerEventType::Memset;
727       event.memset_info.num_bytes = 2 * data->args.hipMemsetD16.count;
728       event.memset_info.async = false;
729       break;
730     case HIP_API_ID_hipMemsetD16Async: {
731       event.type = RocmTracerEventType::Memset;
732       event.memset_info.num_bytes = 2 * data->args.hipMemsetD16Async.count;
733       event.memset_info.async = true;
734       const hipStream_t& stream = data->args.hipMemsetD16Async.stream;
735       event.device_id = hipGetStreamDeviceId(stream);
736     } break;
737     case HIP_API_ID_hipMemsetD32:
738       event.type = RocmTracerEventType::Memset;
739       event.memset_info.num_bytes = 4 * data->args.hipMemsetD32.count;
740       event.memset_info.async = false;
741       break;
742     case HIP_API_ID_hipMemsetD32Async: {
743       event.type = RocmTracerEventType::Memset;
744       event.memset_info.num_bytes = 4 * data->args.hipMemsetD32Async.count;
745       event.memset_info.async = true;
746       const hipStream_t& stream = data->args.hipMemsetD32Async.stream;
747       event.device_id = hipGetStreamDeviceId(stream);
748     } break;
749     case HIP_API_ID_hipMemset:
750       event.type = RocmTracerEventType::Memset;
751       event.memset_info.num_bytes = data->args.hipMemset.sizeBytes;
752       event.memset_info.async = false;
753       break;
754     case HIP_API_ID_hipMemsetAsync: {
755       event.type = RocmTracerEventType::Memset;
756       event.memset_info.num_bytes = data->args.hipMemsetAsync.sizeBytes;
757       event.memset_info.async = true;
758       const hipStream_t& stream = data->args.hipMemsetAsync.stream;
759       event.device_id = hipGetStreamDeviceId(stream);
760     } break;
761     default:
762       LOG(WARNING) << "Unsupported Memset API for profiling observed for cbid="
763                    << cbid << ". Event dropped!";
764       return;
765       break;
766   }
767 
768   bool is_auxiliary =
769       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
770   collector_->AddEvent(std::move(event), is_auxiliary);
771 }
772 
AddMallocFreeEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,uint32_t device_id,uint64_t enter_time,uint64_t exit_time)773 void RocmApiCallbackImpl::AddMallocFreeEventUponApiExit(
774     uint32_t cbid, const hip_api_data_t* data, uint32_t device_id,
775     uint64_t enter_time, uint64_t exit_time) {
776   /*
777     misses: context_id
778 
779     extras: domain
780   */
781 
782   RocmTracerEvent event;
783   event.domain = RocmTracerEventDomain::HIP_API;
784   event.type = (cbid == HIP_API_ID_hipFree || cbid == HIP_API_ID_hipHostFree)
785                    ? RocmTracerEventType::MemoryFree
786                    : RocmTracerEventType::MemoryAlloc;
787   event.source = RocmTracerEventSource::ApiCallback;
788   event.name = wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid, 0);
789   event.device_id = device_id;
790   event.thread_id = GetCachedTID();
791   // We do not set stream_id (probably to zero as Malloc etc. commands seems
792   // to run on  default stream). Later we use the unassigned stream_id as a
793   // feature to assign events to host or device.
794   event.correlation_id = data->correlation_id;
795   event.start_time_ns = enter_time;
796   event.end_time_ns = exit_time;
797 
798   switch (cbid) {
799     case HIP_API_ID_hipMalloc:
800       event.memalloc_info.num_bytes = data->args.hipMalloc.size;
801       break;
802     case HIP_API_ID_hipMallocPitch:
803       event.memalloc_info.num_bytes = data->args.hipMallocPitch.pitch__val *
804                                       data->args.hipMallocPitch.height;
805       break;
806     case HIP_API_ID_hipHostMalloc:
807       event.memalloc_info.num_bytes = data->args.hipHostMalloc.size;
808       break;
809     case HIP_API_ID_hipFree:
810     case HIP_API_ID_hipHostFree:
811       event.memalloc_info.num_bytes = 0;
812       break;
813     default:
814       LOG(WARNING)
815           << "Unsupported Malloc/Free API for profiling observed for cbid="
816           << cbid << ". Event dropped!";
817       return;
818       break;
819   }
820 
821   bool is_auxiliary =
822       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
823   collector_->AddEvent(std::move(event), is_auxiliary);
824 }
825 
AddSynchronizeEventUponApiExit(uint32_t cbid,const hip_api_data_t * data,uint64_t enter_time,uint64_t exit_time)826 void RocmApiCallbackImpl::AddSynchronizeEventUponApiExit(
827     uint32_t cbid, const hip_api_data_t* data, uint64_t enter_time,
828     uint64_t exit_time) {
829   // TODO(rocm-profiler): neither CUDA and nor we capture annotaint for this
830   // event
831   /*
832     misses: context_id
833 
834     extras: domain,
835   */
836 
837   RocmTracerEvent event;
838   event.domain = RocmTracerEventDomain::HIP_API;
839   event.type = RocmTracerEventType::Synchronization;
840   event.source = RocmTracerEventSource::ApiCallback;
841   event.name = wrap::roctracer_op_string(ACTIVITY_DOMAIN_HIP_API, cbid, 0);
842   event.thread_id = GetCachedTID();
843   event.correlation_id = data->correlation_id;
844   event.start_time_ns = enter_time;
845   event.end_time_ns = exit_time;
846 
847   switch (cbid) {
848     case HIP_API_ID_hipStreamSynchronize: {
849       event.synchronization_info.sync_type =
850           RocmTracerSyncTypes::StreamSynchronize;
851       const hipStream_t& stream = data->args.hipStreamSynchronize.stream;
852       event.device_id = hipGetStreamDeviceId(stream);
853     } break;
854     case HIP_API_ID_hipStreamWaitEvent: {
855       event.synchronization_info.sync_type = RocmTracerSyncTypes::StreamWait;
856       const hipStream_t& stream = data->args.hipStreamWaitEvent.stream;
857       event.device_id = hipGetStreamDeviceId(stream);
858     } break;
859     default:
860       LOG(WARNING)
861           << "Unsupported Synchronization API for profiling observed for cbid="
862           << cbid << ". Event dropped!";
863       return;
864       break;
865   }
866   bool is_auxiliary =
867       options_.api_tracking_set.find(cbid) == options_.api_tracking_set.end();
868   collector_->AddEvent(std::move(event), is_auxiliary);
869 }
870 
operator ()(const char * begin,const char * end)871 Status RocmActivityCallbackImpl::operator()(const char* begin,
872                                             const char* end) {
873   // we do not dump activities in this set in logger
874 
875   static std::set<activity_op_t> dump_excluded_activities = {
876       HIP_API_ID_hipGetDevice,
877       HIP_API_ID_hipSetDevice,
878       HIP_API_ID___hipPushCallConfiguration,
879       HIP_API_ID___hipPopCallConfiguration,
880       HIP_API_ID_hipEventQuery,
881       HIP_API_ID_hipCtxSetCurrent,
882       HIP_API_ID_hipEventRecord,
883       HIP_API_ID_hipEventQuery,
884       HIP_API_ID_hipGetDeviceProperties,
885       HIP_API_ID_hipPeekAtLastError,
886       HIP_API_ID_hipModuleGetFunction,
887       HIP_API_ID_hipEventCreateWithFlags};
888 
889   const roctracer_record_t* record =
890       reinterpret_cast<const roctracer_record_t*>(begin);
891   const roctracer_record_t* end_record =
892       reinterpret_cast<const roctracer_record_t*>(end);
893 
894   while (record < end_record) {
895     // DumpActivityRecord(record);
896 
897     switch (record->domain) {
898       // HIP API activities.
899       case ACTIVITY_DOMAIN_HIP_API:
900         switch (record->op) {
901           case HIP_API_ID_hipModuleLaunchKernel:
902           case HIP_API_ID_hipExtModuleLaunchKernel:
903           case HIP_API_ID_hipHccModuleLaunchKernel:
904           case HIP_API_ID_hipLaunchKernel:
905             DumpActivityRecord(record, std::to_string(__LINE__));
906             AddHipKernelActivityEvent(record);
907             break;
908           case HIP_API_ID_hipMemcpyDtoH:
909           case HIP_API_ID_hipMemcpyHtoD:
910           case HIP_API_ID_hipMemcpyDtoD:
911           case HIP_API_ID_hipMemcpyDtoHAsync:
912           case HIP_API_ID_hipMemcpyHtoDAsync:
913           case HIP_API_ID_hipMemcpyDtoDAsync:
914           case HIP_API_ID_hipMemcpyAsync:
915           case HIP_API_ID_hipMemcpy:
916             DumpActivityRecord(record, std::to_string(__LINE__));
917             AddNormalHipMemcpyActivityEvent(record);
918             break;
919           case HIP_API_ID_hipMemset:
920           case HIP_API_ID_hipMemsetAsync:
921           case HIP_API_ID_hipMemsetD32:
922           case HIP_API_ID_hipMemsetD32Async:
923           case HIP_API_ID_hipMemsetD16:
924           case HIP_API_ID_hipMemsetD16Async:
925           case HIP_API_ID_hipMemsetD8:
926           case HIP_API_ID_hipMemsetD8Async:
927             DumpActivityRecord(record, std::to_string(__LINE__));
928             AddHipMemsetActivityEvent(record);
929             break;
930 
931           case HIP_API_ID_hipMalloc:
932           case HIP_API_ID_hipMallocPitch:
933           case HIP_API_ID_hipHostMalloc:
934           case HIP_API_ID_hipFree:
935           case HIP_API_ID_hipHostFree:
936             DumpActivityRecord(record, std::to_string(__LINE__));
937             AddHipMallocActivityEvent(record);
938             break;
939           case HIP_API_ID_hipStreamSynchronize:
940           case HIP_API_ID_hipStreamWaitEvent:
941             // case HIP_API_ID_hipStreamWaitEvent:
942             DumpActivityRecord(record, std::to_string(__LINE__));
943             AddHipStreamSynchronizeActivityEvent(record);
944             break;
945 
946           default:
947             if (dump_excluded_activities.find(record->op) ==
948                 dump_excluded_activities.end()) {
949               string drop_message(
950                   "\nNot in the API tracked activities. Dropped!");
951               DumpActivityRecord(record, drop_message);
952             }
953             break;
954         }  // switch (record->op).
955         break;
956 
957       // HCC ops activities.
958       case ACTIVITY_DOMAIN_HIP_OPS:
959 
960         switch (record->op) {
961           case HIP_OP_ID_DISPATCH:
962             DumpActivityRecord(record, std::to_string(__LINE__));
963             AddHccKernelActivityEvent(record);
964             tracer_->RemoveFromPendingActivityRecords(record->correlation_id);
965             break;
966           case HIP_OP_ID_COPY:
967             switch (record->kind) {
968               // TODO(rocm-profiler): use enum instead.
969               case 4595:   /*CopyDeviceToHost*/
970               case 4596:   /*CopyDeviceToDevice*/
971               case 4597: { /*CopyHostToDevice*/
972                 /*MEMCPY*/
973                 // roctracer returns CopyHostToDevice for hipMemcpyDtoD API
974                 //  Please look at the issue #53 in roctracer GitHub repo.
975                 DumpActivityRecord(record, "");
976                 AddNormalHipOpsMemcpyActivityEvent(record);
977                 tracer_->RemoveFromPendingActivityRecords(
978                     record->correlation_id);
979               } break;
980               case 4615: /*FillBuffer*/
981                 /*MEMSET*/
982                 DumpActivityRecord(record, "");
983                 AddHipOpsMemsetActivityEvent(record);
984                 break;
985               case 4606: /*MARKER*/
986                 // making the log shorter.
987                 // markers are with 0ns duration.
988                 break;
989               default:
990                 string drop_message(
991                     "\nNot in the HIP-OPS-COPY tracked activities. Dropeed!");
992                 DumpActivityRecord(record, drop_message);
993                 break;
994             }  // switch (record->kind)
995             break;
996           default:
997             string drop_message(
998                 "\nNot in the HIP-OPS tracked activities. Dropped!");
999             DumpActivityRecord(record, drop_message);
1000             break;
1001         }  // switch (record->op).
1002         break;
1003       default:
1004         string drop_message("\nNot in the tracked domain activities. Dropped!");
1005         DumpActivityRecord(record, drop_message);
1006         break;
1007     }
1008 
1009     RETURN_IF_ROCTRACER_ERROR(static_cast<roctracer_status_t>(
1010         roctracer_next_record(record, &record)));
1011   }
1012 
1013   return Status::OK();
1014 }
1015 
AddHipKernelActivityEvent(const roctracer_record_t * record)1016 void RocmActivityCallbackImpl::AddHipKernelActivityEvent(
1017     const roctracer_record_t* record) {
1018   /*
1019   missing:
1020    name, device_id(got from hcc), context_id, stream_id(got from hcc),
1021  nvtx_range, kernel_info
1022 
1023   extra:
1024    domain
1025  activity record contains process/thread ID
1026  */
1027   RocmTracerEvent event;
1028   event.domain = RocmTracerEventDomain::HIP_API;
1029   event.type = RocmTracerEventType::Kernel;
1030   event.source = RocmTracerEventSource::Activity;
1031   // event.name =  /* we use the API name instead*/
1032   //    wrap::roctracer_op_string(record->domain, record->op, record->kind);
1033   event.correlation_id = record->correlation_id;
1034   // TODO(rocm-profiler): CUDA uses device id and correlation ID for finding
1035   // annotations.
1036   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1037 
1038   event.start_time_ns = record->begin_ns;
1039   event.end_time_ns = record->end_ns;
1040 
1041   collector_->AddEvent(std::move(event), false);
1042 }
1043 
AddNormalHipMemcpyActivityEvent(const roctracer_record_t * record)1044 void RocmActivityCallbackImpl::AddNormalHipMemcpyActivityEvent(
1045     const roctracer_record_t* record) {
1046   /*
1047   ---------------NormalMemcpy-------------------
1048     misses:context_id, memcpy_info.kind, memcpy_info.srckind,
1049   memcpy_info.dstkind, memcpy_info.num_bytes, memcpy_info.destenation,
1050   device_id, stream_id,
1051 
1052     extras: domain
1053   ---------------PeerMemcpy---------------------
1054     misses: device_id, context_id, stream_id, memcpy_info.kind,
1055       memcpy_info.num_bytes, memcpy_info.destination,
1056     extras:
1057       domain,
1058   */
1059 
1060   RocmTracerEvent event;
1061   event.domain = RocmTracerEventDomain::HIP_API;
1062   event.source = RocmTracerEventSource::Activity;
1063   event.start_time_ns = record->begin_ns;
1064   event.end_time_ns = record->end_ns;
1065   event.correlation_id = record->correlation_id;
1066   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1067   // TODO(roc-profiler): record->bytes is not a valid value
1068   // event.memcpy_info.num_bytes = record->bytes;
1069   event.name =
1070       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1071   switch (record->op) {
1072     case HIP_API_ID_hipMemcpyDtoH:
1073     case HIP_API_ID_hipMemcpyDtoHAsync:
1074       event.type = RocmTracerEventType::MemcpyD2H;
1075       event.memcpy_info.async =
1076           (record->op == HIP_API_ID_hipMemcpyDtoHAsync) ? true : false;
1077       break;
1078     case HIP_API_ID_hipMemcpyHtoD:
1079     case HIP_API_ID_hipMemcpyHtoDAsync:
1080       event.type = RocmTracerEventType::MemcpyH2D;
1081       event.memcpy_info.async =
1082           (record->op == HIP_API_ID_hipMemcpyHtoDAsync) ? true : false;
1083       break;
1084     case HIP_API_ID_hipMemcpyDtoD:
1085     case HIP_API_ID_hipMemcpyDtoDAsync:
1086       event.type = RocmTracerEventType::MemcpyD2D;
1087       event.memcpy_info.async =
1088           (record->op == HIP_API_ID_hipMemcpyDtoDAsync) ? true : false;
1089       break;
1090     case HIP_API_ID_hipMemcpy:
1091     case HIP_API_ID_hipMemcpyAsync:
1092       event.type = RocmTracerEventType::MemcpyOther;
1093       event.memcpy_info.async =
1094           (record->op == HIP_API_ID_hipMemcpyAsync) ? true : false;
1095       break;
1096     case HIP_API_ID_hipMemcpyPeer:
1097     case HIP_API_ID_hipMemcpyPeerAsync:
1098       event.type = RocmTracerEventType::MemcpyP2P;
1099       event.memcpy_info.async =
1100           (record->op == HIP_API_ID_hipMemcpyPeerAsync) ? true : false;
1101       break;
1102     default:
1103       LOG(WARNING) << "Unsupported Memcpy/MemcpyPeer activity for profiling "
1104                       "observed for cbid="
1105                    << record->op << ". Event dropped!";
1106       return;
1107       break;
1108   }
1109 
1110   collector_->AddEvent(std::move(event), false);
1111 }
1112 
AddHipMemsetActivityEvent(const roctracer_record_t * record)1113 void RocmActivityCallbackImpl::AddHipMemsetActivityEvent(
1114     const roctracer_record_t* record) {
1115   /*
1116     misses:
1117       device_id, context_id, stram_id, memset_info.num_bytes
1118       memset_info.kind
1119 
1120     extras:
1121       domain, annotation
1122   */
1123 
1124   RocmTracerEvent event;
1125   event.domain = RocmTracerEventDomain::HIP_API;
1126   event.source = RocmTracerEventSource::Activity;
1127   event.name =
1128       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1129   event.correlation_id = record->correlation_id;
1130   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1131 
1132   event.type = RocmTracerEventType::Memset;
1133 
1134   switch (record->op) {
1135     case HIP_API_ID_hipMemset:
1136       event.memset_info.async = false;
1137       break;
1138     case HIP_API_ID_hipMemsetAsync:
1139       event.memset_info.async = true;
1140       break;
1141     case HIP_API_ID_hipMemsetD8:
1142       event.memset_info.async = false;
1143       break;
1144     case HIP_API_ID_hipMemsetD8Async:
1145       event.memset_info.async = true;
1146       break;
1147     case HIP_API_ID_hipMemsetD16:
1148       event.memset_info.async = false;
1149       break;
1150     case HIP_API_ID_hipMemsetD16Async:
1151       event.memset_info.async = true;
1152       break;
1153     case HIP_API_ID_hipMemsetD32:
1154       event.memset_info.async = false;
1155       break;
1156     case HIP_API_ID_hipMemsetD32Async:
1157       event.memset_info.async = true;
1158       break;
1159   }
1160 
1161   event.start_time_ns = record->begin_ns;
1162   event.end_time_ns = record->end_ns;
1163 
1164   collector_->AddEvent(std::move(event), false);
1165 }
1166 
AddHipMallocActivityEvent(const roctracer_record_t * record)1167 void RocmActivityCallbackImpl::AddHipMallocActivityEvent(
1168     const roctracer_record_t* record) {
1169   /*
1170     misses: device_id, context_id, memory_residency_info (num_byts, kind,
1171     address)
1172 
1173     extras:
1174       annotation, domain,
1175   */
1176 
1177   RocmTracerEvent event;
1178   event.domain = RocmTracerEventDomain::HIP_API;
1179   event.type = RocmTracerEventType::MemoryAlloc;
1180   event.source = RocmTracerEventSource::Activity;
1181   event.name =
1182       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1183   event.correlation_id = record->correlation_id;
1184   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1185   // similar to CUDA we set this to the default stream
1186   event.stream_id = 0;
1187   event.start_time_ns = record->begin_ns;
1188   // making sure it does not have 0ns duration. Otherwise, it may not show up in
1189   // the trace view
1190   event.end_time_ns = std::max(record->end_ns, record->begin_ns + 1);
1191 
1192   collector_->AddEvent(std::move(event), false);
1193 }
1194 
AddHipStreamSynchronizeActivityEvent(const roctracer_record_t * record)1195 void RocmActivityCallbackImpl::AddHipStreamSynchronizeActivityEvent(
1196     const roctracer_record_t* record) {
1197   /*
1198   misses: context_id, device_id (cuda also does not provide but we can get from
1199   API-CB)
1200 
1201   extras: domain, synchronization_info.sync_type, annotation
1202   */
1203 
1204   RocmTracerEvent event;
1205   event.domain = RocmTracerEventDomain::HIP_API;
1206   event.type = RocmTracerEventType::Synchronization;
1207   event.source = RocmTracerEventSource::Activity;
1208   event.name =
1209       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1210   event.correlation_id = record->correlation_id;
1211   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1212   event.start_time_ns = record->begin_ns;
1213 
1214   // making sure it does not have 0ns duration. Otherwise, it may not show up in
1215   // the trace view
1216   event.end_time_ns = std::max(record->end_ns, record->begin_ns + 1);
1217 
1218   switch (record->op) {
1219     case HIP_API_ID_hipStreamSynchronize:
1220       event.synchronization_info.sync_type =
1221           RocmTracerSyncTypes::StreamSynchronize;
1222       break;
1223     case HIP_API_ID_hipStreamWaitEvent:
1224       event.synchronization_info.sync_type = RocmTracerSyncTypes::StreamWait;
1225       break;
1226     default:
1227       event.synchronization_info.sync_type = RocmTracerSyncTypes::InvalidSync;
1228       break;
1229   }
1230   collector_->AddEvent(std::move(event), false);
1231 }
1232 
1233 // TODO(rocm-profiler): rename this function. this is HIP-OP
AddHccKernelActivityEvent(const roctracer_record_t * record)1234 void RocmActivityCallbackImpl::AddHccKernelActivityEvent(
1235     const roctracer_record_t* record) {
1236   /*
1237    missing:
1238      name, context_id, nvtx_range, kernel_info
1239 
1240    extra:
1241      domain (thread id from the HIP activity)
1242 
1243    activity record contains device/stream ID
1244  */
1245   RocmTracerEvent event;
1246   event.domain = RocmTracerEventDomain::HCC_OPS;
1247   event.type = RocmTracerEventType::Kernel;
1248   event.source = RocmTracerEventSource::Activity;
1249   event.correlation_id = record->correlation_id;
1250   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1251   event.start_time_ns = record->begin_ns;
1252   event.end_time_ns = record->end_ns;
1253   event.device_id = record->device_id;
1254   event.stream_id = record->queue_id;
1255 
1256   collector_->AddEvent(std::move(event), false);
1257 }
1258 
AddNormalHipOpsMemcpyActivityEvent(const roctracer_record_t * record)1259 void RocmActivityCallbackImpl::AddNormalHipOpsMemcpyActivityEvent(
1260     const roctracer_record_t* record) {
1261   /*
1262     misses:
1263       type, name(the name set here is not clear enough but we keep it for
1264     debug), context_id, memcpy_info.kind, memcpy_info.num_bytes,
1265     memcpy_info.async, memcpy_info.src_mem_kind, memcpy_info.dst_mem_kind
1266 
1267     extras:
1268       domain,
1269 
1270   */
1271 
1272   RocmTracerEvent event;
1273   event.domain = RocmTracerEventDomain::HCC_OPS;
1274   event.source = RocmTracerEventSource::Activity;
1275   event.name =  // name is stored for debug
1276       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1277   event.correlation_id = record->correlation_id;
1278   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1279 
1280   event.start_time_ns = record->begin_ns;
1281   event.end_time_ns = record->end_ns;
1282   event.device_id = record->device_id;
1283   event.memcpy_info.destination = event.device_id;
1284   event.stream_id = record->queue_id;
1285 
1286   // we set the type as MemcpyOther as HIP-OPS activity record does not carry
1287   // this information
1288   event.type = RocmTracerEventType::MemcpyOther;
1289 
1290   collector_->AddEvent(std::move(event), false);
1291 }
1292 
AddHipOpsMemsetActivityEvent(const roctracer_record_t * record)1293 void RocmActivityCallbackImpl::AddHipOpsMemsetActivityEvent(
1294     const roctracer_record_t* record) {
1295   /*
1296     misses:
1297       name (name recorder here is not clear enough for Memset. We only capture
1298     it for debug), context_id, memset_info.kind, memset_info.num_bytes,
1299     memset_info.async
1300 
1301     extras:
1302       dommain, annotation,
1303 
1304   */
1305 
1306   RocmTracerEvent event;
1307   event.domain = RocmTracerEventDomain::HCC_OPS;
1308   event.source = RocmTracerEventSource::Activity;
1309   event.name =  // name is stored for debug
1310       wrap::roctracer_op_string(record->domain, record->op, record->kind);
1311   event.correlation_id = record->correlation_id;
1312   event.annotation = collector_->annotation_map()->LookUp(event.correlation_id);
1313 
1314   event.start_time_ns = record->begin_ns;
1315   event.end_time_ns = record->end_ns;
1316   event.device_id = record->device_id;
1317   event.stream_id = record->queue_id;
1318 
1319   event.type = RocmTracerEventType::Memset;
1320 
1321   collector_->AddEvent(std::move(event), false);
1322 }
1323 
Add(uint32_t correlation_id,const std::string & annotation)1324 void AnnotationMap::Add(uint32_t correlation_id,
1325                         const std::string& annotation) {
1326   if (annotation.empty()) return;
1327   VLOG(3) << "Add annotation: "
1328           << " correlation_id=" << correlation_id
1329           << ", annotation: " << annotation;
1330   absl::MutexLock lock(&map_.mutex);
1331   if (map_.annotations.size() < max_size_) {
1332     absl::string_view annotation_str =
1333         *map_.annotations.insert(annotation).first;
1334     map_.correlation_map.emplace(correlation_id, annotation_str);
1335   }
1336 }
1337 
LookUp(uint32_t correlation_id)1338 absl::string_view AnnotationMap::LookUp(uint32_t correlation_id) {
1339   absl::MutexLock lock(&map_.mutex);
1340   auto it = map_.correlation_map.find(correlation_id);
1341   return it != map_.correlation_map.end() ? it->second : absl::string_view();
1342 }
1343 
GetRocmTracerSingleton()1344 /* static */ RocmTracer* RocmTracer::GetRocmTracerSingleton() {
1345   static auto* singleton = new RocmTracer();
1346   return singleton;
1347 }
1348 
1349 // FIXME(rocm-profiler): we should also check if we have AMD GPUs
IsAvailable() const1350 bool RocmTracer::IsAvailable() const {
1351   return !activity_tracing_enabled_ && !api_tracing_enabled_;  // &&NumGpus()
1352 }
1353 
NumGpus()1354 int RocmTracer::NumGpus() {
1355   static int num_gpus = []() -> int {
1356     if (hipInit(0) != hipSuccess) {
1357       return 0;
1358     }
1359     int gpu_count;
1360     if (hipGetDeviceCount(&gpu_count) != hipSuccess) {
1361       return 0;
1362     }
1363     LOG(INFO) << "Profiler found " << gpu_count << " GPUs";
1364     return gpu_count;
1365   }();
1366   return num_gpus;
1367 }
1368 
Enable(const RocmTracerOptions & options,RocmTraceCollector * collector)1369 void RocmTracer::Enable(const RocmTracerOptions& options,
1370                         RocmTraceCollector* collector) {
1371   options_ = options;
1372   collector_ = collector;
1373   api_cb_impl_ = new RocmApiCallbackImpl(options, this, collector);
1374   activity_cb_impl_ = new RocmActivityCallbackImpl(options, this, collector);
1375 
1376   // From ROCm 3.5 onwards, the following call is required.
1377   // don't quite know what it does (no documentation!), only that without it
1378   // the call to enable api/activity tracing will run into a segfault
1379   wrap::roctracer_set_properties(ACTIVITY_DOMAIN_HIP_API, nullptr);
1380 
1381   EnableApiTracing().IgnoreError();
1382   EnableActivityTracing().IgnoreError();
1383   LOG(INFO) << "GpuTracer started";
1384 }
1385 
Disable()1386 void RocmTracer::Disable() {
1387   // TODO(rocm-profiler): TF has a SyncAndFlush() function
1388   // to be called before disabling. It makes sure all the contexts
1389   // has finished all the tasks before shutting down the profiler
1390   DisableApiTracing().IgnoreError();
1391   DisableActivityTracing().IgnoreError();
1392   delete api_cb_impl_;
1393   delete activity_cb_impl_;
1394   collector_->Flush();
1395   collector_ = nullptr;
1396   options_.reset();
1397   LOG(INFO) << "GpuTracer stopped";
1398 }
1399 
ApiCallback(uint32_t domain,uint32_t cbid,const void * cbdata,void * user_data)1400 void ApiCallback(uint32_t domain, uint32_t cbid, const void* cbdata,
1401                  void* user_data) {
1402   RocmTracer* tracer = reinterpret_cast<RocmTracer*>(user_data);
1403   tracer->ApiCallbackHandler(domain, cbid, cbdata);
1404 }
1405 
ApiCallbackHandler(uint32_t domain,uint32_t cbid,const void * cbdata)1406 void RocmTracer::ApiCallbackHandler(uint32_t domain, uint32_t cbid,
1407                                     const void* cbdata) {
1408   if (api_tracing_enabled_) (*api_cb_impl_)(domain, cbid, cbdata);
1409 }
1410 
EnableApiTracing()1411 Status RocmTracer::EnableApiTracing() {
1412   if (api_tracing_enabled_) return Status::OK();
1413   api_tracing_enabled_ = true;
1414 
1415   for (auto& iter : options_->api_callbacks) {
1416     activity_domain_t domain = iter.first;
1417     std::vector<uint32_t>& ops = iter.second;
1418     if (ops.size() == 0) {
1419       VLOG(3) << "Enabling API tracing for domain "
1420               << GetActivityDomainName(domain);
1421       RETURN_IF_ROCTRACER_ERROR(
1422           wrap::roctracer_enable_domain_callback(domain, ApiCallback, this));
1423     } else {
1424       VLOG(3) << "Enabling API tracing for " << ops.size() << " ops in domain "
1425               << GetActivityDomainName(domain);
1426       for (auto& op : ops) {
1427         VLOG(3) << "Enabling API tracing for "
1428                 << GetActivityDomainOpName(domain, op);
1429         RETURN_IF_ROCTRACER_ERROR(
1430             wrap::roctracer_enable_op_callback(domain, op, ApiCallback, this));
1431       }
1432     }
1433   }
1434   return Status::OK();
1435 }
1436 
DisableApiTracing()1437 Status RocmTracer::DisableApiTracing() {
1438   if (!api_tracing_enabled_) return Status::OK();
1439   api_tracing_enabled_ = false;
1440 
1441   for (auto& iter : options_->api_callbacks) {
1442     activity_domain_t domain = iter.first;
1443     std::vector<uint32_t>& ops = iter.second;
1444     if (ops.size() == 0) {
1445       VLOG(3) << "Disabling API tracing for domain "
1446               << GetActivityDomainName(domain);
1447       RETURN_IF_ROCTRACER_ERROR(
1448           wrap::roctracer_disable_domain_callback(domain));
1449     } else {
1450       VLOG(3) << "Disabling API tracing for " << ops.size() << " ops in domain "
1451               << GetActivityDomainName(domain);
1452       for (auto& op : ops) {
1453         VLOG(3) << "Disabling API tracing for "
1454                 << GetActivityDomainOpName(domain, op);
1455         RETURN_IF_ROCTRACER_ERROR(
1456             wrap::roctracer_disable_op_callback(domain, op));
1457       }
1458     }
1459   }
1460   return Status::OK();
1461 }
1462 
ActivityCallback(const char * begin,const char * end,void * user_data)1463 void ActivityCallback(const char* begin, const char* end, void* user_data) {
1464   RocmTracer* tracer = reinterpret_cast<RocmTracer*>(user_data);
1465   tracer->ActivityCallbackHandler(begin, end);
1466 }
1467 
ActivityCallbackHandler(const char * begin,const char * end)1468 void RocmTracer::ActivityCallbackHandler(const char* begin, const char* end) {
1469   if (activity_tracing_enabled_) {
1470     (*activity_cb_impl_)(begin, end);
1471   } else {
1472     LOG(WARNING) << "ActivityCallbackHandler called when "
1473                     "activity_tracing_enabled_ is false";
1474 
1475     VLOG(3) << "Dropped Activity Records Start";
1476     const roctracer_record_t* record =
1477         reinterpret_cast<const roctracer_record_t*>(begin);
1478     const roctracer_record_t* end_record =
1479         reinterpret_cast<const roctracer_record_t*>(end);
1480     while (record < end_record) {
1481       DumpActivityRecord(record,
1482                          "activity_tracing_enabled_ is false. Dropped!");
1483       roctracer_next_record(record, &record);
1484     }
1485     VLOG(3) << "Dropped Activity Records End";
1486   }
1487 }
1488 
EnableActivityTracing()1489 Status RocmTracer::EnableActivityTracing() {
1490   if (activity_tracing_enabled_) return Status::OK();
1491   activity_tracing_enabled_ = true;
1492 
1493   if (!options_->activity_tracing.empty()) {
1494     // Create the memory pool to store activity records in
1495     if (wrap::roctracer_default_pool_expl(nullptr) == NULL) {
1496       roctracer_properties_t properties{};
1497       properties.buffer_size = 0x1000;
1498       properties.buffer_callback_fun = ActivityCallback;
1499       properties.buffer_callback_arg = this;
1500       VLOG(3) << "Creating roctracer activity buffer";
1501       RETURN_IF_ROCTRACER_ERROR(
1502           wrap::roctracer_open_pool_expl(&properties, nullptr));
1503     }
1504   }
1505 
1506   for (auto& iter : options_->activity_tracing) {
1507     activity_domain_t domain = iter.first;
1508     std::vector<uint32_t>& ops = iter.second;
1509     if (ops.size() == 0) {
1510       VLOG(3) << "Enabling Activity tracing for domain "
1511               << GetActivityDomainName(domain);
1512       RETURN_IF_ROCTRACER_ERROR(
1513           wrap::roctracer_enable_domain_activity_expl(domain, nullptr));
1514     } else {
1515       VLOG(3) << "Enabling Activity tracing for " << ops.size()
1516               << " ops in domain " << GetActivityDomainName(domain);
1517       for (auto& op : ops) {
1518         VLOG(3) << "Enabling Activity tracing for "
1519                 << GetActivityDomainOpName(domain, op);
1520         // roctracer library has not exported "roctracer_enable_op_activity"
1521         RETURN_IF_ROCTRACER_ERROR(
1522             wrap::roctracer_enable_op_activity_expl(domain, op, nullptr));
1523       }
1524     }
1525   }
1526 
1527   return Status::OK();
1528 }
1529 
DisableActivityTracing()1530 Status RocmTracer::DisableActivityTracing() {
1531   if (!activity_tracing_enabled_) return Status::OK();
1532 
1533   for (auto& iter : options_->activity_tracing) {
1534     activity_domain_t domain = iter.first;
1535     std::vector<uint32_t>& ops = iter.second;
1536     if (ops.size() == 0) {
1537       VLOG(3) << "Disabling Activity tracing for domain "
1538               << GetActivityDomainName(domain);
1539       RETURN_IF_ROCTRACER_ERROR(
1540           wrap::roctracer_disable_domain_activity(domain));
1541     } else {
1542       VLOG(3) << "Disabling Activity tracing for " << ops.size()
1543               << " ops in domain " << GetActivityDomainName(domain);
1544       for (auto& op : ops) {
1545         VLOG(3) << "Disabling Activity tracing for "
1546                 << GetActivityDomainOpName(domain, op);
1547         RETURN_IF_ROCTRACER_ERROR(
1548             wrap::roctracer_disable_op_activity(domain, op));
1549       }
1550     }
1551   }
1552 
1553   // TODO(rocm-profiler): this stopping mechanism needs improvement.
1554   // Flush the activity buffer BEFORE setting the activity_tracing_enable_
1555   // flag to FALSE. This is because the activity record callback routine is
1556   // gated by the same flag
1557   VLOG(3) << "Flushing roctracer activity buffer";
1558   RETURN_IF_ROCTRACER_ERROR(wrap::roctracer_flush_activity_expl(nullptr));
1559   // roctracer_flush_buf();
1560 
1561   // Explicitly wait for (almost) all pending activity records
1562   // The choice of all of the following is based what seemed to work
1563   // best when enabling tracing on a large testcase (BERT)
1564   // * 100 ms as the initial sleep duration AND
1565   // * 1 as the initial threshold value
1566   // * 6 as the maximum number of iterations
1567   int duration_ms = 100;
1568   size_t threshold = 1;
1569   for (int i = 0; i < 6; i++, duration_ms *= 2, threshold *= 2) {
1570     if (GetPendingActivityRecordsCount() < threshold) break;
1571     VLOG(3) << "Wait for pending activity records :"
1572             << " Pending count = " << GetPendingActivityRecordsCount()
1573             << ", Threshold = " << threshold;
1574     VLOG(3) << "Wait for pending activity records : sleep for " << duration_ms
1575             << " ms";
1576     tensorflow::profiler::SleepForMillis(duration_ms);
1577   }
1578   ClearPendingActivityRecordsCount();
1579 
1580   activity_tracing_enabled_ = false;
1581 
1582   return Status::OK();
1583 }
1584 
GetTimestamp()1585 /*static*/ uint64_t RocmTracer::GetTimestamp() {
1586   uint64_t ts;
1587   if (wrap::roctracer_get_timestamp(&ts) != ROCTRACER_STATUS_SUCCESS) {
1588     const char* errstr = wrap::roctracer_error_string();
1589     LOG(ERROR) << "function roctracer_get_timestamp failed with error "
1590                << errstr;
1591     // Return 0 on error.
1592     return 0;
1593   }
1594   return ts;
1595 }
1596 
1597 }  // namespace profiler
1598 }  // namespace tensorflow
1599