xref: /aosp_15_r20/external/pytorch/aten/src/ATen/LegacyVmapTransforms.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/LegacyBatchedTensorImpl.h>
4 #include <ATen/core/IListRef.h>
5 
6 namespace at {
7 
8 // This file contains abstractions used for transforming *logical* vmap
9 // arguments into *physical* arguments. (Keep reading for definitions of these
10 // terms).
11 
12 // NOTE: [Logical vs physical args]
13 // Consider the following vmap.
14 //   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
15 // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
16 // with batch dims 0 and 2:
17 //   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
18 //
19 // We say the *logical* view of the tensor has size [3] -- tensors inside
20 // `func` appear to have size [3].
21 // However, the *physical* underlying tensor (the one passed to vmap) has size
22 // [2, 3, 4].
23 //
24 // This notion of logical vs physical also extends to non-tensor arguments.
25 // Consider the previous tensor; let's assume the user called
26 // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
27 // dimension they are reducing over is dim 0 but the physical dim is dim 1
28 // (the first non-batch dimension)
29 
30 // Forward declared; see NOTE: [What is a VmapPhysicalView?]
31 struct VmapPhysicalView;
32 
33 // Most PyTorch operators take 4 or fewer inputs.
34 constexpr int64_t kVmapTransformStaticInputSize = 4;
35 using VmapPhysicalViewVec =
36     SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
37 
38 // Pytorch generally advertises good performance for <= 5 dims.
39 // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
40 // dimensions to get 8. Adjust this number as necessary
41 constexpr int64_t kVmapStaticDimVecSize = 8;
42 using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
43 using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
44 
45 // NOTE: [What is an VmapTransform?]
46 // An *VmapTransform* converts logical views of tensors to physical views.
47 //
48 // Batching rules use VmapTransforms to convert logical arguments to
49 // physical arguments, then call one or more at:: operator that handles the
50 // physical arguments, and then converts the physical result back to a logical
51 // argument.
52 
53 // VmapTransform for operators that take tensors with multiple batch dims.
54 // Given one or more logical views on Tensors, `logicalToPhysical`
55 // permutes all of the batch dims to the front of the tensor, aligns
56 // and expands the batch dims to match each other (according to their `level`),
57 // and returns a VmapPhysicalView on the tensor(s).
58 struct TORCH_API MultiBatchVmapTransform {
59   static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
60   static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
61 };
62 
63 // VmapTransform for operators that broadcast all inputs.
64 // Given some logical views on Tensors, `logicalToPhysical`:
65 // - permutes all of the batch dims to the front of the tensors
66 // - aligns all the batch dims to the collective levels of all of the tensors.
67 //   If a tensor does not have a batch dim for a vmap level, then it receives
68 //   a size-one dimension for said level.
69 // - aligns the non-batch dims to have the same dimensionality, adding extra
70 //   size-1 dimensions in between the batch dimensions and the non-batch
71 //   dimensions so that the batch dimensions are lined up from the right.
72 //
73 // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
74 // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
75 // tensors of size (B, 1, 2) and (B, 3, 2).
76 //
77 // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
78 // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
79 // actually *need* to return a tensor of size (1, 2) for the second tensor
80 // because the broadcasting operation takes care of that for us, but we do
81 // it anyways to keep things simple.
82 struct TORCH_API BroadcastingVmapTransform {
83   static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
84 };
85 
86 // Forward declared, if you're reading this file head to toe, don't worry about
87 // it yet.
88 struct VmapPhysicalToLogicalMap;
89 
90 // NOTE: [What is a VmapPhysicalView?]
91 // VmapPhysicalView represents a physical view on a Tensor.
92 //
93 // One can use it to further convert logical dimension indices, logical shapes,
94 // and more to their physical variants, or convert a new (physical) tensor into
95 // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
96 //
97 // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
98 // the front and some levels that correspond to said batch dimensions.
99 //
100 // The levels bitset specifies which vmap levels correspond to the batch
101 // dimensions at the front of the tensor. In particular, the number of set bits
102 // corresponds to the number of batch dimensions on `tensor` and the rightmost
103 // bit of `levels` specifies the maximum number of nested vmaps we are in at
104 // this point in time.
105 // For example, given:
106 //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
107 //
108 // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
109 // than or equal to 3.
110 //   bitset: 010100
111 //              ^
112 //              |
113 //   levels: 012345
114 struct TORCH_API VmapPhysicalView {
VmapPhysicalViewVmapPhysicalView115   VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
116       : levels_(levels), tensor_(std::move(tensor)) {
117     TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
118   }
119 
tensorVmapPhysicalView120   Tensor& tensor() {
121     return tensor_;
122   }
tensorVmapPhysicalView123   const Tensor& tensor() const {
124     return tensor_;
125   }
126 
127   // Maps logical dim indices to physical dim indices. Also does dim wrapping.
128   //
129   // For example, given:
130   //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
131   //
132   // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
133   // This is because the size of levels tell us that the first two dimensions
134   // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
135   // a physical dim of `n + 2`.
136   VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
137   int64_t getPhysicalDim(int64_t logical_dim) const;
138 
139   // Returns a VmapPhysicalToLogicalMap object. This can be used for
140   // mapping a physical tensor to a new logical tensor (BatchedTensor)
141   VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
142 
143   // Maps a logical shape to a physical shape by pre-pending the batch
144   // sizes to the logical shape.
145   VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
146 
147   int64_t numBatchDims() const;
148 
149  private:
150   int64_t numLogicalDims() const;
151 
152   std::bitset<kVmapNumLevels> levels_;
153   Tensor tensor_;
154 };
155 
156 // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
157 // to a logical one (BatchedTensor). It holds some levels that are used to do
158 // the mapping and assumes that the batch dimensions in the physical tensor all
159 // occur at the front of the tensor.
160 struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMapVmapPhysicalToLogicalMap161   VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
162       : levels_(levels) {}
163 
164   // Maps a physical tensor to a new logical tensor (BatchedTensor).
165   // Assumes that all of the "batch dimensions" are at the front
166   // of the physical tensor. For example, given:
167   // - x = rank-4 Tensor with size 2, 3, 5, 7
168   // - levels = (2, 4)
169   // Returns:
170   // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
171   Tensor apply(const Tensor& physical_tensor) const;
172 
173   // Given a vector of physical tensors,
174   // 1. maps each tensor to a new logical tensor. Assumes that all of the
175   //    "batch dimensions" are at the front of the physical tensors.
176   // 2. stores the new logical tensors back into the passed-in vector. This is
177   //    to avoid additional dynamic allocations.
178   void applyInplace(std::vector<Tensor>& physical_tensors) const;
179 
180   std::bitset<kVmapNumLevels> levels_;
181 };
182 
183 } // namespace at
184