xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/LegacyVmapTransforms.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 
9 #include <ATen/functorch/Macros.h>
10 #include <ATen/functorch/BatchedTensorImpl.h>
11 
12 namespace at::functorch {
13 
14 // This files contains the legacy (now-deprecated) batching rule API.
15 // Please try to use the new-style batching rule API (see writing_batch_rules.md)
16 
17 // This file contains abstractions used for transforming *logical* vmap arguments
18 // into *physical* arguments. (Keep reading for definitions of these terms).
19 
20 // NOTE: [Logical vs physical args]
21 // Consider the following vmap.
22 //   vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
23 // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
24 // with batch dims 0 and 2:
25 //   BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
26 //
27 // We say the *logical* view of the tensor has size [3] -- tensors inside
28 // `func` appear to have size [3].
29 // However, the *physical* underlying tensor (the one passed to vmap) has size
30 // [2, 3, 4].
31 //
32 // This notion of logical vs physical also extends to non-tensor arguments.
33 // Consider the previous tensor; let's assume the user called
34 // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
35 // dimension they are reducing over is dim 0 but the physical dim is dim 1
36 // (the first non-batch dimension)
37 
38 // Forward declared; see NOTE: [What is a VmapPhysicalView?]
39 struct VmapPhysicalView;
40 
41 // Most PyTorch operators take 4 or fewer inputs.
42 constexpr int64_t kVmapTransformStaticInputSize = 4;
43 using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
44 
45 // Pytorch generally advertises good performance for <= 5 dims.
46 // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
47 // dimensions to get 8. Adjust this number as necessary
48 constexpr int64_t kVmapStaticDimVecSize = 8;
49 using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
50 using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
51 
52 // NOTE: [What is an VmapTransform?]
53 // An *VmapTransform* converts logical views of tensors to physical views.
54 //
55 // Batching rules use VmapTransforms to convert logical arguments to
56 // physical arguments, then call one or more at:: operator that handles the
57 // physical arguments, and then converts the physical result back to a logical
58 // argument.
59 
60 // VmapTransform for operators that take tensors with multiple batch dims.
61 // Given one or more logical views on Tensors, `logicalToPhysical`
62 // permutes all of the batch dims to the front of the tensor, aligns
63 // and expands the batch dims to match each other (according to their `level`),
64 // and returns a VmapPhysicalView on the tensor(s).
65 struct TORCH_API MultiBatchVmapTransform {
66   static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
67   static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
68 };
69 
70 // VmapTransform for operators that broadcast all inputs.
71 // Given some logical views on Tensors, `logicalToPhysical`:
72 // - permutes all of the batch dims to the front of the tensors
73 // - aligns all the batch dims to the collective levels of all of the tensors.
74 //   If a tensor does not have a batch dim for a vmap level, then it receives
75 //   a size-one dimension for said level.
76 // - aligns the non-batch dims to have the same dimensionality, adding extra
77 //   size-1 dimensions in between the batch dimensions and the non-batch dimensions
78 //   so that the batch dimensions are lined up from the right.
79 //
80 // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
81 // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
82 // of size (B, 1, 2) and (B, 3, 2).
83 //
84 // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
85 // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
86 // actually *need* to return a tensor of size (1, 2) for the second tensor
87 // because the broadcasting operation takes care of that for us, but we do
88 // it anyways to keep things simple.
89 struct TORCH_API BroadcastingVmapTransform {
90   static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
91 };
92 
93 // Forward declared, if you're reading this file head to toe, don't worry about
94 // it yet.
95 struct VmapPhysicalToLogicalMap;
96 
97 // NOTE: [What is a VmapPhysicalView?]
98 // VmapPhysicalView represents a physical view on a Tensor.
99 //
100 // One can use it to further convert logical dimension indices, logical shapes,
101 // and more to their physical variants, or convert a new (physical) tensor into
102 // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
103 //
104 // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
105 // the front and some levels that correspond to said batch dimensions.
106 //
107 // The levels bitset specifies which vmap levels correspond to the batch
108 // dimensions at the front of the tensor. In particular, the number of set bits
109 // corresponds to the number of batch dimensions on `tensor` and the rightmost
110 // bit of `levels` specifies the maximum number of nested vmaps we are in at
111 // this point in time.
112 // For example, given:
113 //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
114 //
115 // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
116 // than or equal to 3.
117 //   bitset: 010100
118 //              ^
119 //              |
120 //   levels: 012345
121 struct TORCH_API VmapPhysicalView {
VmapPhysicalViewVmapPhysicalView122   VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
123       : levels_(levels), tensor_(std::move(tensor)) {
124     // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
125   }
126 
tensorVmapPhysicalView127   Tensor& tensor() { return tensor_; }
tensorVmapPhysicalView128   const Tensor& tensor() const { return tensor_; }
129 
130   // Maps logical dim indices to physical dim indices. Also does dim wrapping.
131   //
132   // For example, given:
133   //   physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
134   //
135   // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
136   // This is because the size of levels tell us that the first two dimensions
137   // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
138   // a physical dim of `n + 2`.
139   VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
140   int64_t getPhysicalDim(int64_t logical_dim) const;
141 
142   // Returns a VmapPhysicalToLogicalMap object. This can be used for
143   // mapping a physical tensor to a new logical tensor (BatchedTensor)
144   VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
145 
146   // Maps a logical shape to a physical shape by pre-pending the batch
147   // sizes to the logical shape.
148   VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
149   SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
150 
151   int64_t numBatchDims() const;
152 
153  private:
154   int64_t numLogicalDims() const;
155 
156   std::bitset<kVmapNumLevels> levels_;
157   Tensor tensor_;
158 };
159 
160 // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
161 // to a logical one (BatchedTensor). It holds some levels that are used to do the
162 // mapping and assumes that the batch dimensions in the physical tensor all
163 // occur at the front of the tensor.
164 struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMapVmapPhysicalToLogicalMap165   VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
166 
167   // Maps a physical tensor to a new logical tensor (BatchedTensor).
168   // Assumes that all of the "batch dimensions" are at the front
169   // of the physical tensor. For example, given:
170   // - x = rank-4 Tensor with size 2, 3, 5, 7
171   // - levels = (2, 4)
172   // Returns:
173   // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
174   Tensor apply(const Tensor& physical_tensor) const;
175 
176   // Given a vector of physical tensors,
177   // 1. maps each tensor to a new logical tensor. Assumes that all of the
178   //    "batch dimensions" are at the front of the physical tensors.
179   // 2. stores the new logical tensors back into the passed-in vector. This is
180   //    to avoid additional dynamic allocations.
181   void applyInplace(std::vector<Tensor>& physical_tensors) const;
182 
183   std::bitset<kVmapNumLevels> levels_;
184 };
185 
186 
187 } // namespace at::functorch
188