xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/tensorflow_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/tensorflow_util.h"
16 
17 #include <string.h>
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 
23 #ifdef GOOGLE_PLATFORM
24 #include "file/logging/log_lines.h"
25 #endif
26 #include "google/protobuf/map.h"
27 #include "absl/strings/str_split.h"
28 #include "absl/strings/string_view.h"
29 #include "tensorflow/lite/toco/toco_port.h"
30 #include "tensorflow/lite/toco/tooling_util.h"
31 #include "tensorflow/core/framework/attr_value.pb.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace toco {
38 
39 using tensorflow::AttrValue;
40 using tensorflow::GraphDef;
41 
LogDumpGraphDef(int log_level,const std::string & message,const GraphDef & tf_graph)42 void LogDumpGraphDef(int log_level, const std::string& message,
43                      const GraphDef& tf_graph) {
44   if (!VLOG_IS_ON(log_level)) {
45     return;
46   }
47   std::set<std::string> ops;
48   for (const auto& node : tf_graph.node()) {
49     ops.insert(node.op());
50   }
51   std::string dump;
52   toco::port::AppendF(&dump, R"MSG(
53 BEGIN DUMP OF TENSORFLOW GRAPHDEF (%s)
54 There are %d nodes.
55 There are %zu different op types:
56 )MSG",
57                       message, tf_graph.node_size(), ops.size());
58   for (const auto& op : ops) {
59     toco::port::AppendF(&dump, "  %s\n", op);
60   }
61   dump.append(R"MSG(
62 PROTO DUMP
63 )MSG");
64   for (const auto& node : tf_graph.node()) {
65     toco::port::AppendF(&dump, R"MSG(
66 BEGIN NODE: name = %s
67   op = %s
68   inputs = [
69 )MSG",
70                         node.name(), node.op());
71     for (const auto& input : node.input()) {
72       toco::port::AppendF(&dump, "    %s\n", input);
73     }
74     dump.append("  ]\n");
75     for (const auto& attr : node.attr()) {
76       toco::port::AppendF(&dump, "  ATTR: name = %s\n", attr.first);
77       if (attr.second.value_case() == AttrValue::kFunc) {
78         dump.append("    func\n");
79       } else if (attr.second.value_case() == AttrValue::kPlaceholder) {
80         toco::port::AppendF(&dump, "    placeholder: %s\n",
81                             attr.second.placeholder());
82       } else if (attr.second.value_case() == AttrValue::kS) {
83         dump.append("    string:\n");
84         dump.append(R"MSG(
85       BEGIN EMBEDDED STRING
86 )MSG");
87         const auto& lines = absl::StrSplit(attr.second.s(), '\n');
88         for (const auto& line : lines) {
89           toco::port::AppendF(&dump, "      %s\n", line);
90         }
91         dump.append(R"MSG(
92       END EMBEDDED STRING
93 )MSG");
94       } else if (attr.second.value_case() == AttrValue::kI) {
95         toco::port::AppendF(&dump, "    int: %lld\n", attr.second.i());
96       } else if (attr.second.value_case() == AttrValue::kF) {
97         toco::port::AppendF(&dump, "    float: %g\n", attr.second.f());
98       } else if (attr.second.value_case() == AttrValue::kB) {
99         toco::port::AppendF(&dump, "    bool: %s\n",
100                             attr.second.b() ? "true" : "false");
101       } else if (attr.second.value_case() == AttrValue::kType) {
102         toco::port::AppendF(&dump, "    type: %s\n",
103                             tensorflow::DataType_Name(attr.second.type()));
104       } else if (attr.second.value_case() == AttrValue::kShape) {
105         dump.append("    shape: [ ");
106         const auto& shape = attr.second.shape();
107         for (int i = 0; i < shape.dim_size(); i++) {
108           toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
109         }
110         dump.append("]\n");
111       } else if (attr.second.value_case() == AttrValue::kTensor) {
112         const auto& tensor = attr.second.tensor();
113         dump.append("    TENSOR:\n");
114         toco::port::AppendF(&dump, "      type: %s\n",
115                             tensorflow::DataType_Name(tensor.dtype()));
116         const auto& shape = tensor.tensor_shape();
117         dump.append("      shape: [ ");
118         for (int i = 0; i < shape.dim_size(); i++) {
119           toco::port::AppendF(&dump, "%lld ", shape.dim(i).size());
120         }
121         dump.append("]\n");
122         if (!tensor.tensor_content().empty()) {
123           toco::port::AppendF(&dump, "      tensor_content: %zu bytes\n",
124                               tensor.tensor_content().size());
125         }
126         if (tensor.dtype() == tensorflow::DT_INT32) {
127           CHECK_EQ(0, tensor.tensor_content().size() % sizeof(int32));
128           const int size = tensor.tensor_content().size() / sizeof(int32);
129           std::vector<int32> data(size);
130           toco::port::CopyToBuffer(tensor.tensor_content(),
131                                    reinterpret_cast<char*>(data.data()));
132           const int kMaxValsToPrint = 4;
133           dump.append("        tensor_content as ints: [ ");
134           for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
135             toco::port::AppendF(&dump, "%d ", data[i]);
136           }
137           if (size > kMaxValsToPrint) {
138             dump.append("... ");
139           }
140           dump.append("]\n");
141         }
142         if (tensor.dtype() == tensorflow::DT_FLOAT) {
143           CHECK_EQ(0, tensor.tensor_content().size() % sizeof(float));
144           const int size = tensor.tensor_content().size() / sizeof(float);
145           std::vector<float> data(size);
146           toco::port::CopyToBuffer(tensor.tensor_content(),
147                                    reinterpret_cast<char*>(data.data()));
148           const int kMaxValsToPrint = 4;
149           dump.append("        tensor_content as floats: [ ");
150           for (int i = 0; i < kMaxValsToPrint && i < size; i++) {
151             toco::port::AppendF(&dump, "%g ", data[i]);
152           }
153           if (size > kMaxValsToPrint) {
154             dump.append("... ");
155           }
156           dump.append("]\n");
157         }
158         if (tensor.int_val_size()) {
159           toco::port::AppendF(&dump, "      int_val: %d ints: [ ",
160                               tensor.int_val_size());
161           const int kMaxValsToPrint = 4;
162           for (int i = 0; i < kMaxValsToPrint && i < tensor.int_val_size();
163                i++) {
164             toco::port::AppendF(&dump, "%d ", tensor.int_val(i));
165           }
166           if (tensor.int_val_size() > kMaxValsToPrint) {
167             dump.append("... ");
168           }
169           dump.append("]\n");
170         }
171         if (tensor.float_val_size()) {
172           toco::port::AppendF(&dump, "      float_val: %d floats: [ ",
173                               tensor.float_val_size());
174           const int kMaxValsToPrint = 4;
175           for (int i = 0; i < kMaxValsToPrint && i < tensor.float_val_size();
176                i++) {
177             toco::port::AppendF(&dump, "%g ", tensor.float_val(i));
178           }
179           if (tensor.float_val_size() > kMaxValsToPrint) {
180             dump.append("... ");
181           }
182           dump.append("]\n");
183         }
184         if (tensor.string_val_size()) {
185           toco::port::AppendF(&dump, "      string_val: %d strings\n",
186                               tensor.string_val_size());
187         }
188       } else if (attr.second.value_case() == AttrValue::kList) {
189         dump.append("  LIST\n");
190       }
191     }
192     dump.append("END NODE\n");
193   }
194   toco::port::AppendF(&dump, "END DUMP OF TENSORFLOW GRAPHDEF (%s)\n", message);
195 #if defined(GOOGLE_PLATFORM)
196   VLOG_LINES(log_level, dump);
197 #else
198   VLOG(log_level) << dump;
199 #endif
200 }
201 }  // namespace toco
202