xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/stack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <type_traits>
4 
5 #include <ATen/core/ivalue.h>
6 #include <c10/util/Deprecated.h>
7 #include <c10/util/irange.h>
8 
9 // TODO move this to c10 namespace
10 
11 
12 namespace torch::jit {
13 
14 using c10::IValue;
15 using Stack = std::vector<IValue>;
16 
17 class Operation {
18   template <typename F, typename Arg>
19   using accepts = std::is_constructible<std::function<void(Arg)>, F&&>;
20 
21  public:
22   template <typename F,
23             std::enable_if_t<accepts<F, Stack*>::value, int> = 0>
24   C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
25   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
Operation(F && raw)26   Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) {
27     raw(&stack);
28   }) {}
29 
30   template <typename F,
31             std::enable_if_t<accepts<F, Stack&>::value &&
32                 !std::is_same_v<std::decay_t<F>, Operation>, int> = 0>
Operation(F && op)33   Operation(F&& op): op_(std::forward<F>(op)) {}
34 
Operation(std::nullptr_t)35   Operation(std::nullptr_t) noexcept {}
36 
37   explicit operator bool() const noexcept {
38     return op_ ? true : false;
39   }
40 
operator()41   void operator()(Stack& stack) {
42     op_(stack);
43   }
44 
45   template <typename T>
target()46   T* target() noexcept {
47     return op_.target<T>();
48   }
49 
50  private:
51   std::function<void(Stack&)> op_;
52 };
53 
54 // An operation with N inputs and M outputs pops the last N inputs off
55 // the stack and pushes its M inputs onto the stack
56 // before: <other stack items> I0, I1, ... IN <- stack.back()
57 // after: <other stack items> O0, O1, ... OM
58 // operations are defined this way so that ownership of inputs can be
59 // transferred to the operation and it can incrementally drop ownership of
60 // tensors when they become unneeded. For large operations, like 'run an entire
61 // subgraph', this functionality is very important for minimizing gpu memory
62 // usage return value is the relative 'offset' to jump to for the next
63 // operation:
64 //   pc += 1 + offset
65 // so a return value of 0 goes to the next instruction
66 
67 // treat the last N elements of the stack as a list, looking up
68 // element i
peek(Stack & stack,size_t i,size_t N)69 inline IValue& peek(Stack& stack, size_t i, size_t N) {
70   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
71   return *(stack.end() - N + i);
72 }
peek(Stack * stack,size_t i,size_t N)73 inline IValue& peek(Stack* stack, size_t i, size_t N) {
74   return peek(*stack, i, N);
75 }
peek(const Stack & stack,size_t i,size_t N)76 inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
77   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
78   return *(stack.end() - N + i);
79 }
peek(const Stack * stack,size_t i,size_t N)80 inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
81   return peek(*stack, i, N);
82 }
83 // treat the last N elements of the stack as a list, looking up the
84 // slice starting at index i and having length len
peekSlice(const Stack & stack,size_t i,size_t len,size_t N)85 inline at::ArrayRef<IValue> peekSlice(
86     const Stack& stack,
87     size_t i,
88     size_t len,
89     size_t N) {
90   return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
91 }
last(const Stack & stack,size_t N)92 inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
93   return peekSlice(stack, 0, N, N);
94 }
last(const Stack * stack,size_t N)95 inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
96   return last(*stack, N);
97 }
drop(Stack & stack,size_t n)98 inline void drop(Stack& stack, size_t n) {
99   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
100   stack.erase(stack.end() - n, stack.end());
101 }
drop(Stack * stack,size_t n)102 inline void drop(Stack* stack, size_t n) {
103   drop(*stack, n);
104 }
pop(Stack & stack)105 inline IValue pop(Stack& stack) {
106   auto r = std::move(stack.back());
107   stack.pop_back();
108   return r;
109 }
pop(Stack * stack)110 inline IValue pop(Stack* stack) {
111   return pop(*stack);
112 }
pop(Stack & stack,size_t n)113 inline std::vector<IValue> pop(Stack& stack, size_t n) {
114   std::vector<IValue> result;
115   result.reserve(n);
116   for (const auto i : c10::irange(n)) {
117     result.push_back(std::move(peek(stack, i, n)));
118   }
119   drop(stack, n);
120   return result;
121 }
122 
123 // variadic pop:
124 // int64_t a; at::Tensor b;
125 // pop(stack, a, b);
126 // equivalent to:
127 // b = pop(stack).toTensor();
128 // a = pop(stack).toInt();
129 template <typename... Types>
pop(Stack & stack,Types &...args)130 inline void pop(Stack& stack, Types&... args) {
131   size_t i = 0;
132   constexpr size_t N = sizeof...(args);
133   (void)std::initializer_list<int>{
134       (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
135   drop(stack, N);
136 }
137 template <typename... Types>
pop(Stack * stack,Types &...args)138 inline void pop(Stack* stack, Types&... args) {
139   pop(*stack, args...);
140 }
141 template <typename Type>
push_one(Stack & stack,Type && arg)142 inline void push_one(Stack& stack, Type&& arg) {
143   stack.emplace_back(std::forward<Type>(arg));
144 }
145 
push_one(Stack & stack,c10::TensorOptions options)146 inline void push_one(Stack& stack, c10::TensorOptions options) {
147   stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
148   stack.emplace_back(options.layout());
149   stack.emplace_back(options.device());
150   stack.emplace_back(options.pinned_memory());
151 }
152 
153 template <typename... Types>
push(Stack & stack,Types &&...args)154 inline void push(Stack& stack, Types&&... args) {
155   (void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
156 }
157 template <typename... Types>
push(Stack * stack,Types &&...args)158 inline void push(Stack* stack, Types&&... args) {
159   return push(*stack, std::forward<Types>(args)...);
160 }
161 template <class T>
push_list_elements(Stack & stack,const c10::List<T> & elements)162 inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
163   for (T elem : elements) {
164     stack.push_back(std::move(elem));
165   }
166 }
167 
168 // The packer here is carefully written not to make any unnecessary
169 // copies.
170 
171 // pack takes the return values of aten functions pushes them onto the stack
172 template <typename T>
pack(Stack & stack,T && v)173 inline void pack(Stack& stack, T&& v) {
174   stack.emplace_back(std::forward<T>(v));
175 }
176 template <typename T>
pack(Stack * stack,T && v)177 inline void pack(Stack* stack, T&& v) {
178   pack(*stack, std::forward<T>(v));
179 }
180 
181 template <std::size_t remaining, typename... Args>
182 struct TuplePacker {
183   // NB: *Not* a universal reference.
executeTuplePacker184   static void execute(Stack& stack, std::tuple<Args...>&& t) {
185     // NB: The move here does not "destroy" the entire tuple, that is
186     // not what std::move does; only the particular tuple index
187     // processed here gets stolen.
188     pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t)));
189     TuplePacker<remaining - 1, Args...>::execute(stack, std::move(t));
190   }
191 };
192 
193 template <typename... Args>
194 struct TuplePacker<0, Args...> {
195   // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
196   static void execute(Stack& /*stack*/, std::tuple<Args...>&& /*t*/){};
197 };
198 
199 template <typename... Args>
200 inline void pack(Stack& stack, std::tuple<Args...>&& t) {
201   TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
202 }
203 
204 } // namespace torch::jit
205