xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyBatchedTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <bitset>
4 
5 #include <ATen/ArrayRef.h>
6 #include <ATen/SmallVector.h>
7 #include <ATen/Tensor.h>
8 
9 namespace at {
10 
11 // We assume this in a few other places in the codebase,
12 // but there isn't a centralized definition.
13 constexpr int64_t kVmapMaxTensorDims = 64;
14 
15 // The valid vmap levels range from [0, 64). This effectively means that we
16 // support a maximum of 64 nested vmaps.
17 constexpr int64_t kVmapNumLevels = 64;
18 
19 // Store this number of elements of BatchDims on the stack. Most people will
20 // probably use <= 5 nested vmaps, but adjust this number as necessary.
21 constexpr int64_t kBatchDimsStackSize = 5;
22 
23 // a BatchDim represents a "private" dimension on a Tensor created inside of
24 // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
25 // is being vmap'ed over and the `level` being an identifier for which vmap
26 // said dimension was created inside. The `dim` corresponds to a "physical
27 // dim" - it is a dimension index on the underlying physical tensor that is
28 // being vmapped over.
29 struct BatchDim {
BatchDimBatchDim30   BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
dimBatchDim31   int64_t dim() const {
32     return dim_;
33   }
levelBatchDim34   int64_t level() const {
35     return level_;
36   }
37 
38  private:
39   int64_t dim_;
40   int64_t level_;
41 };
42 
43 using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
44 using BatchDimsRef = ArrayRef<BatchDim>;
45 
46 // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
47 // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
48 // BatchedTensorImpl.
49 //
50 // The batch dimensions are treated as being "private"; they are not
51 // user-visible. For example, in the following Tensor,
52 //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
53 // dimensions 0 and 1 are batch dimensions.
54 //
55 // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
56 // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
57 // tensor.
58 struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
59   explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
60 
61   // Returns a reference to BatchDims that represent which dimensions of this
62   // tensor are private.
bdimsBatchedTensorImpl63   BatchDimsRef bdims() const {
64     return bdims_;
65   }
66 
67   // BatchedTensorImpl wraps a Tensor
valueBatchedTensorImpl68   const Tensor& value() const {
69     return value_;
70   };
71 
72   // Given a public dimension index, return the dimension index in the
73   // underlying value() tensor. For example, if we have
74   //    bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
75   //    dim=2)])
76   // bt.actualDim(0) -> 1
77   // bt.actualDim(1) -> 3
78   // bt.actualDim(2) -> Error
79   int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
80 
81   // We have to override this because we opted into CustomStrides
82   IntArrayRef strides_custom() const override;
83   // Override a bunch of methods inherited from TensorImpl to return error
84   // messages.
85   bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
86   void set_size(int64_t dim, int64_t new_size) override;
87   void set_stride(int64_t dim, int64_t new_stride) override;
88   void set_storage_offset(int64_t storage_offset) override;
89 #ifdef DEBUG
90   bool has_storage() const override;
91 #endif
92 
93  private:
94   // see NOTE: [BatchedTensorImpl levels invariant]
95   void checkInvariants() const;
96   const char* tensorimpl_type_name() const override;
97 
98   Tensor value_;
99 
100   // Note: [BatchedTensorImpl levels invariant]
101   // There is an invariant that the BatchDims must be stored in increasing
102   // `level` order. That is, for i < j, bdims_[i].level must be less than
103   // bdims_[j].level.
104   BatchDims bdims_;
105 };
106 
107 // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
108 // BatchedTensorImpl.
isBatchedTensor(const Tensor & tensor)109 inline bool isBatchedTensor(const Tensor& tensor) {
110   return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
111 }
112 
113 // It is unsafe to call this on a Tensor that is not backed by a
114 // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
unsafeGetBatchedImpl(const Tensor & tensor)115 inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
116   return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
117 }
118 
maybeGetBatchedImpl(const Tensor & tensor)119 inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
120   if (!isBatchedTensor(tensor)) {
121     return nullptr;
122   }
123   return unsafeGetBatchedImpl(tensor);
124 }
125 
126 // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
createBatchDimBitset(BatchDimsRef bdims)127 inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
128     BatchDimsRef bdims) {
129   std::bitset<kVmapMaxTensorDims> is_bdim;
130   for (const auto& bdim : bdims) {
131     is_bdim.set(bdim.dim());
132   }
133   return is_bdim;
134 }
135 
136 // Creates a bitset for all of the levels present in `bdims`
createVmapLevelsBitset(BatchDimsRef bdims)137 inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
138   std::bitset<kVmapNumLevels> result;
139   for (const auto& bdim : bdims) {
140     result.set(bdim.level());
141   }
142   return result;
143 }
144 
145 inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
146   out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
147   return out;
148 }
149 
150 // Use this to construct a BatchedTensor from a regular Tensor
151 TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
152 
153 // Adds a batch dim to `tensor`, returning a BatchedTensor
154 TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
155 
156 // Checks if an inplace operation on self and other is "vmap compatible".
157 // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
158 TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
159 
160 } // namespace at
161