xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/backends/gpu/rocm_tracer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_
17 #define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_
18 
19 #include "absl/container/fixed_array.h"
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/container/node_hash_set.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/types.h"
28 #include "tensorflow/stream_executor/rocm/roctracer_wrapper.h"
29 
30 namespace tensorflow {
31 namespace profiler {
32 
33 struct MemcpyDetails {
34   // The amount of data copied for memcpy events.
35   size_t num_bytes;
36   // The destination device for peer-2-peer communication (memcpy). The source
37   // device is implicit: it's the current device.
38   uint32_t destination;
39   // Whether or not the memcpy is asynchronous.
40   bool async;
41 };
42 
43 struct MemsetDetails {
44   // The number of memory elements getting set
45   size_t num_bytes;
46   // Whether or not the memset is asynchronous.
47   bool async;
48 };
49 
50 struct MemAllocDetails {
51   // The amount of data requested for cudaMalloc events.
52   uint64_t num_bytes;
53 };
54 
55 struct KernelDetails {
56   // The number of registers used in this kernel.
57   uint32_t registers_per_thread;
58   // The amount of shared memory space used by a thread block.
59   uint32_t static_shared_memory_usage;
60   // The amount of dynamic memory space used by a thread block.
61   uint32_t dynamic_shared_memory_usage;
62   // X-dimension of a thread block.
63   uint32_t block_x;
64   // Y-dimension of a thread block.
65   uint32_t block_y;
66   // Z-dimension of a thread block.
67   uint32_t block_z;
68   // X-dimension of a grid.
69   uint32_t grid_x;
70   // Y-dimension of a grid.
71   uint32_t grid_y;
72   // Z-dimension of a grid.
73   uint32_t grid_z;
74 
75   // kernel address. Used for calculating core occupancy
76   void* func_ptr;
77 };
78 
79 // RocmTracerSyncTypes forward decleration
80 enum class RocmTracerSyncTypes;
81 struct SynchronizationDetails {
82   RocmTracerSyncTypes sync_type;
83 };
84 
85 enum class RocmTracerEventType {
86   Unsupported = 0,
87   Kernel,
88   MemcpyH2D,
89   MemcpyD2H,
90   MemcpyD2D,
91   MemcpyP2P,
92   MemcpyOther,
93   MemoryAlloc,
94   MemoryFree,
95   Memset,
96   Synchronization,
97   Generic,
98 };
99 
100 const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type);
101 
102 enum class RocmTracerEventSource {
103   Invalid = 0,
104   ApiCallback,
105   Activity,
106 };
107 
108 const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source);
109 
110 enum class RocmTracerEventDomain {
111   InvalidDomain = 0,
112   HIP_API,
113   HCC_OPS,  // TODO(rocm-profiler): renme this to HIP_OPS
114 };
115 enum class RocmTracerSyncTypes {
116   InvalidSync = 0,
117   StreamSynchronize,  // caller thread wait stream to become empty
118   EventSynchronize,   // caller thread will block until event happens
119   StreamWait          // compute stream will wait for event to happen
120 };
121 
122 const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain);
123 
124 struct RocmTracerEvent {
125   static constexpr uint32_t kInvalidDeviceId =
126       std::numeric_limits<uint32_t>::max();
127   static constexpr uint32_t kInvalidThreadId =
128       std::numeric_limits<uint32_t>::max();
129   static constexpr uint32_t kInvalidCorrelationId =
130       std::numeric_limits<uint32_t>::max();
131   static constexpr uint64_t kInvalidStreamId =
132       std::numeric_limits<uint64_t>::max();
133   RocmTracerEventType type;
134   RocmTracerEventSource source = RocmTracerEventSource::Invalid;
135   RocmTracerEventDomain domain;
136   std::string name;
137   // This points to strings in AnnotationMap, which should outlive the point
138   // where serialization happens.
139   absl::string_view annotation;
140   absl::string_view roctx_range;
141   uint64_t start_time_ns = 0;
142   uint64_t end_time_ns = 0;
143   uint32_t device_id = kInvalidDeviceId;
144   uint32_t correlation_id = kInvalidCorrelationId;
145   uint32_t thread_id = kInvalidThreadId;
146   int64_t stream_id = kInvalidStreamId;
147   union {
148     MemcpyDetails memcpy_info;                    // If type == Memcpy*
149     MemsetDetails memset_info;                    // If type == Memset*
150     MemAllocDetails memalloc_info;                // If type == MemoryAlloc
151     KernelDetails kernel_info;                    // If type == Kernel
152     SynchronizationDetails synchronization_info;  // If type == Synchronization
153   };
154 };
155 
156 void DumpRocmTracerEvent(const RocmTracerEvent& event,
157                          uint64_t start_walltime_ns, uint64_t start_gputime_ns,
158                          const string& message);
159 
160 struct RocmTracerOptions {
161   std::set<uint32_t> api_tracking_set;  // actual api set we want to profile
162 
163   // map of domain --> ops for which we need to enable the API callbacks
164   // If the ops vector is empty, then enable API callbacks for entire domain
165   absl::flat_hash_map<activity_domain_t, std::vector<uint32_t> > api_callbacks;
166 
167   // map of domain --> ops for which we need to enable the Activity records
168   // If the ops vector is empty, then enable Activity records for entire domain
169   absl::flat_hash_map<activity_domain_t, std::vector<uint32_t> >
170       activity_tracing;
171 };
172 
173 struct RocmTraceCollectorOptions {
174   // Maximum number of events to collect from callback API; if -1, no limit.
175   // if 0, the callback API is enabled to build a correlation map, but no
176   // events are collected.
177   uint64_t max_callback_api_events;
178   // Maximum number of events to collect from activity API; if -1, no limit.
179   uint64_t max_activity_api_events;
180   // Maximum number of annotation strings that we can accommodate.
181   uint64_t max_annotation_strings;
182   // Number of GPUs involved.
183   uint32_t num_gpus;
184 };
185 
186 class AnnotationMap {
187  public:
AnnotationMap(uint64_t max_size)188   explicit AnnotationMap(uint64_t max_size) : max_size_(max_size) {}
189   void Add(uint32_t correlation_id, const std::string& annotation);
190   absl::string_view LookUp(uint32_t correlation_id);
191 
192  private:
193   struct AnnotationMapImpl {
194     // The population/consumption of annotations might happen from multiple
195     // callback/activity api related threads.
196     absl::Mutex mutex;
197     // Annotation tends to be repetitive, use a hash_set to store the strings,
198     // an use the reference to the string in the map.
199     absl::node_hash_set<std::string> annotations;
200     absl::flat_hash_map<uint32_t, absl::string_view> correlation_map;
201   };
202   const uint64_t max_size_;
203   AnnotationMapImpl map_;
204 
205  public:
206   // Disable copy and move.
207   AnnotationMap(const AnnotationMap&) = delete;
208   AnnotationMap& operator=(const AnnotationMap&) = delete;
209 };
210 
211 class RocmTraceCollector {
212  public:
RocmTraceCollector(const RocmTraceCollectorOptions & options)213   explicit RocmTraceCollector(const RocmTraceCollectorOptions& options)
214       : options_(options), annotation_map_(options.max_annotation_strings) {}
~RocmTraceCollector()215   virtual ~RocmTraceCollector() {}
216 
217   virtual void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) = 0;
218   virtual void OnEventsDropped(const std::string& reason,
219                                uint32_t num_events) = 0;
220   virtual void Flush() = 0;
221 
annotation_map()222   AnnotationMap* annotation_map() { return &annotation_map_; }
223 
224  protected:
225   RocmTraceCollectorOptions options_;
226 
227  private:
228   AnnotationMap annotation_map_;
229 
230  public:
231   // Disable copy and move.
232   RocmTraceCollector(const RocmTraceCollector&) = delete;
233   RocmTraceCollector& operator=(const RocmTraceCollector&) = delete;
234 };
235 
236 class RocmTracer;
237 
238 class RocmApiCallbackImpl {
239  public:
RocmApiCallbackImpl(const RocmTracerOptions & options,RocmTracer * tracer,RocmTraceCollector * collector)240   RocmApiCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer,
241                       RocmTraceCollector* collector)
242       : options_(options), tracer_(tracer), collector_(collector) {}
243 
244   Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata);
245 
246  private:
247   void AddKernelEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
248                                  uint64_t enter_time, uint64_t exit_time);
249   void AddNormalMemcpyEventUponApiExit(uint32_t cbid,
250                                        const hip_api_data_t* data,
251                                        uint64_t enter_time, uint64_t exit_time);
252   void AddMemcpyPeerEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
253                                      uint64_t enter_time, uint64_t exit_time);
254   void AddMemsetEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
255                                  uint64_t enter_time, uint64_t exit_time);
256   void AddMallocFreeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
257                                      uint32_t device_id, uint64_t enter_time,
258                                      uint64_t exit_time);
259   void AddStreamSynchronizeEventUponApiExit(uint32_t cbid,
260                                             const hip_api_data_t* data,
261                                             uint64_t enter_time,
262                                             uint64_t exit_time);
263   void AddSynchronizeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data,
264                                       uint64_t enter_time, uint64_t exit_time);
265 
266   RocmTracerOptions options_;
267   RocmTracer* tracer_ = nullptr;
268   RocmTraceCollector* collector_ = nullptr;
269   mutex api_call_start_mutex_;
270   // TODO(rocm-profiler): replace this with absl hashmap
271   // keep a map from the corr. id to enter time for API callbacks.
272   std::map<uint32_t, uint64_t> api_call_start_time_
273       TF_GUARDED_BY(api_call_start_mutex_);
274 };
275 
276 class RocmActivityCallbackImpl {
277  public:
RocmActivityCallbackImpl(const RocmTracerOptions & options,RocmTracer * tracer,RocmTraceCollector * collector)278   RocmActivityCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer,
279                            RocmTraceCollector* collector)
280       : options_(options), tracer_(tracer), collector_(collector) {}
281 
282   Status operator()(const char* begin, const char* end);
283 
284  private:
285   void AddHipKernelActivityEvent(const roctracer_record_t* record);
286   void AddNormalHipMemcpyActivityEvent(const roctracer_record_t* record);
287   void AddHipMemsetActivityEvent(const roctracer_record_t* record);
288   void AddHipMallocActivityEvent(const roctracer_record_t* record);
289   void AddHipStreamSynchronizeActivityEvent(const roctracer_record_t* record);
290   void AddHccKernelActivityEvent(const roctracer_record_t* record);
291   void AddNormalHipOpsMemcpyActivityEvent(const roctracer_record_t* record);
292   void AddHipOpsMemsetActivityEvent(const roctracer_record_t* record);
293   RocmTracerOptions options_;
294   RocmTracer* tracer_ = nullptr;
295   RocmTraceCollector* collector_ = nullptr;
296 };
297 
298 // The class use to enable cupti callback/activity API and forward the collected
299 // trace events to RocmTraceCollector. There should be only one RocmTracer
300 // per process.
301 class RocmTracer {
302  public:
303   // Returns a pointer to singleton RocmTracer.
304   static RocmTracer* GetRocmTracerSingleton();
305 
306   // Only one profile session can be live in the same time.
307   bool IsAvailable() const;
308 
309   void Enable(const RocmTracerOptions& options, RocmTraceCollector* collector);
310   void Disable();
311 
312   void ApiCallbackHandler(uint32_t domain, uint32_t cbid, const void* cbdata);
313   void ActivityCallbackHandler(const char* begin, const char* end);
314 
315   static uint64_t GetTimestamp();
316   static int NumGpus();
317 
AddToPendingActivityRecords(uint32_t correlation_id)318   void AddToPendingActivityRecords(uint32_t correlation_id) {
319     pending_activity_records_.Add(correlation_id);
320   }
321 
RemoveFromPendingActivityRecords(uint32_t correlation_id)322   void RemoveFromPendingActivityRecords(uint32_t correlation_id) {
323     pending_activity_records_.Remove(correlation_id);
324   }
325 
ClearPendingActivityRecordsCount()326   void ClearPendingActivityRecordsCount() { pending_activity_records_.Clear(); }
327 
GetPendingActivityRecordsCount()328   size_t GetPendingActivityRecordsCount() {
329     return pending_activity_records_.Count();
330   }
331 
332  protected:
333   // protected constructor for injecting mock cupti interface for testing.
RocmTracer()334   explicit RocmTracer() : num_gpus_(NumGpus()) {}
335 
336  private:
337   Status EnableApiTracing();
338   Status DisableApiTracing();
339 
340   Status EnableActivityTracing();
341   Status DisableActivityTracing();
342 
343   int num_gpus_;
344   absl::optional<RocmTracerOptions> options_;
345   RocmTraceCollector* collector_ = nullptr;
346 
347   bool api_tracing_enabled_ = false;
348   bool activity_tracing_enabled_ = false;
349 
350   RocmApiCallbackImpl* api_cb_impl_;
351   RocmActivityCallbackImpl* activity_cb_impl_;
352 
353   class PendingActivityRecords {
354    public:
355     // add a correlation id to the pending set
Add(uint32_t correlation_id)356     void Add(uint32_t correlation_id) {
357       absl::MutexLock lock(&mutex);
358       pending_set.insert(correlation_id);
359     }
360     // remove a correlation id from the pending set
Remove(uint32_t correlation_id)361     void Remove(uint32_t correlation_id) {
362       absl::MutexLock lock(&mutex);
363       pending_set.erase(correlation_id);
364     }
365     // clear the pending set
Clear()366     void Clear() {
367       absl::MutexLock lock(&mutex);
368       pending_set.clear();
369     }
370     // count the number of correlation ids in the pending set
Count()371     size_t Count() {
372       absl::MutexLock lock(&mutex);
373       return pending_set.size();
374     }
375 
376    private:
377     // set of co-relation ids for which the hcc activity record is pending
378     absl::flat_hash_set<uint32_t> pending_set;
379     // the callback which processes the activity records (and consequently
380     // removes items from the pending set) is called in a separate thread
381     // from the one that adds item to the list.
382     absl::Mutex mutex;
383   };
384   PendingActivityRecords pending_activity_records_;
385 
386  public:
387   // Disable copy and move.
388   RocmTracer(const RocmTracer&) = delete;
389   RocmTracer& operator=(const RocmTracer&) = delete;
390 };
391 
392 }  // namespace profiler
393 }  // namespace tensorflow
394 #endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_
395