xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/processed_node_wrapper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/jit/runtime/static/impl.h>
5 
6 namespace torch::jit {
7 
8 // The following class facilitates code reuse between ProcessedNodeInputWrapper
9 // and ProcessedNodeOutputWrapper via CRTP
10 template <typename DerivedWrapper>
11 class ProcessedNodeWrapperBase {
12  public:
13   class ProcessedNodeWrapperBaseIter {
14    public:
15     using iterator_category = std::forward_iterator_tag;
16     using value_type = at::Tensor;
17     using difference_type = size_t;
18     using pointer = const at::Tensor*;
19     using reference = const at::Tensor&;
20 
21     ProcessedNodeWrapperBaseIter() = default;
22 
ProcessedNodeWrapperBaseIter(const DerivedWrapper * container,size_t start_idx)23     ProcessedNodeWrapperBaseIter(
24         const DerivedWrapper* container,
25         size_t start_idx)
26         : container_(container), idx_(start_idx) {}
27 
28     ProcessedNodeWrapperBaseIter& operator++() {
29       TORCH_DCHECK_NE(idx_, container_->size());
30       ++idx_;
31       return *this;
32     }
33 
34     ProcessedNodeWrapperBaseIter operator++(int) {
35       ProcessedNodeWrapperBaseIter old = *this;
36       ++(*this);
37       return old;
38     }
39 
40     reference operator*() const {
41       TORCH_CHECK(container_ != nullptr);
42       return (*container_)[idx_];
43     }
44 
45     pointer operator->() const {
46       TORCH_CHECK(container_ != nullptr);
47       return &(*container_)[idx_];
48     }
49 
50     friend bool operator==(
51         ProcessedNodeWrapperBaseIter lhs,
52         ProcessedNodeWrapperBaseIter rhs) {
53       TORCH_DCHECK_EQ(lhs.container_, rhs.container_);
54       return lhs.idx_ == rhs.idx_;
55     }
56 
57     friend bool operator!=(
58         ProcessedNodeWrapperBaseIter lhs,
59         ProcessedNodeWrapperBaseIter rhs) {
60       return !(lhs == rhs);
61     }
62 
63    private:
64     const DerivedWrapper* container_ = nullptr;
65     size_t idx_ = 0;
66   };
67 
68   // NB: to mimic the behavior of at::ArrayRef, both iterators are
69   // the const version.
70   using iterator = ProcessedNodeWrapperBaseIter;
71   using const_iterator = ProcessedNodeWrapperBaseIter;
72   using size_type = size_t;
73   using value_type = at::Tensor;
74 
ProcessedNodeWrapperBase(ProcessedNode & pnode)75   explicit ProcessedNodeWrapperBase(ProcessedNode& pnode) : pnode_(pnode) {}
76 
begin()77   iterator begin() {
78     return ProcessedNodeWrapperBaseIter(static_cast<DerivedWrapper*>(this), 0);
79   }
end()80   iterator end() {
81     return ProcessedNodeWrapperBaseIter(
82         static_cast<DerivedWrapper*>(this),
83         static_cast<DerivedWrapper*>(this)->size());
84   }
85 
begin()86   const_iterator begin() const {
87     return ProcessedNodeWrapperBaseIter(
88         static_cast<const DerivedWrapper*>(this), 0);
89   }
end()90   const_iterator end() const {
91     return ProcessedNodeWrapperBaseIter(
92         static_cast<const DerivedWrapper*>(this),
93         static_cast<const DerivedWrapper*>(this)->size());
94   }
95 
cbegin()96   const_iterator cbegin() const {
97     return ProcessedNodeWrapperBaseIter(
98         static_cast<const DerivedWrapper*>(this), 0);
99   }
cend()100   const_iterator cend() const {
101     return ProcessedNodeWrapperBaseIter(
102         static_cast<const DerivedWrapper*>(this),
103         static_cast<const DerivedWrapper*>(this)->size());
104   }
105 
empty()106   bool empty() const {
107     return static_cast<const DerivedWrapper*>(this)->size() == 0;
108   }
109 
110  protected:
111   ProcessedNode& pnode_;
112 };
113 
114 // A ProcessedNodeWrapperBase lets us use ProcessedNode directly in a context
115 // where a container of IValues is expected. This trick is handy for avoiding
116 // refcount bumps in perf-sensitive native ops. For example, suppose we have an
117 // op that takes a list of tensors as an argument and we've turned the op into a
118 // variadic variant in static runtime. To use the PyTorch library implementation
119 // of the op, we would have to pack the variadic arguments into a list:
120 //   std::vector<Tensor> tensor_list;
121 //   tensor_list.reserve(pnode->num_outputs());
122 //   for (const auto i : c10::irange(pnode->num_inputs())
123 //     tensor_list.push_back(pnode->Input(i).toTensor());
124 //   op_impl(tensor_list);
125 // Using ProcessedNodeWrapperBase, we can avoid this round of refcount bumps.
126 // All we need to do is turn `op_impl` into a template and pass it
127 // ProcessedNodeInputWrapper(*pnode)!
128 class ProcessedNodeInputWrapper
129     : public ProcessedNodeWrapperBase<ProcessedNodeInputWrapper> {
130  public:
131   // The last `back_elements_ignored` elements are not considered.
132   // Same for the first `front_elements_ignored` elements.
133   // This is useful for ops where
134   // only the first N elements are tensors (N < inputs.size()).
135   // For instance, the last argument to VarStack is an integer dimension.
136   explicit ProcessedNodeInputWrapper(
137       ProcessedNode& pnode,
138       size_t front_elements_ignored = 0,
139       size_t back_elements_ignored = 1)
140       : ProcessedNodeWrapperBase<ProcessedNodeInputWrapper>(pnode),
141         front_elements_ignored_(front_elements_ignored),
142         back_elements_ignored_(back_elements_ignored) {
143     TORCH_CHECK(front_elements_ignored_ <= pnode_.num_inputs());
144     TORCH_CHECK(
145         back_elements_ignored_ <=
146         pnode_.num_inputs() - front_elements_ignored_);
147   }
148 
size()149   size_t size() const {
150     return pnode_.num_inputs() - back_elements_ignored_ -
151         front_elements_ignored_;
152   }
153 
154   const at::Tensor& operator[](size_t idx) const {
155     TORCH_CHECK(idx < size());
156     return pnode_.Input(front_elements_ignored_ + idx).toTensor();
157   }
158 
front()159   const at::Tensor& front() const {
160     TORCH_CHECK(
161         !empty(),
162         "Attempted to access front() of empty ProcessedNodeInputWrapper");
163     return pnode_.Input(front_elements_ignored_).toTensor();
164   }
165 
back()166   const at::Tensor& back() const {
167     TORCH_CHECK(
168         !empty(),
169         "Attempted to access back() of empty ProcessedNodeInputWrapper");
170     return pnode_.Input(pnode_.num_inputs() - back_elements_ignored_ - 1)
171         .toTensor();
172   }
173 
174  private:
175   size_t front_elements_ignored_;
176   size_t back_elements_ignored_;
177 };
178 
179 // Similar to ProcessedNodeInputWrapper, but wraps outputs and allows for
180 // writing.
181 class ProcessedNodeOutputWrapper
182     : public ProcessedNodeWrapperBase<ProcessedNodeOutputWrapper> {
183  public:
184   using ProcessedNodeWrapperBase<
185       ProcessedNodeOutputWrapper>::ProcessedNodeWrapperBase;
186 
size()187   size_t size() const {
188     return pnode_.num_outputs();
189   }
190 
191   at::Tensor& operator[](size_t idx) const {
192     TORCH_CHECK(idx < size());
193     return pnode_.Output(idx).toTensor();
194   }
195 
front()196   at::Tensor& front() const {
197     TORCH_CHECK(
198         !empty(),
199         "Attempted to access front() of empty ProcessedNodeOutputWrapper");
200     return pnode_.Output(0).toTensor();
201   }
202 
back()203   at::Tensor& back() const {
204     TORCH_CHECK(
205         !empty(),
206         "Attempted to access back() of empty ProcessedNodeOutputWrapper");
207     return pnode_.Output(size() - 1).toTensor();
208   }
209 };
210 
211 } // namespace torch::jit
212