xref: /aosp_15_r20/external/tensorflow/tensorflow/core/lib/monitoring/counter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace tensorflow {
33 namespace monitoring {
34 
35 // CounterCell which has a null implementation.
36 class CounterCell {
37  public:
CounterCell()38   CounterCell() {}
~CounterCell()39   ~CounterCell() {}
40 
IncrementBy(int64 step)41   void IncrementBy(int64 step) {}
value()42   int64 value() const { return 0; }
43 
44  private:
45   TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
46 };
47 
48 // Counter which has a null implementation.
49 template <int NumLabels>
50 class Counter {
51  public:
~Counter()52   ~Counter() {}
53 
54   template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)55   static Counter* New(MetricDefArgs&&... metric_def_args) {
56     return new Counter<NumLabels>();
57   }
58 
59   template <typename... Labels>
GetCell(const Labels &...labels)60   CounterCell* GetCell(const Labels&... labels) {
61     return &default_counter_cell_;
62   }
63 
GetStatus()64   Status GetStatus() { return Status::OK(); }
65 
66  private:
Counter()67   Counter() {}
68 
69   CounterCell default_counter_cell_;
70 
71   TF_DISALLOW_COPY_AND_ASSIGN(Counter);
72 };
73 
74 }  // namespace monitoring
75 }  // namespace tensorflow
76 
77 #else  // IS_MOBILE_PLATFORM
78 
79 #include <array>
80 #include <atomic>
81 #include <map>
82 
83 #include "tensorflow/core/lib/core/status.h"
84 #include "tensorflow/core/lib/monitoring/collection_registry.h"
85 #include "tensorflow/core/lib/monitoring/metric_def.h"
86 #include "tensorflow/core/platform/logging.h"
87 #include "tensorflow/core/platform/macros.h"
88 #include "tensorflow/core/platform/mutex.h"
89 #include "tensorflow/core/platform/thread_annotations.h"
90 
91 namespace tensorflow {
92 namespace monitoring {
93 
94 // CounterCell stores each value of an Counter.
95 //
96 // A cell can be passed off to a module which may repeatedly update it without
97 // needing further map-indexing computations. This improves both encapsulation
98 // (separate modules can own a cell each, without needing to know about the map
99 // to which both cells belong) and performance (since map indexing and
100 // associated locking are both avoided).
101 //
102 // This class is thread-safe.
103 class CounterCell {
104  public:
CounterCell(int64_t value)105   explicit CounterCell(int64_t value) : value_(value) {}
~CounterCell()106   ~CounterCell() {}
107 
108   // Atomically increments the value by step.
109   // REQUIRES: Step be non-negative.
110   void IncrementBy(int64_t step);
111 
112   // Retrieves the current value.
113   int64_t value() const;
114 
115  private:
116   std::atomic<int64_t> value_;
117 
118   TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
119 };
120 
121 // A stateful class for updating a cumulative integer metric.
122 //
123 // This class encapsulates a set of values (or a single value for a label-less
124 // metric). Each value is identified by a tuple of labels. The class allows the
125 // user to increment each value.
126 //
127 // Counter allocates storage and maintains a cell for each value. You can
128 // retrieve an individual cell using a label-tuple and update it separately.
129 // This improves performance since operations related to retrieval, like
130 // map-indexing and locking, are avoided.
131 //
132 // This class is thread-safe.
133 template <int NumLabels>
134 class Counter {
135  public:
~Counter()136   ~Counter() {
137     // Deleted here, before the metric_def is destroyed.
138     registration_handle_.reset();
139   }
140 
141   // Creates the metric based on the metric-definition arguments.
142   //
143   // Example;
144   // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
145   //   "Tensorflow counter", "MyLabelName");
146   template <typename... MetricDefArgs>
147   static Counter* New(MetricDefArgs&&... metric_def_args);
148 
149   // Retrieves the cell for the specified labels, creating it on demand if
150   // not already present.
151   template <typename... Labels>
152   CounterCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
153 
GetStatus()154   Status GetStatus() { return status_; }
155 
156  private:
Counter(const MetricDef<MetricKind::kCumulative,int64_t,NumLabels> & metric_def)157   explicit Counter(
158       const MetricDef<MetricKind::kCumulative, int64_t, NumLabels>& metric_def)
159       : metric_def_(metric_def),
160         registration_handle_(CollectionRegistry::Default()->Register(
161             &metric_def_, [&](MetricCollectorGetter getter) {
162               auto metric_collector = getter.Get(&metric_def_);
163 
164               mutex_lock l(mu_);
165               for (const auto& cell : cells_) {
166                 metric_collector.CollectValue(cell.first, cell.second.value());
167               }
168             })) {
169     if (registration_handle_) {
170       status_ = OkStatus();
171     } else {
172       status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
173                        "Another metric with the same name already exists.");
174     }
175   }
176 
177   mutable mutex mu_;
178 
179   Status status_;
180 
181   // The metric definition. This will be used to identify the metric when we
182   // register it for collection.
183   const MetricDef<MetricKind::kCumulative, int64_t, NumLabels> metric_def_;
184 
185   std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
186 
187   using LabelArray = std::array<string, NumLabels>;
188   std::map<LabelArray, CounterCell> cells_ TF_GUARDED_BY(mu_);
189 
190   TF_DISALLOW_COPY_AND_ASSIGN(Counter);
191 };
192 
193 ////
194 //  Implementation details follow. API readers may skip.
195 ////
196 
IncrementBy(const int64_t step)197 inline void CounterCell::IncrementBy(const int64_t step) {
198   DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
199   value_ += step;
200 }
201 
value()202 inline int64_t CounterCell::value() const { return value_; }
203 
204 template <int NumLabels>
205 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)206 Counter<NumLabels>* Counter<NumLabels>::New(
207     MetricDefArgs&&... metric_def_args) {
208   return new Counter<NumLabels>(
209       MetricDef<MetricKind::kCumulative, int64_t, NumLabels>(
210           std::forward<MetricDefArgs>(metric_def_args)...));
211 }
212 
213 template <int NumLabels>
214 template <typename... Labels>
GetCell(const Labels &...labels)215 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
216     TF_LOCKS_EXCLUDED(mu_) {
217   // Provides a more informative error message than the one during array
218   // construction below.
219   static_assert(sizeof...(Labels) == NumLabels,
220                 "Mismatch between Counter<NumLabels> and number of labels "
221                 "provided in GetCell(...).");
222 
223   const LabelArray& label_array = {{labels...}};
224   mutex_lock l(mu_);
225   const auto found_it = cells_.find(label_array);
226   if (found_it != cells_.end()) {
227     return &(found_it->second);
228   }
229   return &(cells_
230                .emplace(std::piecewise_construct,
231                         std::forward_as_tuple(label_array),
232                         std::forward_as_tuple(0))
233                .first->second);
234 }
235 
236 }  // namespace monitoring
237 }  // namespace tensorflow
238 
239 #endif  // IS_MOBILE_PLATFORM
240 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
241