xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/bounds_inference.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <unordered_map>
4 #include <vector>
5 
6 #include <torch/csrc/Export.h>
7 #include <torch/csrc/jit/tensorexpr/mem_dependency_checker.h>
8 
9 namespace torch {
10 namespace jit {
11 namespace tensorexpr {
12 
13 class Expr;
14 class Buf;
15 class Stmt;
16 
17 enum C10_API_ENUM TensorAccessKind { kLoad, kStore, kMutate };
18 
19 struct TORCH_API TensorAccessBoundsInfo {
20   TensorAccessKind kind;
21   std::vector<ExprPtr> start;
22   std::vector<ExprPtr> stop;
23 };
24 
25 using BoundsInfo =
26     std::unordered_map<BufPtr, std::vector<TensorAccessBoundsInfo>>;
27 
28 TORCH_API BoundsInfo
29 inferBounds(const StmtPtr& s, bool distinctAccessKinds = true);
30 
31 // Bounds inference caching the analysis. The MemDependencyChecker must already
32 // have been run.
33 TORCH_API BoundsInfo getInferredBounds(
34     analysis::MemDependencyChecker& analyzer,
35     const StmtPtr& s,
36     bool distinctAccessKinds = true);
37 TORCH_API BoundsInfo getInferredBounds(
38     analysis::MemDependencyChecker& analyzer,
39     const ExprPtr& e,
40     bool distinctAccessKinds = true);
41 
42 TORCH_API void printBoundsInfo(const BoundsInfo& v);
43 
44 TORCH_API std::vector<ExprPtr> getBoundExtents(
45     const std::vector<TensorAccessBoundsInfo>& infos);
46 
47 // The kind of dependency found, in increasing order of exclusivity.
48 enum class HazardKind {
49   ReadAfterWrite,
50   WriteAfterRead,
51   WriteAfterWrite,
52   NoDependency,
53 };
54 TORCH_API HazardKind getPotentialHazards(
55     analysis::MemDependencyChecker& analyzer,
56     const StmtPtr& A,
57     const StmtPtr& B);
58 
59 // Returns true if there is a conflicting overlap between accesses in
60 // statements A and B. A conflicting overlap is an overlap in buffer accesses
61 // where at least one of the accesses is a Store.
62 TORCH_API bool hasConflictingOverlap(
63     analysis::MemDependencyChecker& analyzer,
64     const StmtPtr& A,
65     const StmtPtr& B);
66 // Same as above, between accesses in stores S1 and S2.
67 TORCH_API bool isOverlapping(
68     analysis::MemDependencyChecker& analyzer,
69     const StorePtr& S1,
70     const StorePtr& S2);
71 // Same as above, between accesses in store S and load L.
72 TORCH_API bool isOverlapping(
73     analysis::MemDependencyChecker& analyzer,
74     const StorePtr& S,
75     const LoadPtr& L);
76 
77 } // namespace tensorexpr
78 } // namespace jit
79 } // namespace torch
80