1 #include <ATen/core/function_schema.h>
2 #include <ATen/core/jit_type.h>
3 #include <ATen/core/symbol.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/ArrayRef.h>
6 #include <torch/csrc/jit/ir/alias_analysis.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/dtype_analysis.h>
10 #include <torch/csrc/jit/passes/utils/op_registry.h>
11 #include <torch/library.h>
12 #include <optional>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #endif
19
20 #include <algorithm>
21 #include <memory>
22 #include <stdexcept>
23
24 namespace torch::jit {
25
26 namespace {
27
28 using Tensor = at::Tensor;
29 using ScalarType = at::ScalarType;
30
31 // ----------------------------------------------------------------------------------
32 // Metatensor Inference for Dtype
33 // ----------------------------------------------------------------------------------
34
MTensorArgumentCreator(Node * n)35 std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) {
36 auto stack = std::make_unique<std::vector<IValue>>();
37 for (Value* inp : n->inputs()) {
38 if (auto tp = inp->type()->cast<TensorType>()) {
39 // Zero-dim tensors have special type promotion behavior, hence the need
40 // for rank.
41 auto rank = tp->symbolic_sizes().rank(); // Validity checked earlier
42 auto tensor_size = std::vector<int64_t>(rank.value(), 1);
43 stack->emplace_back(at::empty(
44 tensor_size, at::TensorOptions(at::kMeta).dtype(*tp->scalarType())));
45 continue;
46 }
47 // Someday Todo: Fill in concrete values that we know.
48 if (inp->type() == FloatType::get()) {
49 stack->emplace_back(1.);
50 } else if (inp->type() == IntType::get()) {
51 stack->emplace_back(1);
52 } else if (inp->type() == BoolType::get()) {
53 throw std::runtime_error(
54 "Bool currently unsupported, need to verify it's safe to add for all ops");
55 stack->emplace_back(false);
56 } else {
57 // Arrays of values are specifically not handled due
58 // to the fact that naive default values would likely be
59 // incorrect anyways.
60 throw std::runtime_error("Unsupported input type for Tensor argument");
61 }
62 }
63 return stack;
64 };
65
MTensorNodeArgValid(Value * value)66 bool MTensorNodeArgValid(Value* value) {
67 auto tensor_type = value->type()->cast<TensorType>();
68 if (!tensor_type) {
69 return true;
70 }
71 if (!tensor_type->scalarType().has_value()) {
72 GRAPH_DEBUG("Argument missing Dtype");
73 return false;
74 }
75 auto rank = tensor_type->symbolic_sizes().rank();
76 return rank.has_value();
77 }
78
canBeInferredWithMetaTensor(Node * n)79 static bool canBeInferredWithMetaTensor(Node* n) {
80 // Not a guarantee that the metatensor will not error out
81 // Do not have a allowlist for now and let things error out in execution.
82 // Has Tensor output is checked in another place
83 bool args_valid =
84 std::all_of(n->inputs().begin(), n->inputs().end(), MTensorNodeArgValid);
85
86 if (!args_valid) {
87 return false;
88 }
89 if (n->outputs().size() != 1) {
90 // Currently not supporting multiple outputs
91 return false;
92 }
93 auto opt_op = n->maybeOperator();
94 if (!opt_op) {
95 GRAPH_DEBUG("not registered with Meta");
96 return false;
97 }
98 return true;
99 }
100
inferWithMetaTensor(Node * n)101 std::optional<Tensor> inferWithMetaTensor(Node* n) {
102 GRAPH_DEBUG("inferWithMetaTensor", getHeader(n));
103 if (!canBeInferredWithMetaTensor(n)) {
104 return std::nullopt;
105 }
106 Operation op = n->getOperation();
107 try {
108 auto stack = MTensorArgumentCreator(n);
109 GRAPH_DEBUG("Running op for ", getHeader(n));
110 op(*stack);
111 GRAPH_DEBUG("op run successfully", getHeader(n));
112 GRAPH_DEBUG("After receive!");
113 return stack->back().toTensor();
114
115 } catch (...) {
116 GRAPH_DEBUG("caught exception with Metatensor run!");
117 };
118 return std::nullopt;
119 }
120
setDtype(Value * value,ScalarType scalarType,bool can_overwrite_dtype=false)121 bool setDtype(
122 Value* value,
123 ScalarType scalarType,
124 bool can_overwrite_dtype = false) {
125 auto tensor_type = value->type()->cast<TensorType>();
126 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
127 if (!tensor_type->scalarType().has_value()) {
128 value->setType(tensor_type->withScalarType(scalarType));
129 return true;
130 }
131 if (tensor_type->scalarType().value() != scalarType) {
132 TORCH_INTERNAL_ASSERT(
133 can_overwrite_dtype,
134 "Expected tensor type to be ",
135 scalarType,
136 " but found ",
137 tensor_type->scalarType().value());
138 value->setType(tensor_type->withScalarType(scalarType));
139 return true;
140 }
141 return false;
142 }
143
tryApplyDtypeMetaTensor(Node * n)144 bool tryApplyDtypeMetaTensor(Node* n) {
145 // returns if anything was changed
146 auto return_tensor = inferWithMetaTensor(n);
147 if (!return_tensor) {
148 return false;
149 }
150 GRAPH_DEBUG("Received ", toString(return_tensor->scalar_type()));
151 return setDtype(n->output(), return_tensor->scalar_type());
152 }
153
154 // ----------------------------------------------------------------------------------
155 // Custom Rules for Dtype
156 // ----------------------------------------------------------------------------------
157 using DtypePropRule = std::function<bool(Node*)>;
158 // Function to propagate dtype information for a node
159 // Returns true if the dtype information was changed
160
setIfAllDtypeMatch(Node * n)161 bool setIfAllDtypeMatch(Node* n) {
162 // Sets all tensor outputs to the dtype of the first input
163 // only if all inputs are the same dtype, otherwise do nothing
164 TORCH_INTERNAL_ASSERT(!n->inputs().empty());
165 auto first_arg = n->inputs().at(0);
166 auto tensor_type = first_arg->type()->cast<TensorType>();
167 TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
168 auto scalar_type = tensor_type->scalarType();
169 if (!scalar_type.has_value()) {
170 return false;
171 }
172 for (auto arg : n->inputs()) {
173 tensor_type = arg->type()->cast<TensorType>();
174 if (!tensor_type) {
175 continue;
176 }
177 auto arg_scalar_type = tensor_type->scalarType();
178
179 if (!arg_scalar_type.has_value()) { // Allow None for optional args
180 continue;
181 }
182 if (arg_scalar_type != scalar_type) {
183 return false;
184 }
185 }
186
187 bool changed = false;
188 for (auto output : n->outputs()) {
189 if (output->type()->cast<TensorType>()) {
190 changed |= setDtype(output, scalar_type.value());
191 }
192 }
193 return changed;
194 }
195
196 // DtypePropagationPass is an analysis pass that walks through a graph in
197 // topological order and forward propagate Dtypes (ScalarTypes) from graph
198 // inputs (expressed in input_descriptors) to all output tensor nodes in the
199 // graph.
200 struct DtypePropagationPass {
DtypePropagationPasstorch::jit::__anon781a85d60111::DtypePropagationPass201 explicit DtypePropagationPass(std::shared_ptr<Graph> graph)
202 : graph_(std::move(graph)) {
203 buildDtypeRuleRegistry();
204 }
205
206 // returns true if at least one node has its scalar type set on a tensor node
runtorch::jit::__anon781a85d60111::DtypePropagationPass207 bool run() {
208 return processBlocks(graph_->block());
209 }
210
211 private:
processBlockstorch::jit::__anon781a85d60111::DtypePropagationPass212 bool processBlocks(at::ArrayRef<Block*> blocks) {
213 bool changed = false;
214 for (auto block : blocks) {
215 changed |= processBlock(block);
216 }
217 return changed;
218 }
219
processBlocktorch::jit::__anon781a85d60111::DtypePropagationPass220 bool processBlock(Block* block) {
221 GRAPH_DEBUG("processBlock");
222 bool changed = false;
223 for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
224 changed |= processNode(*it);
225 }
226 return changed;
227 }
228
processNodetorch::jit::__anon781a85d60111::DtypePropagationPass229 bool processNode(Node* n) {
230 GRAPH_DEBUG("processNode");
231 switch (n->kind()) {
232 case prim::If:
233 return processIf(n);
234 case prim::Loop:
235 case prim::CallMethod:
236 case prim::CallFunction:
237 TORCH_INTERNAL_ASSERT(false, "Loop/Call not handled now");
238 default:
239 break;
240 }
241
242 bool has_tensor_output =
243 std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
244 return (bool)v->type()->cast<TensorType>();
245 });
246
247 if (!has_tensor_output) {
248 // if output contains no tensor, nothing to propagate
249 return false;
250 }
251
252 switch (n->kind()) {
253 case prim::Constant:
254 // This is already been propagated by something else in freezing
255 return false;
256 case prim::ListConstruct:
257 case prim::ListUnpack:
258 TORCH_INTERNAL_ASSERT(
259 false,
260 "List Construct and Unpack is not supported in Dtype Propagation");
261 break;
262 default:
263 if (n->kind().is_aten()) {
264 return processAtenOps(n);
265 } else {
266 TORCH_INTERNAL_ASSERT(
267 false,
268 n->kind().toDisplayString(),
269 "Op is not supported in Dtype Propagation");
270 }
271 }
272 return false;
273 }
274
mergeTensorPropertiestorch::jit::__anon781a85d60111::DtypePropagationPass275 bool mergeTensorProperties(
276 const at::ArrayRef<Value*>& list1,
277 const at::ArrayRef<Value*>& list2) {
278 // This is currently a placeholder for MobileNet
279 // After Month1: implement the merge function
280 TORCH_INTERNAL_ASSERT(list1.empty(), "Not implemented yet");
281 return false;
282 }
283
processIftorch::jit::__anon781a85d60111::DtypePropagationPass284 bool processIf(Node* node) {
285 GRAPH_DEBUG("processIf");
286 bool changed = false;
287 auto blocks = node->blocks();
288 auto true_block = blocks.at(0);
289 auto false_block = blocks.at(1);
290
291 changed |= processBlock(true_block);
292 changed |= processBlock(false_block);
293
294 changed |=
295 mergeTensorProperties(true_block->outputs(), false_block->outputs());
296
297 return changed;
298 }
299
300 // for efficiency
processAtenOpstorch::jit::__anon781a85d60111::DtypePropagationPass301 bool processAtenOps(Node* n) {
302 GRAPH_DEBUG("processAtenOps");
303 GRAPH_DEBUG("case = ", n->kind(), " ", *n);
304 // Custom Rule Matching
305 if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) {
306 DtypePropRule rule = *prop_fn;
307 return rule(n);
308 }
309 return tryApplyDtypeMetaTensor(n);
310 }
311
buildDtypeRuleRegistrytorch::jit::__anon781a85d60111::DtypePropagationPass312 void buildDtypeRuleRegistry() {
313 // building a registry for all of the custom dtype rules
314 dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>();
315
316 dtype_prop_registry_->insert(
317 *nn_ops_first_input_preserving(), setIfAllDtypeMatch);
318 dtype_prop_registry_->insert(
319 *ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch);
320 }
321 std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_;
322 std::shared_ptr<Graph> graph_;
323 };
324
325 } // anonymous namespace
326
327 // This analysis propagates input dtypes (if any) throughout the
328 // graph.
DtypePropagation(std::shared_ptr<Graph> & graph)329 bool DtypePropagation(std::shared_ptr<Graph>& graph) {
330 DtypePropagationPass tp = DtypePropagationPass(graph);
331 bool changed = tp.run();
332 if (changed) {
333 GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
334 }
335 return changed;
336 }
337
338 } // namespace torch::jit
339