xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/monitoring/cell_reader-inl.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/lib/monitoring/cell_reader-inl.h"
16 
17 #include <algorithm>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
25 #include "tensorflow/core/lib/monitoring/collection_registry.h"
26 #include "tensorflow/core/lib/monitoring/metric_def.h"
27 #include "tensorflow/core/lib/monitoring/test_utils.h"
28 #include "tensorflow/core/lib/monitoring/types.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/statusor.h"
31 
32 namespace tensorflow {
33 namespace monitoring {
34 namespace testing {
35 namespace internal {
36 namespace {
37 
38 // Returns the labels of `point` as a vector of strings.
GetLabels(const monitoring::Point & point)39 std::vector<std::string> GetLabels(const monitoring::Point& point) {
40   std::vector<std::string> labels;
41   labels.reserve(point.labels.size());
42   for (const monitoring::Point::Label& label : point.labels) {
43     labels.push_back(label.value);
44   }
45   return labels;
46 }
47 }  // namespace
48 
CollectMetrics()49 std::unique_ptr<CollectedMetrics> CollectMetrics() {
50   CollectionRegistry::CollectMetricsOptions options;
51   return CollectionRegistry::Default()->CollectMetrics(options);
52 }
53 
GetMetricKind(const CollectedMetrics & metrics,const std::string & metric_name)54 MetricKind GetMetricKind(const CollectedMetrics& metrics,
55                          const std::string& metric_name) {
56   auto metric_descriptor = metrics.metric_descriptor_map.find(metric_name);
57   if (metric_descriptor == metrics.metric_descriptor_map.end()) {
58     return MetricKind::kCumulative;
59   }
60   return metric_descriptor->second->metric_kind;
61 }
62 
GetPoints(const CollectedMetrics & metrics,const std::string & metric_name,const std::vector<std::string> & labels)63 StatusOr<std::vector<Point>> GetPoints(const CollectedMetrics& metrics,
64                                        const std::string& metric_name,
65                                        const std::vector<std::string>& labels) {
66   auto metric_descriptor = metrics.metric_descriptor_map.find(metric_name);
67   if (metric_descriptor == metrics.metric_descriptor_map.end()) {
68     return errors::NotFound("Metric descriptor is not found for metric ",
69                             metric_name, ".");
70   }
71   const std::vector<string>& label_names =
72       metric_descriptor->second->label_names;
73   if (label_names.size() != labels.size()) {
74     return errors::InvalidArgument(
75         "Metric ", metric_name, " has ", label_names.size(), " labels: [",
76         absl::StrJoin(label_names, ", "), "]. Got label values [",
77         absl::StrJoin(labels, ", "), "].");
78   }
79   auto point_set = metrics.point_set_map.find(metric_name);
80   if (point_set == metrics.point_set_map.end()) {
81     return errors::NotFound("Metric point set is not found for metric ",
82                             metric_name, ".");
83   }
84 
85   std::vector<Point> points;
86   for (const std::unique_ptr<Point>& point : point_set->second->points) {
87     if (GetLabels(*point) == labels) {
88       points.push_back(*point);
89     }
90   }
91   return points;
92 }
93 
GetLatestPoint(const CollectedMetrics & metrics,const std::string & metric_name,const std::vector<std::string> & labels)94 StatusOr<Point> GetLatestPoint(const CollectedMetrics& metrics,
95                                const std::string& metric_name,
96                                const std::vector<std::string>& labels) {
97   TF_ASSIGN_OR_RETURN(std::vector<Point> points,
98                       GetPoints(metrics, metric_name, labels));
99   if (points.empty()) {
100     return errors::Unavailable("No data collected for metric ", metric_name,
101                                " with labels [", absl::StrJoin(labels, ", "),
102                                "].");
103   }
104 
105   bool same_start_time =
106       std::all_of(points.begin(), points.end(), [&points](const Point& point) {
107         return point.start_timestamp_millis == points[0].start_timestamp_millis;
108       });
109   if (!same_start_time) {
110     return errors::Internal(
111         "Collected cumulative metrics should have the same start timestamp "
112         "(the registration timestamp). This error implies a bug in the "
113         "`tensorflow::monitoring::testing::CellReader` library.");
114   }
115 
116   std::sort(points.begin(), points.end(), [](const Point& a, const Point& b) {
117     return a.end_timestamp_millis < b.end_timestamp_millis;
118   });
119   return points.back();
120 }
121 
122 template <>
GetValue(const Point & point)123 int64_t GetValue(const Point& point) {
124   return point.int64_value;
125 }
126 
127 template <>
GetValue(const Point & point)128 std::string GetValue(const Point& point) {
129   return point.string_value;
130 }
131 
132 template <>
GetValue(const Point & point)133 bool GetValue(const Point& point) {
134   return point.bool_value;
135 }
136 
137 template <>
GetValue(const Point & point)138 Histogram GetValue(const Point& point) {
139   return Histogram(point.histogram_value);
140 }
141 
142 template <>
GetValue(const Point & point)143 Percentiles GetValue(const Point& point) {
144   return Percentiles(point.percentiles_value);
145 }
146 
147 template <>
GetDelta(const int64_t & a,const int64_t & b)148 int64_t GetDelta(const int64_t& a, const int64_t& b) {
149   return a - b;
150 }
151 
152 template <>
GetDelta(const Histogram & a,const Histogram & b)153 Histogram GetDelta(const Histogram& a, const Histogram& b) {
154   StatusOr<Histogram> result = a.Subtract(b);
155   if (!result.ok()) {
156     LOG(FATAL) << "Failed to compute the delta between histograms: "
157                << result.status();
158   }
159   return *result;
160 }
161 
162 template <>
GetDelta(const Percentiles & a,const Percentiles & b)163 Percentiles GetDelta(const Percentiles& a, const Percentiles& b) {
164   return a.Subtract(b);
165 }
166 
167 template <>
GetDelta(const std::string & a,const std::string & b)168 std::string GetDelta(const std::string& a, const std::string& b) {
169   LOG(FATAL) << "`CellReader<std::string>` does not support `Delta`. "
170              << "Please use `Read` instead.";
171 }
172 
173 template <>
GetDelta(const bool & a,const bool & b)174 bool GetDelta(const bool& a, const bool& b) {
175   LOG(FATAL) << "`CellReader<bool>` does not support `Delta`. "
176              << "Please use `Read` instead.";
177 }
178 
179 }  // namespace internal
180 }  // namespace testing
181 }  // namespace monitoring
182 }  // namespace tensorflow
183