xref: /aosp_15_r20/external/webrtc/test/testsupport/perf_test.cc (revision d9f758449e529ab9291ac668be2861e7a55c2422)
1 /*
2  *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "test/testsupport/perf_test.h"
12 
13 #include <stdio.h>
14 
15 #include <algorithm>
16 #include <fstream>
17 #include <set>
18 #include <sstream>
19 #include <vector>
20 
21 #include "absl/strings/string_view.h"
22 #include "api/numerics/samples_stats_counter.h"
23 #include "rtc_base/checks.h"
24 #include "rtc_base/strings/string_builder.h"
25 #include "rtc_base/synchronization/mutex.h"
26 #include "test/testsupport/file_utils.h"
27 #include "test/testsupport/perf_test_histogram_writer.h"
28 
29 namespace webrtc {
30 namespace test {
31 
32 namespace {
33 
UnitWithDirection(absl::string_view units,webrtc::test::ImproveDirection improve_direction)34 std::string UnitWithDirection(
35     absl::string_view units,
36     webrtc::test::ImproveDirection improve_direction) {
37   switch (improve_direction) {
38     case webrtc::test::ImproveDirection::kNone:
39       return std::string(units);
40     case webrtc::test::ImproveDirection::kSmallerIsBetter:
41       return std::string(units) + "_smallerIsBetter";
42     case webrtc::test::ImproveDirection::kBiggerIsBetter:
43       return std::string(units) + "_biggerIsBetter";
44   }
45 }
46 
GetSortedSamples(const SamplesStatsCounter & counter)47 std::vector<SamplesStatsCounter::StatsSample> GetSortedSamples(
48     const SamplesStatsCounter& counter) {
49   rtc::ArrayView<const SamplesStatsCounter::StatsSample> view =
50       counter.GetTimedSamples();
51   std::vector<SamplesStatsCounter::StatsSample> out(view.begin(), view.end());
52   std::stable_sort(out.begin(), out.end(),
53                    [](const SamplesStatsCounter::StatsSample& a,
54                       const SamplesStatsCounter::StatsSample& b) {
55                      return a.time < b.time;
56                    });
57   return out;
58 }
59 
60 template <typename Container>
OutputListToStream(std::ostream * ostream,const Container & values)61 void OutputListToStream(std::ostream* ostream, const Container& values) {
62   const char* sep = "";
63   for (const auto& v : values) {
64     (*ostream) << sep << v;
65     sep = ",";
66   }
67 }
68 
69 struct PlottableCounter {
70   std::string graph_name;
71   std::string trace_name;
72   webrtc::SamplesStatsCounter counter;
73   std::string units;
74 };
75 
76 class PlottableCounterPrinter {
77  public:
PlottableCounterPrinter()78   PlottableCounterPrinter() : output_(stdout) {}
79 
SetOutput(FILE * output)80   void SetOutput(FILE* output) {
81     MutexLock lock(&mutex_);
82     output_ = output;
83   }
84 
AddCounter(absl::string_view graph_name,absl::string_view trace_name,const webrtc::SamplesStatsCounter & counter,absl::string_view units)85   void AddCounter(absl::string_view graph_name,
86                   absl::string_view trace_name,
87                   const webrtc::SamplesStatsCounter& counter,
88                   absl::string_view units) {
89     MutexLock lock(&mutex_);
90     plottable_counters_.push_back({std::string(graph_name),
91                                    std::string(trace_name), counter,
92                                    std::string(units)});
93   }
94 
Print(const std::vector<std::string> & desired_graphs_raw) const95   void Print(const std::vector<std::string>& desired_graphs_raw) const {
96     std::set<std::string> desired_graphs(desired_graphs_raw.begin(),
97                                          desired_graphs_raw.end());
98     MutexLock lock(&mutex_);
99     for (auto& counter : plottable_counters_) {
100       if (!desired_graphs.empty()) {
101         auto it = desired_graphs.find(counter.graph_name);
102         if (it == desired_graphs.end()) {
103           continue;
104         }
105       }
106 
107       std::ostringstream value_stream;
108       value_stream.precision(8);
109       value_stream << R"({"graph_name":")" << counter.graph_name << R"(",)";
110       value_stream << R"("trace_name":")" << counter.trace_name << R"(",)";
111       value_stream << R"("units":")" << counter.units << R"(",)";
112       if (!counter.counter.IsEmpty()) {
113         value_stream << R"("mean":)" << counter.counter.GetAverage() << ',';
114         value_stream << R"("std":)" << counter.counter.GetStandardDeviation()
115                      << ',';
116       }
117       value_stream << R"("samples":[)";
118       const char* sep = "";
119       for (const auto& sample : counter.counter.GetTimedSamples()) {
120         value_stream << sep << R"({"time":)" << sample.time.us() << ','
121                      << R"("value":)" << sample.value << '}';
122         sep = ",";
123       }
124       value_stream << "]}";
125 
126       fprintf(output_, "PLOTTABLE_DATA: %s\n", value_stream.str().c_str());
127     }
128   }
129 
130  private:
131   mutable Mutex mutex_;
132   std::vector<PlottableCounter> plottable_counters_ RTC_GUARDED_BY(&mutex_);
133   FILE* output_ RTC_GUARDED_BY(&mutex_);
134 };
135 
GetPlottableCounterPrinter()136 PlottableCounterPrinter& GetPlottableCounterPrinter() {
137   static PlottableCounterPrinter* printer_ = new PlottableCounterPrinter();
138   return *printer_;
139 }
140 
141 class ResultsLinePrinter {
142  public:
ResultsLinePrinter()143   ResultsLinePrinter() : output_(stdout) {}
144 
SetOutput(FILE * output)145   void SetOutput(FILE* output) {
146     MutexLock lock(&mutex_);
147     output_ = output;
148   }
149 
PrintResult(absl::string_view graph_name,absl::string_view trace_name,const double value,absl::string_view units,bool important,ImproveDirection improve_direction)150   void PrintResult(absl::string_view graph_name,
151                    absl::string_view trace_name,
152                    const double value,
153                    absl::string_view units,
154                    bool important,
155                    ImproveDirection improve_direction) {
156     std::ostringstream value_stream;
157     value_stream.precision(8);
158     value_stream << value;
159 
160     PrintResultImpl(graph_name, trace_name, value_stream.str(), std::string(),
161                     std::string(), UnitWithDirection(units, improve_direction),
162                     important);
163   }
164 
PrintResultMeanAndError(absl::string_view graph_name,absl::string_view trace_name,const double mean,const double error,absl::string_view units,bool important,ImproveDirection improve_direction)165   void PrintResultMeanAndError(absl::string_view graph_name,
166                                absl::string_view trace_name,
167                                const double mean,
168                                const double error,
169                                absl::string_view units,
170                                bool important,
171                                ImproveDirection improve_direction) {
172     std::ostringstream value_stream;
173     value_stream.precision(8);
174     value_stream << mean << ',' << error;
175     PrintResultImpl(graph_name, trace_name, value_stream.str(), "{", "}",
176                     UnitWithDirection(units, improve_direction), important);
177   }
178 
PrintResultList(absl::string_view graph_name,absl::string_view trace_name,const rtc::ArrayView<const double> values,absl::string_view units,const bool important,webrtc::test::ImproveDirection improve_direction)179   void PrintResultList(absl::string_view graph_name,
180                        absl::string_view trace_name,
181                        const rtc::ArrayView<const double> values,
182                        absl::string_view units,
183                        const bool important,
184                        webrtc::test::ImproveDirection improve_direction) {
185     std::ostringstream value_stream;
186     value_stream.precision(8);
187     OutputListToStream(&value_stream, values);
188     PrintResultImpl(graph_name, trace_name, value_stream.str(), "[", "]", units,
189                     important);
190   }
191 
192  private:
PrintResultImpl(absl::string_view graph_name,absl::string_view trace_name,absl::string_view values,absl::string_view prefix,absl::string_view suffix,absl::string_view units,bool important)193   void PrintResultImpl(absl::string_view graph_name,
194                        absl::string_view trace_name,
195                        absl::string_view values,
196                        absl::string_view prefix,
197                        absl::string_view suffix,
198                        absl::string_view units,
199                        bool important) {
200     MutexLock lock(&mutex_);
201     rtc::StringBuilder message;
202     message << (important ? "*" : "") << "RESULT " << graph_name << ": "
203             << trace_name << "= " << prefix << values << suffix << " " << units;
204     // <*>RESULT <graph_name>: <trace_name>= <value> <units>
205     // <*>RESULT <graph_name>: <trace_name>= {<mean>, <std deviation>} <units>
206     // <*>RESULT <graph_name>: <trace_name>= [<value>,value,value,...,] <units>
207     fprintf(output_, "%s\n", message.str().c_str());
208   }
209 
210   Mutex mutex_;
211   FILE* output_ RTC_GUARDED_BY(&mutex_);
212 };
213 
GetResultsLinePrinter()214 ResultsLinePrinter& GetResultsLinePrinter() {
215   static ResultsLinePrinter* const printer_ = new ResultsLinePrinter();
216   return *printer_;
217 }
218 
GetPerfWriter()219 PerfTestResultWriter& GetPerfWriter() {
220   static PerfTestResultWriter* writer = CreateHistogramWriter();
221   return *writer;
222 }
223 
224 }  // namespace
225 
ClearPerfResults()226 void ClearPerfResults() {
227   GetPerfWriter().ClearResults();
228 }
229 
SetPerfResultsOutput(FILE * output)230 void SetPerfResultsOutput(FILE* output) {
231   GetPlottableCounterPrinter().SetOutput(output);
232   GetResultsLinePrinter().SetOutput(output);
233 }
234 
GetPerfResults()235 std::string GetPerfResults() {
236   return GetPerfWriter().Serialize();
237 }
238 
PrintPlottableResults(const std::vector<std::string> & desired_graphs)239 void PrintPlottableResults(const std::vector<std::string>& desired_graphs) {
240   GetPlottableCounterPrinter().Print(desired_graphs);
241 }
242 
WritePerfResults(const std::string & output_path)243 bool WritePerfResults(const std::string& output_path) {
244   std::string results = GetPerfResults();
245   CreateDir(DirName(output_path));
246   FILE* output = fopen(output_path.c_str(), "wb");
247   if (output == NULL) {
248     printf("Failed to write to %s.\n", output_path.c_str());
249     return false;
250   }
251   size_t written =
252       fwrite(results.c_str(), sizeof(char), results.size(), output);
253   fclose(output);
254 
255   if (written != results.size()) {
256     long expected = results.size();
257     printf("Wrote %zu, tried to write %lu\n", written, expected);
258     return false;
259   }
260 
261   return true;
262 }
263 
PrintResult(absl::string_view measurement,absl::string_view modifier,absl::string_view trace,const double value,absl::string_view units,bool important,ImproveDirection improve_direction)264 void PrintResult(absl::string_view measurement,
265                  absl::string_view modifier,
266                  absl::string_view trace,
267                  const double value,
268                  absl::string_view units,
269                  bool important,
270                  ImproveDirection improve_direction) {
271   rtc::StringBuilder graph_name;
272   graph_name << measurement << modifier;
273   RTC_CHECK(std::isfinite(value))
274       << "Expected finite value for graph " << graph_name.str()
275       << ", trace name " << trace << ", units " << units << ", got " << value;
276   GetPerfWriter().LogResult(graph_name.str(), trace, value, units, important,
277                             improve_direction);
278   GetResultsLinePrinter().PrintResult(graph_name.str(), trace, value, units,
279                                       important, improve_direction);
280 }
281 
PrintResult(absl::string_view measurement,absl::string_view modifier,absl::string_view trace,const SamplesStatsCounter & counter,absl::string_view units,const bool important,ImproveDirection improve_direction)282 void PrintResult(absl::string_view measurement,
283                  absl::string_view modifier,
284                  absl::string_view trace,
285                  const SamplesStatsCounter& counter,
286                  absl::string_view units,
287                  const bool important,
288                  ImproveDirection improve_direction) {
289   rtc::StringBuilder graph_name;
290   graph_name << measurement << modifier;
291   GetPlottableCounterPrinter().AddCounter(graph_name.str(), trace, counter,
292                                           units);
293 
294   double mean = counter.IsEmpty() ? 0 : counter.GetAverage();
295   double error = counter.IsEmpty() ? 0 : counter.GetStandardDeviation();
296 
297   std::vector<SamplesStatsCounter::StatsSample> timed_samples =
298       GetSortedSamples(counter);
299   std::vector<double> samples(timed_samples.size());
300   for (size_t i = 0; i < timed_samples.size(); ++i) {
301     samples[i] = timed_samples[i].value;
302   }
303   // If we have an empty counter, default it to 0.
304   if (samples.empty()) {
305     samples.push_back(0);
306   }
307 
308   GetPerfWriter().LogResultList(graph_name.str(), trace, samples, units,
309                                 important, improve_direction);
310   GetResultsLinePrinter().PrintResultMeanAndError(graph_name.str(), trace, mean,
311                                                   error, units, important,
312                                                   improve_direction);
313 }
314 
PrintResultMeanAndError(absl::string_view measurement,absl::string_view modifier,absl::string_view trace,const double mean,const double error,absl::string_view units,bool important,ImproveDirection improve_direction)315 void PrintResultMeanAndError(absl::string_view measurement,
316                              absl::string_view modifier,
317                              absl::string_view trace,
318                              const double mean,
319                              const double error,
320                              absl::string_view units,
321                              bool important,
322                              ImproveDirection improve_direction) {
323   RTC_CHECK(std::isfinite(mean));
324   RTC_CHECK(std::isfinite(error));
325 
326   rtc::StringBuilder graph_name;
327   graph_name << measurement << modifier;
328   GetPerfWriter().LogResultMeanAndError(graph_name.str(), trace, mean, error,
329                                         units, important, improve_direction);
330   GetResultsLinePrinter().PrintResultMeanAndError(graph_name.str(), trace, mean,
331                                                   error, units, important,
332                                                   improve_direction);
333 }
334 
PrintResultList(absl::string_view measurement,absl::string_view modifier,absl::string_view trace,const rtc::ArrayView<const double> values,absl::string_view units,bool important,ImproveDirection improve_direction)335 void PrintResultList(absl::string_view measurement,
336                      absl::string_view modifier,
337                      absl::string_view trace,
338                      const rtc::ArrayView<const double> values,
339                      absl::string_view units,
340                      bool important,
341                      ImproveDirection improve_direction) {
342   for (double v : values) {
343     RTC_CHECK(std::isfinite(v));
344   }
345 
346   rtc::StringBuilder graph_name;
347   graph_name << measurement << modifier;
348   GetPerfWriter().LogResultList(graph_name.str(), trace, values, units,
349                                 important, improve_direction);
350   GetResultsLinePrinter().PrintResultList(graph_name.str(), trace, values,
351                                           units, important, improve_direction);
352 }
353 
354 }  // namespace test
355 }  // namespace webrtc
356