1 #pragma once
2
3 #include <type_traits>
4
5 #include <ATen/core/ivalue.h>
6 #include <c10/util/Deprecated.h>
7 #include <c10/util/irange.h>
8
9 // TODO move this to c10 namespace
10
11
12 namespace torch::jit {
13
14 using c10::IValue;
15 using Stack = std::vector<IValue>;
16
17 class Operation {
18 template <typename F, typename Arg>
19 using accepts = std::is_constructible<std::function<void(Arg)>, F&&>;
20
21 public:
22 template <typename F,
23 std::enable_if_t<accepts<F, Stack*>::value, int> = 0>
24 C10_DEPRECATED_MESSAGE("Please use void(Stack&) to register operator instead.")
25 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
Operation(F && raw)26 Operation(F&& raw): op_([raw = std::forward<F>(raw)](Stack& stack) {
27 raw(&stack);
28 }) {}
29
30 template <typename F,
31 std::enable_if_t<accepts<F, Stack&>::value &&
32 !std::is_same_v<std::decay_t<F>, Operation>, int> = 0>
Operation(F && op)33 Operation(F&& op): op_(std::forward<F>(op)) {}
34
Operation(std::nullptr_t)35 Operation(std::nullptr_t) noexcept {}
36
37 explicit operator bool() const noexcept {
38 return op_ ? true : false;
39 }
40
operator()41 void operator()(Stack& stack) {
42 op_(stack);
43 }
44
45 template <typename T>
target()46 T* target() noexcept {
47 return op_.target<T>();
48 }
49
50 private:
51 std::function<void(Stack&)> op_;
52 };
53
54 // An operation with N inputs and M outputs pops the last N inputs off
55 // the stack and pushes its M inputs onto the stack
56 // before: <other stack items> I0, I1, ... IN <- stack.back()
57 // after: <other stack items> O0, O1, ... OM
58 // operations are defined this way so that ownership of inputs can be
59 // transferred to the operation and it can incrementally drop ownership of
60 // tensors when they become unneeded. For large operations, like 'run an entire
61 // subgraph', this functionality is very important for minimizing gpu memory
62 // usage return value is the relative 'offset' to jump to for the next
63 // operation:
64 // pc += 1 + offset
65 // so a return value of 0 goes to the next instruction
66
67 // treat the last N elements of the stack as a list, looking up
68 // element i
peek(Stack & stack,size_t i,size_t N)69 inline IValue& peek(Stack& stack, size_t i, size_t N) {
70 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
71 return *(stack.end() - N + i);
72 }
peek(Stack * stack,size_t i,size_t N)73 inline IValue& peek(Stack* stack, size_t i, size_t N) {
74 return peek(*stack, i, N);
75 }
peek(const Stack & stack,size_t i,size_t N)76 inline const IValue& peek(const Stack& stack, size_t i, size_t N) {
77 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
78 return *(stack.end() - N + i);
79 }
peek(const Stack * stack,size_t i,size_t N)80 inline const IValue& peek(const Stack* stack, size_t i, size_t N) {
81 return peek(*stack, i, N);
82 }
83 // treat the last N elements of the stack as a list, looking up the
84 // slice starting at index i and having length len
peekSlice(const Stack & stack,size_t i,size_t len,size_t N)85 inline at::ArrayRef<IValue> peekSlice(
86 const Stack& stack,
87 size_t i,
88 size_t len,
89 size_t N) {
90 return at::ArrayRef<IValue>(stack).slice(stack.size() - N + i, len);
91 }
last(const Stack & stack,size_t N)92 inline at::ArrayRef<IValue> last(const Stack& stack, size_t N) {
93 return peekSlice(stack, 0, N, N);
94 }
last(const Stack * stack,size_t N)95 inline at::ArrayRef<IValue> last(const Stack* stack, size_t N) {
96 return last(*stack, N);
97 }
drop(Stack & stack,size_t n)98 inline void drop(Stack& stack, size_t n) {
99 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions)
100 stack.erase(stack.end() - n, stack.end());
101 }
drop(Stack * stack,size_t n)102 inline void drop(Stack* stack, size_t n) {
103 drop(*stack, n);
104 }
pop(Stack & stack)105 inline IValue pop(Stack& stack) {
106 auto r = std::move(stack.back());
107 stack.pop_back();
108 return r;
109 }
pop(Stack * stack)110 inline IValue pop(Stack* stack) {
111 return pop(*stack);
112 }
pop(Stack & stack,size_t n)113 inline std::vector<IValue> pop(Stack& stack, size_t n) {
114 std::vector<IValue> result;
115 result.reserve(n);
116 for (const auto i : c10::irange(n)) {
117 result.push_back(std::move(peek(stack, i, n)));
118 }
119 drop(stack, n);
120 return result;
121 }
122
123 // variadic pop:
124 // int64_t a; at::Tensor b;
125 // pop(stack, a, b);
126 // equivalent to:
127 // b = pop(stack).toTensor();
128 // a = pop(stack).toInt();
129 template <typename... Types>
pop(Stack & stack,Types &...args)130 inline void pop(Stack& stack, Types&... args) {
131 size_t i = 0;
132 constexpr size_t N = sizeof...(args);
133 (void)std::initializer_list<int>{
134 (args = std::move(peek(stack, i++, N)).template to<Types>(), 0)...};
135 drop(stack, N);
136 }
137 template <typename... Types>
pop(Stack * stack,Types &...args)138 inline void pop(Stack* stack, Types&... args) {
139 pop(*stack, args...);
140 }
141 template <typename Type>
push_one(Stack & stack,Type && arg)142 inline void push_one(Stack& stack, Type&& arg) {
143 stack.emplace_back(std::forward<Type>(arg));
144 }
145
push_one(Stack & stack,c10::TensorOptions options)146 inline void push_one(Stack& stack, c10::TensorOptions options) {
147 stack.emplace_back(c10::typeMetaToScalarType(options.dtype()));
148 stack.emplace_back(options.layout());
149 stack.emplace_back(options.device());
150 stack.emplace_back(options.pinned_memory());
151 }
152
153 template <typename... Types>
push(Stack & stack,Types &&...args)154 inline void push(Stack& stack, Types&&... args) {
155 (void)std::initializer_list<int>{(push_one(stack, std::forward<Types>(args)), 0)...};
156 }
157 template <typename... Types>
push(Stack * stack,Types &&...args)158 inline void push(Stack* stack, Types&&... args) {
159 return push(*stack, std::forward<Types>(args)...);
160 }
161 template <class T>
push_list_elements(Stack & stack,const c10::List<T> & elements)162 inline void push_list_elements(Stack& stack, const c10::List<T>& elements) {
163 for (T elem : elements) {
164 stack.push_back(std::move(elem));
165 }
166 }
167
168 // The packer here is carefully written not to make any unnecessary
169 // copies.
170
171 // pack takes the return values of aten functions pushes them onto the stack
172 template <typename T>
pack(Stack & stack,T && v)173 inline void pack(Stack& stack, T&& v) {
174 stack.emplace_back(std::forward<T>(v));
175 }
176 template <typename T>
pack(Stack * stack,T && v)177 inline void pack(Stack* stack, T&& v) {
178 pack(*stack, std::forward<T>(v));
179 }
180
181 template <std::size_t remaining, typename... Args>
182 struct TuplePacker {
183 // NB: *Not* a universal reference.
executeTuplePacker184 static void execute(Stack& stack, std::tuple<Args...>&& t) {
185 // NB: The move here does not "destroy" the entire tuple, that is
186 // not what std::move does; only the particular tuple index
187 // processed here gets stolen.
188 pack(stack, std::get<sizeof...(Args) - remaining>(std::move(t)));
189 TuplePacker<remaining - 1, Args...>::execute(stack, std::move(t));
190 }
191 };
192
193 template <typename... Args>
194 struct TuplePacker<0, Args...> {
195 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
196 static void execute(Stack& /*stack*/, std::tuple<Args...>&& /*t*/){};
197 };
198
199 template <typename... Args>
200 inline void pack(Stack& stack, std::tuple<Args...>&& t) {
201 TuplePacker<sizeof...(Args), Args...>::execute(stack, std::move(t));
202 }
203
204 } // namespace torch::jit
205