1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <torch/csrc/jit/runtime/static/impl.h> 5 6 namespace torch::jit { 7 8 // The following class facilitates code reuse between ProcessedNodeInputWrapper 9 // and ProcessedNodeOutputWrapper via CRTP 10 template <typename DerivedWrapper> 11 class ProcessedNodeWrapperBase { 12 public: 13 class ProcessedNodeWrapperBaseIter { 14 public: 15 using iterator_category = std::forward_iterator_tag; 16 using value_type = at::Tensor; 17 using difference_type = size_t; 18 using pointer = const at::Tensor*; 19 using reference = const at::Tensor&; 20 21 ProcessedNodeWrapperBaseIter() = default; 22 ProcessedNodeWrapperBaseIter(const DerivedWrapper * container,size_t start_idx)23 ProcessedNodeWrapperBaseIter( 24 const DerivedWrapper* container, 25 size_t start_idx) 26 : container_(container), idx_(start_idx) {} 27 28 ProcessedNodeWrapperBaseIter& operator++() { 29 TORCH_DCHECK_NE(idx_, container_->size()); 30 ++idx_; 31 return *this; 32 } 33 34 ProcessedNodeWrapperBaseIter operator++(int) { 35 ProcessedNodeWrapperBaseIter old = *this; 36 ++(*this); 37 return old; 38 } 39 40 reference operator*() const { 41 TORCH_CHECK(container_ != nullptr); 42 return (*container_)[idx_]; 43 } 44 45 pointer operator->() const { 46 TORCH_CHECK(container_ != nullptr); 47 return &(*container_)[idx_]; 48 } 49 50 friend bool operator==( 51 ProcessedNodeWrapperBaseIter lhs, 52 ProcessedNodeWrapperBaseIter rhs) { 53 TORCH_DCHECK_EQ(lhs.container_, rhs.container_); 54 return lhs.idx_ == rhs.idx_; 55 } 56 57 friend bool operator!=( 58 ProcessedNodeWrapperBaseIter lhs, 59 ProcessedNodeWrapperBaseIter rhs) { 60 return !(lhs == rhs); 61 } 62 63 private: 64 const DerivedWrapper* container_ = nullptr; 65 size_t idx_ = 0; 66 }; 67 68 // NB: to mimic the behavior of at::ArrayRef, both iterators are 69 // the const version. 70 using iterator = ProcessedNodeWrapperBaseIter; 71 using const_iterator = ProcessedNodeWrapperBaseIter; 72 using size_type = size_t; 73 using value_type = at::Tensor; 74 ProcessedNodeWrapperBase(ProcessedNode & pnode)75 explicit ProcessedNodeWrapperBase(ProcessedNode& pnode) : pnode_(pnode) {} 76 begin()77 iterator begin() { 78 return ProcessedNodeWrapperBaseIter(static_cast<DerivedWrapper*>(this), 0); 79 } end()80 iterator end() { 81 return ProcessedNodeWrapperBaseIter( 82 static_cast<DerivedWrapper*>(this), 83 static_cast<DerivedWrapper*>(this)->size()); 84 } 85 begin()86 const_iterator begin() const { 87 return ProcessedNodeWrapperBaseIter( 88 static_cast<const DerivedWrapper*>(this), 0); 89 } end()90 const_iterator end() const { 91 return ProcessedNodeWrapperBaseIter( 92 static_cast<const DerivedWrapper*>(this), 93 static_cast<const DerivedWrapper*>(this)->size()); 94 } 95 cbegin()96 const_iterator cbegin() const { 97 return ProcessedNodeWrapperBaseIter( 98 static_cast<const DerivedWrapper*>(this), 0); 99 } cend()100 const_iterator cend() const { 101 return ProcessedNodeWrapperBaseIter( 102 static_cast<const DerivedWrapper*>(this), 103 static_cast<const DerivedWrapper*>(this)->size()); 104 } 105 empty()106 bool empty() const { 107 return static_cast<const DerivedWrapper*>(this)->size() == 0; 108 } 109 110 protected: 111 ProcessedNode& pnode_; 112 }; 113 114 // A ProcessedNodeWrapperBase lets us use ProcessedNode directly in a context 115 // where a container of IValues is expected. This trick is handy for avoiding 116 // refcount bumps in perf-sensitive native ops. For example, suppose we have an 117 // op that takes a list of tensors as an argument and we've turned the op into a 118 // variadic variant in static runtime. To use the PyTorch library implementation 119 // of the op, we would have to pack the variadic arguments into a list: 120 // std::vector<Tensor> tensor_list; 121 // tensor_list.reserve(pnode->num_outputs()); 122 // for (const auto i : c10::irange(pnode->num_inputs()) 123 // tensor_list.push_back(pnode->Input(i).toTensor()); 124 // op_impl(tensor_list); 125 // Using ProcessedNodeWrapperBase, we can avoid this round of refcount bumps. 126 // All we need to do is turn `op_impl` into a template and pass it 127 // ProcessedNodeInputWrapper(*pnode)! 128 class ProcessedNodeInputWrapper 129 : public ProcessedNodeWrapperBase<ProcessedNodeInputWrapper> { 130 public: 131 // The last `back_elements_ignored` elements are not considered. 132 // Same for the first `front_elements_ignored` elements. 133 // This is useful for ops where 134 // only the first N elements are tensors (N < inputs.size()). 135 // For instance, the last argument to VarStack is an integer dimension. 136 explicit ProcessedNodeInputWrapper( 137 ProcessedNode& pnode, 138 size_t front_elements_ignored = 0, 139 size_t back_elements_ignored = 1) 140 : ProcessedNodeWrapperBase<ProcessedNodeInputWrapper>(pnode), 141 front_elements_ignored_(front_elements_ignored), 142 back_elements_ignored_(back_elements_ignored) { 143 TORCH_CHECK(front_elements_ignored_ <= pnode_.num_inputs()); 144 TORCH_CHECK( 145 back_elements_ignored_ <= 146 pnode_.num_inputs() - front_elements_ignored_); 147 } 148 size()149 size_t size() const { 150 return pnode_.num_inputs() - back_elements_ignored_ - 151 front_elements_ignored_; 152 } 153 154 const at::Tensor& operator[](size_t idx) const { 155 TORCH_CHECK(idx < size()); 156 return pnode_.Input(front_elements_ignored_ + idx).toTensor(); 157 } 158 front()159 const at::Tensor& front() const { 160 TORCH_CHECK( 161 !empty(), 162 "Attempted to access front() of empty ProcessedNodeInputWrapper"); 163 return pnode_.Input(front_elements_ignored_).toTensor(); 164 } 165 back()166 const at::Tensor& back() const { 167 TORCH_CHECK( 168 !empty(), 169 "Attempted to access back() of empty ProcessedNodeInputWrapper"); 170 return pnode_.Input(pnode_.num_inputs() - back_elements_ignored_ - 1) 171 .toTensor(); 172 } 173 174 private: 175 size_t front_elements_ignored_; 176 size_t back_elements_ignored_; 177 }; 178 179 // Similar to ProcessedNodeInputWrapper, but wraps outputs and allows for 180 // writing. 181 class ProcessedNodeOutputWrapper 182 : public ProcessedNodeWrapperBase<ProcessedNodeOutputWrapper> { 183 public: 184 using ProcessedNodeWrapperBase< 185 ProcessedNodeOutputWrapper>::ProcessedNodeWrapperBase; 186 size()187 size_t size() const { 188 return pnode_.num_outputs(); 189 } 190 191 at::Tensor& operator[](size_t idx) const { 192 TORCH_CHECK(idx < size()); 193 return pnode_.Output(idx).toTensor(); 194 } 195 front()196 at::Tensor& front() const { 197 TORCH_CHECK( 198 !empty(), 199 "Attempted to access front() of empty ProcessedNodeOutputWrapper"); 200 return pnode_.Output(0).toTensor(); 201 } 202 back()203 at::Tensor& back() const { 204 TORCH_CHECK( 205 !empty(), 206 "Attempted to access back() of empty ProcessedNodeOutputWrapper"); 207 return pnode_.Output(size() - 1).toTensor(); 208 } 209 }; 210 211 } // namespace torch::jit 212