1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11
12 #include <c10/util/CallOnce.h>
13
14 #include <fstream>
15 #include <functional>
16 #include <iostream>
17 #include <memory>
18 #include <mutex>
19 #include <string>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24
25 namespace at::cuda::tunable {
26
27 namespace detail {
28
29 struct MaybeDelete {
30 bool owns_pointer;
operatorMaybeDelete31 void operator()(std::ostream* os) const { if (owns_pointer) delete os; }
32 };
33
34 using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
35
get_stream(std::string filename)36 static OstreamPtr get_stream(std::string filename) {
37 if (filename.compare("out") == 0) {
38 return OstreamPtr { &std::cout, MaybeDelete {false} };
39 }
40 else if (filename.compare("err") == 0) {
41 return OstreamPtr { &std::cerr, MaybeDelete {false} };
42 }
43 else {
44 return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} };
45 }
46 }
47
48 }
49
TunableLog(int level,const std::string & msg)50 static void TunableLog(int level, const std::string& msg) {
51 static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
52 static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
53 static int level_user = env_verbose ? atoi(env_verbose) : 0;
54 static auto streamptr = detail::get_stream(env_file ? env_file : "err");
55 if (level_user >= level) {
56 (*streamptr) << msg <<std::endl;
57 }
58 }
59 #define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
60 #define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
61 #define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
62 #define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
63
64 enum TORCH_CUDA_CPP_API TuningStatus {
65 OK = 0,
66 FAIL = 1,
67 UNSUPPORTED = 2,
68 };
69
70 // Mapping from params signature to kernel id
71 class TORCH_CUDA_CPP_API ResultEntry {
72 public:
ResultEntry(const std::string & key,double time)73 explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
74 bool operator==(const ResultEntry& other) { return key_ == other.key_; }
75 bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
string()76 operator std::string () { return key_; }
GetKey()77 std::string GetKey() const { return key_; }
GetTime()78 double GetTime() const { return time_; }
79 friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
Null()80 static ResultEntry Null() { return ResultEntry("Null", 0.0); }
Default()81 static ResultEntry Default() { return ResultEntry("Default", 0.0); }
82
83 private:
84 std::string key_;
85 double time_;
86 };
87
88 typedef std::unordered_map<std::string, ResultEntry> KernelMap;
89 typedef std::unordered_map<std::string, KernelMap> ResultsMap;
90
91 struct TORCH_CUDA_CPP_API TuningResults {
92 // Validates if these results are compatible with the libraries
93 std::unordered_map<std::string, std::string> validators;
94
95 // Mapping from Callable signature to Callable's tuning result
96 ResultsMap results;
97 };
98
99 class TORCH_CUDA_CPP_API TuningResultsManager {
100 public:
101 TuningResultsManager() = default;
102 ~TuningResultsManager() = default;
103
104 KernelMap Lookup(const std::string& op_signature);
105
106 ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
107
108 inline void AddImpl(const std::string& op_signature,
109 const std::string& params_signature,
110 ResultEntry best,
111 KernelMap& kernel_map);
112
113 void Add(const std::string& op_signature,
114 const std::string& params_signature,
115 ResultEntry best);
116
117 void Delete(const std::string& op_signature, const std::string& params_signature);
118
119 inline void DisjointMergeImpl(
120 const std::string& op_signature,
121 const KernelMap& kernel_map,
122 /*out*/ ResultsMap& results);
123
124 void Load(const ResultsMap& results_to_load);
125
126 ResultsMap Dump();
127
128 void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
129
130 size_t GetSize();
131
132 private:
133 std::mutex lock_;
134 ResultsMap results_;
135 };
136
137 class TORCH_CUDA_CPP_API TuningResultsValidator {
138 public:
139 using GetFunc = std::function<std::string()>;
140 using ValidateFunc = std::function<TuningStatus(const std::string&)>;
141 using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
142
143 TuningResultsValidator();
144 ~TuningResultsValidator() = default;
145
146 std::unordered_map<std::string, std::string> GetAllValidators() const;
147 TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
148 void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
149
150 protected:
151 std::string GetPyTorchVersion() const;
152 TuningStatus ValidatePyTorchVersion(const std::string& value) const;
153
154 public:
155 static constexpr const std::array mandatory_keys{"PT_VERSION"};
156
157 private:
158 GetValidateFuncs validators_;
159 };
160
161 class TORCH_CUDA_CPP_API TuningContext {
162 public:
163 TuningContext();
164 ~TuningContext();
165 TuningContext(TuningContext &) = delete;
166 TuningContext(TuningContext &&) = delete;
167 TuningContext &operator=(TuningContext &) = delete;
168 TuningContext &operator=(TuningContext &&) = delete;
169
170 void EnableTunableOp(bool value);
171 bool IsTunableOpEnabled() const;
172
173 void EnableTuning(bool value);
174 bool IsTuningEnabled() const;
175
176 void EnableNumericsCheck(bool value);
177 bool IsNumericsCheckEnabled() const;
178
179 void SetMaxTuningDurationMs(int max_duration_ms);
180 int GetMaxTuningDurationMs() const;
181
182 void SetMaxTuningIterations(int max_iter);
183 int GetMaxTuningIterations() const;
184
185 void SetMaxWarmupDurationMs(int max_duration_ms);
186 int GetMaxWarmupDurationMs() const;
187
188 void SetMaxWarmupIterations(int max_iter);
189 int GetMaxWarmupIterations() const;
190
191 void EnableICacheFlush(bool value);
192 bool IsICacheFlushEnabled() const;
193
194 void SetRotatingBufferSize(int size);
195 int GetRotatingBufferSize() const;
196
197 TuningResultsManager& GetTuningResultsManager();
198
199 TuningResultsValidator& GetTuningResultsValidator();
200
201 TuningResults GetTuningResults();
202
203 TuningStatus LoadTuningResults(const TuningResults& tr);
204
205 void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
206 std::string GetFilename() const;
207
208 void WriteFileOnExit(bool value);
209
210 bool ReadFile(const std::string& filename={});
211 bool WriteFile(const std::string& filename={});
212
213 private:
214 bool enable_;
215 bool tuning_enable_;
216 bool manager_initialized_;
217 bool write_file_on_exit_;
218 bool numerics_check_enable_;
219 int max_tuning_duration_ms_;
220 int max_tuning_iterations_;
221 int max_warmup_duration_ms_;
222 int max_warmup_iterations_;
223 bool icache_flush_;
224 int rotating_buffer_size_;
225 mutable TuningResultsManager manager_;
226 mutable c10::once_flag manager_init_once_;
227 TuningResultsValidator validator_;
228 std::string filename_;
229 size_t results_count_from_input_file_;
230 };
231
232 TORCH_CUDA_CPP_API TuningContext* getTuningContext();
233
234 class ITimer {
235 public:
236 ITimer() = default;
237 virtual ~ITimer() = default;
238
239 virtual void Start() = 0;
240 virtual void End() = 0;
241
242 /// Computes the elapsed time in milliseconds between Start() and End()
243 virtual float Duration() = 0;
244 };
245
246 } // namespace at::cuda::tunable
247