xref: /aosp_15_r20/external/pytorch/binaries/compare_models_torch.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Copyright (c) 2016-present, Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <iomanip>
18 #include <string>
19 #include <vector>
20 
21 #include <ATen/ATen.h>
22 #include <caffe2/core/timer.h>
23 #include <caffe2/utils/string_utils.h>
24 #include <torch/csrc/autograd/grad_mode.h>
25 #include <torch/csrc/jit/serialization/import.h>
26 #include <torch/script.h>
27 
28 #include <c10/mobile/CPUCachingAllocator.h>
29 
30 C10_DEFINE_string(
31     refmodel,
32     "",
33     "The reference torch script model to compare against.");
34 C10_DEFINE_string(
35     model,
36     "",
37     "The torch script model to compare to the reference model.");
38 C10_DEFINE_string(
39     input_dims,
40     "",
41     "Alternate to input_files, if all inputs are simple "
42     "float TensorCPUs, specify the dimension using comma "
43     "separated numbers. If multiple input needed, use "
44     "semicolon to separate the dimension of different "
45     "tensors.");
46 C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
47 C10_DEFINE_string(
48     input_memory_format,
49     "contiguous_format",
50     "Input memory format (contiguous_format/channels_last)");
51 C10_DEFINE_int(input_max, 1, "The maximum value inputs should have");
52 C10_DEFINE_int(input_min, -1, "The minimum value inputs should have");
53 C10_DEFINE_bool(
54     no_inputs,
55     false,
56     "Whether the model has any input. Will ignore other input arguments if true");
57 C10_DEFINE_bool(
58     use_caching_allocator,
59     false,
60     "Whether to cache allocations between inference iterations");
61 C10_DEFINE_bool(
62     print_output,
63     false,
64     "Whether to print output with all one input tensor.");
65 C10_DEFINE_int(iter, 10, "The number of iterations to run.");
66 C10_DEFINE_int(report_freq, 1000, "An update will be reported every n iterations");
67 C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
68 C10_DEFINE_string(
69     backend,
70     "cpu",
71     "what backend to use for model (vulkan, cpu, metal) (default=cpu)");
72 C10_DEFINE_string(
73     refbackend,
74     "cpu",
75     "what backend to use for model (vulkan, cpu, metal) (default=cpu)");
76 C10_DEFINE_string(tolerance, "1e-5", "tolerance to use for comparison");
77 C10_DEFINE_int(nthreads, 1, "Number of threads to launch. Useful for checking correct concurrent behaviour.");
78 C10_DEFINE_bool(
79     report_failures,
80     true,
81     "Whether to report error during failed iterations");
82 
checkRtol(const at::Tensor & diff,const std::vector<at::Tensor> & inputs,float tolerance,bool report)83 bool checkRtol(
84     const at::Tensor& diff,
85     const std::vector<at::Tensor>& inputs,
86     float tolerance,
87     bool report) {
88   float maxValue = 0.0f;
89 
90   for (const auto& tensor : inputs) {
91     maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
92   }
93   float threshold = tolerance * maxValue;
94   float maxDiff = diff.abs().max().item<float>();
95 
96   bool passed = maxDiff < threshold;
97   if (!passed && report) {
98     std::cout << "Check FAILED!      Max diff allowed: "
99               << std::setw(10) << std::setprecision(5) << threshold
100               << "     max diff: "
101               << std::setw(10) << std::setprecision(5) << maxDiff
102               << std::endl;
103   }
104 
105   return passed;
106 }
107 
report_pass_rate(int passed,int total)108 void report_pass_rate(int passed, int total) {
109   int pass_rate = static_cast<int>(static_cast<float>(passed) / static_cast<float>(total) * 100);
110   std::cout << "Output was equal within tolerance " << passed << "/"
111             << total
112             << " times. Pass rate: " << pass_rate
113             << std::setprecision(2) << "%" << std::endl;
114 }
115 
split(char separator,const std::string & string,bool ignore_empty=true)116 std::vector<std::string> split(
117     char separator,
118     const std::string& string,
119     bool ignore_empty = true) {
120   std::vector<std::string> pieces;
121   std::stringstream ss(string);
122   std::string item;
123   while (getline(ss, item, separator)) {
124     if (!ignore_empty || !item.empty()) {
125       pieces.push_back(std::move(item));
126     }
127   }
128   return pieces;
129 }
130 
create_inputs(std::vector<c10::IValue> & refinputs,std::vector<c10::IValue> & inputs,std::string & refbackend,std::string & backend,const int range_min,const int range_max)131 std::vector<c10::IValue> create_inputs(
132     std::vector<c10::IValue>& refinputs,
133     std::vector<c10::IValue>& inputs,
134     std::string& refbackend,
135     std::string& backend,
136     const int range_min,
137     const int range_max) {
138   if (FLAGS_no_inputs) {
139     return {};
140   }
141 
142   CAFFE_ENFORCE_GE(FLAGS_input_dims.size(), 0, "Input dims must be specified.");
143   CAFFE_ENFORCE_GE(FLAGS_input_type.size(), 0, "Input type must be specified.");
144 
145   std::vector<std::string> input_dims_list = split(';', FLAGS_input_dims);
146   std::vector<std::string> input_type_list = split(';', FLAGS_input_type);
147   std::vector<std::string> input_memory_format_list =
148       split(';', FLAGS_input_memory_format);
149 
150   CAFFE_ENFORCE_GE(
151       input_dims_list.size(), 0, "Input dims not specified correctly.");
152   CAFFE_ENFORCE_GE(
153       input_type_list.size(), 0, "Input type not specified correctly.");
154   CAFFE_ENFORCE_GE(
155       input_memory_format_list.size(),
156       0,
157       "Input format list not specified correctly.");
158 
159   CAFFE_ENFORCE_EQ(
160       input_dims_list.size(),
161       input_type_list.size(),
162       "Input dims and type should have the same number of items.");
163   CAFFE_ENFORCE_EQ(
164       input_dims_list.size(),
165       input_memory_format_list.size(),
166       "Input dims and format should have the same number of items.");
167 
168   for (size_t i = 0; i < input_dims_list.size(); ++i) {
169     auto input_dims_str = split(',', input_dims_list[i]);
170     std::vector<int64_t> input_dims;
171     input_dims.reserve(input_dims_str.size());
172     for (const auto& s : input_dims_str) {
173       input_dims.push_back(std::stoi(s));
174     }
175 
176     at::ScalarType input_type;
177     if (input_type_list[i] == "float") {
178       input_type = at::ScalarType::Float;
179     } else if (input_type_list[i] == "uint8_t") {
180       input_type = at::ScalarType::Byte;
181     } else if (input_type_list[i] == "int64") {
182       input_type = at::ScalarType::Long;
183     } else {
184       CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
185     }
186 
187     at::MemoryFormat input_memory_format;
188     if (input_memory_format_list[i] == "channels_last") {
189       if (input_dims.size() != 4u) {
190         CAFFE_THROW(
191             "channels_last memory format only available on 4D tensors!");
192       }
193       input_memory_format = at::MemoryFormat::ChannelsLast;
194     } else if (input_memory_format_list[i] == "contiguous_format") {
195       input_memory_format = at::MemoryFormat::Contiguous;
196     } else {
197       CAFFE_THROW(
198           "Unsupported input memory format: ", input_memory_format_list[i]);
199     }
200 
201     const auto input_tensor = torch::rand(
202         input_dims,
203         at::TensorOptions(input_type).memory_format(input_memory_format))*(range_max - range_min) - range_min;
204 
205     if (refbackend == "vulkan") {
206       refinputs.emplace_back(input_tensor.vulkan());
207     } else {
208       refinputs.emplace_back(input_tensor);
209     }
210 
211     if (backend == "vulkan") {
212       inputs.emplace_back(input_tensor.vulkan());
213     } else {
214       inputs.emplace_back(input_tensor);
215     }
216   }
217 
218   if (FLAGS_pytext_len > 0) {
219     auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
220     if (refbackend == "vulkan") {
221       refinputs.emplace_back(stensor.vulkan());
222     } else {
223       refinputs.emplace_back(stensor);
224     }
225 
226     if (backend == "vulkan") {
227       inputs.emplace_back(stensor.vulkan());
228     } else {
229       inputs.emplace_back(stensor);
230     }
231   }
232 
233   return inputs;
234 }
235 
run_check(float tolerance)236 void run_check(float tolerance) {
237   torch::jit::Module module = torch::jit::load(FLAGS_model);
238   torch::jit::Module refmodule = torch::jit::load(FLAGS_refmodel);
239 
240   module.eval();
241   refmodule.eval();
242 
243   std::thread::id this_id = std::this_thread::get_id();
244   std::cout << "Running check on thread " << this_id << "." << std::endl;
245 
246   int passed = 0;
247   for (int i = 0; i < FLAGS_iter; ++i) {
248     std::vector<c10::IValue> refinputs;
249     std::vector<c10::IValue> inputs;
250     create_inputs(
251         refinputs, inputs,
252         FLAGS_refbackend, FLAGS_backend,
253         FLAGS_input_min, FLAGS_input_max);
254 
255     const auto refoutput = refmodule.forward(refinputs).toTensor().cpu();
256     const auto output = module.forward(inputs).toTensor().cpu();
257 
258     bool check = checkRtol(
259         refoutput-output,
260         {refoutput, output},
261         tolerance,
262         FLAGS_report_failures);
263 
264     if (check) {
265       passed += 1;
266     }
267     else if (FLAGS_report_failures) {
268       std::cout << " (Iteration " << i << " failed)" << std::endl;
269     }
270 
271     if (i > 0 && (i+1) % FLAGS_report_freq == 0) {
272       report_pass_rate(passed, i+1);
273     }
274   }
275   report_pass_rate(passed, FLAGS_iter);
276 }
277 
main(int argc,char ** argv)278 int main(int argc, char** argv) {
279   c10::SetUsageMessage(
280       "Run accuracy comparison to a reference model for a pytorch model.\n"
281       "Example usage:\n"
282       "./compare_models_torch"
283       " --refmodel=<ref_model_file>"
284       " --model=<model_file>"
285       " --iter=20");
286   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
287     std::cerr << "Failed to parse command line flags!" << std::endl;
288     return 1;
289   }
290 
291   if (FLAGS_input_min >= FLAGS_input_max) {
292     std::cerr << "Input min: " << FLAGS_input_min
293               << " should be less than input max: "
294               << FLAGS_input_max << std::endl;
295     return 1;
296   }
297 
298   std::stringstream ss(FLAGS_tolerance);
299   float tolerance = 0;
300   ss >> tolerance;
301   std::cout << "tolerance: " << tolerance << std::endl;
302 
303   c10::InferenceMode mode;
304   torch::autograd::AutoGradMode guard(false);
305   torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard(false);
306 
307   c10::CPUCachingAllocator caching_allocator;
308   std::optional<c10::WithCPUCachingAllocatorGuard> caching_allocator_guard;
309   if (FLAGS_use_caching_allocator) {
310     caching_allocator_guard.emplace(&caching_allocator);
311   }
312 
313   std::vector<std::thread> check_threads;
314   check_threads.reserve(FLAGS_nthreads);
315   for (int i = 0; i < FLAGS_nthreads; ++i) {
316     check_threads.emplace_back(std::thread(run_check, tolerance));
317   }
318 
319   for (std::thread& th : check_threads) {
320     if (th.joinable()) {
321       th.join();
322     }
323   }
324 
325   return 0;
326 }
327