1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_ 17 #define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_ 18 19 #include "absl/container/fixed_array.h" 20 #include "absl/container/flat_hash_map.h" 21 #include "absl/container/flat_hash_set.h" 22 #include "absl/container/node_hash_set.h" 23 #include "absl/types/optional.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/types.h" 28 #include "tensorflow/stream_executor/rocm/roctracer_wrapper.h" 29 30 namespace tensorflow { 31 namespace profiler { 32 33 struct MemcpyDetails { 34 // The amount of data copied for memcpy events. 35 size_t num_bytes; 36 // The destination device for peer-2-peer communication (memcpy). The source 37 // device is implicit: it's the current device. 38 uint32_t destination; 39 // Whether or not the memcpy is asynchronous. 40 bool async; 41 }; 42 43 struct MemsetDetails { 44 // The number of memory elements getting set 45 size_t num_bytes; 46 // Whether or not the memset is asynchronous. 47 bool async; 48 }; 49 50 struct MemAllocDetails { 51 // The amount of data requested for cudaMalloc events. 52 uint64_t num_bytes; 53 }; 54 55 struct KernelDetails { 56 // The number of registers used in this kernel. 57 uint32_t registers_per_thread; 58 // The amount of shared memory space used by a thread block. 59 uint32_t static_shared_memory_usage; 60 // The amount of dynamic memory space used by a thread block. 61 uint32_t dynamic_shared_memory_usage; 62 // X-dimension of a thread block. 63 uint32_t block_x; 64 // Y-dimension of a thread block. 65 uint32_t block_y; 66 // Z-dimension of a thread block. 67 uint32_t block_z; 68 // X-dimension of a grid. 69 uint32_t grid_x; 70 // Y-dimension of a grid. 71 uint32_t grid_y; 72 // Z-dimension of a grid. 73 uint32_t grid_z; 74 75 // kernel address. Used for calculating core occupancy 76 void* func_ptr; 77 }; 78 79 // RocmTracerSyncTypes forward decleration 80 enum class RocmTracerSyncTypes; 81 struct SynchronizationDetails { 82 RocmTracerSyncTypes sync_type; 83 }; 84 85 enum class RocmTracerEventType { 86 Unsupported = 0, 87 Kernel, 88 MemcpyH2D, 89 MemcpyD2H, 90 MemcpyD2D, 91 MemcpyP2P, 92 MemcpyOther, 93 MemoryAlloc, 94 MemoryFree, 95 Memset, 96 Synchronization, 97 Generic, 98 }; 99 100 const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type); 101 102 enum class RocmTracerEventSource { 103 Invalid = 0, 104 ApiCallback, 105 Activity, 106 }; 107 108 const char* GetRocmTracerEventSourceName(const RocmTracerEventSource& source); 109 110 enum class RocmTracerEventDomain { 111 InvalidDomain = 0, 112 HIP_API, 113 HCC_OPS, // TODO(rocm-profiler): renme this to HIP_OPS 114 }; 115 enum class RocmTracerSyncTypes { 116 InvalidSync = 0, 117 StreamSynchronize, // caller thread wait stream to become empty 118 EventSynchronize, // caller thread will block until event happens 119 StreamWait // compute stream will wait for event to happen 120 }; 121 122 const char* GetRocmTracerEventDomainName(const RocmTracerEventDomain& domain); 123 124 struct RocmTracerEvent { 125 static constexpr uint32_t kInvalidDeviceId = 126 std::numeric_limits<uint32_t>::max(); 127 static constexpr uint32_t kInvalidThreadId = 128 std::numeric_limits<uint32_t>::max(); 129 static constexpr uint32_t kInvalidCorrelationId = 130 std::numeric_limits<uint32_t>::max(); 131 static constexpr uint64_t kInvalidStreamId = 132 std::numeric_limits<uint64_t>::max(); 133 RocmTracerEventType type; 134 RocmTracerEventSource source = RocmTracerEventSource::Invalid; 135 RocmTracerEventDomain domain; 136 std::string name; 137 // This points to strings in AnnotationMap, which should outlive the point 138 // where serialization happens. 139 absl::string_view annotation; 140 absl::string_view roctx_range; 141 uint64_t start_time_ns = 0; 142 uint64_t end_time_ns = 0; 143 uint32_t device_id = kInvalidDeviceId; 144 uint32_t correlation_id = kInvalidCorrelationId; 145 uint32_t thread_id = kInvalidThreadId; 146 int64_t stream_id = kInvalidStreamId; 147 union { 148 MemcpyDetails memcpy_info; // If type == Memcpy* 149 MemsetDetails memset_info; // If type == Memset* 150 MemAllocDetails memalloc_info; // If type == MemoryAlloc 151 KernelDetails kernel_info; // If type == Kernel 152 SynchronizationDetails synchronization_info; // If type == Synchronization 153 }; 154 }; 155 156 void DumpRocmTracerEvent(const RocmTracerEvent& event, 157 uint64_t start_walltime_ns, uint64_t start_gputime_ns, 158 const string& message); 159 160 struct RocmTracerOptions { 161 std::set<uint32_t> api_tracking_set; // actual api set we want to profile 162 163 // map of domain --> ops for which we need to enable the API callbacks 164 // If the ops vector is empty, then enable API callbacks for entire domain 165 absl::flat_hash_map<activity_domain_t, std::vector<uint32_t> > api_callbacks; 166 167 // map of domain --> ops for which we need to enable the Activity records 168 // If the ops vector is empty, then enable Activity records for entire domain 169 absl::flat_hash_map<activity_domain_t, std::vector<uint32_t> > 170 activity_tracing; 171 }; 172 173 struct RocmTraceCollectorOptions { 174 // Maximum number of events to collect from callback API; if -1, no limit. 175 // if 0, the callback API is enabled to build a correlation map, but no 176 // events are collected. 177 uint64_t max_callback_api_events; 178 // Maximum number of events to collect from activity API; if -1, no limit. 179 uint64_t max_activity_api_events; 180 // Maximum number of annotation strings that we can accommodate. 181 uint64_t max_annotation_strings; 182 // Number of GPUs involved. 183 uint32_t num_gpus; 184 }; 185 186 class AnnotationMap { 187 public: AnnotationMap(uint64_t max_size)188 explicit AnnotationMap(uint64_t max_size) : max_size_(max_size) {} 189 void Add(uint32_t correlation_id, const std::string& annotation); 190 absl::string_view LookUp(uint32_t correlation_id); 191 192 private: 193 struct AnnotationMapImpl { 194 // The population/consumption of annotations might happen from multiple 195 // callback/activity api related threads. 196 absl::Mutex mutex; 197 // Annotation tends to be repetitive, use a hash_set to store the strings, 198 // an use the reference to the string in the map. 199 absl::node_hash_set<std::string> annotations; 200 absl::flat_hash_map<uint32_t, absl::string_view> correlation_map; 201 }; 202 const uint64_t max_size_; 203 AnnotationMapImpl map_; 204 205 public: 206 // Disable copy and move. 207 AnnotationMap(const AnnotationMap&) = delete; 208 AnnotationMap& operator=(const AnnotationMap&) = delete; 209 }; 210 211 class RocmTraceCollector { 212 public: RocmTraceCollector(const RocmTraceCollectorOptions & options)213 explicit RocmTraceCollector(const RocmTraceCollectorOptions& options) 214 : options_(options), annotation_map_(options.max_annotation_strings) {} ~RocmTraceCollector()215 virtual ~RocmTraceCollector() {} 216 217 virtual void AddEvent(RocmTracerEvent&& event, bool is_auxiliary) = 0; 218 virtual void OnEventsDropped(const std::string& reason, 219 uint32_t num_events) = 0; 220 virtual void Flush() = 0; 221 annotation_map()222 AnnotationMap* annotation_map() { return &annotation_map_; } 223 224 protected: 225 RocmTraceCollectorOptions options_; 226 227 private: 228 AnnotationMap annotation_map_; 229 230 public: 231 // Disable copy and move. 232 RocmTraceCollector(const RocmTraceCollector&) = delete; 233 RocmTraceCollector& operator=(const RocmTraceCollector&) = delete; 234 }; 235 236 class RocmTracer; 237 238 class RocmApiCallbackImpl { 239 public: RocmApiCallbackImpl(const RocmTracerOptions & options,RocmTracer * tracer,RocmTraceCollector * collector)240 RocmApiCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer, 241 RocmTraceCollector* collector) 242 : options_(options), tracer_(tracer), collector_(collector) {} 243 244 Status operator()(uint32_t domain, uint32_t cbid, const void* cbdata); 245 246 private: 247 void AddKernelEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, 248 uint64_t enter_time, uint64_t exit_time); 249 void AddNormalMemcpyEventUponApiExit(uint32_t cbid, 250 const hip_api_data_t* data, 251 uint64_t enter_time, uint64_t exit_time); 252 void AddMemcpyPeerEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, 253 uint64_t enter_time, uint64_t exit_time); 254 void AddMemsetEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, 255 uint64_t enter_time, uint64_t exit_time); 256 void AddMallocFreeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, 257 uint32_t device_id, uint64_t enter_time, 258 uint64_t exit_time); 259 void AddStreamSynchronizeEventUponApiExit(uint32_t cbid, 260 const hip_api_data_t* data, 261 uint64_t enter_time, 262 uint64_t exit_time); 263 void AddSynchronizeEventUponApiExit(uint32_t cbid, const hip_api_data_t* data, 264 uint64_t enter_time, uint64_t exit_time); 265 266 RocmTracerOptions options_; 267 RocmTracer* tracer_ = nullptr; 268 RocmTraceCollector* collector_ = nullptr; 269 mutex api_call_start_mutex_; 270 // TODO(rocm-profiler): replace this with absl hashmap 271 // keep a map from the corr. id to enter time for API callbacks. 272 std::map<uint32_t, uint64_t> api_call_start_time_ 273 TF_GUARDED_BY(api_call_start_mutex_); 274 }; 275 276 class RocmActivityCallbackImpl { 277 public: RocmActivityCallbackImpl(const RocmTracerOptions & options,RocmTracer * tracer,RocmTraceCollector * collector)278 RocmActivityCallbackImpl(const RocmTracerOptions& options, RocmTracer* tracer, 279 RocmTraceCollector* collector) 280 : options_(options), tracer_(tracer), collector_(collector) {} 281 282 Status operator()(const char* begin, const char* end); 283 284 private: 285 void AddHipKernelActivityEvent(const roctracer_record_t* record); 286 void AddNormalHipMemcpyActivityEvent(const roctracer_record_t* record); 287 void AddHipMemsetActivityEvent(const roctracer_record_t* record); 288 void AddHipMallocActivityEvent(const roctracer_record_t* record); 289 void AddHipStreamSynchronizeActivityEvent(const roctracer_record_t* record); 290 void AddHccKernelActivityEvent(const roctracer_record_t* record); 291 void AddNormalHipOpsMemcpyActivityEvent(const roctracer_record_t* record); 292 void AddHipOpsMemsetActivityEvent(const roctracer_record_t* record); 293 RocmTracerOptions options_; 294 RocmTracer* tracer_ = nullptr; 295 RocmTraceCollector* collector_ = nullptr; 296 }; 297 298 // The class use to enable cupti callback/activity API and forward the collected 299 // trace events to RocmTraceCollector. There should be only one RocmTracer 300 // per process. 301 class RocmTracer { 302 public: 303 // Returns a pointer to singleton RocmTracer. 304 static RocmTracer* GetRocmTracerSingleton(); 305 306 // Only one profile session can be live in the same time. 307 bool IsAvailable() const; 308 309 void Enable(const RocmTracerOptions& options, RocmTraceCollector* collector); 310 void Disable(); 311 312 void ApiCallbackHandler(uint32_t domain, uint32_t cbid, const void* cbdata); 313 void ActivityCallbackHandler(const char* begin, const char* end); 314 315 static uint64_t GetTimestamp(); 316 static int NumGpus(); 317 AddToPendingActivityRecords(uint32_t correlation_id)318 void AddToPendingActivityRecords(uint32_t correlation_id) { 319 pending_activity_records_.Add(correlation_id); 320 } 321 RemoveFromPendingActivityRecords(uint32_t correlation_id)322 void RemoveFromPendingActivityRecords(uint32_t correlation_id) { 323 pending_activity_records_.Remove(correlation_id); 324 } 325 ClearPendingActivityRecordsCount()326 void ClearPendingActivityRecordsCount() { pending_activity_records_.Clear(); } 327 GetPendingActivityRecordsCount()328 size_t GetPendingActivityRecordsCount() { 329 return pending_activity_records_.Count(); 330 } 331 332 protected: 333 // protected constructor for injecting mock cupti interface for testing. RocmTracer()334 explicit RocmTracer() : num_gpus_(NumGpus()) {} 335 336 private: 337 Status EnableApiTracing(); 338 Status DisableApiTracing(); 339 340 Status EnableActivityTracing(); 341 Status DisableActivityTracing(); 342 343 int num_gpus_; 344 absl::optional<RocmTracerOptions> options_; 345 RocmTraceCollector* collector_ = nullptr; 346 347 bool api_tracing_enabled_ = false; 348 bool activity_tracing_enabled_ = false; 349 350 RocmApiCallbackImpl* api_cb_impl_; 351 RocmActivityCallbackImpl* activity_cb_impl_; 352 353 class PendingActivityRecords { 354 public: 355 // add a correlation id to the pending set Add(uint32_t correlation_id)356 void Add(uint32_t correlation_id) { 357 absl::MutexLock lock(&mutex); 358 pending_set.insert(correlation_id); 359 } 360 // remove a correlation id from the pending set Remove(uint32_t correlation_id)361 void Remove(uint32_t correlation_id) { 362 absl::MutexLock lock(&mutex); 363 pending_set.erase(correlation_id); 364 } 365 // clear the pending set Clear()366 void Clear() { 367 absl::MutexLock lock(&mutex); 368 pending_set.clear(); 369 } 370 // count the number of correlation ids in the pending set Count()371 size_t Count() { 372 absl::MutexLock lock(&mutex); 373 return pending_set.size(); 374 } 375 376 private: 377 // set of co-relation ids for which the hcc activity record is pending 378 absl::flat_hash_set<uint32_t> pending_set; 379 // the callback which processes the activity records (and consequently 380 // removes items from the pending set) is called in a separate thread 381 // from the one that adds item to the list. 382 absl::Mutex mutex; 383 }; 384 PendingActivityRecords pending_activity_records_; 385 386 public: 387 // Disable copy and move. 388 RocmTracer(const RocmTracer&) = delete; 389 RocmTracer& operator=(const RocmTracer&) = delete; 390 }; 391 392 } // namespace profiler 393 } // namespace tensorflow 394 #endif // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_ROCM_TRACER_H_ 395