xref: /aosp_15_r20/external/pytorch/c10/core/SymbolicShapeMeta.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/SymBool.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/macros/Export.h>
5 #include <c10/macros/Macros.h>
6 #include <c10/util/DimVector.h>
7 
8 #include <atomic>
9 #include <cstdint>
10 #include <mutex>
11 #include <utility>
12 
13 namespace c10 {
14 
15 class C10_API SymbolicShapeMeta {
16  public:
17   // Basic metadata from which other quantities are derived
18   SymDimVector sizes_ = {0};
19   SymDimVector strides_ = {1};
20   SymInt storage_offset_ = 0;
21 
22   bool strides_valid_ = true; // e.g. for sparse where there are no strides
23 
24   SymbolicShapeMeta() = default;
25   SymbolicShapeMeta(const SymbolicShapeMeta& other);
26   SymbolicShapeMeta& operator=(const SymbolicShapeMeta& other) = delete;
27   SymbolicShapeMeta& operator=(SymbolicShapeMeta&& other) = delete;
28 
refresh_numel()29   void refresh_numel() {
30     // Non-const, don't need to hold mutables_ lock
31     available_.fetch_and(~numel_avail);
32     numel_ = 1;
33   }
34 
refresh_contiguous()35   void refresh_contiguous() {
36     // Non-const, don't need to hold mutables_ lock
37     available_.fetch_and(numel_avail);
38     is_contiguous_ = false;
39     is_channels_last_contiguous_ = false;
40     is_channels_last_3d_contiguous_ = false;
41     is_channels_last_ = false;
42     is_channels_last_3d_ = false;
43     is_non_overlapping_and_dense_ = false;
44   }
45 
dim()46   int64_t dim() const {
47     return static_cast<int64_t>(sizes_.size());
48   }
49 
50   // Accessors for derived quantities, computed lazily on first access
51 
has_numel()52   bool has_numel() const {
53     return available_.load() & numel_avail;
54   }
has_is_contiguous()55   bool has_is_contiguous() const {
56     return available_.load() & is_contiguous_avail;
57   }
has_is_channels_last_contiguous()58   bool has_is_channels_last_contiguous() const {
59     return available_.load() & is_channels_last_contiguous_avail;
60   }
has_is_channels_last_3d_contiguous()61   bool has_is_channels_last_3d_contiguous() const {
62     return available_.load() & is_channels_last_3d_contiguous_avail;
63   }
has_is_channels_last()64   bool has_is_channels_last() const {
65     return available_.load() & is_channels_last_avail;
66   }
has_is_channels_last_3d()67   bool has_is_channels_last_3d() const {
68     return available_.load() & is_channels_last_3d_avail;
69   }
has_is_non_overlapping_and_dense()70   bool has_is_non_overlapping_and_dense() const {
71     return available_.load() & is_non_overlapping_and_dense_avail;
72   }
73 
74   // Accessors to cached derived properties
75   // DO NOT call with mutables_ lock held
numel()76   const SymInt& numel() const {
77     if (C10_UNLIKELY(!has_numel())) {
78       init_numel();
79     }
80     return numel_;
81   }
82 
is_contiguous()83   const SymBool& is_contiguous() const {
84     if (C10_UNLIKELY(!has_is_contiguous())) {
85       init_is_contiguous();
86     }
87     return is_contiguous_;
88   }
89 
is_channels_last_contiguous()90   const SymBool& is_channels_last_contiguous() const {
91     if (C10_UNLIKELY(!has_is_channels_last_contiguous())) {
92       init_is_channels_last_contiguous();
93     }
94     return is_channels_last_contiguous_;
95   }
96 
is_channels_last_3d_contiguous()97   const SymBool& is_channels_last_3d_contiguous() const {
98     if (C10_UNLIKELY(!has_is_channels_last_3d_contiguous())) {
99       init_is_channels_last_3d_contiguous();
100     }
101     return is_channels_last_3d_contiguous_;
102   }
103 
is_channels_last()104   const SymBool& is_channels_last() const {
105     if (C10_UNLIKELY(!has_is_channels_last())) {
106       init_is_channels_last();
107     }
108     return is_channels_last_;
109   }
110 
is_channels_last_3d()111   const SymBool& is_channels_last_3d() const {
112     if (C10_UNLIKELY(!has_is_channels_last_3d())) {
113       init_is_channels_last_3d();
114     }
115     return is_channels_last_3d_;
116   }
117 
is_non_overlapping_and_dense()118   const SymBool& is_non_overlapping_and_dense() const {
119     if (C10_UNLIKELY(!has_is_non_overlapping_and_dense())) {
120       init_is_non_overlapping_and_dense();
121     }
122     return is_non_overlapping_and_dense_;
123   }
124 
125   // Assumptions so we can short-circuit computation
126   // NOTE: Don't need to lock mutables_ since these aren't const
127   void assume_contiguous(SymBool val = true) {
128     is_contiguous_ = std::move(val);
129     available_.fetch_or(is_contiguous_avail);
130   }
131   void assume_channels_last_contiguous(SymBool val = true) {
132     is_contiguous_ = std::move(val);
133     available_.fetch_or(is_channels_last_contiguous_avail);
134   }
135   void assume_channels_last_3d_contiguous(SymBool val = true) {
136     is_channels_last_3d_contiguous_ = std::move(val);
137     available_.fetch_or(is_channels_last_3d_contiguous_avail);
138   }
139   void assume_channels_last(SymBool val = true) {
140     is_channels_last_ = std::move(val);
141     available_.fetch_or(is_channels_last_avail);
142   }
143   void assume_channels_last_3d(SymBool val = true) {
144     is_channels_last_3d_ = std::move(val);
145     available_.fetch_or(is_channels_last_3d_avail);
146   }
147   void assume_non_overlapping_and_dense(SymBool val = true) {
148     is_non_overlapping_and_dense_ = std::move(val);
149     available_.fetch_or(is_non_overlapping_and_dense_avail);
150   }
151 
152  private:
153   SymBool compute_contiguous() const;
154   SymBool compute_channels_last_contiguous_2d() const;
155   SymBool compute_channels_last_contiguous_3d() const;
156   SymBool compute_strides_like_channels_last_2d() const;
157   SymBool compute_strides_like_channels_last_3d() const;
158   SymBool compute_non_overlapping_and_dense() const;
159 
160   // These are little wrappers over the real compute_ functions that
161   // can make use of other contiguity fields to short circuit.
162   // They need to be implemented separately for SymBool, as SymBool does
163   // not short circuit.
164   // TODO: should the SymBool cases avoid the short circuit?  Need to reason
165   // if its correct, and reason if the simpler expressions are better for
166   // analysis (maybe not!)
167 
168   SymBool compute_channels_last_contiguous_3d_dim5() const;
169   SymBool compute_channels_last_2d_dim5() const;
170   SymBool compute_channels_last_3d_dim5() const;
171   SymBool compute_is_non_overlapping_and_dense_dim4() const;
172   SymBool compute_is_non_overlapping_and_dense_dim5() const;
173   SymBool compute_is_non_overlapping_and_dense_anydim() const;
174 
175   void init_numel() const;
176   void init_is_contiguous() const;
177   void init_is_channels_last_contiguous() const;
178   void init_is_channels_last_3d_contiguous() const;
179   void init_is_channels_last() const;
180   void init_is_channels_last_3d() const;
181   void init_is_non_overlapping_and_dense() const;
182 
183   // NOTE: These only set if !has_foo()
184   void set_numel(SymInt val) const;
185   void set_is_contiguous(SymBool val) const;
186   void set_is_channels_last_contiguous(SymBool val) const;
187   void set_is_channels_last_3d_contiguous(SymBool val) const;
188   void set_is_channels_last(SymBool val) const;
189   void set_is_channels_last_3d(SymBool val) const;
190   void set_is_non_overlapping_and_dense(SymBool val) const;
191 
192   // Lazily initialized variables, with the corresponding available_ flag
193   // indicating whether the value has been initialized
194   mutable std::atomic<int> available_{0};
195   enum avail {
196     numel_avail = 1 << 0,
197     is_contiguous_avail = 1 << 1,
198     is_channels_last_contiguous_avail = 1 << 2,
199     is_channels_last_3d_contiguous_avail = 1 << 3,
200     is_channels_last_avail = 1 << 4,
201     is_channels_last_3d_avail = 1 << 5,
202     is_non_overlapping_and_dense_avail = 1 << 6,
203   };
204 
205   // Mutex to prevent races when initializing the variable from const accessors
206   mutable std::mutex mutables_;
207   mutable SymInt numel_ = 1;
208   mutable SymBool is_contiguous_{true};
209   mutable SymBool is_channels_last_contiguous_{false};
210   mutable SymBool is_channels_last_3d_contiguous_{false};
211   mutable SymBool is_channels_last_{false};
212   mutable SymBool is_channels_last_3d_{false};
213   mutable SymBool is_non_overlapping_and_dense_{true};
214 };
215 
216 } // namespace c10
217