xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/internal/tfprof_code.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
17 
18 #include <stdio.h>
19 
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
24 #include "tensorflow/core/lib/io/path.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
27 #include "tensorflow/core/platform/regexp.h"
28 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
29 
30 namespace tensorflow {
31 namespace tfprof {
32 namespace {
33 
34 const char* const kGradientSuffix = " (gradient)";
35 
36 // Convert to Trace proto into a short readable string.
GetTraceString(const CallStack::Trace & trace)37 std::string GetTraceString(const CallStack::Trace& trace) {
38   std::string ntrace =
39       absl::StrCat(io::Basename(trace.file()), ":", trace.lineno());
40   if (trace.function().length() < 20) {
41     absl::StrAppend(&ntrace, ":", trace.function());
42   } else {
43     absl::StrAppend(&ntrace, ":", trace.function().substr(0, 17), "...");
44   }
45   return ntrace;
46 }
47 
IsGradNode(const string & name,string * forward_name)48 bool IsGradNode(const string& name, string* forward_name) {
49   // Given a forward operation with name op, its gradient op has the following
50   // name: ...gradients/op_grad/...
51   // TODO(xpan): This is hacky.
52   auto grad_prefix = name.find("gradients/");
53   auto grad_suffix = name.find("_grad/");
54   if (grad_prefix == name.npos || grad_suffix == name.npos) {
55     return false;
56   }
57   auto start = grad_prefix + string("gradients/").length();
58   auto len = grad_suffix - start;
59   if (len <= 0) {
60     return false;
61   }
62   *forward_name = name.substr(start, len);
63   return true;
64 }
65 
66 // StringTable maps each string to an id.
67 class StringTable {
68  public:
StringTable()69   StringTable() {
70     // Pprof requires first entry in string_table to be ''.
71     string_id_[""] = 0;
72     all_strings_.push_back("");
73   }
74 
75   // Returns the index of a string. If not found, inserts the string and
76   // return the inserted index.
GetIndex(const string & str)77   uint64 GetIndex(const string& str) {
78     auto idx = string_id_.find(str);
79     if (idx != string_id_.end()) {
80       return idx->second;
81     }
82     all_strings_.push_back(str);
83     return string_id_.insert(std::pair<string, int64_t>(str, string_id_.size()))
84         .first->second;
85   }
86 
strings() const87   const std::vector<string>& strings() const { return all_strings_; }
88 
89  private:
90   std::map<string, uint64> string_id_;
91   std::vector<string> all_strings_;
92 };
93 
94 // FunctionTable maps each function to an id.
95 class FunctionTable {
96  public:
FunctionTable(StringTable * string_table)97   explicit FunctionTable(StringTable* string_table)
98       : string_table_(string_table) {}
99 
100   // Returns the index of a function. If not found, adds a function proto
101   // and returns the function index.
GetIndex(const string & file_path,const string & func_name,uint64 func_start_line)102   uint64 GetIndex(const string& file_path, const string& func_name,
103                   uint64 func_start_line) {
104     auto key = std::tuple<string, string, uint64>(file_path, func_name,
105                                                   func_start_line);
106     auto idx = function_table_.find(key);
107     if (idx != function_table_.end()) {
108       return idx->second.id();
109     }
110     pprof::Function* func_pb = &function_table_[key];
111     // function index should start from 1.
112     func_pb->set_id(function_table_.size());
113 
114     string file_base(io::Basename(file_path));
115     file_base = file_base.substr(0, file_base.find_last_of('.'));
116     func_pb->set_name(
117         string_table_->GetIndex(absl::StrCat(file_base, ":", func_name)));
118     func_pb->set_filename(string_table_->GetIndex(file_path));
119     func_pb->set_start_line(func_start_line);
120     return func_pb->id();
121   }
122 
123   const std::map<std::tuple<string, string, uint64>, pprof::Function>&
functions() const124   functions() const {
125     return function_table_;
126   }
127 
128  private:
129   StringTable* string_table_;
130   std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
131 };
132 
133 // LocationTable maps each function call to an id.
134 class LocationTable {
135  public:
LocationTable(FunctionTable * function_table)136   explicit LocationTable(FunctionTable* function_table)
137       : function_table_(function_table) {}
138 
139   // Returns the index of a function call location. If not found, adds a
140   // location proto and returns the location index.
GetIndex(const string & file_path,uint64 line_number,const string & called_function_name,const string & called_file_path,uint64 called_func_start_line)141   uint64 GetIndex(const string& file_path, uint64 line_number,
142                   const string& called_function_name,
143                   const string& called_file_path,
144                   uint64 called_func_start_line) {
145     auto key = std::tuple<string, string, uint64>(
146         file_path, called_function_name, line_number);
147 
148     auto idx = location_table_.find(key);
149     if (idx != location_table_.end()) {
150       return idx->second.id();
151     }
152     pprof::Location* location_pb = &location_table_[key];
153     location_pb->set_id(location_table_.size());
154     pprof::Line* line_pb = location_pb->add_line();
155     line_pb->set_function_id(function_table_->GetIndex(
156         called_file_path, called_function_name, called_func_start_line));
157     line_pb->set_line(line_number);
158     return location_pb->id();
159   }
160 
161   const std::map<std::tuple<string, string, uint64>, pprof::Location>&
locations() const162   locations() const {
163     return location_table_;
164   }
165 
166  private:
167   FunctionTable* function_table_;
168   std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
169 };
170 
171 // Samples stores samples of all calls. A sample is a single call trace,
172 // that is, the call path from top caller to the leaf callee.
173 class Samples {
174  public:
Samples(StringTable * string_table,const Options * opts)175   explicit Samples(StringTable* string_table, const Options* opts)
176       : string_table_(string_table), opts_(opts) {}
177 
178   // 'node' is the leaf of the displayed trace. It includes all graph nodes
179   // created by it. 'location_ids' contains
180   // the call stack, from callee to caller.
181   // This method adds the statistics of graph nodes created by the python
182   // call.
Add(const CodeNode * node,const std::vector<uint64> & location_ids)183   void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
184     // displayed leaf might not be true leaf. Retrieve the true leaves for
185     // stats.
186     std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
187     CHECK(!all_leaf.empty()) << node->name();
188 
189     for (const CodeNode* cn : all_leaf) {
190       for (const auto& gn_it : cn->node->graph_nodes()) {
191         const TFGraphNode* gn = gn_it.second;
192         string name = gn->name();
193         // Generate a new trace name, in case the name is taken.
194         while (sample_table_.find(name) != sample_table_.end()) {
195           name += '@';
196         }
197         pprof::Sample* sample_pb = &sample_table_[name];
198         for (uint64 id : location_ids) {
199           sample_pb->mutable_location_id()->Add(id);
200         }
201         pprof::Label* label_pb = sample_pb->mutable_label()->Add();
202         label_pb->set_key(string_table_->GetIndex("graph node:"));
203         label_pb->set_str(string_table_->GetIndex(gn->name()));
204 
205         sample_pb->mutable_value()->Add(1);
206         string type = *opts_->select.begin();
207         if (type == kShown[1]) {
208           sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
209         } else if (type == kShown[9]) {
210           sample_pb->mutable_value()->Add(
211               gn->accelerator_exec_micros(node->node->step()));
212         } else if (type == kShown[10]) {
213           sample_pb->mutable_value()->Add(
214               gn->cpu_exec_micros(node->node->step()));
215         } else if (type == kShown[0]) {
216           sample_pb->mutable_value()->Add(
217               gn->requested_bytes(node->node->step()));
218         } else if (type == kShown[11]) {
219           sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
220         } else if (type == kShown[12]) {
221           sample_pb->mutable_value()->Add(
222               gn->residual_bytes(node->node->step()));
223         } else if (type == kShown[13]) {
224           sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
225         } else if (type == kShown[2]) {
226           sample_pb->mutable_value()->Add(gn->parameters());
227         } else if (type == kShown[3]) {
228           sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
229         } else {
230           absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", type);
231         }
232       }
233     }
234   }
235 
samples() const236   const std::map<string, pprof::Sample>& samples() const {
237     return sample_table_;
238   }
239 
240  private:
FetchAllLeaf(const CodeNode * root)241   std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
242     if (root->children.empty()) {
243       return {root};
244     }
245     std::vector<const CodeNode*> ret;
246     for (auto& n : root->children) {
247       std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
248       ret.insert(ret.end(), nodes.begin(), nodes.end());
249     }
250     return ret;
251   }
252 
253   StringTable* string_table_;
254   const Options* opts_;
255   std::map<string, pprof::Sample> sample_table_;
256 };
257 
258 class PprofProfileImpl : public PprofProfile {
259  public:
PprofProfileImpl(const Options * opts)260   explicit PprofProfileImpl(const Options* opts)
261       : opts_(opts),
262         func_table_(new FunctionTable(&string_table_)),
263         loc_table_(new LocationTable(func_table_.get())),
264         samples_(new Samples(&string_table_, opts)) {}
265 
AddLocation(const CodeNode * callee,const CodeNode * caller)266   uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
267     const string& file_path = caller->file();
268     uint64 lineno = caller->lineno();
269     const string& callee_file_path = callee->file();
270     const string& callee_function = callee->function();
271     uint64 callee_func_start_line = callee->func_start_line();
272 
273     return loc_table_->GetIndex(file_path, lineno, callee_function,
274                                 callee_file_path, callee_func_start_line);
275   }
276 
AddSample(const CodeNode * leaf,std::vector<uint64> * call_ids)277   void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
278     std::vector<uint64> reversed_call_ids;
279     std::reverse_copy(call_ids->begin(), call_ids->end(),
280                       std::back_inserter(reversed_call_ids));
281     samples_->Add(leaf, reversed_call_ids);
282   }
283 
WritePprofProfile(const string & filename)284   Status WritePprofProfile(const string& filename) override {
285     pprof::Profile profile_pb;
286     Build(&profile_pb);
287 
288     std::unique_ptr<WritableFile> file;
289     Status s = Env::Default()->NewWritableFile(filename, &file);
290     if (!s.ok()) return s;
291 
292     int32_t buf_size = 1024 * 1024;
293     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
294         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
295     s = zlib_output_buffer->Init();
296     if (!s.ok()) {
297       delete zlib_output_buffer;
298       return s;
299     }
300     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
301     if (!s.ok()) {
302       delete zlib_output_buffer;
303       return s;
304     }
305     s = zlib_output_buffer->Close();
306     if (!s.ok()) {
307       delete zlib_output_buffer;
308       return s;
309     }
310     absl::FPrintF(stdout,
311                   "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
312                   filename);
313     delete zlib_output_buffer;
314     return s;
315   }
316 
317  private:
Build(pprof::Profile * profile_pb)318   void Build(pprof::Profile* profile_pb) {
319     string sample_type_description = "count";
320     auto sample_type = profile_pb->mutable_sample_type()->Add();
321     sample_type->set_type(string_table_.GetIndex(sample_type_description));
322     sample_type->set_unit(string_table_.GetIndex("count"));
323 
324     string type = *opts_->select.begin();
325     sample_type_description = type;
326     sample_type = profile_pb->mutable_sample_type()->Add();
327     sample_type->set_type(string_table_.GetIndex(sample_type_description));
328     if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
329       sample_type->set_unit(string_table_.GetIndex("microseconds"));
330       if (type == kShown[1]) {
331         profile_pb->mutable_comment()->Add(string_table_.GetIndex(
332             "Sum of accelerator execution time and cpu execution time."));
333       } else if (type == kShown[9]) {
334         profile_pb->mutable_comment()->Add(
335             string_table_.GetIndex("Accelerator execution time."));
336       } else if (type == kShown[10]) {
337         profile_pb->mutable_comment()->Add(
338             string_table_.GetIndex("CPU execution time."));
339       }
340     } else if (type == kShown[0]) {
341       sample_type->set_unit(string_table_.GetIndex("bytes"));
342       profile_pb->mutable_comment()->Add(
343           string_table_.GetIndex("Sum of operation total memory requests, "
344                                  "excluding deallocations."));
345     } else if (type == kShown[11]) {
346       sample_type->set_unit(string_table_.GetIndex("bytes"));
347       profile_pb->mutable_comment()->Add(
348           string_table_.GetIndex("Sum of operation peak memory usage."));
349     } else if (type == kShown[12]) {
350       sample_type->set_unit(string_table_.GetIndex("bytes"));
351       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
352           "Sum of operation allocated memory after finish."));
353     } else if (type == kShown[13]) {
354       sample_type->set_unit(string_table_.GetIndex("bytes"));
355       profile_pb->mutable_comment()->Add(
356           string_table_.GetIndex("Sum of operation output size."));
357     } else if (type == kShown[2]) {
358       sample_type->set_unit(string_table_.GetIndex("count"));
359       profile_pb->mutable_comment()->Add(
360           string_table_.GetIndex("Model parameters."));
361     } else if (type == kShown[3]) {
362       sample_type->set_unit(string_table_.GetIndex("count"));
363       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
364           "Model float operations (Only available if defined)."));
365     } else {
366       absl::FPrintF(stderr, "pprof doesn't support selecting: %s\n", type);
367     }
368 
369     for (const string& str : string_table_.strings()) {
370       *profile_pb->mutable_string_table()->Add() = str;
371     }
372     for (const auto& sample_it : samples_->samples()) {
373       // TODO(xpan): Consider swap.
374       profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
375     }
376     for (const auto& function_it : func_table_->functions()) {
377       profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
378     }
379     for (const auto& location_it : loc_table_->locations()) {
380       profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
381     }
382   }
383 
384   const Options* opts_;
385   StringTable string_table_;
386   std::unique_ptr<FunctionTable> func_table_;
387   std::unique_ptr<LocationTable> loc_table_;
388   std::unique_ptr<Samples> samples_;
389 };
390 }  // namespace
391 
AddNode(TFGraphNode * node)392 void TFCode::AddNode(TFGraphNode* node) {
393   if (!node->call_stack() || node->call_stack()->traces().empty()) {
394     return;
395   }
396   // We infer the forward operation name from gradient op name. So, we can
397   // map gradient op traces to forward op traces.
398   // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
399   string forward_name;
400   if (IsGradNode(node->name(), &forward_name)) {
401     auto grad_nodes_it = grad_nodes_.find(forward_name);
402     if (grad_nodes_it != grad_nodes_.end()) {
403       grad_nodes_it->second.push_back(node);
404     } else {
405       grad_nodes_.insert(
406           std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
407     }
408     return;
409   } else {
410     forward_nodes_[node->name()] = node;
411   }
412 
413   if (!root_) {
414     graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
415     root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
416   }
417 
418   CodeNode* pre_code_node = root_.get();
419   // TODO(xpan): Consider to release CodeDef after TFCode is built. It
420   // takes a lot of memory.
421   std::set<string> traces;
422   for (int i = 0, end = node->call_stack()->traces().size(); i < end; ++i) {
423     // Unlike op name, which is globally unique, trace name is only unique
424     // w.r.t. it's parent.
425     const string& trace = GetTraceString(node->call_stack()->traces().at(i));
426     traces.insert(trace);
427     pre_code_node = pre_code_node->AddChildren(
428         trace, &node->call_stack()->traces().at(i), "");
429     const int64_t last_index = node->call_stack()->traces().size() - 1;
430     if (i == last_index) {
431       pre_code_node->node->AddGraphNode(node);
432     }
433   }
434 }
435 
Build()436 void TFCode::Build() {
437   int64_t unaccounted_nodes = 0;
438   for (const auto& it : grad_nodes_) {
439     const string& forward_name = it.first;
440     auto forward_it = forward_nodes_.find(forward_name);
441     if (forward_it == forward_nodes_.end()) {
442       unaccounted_nodes += 1;
443       continue;
444     }
445     TFGraphNode* fn = forward_it->second;
446     CodeNode* leaf = nullptr;
447     CodeNode* pre_code_node = root_.get();
448     for (int i = 0, end = fn->call_stack()->traces().size(); i < end; ++i) {
449       const string& trace =
450           GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
451       pre_code_node = pre_code_node->AddChildren(
452           trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
453       const int64_t last_trace = fn->call_stack()->traces().size() - 1;
454       if (i == last_trace) {
455         leaf = pre_code_node;
456       }
457     }
458     for (TFGraphNode* gn : it.second) {
459       leaf->node->AddGraphNode(gn);
460     }
461   }
462   if (unaccounted_nodes > 0) {
463     absl::FPrintF(stderr, "%d gradient nodes not accounted\n",
464                   unaccounted_nodes);
465   }
466 }
467 
ShowInternal(const Options & opts,Timeline * timeline)468 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
469                                           Timeline* timeline) {
470   root_->ResetTotalStats();
471   if (opts.output_type == kOutput[3]) {
472     if (opts.select.size() != 1) {
473       absl::FPrintF(stderr, "Can only select 1 attribute for pprof output.\n");
474       return root_.get();
475     }
476     string select = *opts.select.begin();
477     if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
478         select != kShown[3] && select != kShown[9] && select != kShown[10] &&
479         select != kShown[11] && select != kShown[12] && select != kShown[13]) {
480       absl::FPrintF(stderr, "pprof doesn't support -select=%s\n", select);
481       return root_.get();
482     }
483   }
484   if (opts.account_displayed_op_only) {
485     absl::FPrintF(stderr,
486                   "Note: code view ignores account_displayed_op_only\n");
487   }
488 
489   std::vector<CodeNode*> roots = Account(root_->children, opts);
490   root_->show_children.clear();
491   for (CodeNode* n : roots) {
492     root_->AggregateTotalStats(n);
493   }
494 
495   if (opts.start_name_regexes.size() != 1 ||
496       opts.start_name_regexes[0] != ".*") {
497     roots = SearchRoot(roots, opts.start_name_regexes);
498   }
499 
500   root_->show_children.assign(roots.begin(), roots.end());
501 
502   CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
503 
504   root->formatted_str = FormatLegend(opts) + root->formatted_str;
505 
506   if (opts.output_type == kOutput[3]) {
507     std::vector<uint64> call_ids;
508     pprof_profile_.reset(new PprofProfileImpl(&opts));
509     Format(root, root->show_children, opts, &root->formatted_str,
510            root->mutable_proto(), &call_ids);
511     Status s = pprof_profile_->WritePprofProfile(
512         opts.output_options.at(kPprofOpts[0]));
513     if (!s.ok()) {
514       absl::FPrintF(stderr, "%s\n", s.ToString());
515     }
516   } else {
517     Format(root, root->show_children, opts, &root->formatted_str,
518            root->mutable_proto(), nullptr);
519     if (timeline) {
520       timeline->GenerateCodeTimeline(root);
521     }
522   }
523   return root;
524 }
525 
Format(const CodeNode * root,const std::vector<CodeNode * > & nodes,const Options & opts,string * display_str,MultiGraphNodeProto * proto,std::vector<uint64> * call_ids)526 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
527                     const Options& opts, string* display_str,
528                     MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
529   if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
530     pprof_profile_->AddSample(root, call_ids);
531   }
532 
533   for (CodeNode* node : nodes) {
534     if (root->has_trace() && opts.output_type == kOutput[3]) {
535       uint64 loc_id = pprof_profile_->AddLocation(node, root);
536       call_ids->push_back(loc_id);
537     }
538     display_str->append(node->formatted_str);
539     MultiGraphNodeProto* child = proto->add_children();
540     child->MergeFrom(node->proto());
541     Format(node, node->show_children, opts, display_str, child, call_ids);
542     if (root->has_trace() && opts.output_type == kOutput[3]) {
543       call_ids->pop_back();
544     }
545   }
546 }
547 
SearchRoot(std::vector<CodeNode * > roots,const std::vector<string> & regexes)548 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
549                                           const std::vector<string>& regexes) {
550   std::vector<CodeNode*> res;
551   if (roots.empty()) {
552     return res;
553   }
554   for (CodeNode* root : roots) {
555     bool match_start_node = false;
556     for (const string& regex : regexes) {
557       if (RE2::FullMatch(root->name(), regex)) {
558         res.push_back(root);
559         match_start_node = true;
560         break;
561       }
562     }
563     if (match_start_node) {
564       // Found a start node at this branch, no need to continue.
565       continue;
566     }
567     std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
568     res.insert(res.end(), nroots.begin(), nroots.end());
569   }
570   return res;
571 }
572 
PrintScope(const std::vector<CodeNode * > roots,const Options & opts,int depth,int last_ident)573 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
574                                           const Options& opts, int depth,
575                                           int last_ident) {
576   std::vector<CodeNode*> show_nodes;
577 
578   for (CodeNode* node : roots) {
579     if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
580       continue;
581     }
582     int ident = last_ident;
583     bool show = ShouldShow(node, opts, depth);
584     if (show) ident += 2;
585 
586     std::vector<CodeNode*> show_cnodes =
587         PrintScope(node->show_children, opts, depth + 1, ident);
588     if (show) {
589       node->show_children.clear();
590 
591       show_cnodes = SortNodes(show_cnodes, opts);
592       for (CodeNode* sc : show_cnodes) {
593         node->show_children.push_back(sc);
594       }
595 
596       node->formatted_str = FormatNode(node, opts, last_ident);
597 
598       if (opts.select.find(kShown[4]) != opts.select.end()) {
599         absl::FPrintF(stderr, "code view has no tensor value to show\n");
600       }
601       show_nodes.push_back(node);
602     } else {
603       show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
604                         show_cnodes.end());
605     }
606   }
607   return show_nodes;
608 }
609 
Account(const std::vector<CodeNode * > & roots,const Options & opts)610 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
611                                        const Options& opts) {
612   std::vector<CodeNode*> act_nodes;
613 
614   for (CodeNode* node : roots) {
615     node->ResetTotalStats();
616     std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
617     node->account = ReAccount(node, opts);
618     if (node->account || !act_cnodes.empty()) {
619       node->show_children.clear();
620       node->ResetTotalStats();
621       node->AddSelfToTotalStats();
622       for (CodeNode* c : act_cnodes) {
623         node->AggregateTotalStats(c);
624         node->show_children.push_back(c);
625       }
626       act_nodes.push_back(node);
627     }
628   }
629   return act_nodes;
630 }
631 
FormatNodeMemory(CodeNode * node,int64_t bytes,int64_t total_bytes) const632 string TFCode::FormatNodeMemory(CodeNode* node, int64_t bytes,
633                                 int64_t total_bytes) const {
634   string memory = FormatMemory(total_bytes);
635   if (node->account) {
636     memory = FormatMemory(bytes) + "/" + memory;
637   } else {
638     memory = "--/" + memory;
639   }
640   return memory;
641 }
642 
FormatNode(CodeNode * node,const Options & opts,int64_t indent) const643 string TFCode::FormatNode(CodeNode* node, const Options& opts,
644                           int64_t indent) const {
645   std::vector<string> attrs;
646   if (opts.select.find(kShown[0]) != opts.select.end()) {
647     attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
648                                      node->proto().total_requested_bytes()));
649   }
650   if (opts.select.find(kShown[11]) != opts.select.end()) {
651     attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
652                                      node->proto().total_peak_bytes()));
653   }
654   if (opts.select.find(kShown[12]) != opts.select.end()) {
655     attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
656                                      node->proto().total_residual_bytes()));
657   }
658   if (opts.select.find(kShown[13]) != opts.select.end()) {
659     attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
660                                      node->proto().total_output_bytes()));
661   }
662 
663   std::vector<string> time_attrs = FormatTimes(node, opts);
664   attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
665 
666   if (opts.select.find(kShown[2]) != opts.select.end()) {
667     string params = FormatNumber(node->proto().total_parameters()) + " params";
668     if (node->account) {
669       params = FormatNumber(node->proto().parameters()) + "/" + params;
670     } else {
671       params = "--/" + params;
672     }
673     attrs.push_back(params);
674   }
675 
676   if (opts.select.find(kShown[3]) != opts.select.end()) {
677     string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
678     if (node->account) {
679       fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
680     } else {
681       fops = "--/" + fops;
682     }
683     attrs.push_back(fops);
684   }
685 
686   if (opts.select.find(kShown[5]) != opts.select.end() &&
687       !node->node->devices().empty()) {
688     attrs.push_back(absl::StrJoin(node->node->devices(), "|"));
689   }
690   if (opts.select.find(kShown[6]) != opts.select.end()) {
691     std::set<string> op_types = node->node->op_types();
692     attrs.push_back(absl::StrJoin(op_types, "|"));
693   }
694   if (opts.select.find(kShown[7]) != opts.select.end()) {
695     // TODO(xpan): Make op count available in code view?
696     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[7]));
697   }
698   if (opts.select.find(kShown[8]) != opts.select.end()) {
699     attrs.push_back(absl::StrFormat("%s N/A in code view", kShown[8]));
700   }
701 
702   return absl::StrFormat("%s%s (%s)\n", std::string(indent, ' '), node->name(),
703                          absl::StrJoin(attrs, ", "));
704 }
705 }  // namespace tfprof
706 }  // namespace tensorflow
707