xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/experimental/stats_aggregator_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16 
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/resource_op_kernel.h"
19 #include "tensorflow/core/framework/stats_aggregator.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/kernels/summary_interface.h"
22 #include "tensorflow/core/lib/core/refcount.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/histogram/histogram.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/sampler.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/events_writer.h"
30 
31 namespace tensorflow {
32 namespace data {
33 namespace experimental {
34 namespace {
35 
get_counters_map_lock()36 static mutex* get_counters_map_lock() {
37   static mutex counters_map_lock(LINKER_INITIALIZED);
38   return &counters_map_lock;
39 }
40 
get_counters_map()41 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
42   static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
43       new std::unordered_map<string, monitoring::Counter<1>*>;
44   return counters_map;
45 }
46 
47 class StatsAggregatorImpl : public StatsAggregator {
48  public:
StatsAggregatorImpl()49   StatsAggregatorImpl() {}
50 
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64_t steps)51   void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
52                       const int64_t steps) override {
53     mutex_lock l(mu_);
54     histogram::Histogram& histogram = histograms_[name];
55     for (double value : values) {
56       histogram.Add(value);
57     }
58   }
59 
AddScalar(const string & name,float value,const int64_t steps)60   void AddScalar(const string& name, float value,
61                  const int64_t steps) override {
62     mutex_lock l(mu_);
63     scalars_[name] = value;
64   }
65 
EncodeToProto(Summary * out_summary)66   void EncodeToProto(Summary* out_summary) override {
67     mutex_lock l(mu_);
68     for (const auto& pair : histograms_) {
69       const string& name = pair.first;
70       const histogram::Histogram& histogram = pair.second;
71 
72       Summary::Value* value = out_summary->add_value();
73       value->set_tag(name);
74       histogram.EncodeToProto(value->mutable_histo(),
75                               false /* doesn't preserve zero buckets */);
76     }
77     for (const auto& pair : scalars_) {
78       Summary::Value* value = out_summary->add_value();
79       value->set_tag(pair.first);
80       value->set_simple_value(pair.second);
81     }
82   }
83 
84   // StatsAggregator implementation for V2 is based on push-based summary, no-op
85   // in V1.
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)86   Status SetSummaryWriter(
87       SummaryWriterInterface* summary_writer_interface) override {
88     return OkStatus();
89   }
90 
IncrementCounter(const string & name,const string & label,int64_t val)91   void IncrementCounter(const string& name, const string& label,
92                         int64_t val) override {
93     mutex_lock l(*get_counters_map_lock());
94     auto counters_map = get_counters_map();
95     if (counters_map->find(name) == counters_map->end()) {
96       counters_map->emplace(
97           name,
98           monitoring::Counter<1>::New(
99               /*streamz name*/ name,
100               /*streamz description*/
101               strings::StrCat(name, " generated or consumed by the component."),
102               /*streamz label name*/ "component_descriptor"));
103     }
104     counters_map->at(name)->GetCell(label)->IncrementBy(val);
105   }
106 
107  private:
108   mutex mu_;
109   std::unordered_map<string, histogram::Histogram> histograms_
110       TF_GUARDED_BY(mu_);
111   std::unordered_map<string, float> scalars_ TF_GUARDED_BY(mu_);
112   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
113 };
114 
115 class StatsAggregatorHandleOp
116     : public ResourceOpKernel<StatsAggregatorResource> {
117  public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)118   explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
119       : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
120 
121  private:
CreateResource(StatsAggregatorResource ** ret)122   Status CreateResource(StatsAggregatorResource** ret) override
123       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
124     *ret = new StatsAggregatorResource(std::make_unique<StatsAggregatorImpl>());
125     return OkStatus();
126   }
127 };
128 
129 class StatsAggregatorImplV2 : public StatsAggregator {
130  public:
StatsAggregatorImplV2()131   StatsAggregatorImplV2() {}
132 
~StatsAggregatorImplV2()133   ~StatsAggregatorImplV2() override {
134     if (summary_writer_interface_) {
135       summary_writer_interface_->Unref();
136     }
137   }
138 
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64_t steps)139   void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
140                       const int64_t steps) override {
141     mutex_lock l(mu_);
142     histogram::Histogram& histogram = histograms_[name];
143     for (double value : values) {
144       histogram.Add(value);
145     }
146     AddToEvents(name, steps, histogram);
147   }
148 
AddScalar(const string & name,float value,const int64_t steps)149   void AddScalar(const string& name, float value,
150                  const int64_t steps) override {
151     mutex_lock l(mu_);
152     AddToEvents(name, steps, value);
153   }
154 
155   // TODO(b/116314787): expose this is public API to manually flush summary.
Flush()156   Status Flush() {
157     mutex_lock l(mu_);
158     if (summary_writer_interface_)
159       TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
160     return OkStatus();
161   }
162 
IncrementCounter(const string & name,const string & label,int64_t val)163   void IncrementCounter(const string& name, const string& label,
164                         int64_t val) override {
165     mutex_lock l(*get_counters_map_lock());
166     auto counters_map = get_counters_map();
167     if (counters_map->find(name) == counters_map->end()) {
168       counters_map->emplace(
169           name, monitoring::Counter<1>::New(
170                     /*streamz name*/ "/tensorflow/" + name,
171                     /*streamz description*/
172                     name + " generated or consumed by the component.",
173                     /*streamz label name*/ "component_descriptor"));
174     }
175     counters_map->at(name)->GetCell(label)->IncrementBy(val);
176   }
177 
178   // StatsAggregator implementation for V1 is based on pull-based summary, no-op
179   // in V2.
EncodeToProto(Summary * out_summary)180   void EncodeToProto(Summary* out_summary) override {}
181 
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)182   Status SetSummaryWriter(
183       SummaryWriterInterface* summary_writer_interface) override {
184     mutex_lock l(mu_);
185     if (summary_writer_interface_) {
186       summary_writer_interface_->Unref();
187       // If we create stats_aggregator twice in a program, we would end up with
188       // already existing resource. In this case emitting an error if a
189       // `summary_writer_resource` is present is not the intended behavior, we
190       // could either Unref the existing summary_writer_resource or not set the
191       // new resource at all.
192     }
193     summary_writer_interface_ = summary_writer_interface;
194     summary_writer_interface_->Ref();
195     return OkStatus();
196   }
197 
198  private:
AddToEvents(const string & name,const int64_t steps,const float scalar_value)199   void AddToEvents(const string& name, const int64_t steps,
200                    const float scalar_value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
201     if (summary_writer_interface_ == nullptr) {
202       return;
203     }
204     std::unique_ptr<Event> e{new Event};
205     e->set_step(steps);
206     e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
207     // maybe expose GetWallTime in SummaryWriterInterface
208     Summary::Value* v = e->mutable_summary()->add_value();
209     v->set_tag(name);
210     v->set_simple_value(scalar_value);
211     TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
212   }
213 
AddToEvents(const string & name,const int64_t steps,const histogram::Histogram & histogram)214   void AddToEvents(const string& name, const int64_t steps,
215                    const histogram::Histogram& histogram)
216       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
217     if (summary_writer_interface_ == nullptr) {
218       return;
219     }
220     std::unique_ptr<Event> e{new Event};
221     e->set_step(steps);
222     e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
223     Summary::Value* v = e->mutable_summary()->add_value();
224     v->set_tag(name);
225     histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
226     TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
227   }
228 
229   mutex mu_;
230   SummaryWriterInterface* summary_writer_interface_ TF_GUARDED_BY(mu_) =
231       nullptr;
232   // not owned, we might be associating the default summary_writer from the
233   // context
234   std::unordered_map<string, histogram::Histogram> histograms_
235       TF_GUARDED_BY(mu_);
236   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
237 };
238 
239 class StatsAggregatorHandleOpV2
240     : public ResourceOpKernel<StatsAggregatorResource> {
241  public:
StatsAggregatorHandleOpV2(OpKernelConstruction * ctx)242   explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
243       : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
244 
245  private:
CreateResource(StatsAggregatorResource ** ret)246   Status CreateResource(StatsAggregatorResource** ret) override
247       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
248     *ret =
249         new StatsAggregatorResource(std::make_unique<StatsAggregatorImplV2>());
250     return OkStatus();
251   }
252 };
253 
254 class StatsAggregatorSummaryOp : public OpKernel {
255  public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)256   explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
257       : OpKernel(ctx) {}
258 
Compute(OpKernelContext * ctx)259   void Compute(OpKernelContext* ctx) override {
260     const Tensor& resource_handle_t = ctx->input(0);
261     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
262                 errors::InvalidArgument("resource_handle must be a scalar"));
263 
264     core::RefCountPtr<StatsAggregatorResource> resource;
265     OP_REQUIRES_OK(ctx,
266                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
267 
268     Tensor* summary_t;
269     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
270     Summary summary;
271     resource->stats_aggregator()->EncodeToProto(&summary);
272     summary_t->scalar<tstring>()() = summary.SerializeAsString();
273   }
274 };
275 
276 class StatsAggregatorSetSummaryWriterOp : public OpKernel {
277  public:
StatsAggregatorSetSummaryWriterOp(OpKernelConstruction * ctx)278   explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
279       : OpKernel(ctx) {}
280 
Compute(OpKernelContext * ctx)281   void Compute(OpKernelContext* ctx) override {
282     const Tensor& resource_handle_t = ctx->input(0);
283     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
284                 errors::InvalidArgument("resource_handle must be a scalar"));
285 
286     core::RefCountPtr<StatsAggregatorResource> resource;
287     OP_REQUIRES_OK(ctx,
288                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
289 
290     const Tensor& summary_resource_handle_t = ctx->input(1);
291     OP_REQUIRES(ctx,
292                 TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
293                 errors::InvalidArgument("resource_handle must be a scalar"));
294     core::RefCountPtr<SummaryWriterInterface> summary_resource;
295     OP_REQUIRES_OK(
296         ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &summary_resource));
297     TF_CHECK_OK(
298         resource->stats_aggregator()->SetSummaryWriter(summary_resource.get()));
299   }
300 };
301 
302 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
303                         StatsAggregatorHandleOp);
304 REGISTER_KERNEL_BUILDER(
305     Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
306     StatsAggregatorHandleOp);
307 
308 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
309                         StatsAggregatorHandleOpV2);
310 
311 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
312                         StatsAggregatorSummaryOp);
313 REGISTER_KERNEL_BUILDER(
314     Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
315     StatsAggregatorSummaryOp);
316 
317 REGISTER_KERNEL_BUILDER(
318     Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
319     StatsAggregatorSetSummaryWriterOp);
320 
321 }  // namespace
322 }  // namespace experimental
323 }  // namespace data
324 }  // namespace tensorflow
325