xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/dtype_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/function_schema.h>
2 #include <ATen/core/jit_type.h>
3 #include <ATen/core/symbol.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/ArrayRef.h>
6 #include <torch/csrc/jit/ir/alias_analysis.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/dtype_analysis.h>
10 #include <torch/csrc/jit/passes/utils/op_registry.h>
11 #include <torch/library.h>
12 #include <optional>
13 
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/empty.h>
18 #endif
19 
20 #include <algorithm>
21 #include <memory>
22 #include <stdexcept>
23 
24 namespace torch::jit {
25 
26 namespace {
27 
28 using Tensor = at::Tensor;
29 using ScalarType = at::ScalarType;
30 
31 // ----------------------------------------------------------------------------------
32 // Metatensor Inference for Dtype
33 // ----------------------------------------------------------------------------------
34 
MTensorArgumentCreator(Node * n)35 std::unique_ptr<Stack> MTensorArgumentCreator(Node* n) {
36   auto stack = std::make_unique<std::vector<IValue>>();
37   for (Value* inp : n->inputs()) {
38     if (auto tp = inp->type()->cast<TensorType>()) {
39       // Zero-dim tensors have special type promotion behavior, hence the need
40       // for rank.
41       auto rank = tp->symbolic_sizes().rank(); // Validity checked earlier
42       auto tensor_size = std::vector<int64_t>(rank.value(), 1);
43       stack->emplace_back(at::empty(
44           tensor_size, at::TensorOptions(at::kMeta).dtype(*tp->scalarType())));
45       continue;
46     }
47     // Someday Todo: Fill in concrete values that we know.
48     if (inp->type() == FloatType::get()) {
49       stack->emplace_back(1.);
50     } else if (inp->type() == IntType::get()) {
51       stack->emplace_back(1);
52     } else if (inp->type() == BoolType::get()) {
53       throw std::runtime_error(
54           "Bool currently unsupported, need to verify it's safe to add for all ops");
55       stack->emplace_back(false);
56     } else {
57       // Arrays of values are specifically not handled due
58       // to the fact that naive default values would likely be
59       // incorrect anyways.
60       throw std::runtime_error("Unsupported input type for Tensor argument");
61     }
62   }
63   return stack;
64 };
65 
MTensorNodeArgValid(Value * value)66 bool MTensorNodeArgValid(Value* value) {
67   auto tensor_type = value->type()->cast<TensorType>();
68   if (!tensor_type) {
69     return true;
70   }
71   if (!tensor_type->scalarType().has_value()) {
72     GRAPH_DEBUG("Argument missing Dtype");
73     return false;
74   }
75   auto rank = tensor_type->symbolic_sizes().rank();
76   return rank.has_value();
77 }
78 
canBeInferredWithMetaTensor(Node * n)79 static bool canBeInferredWithMetaTensor(Node* n) {
80   // Not a guarantee that the metatensor will not error out
81   // Do not have a allowlist for now and let things error out in execution.
82   // Has Tensor output is checked in another place
83   bool args_valid =
84       std::all_of(n->inputs().begin(), n->inputs().end(), MTensorNodeArgValid);
85 
86   if (!args_valid) {
87     return false;
88   }
89   if (n->outputs().size() != 1) {
90     // Currently not supporting multiple outputs
91     return false;
92   }
93   auto opt_op = n->maybeOperator();
94   if (!opt_op) {
95     GRAPH_DEBUG("not registered with Meta");
96     return false;
97   }
98   return true;
99 }
100 
inferWithMetaTensor(Node * n)101 std::optional<Tensor> inferWithMetaTensor(Node* n) {
102   GRAPH_DEBUG("inferWithMetaTensor", getHeader(n));
103   if (!canBeInferredWithMetaTensor(n)) {
104     return std::nullopt;
105   }
106   Operation op = n->getOperation();
107   try {
108     auto stack = MTensorArgumentCreator(n);
109     GRAPH_DEBUG("Running op for ", getHeader(n));
110     op(*stack);
111     GRAPH_DEBUG("op run successfully", getHeader(n));
112     GRAPH_DEBUG("After receive!");
113     return stack->back().toTensor();
114 
115   } catch (...) {
116     GRAPH_DEBUG("caught exception with Metatensor run!");
117   };
118   return std::nullopt;
119 }
120 
setDtype(Value * value,ScalarType scalarType,bool can_overwrite_dtype=false)121 bool setDtype(
122     Value* value,
123     ScalarType scalarType,
124     bool can_overwrite_dtype = false) {
125   auto tensor_type = value->type()->cast<TensorType>();
126   TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
127   if (!tensor_type->scalarType().has_value()) {
128     value->setType(tensor_type->withScalarType(scalarType));
129     return true;
130   }
131   if (tensor_type->scalarType().value() != scalarType) {
132     TORCH_INTERNAL_ASSERT(
133         can_overwrite_dtype,
134         "Expected tensor type to be ",
135         scalarType,
136         " but found ",
137         tensor_type->scalarType().value());
138     value->setType(tensor_type->withScalarType(scalarType));
139     return true;
140   }
141   return false;
142 }
143 
tryApplyDtypeMetaTensor(Node * n)144 bool tryApplyDtypeMetaTensor(Node* n) {
145   // returns if anything was changed
146   auto return_tensor = inferWithMetaTensor(n);
147   if (!return_tensor) {
148     return false;
149   }
150   GRAPH_DEBUG("Received ", toString(return_tensor->scalar_type()));
151   return setDtype(n->output(), return_tensor->scalar_type());
152 }
153 
154 // ----------------------------------------------------------------------------------
155 // Custom Rules for Dtype
156 // ----------------------------------------------------------------------------------
157 using DtypePropRule = std::function<bool(Node*)>;
158 // Function to propagate dtype information for a node
159 // Returns true if the dtype information was changed
160 
setIfAllDtypeMatch(Node * n)161 bool setIfAllDtypeMatch(Node* n) {
162   // Sets all tensor outputs to the dtype of the first input
163   // only if all inputs are the same dtype, otherwise do nothing
164   TORCH_INTERNAL_ASSERT(!n->inputs().empty());
165   auto first_arg = n->inputs().at(0);
166   auto tensor_type = first_arg->type()->cast<TensorType>();
167   TORCH_INTERNAL_ASSERT(tensor_type, "Expecting a tensor type");
168   auto scalar_type = tensor_type->scalarType();
169   if (!scalar_type.has_value()) {
170     return false;
171   }
172   for (auto arg : n->inputs()) {
173     tensor_type = arg->type()->cast<TensorType>();
174     if (!tensor_type) {
175       continue;
176     }
177     auto arg_scalar_type = tensor_type->scalarType();
178 
179     if (!arg_scalar_type.has_value()) { // Allow None for optional args
180       continue;
181     }
182     if (arg_scalar_type != scalar_type) {
183       return false;
184     }
185   }
186 
187   bool changed = false;
188   for (auto output : n->outputs()) {
189     if (output->type()->cast<TensorType>()) {
190       changed |= setDtype(output, scalar_type.value());
191     }
192   }
193   return changed;
194 }
195 
196 // DtypePropagationPass is an analysis pass that walks through a graph in
197 // topological order and forward propagate Dtypes (ScalarTypes) from graph
198 // inputs (expressed in input_descriptors) to all output tensor nodes in the
199 // graph.
200 struct DtypePropagationPass {
DtypePropagationPasstorch::jit::__anon781a85d60111::DtypePropagationPass201   explicit DtypePropagationPass(std::shared_ptr<Graph> graph)
202       : graph_(std::move(graph)) {
203     buildDtypeRuleRegistry();
204   }
205 
206   // returns true if at least one node has its scalar type set on a tensor node
runtorch::jit::__anon781a85d60111::DtypePropagationPass207   bool run() {
208     return processBlocks(graph_->block());
209   }
210 
211  private:
processBlockstorch::jit::__anon781a85d60111::DtypePropagationPass212   bool processBlocks(at::ArrayRef<Block*> blocks) {
213     bool changed = false;
214     for (auto block : blocks) {
215       changed |= processBlock(block);
216     }
217     return changed;
218   }
219 
processBlocktorch::jit::__anon781a85d60111::DtypePropagationPass220   bool processBlock(Block* block) {
221     GRAPH_DEBUG("processBlock");
222     bool changed = false;
223     for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
224       changed |= processNode(*it);
225     }
226     return changed;
227   }
228 
processNodetorch::jit::__anon781a85d60111::DtypePropagationPass229   bool processNode(Node* n) {
230     GRAPH_DEBUG("processNode");
231     switch (n->kind()) {
232       case prim::If:
233         return processIf(n);
234       case prim::Loop:
235       case prim::CallMethod:
236       case prim::CallFunction:
237         TORCH_INTERNAL_ASSERT(false, "Loop/Call not handled now");
238       default:
239         break;
240     }
241 
242     bool has_tensor_output =
243         std::any_of(n->outputs().begin(), n->outputs().end(), [](Value* v) {
244           return (bool)v->type()->cast<TensorType>();
245         });
246 
247     if (!has_tensor_output) {
248       // if output contains no tensor, nothing to propagate
249       return false;
250     }
251 
252     switch (n->kind()) {
253       case prim::Constant:
254         // This is already been propagated by something else in freezing
255         return false;
256       case prim::ListConstruct:
257       case prim::ListUnpack:
258         TORCH_INTERNAL_ASSERT(
259             false,
260             "List Construct and Unpack is not supported in Dtype Propagation");
261         break;
262       default:
263         if (n->kind().is_aten()) {
264           return processAtenOps(n);
265         } else {
266           TORCH_INTERNAL_ASSERT(
267               false,
268               n->kind().toDisplayString(),
269               "Op is not supported in Dtype Propagation");
270         }
271     }
272     return false;
273   }
274 
mergeTensorPropertiestorch::jit::__anon781a85d60111::DtypePropagationPass275   bool mergeTensorProperties(
276       const at::ArrayRef<Value*>& list1,
277       const at::ArrayRef<Value*>& list2) {
278     // This is currently a placeholder for MobileNet
279     // After Month1: implement the merge function
280     TORCH_INTERNAL_ASSERT(list1.empty(), "Not implemented yet");
281     return false;
282   }
283 
processIftorch::jit::__anon781a85d60111::DtypePropagationPass284   bool processIf(Node* node) {
285     GRAPH_DEBUG("processIf");
286     bool changed = false;
287     auto blocks = node->blocks();
288     auto true_block = blocks.at(0);
289     auto false_block = blocks.at(1);
290 
291     changed |= processBlock(true_block);
292     changed |= processBlock(false_block);
293 
294     changed |=
295         mergeTensorProperties(true_block->outputs(), false_block->outputs());
296 
297     return changed;
298   }
299 
300   // for efficiency
processAtenOpstorch::jit::__anon781a85d60111::DtypePropagationPass301   bool processAtenOps(Node* n) {
302     GRAPH_DEBUG("processAtenOps");
303     GRAPH_DEBUG("case = ", n->kind(), " ", *n);
304     // Custom Rule Matching
305     if (auto prop_fn = dtype_prop_registry_->find(n->getOperator())) {
306       DtypePropRule rule = *prop_fn;
307       return rule(n);
308     }
309     return tryApplyDtypeMetaTensor(n);
310   }
311 
buildDtypeRuleRegistrytorch::jit::__anon781a85d60111::DtypePropagationPass312   void buildDtypeRuleRegistry() {
313     // building a registry for all of the custom dtype rules
314     dtype_prop_registry_ = std::make_unique<OperatorMap<DtypePropRule>>();
315 
316     dtype_prop_registry_->insert(
317         *nn_ops_first_input_preserving(), setIfAllDtypeMatch);
318     dtype_prop_registry_->insert(
319         *ops_one_tensor_in_shape_transform(), setIfAllDtypeMatch);
320   }
321   std::unique_ptr<OperatorMap<DtypePropRule>> dtype_prop_registry_;
322   std::shared_ptr<Graph> graph_;
323 };
324 
325 } // anonymous namespace
326 
327 // This analysis propagates input dtypes (if any) throughout the
328 // graph.
DtypePropagation(std::shared_ptr<Graph> & graph)329 bool DtypePropagation(std::shared_ptr<Graph>& graph) {
330   DtypePropagationPass tp = DtypePropagationPass(graph);
331   bool changed = tp.run();
332   if (changed) {
333     GRAPH_DUMP("After TensorPropertyPropagation pass:", graph);
334   }
335   return changed;
336 }
337 
338 } // namespace torch::jit
339