1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17
18 #include <stdio.h>
19
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
27 #include "tensorflow/core/platform/regexp.h"
28 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
29
30 namespace tensorflow {
31 namespace tfprof {
32 namespace {
33
34 const char* const kGradientSuffix = " (gradient)";
35
36 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)37 std::string GetTraceString(const CallStack::Trace& trace) {
38 std::string ntrace =
39 absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
40 if (trace.function().length() < 20) {
41 absl::StrAppend(&ntrace, ":", trace.function());
42 } else {
43 absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
44 }
45 return ntrace;
46 }
47
IsGradNode(const string & name,string * forward_name)48 bool IsGradNode(const string& name, string* forward_name) {
49 // Given a forward operation with name op, its gradient op has the following
50 // name: ...gradients/op_grad/...
51 // TODO(xpan): This is hacky.
52 auto grad_prefix = name.find("gradients/");
53 auto grad_suffix = name.find("_grad/");
54 if (grad_prefix == name.npos || grad_suffix == name.npos) {
55 return false;
56 }
57 auto start = grad_prefix + string("gradients/").length();
58 auto len = grad_suffix - start;
59 if (len <= 0) {
60 return false;
61 }
62 *forward_name = name.substr(start, len);
63 return true;
64 }
65
66 // StringTable maps each string to an id.
67 class StringTable {
68 public:
StringTable()69 StringTable() {
70 // Pprof requires first entry in string_table to be ''.
71 string_id_[""] = 0;
72 all_strings_.push_back("");
73 }
74
75 // Returns the index of a string. If not found, inserts the string and
76 // return the inserted index.
GetIndex(const string & str)77 uint64 GetIndex(const string& str) {
78 auto idx = string_id_.find(str);
79 if (idx != string_id_.end()) {
80 return idx->second;
81 }
82 all_strings_.push_back(str);
83 return string_id_.insert(std::pair<string, int64_t>(str, string_id_.size()))
84 .first->second;
85 }
86
strings() const87 const std::vector<string>& strings() const { return all_strings_; }
88
89 private:
90 std::map<string, uint64> string_id_;
91 std::vector<string> all_strings_;
92 };
93
94 // FunctionTable maps each function to an id.
95 class FunctionTable {
96 public:
FunctionTable(StringTable * string_table)97 explicit FunctionTable(StringTable* string_table)
98 : string_table_(string_table) {}
99
100 // Returns the index of a function. If not found, adds a function proto
101 // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)102 uint64 GetIndex(const string& file_path, const string& func_name,
103 uint64 func_start_line) {
104 auto key = std::tuple<string, string, uint64>(file_path, func_name,
105 func_start_line);
106 auto idx = function_table_.find(key);
107 if (idx != function_table_.end()) {
108 return idx->second.id();
109 }
110 pprof::Function* func_pb = &function_table_[key];
111 // function index should start from 1.
112 func_pb->set_id(function_table_.size());
113
114 string file_base(io::Basename(file_path));
115 file_base = file_base.substr(0, file_base.find_last_of('.'));
116 func_pb->set_name(
117 string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
118 func_pb->set_filename(string_table_->GetIndex(file_path));
119 func_pb->set_start_line(func_start_line);
120 return func_pb->id();
121 }
122
123 const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const124 functions() const {
125 return function_table_;
126 }
127
128 private:
129 StringTable* string_table_;
130 std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
131 };
132
133 // LocationTable maps each function call to an id.
134 class LocationTable {
135 public:
LocationTable(FunctionTable * function_table)136 explicit LocationTable(FunctionTable* function_table)
137 : function_table_(function_table) {}
138
139 // Returns the index of a function call location. If not found, adds a
140 // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)141 uint64 GetIndex(const string& file_path, uint64 line_number,
142 const string& called_function_name,
143 const string& called_file_path,
144 uint64 called_func_start_line) {
145 auto key = std::tuple<string, string, uint64>(
146 file_path, called_function_name, line_number);
147
148 auto idx = location_table_.find(key);
149 if (idx != location_table_.end()) {
150 return idx->second.id();
151 }
152 pprof::Location* location_pb = &location_table_[key];
153 location_pb->set_id(location_table_.size());
154 pprof::Line* line_pb = location_pb->add_line();
155 line_pb->set_function_id(function_table_->GetIndex(
156 called_file_path, called_function_name, called_func_start_line));
157 line_pb->set_line(line_number);
158 return location_pb->id();
159 }
160
161 const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const162 locations() const {
163 return location_table_;
164 }
165
166 private:
167 FunctionTable* function_table_;
168 std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
169 };
170
171 // Samples stores samples of all calls. A sample is a single call trace,
172 // that is, the call path from top caller to the leaf callee.
173 class Samples {
174 public:
Samples(StringTable * string_table,const Options * opts)175 explicit Samples(StringTable* string_table, const Options* opts)
176 : string_table_(string_table), opts_(opts) {}
177
178 // 'node' is the leaf of the displayed trace. It includes all graph nodes
179 // created by it. 'location_ids' contains
180 // the call stack, from callee to caller.
181 // This method adds the statistics of graph nodes created by the python
182 // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)183 void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
184 // displayed leaf might not be true leaf. Retrieve the true leaves for
185 // stats.
186 std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
187 CHECK(!all_leaf.empty()) << node->name();
188
189 for (const CodeNode* cn : all_leaf) {
190 for (const auto& gn_it : cn->node->graph_nodes()) {
191 const TFGraphNode* gn = gn_it.second;
192 string name = gn->name();
193 // Generate a new trace name, in case the name is taken.
194 while (sample_table_.find(name) != sample_table_.end()) {
195 name += '@';
196 }
197 pprof::Sample* sample_pb = &sample_table_[name];
198 for (uint64 id : location_ids) {
199 sample_pb->mutable_location_id()->Add(id);
200 }
201 pprof::Label* label_pb = sample_pb->mutable_label()->Add();
202 label_pb->set_key(string_table_->GetIndex("graph node:"));
203 label_pb->set_str(string_table_->GetIndex(gn->name()));
204
205 sample_pb->mutable_value()->Add(1);
206 string type = *opts_->select.begin();
207 if (type == kShown[1]) {
208 sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
209 } else if (type == kShown[9]) {
210 sample_pb->mutable_value()->Add(
211 gn->accelerator_exec_micros(node->node->step()));
212 } else if (type == kShown[10]) {
213 sample_pb->mutable_value()->Add(
214 gn->cpu_exec_micros(node->node->step()));
215 } else if (type == kShown[0]) {
216 sample_pb->mutable_value()->Add(
217 gn->requested_bytes(node->node->step()));
218 } else if (type == kShown[11]) {
219 sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
220 } else if (type == kShown[12]) {
221 sample_pb->mutable_value()->Add(
222 gn->residual_bytes(node->node->step()));
223 } else if (type == kShown[13]) {
224 sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
225 } else if (type == kShown[2]) {
226 sample_pb->mutable_value()->Add(gn->parameters());
227 } else if (type == kShown[3]) {
228 sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
229 } else {
230 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
231 }
232 }
233 }
234 }
235
samples() const236 const std::map<string, pprof::Sample>& samples() const {
237 return sample_table_;
238 }
239
240 private:
FetchAllLeaf(const CodeNode * root)241 std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
242 if (root->children.empty()) {
243 return {root};
244 }
245 std::vector<const CodeNode*> ret;
246 for (auto& n : root->children) {
247 std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
248 ret.insert(ret.end(), nodes.begin(), nodes.end());
249 }
250 return ret;
251 }
252
253 StringTable* string_table_;
254 const Options* opts_;
255 std::map<string, pprof::Sample> sample_table_;
256 };
257
258 class PprofProfileImpl : public PprofProfile {
259 public:
PprofProfileImpl(const Options * opts)260 explicit PprofProfileImpl(const Options* opts)
261 : opts_(opts),
262 func_table_(new FunctionTable(&string_table_)),
263 loc_table_(new LocationTable(func_table_.get())),
264 samples_(new Samples(&string_table_, opts)) {}
265
AddLocation(const CodeNode * callee,const CodeNode * caller)266 uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
267 const string& file_path = caller->file();
268 uint64 lineno = caller->lineno();
269 const string& callee_file_path = callee->file();
270 const string& callee_function = callee->function();
271 uint64 callee_func_start_line = callee->func_start_line();
272
273 return loc_table_->GetIndex(file_path, lineno, callee_function,
274 callee_file_path, callee_func_start_line);
275 }
276
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)277 void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
278 std::vector<uint64> reversed_call_ids;
279 std::reverse_copy(call_ids->begin(), call_ids->end(),
280 std::back_inserter(reversed_call_ids));
281 samples_->Add(leaf, reversed_call_ids);
282 }
283
WritePprofProfile(const string & filename)284 Status WritePprofProfile(const string& filename) override {
285 pprof::Profile profile_pb;
286 Build(&profile_pb);
287
288 std::unique_ptr<WritableFile> file;
289 Status s = Env::Default()->NewWritableFile(filename, &file);
290 if (!s.ok()) return s;
291
292 int32_t buf_size = 1024 * 1024;
293 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
294 file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
295 s = zlib_output_buffer->Init();
296 if (!s.ok()) {
297 delete zlib_output_buffer;
298 return s;
299 }
300 s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
301 if (!s.ok()) {
302 delete zlib_output_buffer;
303 return s;
304 }
305 s = zlib_output_buffer->Close();
306 if (!s.ok()) {
307 delete zlib_output_buffer;
308 return s;
309 }
310 absl::FPrintF(stdout,
311 "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
312 filename);
313 delete zlib_output_buffer;
314 return s;
315 }
316
317 private:
Build(pprof::Profile * profile_pb)318 void Build(pprof::Profile* profile_pb) {
319 string sample_type_description = "count";
320 auto sample_type = profile_pb->mutable_sample_type()->Add();
321 sample_type->set_type(string_table_.GetIndex(sample_type_description));
322 sample_type->set_unit(string_table_.GetIndex("count"));
323
324 string type = *opts_->select.begin();
325 sample_type_description = type;
326 sample_type = profile_pb->mutable_sample_type()->Add();
327 sample_type->set_type(string_table_.GetIndex(sample_type_description));
328 if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
329 sample_type->set_unit(string_table_.GetIndex("microseconds"));
330 if (type == kShown[1]) {
331 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
332 "Sum of accelerator execution time and cpu execution time."));
333 } else if (type == kShown[9]) {
334 profile_pb->mutable_comment()->Add(
335 string_table_.GetIndex("Accelerator execution time."));
336 } else if (type == kShown[10]) {
337 profile_pb->mutable_comment()->Add(
338 string_table_.GetIndex("CPU execution time."));
339 }
340 } else if (type == kShown[0]) {
341 sample_type->set_unit(string_table_.GetIndex("bytes"));
342 profile_pb->mutable_comment()->Add(
343 string_table_.GetIndex("Sum of operation total memory requests, "
344 "excluding deallocations."));
345 } else if (type == kShown[11]) {
346 sample_type->set_unit(string_table_.GetIndex("bytes"));
347 profile_pb->mutable_comment()->Add(
348 string_table_.GetIndex("Sum of operation peak memory usage."));
349 } else if (type == kShown[12]) {
350 sample_type->set_unit(string_table_.GetIndex("bytes"));
351 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
352 "Sum of operation allocated memory after finish."));
353 } else if (type == kShown[13]) {
354 sample_type->set_unit(string_table_.GetIndex("bytes"));
355 profile_pb->mutable_comment()->Add(
356 string_table_.GetIndex("Sum of operation output size."));
357 } else if (type == kShown[2]) {
358 sample_type->set_unit(string_table_.GetIndex("count"));
359 profile_pb->mutable_comment()->Add(
360 string_table_.GetIndex("Model parameters."));
361 } else if (type == kShown[3]) {
362 sample_type->set_unit(string_table_.GetIndex("count"));
363 profile_pb->mutable_comment()->Add(string_table_.GetIndex(
364 "Model float operations (Only available if defined)."));
365 } else {
366 absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
367 }
368
369 for (const string& str : string_table_.strings()) {
370 *profile_pb->mutable_string_table()->Add() = str;
371 }
372 for (const auto& sample_it : samples_->samples()) {
373 // TODO(xpan): Consider swap.
374 profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
375 }
376 for (const auto& function_it : func_table_->functions()) {
377 profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
378 }
379 for (const auto& location_it : loc_table_->locations()) {
380 profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
381 }
382 }
383
384 const Options* opts_;
385 StringTable string_table_;
386 std::unique_ptr<FunctionTable> func_table_;
387 std::unique_ptr<LocationTable> loc_table_;
388 std::unique_ptr<Samples> samples_;
389 };
390 } // namespace
391
AddNode(TFGraphNode * node)392 void TFCode::AddNode(TFGraphNode* node) {
393 if (!node->call_stack() || node->call_stack()->traces().empty()) {
394 return;
395 }
396 // We infer the forward operation name from gradient op name. So, we can
397 // map gradient op traces to forward op traces.
398 // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
399 string forward_name;
400 if (IsGradNode(node->name(), &forward_name)) {
401 auto grad_nodes_it = grad_nodes_.find(forward_name);
402 if (grad_nodes_it != grad_nodes_.end()) {
403 grad_nodes_it->second.push_back(node);
404 } else {
405 grad_nodes_.insert(
406 std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
407 }
408 return;
409 } else {
410 forward_nodes_[node->name()] = node;
411 }
412
413 if (!root_) {
414 graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
415 root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
416 }
417
418 CodeNode* pre_code_node = root_.get();
419 // TODO(xpan): Consider to release CodeDef after TFCode is built. It
420 // takes a lot of memory.
421 std::set<string> traces;
422 for (int i = 0, end = node->call_stack()->traces().size(); i < end; ++i) {
423 // Unlike op name, which is globally unique, trace name is only unique
424 // w.r.t. it's parent.
425 const string& trace = GetTraceString(node->call_stack()->traces().at(i));
426 traces.insert(trace);
427 pre_code_node = pre_code_node->AddChildren(
428 trace, &node->call_stack()->traces().at(i), "");
429 const int64_t last_index = node->call_stack()->traces().size() - 1;
430 if (i == last_index) {
431 pre_code_node->node->AddGraphNode(node);
432 }
433 }
434 }
435
Build()436 void TFCode::Build() {
437 int64_t unaccounted_nodes = 0;
438 for (const auto& it : grad_nodes_) {
439 const string& forward_name = it.first;
440 auto forward_it = forward_nodes_.find(forward_name);
441 if (forward_it == forward_nodes_.end()) {
442 unaccounted_nodes += 1;
443 continue;
444 }
445 TFGraphNode* fn = forward_it->second;
446 CodeNode* leaf = nullptr;
447 CodeNode* pre_code_node = root_.get();
448 for (int i = 0, end = fn->call_stack()->traces().size(); i < end; ++i) {
449 const string& trace =
450 GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
451 pre_code_node = pre_code_node->AddChildren(
452 trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
453 const int64_t last_trace = fn->call_stack()->traces().size() - 1;
454 if (i == last_trace) {
455 leaf = pre_code_node;
456 }
457 }
458 for (TFGraphNode* gn : it.second) {
459 leaf->node->AddGraphNode(gn);
460 }
461 }
462 if (unaccounted_nodes > 0) {
463 absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
464 unaccounted_nodes);
465 }
466 }
467
ShowInternal(const Options & opts,Timeline * timeline)468 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
469 Timeline* timeline) {
470 root_->ResetTotalStats();
471 if (opts.output_type == kOutput[3]) {
472 if (opts.select.size() != 1) {
473 absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
474 return root_.get();
475 }
476 string select = *opts.select.begin();
477 if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
478 select != kShown[3] && select != kShown[9] && select != kShown[10] &&
479 select != kShown[11] && select != kShown[12] && select != kShown[13]) {
480 absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
481 return root_.get();
482 }
483 }
484 if (opts.account_displayed_op_only) {
485 absl::FPrintF(stderr,
486 "Note: code view ignores account_displayed_op_only\n");
487 }
488
489 std::vector<CodeNode*> roots = Account(root_->children, opts);
490 root_->show_children.clear();
491 for (CodeNode* n : roots) {
492 root_->AggregateTotalStats(n);
493 }
494
495 if (opts.start_name_regexes.size() != 1 ||
496 opts.start_name_regexes[0] != ".*") {
497 roots = SearchRoot(roots, opts.start_name_regexes);
498 }
499
500 root_->show_children.assign(roots.begin(), roots.end());
501
502 CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
503
504 root->formatted_str = FormatLegend(opts) + root->formatted_str;
505
506 if (opts.output_type == kOutput[3]) {
507 std::vector<uint64> call_ids;
508 pprof_profile_.reset(new PprofProfileImpl(&opts));
509 Format(root, root->show_children, opts, &root->formatted_str,
510 root->mutable_proto(), &call_ids);
511 Status s = pprof_profile_->WritePprofProfile(
512 opts.output_options.at(kPprofOpts[0]));
513 if (!s.ok()) {
514 absl::FPrintF(stderr, "%s\n", s.ToString());
515 }
516 } else {
517 Format(root, root->show_children, opts, &root->formatted_str,
518 root->mutable_proto(), nullptr);
519 if (timeline) {
520 timeline->GenerateCodeTimeline(root);
521 }
522 }
523 return root;
524 }
525
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)526 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
527 const Options& opts, string* display_str,
528 MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
529 if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
530 pprof_profile_->AddSample(root, call_ids);
531 }
532
533 for (CodeNode* node : nodes) {
534 if (root->has_trace() && opts.output_type == kOutput[3]) {
535 uint64 loc_id = pprof_profile_->AddLocation(node, root);
536 call_ids->push_back(loc_id);
537 }
538 display_str->append(node->formatted_str);
539 MultiGraphNodeProto* child = proto->add_children();
540 child->MergeFrom(node->proto());
541 Format(node, node->show_children, opts, display_str, child, call_ids);
542 if (root->has_trace() && opts.output_type == kOutput[3]) {
543 call_ids->pop_back();
544 }
545 }
546 }
547
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)548 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
549 const std::vector<string>& regexes) {
550 std::vector<CodeNode*> res;
551 if (roots.empty()) {
552 return res;
553 }
554 for (CodeNode* root : roots) {
555 bool match_start_node = false;
556 for (const string& regex : regexes) {
557 if (RE2::FullMatch(root->name(), regex)) {
558 res.push_back(root);
559 match_start_node = true;
560 break;
561 }
562 }
563 if (match_start_node) {
564 // Found a start node at this branch, no need to continue.
565 continue;
566 }
567 std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
568 res.insert(res.end(), nroots.begin(), nroots.end());
569 }
570 return res;
571 }
572
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)573 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
574 const Options& opts, int depth,
575 int last_ident) {
576 std::vector<CodeNode*> show_nodes;
577
578 for (CodeNode* node : roots) {
579 if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
580 continue;
581 }
582 int ident = last_ident;
583 bool show = ShouldShow(node, opts, depth);
584 if (show) ident += 2;
585
586 std::vector<CodeNode*> show_cnodes =
587 PrintScope(node->show_children, opts, depth + 1, ident);
588 if (show) {
589 node->show_children.clear();
590
591 show_cnodes = SortNodes(show_cnodes, opts);
592 for (CodeNode* sc : show_cnodes) {
593 node->show_children.push_back(sc);
594 }
595
596 node->formatted_str = FormatNode(node, opts, last_ident);
597
598 if (opts.select.find(kShown[4]) != opts.select.end()) {
599 absl::FPrintF(stderr, "code view has no tensor value to show\n");
600 }
601 show_nodes.push_back(node);
602 } else {
603 show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
604 show_cnodes.end());
605 }
606 }
607 return show_nodes;
608 }
609
Account(const std::vector<CodeNode * > & roots,const Options & opts)610 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
611 const Options& opts) {
612 std::vector<CodeNode*> act_nodes;
613
614 for (CodeNode* node : roots) {
615 node->ResetTotalStats();
616 std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
617 node->account = ReAccount(node, opts);
618 if (node->account || !act_cnodes.empty()) {
619 node->show_children.clear();
620 node->ResetTotalStats();
621 node->AddSelfToTotalStats();
622 for (CodeNode* c : act_cnodes) {
623 node->AggregateTotalStats(c);
624 node->show_children.push_back(c);
625 }
626 act_nodes.push_back(node);
627 }
628 }
629 return act_nodes;
630 }
631
FormatNodeMemory(CodeNode * node,int64_t bytes,int64_t total_bytes) const632 string TFCode::FormatNodeMemory(CodeNode* node, int64_t bytes,
633 int64_t total_bytes) const {
634 string memory = FormatMemory(total_bytes);
635 if (node->account) {
636 memory = FormatMemory(bytes) + "/" + memory;
637 } else {
638 memory = "--/" + memory;
639 }
640 return memory;
641 }
642
FormatNode(CodeNode * node,const Options & opts,int64_t indent) const643 string TFCode::FormatNode(CodeNode* node, const Options& opts,
644 int64_t indent) const {
645 std::vector<string> attrs;
646 if (opts.select.find(kShown[0]) != opts.select.end()) {
647 attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
648 node->proto().total_requested_bytes()));
649 }
650 if (opts.select.find(kShown[11]) != opts.select.end()) {
651 attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
652 node->proto().total_peak_bytes()));
653 }
654 if (opts.select.find(kShown[12]) != opts.select.end()) {
655 attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
656 node->proto().total_residual_bytes()));
657 }
658 if (opts.select.find(kShown[13]) != opts.select.end()) {
659 attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
660 node->proto().total_output_bytes()));
661 }
662
663 std::vector<string> time_attrs = FormatTimes(node, opts);
664 attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
665
666 if (opts.select.find(kShown[2]) != opts.select.end()) {
667 string params = FormatNumber(node->proto().total_parameters()) + " params";
668 if (node->account) {
669 params = FormatNumber(node->proto().parameters()) + "/" + params;
670 } else {
671 params = "--/" + params;
672 }
673 attrs.push_back(params);
674 }
675
676 if (opts.select.find(kShown[3]) != opts.select.end()) {
677 string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
678 if (node->account) {
679 fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
680 } else {
681 fops = "--/" + fops;
682 }
683 attrs.push_back(fops);
684 }
685
686 if (opts.select.find(kShown[5]) != opts.select.end() &&
687 !node->node->devices().empty()) {
688 attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
689 }
690 if (opts.select.find(kShown[6]) != opts.select.end()) {
691 std::set<string> op_types = node->node->op_types();
692 attrs.push_back(absl::StrJoin(op_types, "|"));
693 }
694 if (opts.select.find(kShown[7]) != opts.select.end()) {
695 // TODO(xpan): Make op count available in code view?
696 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
697 }
698 if (opts.select.find(kShown[8]) != opts.select.end()) {
699 attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
700 }
701
702 return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
703 absl::StrJoin(attrs, ", "));
704 }
705 } // namespace tfprof
706 } // namespace tensorflow
707