xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir_util.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <unordered_map>
4 #include <vector>
5 
6 #include <torch/csrc/lazy/core/ir.h>
7 
8 namespace torch {
9 namespace lazy {
10 
11 class TORCH_API Util {
12  public:
13   // Tracks the emission status of the nodes during the post-order generation.
14   // It helps tracking loops within the computation graphs.
15   enum EmitStatus {
16     kNotEmitted,
17     kEmitting,
18     kEmitted,
19   };
20 
21   using EmissionMap = std::unordered_map<const Node*, EmitStatus>;
22 
23   // Computes the post order from the given node, without using recursion. The
24   // emission map can be used as saved state, for multiple separate calls to
25   // this API. The returned post-order can be empty if the node has already been
26   // emitted inside the emission map. An error is generated if a loop is
27   // detected.
28   static std::vector<const Node*> ComputePostOrder(
29       const Node* node,
30       EmissionMap* emap);
31 
32   static std::vector<const Node*> ComputePostOrder(
33       c10::ArrayRef<const Node*> nodes,
34       EmissionMap* emap);
35 
36   // Same as above, but computes the post order on the set of nodes specified as
37   // argument.
38   static std::vector<const Node*> ComputePostOrder(
39       c10::ArrayRef<const Node*> nodes);
40 
41   // Retrieves the number of nodes within the graph whose sink are passed in the
42   // nodes argument.
43   static size_t GetGraphSize(c10::ArrayRef<const Node*> nodes);
44 };
45 
46 } // namespace lazy
47 } // namespace torch
48