xref: /aosp_15_r20/external/tensorflow/tensorflow/c/kernels/summary_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 
2 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include <sstream>
18 #include <string>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/c/kernels.h"
22 #include "tensorflow/c/kernels/tensor_shape_utils.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/c/tf_tensor.h"
25 #include "tensorflow/core/framework/registration/registration.h"
26 #include "tensorflow/core/framework/summary.pb.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/platform/bfloat16.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/strcat.h"
33 #include "tensorflow/core/platform/tstring.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace {
37 
38 // Struct that stores the status and TF_Tensor inputs to the opkernel.
39 // Used to delete tensor and status in its destructor upon kernel return.
40 struct Params {
41   TF_Tensor* tags;
42   TF_Tensor* values;
43   TF_Status* status;
Params__anon214280000111::Params44   explicit Params(TF_OpKernelContext* ctx)
45       : tags(nullptr), values(nullptr), status(nullptr) {
46     status = TF_NewStatus();
47     TF_GetInput(ctx, 0, &tags, status);
48     if (TF_GetCode(status) == TF_OK) {
49       TF_GetInput(ctx, 1, &values, status);
50     }
51   }
~Params__anon214280000111::Params52   ~Params() {
53     TF_DeleteStatus(status);
54     TF_DeleteTensor(tags);
55     TF_DeleteTensor(values);
56   }
57 };
58 
59 // dummy functions used for kernel registration
ScalarSummaryOp_Create(TF_OpKernelConstruction * ctx)60 void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; }
61 
ScalarSummaryOp_Delete(void * kernel)62 void ScalarSummaryOp_Delete(void* kernel) {}
63 
64 // Helper functions for compute method
65 bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2);
66 // Returns a string representation of a single tag or empty string if there
67 // are multiple tags
68 std::string SingleTag(TF_Tensor* tags);
69 
70 template <typename T>
ScalarSummaryOp_Compute(void * kernel,TF_OpKernelContext * ctx)71 void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
72   Params params(ctx);
73   if (TF_GetCode(params.status) != TF_OK) {
74     TF_OpKernelContext_Failure(ctx, params.status);
75     return;
76   }
77   if (!IsSameSize(params.tags, params.values)) {
78     std::ostringstream err;
79     err << "tags and values are not the same shape: "
80         << tensorflow::ShapeDebugString(params.tags)
81         << " != " << tensorflow::ShapeDebugString(params.values)
82         << SingleTag(params.tags);
83     TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str());
84     TF_OpKernelContext_Failure(ctx, params.status);
85     return;
86   }
87   // Convert tags and values tensor to array to access elements by index
88   tensorflow::Summary s;
89   auto tags_array =
90       static_cast<tensorflow::tstring*>(TF_TensorData(params.tags));
91   auto values_array = static_cast<T*>(TF_TensorData(params.values));
92   // Copy tags and values into summary protobuf
93   for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) {
94     tensorflow::Summary::Value* v = s.add_value();
95     const tensorflow::tstring& Ttags_i = tags_array[i];
96     v->set_tag(Ttags_i.data(), Ttags_i.size());
97     v->set_simple_value(static_cast<float>(values_array[i]));
98   }
99   TF_Tensor* summary_tensor =
100       TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0,
101                         sizeof(tensorflow::tstring), params.status);
102   if (TF_GetCode(params.status) != TF_OK) {
103     TF_DeleteTensor(summary_tensor);
104     TF_OpKernelContext_Failure(ctx, params.status);
105     return;
106   }
107   tensorflow::tstring* output_tstring =
108       reinterpret_cast<tensorflow::tstring*>(TF_TensorData(summary_tensor));
109   CHECK(SerializeToTString(s, output_tstring));
110   TF_DeleteTensor(summary_tensor);
111 }
112 
IsSameSize(TF_Tensor * tensor1,TF_Tensor * tensor2)113 bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) {
114   if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) {
115     return false;
116   }
117   for (int d = 0; d < TF_NumDims(tensor1); d++) {
118     if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) {
119       return false;
120     }
121   }
122   return true;
123 }
124 
SingleTag(TF_Tensor * tags)125 std::string SingleTag(TF_Tensor* tags) {
126   if (TF_TensorElementCount(tags) == 1) {
127     const char* single_tag =
128         static_cast<tensorflow::tstring*>(TF_TensorData(tags))->c_str();
129     return tensorflow::strings::StrCat(" (tag '", single_tag, "')");
130   } else {
131     return "";
132   }
133 }
134 
135 template <typename T>
RegisterScalarSummaryOpKernel()136 void RegisterScalarSummaryOpKernel() {
137   TF_Status* status = TF_NewStatus();
138   {
139     auto* builder = TF_NewKernelBuilder(
140         "ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create,
141         &ScalarSummaryOp_Compute<T>, &ScalarSummaryOp_Delete);
142     TF_KernelBuilder_TypeConstraint(
143         builder, "T",
144         static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
145     CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
146     TF_RegisterKernelBuilder("ScalarSummary", builder, status);
147     CHECK_EQ(TF_OK, TF_GetCode(status))
148         << "Error while registering Scalar Summmary kernel";
149   }
150   TF_DeleteStatus(status);
151 }
152 
153 // A dummy static variable initialized by a lambda whose side-effect is to
154 // register the ScalarSummary kernel.
__anon214280000202() 155 TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() {
156   if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) {
157     RegisterScalarSummaryOpKernel<int64_t>();
158     RegisterScalarSummaryOpKernel<tensorflow::uint64>();
159     RegisterScalarSummaryOpKernel<tensorflow::int32>();
160     RegisterScalarSummaryOpKernel<tensorflow::uint32>();
161     RegisterScalarSummaryOpKernel<tensorflow::uint16>();
162     RegisterScalarSummaryOpKernel<tensorflow::int16>();
163     RegisterScalarSummaryOpKernel<tensorflow::int8>();
164     RegisterScalarSummaryOpKernel<tensorflow::uint8>();
165     RegisterScalarSummaryOpKernel<Eigen::half>();
166     RegisterScalarSummaryOpKernel<tensorflow::bfloat16>();
167     RegisterScalarSummaryOpKernel<float>();
168     RegisterScalarSummaryOpKernel<double>();
169   }
170   return true;
171 }();
172 }  // namespace
173