1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/resource_op_kernel.h"
19 #include "tensorflow/core/framework/stats_aggregator.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/kernels/summary_interface.h"
22 #include "tensorflow/core/lib/core/refcount.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/histogram/histogram.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/sampler.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/events_writer.h"
30
31 namespace tensorflow {
32 namespace data {
33 namespace experimental {
34 namespace {
35
get_counters_map_lock()36 static mutex* get_counters_map_lock() {
37 static mutex counters_map_lock(LINKER_INITIALIZED);
38 return &counters_map_lock;
39 }
40
get_counters_map()41 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
42 static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
43 new std::unordered_map<string, monitoring::Counter<1>*>;
44 return counters_map;
45 }
46
47 class StatsAggregatorImpl : public StatsAggregator {
48 public:
StatsAggregatorImpl()49 StatsAggregatorImpl() {}
50
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64_t steps)51 void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
52 const int64_t steps) override {
53 mutex_lock l(mu_);
54 histogram::Histogram& histogram = histograms_[name];
55 for (double value : values) {
56 histogram.Add(value);
57 }
58 }
59
AddScalar(const string & name,float value,const int64_t steps)60 void AddScalar(const string& name, float value,
61 const int64_t steps) override {
62 mutex_lock l(mu_);
63 scalars_[name] = value;
64 }
65
EncodeToProto(Summary * out_summary)66 void EncodeToProto(Summary* out_summary) override {
67 mutex_lock l(mu_);
68 for (const auto& pair : histograms_) {
69 const string& name = pair.first;
70 const histogram::Histogram& histogram = pair.second;
71
72 Summary::Value* value = out_summary->add_value();
73 value->set_tag(name);
74 histogram.EncodeToProto(value->mutable_histo(),
75 false /* doesn't preserve zero buckets */);
76 }
77 for (const auto& pair : scalars_) {
78 Summary::Value* value = out_summary->add_value();
79 value->set_tag(pair.first);
80 value->set_simple_value(pair.second);
81 }
82 }
83
84 // StatsAggregator implementation for V2 is based on push-based summary, no-op
85 // in V1.
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)86 Status SetSummaryWriter(
87 SummaryWriterInterface* summary_writer_interface) override {
88 return OkStatus();
89 }
90
IncrementCounter(const string & name,const string & label,int64_t val)91 void IncrementCounter(const string& name, const string& label,
92 int64_t val) override {
93 mutex_lock l(*get_counters_map_lock());
94 auto counters_map = get_counters_map();
95 if (counters_map->find(name) == counters_map->end()) {
96 counters_map->emplace(
97 name,
98 monitoring::Counter<1>::New(
99 /*streamz name*/ name,
100 /*streamz description*/
101 strings::StrCat(name, " generated or consumed by the component."),
102 /*streamz label name*/ "component_descriptor"));
103 }
104 counters_map->at(name)->GetCell(label)->IncrementBy(val);
105 }
106
107 private:
108 mutex mu_;
109 std::unordered_map<string, histogram::Histogram> histograms_
110 TF_GUARDED_BY(mu_);
111 std::unordered_map<string, float> scalars_ TF_GUARDED_BY(mu_);
112 TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
113 };
114
115 class StatsAggregatorHandleOp
116 : public ResourceOpKernel<StatsAggregatorResource> {
117 public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)118 explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
119 : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
120
121 private:
CreateResource(StatsAggregatorResource ** ret)122 Status CreateResource(StatsAggregatorResource** ret) override
123 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
124 *ret = new StatsAggregatorResource(std::make_unique<StatsAggregatorImpl>());
125 return OkStatus();
126 }
127 };
128
129 class StatsAggregatorImplV2 : public StatsAggregator {
130 public:
StatsAggregatorImplV2()131 StatsAggregatorImplV2() {}
132
~StatsAggregatorImplV2()133 ~StatsAggregatorImplV2() override {
134 if (summary_writer_interface_) {
135 summary_writer_interface_->Unref();
136 }
137 }
138
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64_t steps)139 void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
140 const int64_t steps) override {
141 mutex_lock l(mu_);
142 histogram::Histogram& histogram = histograms_[name];
143 for (double value : values) {
144 histogram.Add(value);
145 }
146 AddToEvents(name, steps, histogram);
147 }
148
AddScalar(const string & name,float value,const int64_t steps)149 void AddScalar(const string& name, float value,
150 const int64_t steps) override {
151 mutex_lock l(mu_);
152 AddToEvents(name, steps, value);
153 }
154
155 // TODO(b/116314787): expose this is public API to manually flush summary.
Flush()156 Status Flush() {
157 mutex_lock l(mu_);
158 if (summary_writer_interface_)
159 TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
160 return OkStatus();
161 }
162
IncrementCounter(const string & name,const string & label,int64_t val)163 void IncrementCounter(const string& name, const string& label,
164 int64_t val) override {
165 mutex_lock l(*get_counters_map_lock());
166 auto counters_map = get_counters_map();
167 if (counters_map->find(name) == counters_map->end()) {
168 counters_map->emplace(
169 name, monitoring::Counter<1>::New(
170 /*streamz name*/ "/tensorflow/" + name,
171 /*streamz description*/
172 name + " generated or consumed by the component.",
173 /*streamz label name*/ "component_descriptor"));
174 }
175 counters_map->at(name)->GetCell(label)->IncrementBy(val);
176 }
177
178 // StatsAggregator implementation for V1 is based on pull-based summary, no-op
179 // in V2.
EncodeToProto(Summary * out_summary)180 void EncodeToProto(Summary* out_summary) override {}
181
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)182 Status SetSummaryWriter(
183 SummaryWriterInterface* summary_writer_interface) override {
184 mutex_lock l(mu_);
185 if (summary_writer_interface_) {
186 summary_writer_interface_->Unref();
187 // If we create stats_aggregator twice in a program, we would end up with
188 // already existing resource. In this case emitting an error if a
189 // `summary_writer_resource` is present is not the intended behavior, we
190 // could either Unref the existing summary_writer_resource or not set the
191 // new resource at all.
192 }
193 summary_writer_interface_ = summary_writer_interface;
194 summary_writer_interface_->Ref();
195 return OkStatus();
196 }
197
198 private:
AddToEvents(const string & name,const int64_t steps,const float scalar_value)199 void AddToEvents(const string& name, const int64_t steps,
200 const float scalar_value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
201 if (summary_writer_interface_ == nullptr) {
202 return;
203 }
204 std::unique_ptr<Event> e{new Event};
205 e->set_step(steps);
206 e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
207 // maybe expose GetWallTime in SummaryWriterInterface
208 Summary::Value* v = e->mutable_summary()->add_value();
209 v->set_tag(name);
210 v->set_simple_value(scalar_value);
211 TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
212 }
213
AddToEvents(const string & name,const int64_t steps,const histogram::Histogram & histogram)214 void AddToEvents(const string& name, const int64_t steps,
215 const histogram::Histogram& histogram)
216 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
217 if (summary_writer_interface_ == nullptr) {
218 return;
219 }
220 std::unique_ptr<Event> e{new Event};
221 e->set_step(steps);
222 e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
223 Summary::Value* v = e->mutable_summary()->add_value();
224 v->set_tag(name);
225 histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
226 TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
227 }
228
229 mutex mu_;
230 SummaryWriterInterface* summary_writer_interface_ TF_GUARDED_BY(mu_) =
231 nullptr;
232 // not owned, we might be associating the default summary_writer from the
233 // context
234 std::unordered_map<string, histogram::Histogram> histograms_
235 TF_GUARDED_BY(mu_);
236 TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
237 };
238
239 class StatsAggregatorHandleOpV2
240 : public ResourceOpKernel<StatsAggregatorResource> {
241 public:
StatsAggregatorHandleOpV2(OpKernelConstruction * ctx)242 explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
243 : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
244
245 private:
CreateResource(StatsAggregatorResource ** ret)246 Status CreateResource(StatsAggregatorResource** ret) override
247 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
248 *ret =
249 new StatsAggregatorResource(std::make_unique<StatsAggregatorImplV2>());
250 return OkStatus();
251 }
252 };
253
254 class StatsAggregatorSummaryOp : public OpKernel {
255 public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)256 explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
257 : OpKernel(ctx) {}
258
Compute(OpKernelContext * ctx)259 void Compute(OpKernelContext* ctx) override {
260 const Tensor& resource_handle_t = ctx->input(0);
261 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
262 errors::InvalidArgument("resource_handle must be a scalar"));
263
264 core::RefCountPtr<StatsAggregatorResource> resource;
265 OP_REQUIRES_OK(ctx,
266 LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
267
268 Tensor* summary_t;
269 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
270 Summary summary;
271 resource->stats_aggregator()->EncodeToProto(&summary);
272 summary_t->scalar<tstring>()() = summary.SerializeAsString();
273 }
274 };
275
276 class StatsAggregatorSetSummaryWriterOp : public OpKernel {
277 public:
StatsAggregatorSetSummaryWriterOp(OpKernelConstruction * ctx)278 explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
279 : OpKernel(ctx) {}
280
Compute(OpKernelContext * ctx)281 void Compute(OpKernelContext* ctx) override {
282 const Tensor& resource_handle_t = ctx->input(0);
283 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
284 errors::InvalidArgument("resource_handle must be a scalar"));
285
286 core::RefCountPtr<StatsAggregatorResource> resource;
287 OP_REQUIRES_OK(ctx,
288 LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
289
290 const Tensor& summary_resource_handle_t = ctx->input(1);
291 OP_REQUIRES(ctx,
292 TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
293 errors::InvalidArgument("resource_handle must be a scalar"));
294 core::RefCountPtr<SummaryWriterInterface> summary_resource;
295 OP_REQUIRES_OK(
296 ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &summary_resource));
297 TF_CHECK_OK(
298 resource->stats_aggregator()->SetSummaryWriter(summary_resource.get()));
299 }
300 };
301
302 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
303 StatsAggregatorHandleOp);
304 REGISTER_KERNEL_BUILDER(
305 Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
306 StatsAggregatorHandleOp);
307
308 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
309 StatsAggregatorHandleOpV2);
310
311 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
312 StatsAggregatorSummaryOp);
313 REGISTER_KERNEL_BUILDER(
314 Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
315 StatsAggregatorSummaryOp);
316
317 REGISTER_KERNEL_BUILDER(
318 Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
319 StatsAggregatorSetSummaryWriterOp);
320
321 } // namespace
322 } // namespace experimental
323 } // namespace data
324 } // namespace tensorflow
325