xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/tunable/Tunable.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Original TunableOp is from onnxruntime.
2 // https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/framework/tunable.h
3 // https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/core/providers/rocm/tunable
4 // Copyright (c) Microsoft Corporation.
5 // Licensed under the MIT license.
6 //
7 // Adapting TunableOp into PyTorch
8 // Copyright (c) Advanced Micro Devices, Inc.
9 //
10 #pragma once
11 
12 #include <c10/util/CallOnce.h>
13 
14 #include <fstream>
15 #include <functional>
16 #include <iostream>
17 #include <memory>
18 #include <mutex>
19 #include <string>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 namespace at::cuda::tunable {
26 
27 namespace detail {
28 
29 struct MaybeDelete {
30   bool owns_pointer;
operatorMaybeDelete31   void operator()(std::ostream* os) const { if (owns_pointer) delete os; }
32 };
33 
34 using OstreamPtr = std::unique_ptr<std::ostream, MaybeDelete>;
35 
get_stream(std::string filename)36 static OstreamPtr get_stream(std::string filename) {
37   if (filename.compare("out") == 0) {
38     return OstreamPtr { &std::cout, MaybeDelete {false} };
39   }
40   else if (filename.compare("err") == 0) {
41     return OstreamPtr { &std::cerr, MaybeDelete {false} };
42   }
43   else {
44     return OstreamPtr { new std::ofstream {filename.c_str()}, MaybeDelete {true} };
45   }
46 }
47 
48 }
49 
TunableLog(int level,const std::string & msg)50 static void TunableLog(int level, const std::string& msg) {
51   static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
52   static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
53   static int level_user = env_verbose ? atoi(env_verbose) : 0;
54   static auto streamptr = detail::get_stream(env_file ? env_file : "err");
55   if (level_user >= level) {
56     (*streamptr) << msg <<std::endl;
57   }
58 }
59 #define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
60 #define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
61 #define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
62 #define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)
63 
64 enum TORCH_CUDA_CPP_API TuningStatus {
65   OK = 0,
66   FAIL = 1,
67   UNSUPPORTED = 2,
68 };
69 
70 // Mapping from params signature to kernel id
71 class TORCH_CUDA_CPP_API ResultEntry {
72   public:
ResultEntry(const std::string & key,double time)73     explicit ResultEntry(const std::string& key, double time) : key_(key), time_(time) {}
74     bool operator==(const ResultEntry& other) { return key_ == other.key_; }
75     bool operator!=(const ResultEntry& other) { return key_ != other.key_; }
string()76     operator std::string () { return key_; }
GetKey()77     std::string GetKey() const { return key_; }
GetTime()78     double GetTime() const { return time_; }
79     friend std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry);
Null()80     static ResultEntry Null() { return ResultEntry("Null", 0.0); }
Default()81     static ResultEntry Default() { return ResultEntry("Default", 0.0); }
82 
83   private:
84     std::string key_;
85     double time_;
86 };
87 
88 typedef std::unordered_map<std::string, ResultEntry> KernelMap;
89 typedef std::unordered_map<std::string, KernelMap> ResultsMap;
90 
91 struct TORCH_CUDA_CPP_API TuningResults {
92   // Validates if these results are compatible with the libraries
93   std::unordered_map<std::string, std::string> validators;
94 
95   // Mapping from Callable signature to Callable's tuning result
96   ResultsMap results;
97 };
98 
99 class TORCH_CUDA_CPP_API TuningResultsManager {
100   public:
101     TuningResultsManager() = default;
102     ~TuningResultsManager() = default;
103 
104     KernelMap Lookup(const std::string& op_signature);
105 
106     ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature);
107 
108     inline void AddImpl(const std::string& op_signature,
109         const std::string& params_signature,
110         ResultEntry best,
111         KernelMap& kernel_map);
112 
113     void Add(const std::string& op_signature,
114         const std::string& params_signature,
115         ResultEntry best);
116 
117     void Delete(const std::string& op_signature, const std::string& params_signature);
118 
119     inline void DisjointMergeImpl(
120         const std::string& op_signature,
121         const KernelMap& kernel_map,
122         /*out*/ ResultsMap& results);
123 
124     void Load(const ResultsMap& results_to_load);
125 
126     ResultsMap Dump();
127 
128     void DisjointMerge(const std::string& op_signature, const KernelMap& kernel_map);
129 
130     size_t GetSize();
131 
132   private:
133     std::mutex lock_;
134     ResultsMap results_;
135 };
136 
137 class TORCH_CUDA_CPP_API TuningResultsValidator {
138   public:
139     using GetFunc = std::function<std::string()>;
140     using ValidateFunc = std::function<TuningStatus(const std::string&)>;
141     using GetValidateFuncs = std::unordered_map<std::string, std::pair<GetFunc, ValidateFunc>>;
142 
143     TuningResultsValidator();
144     ~TuningResultsValidator() = default;
145 
146     std::unordered_map<std::string, std::string> GetAllValidators() const;
147     TuningStatus ValidateAll(const std::unordered_map<std::string, std::string>& to_validate) const;
148     void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf);
149 
150   protected:
151     std::string GetPyTorchVersion() const;
152     TuningStatus ValidatePyTorchVersion(const std::string& value) const;
153 
154   public:
155     static constexpr const std::array mandatory_keys{"PT_VERSION"};
156 
157   private:
158     GetValidateFuncs validators_;
159 };
160 
161 class TORCH_CUDA_CPP_API TuningContext {
162   public:
163     TuningContext();
164     ~TuningContext();
165     TuningContext(TuningContext &) = delete;
166     TuningContext(TuningContext &&) = delete;
167     TuningContext &operator=(TuningContext &) = delete;
168     TuningContext &operator=(TuningContext &&) = delete;
169 
170     void EnableTunableOp(bool value);
171     bool IsTunableOpEnabled() const;
172 
173     void EnableTuning(bool value);
174     bool IsTuningEnabled() const;
175 
176     void EnableNumericsCheck(bool value);
177     bool IsNumericsCheckEnabled() const;
178 
179     void SetMaxTuningDurationMs(int max_duration_ms);
180     int GetMaxTuningDurationMs() const;
181 
182     void SetMaxTuningIterations(int max_iter);
183     int GetMaxTuningIterations() const;
184 
185     void SetMaxWarmupDurationMs(int max_duration_ms);
186     int GetMaxWarmupDurationMs() const;
187 
188     void SetMaxWarmupIterations(int max_iter);
189     int GetMaxWarmupIterations() const;
190 
191     void EnableICacheFlush(bool value);
192     bool IsICacheFlushEnabled() const;
193 
194     void SetRotatingBufferSize(int size);
195     int GetRotatingBufferSize() const;
196 
197     TuningResultsManager& GetTuningResultsManager();
198 
199     TuningResultsValidator& GetTuningResultsValidator();
200 
201     TuningResults GetTuningResults();
202 
203     TuningStatus LoadTuningResults(const TuningResults& tr);
204 
205     void SetFilename(const std::string& filename, bool insert_device_ordinal=false);
206     std::string GetFilename() const;
207 
208     void WriteFileOnExit(bool value);
209 
210     bool ReadFile(const std::string& filename={});
211     bool WriteFile(const std::string& filename={});
212 
213   private:
214     bool enable_;
215     bool tuning_enable_;
216     bool manager_initialized_;
217     bool write_file_on_exit_;
218     bool numerics_check_enable_;
219     int max_tuning_duration_ms_;
220     int max_tuning_iterations_;
221     int max_warmup_duration_ms_;
222     int max_warmup_iterations_;
223     bool icache_flush_;
224     int rotating_buffer_size_;
225     mutable TuningResultsManager manager_;
226     mutable c10::once_flag manager_init_once_;
227     TuningResultsValidator validator_;
228     std::string filename_;
229     size_t results_count_from_input_file_;
230 };
231 
232 TORCH_CUDA_CPP_API TuningContext* getTuningContext();
233 
234 class ITimer {
235   public:
236     ITimer() = default;
237     virtual ~ITimer() = default;
238 
239     virtual void Start() = 0;
240     virtual void End() = 0;
241 
242     /// Computes the elapsed time in milliseconds between Start() and End()
243     virtual float Duration() = 0;
244 };
245 
246 } // namespace at::cuda::tunable
247