1
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include <sstream>
18 #include <string>
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/c/kernels.h"
22 #include "tensorflow/c/kernels/tensor_shape_utils.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/c/tf_tensor.h"
25 #include "tensorflow/core/framework/registration/registration.h"
26 #include "tensorflow/core/framework/summary.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/platform/bfloat16.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/tstring.h"
34 #include "tensorflow/core/platform/types.h"
35
36 namespace {
37
38 // Struct that stores the status and TF_Tensor inputs to the opkernel.
39 // Used to delete tensor and status in its destructor upon kernel return.
40 struct Params {
41 TF_Tensor* tags;
42 TF_Tensor* values;
43 TF_Status* status;
Params__anon214280000111::Params44 explicit Params(TF_OpKernelContext* ctx)
45 : tags(nullptr), values(nullptr), status(nullptr) {
46 status = TF_NewStatus();
47 TF_GetInput(ctx, 0, &tags, status);
48 if (TF_GetCode(status) == TF_OK) {
49 TF_GetInput(ctx, 1, &values, status);
50 }
51 }
~Params__anon214280000111::Params52 ~Params() {
53 TF_DeleteStatus(status);
54 TF_DeleteTensor(tags);
55 TF_DeleteTensor(values);
56 }
57 };
58
59 // dummy functions used for kernel registration
ScalarSummaryOp_Create(TF_OpKernelConstruction * ctx)60 void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
61
ScalarSummaryOp_Delete(void * kernel)62 void ScalarSummaryOp_Delete(void* kernel) {}
63
64 // Helper functions for compute method
65 bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
66 // Returns a string representation of a single tag or empty string if there
67 // are multiple tags
68 std::string SingleTag(TF_Tensor* tags);
69
70 template <typename T>
ScalarSummaryOp_Compute(void * kernel,TF_OpKernelContext * ctx)71 void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
72 Params params(ctx);
73 if (TF_GetCode(params.status) != TF_OK) {
74 TF_OpKernelContext_Failure(ctx, params.status);
75 return;
76 }
77 if (!IsSameSize(params.tags, params.values)) {
78 std::ostringstream err;
79 err << "tags and values are not the same shape: "
80 << tensorflow::ShapeDebugString(params.tags)
81 << " != " << tensorflow::ShapeDebugString(params.values)
82 << SingleTag(params.tags);
83 TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
84 TF_OpKernelContext_Failure(ctx, params.status);
85 return;
86 }
87 // Convert tags and values tensor to array to access elements by index
88 tensorflow::Summary s;
89 auto tags_array =
90 static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
91 auto values_array = static_cast<T*>(TF_TensorData(params.values));
92 // Copy tags and values into summary protobuf
93 for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
94 tensorflow::Summary::Value* v = s.add_value();
95 const tensorflow::tstring& Ttags_i = tags_array[i];
96 v->set_tag(Ttags_i.data(), Ttags_i.size());
97 v->set_simple_value(static_cast<float>(values_array[i]));
98 }
99 TF_Tensor* summary_tensor =
100 TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
101 sizeof(tensorflow::tstring), params.status);
102 if (TF_GetCode(params.status) != TF_OK) {
103 TF_DeleteTensor(summary_tensor);
104 TF_OpKernelContext_Failure(ctx, params.status);
105 return;
106 }
107 tensorflow::tstring* output_tstring =
108 reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
109 CHECK(SerializeToTString(s, output_tstring));
110 TF_DeleteTensor(summary_tensor);
111 }
112
IsSameSize(TF_Tensor * tensor1,TF_Tensor * tensor2)113 bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
114 if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
115 return false;
116 }
117 for (int d = 0; d < TF_NumDims(tensor1); d++) {
118 if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
119 return false;
120 }
121 }
122 return true;
123 }
124
SingleTag(TF_Tensor * tags)125 std::string SingleTag(TF_Tensor* tags) {
126 if (TF_TensorElementCount(tags) == 1) {
127 const char* single_tag =
128 static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
129 return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
130 } else {
131 return "";
132 }
133 }
134
135 template <typename T>
RegisterScalarSummaryOpKernel()136 void RegisterScalarSummaryOpKernel() {
137 TF_Status* status = TF_NewStatus();
138 {
139 auto* builder = TF_NewKernelBuilder(
140 "ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
141 &ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
142 TF_KernelBuilder_TypeConstraint(
143 builder, "T",
144 static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
145 CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
146 TF_RegisterKernelBuilder("ScalarSummary", builder, status);
147 CHECK_EQ(TF_OK, TF_GetCode(status))
148 << "Error while registering Scalar Summmary kernel";
149 }
150 TF_DeleteStatus(status);
151 }
152
153 // A dummy static variable initialized by a lambda whose side-effect is to
154 // register the ScalarSummary kernel.
__anon214280000202() 155 TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
156 if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
157 RegisterScalarSummaryOpKernel<int64_t>();
158 RegisterScalarSummaryOpKernel<tensorflow::uint64>();
159 RegisterScalarSummaryOpKernel<tensorflow::int32>();
160 RegisterScalarSummaryOpKernel<tensorflow::uint32>();
161 RegisterScalarSummaryOpKernel<tensorflow::uint16>();
162 RegisterScalarSummaryOpKernel<tensorflow::int16>();
163 RegisterScalarSummaryOpKernel<tensorflow::int8>();
164 RegisterScalarSummaryOpKernel<tensorflow::uint8>();
165 RegisterScalarSummaryOpKernel<Eigen::half>();
166 RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
167 RegisterScalarSummaryOpKernel<float>();
168 RegisterScalarSummaryOpKernel<double>();
169 }
170 return true;
171 }();
172 } // namespace
173