xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/util.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/function.h>
2 #include <torch/csrc/profiler/kineto_shim.h>
3 #include <torch/csrc/profiler/util.h>
4 
5 #include <c10/util/ArrayRef.h>
6 #include <c10/util/irange.h>
7 #include <fmt/format.h>
8 #include <fmt/ranges.h>
9 
10 #ifdef USE_KINETO
11 #include <libkineto.h>
12 #endif
13 #ifdef USE_DISTRIBUTED
14 #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
15 #endif // USE_DISTRIBUTED
16 
17 namespace torch::profiler::impl {
18 
19 namespace {
20 std::optional<bool> soft_assert_raises_;
21 } // namespace
22 
setSoftAssertRaises(std::optional<bool> value)23 void setSoftAssertRaises(std::optional<bool> value) {
24   soft_assert_raises_ = value;
25 }
26 
softAssertRaises()27 bool softAssertRaises() {
28   return soft_assert_raises_.value_or(false);
29 }
30 
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,const char * args)31 void logSoftAssert(
32     const char* func,
33     const char* file,
34     uint32_t line,
35     const char* cond,
36     const char* args) {
37 #ifdef USE_KINETO
38   std::string error;
39   error = fmt::format(
40       "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
41       cond,
42       file,
43       line,
44       func,
45       args);
46   // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
47   kineto::logInvariantViolation(cond, error, "", "");
48 #endif
49 }
50 
logSoftAssert(const char * func,const char * file,uint32_t line,const char * cond,const std::string & args)51 void logSoftAssert(
52     const char* func,
53     const char* file,
54     uint32_t line,
55     const char* cond,
56     const std::string& args) {
57 #ifdef USE_KINETO
58   std::string error;
59   error = fmt::format(
60       "{} SOFT ASSERT FAILED at {}:{}, func: {}, args: {}",
61       cond,
62       file,
63       line,
64       func,
65       args);
66   // TODO: Implement profile_id and group_profile_id as 3rd/4th arguments.
67   kineto::logInvariantViolation(cond, error, "", "");
68 #endif
69 }
70 
71 // ----------------------------------------------------------------------------
72 // -- NVTX --------------------------------------------------------------------
73 // ----------------------------------------------------------------------------
getNvtxStr(const char * name,int64_t sequence_nr,const std::vector<std::vector<int64_t>> & shapes,at::RecordFunctionHandle op_id,const std::list<std::pair<at::RecordFunctionHandle,int>> & input_op_ids)74 std::string getNvtxStr(
75     const char* name,
76     int64_t sequence_nr,
77     const std::vector<std::vector<int64_t>>& shapes,
78     at::RecordFunctionHandle op_id,
79     const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
80   if (sequence_nr >= -1 || !shapes.empty()) {
81     std::string str;
82     if (sequence_nr >= 0) {
83       str = fmt::format("{}, seq = {}", name, sequence_nr);
84     } else if (sequence_nr == -1) {
85       str = name;
86     } else {
87 #if defined(USE_ROCM)
88       // Only ROCM supports < -1 sequence_nr
89       str = name;
90 #endif
91     }
92     if (op_id > 0) {
93       str = fmt::format("{}, op_id = {}", str, op_id);
94     }
95     if (!shapes.empty()) {
96       str = fmt::format("{}, sizes = {}", str, shapesToStr(shapes));
97     }
98     // Include the op ids of the input edges so
99     // you can build the network graph
100     if (!input_op_ids.empty()) {
101       str = fmt::format(
102           "{}, input_op_ids = {}", str, inputOpIdsToStr(input_op_ids));
103     }
104     return str;
105   } else {
106     return name;
107   }
108 }
109 
110 // ----------------------------------------------------------------------------
111 // -- Op context (shapes, call stack) -----------------------------------------
112 // ----------------------------------------------------------------------------
prepareCallstack(const std::vector<jit::StackEntry> & cs)113 std::vector<FileLineFunc> prepareCallstack(
114     const std::vector<jit::StackEntry>& cs) {
115   std::vector<FileLineFunc> entries;
116   entries.reserve(cs.size());
117   for (const auto& entry : cs) {
118     auto& range = entry.range;
119     if (range.source()) {
120       auto& src = range.source();
121       if (src && src->filename()) {
122         auto line =
123             src->starting_line_no() + src->lineno_for_offset(range.start());
124         entries.emplace_back(
125             // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
126             FileLineFunc{*(src->filename()), line, entry.filename});
127       }
128     }
129   }
130   return entries;
131 }
132 
callstackStr(const std::vector<FileLineFunc> & cs)133 std::vector<std::string> callstackStr(const std::vector<FileLineFunc>& cs) {
134   std::vector<std::string> cs_str;
135   cs_str.reserve(cs.size());
136   for (const auto& entry : cs) {
137     std::stringstream loc;
138     loc << entry.filename << "(" << entry.line << "): " << entry.funcname;
139     cs_str.push_back(loc.str());
140   }
141   return cs_str;
142 }
143 
stacksToStr(const std::vector<std::string> & stacks,const char * delim)144 std::string stacksToStr(
145     const std::vector<std::string>& stacks,
146     const char* delim) {
147   std::ostringstream oss;
148   std::transform(
149       stacks.begin(),
150       stacks.end(),
151       std::ostream_iterator<std::string>(oss, delim),
152       [](std::string s) -> std::string {
153 #ifdef _WIN32
154         // replace the windows backslash with forward slash
155         std::replace(s.begin(), s.end(), '\\', '/');
156 #endif
157         return s;
158       });
159   auto rc = oss.str();
160   return "\"" + rc + "\"";
161 }
162 
flattenList(const c10::List<c10::IValue> & list)163 static std::vector<std::vector<int64_t>> flattenList(
164     const c10::List<c10::IValue>& list) {
165   std::vector<std::vector<int64_t>> tensor_dims;
166   for (const c10::IValue& input : list) {
167     if (input.isTensor()) {
168       const at::Tensor& tensor = input.toTensor();
169       if (tensor.defined()) {
170         tensor_dims.push_back(input.toTensor().sizes().vec());
171       }
172     }
173   }
174   return tensor_dims;
175 }
176 
inputSizes(const at::RecordFunction & fn,bool flatten_list_enabled)177 std::vector<std::vector<int64_t>> inputSizes(
178     const at::RecordFunction& fn,
179     bool flatten_list_enabled) {
180   std::vector<std::vector<int64_t>> sizes;
181   sizes.reserve(fn.inputs().size());
182   for (const c10::IValue& input : fn.inputs()) {
183     if (input.isTensor()) {
184       const at::Tensor& tensor = input.toTensor();
185       if (tensor.defined()) {
186         sizes.push_back(input.toTensor().sizes().vec());
187       } else {
188         sizes.emplace_back();
189       }
190     } else if (input.isList()) {
191       std::vector<std::vector<int64_t>> tmp_sizes;
192       if (flatten_list_enabled) {
193         tmp_sizes = flattenList(input.toList());
194       }
195       // Extend the current sizes array by the array returned from input sizes
196       if (!tmp_sizes.empty()) {
197         sizes.insert(sizes.end(), tmp_sizes.begin(), tmp_sizes.end());
198       } else {
199         sizes.emplace_back();
200       }
201     } else {
202       sizes.emplace_back();
203     }
204   }
205   return sizes;
206 }
207 
shapesToStr(const std::vector<std::vector<int64_t>> & shapes)208 std::string shapesToStr(const std::vector<std::vector<int64_t>>& shapes) {
209   std::string str("[");
210   for (const auto t_idx : c10::irange(shapes.size())) {
211     if (t_idx > 0) {
212       str = fmt::format("{}, ", str);
213     }
214     str = fmt::format("{}{}", str, shapeToStr(shapes[t_idx]));
215   }
216   str = fmt::format("{}]", str);
217   return str;
218 }
219 
variantShapesToStr(const std::vector<shape> & shapes)220 std::string variantShapesToStr(const std::vector<shape>& shapes) {
221   std::string str("[");
222   for (const auto t_idx : c10::irange(shapes.size())) {
223     if (t_idx > 0) {
224       str = fmt::format("{}, ", str);
225     }
226     if (std::holds_alternative<std::vector<int64_t>>(shapes[t_idx])) {
227       const auto& shape = std::get<std::vector<int64_t>>(shapes[t_idx]);
228       str = fmt::format("{}{}", str, shapeToStr(shape));
229     } else if (std::holds_alternative<std::vector<std::vector<int64_t>>>(
230                    shapes[t_idx])) {
231       const auto& tensor_shape =
232           std::get<std::vector<std::vector<int64_t>>>(shapes[t_idx]);
233       if (tensor_shape.size() > TENSOR_LIST_DISPLAY_LENGTH_LIMIT) {
234         // skip if the tensor list is too long
235         str = fmt::format("{}[]", str);
236         continue;
237       }
238       str = fmt::format("{}[", str);
239       for (const auto s_idx : c10::irange(tensor_shape.size())) {
240         if (s_idx > 0) {
241           str = fmt::format("{}, ", str);
242         }
243         str = fmt::format("{}{}", str, shapeToStr(tensor_shape[s_idx]));
244       }
245       str = fmt::format("{}]", str);
246     }
247   }
248   str = fmt::format("{}]", str);
249   return str;
250 }
251 
shapeToStr(const std::vector<int64_t> & shape)252 std::string shapeToStr(const std::vector<int64_t>& shape) {
253   std::string str("[");
254   for (const auto s_idx : c10::irange(shape.size())) {
255     if (s_idx > 0) {
256       str = fmt::format("{}, ", str);
257     }
258     str = fmt::format("{}{}", str, shape[s_idx]);
259   }
260   str = fmt::format("{}]", str);
261   return str;
262 }
263 
inputOpIdsToStr(const std::list<std::pair<at::RecordFunctionHandle,int>> & input_op_ids)264 std::string inputOpIdsToStr(
265     const std::list<std::pair<at::RecordFunctionHandle, int>>& input_op_ids) {
266   std::string str("[");
267   int idx = 0;
268 
269   for (const auto& op_id_info_pair : input_op_ids) {
270     if (idx++ > 0) {
271       str = fmt::format("{}, ", str);
272     }
273     // (OpId,OutputNr)
274     str = fmt::format(
275         "{}({},{})", str, op_id_info_pair.first, op_id_info_pair.second);
276   }
277   str = fmt::format("{}]", str);
278   return str;
279 }
280 
strListToStr(const std::vector<std::string> & types)281 std::string strListToStr(const std::vector<std::string>& types) {
282   if (types.empty()) {
283     return "[]";
284   } else {
285     std::ostringstream oss;
286     std::transform(
287         types.begin(),
288         types.end(),
289         std::ostream_iterator<std::string>(oss, ", "),
290         [](const std::string& s) -> std::string { return "\"" + s + "\""; });
291     auto rc = oss.str();
292     rc.erase(rc.length() - 2); // remove last ", "
293     return "[" + rc + "]";
294   }
295 }
ivalueToStr(const c10::IValue & val,bool isString)296 std::string ivalueToStr(const c10::IValue& val, bool isString) {
297   std::stringstream ss;
298   if (val.isNone()) {
299     return "\"None\"";
300   } else {
301     ss.str("");
302     if (isString) {
303       ss << "\"";
304     }
305     ss << val;
306     if (isString) {
307       ss << "\"";
308     }
309     std::string mystr = ss.str();
310 
311     // A double quote can cause issues with the chrome tracing so force
312     // all inputs to not contain more than the 2 we add in this function
313     int count = std::count(mystr.begin(), mystr.end(), '\"');
314     return count > 2 ? "\"None\"" : mystr;
315   }
316 }
317 
ivalueListToStr(const std::vector<c10::IValue> & list)318 std::string ivalueListToStr(const std::vector<c10::IValue>& list) {
319   std::vector<std::string> concrete_str_inputs;
320   std::stringstream ss;
321   for (const auto& val : list) {
322     if (val.isNone()) {
323       concrete_str_inputs.emplace_back("");
324     } else {
325       ss.str("");
326       ss << val;
327       concrete_str_inputs.emplace_back(ss.str());
328     }
329   }
330   return strListToStr(concrete_str_inputs);
331 }
332 
inputTypes(const at::RecordFunction & fn)333 std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
334   std::vector<std::string> types;
335   types.reserve(fn.inputs().size());
336   for (const c10::IValue& input : fn.inputs()) {
337     if (input.isTensor()) {
338       const at::Tensor& tensor = input.toTensor();
339       if (tensor.defined()) {
340         types.push_back(
341             static_cast<std::string>(input.toTensor().dtype().name()));
342       } else {
343         types.emplace_back();
344       }
345     } else if (input.isScalar() || input.isList()) {
346       types.push_back(input.tagKind());
347     } else {
348       types.emplace_back();
349     }
350   }
351   return types;
352 }
353 
354 // ----------------------------------------------------------------------------
355 // -- NCCL Metadata -----------------------------------------------------------
356 // ----------------------------------------------------------------------------
357 
358 static constexpr int32_t kTruncatLength = 30;
359 
360 template <typename ListLikeType>
format_list(ListLikeType list,bool truncate)361 inline std::string format_list(ListLikeType list, bool truncate) {
362   if (truncate && list.size() > kTruncatLength) {
363     return fmt::format(
364         "\"[{}, ...]\"",
365         fmt::join(list.begin(), list.begin() + kTruncatLength, ", "));
366   }
367   return fmt::format("\"[{}]\"", fmt::join(list.begin(), list.end(), ", "));
368 }
369 
saveNcclMeta(const at::RecordFunction & fn,bool truncate)370 std::unordered_map<std::string, std::string> saveNcclMeta(
371     const at::RecordFunction& fn,
372     bool truncate) {
373   std::unordered_map<std::string, std::string> map;
374 #ifdef USE_DISTRIBUTED
375   auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
376       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
377   if (debugInfo == nullptr) {
378     LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
379                  << fn.name();
380     return map;
381   }
382 
383   auto& collective_name = debugInfo->getCollectiveName();
384   map.emplace(kCommsName, fmt::format("\"{}\"", collective_name));
385   map.emplace(
386       kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
387   map.emplace(kInMsgNelems, std::to_string(debugInfo->getInMessageNelems()));
388   map.emplace(kOutMsgNelems, std::to_string(debugInfo->getOutMessageNelems()));
389 
390   auto& inSplitSizes = debugInfo->getInputSplitSizes();
391   map.emplace(kInSplit, format_list(inSplitSizes, truncate));
392 
393   auto& outSplitSizes = debugInfo->getOutputSplitSizes();
394   map.emplace(kOutSplit, format_list(outSplitSizes, truncate));
395 
396   auto globalRankStart = debugInfo->getGlobalRankStart();
397   if (globalRankStart >= 0) {
398     map.emplace(kGlobalRankStart, std::to_string(globalRankStart));
399   }
400   auto globalRankStride = debugInfo->getGlobalRankStride();
401   if (globalRankStride > 0) {
402     map.emplace(kGlobalRankStride, std::to_string(globalRankStride));
403   }
404   map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
405   auto& group_name = debugInfo->getProcessGroupName();
406   if (!group_name.empty()) {
407     map.emplace(kProcessGroupName, fmt::format("\"{}\"", group_name));
408   }
409   auto& group_desc = debugInfo->getProcessGroupDesc();
410   if (!group_desc.empty()) {
411     map.emplace(kProcessGroupDesc, fmt::format("\"{}\"", group_desc));
412   }
413   auto& groupRanks = debugInfo->getGroupRanks();
414   map.emplace(kGroupRanks, format_list(groupRanks, truncate));
415 
416   auto rank = debugInfo->getRank();
417   map.emplace(kRank, std::to_string(rank));
418   int nRanks = static_cast<int>(groupRanks.size());
419   if (collective_name == "send") {
420     if (rank >= 0 && rank < nRanks) {
421       map.emplace(kP2pDst, std::to_string(groupRanks[rank]));
422     }
423   } else if (collective_name == "recv") {
424     if (rank >= 0 && rank < nRanks) {
425       map.emplace(kP2pSrc, std::to_string(groupRanks[rank]));
426     }
427   }
428 #endif // USE_DISTRIBUTED
429   return map;
430 }
431 
432 // ----------------------------------------------------------------------------
433 // -- FLOPS -------------------------------------------------------------------
434 // ----------------------------------------------------------------------------
435 static constexpr auto kConv2dStride = 3;
436 static constexpr auto kConv2dPadding = 4;
437 static constexpr auto kConv2dDilation = 5;
438 static constexpr auto kConv2dGroups = 6;
439 
440 // List of supported operators
441 static constexpr auto kConv2dOp = "aten::conv2d";
442 static constexpr auto kMMOp = "aten::mm";
443 static constexpr auto kAddMMOp = "aten::addmm";
444 static constexpr auto kMulOp = "aten::mul";
445 static constexpr auto kAddOp = "aten::add";
446 static constexpr auto kBMMOp = "aten::bmm";
447 static constexpr auto kBAddBMMOp = "aten::baddbmm";
448 
449 static constexpr auto kInputSize = "input_size";
450 static constexpr auto kWeightSize = "weight_size";
451 static constexpr auto kGroups = "groups";
452 static constexpr auto kPadding = "padding";
453 static constexpr auto kStride = "stride";
454 static constexpr auto kDilation = "dilation";
455 static constexpr auto kMatSize = "mat_size";
456 static constexpr auto kMat1Size = "mat1_size";
457 static constexpr auto kMat2Size = "mat2_size";
458 
getInputSizes(const std::string & op_name,size_t min_size,c10::ArrayRef<const c10::IValue> inputs,const c10::ArrayRef<int> & should_be_tensor)459 static std::vector<c10::IntArrayRef> getInputSizes(
460     const std::string& op_name,
461     size_t min_size,
462     c10::ArrayRef<const c10::IValue> inputs,
463     const c10::ArrayRef<int>& should_be_tensor) {
464   std::stringstream ss;
465   if (inputs.size() < min_size) {
466     ss << "Failed to save extra arguments for flops computation of op "
467        << op_name << ", min size: " << min_size
468        << ", actual size: " << inputs.size();
469     TORCH_WARN(ss.str());
470     return {};
471   }
472   std::vector<c10::IntArrayRef> inputSizes = {};
473   for (auto index : should_be_tensor) {
474     if (!inputs[index].isTensor()) {
475       ss << "Failed to save extra arguments for flops computation of op "
476          << op_name << ", input[" << index << "] must be a tensor.";
477       TORCH_WARN(ss.str());
478       return {};
479     }
480     at::Tensor t = inputs[index].toTensor();
481     if (t.is_nested()) {
482       ss << "Failed to save extra arguments for flops computation of op "
483          << op_name << " with input[" << index << "] as nested tensor.";
484       TORCH_WARN(ss.str());
485       return {};
486     }
487     inputSizes.emplace_back(t.sizes());
488   }
489   return inputSizes;
490 }
491 
saveExtraArgs(const at::RecordFunction & fn)492 std::unordered_map<std::string, c10::IValue> saveExtraArgs(
493     const at::RecordFunction& fn) {
494   // for specific types of fn, return the saved extra args for computing flops
495   std::unordered_map<std::string, c10::IValue> map;
496   auto inputs = fn.inputs();
497   std::string fname(fn.name());
498 
499   if (inputs.empty()) {
500     // Input shape is unavailable, return empty map
501     return map;
502   }
503 
504   if (fname == kConv2dOp) {
505     const auto inputSizes =
506         getInputSizes(fname, kConv2dGroups + 1, inputs, {0, 1});
507     if (inputSizes.empty()) {
508       return map;
509     }
510     if (inputSizes[1].size() != 4) {
511       TORCH_WARN(
512           "Failed to compute flops for op aten::conv2d because it requires a 4D kernel tensor.");
513       return map;
514     }
515     map[kInputSize] = at::IValue(inputSizes[0]);
516     map[kWeightSize] = at::IValue(inputSizes[1]);
517     map[kStride] = inputs[kConv2dStride];
518     map[kPadding] = inputs[kConv2dPadding];
519     map[kDilation] = inputs[kConv2dDilation];
520     map[kGroups] = inputs[kConv2dGroups];
521   } else if (fname == kMMOp) {
522     const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
523     if (inputSizes.empty()) {
524       return map;
525     }
526 
527     map[kMat1Size] = at::IValue(inputSizes[0]);
528     map[kMat2Size] = at::IValue(inputSizes[1]);
529   } else if (fname == kAddMMOp) {
530     const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
531     if (inputSizes.empty()) {
532       return map;
533     }
534     // Exact FLOP count depends on scaling factors alpha and beta but
535     // just assume these are +=1.
536     // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
537     // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
538     map[kMat1Size] = at::IValue(inputSizes[1]);
539     map[kMat2Size] = at::IValue(inputSizes[2]);
540   } else if (fname == kMulOp) {
541     const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
542     if (inputSizes.empty()) {
543       return map;
544     }
545     map[kMatSize] = at::IValue(inputSizes[0]);
546   } else if (fname == kAddOp) {
547     const auto inputSizes = getInputSizes(fname, 1, inputs, {0});
548     if (inputSizes.empty()) {
549       return map;
550     }
551     map[kMatSize] = at::IValue(inputSizes[0]);
552   } else if (fname == kBMMOp) {
553     const auto inputSizes = getInputSizes(fname, 2, inputs, {0, 1});
554     if (inputSizes.empty()) {
555       return map;
556     }
557 
558     map[kMat1Size] = at::IValue(inputSizes[0]);
559     map[kMat2Size] = at::IValue(inputSizes[1]);
560   } else if (fname == kBAddBMMOp) {
561     const auto inputSizes = getInputSizes(fname, 3, inputs, {0, 1, 2});
562     if (inputSizes.empty()) {
563       return map;
564     }
565 
566     // Exact FLOP count depends on scaling factors alpha and beta but
567     // just assume these are +=1.
568     // (similar to http://www.netlib.org/lapack/lawnspdf/lawn41.pdf,
569     // "Operations Count for the BLAS and LAPACK", Table 3, SGEMM)
570     map[kMat1Size] = at::IValue(inputSizes[1]);
571     map[kMat2Size] = at::IValue(inputSizes[2]);
572   }
573 
574   return map;
575 }
576 
computeFlops(const std::string & op_name,const std::unordered_map<std::string,c10::IValue> & extra_args)577 uint64_t computeFlops(
578     const std::string& op_name,
579     const std::unordered_map<std::string, c10::IValue>& extra_args) {
580   if (op_name == kConv2dOp) {
581     if (extra_args.find(kInputSize) == extra_args.end() ||
582         extra_args.find(kWeightSize) == extra_args.end() ||
583         extra_args.find(kGroups) == extra_args.end() ||
584         extra_args.find(kPadding) == extra_args.end() ||
585         extra_args.find(kStride) == extra_args.end() ||
586         extra_args.find(kDilation) == extra_args.end()) {
587       TORCH_WARN(
588           "Calculating flops for aten::conv2d requires groups, padding, stride, dilation, input_size, and weight_size in saved arguments.");
589       return 0;
590     }
591     auto input_sizes_ref = extra_args.at(kInputSize);
592     auto kernel_sizes_ref = extra_args.at(kWeightSize);
593     auto groups_ref = extra_args.at(kGroups);
594     auto padding_ref = extra_args.at(kPadding);
595     auto stride_ref = extra_args.at(kStride);
596     auto dilation_ref = extra_args.at(kDilation);
597     if (!input_sizes_ref.isIntList() || !kernel_sizes_ref.isIntList()) {
598       TORCH_WARN(
599           "Failed to compute flops for op aten::conv2d because it requires input and weight tensor sizes.");
600       return 0;
601     }
602     if (!padding_ref.isIntList() || !stride_ref.isIntList() ||
603         !dilation_ref.isIntList()) {
604       TORCH_WARN(
605           "Failed to compute flops for op aten::conv2d because it requires padding, stride, and dilation values.");
606       return 0;
607     }
608 
609     const auto input_sizes = input_sizes_ref.toDimVector();
610     const auto kernel_sizes = kernel_sizes_ref.toDimVector();
611     const uint64_t groups = groups_ref.toInt();
612     const std::vector<int64_t> padding = padding_ref.toIntVector();
613     const std::vector<int64_t> stride = stride_ref.toIntVector();
614     const std::vector<int64_t> dilation = dilation_ref.toIntVector();
615     if (input_sizes.size() != 4 || kernel_sizes.size() != 4) {
616       TORCH_WARN(
617           "Failed to compute flops for op aten::conv2d because both input and weight must be size 4.");
618       return 0;
619     }
620     if (!groups) {
621       TORCH_WARN(
622           "Failed to compute flops for op aten::conv2d because group size must not be 0.");
623       return 0;
624     }
625     if (padding.size() != 2 || dilation.size() != 2) {
626       TORCH_WARN(
627           "Failed to compute flops for op aten::conv2d because both padding and dilation must be size 2.");
628       return 0;
629     }
630     if (stride.size() != 2 || (stride[0] * stride[1] == 0)) {
631       TORCH_WARN(
632           "Failed to compute flops for op aten::conv2d because stride must be size 2 and cannot be 0.");
633       return 0;
634     }
635     // format of the input is defined in
636     // torch.ao.nn.quantized.functional.conv2d()
637     const uint64_t conv2d_multiply_factor = 2;
638     auto [minibatch, in_channels, input_h, input_w] = std::make_tuple(
639         input_sizes[0], input_sizes[1], input_sizes[2], input_sizes[3]);
640     auto [out_channels, _, kernel_h, kernel_w] = std::make_tuple(
641         kernel_sizes[0], kernel_sizes[1], kernel_sizes[2], kernel_sizes[3]);
642     uint64_t output_h =
643         (input_h + 2 * padding[0] - dilation[0] * (kernel_h - 1) - 1) /
644             stride[0] +
645         1;
646     uint64_t output_w =
647         (input_w + 2 * padding[1] - dilation[1] * (kernel_w - 1) - 1) /
648             stride[1] +
649         1;
650 
651     return conv2d_multiply_factor * minibatch * output_h * output_w * kernel_h *
652         kernel_w * in_channels * out_channels / groups;
653   } else if (op_name == kMMOp || op_name == kAddMMOp) {
654     if (extra_args.find(kMat1Size) == extra_args.end() ||
655         extra_args.find(kMat2Size) == extra_args.end()) {
656       TORCH_WARN(
657           "Calculating flops for ",
658           op_name,
659           " requires mat1_size and mat2_size in saved arguments.");
660       return 0;
661     }
662     auto mat1_sizes_ref = extra_args.at(kMat1Size);
663     auto mat2_sizes_ref = extra_args.at(kMat2Size);
664     if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
665       TORCH_WARN(
666           "Failed to compute flops for op ",
667           op_name,
668           " because it requires mat1_size and mat2_size to be IntList.");
669       return 0;
670     }
671 
672     const auto mat1_size = mat1_sizes_ref.toDimVector();
673     const auto mat2_size = mat2_sizes_ref.toDimVector();
674     if (mat1_size.empty()) {
675       return 0;
676     }
677 
678     int64_t overlap_dim = mat1_size.back();
679     if (overlap_dim == 0) {
680       return 0;
681     }
682 
683     const uint64_t gemm_multiply_factor = 2;
684     uint64_t flops = 1;
685     for (int64_t dim : mat1_size) {
686       flops *= dim;
687     }
688     flops /= overlap_dim;
689     for (int64_t dim : mat2_size) {
690       flops *= dim;
691     }
692     flops *= gemm_multiply_factor;
693     return flops;
694   } else if (op_name == kBMMOp || op_name == kBAddBMMOp) {
695     if (extra_args.find(kMat1Size) == extra_args.end() ||
696         extra_args.find(kMat2Size) == extra_args.end()) {
697       TORCH_WARN(
698           "Calculating flops for ",
699           op_name,
700           " requires mat1_size and mat2_size in saved arguments.");
701       return 0;
702     }
703     auto mat1_sizes_ref = extra_args.at(kMat1Size);
704     auto mat2_sizes_ref = extra_args.at(kMat2Size);
705     if (!mat1_sizes_ref.isIntList() || !mat2_sizes_ref.isIntList()) {
706       TORCH_WARN(
707           "Failed to compute flops for op ",
708           op_name,
709           " because it requires mat1_size and mat2_size to be IntList.");
710       return 0;
711     }
712 
713     const auto mat1_size = mat1_sizes_ref.toDimVector();
714     const auto mat2_size = mat2_sizes_ref.toDimVector();
715     if (mat1_size.empty()) {
716       return 0;
717     }
718 
719     int64_t batch_size = mat1_size.front();
720     if (batch_size == 0) {
721       return 0;
722     }
723 
724     int64_t overlap_dim = mat1_size.back();
725     if (overlap_dim == 0) {
726       return 0;
727     }
728 
729     const uint64_t gemm_multiply_factor = 2;
730     uint64_t flops = 1;
731     for (int64_t dim : mat1_size) {
732       flops *= dim;
733     }
734     flops /= overlap_dim;
735     flops /= batch_size;
736     for (int64_t dim : mat2_size) {
737       flops *= dim;
738     }
739     flops *= gemm_multiply_factor;
740     return flops;
741   } else if (op_name == kMulOp) {
742     if (extra_args.find(kMatSize) == extra_args.end()) {
743       TORCH_WARN(
744           "Calculating flops for aten::mul.Tensor requires mat_size in saved arguments.");
745       return 0;
746     }
747     auto mat_sizes = extra_args.at(kMatSize);
748     if (!mat_sizes.isIntList()) {
749       TORCH_WARN(
750           "Failed to compute flops for op aten::mul because it requires mat_size to be IntList.");
751       return 0;
752     }
753 
754     const auto mat_size = mat_sizes.toDimVector();
755     uint64_t flops = 1;
756     for (int64_t dim : mat_size) {
757       flops *= dim;
758     }
759     return flops;
760   } else if (op_name == kAddOp) {
761     if (extra_args.find(kMatSize) == extra_args.end()) {
762       TORCH_WARN(
763           "Calculating flops for aten::add.Tensor requires mat_size in saved arguments.");
764       return 0;
765     }
766     auto mat_sizes = extra_args.at(kMatSize);
767     if (!mat_sizes.isIntList()) {
768       TORCH_WARN(
769           "Failed to compute flops for op aten::add because it requires mat_size to be IntList.");
770       return 0;
771     }
772 
773     const auto mat_size = mat_sizes.toDimVector();
774     uint64_t flops = 1;
775     for (int64_t dim : mat_size) {
776       flops *= dim;
777     }
778     return flops;
779   }
780   return 0;
781 }
782 
783 } // namespace torch::profiler::impl
784