1 // Copyright (c) Facebook, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 9 #include <ATen/functorch/Macros.h> 10 #include <ATen/functorch/BatchedTensorImpl.h> 11 12 namespace at::functorch { 13 14 // This files contains the legacy (now-deprecated) batching rule API. 15 // Please try to use the new-style batching rule API (see writing_batch_rules.md) 16 17 // This file contains abstractions used for transforming *logical* vmap arguments 18 // into *physical* arguments. (Keep reading for definitions of these terms). 19 20 // NOTE: [Logical vs physical args] 21 // Consider the following vmap. 22 // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) 23 // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], 24 // with batch dims 0 and 2: 25 // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) 26 // 27 // We say the *logical* view of the tensor has size [3] -- tensors inside 28 // `func` appear to have size [3]. 29 // However, the *physical* underlying tensor (the one passed to vmap) has size 30 // [2, 3, 4]. 31 // 32 // This notion of logical vs physical also extends to non-tensor arguments. 33 // Consider the previous tensor; let's assume the user called 34 // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical 35 // dimension they are reducing over is dim 0 but the physical dim is dim 1 36 // (the first non-batch dimension) 37 38 // Forward declared; see NOTE: [What is a VmapPhysicalView?] 39 struct VmapPhysicalView; 40 41 // Most PyTorch operators take 4 or fewer inputs. 42 constexpr int64_t kVmapTransformStaticInputSize = 4; 43 using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>; 44 45 // Pytorch generally advertises good performance for <= 5 dims. 46 // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap 47 // dimensions to get 8. Adjust this number as necessary 48 constexpr int64_t kVmapStaticDimVecSize = 8; 49 using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>; 50 using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>; 51 52 // NOTE: [What is an VmapTransform?] 53 // An *VmapTransform* converts logical views of tensors to physical views. 54 // 55 // Batching rules use VmapTransforms to convert logical arguments to 56 // physical arguments, then call one or more at:: operator that handles the 57 // physical arguments, and then converts the physical result back to a logical 58 // argument. 59 60 // VmapTransform for operators that take tensors with multiple batch dims. 61 // Given one or more logical views on Tensors, `logicalToPhysical` 62 // permutes all of the batch dims to the front of the tensor, aligns 63 // and expands the batch dims to match each other (according to their `level`), 64 // and returns a VmapPhysicalView on the tensor(s). 65 struct TORCH_API MultiBatchVmapTransform { 66 static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); 67 static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); 68 }; 69 70 // VmapTransform for operators that broadcast all inputs. 71 // Given some logical views on Tensors, `logicalToPhysical`: 72 // - permutes all of the batch dims to the front of the tensors 73 // - aligns all the batch dims to the collective levels of all of the tensors. 74 // If a tensor does not have a batch dim for a vmap level, then it receives 75 // a size-one dimension for said level. 76 // - aligns the non-batch dims to have the same dimensionality, adding extra 77 // size-1 dimensions in between the batch dimensions and the non-batch dimensions 78 // so that the batch dimensions are lined up from the right. 79 // 80 // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch 81 // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors 82 // of size (B, 1, 2) and (B, 3, 2). 83 // 84 // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns 85 // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't 86 // actually *need* to return a tensor of size (1, 2) for the second tensor 87 // because the broadcasting operation takes care of that for us, but we do 88 // it anyways to keep things simple. 89 struct TORCH_API BroadcastingVmapTransform { 90 static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); 91 }; 92 93 // Forward declared, if you're reading this file head to toe, don't worry about 94 // it yet. 95 struct VmapPhysicalToLogicalMap; 96 97 // NOTE: [What is a VmapPhysicalView?] 98 // VmapPhysicalView represents a physical view on a Tensor. 99 // 100 // One can use it to further convert logical dimension indices, logical shapes, 101 // and more to their physical variants, or convert a new (physical) tensor into 102 // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). 103 // 104 // VmapPhysicalView stores a physical tensor with all of its batch dimensions at 105 // the front and some levels that correspond to said batch dimensions. 106 // 107 // The levels bitset specifies which vmap levels correspond to the batch 108 // dimensions at the front of the tensor. In particular, the number of set bits 109 // corresponds to the number of batch dimensions on `tensor` and the rightmost 110 // bit of `levels` specifies the maximum number of nested vmaps we are in at 111 // this point in time. 112 // For example, given: 113 // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) 114 // 115 // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less 116 // than or equal to 3. 117 // bitset: 010100 118 // ^ 119 // | 120 // levels: 012345 121 struct TORCH_API VmapPhysicalView { VmapPhysicalViewVmapPhysicalView122 VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels) 123 : levels_(levels), tensor_(std::move(tensor)) { 124 // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor)); 125 } 126 tensorVmapPhysicalView127 Tensor& tensor() { return tensor_; } tensorVmapPhysicalView128 const Tensor& tensor() const { return tensor_; } 129 130 // Maps logical dim indices to physical dim indices. Also does dim wrapping. 131 // 132 // For example, given: 133 // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) 134 // 135 // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. 136 // This is because the size of levels tell us that the first two dimensions 137 // of `tensor_` are batch dimensions, so a logical dim of `n` is actually 138 // a physical dim of `n + 2`. 139 VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; 140 int64_t getPhysicalDim(int64_t logical_dim) const; 141 142 // Returns a VmapPhysicalToLogicalMap object. This can be used for 143 // mapping a physical tensor to a new logical tensor (BatchedTensor) 144 VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; 145 146 // Maps a logical shape to a physical shape by pre-pending the batch 147 // sizes to the logical shape. 148 VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; 149 SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const; 150 151 int64_t numBatchDims() const; 152 153 private: 154 int64_t numLogicalDims() const; 155 156 std::bitset<kVmapNumLevels> levels_; 157 Tensor tensor_; 158 }; 159 160 // Convenience struct used for mapping a physical tensor (a non-BatchedTensor) 161 // to a logical one (BatchedTensor). It holds some levels that are used to do the 162 // mapping and assumes that the batch dimensions in the physical tensor all 163 // occur at the front of the tensor. 164 struct TORCH_API VmapPhysicalToLogicalMap { VmapPhysicalToLogicalMapVmapPhysicalToLogicalMap165 VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {} 166 167 // Maps a physical tensor to a new logical tensor (BatchedTensor). 168 // Assumes that all of the "batch dimensions" are at the front 169 // of the physical tensor. For example, given: 170 // - x = rank-4 Tensor with size 2, 3, 5, 7 171 // - levels = (2, 4) 172 // Returns: 173 // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) 174 Tensor apply(const Tensor& physical_tensor) const; 175 176 // Given a vector of physical tensors, 177 // 1. maps each tensor to a new logical tensor. Assumes that all of the 178 // "batch dimensions" are at the front of the physical tensors. 179 // 2. stores the new logical tensors back into the passed-in vector. This is 180 // to avoid additional dynamic allocations. 181 void applyInplace(std::vector<Tensor>& physical_tensors) const; 182 183 std::bitset<kVmapNumLevels> levels_; 184 }; 185 186 187 } // namespace at::functorch 188