1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #include <cuda_runtime.h>
11
12 #include <ATen/cuda/CUDAContextLight.h>
13 #include <ATen/cuda/tunable/Tunable.h>
14 #include <c10/util/Exception.h>
15 #include <c10/util/StringUtil.h>
16 #include <torch/version.h>
17
18 #ifndef _WIN32
19 #include <cxxabi.h>
20 #endif
21
22 #include <chrono>
23 #include <fstream>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <mutex>
28 #include <sstream>
29 #include <string>
30 #include <thread>
31 #include <type_traits>
32 #include <unordered_map>
33 #include <unordered_set>
34 #include <utility>
35 #include <vector>
36
37 // for validators
38 #ifdef USE_ROCM
39 #include <rocm-core/rocm_version.h>
40 #define ROCBLAS_BETA_FEATURES_API
41 #include <rocblas/rocblas.h>
42 #include <hipblaslt/hipblaslt.h>
43 #include <hipblaslt/hipblaslt-ext.hpp>
44 #endif
45
46 namespace at::cuda::tunable {
47
getTuningContext()48 TuningContext* getTuningContext() {
49 static TuningContext tuning_context;
50 return &tuning_context;
51 }
52
operator <<(std::ostream & stream,const ResultEntry & entry)53 std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
54 return stream << entry.key_ << "," << entry.time_;
55 }
56
57 // TuningResultsManager
58
Lookup(const std::string & op_signature)59 KernelMap TuningResultsManager::Lookup(const std::string& op_signature) {
60 std::scoped_lock l{lock_};
61 auto it = results_.find(op_signature);
62 if (it == results_.cend()) {
63 return {};
64 }
65 return it->second; // copied
66 }
67
Lookup(const std::string & op_signature,const std::string & params_signature)68 ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const std::string& params_signature) {
69 std::scoped_lock l{lock_};
70 auto kernel_map_it = results_.find(op_signature);
71 if (kernel_map_it == results_.cend()) {
72 TUNABLE_LOG3("missing op_signature, returning null ResultEntry for ", op_signature, ",", params_signature);
73 return ResultEntry::Null();
74 }
75
76 const auto& km = kernel_map_it->second;
77 auto it = km.find(params_signature);
78 if (it == km.cend()) {
79 TUNABLE_LOG3("missing params_signature, returning null ResultEntry for ", op_signature, ",", params_signature);
80 return ResultEntry::Null();
81 }
82 TUNABLE_LOG3("ResultEntry found for ", op_signature, ",", params_signature);
83 return it->second;
84 }
85
AddImpl(const std::string & op_signature,const std::string & params_signature,ResultEntry best,KernelMap & kernel_map)86 inline void TuningResultsManager::AddImpl(const std::string& op_signature,
87 const std::string& params_signature,
88 ResultEntry best,
89 KernelMap& kernel_map) {
90 auto it = kernel_map.find(params_signature);
91 if (it != kernel_map.end()) {
92 if (it->second != best) {
93 TUNABLE_LOG1(op_signature, "(", params_signature, ") already has a best kernel ",
94 "id=", it->second, " selected, want to add a different best kernel ", best,
95 ", the new kernel id will be ignored.");
96 }
97 return;
98 }
99
100 TUNABLE_LOG2(op_signature, "(", params_signature, ") -> ", best);
101 kernel_map.emplace(params_signature, best);
102 }
103
Add(const std::string & op_signature,const std::string & params_signature,ResultEntry best)104 void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) {
105 std::scoped_lock l{lock_};
106
107 auto it = results_.find(op_signature);
108 if (it == results_.end()) {
109 it = results_.insert({op_signature, {}}).first;
110 }
111
112 AddImpl(op_signature, params_signature, best, it->second);
113 }
114
Delete(const std::string & op_signature,const std::string & params_signature)115 void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) {
116 std::scoped_lock l{lock_};
117
118 auto it = results_.find(op_signature);
119 if (it == results_.end()) {
120 return;
121 }
122
123 auto it2 = it->second.find(params_signature);
124 if (it2 == it->second.end()) {
125 return;
126 }
127
128 TUNABLE_LOG2(op_signature, "(", params_signature, ")");
129 it->second.erase(it2);
130 }
131
DisjointMergeImpl(const std::string & op_signature,const KernelMap & kernel_map,std::unordered_map<std::string,KernelMap> & results)132 inline void TuningResultsManager::DisjointMergeImpl(
133 const std::string& op_signature,
134 const KernelMap& kernel_map,
135 /*out*/ std::unordered_map<std::string, KernelMap>& results) {
136 auto it = results.find(op_signature);
137 if (it == results.end()) {
138 for (const auto& [param_sig, kernel_id] : kernel_map) {
139 TUNABLE_LOG2(op_signature, "(", param_sig, ") -> ", kernel_id);
140 }
141 results[op_signature] = kernel_map;
142 return;
143 }
144
145 for (const auto& [params_signature, best] : kernel_map) {
146 AddImpl(op_signature, params_signature, best, it->second);
147 }
148 }
149
Load(const std::unordered_map<std::string,KernelMap> & results_to_load)150 void TuningResultsManager::Load(const std::unordered_map<std::string, KernelMap>& results_to_load) {
151 TUNABLE_LOG1("Loading results");
152 std::scoped_lock l{lock_};
153 for (const auto& [op_signature, kernel_map] : results_to_load) {
154 DisjointMergeImpl(op_signature, kernel_map, results_);
155 }
156 }
157
Dump()158 ResultsMap TuningResultsManager::Dump() {
159 std::scoped_lock l{lock_};
160 return results_;
161 }
162
DisjointMerge(const std::string & op_signature,const KernelMap & kernel_map)163 void TuningResultsManager::DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map) {
164 std::scoped_lock l{lock_};
165 DisjointMergeImpl(op_signature, kernel_map, results_);
166 }
167
GetSize()168 size_t TuningResultsManager::GetSize() {
169 size_t size = 0;
170 std::scoped_lock l{lock_};
171 for (const auto& [op_signature, kernel_map] : results_) {
172 size += kernel_map.size();
173 }
174 return size;
175 }
176
177 // TuningResultsValidator
178
TuningResultsValidator()179 TuningResultsValidator::TuningResultsValidator() {
180 RegisterValidator(
181 "PT_VERSION",
182 [this]() { return GetPyTorchVersion(); },
183 [this](auto&& k) { return ValidatePyTorchVersion(std::forward<decltype(k)>(k)); });
184 #ifdef USE_ROCM
185 // rocm
186 {
187 std::string rocm_version = ROCM_BUILD_INFO;
188 RegisterValidator(
189 "ROCM_VERSION",
190 [rocm_version]() { return rocm_version; },
191 [rocm_version](auto&& k) { return rocm_version == k ? OK : FAIL; });
192 }
193 // gfx arch
194 {
195 std::string gcn_arch_name = at::cuda::getCurrentDeviceProperties()->gcnArchName;
196 RegisterValidator(
197 "GCN_ARCH_NAME",
198 [gcn_arch_name]() { return gcn_arch_name; },
199 [gcn_arch_name](auto&& k) { return gcn_arch_name == k ? OK : FAIL; });
200 }
201 // rocblas
202 {
203 #define STRINGIFY(s) #s
204 #define XSTRINGIFY(s) STRINGIFY(s)
205 std::string rocblas_version = c10::str(
206 XSTRINGIFY(ROCBLAS_VERSION_MAJOR), ".",
207 XSTRINGIFY(ROCBLAS_VERSION_MINOR), ".",
208 XSTRINGIFY(ROCBLAS_VERSION_PATCH), "-",
209 XSTRINGIFY(ROCBLAS_VERSION_TWEAK));
210 #undef XSTRINGIFY
211 #undef STRINGIFY
212 RegisterValidator(
213 "ROCBLAS_VERSION",
214 [rocblas_version]() { return rocblas_version; },
215 [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
216 }
217 // hipblaslt
218 {
219 int version;
220 std::string revision(128, '\0');
221 auto handle = at::cuda::getCurrentCUDABlasLtHandle();
222 hipblasLtGetVersion(handle, &version);
223 hipblasLtGetGitRevision(handle, revision.data());
224 std::string hipblaslt_version =
225 c10::str(version, "-", revision.c_str());
226 RegisterValidator(
227 "HIPBLASLT_VERSION",
228 [hipblaslt_version]() { return hipblaslt_version; },
229 [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
230 }
231 #endif
232 }
233
GetAllValidators() const234 std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
235 std::unordered_map<std::string, std::string> ret;
236 for (const auto& [key, get_validate_func_pair] : validators_) {
237 const GetFunc& getter = get_validate_func_pair.first;
238 ret[key] = getter();
239 }
240 return ret;
241 }
242
CheckMandatoryKeys(const TuningResultsValidator::GetValidateFuncs & gv_funcs,const std::unordered_map<std::string,std::string> & to_check)243 static bool CheckMandatoryKeys(
244 const TuningResultsValidator::GetValidateFuncs& gv_funcs,
245 const std::unordered_map<std::string, std::string>& to_check) {
246 bool passed = true;
247 for (const auto& k : TuningResultsValidator::mandatory_keys) {
248 if (gv_funcs.find(k) == gv_funcs.end()) {
249 passed = false;
250 TUNABLE_LOG1("key=\"", k, "\" is not registered for Get and Validate. ");
251 }
252
253 if (to_check.find(k) == to_check.end()) {
254 passed = false;
255 TUNABLE_LOG1("key=\"", k, "\" is not provided for validation. ");
256 }
257 }
258 return passed;
259 }
260
CheckKeysMatching(const TuningResultsValidator::GetValidateFuncs & gv_funcs,const std::unordered_map<std::string,std::string> & to_check)261 static bool CheckKeysMatching(
262 const TuningResultsValidator::GetValidateFuncs& gv_funcs,
263 const std::unordered_map<std::string, std::string>& to_check) {
264 auto get_keys = [](const auto& it) -> std::string { return it.first; };
265 std::vector<std::string> required_keys;
266 std::vector<std::string> provided_keys;
267 std::transform(gv_funcs.cbegin(), gv_funcs.cend(), std::back_inserter(required_keys), get_keys);
268 std::transform(to_check.cbegin(), to_check.cend(), std::back_inserter(provided_keys), get_keys);
269 std::sort(required_keys.begin(), required_keys.end());
270 std::sort(provided_keys.begin(), provided_keys.end());
271
272 std::unordered_set<std::string> intersection;
273 std::set_intersection(required_keys.cbegin(), required_keys.cend(),
274 provided_keys.cbegin(), provided_keys.cend(),
275 std::inserter(intersection, intersection.end()));
276 bool matched = true;
277 if (intersection.size() != required_keys.size()) {
278 matched = false;
279 for (const auto& k : required_keys) {
280 if (intersection.find(k) == intersection.end()) {
281 TORCH_WARN("Unmatched validator: \"", k, "\" is required, but the tuning results does not provide it. ");
282 }
283 }
284 }
285 if (intersection.size() != provided_keys.size()) {
286 matched = false;
287 for (const auto& k : provided_keys) {
288 if (intersection.find(k) == intersection.end()) {
289 TORCH_WARN("Unmatched validator: \"", k, "\" is provided, but pytorch is unable to consume it. ");
290 }
291 }
292 }
293 return matched;
294 }
295
ValidateAll(const std::unordered_map<std::string,std::string> & to_validate) const296 TuningStatus TuningResultsValidator::ValidateAll(
297 const std::unordered_map<std::string, std::string>& to_validate) const {
298 if (!CheckMandatoryKeys(validators_, to_validate)) {
299 return FAIL;
300 }
301 if (!CheckKeysMatching(validators_, to_validate)) {
302 return FAIL;
303 }
304
305 for (const auto& [key, value] : to_validate) {
306 const auto& it = validators_.find(key);
307 if (it == validators_.cend()) {
308 TORCH_WARN("Failed to lookup validator using key ", key);
309 for (const auto& [key2, val2] : validators_) {
310 TORCH_WARN("available key ", key2);
311 }
312 return FAIL;
313 }
314 const ValidateFunc& validator = it->second.second;
315 if (validator(value) != OK) {
316 TORCH_WARN("Failed validator: ", key);
317 return FAIL;
318 }
319 }
320
321 return OK;
322 }
323
RegisterValidator(const std::string & key,const GetFunc & gf,const ValidateFunc & vf)324 void TuningResultsValidator::RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf) {
325 if (validators_.find(key) != validators_.end()) {
326 TORCH_WARN("Attempting to re-register validator with key ", key);
327 }
328 else {
329 validators_[key] = std::make_pair(gf, vf);
330 }
331 }
332
GetPyTorchVersion() const333 std::string TuningResultsValidator::GetPyTorchVersion() const {
334 return TORCH_VERSION;
335 }
336
ValidatePyTorchVersion(const std::string & value) const337 TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& value) const {
338 TUNABLE_LOG1("PT_VERSION validation: expect ", value, " to match ", GetPyTorchVersion());
339 if (value == GetPyTorchVersion()) {
340 return OK;
341 }
342 return FAIL;
343 }
344
345 // TuningContext
346
TuningContext()347 TuningContext::TuningContext() :
348 enable_{false},
349 tuning_enable_{true},
350 manager_initialized_{false},
351 write_file_on_exit_{true},
352 numerics_check_enable_{false},
353 max_tuning_duration_ms_{30},
354 max_tuning_iterations_{100},
355 max_warmup_duration_ms_{0},
356 max_warmup_iterations_{0},
357 icache_flush_{true},
358 rotating_buffer_size_{-1},
359 filename_{},
360 results_count_from_input_file_{0}
361 {
362 }
363
~TuningContext()364 TuningContext::~TuningContext() {
365 if (!manager_initialized_) {
366 // TuningResultsManager was never initialized, no tuning requested or performed.
367 // This can happen in a DDP job where a python process spawns other workers
368 // but doesn't do any computation itself.
369 return;
370 }
371 auto filename = GetFilename();
372 if (IsTunableOpEnabled() && IsTuningEnabled() && !filename.empty() && write_file_on_exit_) {
373 if (results_count_from_input_file_ < GetTuningResultsManager().GetSize()) {
374 if (results_count_from_input_file_ > 0) {
375 TUNABLE_LOG1("additional tuning results available, rewriting file ", filename);
376 }
377 else {
378 TUNABLE_LOG1("writing file ", filename);
379 }
380 if (!WriteFile(filename)) {
381 TUNABLE_LOG1("failed to write file ", filename);
382 }
383 }
384 }
385 }
386
EnableTunableOp(bool value)387 void TuningContext::EnableTunableOp(bool value) {
388 enable_ = value;
389 if (value) {
390 TUNABLE_LOG1("Enable TunableOp");
391 }
392 else {
393 TUNABLE_LOG1("Disable TunableOp");
394 }
395 }
396
IsTunableOpEnabled() const397 bool TuningContext::IsTunableOpEnabled() const {
398 static const char *env = std::getenv("PYTORCH_TUNABLEOP_ENABLED");
399 if (env != nullptr && strcmp(env, "1") == 0) {
400 return true;
401 }
402 return enable_;
403 }
404
EnableTuning(bool value)405 void TuningContext::EnableTuning(bool value) {
406 tuning_enable_ = value;
407 if (value) {
408 TUNABLE_LOG1("Enable Tuning for TunableOp");
409 }
410 else {
411 TUNABLE_LOG1("Disable Tuning for TunableOp");
412 }
413 }
414
IsTuningEnabled() const415 bool TuningContext::IsTuningEnabled() const {
416 static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING");
417 if (env != nullptr && strcmp(env, "0") == 0) {
418 return false;
419 }
420 return tuning_enable_;
421 }
422
WriteFileOnExit(bool value)423 void TuningContext::WriteFileOnExit(bool value) {
424 write_file_on_exit_ = value;
425 }
426
EnableNumericsCheck(bool value)427 void TuningContext::EnableNumericsCheck(bool value) {
428 numerics_check_enable_ = value;
429 }
430
IsNumericsCheckEnabled() const431 bool TuningContext::IsNumericsCheckEnabled() const {
432 const char *env = getenv("PYTORCH_TUNABLEOP_NUMERICAL_CHECK");
433 if (env != nullptr && strcmp(env, "1") == 0) {
434 return true;
435 }
436 return numerics_check_enable_;
437 }
438
SetMaxTuningDurationMs(int max_duration_ms)439 void TuningContext::SetMaxTuningDurationMs(int max_duration_ms) {
440 max_tuning_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms;
441 }
442
GetMaxTuningDurationMs() const443 int TuningContext::GetMaxTuningDurationMs() const {
444 static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_DURATION_MS");
445 if (env != nullptr) {
446 int val = atoi(env);
447 return val < 0 ? 0 : val;
448 }
449 return max_tuning_duration_ms_;
450 }
451
SetMaxTuningIterations(int max_iter)452 void TuningContext::SetMaxTuningIterations(int max_iter) {
453 max_tuning_iterations_ = max_iter < 0 ? 0 : max_iter;
454 }
455
GetMaxTuningIterations() const456 int TuningContext::GetMaxTuningIterations() const {
457 static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_TUNING_ITERATIONS");
458 if (env != nullptr) {
459 int val = atoi(env);
460 return val < 0 ? 0 : val;
461 }
462 return max_tuning_iterations_;
463 }
464
SetMaxWarmupDurationMs(int max_duration_ms)465 void TuningContext::SetMaxWarmupDurationMs(int max_duration_ms) {
466 max_warmup_duration_ms_ = max_duration_ms < 0 ? 0 : max_duration_ms;
467 }
468
GetMaxWarmupDurationMs() const469 int TuningContext::GetMaxWarmupDurationMs() const {
470 static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS");
471 if (env != nullptr) {
472 int val = atoi(env);
473 return val < 0 ? 0 : val;
474 }
475 return max_warmup_duration_ms_;
476 }
477
SetMaxWarmupIterations(int max_iter)478 void TuningContext::SetMaxWarmupIterations(int max_iter) {
479 max_warmup_iterations_ = max_iter < 0 ? 0 : max_iter;
480 }
481
GetMaxWarmupIterations() const482 int TuningContext::GetMaxWarmupIterations() const {
483 static const char *env = std::getenv("PYTORCH_TUNABLEOP_MAX_WARMUP_ITERATIONS");
484 if (env != nullptr) {
485 int val = atoi(env);
486 return val < 0 ? 0 : val;
487 }
488 return max_warmup_iterations_;
489 }
490
EnableICacheFlush(bool value)491 void TuningContext::EnableICacheFlush(bool value) {
492 icache_flush_ = value;
493 }
494
IsICacheFlushEnabled() const495 bool TuningContext::IsICacheFlushEnabled() const {
496 static const char *env = std::getenv("PYTORCH_TUNABLEOP_ICACHE_FLUSH_ENABLED");
497 if (env != nullptr && strcmp(env, "0") == 0) {
498 return false;
499 }
500 return icache_flush_;
501 }
502
SetRotatingBufferSize(int size)503 void TuningContext::SetRotatingBufferSize(int size) {
504 rotating_buffer_size_ = size < 0 ? 0 : size;
505 }
506
GetRotatingBufferSize() const507 int TuningContext::GetRotatingBufferSize() const {
508 static const char *env = std::getenv("PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE");
509 if (env != nullptr) {
510 constexpr int MB = 1024 * 1024;
511 int val = atoi(env);
512 return val < 0 ? 0 : val * MB; // env var is specified as MB, returned as bytes
513 }
514 else {
515 if (rotating_buffer_size_ < 0) {
516 // negative buffer size (default) means query for L2 cache size
517 int l2_cache_size = at::cuda::getCurrentDeviceProperties()->l2CacheSize;
518 return l2_cache_size;
519 }
520 else {
521 return rotating_buffer_size_;
522 }
523 }
524 }
525
GetTuningResultsManager()526 TuningResultsManager& TuningContext::GetTuningResultsManager() {
527 c10::call_once(manager_init_once_, [this]() {
528 manager_initialized_ = true;
529 if (GetFilename().empty()) {
530 // if SetFilename() was not already called, call it now with the default or env var
531 const char *env = std::getenv("PYTORCH_TUNABLEOP_FILENAME");
532 std::string filename = (env == nullptr) ? "tunableop_results.csv" : env;
533 SetFilename(filename, true);
534 }
535 auto filename = GetFilename();
536 if (!filename.empty()) {
537 ReadFile(filename);
538 // attempt immediately to open file for writing to catch errors early
539 std::ofstream file(filename, std::ios::out | std::ios::app);
540 if (!file.good()) {
541 TORCH_WARN("failed to open file '", filename, "' for writing; your tuning results will not be saved");
542 }
543 }
544 });
545 return manager_;
546 }
547
GetTuningResultsValidator()548 TuningResultsValidator& TuningContext::GetTuningResultsValidator() {
549 return validator_;
550 }
551
GetTuningResults()552 TuningResults TuningContext::GetTuningResults() {
553 TuningResults tr;
554 tr.validators = GetTuningResultsValidator().GetAllValidators();
555 tr.results = GetTuningResultsManager().Dump();
556 return tr;
557 }
558
LoadTuningResults(const TuningResults & tr)559 TuningStatus TuningContext::LoadTuningResults(const TuningResults& tr) {
560 TORCH_CHECK(GetTuningResultsValidator().ValidateAll(tr.validators));
561 GetTuningResultsManager().Load(tr.results);
562 return OK;
563 }
564
SetFilename(const std::string & filename,bool insert_device_ordinal)565 void TuningContext::SetFilename(const std::string& filename, bool insert_device_ordinal) {
566 filename_ = filename;
567
568 if (filename_.empty()) {
569 return;
570 }
571
572 if (insert_device_ordinal) {
573 // differentiate filename based on device ordinal to avoid
574 // use case of one process per device writing to same file
575 std::string device = c10::str(int(c10::cuda::current_device()));
576
577 // does filename contain %d to insert device ordinal in specific location?
578 const std::string TOKEN("%d");
579 std::size_t found = filename_.find(TOKEN);
580 if (found != std::string::npos) {
581 filename_.replace(found, TOKEN.length(), device);
582 }
583 else {
584 // no %d present, so append device ordinal before final '.'
585 found = filename_.rfind('.');
586 if (found != std::string::npos) {
587 filename_.insert(found, device);
588 }
589 else {
590 // all else fails, just append
591 filename_.append(device);
592 }
593 }
594 }
595 }
596
GetFilename() const597 std::string TuningContext::GetFilename() const {
598 return filename_;
599 }
600
ReadFile(const std::string & filename_)601 bool TuningContext::ReadFile(const std::string& filename_) {
602 std::string filename = filename_.empty() ? GetFilename() : filename_;
603 TUNABLE_LOG1("reading tuning results from ", filename);
604 ResultsMap results;
605 std::unordered_map<std::string, std::string> validators;
606 std::string line;
607 std::ifstream file(filename);
608 if (!file) {
609 TUNABLE_LOG1("could not open ", filename, " for reading tuning results");
610 return false;
611 }
612 while (std::getline(file, line)) {
613 if (line.empty()) {
614 continue;
615 }
616 std::string part;
617 std::vector<std::string> parts;
618 std::stringstream line_as_stream(line);
619 while (std::getline(line_as_stream, part, ',')) {
620 parts.push_back(part);
621 }
622 if (parts[0] == "Validator" && parts.size() >= 3) {
623 validators[parts[1]] = parts[2];
624 TUNABLE_LOG1("Validator ", parts[1], "=", parts[2]);
625 }
626 else if (parts.size() >= 4) {
627 results[parts[0]].emplace(parts[1], ResultEntry(parts[2], atof(parts[3].c_str())));
628 }
629 else if (parts.size() >= 3) {
630 // the timestamp from the file is optional
631 results[parts[0]].emplace(parts[1], ResultEntry(parts[2], 0));
632 }
633 else {
634 TUNABLE_LOG1("could not parse line: ", line);
635 }
636 }
637 if (GetTuningResultsValidator().ValidateAll(validators) != FAIL) {
638 manager_.Load(results);
639 results_count_from_input_file_ = manager_.GetSize();
640 }
641 else {
642 TUNABLE_LOG1("results validator check failed");
643 return false;
644 }
645 return true;
646 }
647
WriteFile(const std::string & filename_)648 bool TuningContext::WriteFile(const std::string& filename_) {
649 std::string filename = filename_.empty() ? GetFilename() : filename_;
650 std::ofstream file(filename, std::ios::out | std::ios::trunc);
651 if (!file.good()) {
652 TUNABLE_LOG1("error opening tuning results file for writing ", filename);
653 return false;
654 }
655 auto validators = GetTuningResultsValidator().GetAllValidators();
656 for (const auto& [key, val] : validators) {
657 file << "Validator," << key << "," << val << std::endl;
658 }
659 auto results = GetTuningResultsManager().Dump();
660 for (const auto& [op_sig, kernelmap] : results) {
661 for (const auto& [param_sig, result] : kernelmap) {
662 file << op_sig << "," << param_sig << "," << result << std::endl;
663 }
664 }
665 file.close();
666 return true;
667 }
668
669 } // namespace at::cuda::tunable
670