1 #include <torch/csrc/jit/passes/shape_analysis.h>
2
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/frontend/error_report.h>
6 #include <torch/csrc/jit/ir/alias_analysis.h>
7 #include <torch/csrc/jit/ir/constants.h>
8 #include <torch/csrc/jit/ir/ir.h>
9 #include <torch/csrc/jit/ir/ir_views.h>
10 #include <torch/csrc/jit/passes/utils/op_registry.h>
11 #include <torch/csrc/jit/runtime/exception_message.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13
14 #include <torch/csrc/autograd/variable.h>
15
16 #include <ATen/DeviceGuard.h>
17 #include <ATen/ExpandUtils.h>
18 #include <ATen/core/symbol.h>
19
20 #ifndef AT_PER_OPERATOR_HEADERS
21 #include <ATen/Functions.h>
22 #else
23 #include <ATen/ops/empty_strided.h>
24 #endif
25
26 #include <exception>
27 #include <memory>
28 #include <sstream>
29 #include <utility>
30 #include <vector>
31
32 namespace torch::jit {
33
mergeTypes(ArrayRef<Value * > lhs,ArrayRef<Value * > rhs,ArrayRef<Value * > outputs)34 bool mergeTypes(
35 ArrayRef<Value*> lhs,
36 ArrayRef<Value*> rhs,
37 ArrayRef<Value*> outputs) {
38 AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
39 bool changed = false;
40 for (const auto i : c10::irange(lhs.size())) {
41 auto old_output_type = outputs[i]->type();
42 auto new_type =
43 unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_union=*/true);
44 AT_ASSERT(new_type);
45 outputs[i]->setType(*new_type);
46 if (*old_output_type != *outputs[i]->type())
47 changed = true;
48 }
49 return changed;
50 }
51
applyTypes(ArrayRef<Value * > src,ArrayRef<Value * > dst)52 static void applyTypes(ArrayRef<Value*> src, ArrayRef<Value*> dst) {
53 AT_ASSERT(src.size() == dst.size());
54 for (const auto i : c10::irange(src.size())) {
55 dst[i]->setType(src[i]->type());
56 }
57 }
58
propagateBlock(Block * block,bool insert_expands)59 void PropertyPropBase::propagateBlock(Block* block, bool insert_expands) {
60 for (Node* node : block->nodes()) {
61 try {
62 propagateNode(node, insert_expands);
63 } catch (propagation_error& e) {
64 setUnshapedType(node);
65 } catch (std::exception& e) {
66 throw(
67 ErrorReport(node->sourceRange())
68 << ExceptionMessage(e)
69 << "\nThe above operation failed shape propagation in this context");
70 }
71 }
72 }
73
processIf(Node * node)74 void PropertyPropBase::processIf(Node* node) {
75 auto then_block = node->blocks().at(0);
76 auto else_block = node->blocks().at(1);
77 propagateBlock(then_block);
78 propagateBlock(else_block);
79 mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs());
80 }
81
processLoop(Node * node)82 void PropertyPropBase::processLoop(Node* node) {
83 LoopView loop(node);
84 // propagate counter type
85 loop.currentTripCount()->setType(loop.maxTripCount()->type());
86 applyTypes(loop.carriedInputs(), loop.bodyCarriedInputs());
87
88 do {
89 propagateBlock(loop.bodyBlock(), /*insert_expands=*/false);
90 // note: inserting expands is unsafe at this point, we don't know
91 // if the types are stable yet, so the arguments to expand may change
92 } while (mergeTypes(
93 loop.bodyCarriedInputs(),
94 loop.bodyCarriedOutputs(),
95 loop.bodyCarriedInputs()));
96
97 // now that the types are stable, we can insert the expands
98 propagateBlock(loop.bodyBlock(), /*insert_expands=*/true);
99 applyTypes(loop.bodyCarriedInputs(), loop.carriedOutputs());
100 }
101
setUnshapedType(Value * o)102 void PropertyPropBase::setUnshapedType(Value* o) {
103 o->setType(unshapedType(o->type()));
104 }
105
setUnshapedType(Node * node)106 void PropertyPropBase::setUnshapedType(Node* node) {
107 for (auto o : node->outputs()) {
108 setUnshapedType(o);
109 }
110 }
111
112 namespace prim {
113 using namespace ::c10::prim;
114 }
115
116 #define SHAPE_ASSERT(cond) \
117 if (!(cond)) \
118 throw propagation_error()
119
120 namespace {
121
isValidArgumentForRunning(Value * v)122 bool isValidArgumentForRunning(Value* v) {
123 // allow constants
124 if (toIValue(v))
125 return true;
126 if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
127 if (!tt->scalarType()) {
128 return false;
129 }
130 return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
131 }
132 return v->type()->isSubtypeOf(*FloatType::get());
133 }
134
isValidReturnForRunning(Value * v)135 bool isValidReturnForRunning(Value* v) {
136 return v->type()->isSubtypeOf(*TensorType::get()) ||
137 v->type()->isSubtypeOf(*NumberType::get());
138 }
139
containsTensorType(const TypePtr & t)140 bool containsTensorType(const TypePtr& t) {
141 auto n_contained = t->containedTypes().size();
142 if (n_contained == 1) {
143 return t->containedTypes().at(0)->isSubtypeOf(*TensorType::get());
144 } else if (n_contained > 1) {
145 return std::any_of(
146 t->containedTypes().begin(),
147 t->containedTypes().end(),
148 containsTensorType);
149 }
150 return false;
151 }
152
153 // for each node in the schema with type Tensor, extract the T type
154 // returns std::nullopt if any Tensor in the schema does not have a known
155 // shape ignores non-tensor in the list of inputs
gatherTensorTypes(Node * node,bool complete=false)156 std::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
157 Node* node,
158 bool complete = false) {
159 std::vector<TensorTypePtr> tensor_types;
160
161 auto schema_opt = node->maybeSchema();
162 if (!schema_opt) {
163 return std::nullopt;
164 }
165 auto& schema = *schema_opt;
166 auto& args = schema.arguments();
167 // can't handle varargs primitives because we don't know what should be a
168 // Tensor
169 if (schema.is_vararg()) {
170 return std::nullopt;
171 }
172 for (const auto i : c10::irange(args.size())) {
173 if (args[i].type()->isSubtypeOf(*ListType::ofTensors())) {
174 return std::nullopt;
175 } else if (args[i].type()->isSubtypeOf(*TensorType::get())) {
176 if (auto type = node->input(i)->type()->cast<TensorType>()) {
177 if (complete && !type->isComplete()) {
178 return std::nullopt;
179 }
180 tensor_types.push_back(type);
181 } else {
182 return std::nullopt;
183 }
184 } else /* non-tensor type */ {
185 continue;
186 }
187 }
188 return tensor_types;
189 }
190
wrapDim(int64_t dim,at::IntArrayRef sizes)191 int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
192 if (dim < 0) {
193 dim += (int64_t)sizes.size();
194 }
195 return dim;
196 }
197
unionScalarTypes(c10::ScalarType original,c10::ScalarType next)198 c10::ScalarType unionScalarTypes(
199 c10::ScalarType original,
200 c10::ScalarType next) {
201 if (original == c10::ScalarType::Undefined) {
202 return next;
203 } else {
204 return c10::promoteTypes(original, next);
205 }
206 }
207
208 // Promotes result types for arithmetic operations on Tensor operands using
209 // new type promotion logic. See tensor_attributes.rst for details.
210 // This doesn't handle the case of arithmetic ops with Scalar arguments (when
211 // `Tensor.getUnsafeTensorImpl()->is_wrapped_number()` would return true)
getPromotedTypeForArithmeticOp(Node * node)212 std::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) {
213 c10::ScalarType dimmed = c10::ScalarType::Undefined;
214 c10::ScalarType zerodim = c10::ScalarType::Undefined;
215 // binary arithmetic ops, more than 2 args is alpha.
216 for (const auto i : c10::irange(2)) {
217 auto dtt = node->inputs()[i]->type()->expect<TensorType>();
218 auto inputDtype = dtt->scalarType();
219 if (!dtt || !inputDtype) {
220 return std::nullopt;
221 }
222 if (dtt->dim() && *dtt->dim() > 0) {
223 dimmed = unionScalarTypes(dimmed, *inputDtype);
224 } else if (!isFloatingType(dimmed)) {
225 // if no dimensions
226 zerodim = unionScalarTypes(zerodim, *inputDtype);
227 }
228 }
229 // if a tensor with dimensions is already of the highest category, don't
230 // need to check zero-dim tensors.
231 if (isFloatingType(dimmed)) {
232 return dimmed;
233 }
234 // int_tensor * zero_dim_floating -> floating_tensor
235 if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) {
236 return zerodim;
237 }
238 // bool_tensor * non_bool_scalar -> non_bool_tensor
239 if (c10::ScalarType::Bool == dimmed &&
240 c10::ScalarType::Undefined != zerodim) {
241 return zerodim;
242 }
243 // types of dimensioned tensors generally take precedence over zero-dim
244 // tensors if not promoting due to category. e.g.:
245 // int_tensor * long -> int_tensor
246 if (c10::ScalarType::Undefined != dimmed) {
247 return dimmed;
248 }
249
250 // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor.
251 return zerodim;
252 }
253
254 class ShapePropagator : public PropertyPropBase {
255 public:
ShapePropagator(const std::shared_ptr<Graph> & graph)256 explicit ShapePropagator(const std::shared_ptr<Graph>& graph)
257 : PropertyPropBase(graph), aliasDb_(graph) {
258 collectResizeSet(graph->block());
259 }
260
261 private:
262 ValueSet resized_alias_set;
263 const AliasDb aliasDb_;
264
resizesInput(Node * n)265 bool resizesInput(Node* n) {
266 static std::unordered_set<Symbol> resize_ops{
267 aten::resize_,
268 aten::resize_as_,
269 aten::copy_,
270 aten::set_,
271 aten::unsqueeze_,
272 aten::t_,
273 aten::transpose_,
274 };
275
276 if (resize_ops.count(n->kind()))
277 return true;
278
279 if (!n->maybeSchema())
280 return false;
281
282 // ops which take the result and write to input "out"
283 if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
284 auto arg = n->schema().arguments().at(*out_arg_index);
285 return arg.kwarg_only() && arg.type()->isSubtypeOf(*TensorType::get());
286 }
287 return false;
288 }
289
collectResizeSet(Block * block)290 void collectResizeSet(Block* block) {
291 for (Node* n : block->nodes()) {
292 for (Block* b : n->blocks()) {
293 collectResizeSet(b);
294 }
295 if (resizesInput(n)) {
296 for (const auto input : n->inputs()) {
297 if (aliasDb_.writesToAlias(n, {input})) {
298 resized_alias_set.insert(input);
299 }
300 }
301 }
302 }
303 }
304
representativeValue(Value * v)305 IValue representativeValue(Value* v) {
306 TypePtr type_ = v->type();
307 // if the value is actually constant, just use it!
308 if (auto iv = toIValue(v)) {
309 return *iv;
310 }
311 if (TensorTypePtr type = type_->cast<TensorType>()) {
312 if (type->isComplete()) {
313 at::DeviceGuard device_guard(*type->device());
314 return at::empty_strided(
315 *type->sizes().concrete_sizes(),
316 *type->strides().concrete_sizes(),
317 at::TensorOptions(*type->device()).dtype(type->scalarType()))
318 .zero_();
319 }
320 // fallthrough
321 } else if (type_->isSubtypeOf(*FloatType::get())) {
322 return 0.f;
323 }
324 // we should not get here because isValidArgumentForRunning should have
325 // prevented it
326 std::stringstream ss;
327 ss << "unable to create representative value for: " << type_->str()
328 << ". File a bug report";
329 throw std::runtime_error(ss.str());
330 }
331
broadcastBinary(Node * node,std::vector<TensorTypePtr> & types,size_t idx1,size_t idx2)332 void broadcastBinary(
333 Node* node,
334 std::vector<TensorTypePtr>& types,
335 size_t idx1,
336 size_t idx2) {
337 auto expected_size = at::infer_size(
338 *types[idx1]->sizes().concrete_sizes(),
339 *types[idx2]->sizes().concrete_sizes());
340 auto broadcast = [&](size_t input_idx) {
341 TensorTypePtr input_type = types.at(input_idx);
342 if (input_type->sizes() == expected_size)
343 return;
344 auto graph = node->owningGraph();
345 WithInsertPoint point_guard{node};
346 Node* expand = graph
347 ->create(
348 aten::expand,
349 {node->inputs().at(input_idx),
350 graph->insertConstant(expected_size),
351 graph->insertConstant(false)})
352 ->insertBefore(node);
353 propagateNode(expand);
354 node->replaceInput(input_idx, expand->output());
355 };
356 broadcast(idx1);
357 broadcast(idx2);
358 types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
359 types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
360 }
361
362 OperatorSet cannot_propagate_shape_by_running_it = {
363 "aten::inverse(Tensor self) -> Tensor",
364 };
365
366 // Check if this node depends on a value that has been mutated previously. If
367 // it has, then it's not safe to run this node in isolation, since we don't
368 // know whether the dependency has been executed.
369 std::unordered_map<Node*, bool> dependsOnMutationMemo_;
dependsOnMutation(Node * node)370 bool dependsOnMutation(Node* node) {
371 if (dependsOnMutationMemo_.count(node) != 0) {
372 return dependsOnMutationMemo_[node];
373 }
374
375 if (aliasDb_.hasWriters(node)) {
376 // If something could have written to a value used by this node, we can't
377 // guarantee the result is the same when running it in isolation.
378 dependsOnMutationMemo_[node] = true;
379 return true;
380 }
381
382 // recursively check the producers of its inputs. We need to do this if the
383 // mutable value has been laundered through a pure function:
384 // a += 1
385 // c = a + b
386 // d = c + 1
387 // In this case, `d` cares whether `a` has been mutated even though it's not
388 // a direct input.
389 auto depends = false;
390 for (auto input : node->inputs()) {
391 depends |= dependsOnMutation(input->node());
392 }
393
394 dependsOnMutationMemo_[node] = depends;
395 return depends;
396 }
397
canPropagateShapeByRunningIt(Node * node)398 bool canPropagateShapeByRunningIt(Node* node) {
399 if (node->isMemberOf(cannot_propagate_shape_by_running_it)) {
400 return false;
401 }
402
403 if (dependsOnMutation(node)) {
404 return false;
405 }
406
407 bool valid_args = std::all_of(
408 node->inputs().begin(),
409 node->inputs().end(),
410 isValidArgumentForRunning);
411 if (!valid_args)
412 return false;
413
414 bool valid_returns = std::all_of(
415 node->outputs().begin(),
416 node->outputs().end(),
417 isValidReturnForRunning);
418 if (!valid_returns)
419 return false;
420
421 return true;
422 }
423
424 // If there's no Tensor in outputs, e.g float / float,
425 // we don't need to propagate shape.
DoesntRefineOutputs(Node * node)426 bool DoesntRefineOutputs(Node* node) {
427 auto outputs = node->outputs();
428 for (auto& out : outputs) {
429 if (containsTensorType(out->type())) {
430 return false;
431 }
432 }
433 return true;
434 }
435
PropagateShapeOnNodeByRunningIt(Node * node,Operation op=nullptr)436 bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
437 if (!canPropagateShapeByRunningIt(node))
438 return false;
439
440 if (!op)
441 op = node->getOperation();
442
443 Stack stack;
444
445 for (auto input : node->inputs()) {
446 stack.push_back(representativeValue(input));
447 }
448
449 // XXX: we're not catching any exceptions from the op for now. This
450 // is to uncover any mistakes we could make when editing this code,
451 // and eventually it shouldn't matter, because this phase should be
452 // preceded by schema checking.
453 op(stack);
454
455 AT_ASSERT(stack.size() == node->outputs().size());
456 for (const auto i : c10::irange(stack.size())) {
457 // some ops may have mixed tensor/primitive outputs
458 // for primitives, we don't need to change the type because it is already
459 // its most constrained form.
460 auto tensor_type = node->outputs()[i]->type()->cast<TensorType>();
461 if (stack[i].isTensor() && tensor_type) {
462 // gradient information isn't always available or part of representative
463 // inputs, maintain original grad property
464 auto tensor_grad = tensor_type->requiresGrad();
465 node->outputs()[i]->setType(TensorType::create(stack[i].toTensor())
466 ->withRequiresGrad(tensor_grad));
467 }
468 }
469 return true;
470 }
471
PropagateCatShape(Node * cat_node)472 void PropagateCatShape(Node* cat_node) {
473 static const auto propagate_complete =
474 [](Node* node, at::ArrayRef<Value*> tensors) -> bool {
475 auto input_types =
476 fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
477 if (!std::all_of(
478 input_types.begin(),
479 input_types.end(),
480 [](const TensorTypePtr& tp) {
481 return tp != nullptr && tp->isComplete();
482 })) {
483 return false;
484 }
485 if (!node->is_constant(attr::dim))
486 return false;
487 std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
488 const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
489 const int64_t ndim = (int64_t)sizes.size();
490
491 if (dim < 0 || dim >= ndim)
492 return false;
493
494 sizes[dim] = 0;
495 for (auto& tp : input_types) {
496 auto tp_sizes = tp->sizes().concrete_sizes().value();
497 if (sizes.size() != tp_sizes.size())
498 return false;
499 for (const auto i : c10::irange(ndim)) {
500 if (sizes[i] != tp_sizes[i] && i != dim) {
501 return false;
502 }
503 }
504 sizes[dim] += tp_sizes[dim];
505 }
506 node->output()->setType(input_types[0]->withSizes(sizes));
507 return true;
508 };
509 static const auto propagate = [](Node* node,
510 at::ArrayRef<Value*> tensors) -> bool {
511 for (Value* v : tensors) {
512 if (auto type = v->type()->cast<TensorType>()) {
513 node->output()->setType(type->dimensionedOnly());
514 return true;
515 }
516 }
517 return false;
518 };
519 auto list_node =
520 ((cat_node->kind() == prim::FusedConcat)
521 ? cat_node
522 : cat_node->namedInput(attr::tensors)->node());
523 if (list_node->kind() == prim::ListConstruct ||
524 cat_node->kind() == prim::FusedConcat) {
525 auto tensors = list_node->inputs();
526 if (!tensors.empty()) {
527 // NOLINTNEXTLINE(bugprone-branch-clone)
528 if (propagate_complete(cat_node, tensors)) {
529 return;
530 } else if (propagate(cat_node, tensors)) {
531 return;
532 }
533 }
534 }
535 setUnshapedType(cat_node);
536 }
537
propagateTorchTensorShape(Node * node)538 void propagateTorchTensorShape(Node* node) {
539 auto input_type = node->inputs().at(0)->type();
540
541 size_t dims = 0;
542 auto input_base_type = input_type;
543 auto list_type = input_type->cast<ListType>();
544 while (list_type) {
545 dims++;
546 input_base_type = list_type->getElementType();
547 list_type = input_base_type->cast<ListType>();
548 }
549
550 std::optional<at::ScalarType> default_type =
551 tryScalarTypeFromJitType(*input_base_type);
552 if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
553 auto inp = toIValue(node->inputs().at(*grad_index));
554 if (inp == std::nullopt) {
555 return;
556 } else if (!inp->isNone()) {
557 default_type = inp->toScalarType();
558 }
559 }
560
561 at::Device default_device = at::kCPU;
562 if (auto device_index = node->schema().argumentIndexWithName("device")) {
563 auto inp = toIValue(node->inputs().at(*device_index));
564 if (inp == std::nullopt) {
565 return;
566 } else if (!inp->isNone()) {
567 default_device = inp->toDevice();
568 }
569 }
570 node->output()->setType(TensorType::create(
571 default_type, default_device, dims, /*requires_grad=*/std::nullopt));
572 }
573
574 // returns whether any such values were found
setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value * > vs)575 bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) {
576 bool in_resize = false;
577 for (auto v : vs) {
578 if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
579 setUnshapedType(v);
580 in_resize = true;
581 }
582 }
583 return in_resize;
584 }
585
propagateNode(Node * node,bool insert_expands=true)586 void propagateNode(Node* node, bool insert_expands = true) override {
587 // Certain ops like resize_ change the input tensors size. Because our
588 // analysis is flow invariant, we set any Tensor that can alias a resized
589 // Tensor to the base Tensor Type without size information.
590 if (setUnshapedTypeIfAliasResizedSet(node->inputs())) {
591 return setUnshapedType(node);
592 }
593
594 // These don't require the types, and have complicated schema. Return early
595 // after we process them.
596 switch (node->kind()) {
597 case prim::If:
598 return processIf(node);
599 case prim::Loop: {
600 return processLoop(node);
601 }
602 case aten::Bool:
603 case aten::Int:
604 case aten::Float:
605 case aten::ScalarImplicit:
606 case aten::FloatImplicit:
607 case aten::IntImplicit:
608 return; // correct num type is already set
609 case prim::NumToTensor: {
610 TypePtr typ = node->input()->type();
611 if (typ->isSubtypeOf(*IntType::get()) ||
612 typ->isSubtypeOf(*BoolType::get())) {
613 node->output()->setType(TensorType::create(
614 at::kLong, at::kCPU, 0, /*requires_grad=*/std::nullopt));
615 } else if (node->input()->type()->isSubtypeOf(*FloatType::get())) {
616 node->output()->setType(TensorType::create(
617 at::kDouble, at::kCPU, 0, /*requires_grad=*/std::nullopt));
618 }
619 return;
620 }
621 case aten::tensor:
622 case aten::as_tensor: {
623 // as_tensor has an overloaded schema and can either have a tensor or
624 // a list as the first input, if the input is a tensor, we delegate
625 // the shape propagation in PropagateTensorShapeOnNode
626 if (node->inputs().at(0)->type()->isSubtypeOf(*TensorType::get())) {
627 break;
628 }
629 return propagateTorchTensorShape(node);
630 }
631 case prim::TupleConstruct: {
632 // We refresh the tuple type, because the input types could have been
633 // refined.
634 auto orig_type = node->output()->type()->expect<TupleType>();
635 auto new_types =
636 fmap(node->inputs(), [](Value* v) { return v->type(); });
637 node->output()->setType(
638 orig_type->createWithContained(std::move(new_types)));
639 return;
640 }
641 case prim::TupleUnpack: {
642 auto tuple_type = node->input()->type()->cast<TupleType>();
643 AT_ASSERT(
644 tuple_type &&
645 tuple_type->elements().size() == node->outputs().size());
646 auto elems = tuple_type->elements();
647 for (size_t i = 0; i < node->outputs().size(); ++i) {
648 node->output(i)->setType(elems[i]);
649 }
650 return;
651 }
652 case prim::Constant: {
653 if (node->output()->type()->isSubtypeOf(*TensorType::get())) {
654 node->output()->inferTypeFrom(node->t(attr::value));
655 }
656 return;
657 }
658 case prim::unchecked_unwrap_optional: {
659 // If we have specialized the optional type to the element type,
660 // we want to pass it down. We write this as input.isSubtypeOf(output)
661 // to be sure that we don't screw up nested optionals.
662 if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
663 node->output()->setType(node->input()->type());
664 }
665 return;
666 }
667 case prim::ConstantChunk: {
668 Value* tensor = node->input();
669 if (auto type = tensor->type()->cast<TensorType>()) {
670 type = type->dimensionedOnly();
671 for (Value* output : node->outputs()) {
672 output->setType(type);
673 }
674 } else {
675 setUnshapedType(node);
676 }
677 return;
678 }
679 case prim::grad: {
680 auto tt = node->input()->type()->expect<TensorType>();
681 // grad may be undefined
682 // requires_grad may be required
683 auto grad_type = TensorType::get()->withPossiblyUndefined();
684 node->output()->setType(std::move(grad_type));
685 return;
686 }
687 case prim::CallFunction:
688 case prim::CallMethod:
689 case prim::AutogradZero: {
690 setUnshapedType(node);
691 return;
692 }
693 case prim::GetAttr: {
694 auto cls = node->input()->type()->expect<ClassType>();
695 // propagate any type specializations encoded in the type of the class
696 node->output()->setType(cls->getAttribute(node->s(attr::name)));
697 return;
698 }
699 case aten::_unwrap_optional: {
700 // If we have specialized the optional type to the element type,
701 // we want to pass it down. We write this as input.isSubtypeOf(output)
702 // to be sure that we don't screw up nested optionals.
703 if (node->input()->type()->isSubtypeOf(*node->output()->type())) {
704 node->output()->setType(node->input()->type());
705 }
706 return;
707 }
708 default:
709 break; // fall-through
710 }
711
712 if (node->hasSideEffects()) {
713 return;
714 }
715
716 if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
717 node->kind() == prim::FusedConcat) {
718 return PropagateCatShape(node);
719 }
720
721 if (auto maybe_complete_types =
722 gatherTensorTypes(node, /*complete=*/true)) {
723 if (PropagateCompleteShapeOnNode(
724 node, insert_expands, std::move(*maybe_complete_types))) {
725 return;
726 }
727 }
728
729 if (PropagateTensorShapeOnNode(node, insert_expands)) {
730 return;
731 }
732
733 if (DoesntRefineOutputs(node)) {
734 return;
735 }
736
737 if (PropagateShapeOnNodeByRunningIt(node)) {
738 return;
739 }
740 return setUnshapedType(node);
741 }
742
determineListSize(Value * list)743 static std::optional<size_t> determineListSize(Value* list) {
744 AT_ASSERT(list->type()->cast<ListType>());
745 if (auto shape = constant_as<c10::List<int64_t>>(list)) {
746 return shape->size();
747 }
748 auto input_node = list->node();
749 if (input_node->kind() == prim::ListConstruct) {
750 return input_node->inputs().size();
751 }
752 return std::nullopt;
753 }
754
755 // is it ok to try to run the op
756 // If an input is a constant, then we assume that the input is valid
757 // and we can try to run it.
758 // Otherwise:
759 // Integral typed _inputs_ are often an indicator that we're indexing into
760 // a tensor, so we should special-case these ops in the shape propagation.
761 // Additionally, passing in a zero representative tensor into an integer
762 // division op causes divide-by-zero errors
763 // _Outputs_ must be tensors or primitives
764 // We will call inferTypeFrom on the tensors, and ignore the primitives.
765 // However, we allow primitive returns because we want to support mixed
766 // primitive/tensor outputs.
767
PropagateTensorShapeOnNode(Node * node,bool insert_expands)768 bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
769 static const auto broadcast =
770 [](std::vector<TensorTypePtr>& tensor_types,
771 std::optional<at::ScalarType> t) -> TensorTypePtr {
772 if (tensor_types.size() == 1) {
773 return tensor_types[0]->dimensionedOnly()->withScalarType(t);
774 }
775 AT_ASSERT(!tensor_types.empty());
776 auto any_type = tensor_types[0];
777 auto max_dims = any_type->dim();
778 for (auto& type : tensor_types) {
779 if (!max_dims || !type->dim()) {
780 max_dims = std::nullopt;
781 } else {
782 max_dims = std::max(*max_dims, *type->dim());
783 }
784 }
785 return TensorType::create(
786 t,
787 any_type->device(),
788 max_dims,
789 /*requires_grad=*/std::nullopt);
790 };
791
792 using type_vec_t = std::vector<TensorTypePtr>;
793 // Formula is expected to return a vector of length equal to the number of
794 // tensor outputs of the node, or an empty vector which implies that it
795 // failed to propagate.
796 using formula_t = std::function<type_vec_t(Node*)>;
797 static std::mutex shape_formulas_mutex;
798 static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
799 struct register_formula_for {
800 register_formula_for(OperatorSet operators, formula_t formula) {
801 std::unique_lock<std::mutex> lock{shape_formulas_mutex};
802 shape_formulas.emplace_back(std::move(operators), std::move(formula));
803 }
804 };
805
806 // Requirements:
807 // dims : preserved
808 // scalar type : preserved
809 // device : preserved
810 // tensor inputs : 1
811 // tensor outputs : 1
812 // Additionally:
813 // - First input should be the only tensor input
814 static const register_formula_for simple_unary_ops{
815 {
816 "aten::acos(Tensor self) -> Tensor",
817 "aten::neg(Tensor self) -> Tensor",
818 "aten::t(Tensor self) -> Tensor",
819 "aten::sigmoid(Tensor self) -> Tensor",
820 "aten::logit(Tensor self, float? eps=None) -> Tensor",
821 "aten::tanh(Tensor self) -> Tensor",
822 "aten::relu(Tensor self) -> Tensor",
823 "aten::asin(Tensor self) -> Tensor",
824 "aten::atan(Tensor self) -> Tensor",
825 "aten::ceil(Tensor self) -> Tensor",
826 "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
827 "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
828 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
829 "aten::celu(Tensor self, Scalar alpha) -> Tensor",
830 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
831 "aten::clamp_max(Tensor self, Scalar max) -> Tensor",
832 "aten::clamp_min(Tensor self, Scalar min) -> Tensor",
833 "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
834 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
835 "aten::cos(Tensor self) -> Tensor",
836 "aten::cosh(Tensor self) -> Tensor",
837 "aten::digamma(Tensor self) -> Tensor",
838 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
839 "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
840 "aten::erf(Tensor self) -> Tensor",
841 "aten::erfc(Tensor self) -> Tensor",
842 "aten::erfinv(Tensor self) -> Tensor",
843 "aten::exp(Tensor self) -> Tensor",
844 "aten::expm1(Tensor self) -> Tensor",
845 "aten::log(Tensor self) -> Tensor",
846 "aten::log10(Tensor self) -> Tensor",
847 "aten::log1p(Tensor self) -> Tensor",
848 "aten::log2(Tensor self) -> Tensor",
849 "aten::log_sigmoid(Tensor self) -> Tensor",
850 "aten::floor(Tensor self) -> Tensor",
851 "aten::frac(Tensor self) -> Tensor",
852 "aten::flip(Tensor self, int[] dims) -> Tensor",
853 "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
854 "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
855 "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
856 "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
857 "aten::glu(Tensor self, int dim) -> Tensor",
858 "aten::inverse(Tensor self) -> Tensor",
859 "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
860 "aten::lgamma(Tensor self) -> Tensor",
861 "aten::mvlgamma(Tensor self, int p) -> Tensor",
862 "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
863 "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
864 "aten::permute(Tensor self, int[] dims) -> Tensor",
865 "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
866 "aten::pinverse(Tensor self, float rcond) -> Tensor",
867 "aten::reciprocal(Tensor self) -> Tensor",
868 "aten::relu(Tensor self) -> Tensor",
869 "aten::round(Tensor self) -> Tensor",
870 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
871 "aten::rsqrt(Tensor self) -> Tensor",
872 "aten::selu(Tensor self) -> Tensor",
873 "aten::gelu(Tensor self, *, str approximate='none') -> Tensor",
874 "aten::sigmoid(Tensor self) -> Tensor",
875 "aten::sign(Tensor self) -> Tensor",
876 "aten::sin(Tensor self) -> Tensor",
877 "aten::sinh(Tensor self) -> Tensor",
878 "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
879 "aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
880 "aten::sqrt(Tensor self) -> Tensor",
881 "aten::tan(Tensor self) -> Tensor",
882 "aten::tanh(Tensor self) -> Tensor",
883 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
884 "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
885 "aten::tril(Tensor self, int diagonal) -> Tensor",
886 "aten::triu(Tensor self, int diagonal) -> Tensor",
887 "aten::trunc(Tensor self) -> Tensor",
888 "aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
889 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
890 "aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor",
891 "aten::alias(Tensor self) -> Tensor",
892 },
893 [](Node* node) -> type_vec_t {
894 auto input_type = node->input(0)->type()->cast<TensorType>();
895 return input_type ? type_vec_t{input_type->dimensionedOnly()}
896 : type_vec_t{};
897 }};
898
899 // Requirements:
900 // dims : preserved
901 // scalar type : preserved, except complex maps to float
902 // device : preserved
903 // tensor inputs : 1
904 // tensor outputs : 1
905 // Additionally:
906 // - First input should be the only tensor input
907 static const register_formula_for simple_unary_ops_complex_to_float{
908 {
909 "aten::abs(Tensor self) -> Tensor",
910 },
911 [](Node* node) -> type_vec_t {
912 auto input_type = node->input(0)->type()->cast<TensorType>();
913
914 // Maps complex -> float
915 if (input_type->scalarType()) {
916 const auto scalar_type = *(input_type->scalarType());
917 if (isComplexType(scalar_type)) {
918 const auto out_type = c10::toRealValueType(scalar_type);
919 return type_vec_t{
920 input_type->dimensionedOnly()->withScalarType(out_type)};
921 }
922 }
923
924 return input_type ? type_vec_t{input_type->dimensionedOnly()}
925 : type_vec_t{};
926 }};
927
928 // Requirements:
929 // dims : broadcast all tensor args
930 // scalar type : promoted from input dtypes
931 // device : always matching and preserved
932 // tensor inputs : *
933 // tensor outputs : 1
934 static const register_formula_for broadcasting_ops_arithmetic{
935 {
936 // Tensor-Tensor operators
937 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
938 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
939 "aten::mul(Tensor self, Tensor other) -> Tensor",
940 "aten::div(Tensor self, Tensor other) -> Tensor",
941 },
942 [](Node* node) -> type_vec_t {
943 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
944 AT_ASSERT(maybe_tensor_types->size() >= 2);
945 auto dtype = getPromotedTypeForArithmeticOp(node);
946 return {broadcast(*maybe_tensor_types, dtype)};
947 }
948 return {};
949 }};
950
951 // Requirements:
952 // dims : broadcast all tensor args
953 // scalar type : always matching and preserved
954 // device : always matching and preserved
955 // tensor inputs : *
956 // tensor outputs : 1
957 static const register_formula_for broadcasting_ops{
958 {
959 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
960 "aten::fmod(Tensor self, Tensor other) -> Tensor",
961 "aten::remainder(Tensor self, Tensor other) -> Tensor",
962 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
963 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
964 "aten::max(Tensor self, Tensor other) -> Tensor",
965 "aten::min(Tensor self, Tensor other) -> Tensor",
966 "aten::__and__(Tensor self, Tensor other) -> Tensor",
967 "aten::__or__(Tensor self, Tensor other) -> Tensor",
968 "aten::__xor__(Tensor self, Tensor other) -> Tensor",
969 "aten::__lshift__(Tensor self, Tensor other) -> Tensor",
970 "aten::__rshift__(Tensor self, Tensor other) -> Tensor",
971 "aten::__iand__(Tensor self, Tensor other) -> Tensor",
972 "aten::__ior__(Tensor self, Tensor other) -> Tensor",
973 "aten::__ixor__(Tensor self, Tensor other) -> Tensor",
974 "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
975 "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
976
977 // Ops with Tensor-Tensor overloads only
978 "aten::atan2(Tensor self, Tensor other) -> Tensor",
979 },
980 [](Node* node) -> type_vec_t {
981 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
982 AT_ASSERT(maybe_tensor_types->size() >= 2);
983 auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
984 auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
985 if (!first_scalar_type || !second_scalar_type) {
986 return {};
987 }
988 size_t arg_for_type = 0;
989 if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) !=
990 first_scalar_type) {
991 arg_for_type = 1;
992 }
993 auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
994 return {broadcast(*maybe_tensor_types, t)};
995 }
996 return {};
997 }};
998
999 static const register_formula_for fused_accum_binary_ops{
1000 {
1001 // Non-binary ops
1002 "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1003 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
1004 },
1005 [](Node* node) -> type_vec_t {
1006 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1007 auto dtype = (*maybe_tensor_types)[0]->scalarType();
1008 if (!dtype) {
1009 return {};
1010 }
1011 return {broadcast(*maybe_tensor_types, dtype)};
1012 }
1013 return {};
1014 }};
1015
1016 static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{
1017 {
1018 // Tensor-Scalar operators
1019 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1020 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
1021 "aten::mul(Tensor self, Scalar other) -> Tensor",
1022 "aten::div(Tensor self, Scalar other) -> Tensor",
1023 },
1024 [](Node* node) -> type_vec_t {
1025 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1026 auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
1027 auto second_scalar_type =
1028 tryScalarTypeFromJitType(*node->inputs()[1]->type());
1029 if (!first_scalar_type || !second_scalar_type) {
1030 return {};
1031 }
1032 if (isIntegralType(*first_scalar_type, false) &&
1033 isFloatingType(*second_scalar_type)) {
1034 auto default_dtype =
1035 at::typeMetaToScalarType(caffe2::get_default_dtype());
1036 return {broadcast(*maybe_tensor_types, default_dtype)};
1037 }
1038 if (c10::ScalarType::Bool == *first_scalar_type &&
1039 c10::ScalarType::Bool != *second_scalar_type) {
1040 auto result_type =
1041 c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1042 return {broadcast(*maybe_tensor_types, result_type)};
1043 }
1044 return {broadcast(*maybe_tensor_types, first_scalar_type)};
1045 }
1046 return {};
1047 }};
1048
1049 // NB: we always take the scalar type of the Tensor
1050 static const register_formula_for broadcasting_tensor_scalar_ops{
1051 {
1052
1053 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
1054 "aten::fmod(Tensor self, Scalar other) -> Tensor",
1055 "aten::remainder(Tensor self, Scalar other) -> Tensor",
1056 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
1057 "aten::__and__(Tensor self, Scalar other) -> Tensor",
1058 "aten::__or__(Tensor self, Scalar other) -> Tensor",
1059 "aten::__xor__(Tensor self, Scalar other) -> Tensor",
1060 "aten::__lshift__(Tensor self, Scalar other) -> Tensor",
1061 "aten::__rshift__(Tensor self, Scalar other) -> Tensor",
1062 "aten::__iand__(Tensor self, Scalar other) -> Tensor",
1063 "aten::__ior__(Tensor self, Scalar other) -> Tensor",
1064 "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
1065 "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
1066 "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
1067 },
1068 [](Node* node) -> type_vec_t {
1069 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1070 return {broadcast(
1071 *maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())};
1072 }
1073 return {};
1074 }};
1075
1076 // aten::where is special in that its return type is the second argument's
1077 // (self) type rather than the that of condition
1078 static const register_formula_for where_op{
1079 {
1080 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
1081 },
1082 [](Node* node) -> type_vec_t {
1083 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1084 return {broadcast(
1085 *maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())};
1086 }
1087 return {};
1088 }};
1089
1090 static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
1091 for (Value* input : node->inputs()) {
1092 if (auto type = input->type()->cast<TensorType>()) {
1093 if (type->dim().has_value()) {
1094 return type;
1095 }
1096 }
1097 }
1098 return nullptr;
1099 };
1100
1101 // Requirements:
1102 // dims : always matching and preserved
1103 // scalar type : always matching and preserved
1104 // device : always matching and preserved
1105 // tensor inputs : 2
1106 // tensor outputs : 1
1107 static const register_formula_for binary_ops_strict_match{
1108 {
1109 "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
1110 "aten::mm(Tensor self, Tensor mat2) -> Tensor",
1111 "aten::bmm(Tensor self, Tensor mat2) -> Tensor",
1112 },
1113 [](Node* node) -> type_vec_t {
1114 if (auto type = any_tensor_type(node)) {
1115 return {std::move(type)};
1116 }
1117 return {};
1118 }};
1119
1120 // Requirements:
1121 // dims : all tensor args are broadcast
1122 // scalar type : byte/uint8
1123 // device : always matching and preserved
1124 // tensor inputs : *
1125 // tensor outputs : 1
1126 static const register_formula_for comparison_ops{
1127 {
1128 "aten::lt(Tensor self, Tensor other) -> Tensor",
1129 "aten::le(Tensor self, Tensor other) -> Tensor",
1130 "aten::gt(Tensor self, Tensor other) -> Tensor",
1131 "aten::ge(Tensor self, Tensor other) -> Tensor",
1132 "aten::eq(Tensor self, Tensor other) -> Tensor",
1133 "aten::ne(Tensor self, Tensor other) -> Tensor",
1134 "aten::lt(Tensor self, Scalar other) -> Tensor",
1135 "aten::le(Tensor self, Scalar other) -> Tensor",
1136 "aten::gt(Tensor self, Scalar other) -> Tensor",
1137 "aten::ge(Tensor self, Scalar other) -> Tensor",
1138 "aten::eq(Tensor self, Scalar other) -> Tensor",
1139 "aten::ne(Tensor self, Scalar other) -> Tensor",
1140 },
1141 [](Node* node) -> type_vec_t {
1142 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1143 return {broadcast(*maybe_tensor_types, at::kBool)};
1144 }
1145 return {};
1146 }};
1147
1148 static const register_formula_for nn_ops_first_input_formula{
1149 *nn_ops_first_input_preserving(), [](Node* node) -> type_vec_t {
1150 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1151 return {type->dimensionedOnly()};
1152 }
1153 return {};
1154 }};
1155
1156 // Requirements:
1157 // dims : 0
1158 // scalar type : preserved
1159 // device : preserved
1160 // tensor inputs : 1
1161 // tensor outputs : 1
1162 // Additionally:
1163 // - First input should be the only tensor input
1164 static const register_formula_for all_reduce_ops{
1165 {
1166 "aten::det(Tensor self) -> Tensor",
1167 "aten::logdet(Tensor self) -> Tensor",
1168 "aten::max(Tensor self) -> Tensor",
1169 "aten::min(Tensor self) -> Tensor",
1170 "aten::median(Tensor self) -> Tensor",
1171 "aten::nanmedian(Tensor self) -> Tensor",
1172 "aten::norm(Tensor self, Scalar p) -> Tensor",
1173 "aten::std(Tensor self, bool unbiased) -> Tensor",
1174 "aten::trace(Tensor self) -> Tensor",
1175 "aten::var(Tensor self, bool unbiased) -> Tensor",
1176 "aten::all(Tensor self) -> Tensor",
1177 "aten::any(Tensor self) -> Tensor",
1178 },
1179 [](Node* node) -> type_vec_t {
1180 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1181 return {type->withDim(0)};
1182 }
1183 return {};
1184 }};
1185
1186 // Requirements:
1187 // dims : 0
1188 // scalar type : dtype if specified, else preserved
1189 // device : preserved
1190 // tensor inputs : 1
1191 // tensor outputs : 1
1192 // Additionally:
1193 // - First input should be the only tensor input
1194 static const register_formula_for reduce_ops_with_opt_dtype{
1195 {"aten::mean(Tensor self, *, int? dtype) -> Tensor"},
1196 [](Node* node) -> type_vec_t {
1197 std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1198 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1199 auto ret = type->withDim(0);
1200 if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1201 return {ret->withScalarType(maybe_dtype_option->toScalarType())};
1202 } else {
1203 return {std::move(ret)};
1204 }
1205 }
1206 return {};
1207 }};
1208
1209 // Requirements:
1210 // dims : 0
1211 // scalar type : dtype if specified, else preserved if floating point,
1212 // otherwise long/int64 device : preserved tensor inputs : 1
1213 // tensor outputs : 1
1214 // Additionally:
1215 // - First input should be the only tensor input
1216 static const register_formula_for
1217 all_reduce_ops_with_integer_upcast_and_dtype{
1218 {
1219 "aten::sum(Tensor self, *, int? dtype) -> Tensor",
1220 "aten::prod(Tensor self, *, int? dtype) -> Tensor",
1221 },
1222 [](Node* node) -> type_vec_t {
1223 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1224 type = type->withDim(0);
1225 std::optional<IValue> maybe_dtype_option =
1226 node->get(attr::dtype);
1227 if (maybe_dtype_option && !maybe_dtype_option->isNone()) {
1228 return {
1229 type->withScalarType(maybe_dtype_option->toScalarType())};
1230 }
1231 if (type->scalarType()) {
1232 return {
1233 at::isFloatingType(*type->scalarType())
1234 ? std::move(type)
1235 : type->withScalarType(at::kLong)};
1236 } else {
1237 return {std::move(type)};
1238 }
1239 }
1240 return {};
1241 }};
1242
1243 static const auto reduce_op_handler = [](Node* node,
1244 int64_t num_reduced_dim = 0,
1245 bool upcast_integer = false,
1246 std::optional<IValue> opt_dtype =
1247 std::nullopt) -> type_vec_t {
1248 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1249 if (!type->scalarType() || !type->dim()) {
1250 return {};
1251 }
1252 if (opt_dtype && !opt_dtype->isNone()) {
1253 type = type->withScalarType(opt_dtype->toScalarType());
1254 } else if (upcast_integer && !at::isFloatingType(*type->scalarType())) {
1255 type = type->withScalarType(at::kLong);
1256 }
1257 if (static_cast<int64_t>(*type->dim()) >= num_reduced_dim &&
1258 num_reduced_dim > 0) {
1259 return {type->withDim(*type->dim() - num_reduced_dim)};
1260 } else {
1261 return {std::move(type)};
1262 }
1263 }
1264 return {};
1265 };
1266
1267 static const auto multidim_reduce_with_keepdim =
1268 [](Node* node,
1269 int64_t num_reduced_dim,
1270 bool upcast_integer) -> type_vec_t {
1271 auto maybe_keepdim = node->get<bool>(attr::keepdim);
1272 if (!maybe_keepdim)
1273 return {};
1274 return reduce_op_handler(
1275 node, *maybe_keepdim ? 0 : num_reduced_dim, upcast_integer);
1276 };
1277
1278 // Requirements:
1279 // dims : 0 if dim is None, otherwise preserved if keepdim ==
1280 // false or 1 smaller otherwise scalar type : preserved device :
1281 // preserved tensor inputs : 1 tensor outputs : 1
1282 // Additionally:
1283 // - First input should be the only tensor input
1284 // - Has a bool keepdim argument
1285 static const register_formula_for argminmax{
1286 {
1287 "aten::argmax(Tensor self, int? dim, bool keepdim) -> Tensor",
1288 "aten::argmin(Tensor self, int? dim, bool keepdim) -> Tensor",
1289 },
1290 [](Node* node) -> type_vec_t {
1291 if (auto type = node->input(0)->type()->cast<TensorType>()) {
1292 if (node->input(1)->type()->kind() == c10::TypeKind::NoneType) {
1293 return {type->withDim(0)};
1294 } else {
1295 return multidim_reduce_with_keepdim(
1296 node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1297 }
1298 }
1299 return {};
1300 }};
1301
1302 // Requirements:
1303 // dims : preserved if keepdim == false, 1 smaller otherwise
1304 // scalar type : preserved for first output, byte/uint8 for second
1305 // output if exists device : preserved tensor inputs : 1 tensor
1306 // outputs : 1 or 2
1307 // Additionally:
1308 // - First input should be the only tensor input
1309 // - Has a bool keepdim argument
1310 static const register_formula_for dim_reduce_ops{
1311 {
1312 "aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
1313 "aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
1314
1315 // Ops returning indices as second output
1316 "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
1317 "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1318 "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1319 "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1320 "aten::nanmedian(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1321 "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
1322 },
1323 [](Node* node) -> type_vec_t {
1324 // NB: Note that while this function is generally meant to be used
1325 // with ops that have a single output, we will fix up its return right
1326 // below.
1327 auto output_types = multidim_reduce_with_keepdim(
1328 node, /*num_reduced_dim=*/1, /*upcast_integer=*/false);
1329 if (!output_types.empty() && node->outputs().size() == 2) {
1330 output_types.push_back(
1331 output_types.back()->withScalarType(at::kLong));
1332 }
1333 return output_types;
1334 }};
1335
1336 // Requirements:
1337 // dims : preserved if keepdim == false, 1 smaller otherwise
1338 // scalar type : dtype if specified. preserved if floating point,
1339 // otherwise long/int64 device : preserved tensor inputs : 1
1340 // tensor outputs : 1
1341 // Additionally:
1342 // - First input should be the only tensor input
1343 // - has a bool keepdim argument
1344 static const register_formula_for dim_reduce_ops_with_integer_upcast{
1345 {
1346 "aten::prod(Tensor self, int dim, bool keepdim, *, int? dtype) -> Tensor",
1347 },
1348 [](Node* node) -> type_vec_t {
1349 auto maybe_keepdim = node->get<bool>(attr::keepdim);
1350 std::optional<IValue> opt_dtype = node->get(attr::dtype);
1351 return reduce_op_handler(
1352 node,
1353 /*num_reduce_dim=*/*maybe_keepdim ? 0 : 1,
1354 /*integer_upcast=*/true,
1355 std::move(opt_dtype));
1356 }};
1357
1358 // Requirements:
1359 // dims : preserved
1360 // scalar type : dtype if specified, preserved if floating point,
1361 // otherwise long/int64
1362 // device : preserved
1363 // tensor inputs : 1
1364 // tensor outputs : 1
1365 // Additionally:
1366 // - First input should be the only tensor input
1367 static const register_formula_for dim_reduce_ops_dtype{
1368 {"aten::cumprod(Tensor self, int dim, *, int? dtype) -> Tensor",
1369 "aten::cumsum(Tensor self, int dim, *, int? dtype) -> Tensor",
1370 "aten::log_softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1371 [](Node* node) -> type_vec_t {
1372 std::optional<IValue> opt_dtype = node->get(attr::dtype);
1373 return reduce_op_handler(
1374 node,
1375 /*num_reduce_dim=*/0,
1376 /*integer_upcast=*/true,
1377 std::move(opt_dtype));
1378 }};
1379
1380 // Requirements:
1381 // dims : preserved
1382 // scalar type : dtype if specified, otherwise preserved
1383 // device : preserved
1384 // tensor inputs : 1
1385 // tensor outputs : 1
1386 // Additionally:
1387 // - has bool keepdim and int[] dim arguments
1388 static const register_formula_for register_softmax{
1389 {"aten::softmax(Tensor self, int dim, int? dtype) -> Tensor"},
1390 [](Node* node) -> type_vec_t {
1391 std::optional<IValue> opt_dtype = node->get(attr::dtype);
1392 return reduce_op_handler(
1393 node,
1394 /*num_reduced_dim=*/0,
1395 /*upcast_integer=*/false,
1396 std::move(opt_dtype));
1397 }};
1398
1399 static const auto factory_with_ndim =
1400 [](Node* node, int dim, at::ScalarType default_dtype) -> type_vec_t {
1401 std::optional<IValue> maybe_layout_option = node->get(attr::layout);
1402 if (!maybe_layout_option)
1403 return {};
1404
1405 std::optional<IValue> maybe_device_option = node->get(attr::device);
1406 if (!maybe_device_option)
1407 return {};
1408 auto device =
1409 (maybe_device_option->isNone() ? at::kCPU
1410 : maybe_device_option->toDevice());
1411
1412 std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1413 if (!maybe_dtype_option)
1414 return {};
1415 auto dtype =
1416 (maybe_dtype_option->isNone() ? default_dtype
1417 : maybe_dtype_option->toScalarType());
1418
1419 return {TensorType::create(
1420 dtype, device, dim, /*requires_grad=*/std::nullopt)};
1421 };
1422
1423 static const auto factory_like_with_ndim = [](Node* node,
1424 int dim) -> type_vec_t {
1425 auto tt = node->input(0)->type()->expect<TensorType>();
1426 auto in_type = tt->scalarType();
1427 auto in_dev = tt->device();
1428
1429 std::optional<IValue> maybe_layout_option = node->get(attr::layout);
1430 if (!maybe_layout_option)
1431 return {};
1432
1433 std::optional<IValue> maybe_device_option = node->get(attr::device);
1434 if (!maybe_device_option)
1435 return {};
1436
1437 if (!maybe_device_option->isNone()) {
1438 in_dev = maybe_device_option->toDevice();
1439 }
1440
1441 std::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
1442 if (!maybe_dtype_option)
1443 return {};
1444
1445 if (!maybe_dtype_option->isNone()) {
1446 in_type = maybe_dtype_option->toScalarType();
1447 }
1448
1449 return {TensorType::create(
1450 in_type, in_dev, dim, /*requires_grad=*/std::nullopt)};
1451 };
1452
1453 // Requirements:
1454 // dims : preserved
1455 // scalar type : equal to value of dtype
1456 // device : equal to value of device
1457 // tensor inputs : 1
1458 // tensor outputs : 1
1459 // Additionally:
1460 // - has ScalarType dtype, Layout layout and Device device arguments
1461 static const register_formula_for like_factories_with_options{
1462 {
1463 "aten::empty_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1464 "aten::full_like(Tensor self, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1465 "aten::ones_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1466 "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1467 "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1468 "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1469 "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1470 "aten::zeros_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1471 },
1472 [](Node* node) -> type_vec_t {
1473 if (auto type =
1474 node->namedInput(attr::self)->type()->cast<TensorType>()) {
1475 if (type->dim()) {
1476 return factory_like_with_ndim(node, (int)*type->dim());
1477 }
1478 }
1479 return {};
1480 }};
1481
1482 // Requirements:
1483 // dims : equal to number of elements in size
1484 // scalar type : equal to value of dtype
1485 // device : equal to value of device
1486 // tensor inputs : 1
1487 // tensor outputs : 1
1488 // Additionally:
1489 // - has int[] size, ScalarType dtype, Layout layout and Device device
1490 // arguments
1491 static const register_formula_for size_factories_with_options{
1492 {
1493 "aten::empty(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory, MemoryFormat? memory_format=contiguous_format) -> Tensor",
1494 "aten::full(int[] size, Scalar fill_value, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1495 "aten::ones(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1496 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1497 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1498 "aten::zeros(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1499 },
1500 [](Node* node) -> type_vec_t {
1501 if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1502 return factory_with_ndim(
1503 node, (int)maybe_size->size(), at::kDouble);
1504 }
1505 return {};
1506 }};
1507
1508 static const register_formula_for randint{
1509 {
1510 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1511 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1512 },
1513 [](Node* node) -> type_vec_t {
1514 if (auto maybe_size = node->get<c10::List<int64_t>>(attr::size)) {
1515 return factory_with_ndim(node, (int)maybe_size->size(), at::kLong);
1516 }
1517 return {};
1518 }};
1519
1520 static const auto get_cast_scalar_type = [](Node* node) -> at::ScalarType {
1521 switch (node->kind()) {
1522 case aten::_cast_Byte:
1523 return at::kByte;
1524 case aten::_cast_Char:
1525 return at::kChar;
1526 case aten::_cast_Double:
1527 return at::kDouble;
1528 case aten::_cast_Float:
1529 return at::kFloat;
1530 case aten::_cast_Half:
1531 return at::kHalf;
1532 case aten::_cast_Int:
1533 return at::kInt;
1534 case aten::_cast_Long:
1535 return at::kLong;
1536 case aten::_cast_Short:
1537 return at::kShort;
1538 default:
1539 AT_ASSERTM(
1540 false,
1541 "unknown node kind in get_cast_scalar_type: ",
1542 node->kind().toQualString());
1543 }
1544 };
1545 static const register_formula_for cast_ops{
1546 {
1547 "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
1548 "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
1549 "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
1550 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
1551 "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
1552 "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
1553 "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
1554 "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
1555 },
1556 [](Node* node) -> type_vec_t {
1557 if (auto type =
1558 node->namedInput(attr::self)->type()->cast<TensorType>()) {
1559 return {type->withScalarType(get_cast_scalar_type(node))};
1560 }
1561 return {};
1562 }};
1563
1564 // First, try to match one of the registered formulas to their operator
1565 // sets.
1566 for (auto& entry : shape_formulas) {
1567 if (node->isMemberOf(entry.first)) {
1568 auto types = entry.second(node);
1569 if (types.empty()) {
1570 return false;
1571 } else {
1572 auto outputs = node->outputs();
1573 AT_ASSERT(types.size() == outputs.size());
1574 for (const auto i : c10::irange(types.size())) {
1575 AT_ASSERT(outputs[i]->type()->isSubtypeOf(*TensorType::get()));
1576 outputs[i]->setType(types[i]);
1577 }
1578 return true;
1579 }
1580 }
1581 }
1582
1583 // This section implements shape prop for an assorted set of nodes that only
1584 // need partial information about their input types.
1585 const auto input_type = [node](size_t index) {
1586 auto result = node->input(index)->type()->cast<TensorType>();
1587 if (result) {
1588 result = result->dimensionedOnly();
1589 }
1590 return result;
1591 };
1592 if (node->matches(
1593 "aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
1594 if (auto type = input_type(0)) {
1595 node->output()->setType(type->withDim(1));
1596 return true;
1597 }
1598 } else if (node->matches("aten::detach(Tensor(a) self) -> Tensor(a)")) {
1599 if (auto type = input_type(0)) {
1600 node->output()->setType(type->withRequiresGrad(false));
1601 return true;
1602 }
1603 } else if (
1604 node->matches(
1605 "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)")) {
1606 if (auto type = input_type(0)) {
1607 if (type->scalarType() == at::kHalf) {
1608 type = type->withScalarType(at::kFloat);
1609 }
1610 type = type->withDim(1);
1611 node->outputs()[0]->setType(type);
1612 node->outputs()[1]->setType(std::move(type));
1613 return true;
1614 }
1615 } else if (node->matches(
1616 "aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
1617 if (auto type = any_tensor_type(node)) {
1618 node->output()->setType(type->withDim(0));
1619 return true;
1620 }
1621 } else if (
1622 node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
1623 node->matches(
1624 "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
1625 if (auto type = any_tensor_type(node)) {
1626 node->output()->setType(type->withDim(1));
1627 return true;
1628 }
1629 } else if (
1630 node->matches(
1631 "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1632 node->matches(
1633 "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
1634 node->matches(
1635 "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1636 if (auto type = any_tensor_type(node)) {
1637 node->output()->setType(type->withDim(2));
1638 return true;
1639 }
1640 } else if (
1641 node->matches(
1642 "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
1643 if (auto type = any_tensor_type(node)) {
1644 node->output()->setType(type->withDim(3));
1645 return true;
1646 }
1647 } else if (
1648 node->matches(
1649 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
1650 auto type = input_type(0);
1651 auto index_type = input_type(1);
1652 // index_select behaves very weirdly when self.dim() == 0. It allows both
1653 // 0D and 1D indices, and returns a value that has as many dimensions as
1654 // index.
1655 if (type && index_type && type->dim()) {
1656 if (*type->dim() == 0) {
1657 node->output()->setType(type->withDim(index_type->dim()));
1658 } else {
1659 node->output()->setType(std::move(type));
1660 }
1661 return true;
1662 }
1663 } else if (
1664 node->matches(
1665 "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor")) {
1666 auto type = input_type(0);
1667 auto index_type = input_type(1);
1668 // Gather has this annoying edge case where index always needs to match
1669 // the number of dims of self, **except** when self is 1D and index is 0D
1670 // in which case we return a 0D output.
1671 if (type && index_type && index_type->dim()) {
1672 if (*index_type->dim() == 0) {
1673 node->output()->setType(type->withDim(0));
1674 } else {
1675 node->output()->setType(std::move(type));
1676 }
1677 return true;
1678 }
1679 } else if (
1680 node->matches(
1681 "aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
1682 auto weight_type = input_type(0);
1683 auto indices_type = input_type(1);
1684 if (weight_type && indices_type && indices_type->dim()) {
1685 node->output()->setType(weight_type->withDim(*indices_type->dim() + 1));
1686 return true;
1687 }
1688 } else if (
1689 node->matches(
1690 "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor")) {
1691 if (auto type = input_type(0)) {
1692 node->output()->setType(std::move(type));
1693 return true;
1694 }
1695 if (auto type = input_type(1)) {
1696 node->output()->setType(std::move(type));
1697 return true;
1698 }
1699 } else if (
1700 node->matches(
1701 "aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
1702 if (auto type = any_tensor_type(node)) {
1703 node->output()->setType(type->withDim(0));
1704 return true;
1705 }
1706 }
1707
1708 // The code below implements formulas that need type information for all
1709 // their tensor inputs, and have exactly one output.
1710 std::vector<TensorTypePtr> tensor_types;
1711 static const auto reshape_prop =
1712 [](Node* node,
1713 Symbol shape_input,
1714 const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
1715 if (auto list_size = determineListSize(node->namedInput(shape_input))) {
1716 return tensor_types.at(0)->withDim(list_size);
1717 }
1718 return nullptr;
1719 };
1720 const auto getSingleOutputType = [&]() -> TypePtr {
1721 if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
1722 return tensor_types.at(0)->withScalarType(
1723 tensor_types.at(1)->scalarType());
1724 } else if (
1725 node->matches(
1726 "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1727 node->matches(
1728 "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)") ||
1729 node->matches(
1730 "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)")) {
1731 return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
1732 } else if (
1733 node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
1734 node->matches(
1735 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
1736 node->matches(
1737 "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
1738 return reshape_prop(node, attr::size, tensor_types);
1739 } else if (
1740 node->matches(
1741 "aten::as_tensor(Tensor data, *, ScalarType? dtype, Device? device) -> Tensor")) {
1742 TypePtr input_type = node->inputs().at(0)->type();
1743 if (auto type = input_type->cast<TensorType>()) {
1744 if (type->scalarType() && type->device()) {
1745 at::ScalarType default_type = *type->scalarType();
1746 c10::Device default_device = *type->device();
1747 if (auto dtype_index =
1748 node->schema().argumentIndexWithName("dtype")) {
1749 auto inp = toIValue(node->inputs().at(*dtype_index));
1750 if (inp == std::nullopt) {
1751 return nullptr;
1752 }
1753 if (!inp->isNone()) {
1754 default_type = inp->toScalarType();
1755 }
1756 }
1757 if (auto device_index =
1758 node->schema().argumentIndexWithName("device")) {
1759 auto inp = toIValue(node->inputs().at(*device_index));
1760 if (inp == std::nullopt) {
1761 return nullptr;
1762 }
1763 if (!inp->isNone()) {
1764 default_device = inp->toDevice();
1765 }
1766 }
1767 node->output()->setType(TensorType::create(
1768 default_type,
1769 default_device,
1770 type->dim(),
1771 /*requires_grad=*/std::nullopt));
1772 }
1773 }
1774 return nullptr;
1775 } else if (
1776 node->matches(
1777 "aten::reshape(Tensor(a) self, int[] shape) -> Tensor(a)")) {
1778 return reshape_prop(node, attr::shape, tensor_types);
1779 } else if (node->matches(
1780 "aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
1781 return reshape_prop(node, attr::repeats, tensor_types);
1782 } else if (node->matches(
1783 "aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
1784 auto& t = tensor_types.at(0);
1785 if (!t->dim()) {
1786 return t;
1787 }
1788 return t->withDim(*t->dim() + 1);
1789 } else if (
1790 node->matches(
1791 "aten::select(Tensor self, int dim, int index) -> Tensor") ||
1792 node->matches(
1793 "aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
1794 auto& t = tensor_types.at(0);
1795 return t->dim() && *t->dim() > 0 ? t->withDim(*t->dim() - 1) : nullptr;
1796 } else if (node->matches(
1797 "aten::matmul(Tensor self, Tensor other) -> Tensor")) {
1798 if (!tensor_types.at(0)->dim() || !tensor_types.at(1)->dim()) {
1799 return nullptr;
1800 }
1801 auto dim1 = *tensor_types.at(0)->dim();
1802 auto dim2 = *tensor_types.at(1)->dim();
1803 if (dim1 == 1 && dim2 == 1) {
1804 // Dot product
1805 return tensor_types.at(0)->withDim(0);
1806 // NOLINTNEXTLINE(bugprone-branch-clone)
1807 } else if (dim1 == 2 && dim2 == 2) {
1808 // Matrix multiply
1809 return tensor_types.at(0);
1810 } else if (dim1 == 1 && dim2 == 2) {
1811 // Unsqueeze + matrix multiply + squeeze
1812 return tensor_types.at(0);
1813 } else if (dim1 == 2 && dim2 == 1) {
1814 // Matrix vector multiply
1815 return tensor_types.at(1);
1816 } else {
1817 // Batched matrix multiply (possibly with squeeze + unsqueeze if one
1818 // argument is 1D)
1819 auto type = broadcast(tensor_types, tensor_types[0]->scalarType());
1820 if (dim1 == 1 || dim2 == 1) {
1821 type = type->withDim(type->dim().value() - 1);
1822 }
1823 return type;
1824 }
1825 } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
1826 return tensor_types.at(0)->dimensionedOnly()->withScalarType(at::kLong);
1827 } else if (node->matches(
1828 "aten::take(Tensor self, Tensor index) -> Tensor")) {
1829 return tensor_types.at(1)->dimensionedOnly()->withScalarType(
1830 tensor_types.at(0)->scalarType());
1831 } else if (node->matches(
1832 "aten::diagflat(Tensor self, int offset) -> Tensor")) {
1833 return tensor_types.at(0)->withDim(2);
1834 } else if (node->matches(
1835 "aten::diag(Tensor self, int diagonal) -> Tensor")) {
1836 auto& t = tensor_types.at(0);
1837 if (t->dim() && *t->dim() == 1) {
1838 return t->withDim(2);
1839 } else if (t->dim() && *t->dim() == 2) {
1840 return t->withDim(1);
1841 } else {
1842 return nullptr;
1843 }
1844 } else if (
1845 node->matches(
1846 "aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
1847 auto& t = tensor_types.at(0);
1848 if (!t->dim()) {
1849 return nullptr;
1850 }
1851 return t->withDim(*t->dim() + 1);
1852 } else if (node->matches(
1853 "aten::polygamma(int n, Tensor self) -> Tensor")) {
1854 return tensor_types.at(0);
1855 }
1856 return nullptr;
1857 };
1858 if (auto maybe_tensor_types = gatherTensorTypes(node)) {
1859 tensor_types = std::move(*maybe_tensor_types);
1860 } else {
1861 return false;
1862 }
1863 if (node->outputs().size() == 1) {
1864 if (auto type = getSingleOutputType()) {
1865 node->output()->setType(std::move(type));
1866 return true;
1867 }
1868 }
1869 return false;
1870 }
1871
PropagateCompleteShapeOnNode(Node * node,bool insert_expands,std::vector<TensorTypePtr> tensor_types)1872 bool PropagateCompleteShapeOnNode(
1873 Node* node,
1874 bool insert_expands,
1875 std::vector<TensorTypePtr> tensor_types) {
1876 // For expensive ops we can directly encode their shape propagation
1877 // here, otherwise we fallback to running a fake version of the op
1878 // to get a quick and dirty propagation.
1879 if (node->matches(
1880 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1881 node->matches(
1882 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
1883 node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
1884 // These nodes handle tensors of different shapes internally, so there's
1885 // no need to insert explicit expand nodes.
1886 return PropagateShapeOnNodeByRunningIt(node);
1887 } else if (node->matches(
1888 "aten::div(Tensor self, Tensor other) -> Tensor")) {
1889 // "div" handle tensors of different shapes internally, so there's no need
1890 // to insert explicit expand nodes.
1891 // Note that this function could be merged to the one above , but "div" is
1892 // not always safe to run by itself due to integer divide-by-zero.
1893 // We fake the execution by running "mul" operation instead.
1894 auto op = getOperatorForLiteral(
1895 "aten::mul(Tensor self, Tensor other) -> Tensor")
1896 ->getOperation();
1897 return PropagateShapeOnNodeByRunningIt(node, std::move(op));
1898 } else if (node->matches(
1899 "aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
1900 node->output()->setType(tensor_types.at(0));
1901 return true;
1902 } else if (
1903 node->matches(
1904 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1905 node->matches(
1906 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
1907 node->matches("aten::div(Tensor self, Scalar other) -> Tensor") ||
1908 node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) {
1909 auto first_scalar_type = (tensor_types)[0]->scalarType();
1910 auto second_scalar_type =
1911 tryScalarTypeFromJitType(*node->inputs()[1]->type());
1912 if (!first_scalar_type || !second_scalar_type) {
1913 return false;
1914 }
1915 if (isIntegralType(*first_scalar_type, false) &&
1916 isFloatingType(*second_scalar_type)) {
1917 auto default_dtype =
1918 at::typeMetaToScalarType(caffe2::get_default_dtype());
1919 auto type = tensor_types[0]->withScalarType(default_dtype);
1920 node->output()->setType(std::move(type));
1921 return true;
1922 }
1923 if (c10::ScalarType::Bool == *first_scalar_type &&
1924 c10::ScalarType::Bool != *second_scalar_type) {
1925 auto result_type =
1926 c10::promoteTypes(*first_scalar_type, *second_scalar_type);
1927 auto type = tensor_types[0]->withScalarType(result_type);
1928 node->output()->setType(std::move(type));
1929 return true;
1930 }
1931 auto type = tensor_types[0]->withScalarType(first_scalar_type);
1932 node->output()->setType(std::move(type));
1933 return true;
1934 } else if (
1935 insert_expands &&
1936 (node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
1937 node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
1938 node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
1939 node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
1940 node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
1941 node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
1942 node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
1943 node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
1944 node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
1945 // Binary broadcasting ops
1946 // NB: we don't handle the nodes in any other way (note the lack of
1947 // return!), because the type casting logic in scalar cases is
1948 // non-trivial. It's better to just run them.
1949 broadcastBinary(node, tensor_types, 0, 1);
1950 return PropagateShapeOnNodeByRunningIt(node);
1951 } else if (
1952 node->matches(
1953 "aten::logit(Tensor self, float? eps = None) -> Tensor") ||
1954 node->matches("aten::neg(Tensor self) -> Tensor") ||
1955 node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
1956 node->matches("aten::tanh(Tensor self) -> Tensor")) {
1957 node->output()->setType(tensor_types.at(0)->contiguous());
1958 return true;
1959 } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
1960 auto lhs_type = tensor_types.at(0);
1961 auto rhs_type = tensor_types.at(1);
1962 auto lhs_sizes = lhs_type->sizes().concrete_sizes().value();
1963 auto rhs_sizes = rhs_type->sizes().concrete_sizes().value();
1964 SHAPE_ASSERT(
1965 *lhs_type->sizes().size() == 2 && *rhs_type->sizes().size() == 2);
1966 node->output()->setType(TensorType::createContiguous(
1967 *lhs_type->scalarType(),
1968 *lhs_type->device(),
1969 at::IntArrayRef{lhs_sizes[0], rhs_sizes[1]}));
1970 return true;
1971 } else if (node->matches("aten::t(Tensor self) -> Tensor")) {
1972 auto tp = tensor_types.at(0);
1973 auto sizes = tp->sizes().concrete_sizes().value();
1974 auto strides = tp->strides().concrete_sizes().value();
1975 SHAPE_ASSERT(sizes.size() == 2);
1976 std::swap(sizes.at(0), sizes.at(1));
1977 std::swap(strides.at(0), strides.at(1));
1978 node->output()->setType(tp->withSizesStrides(sizes, strides));
1979 return true;
1980 } else if (
1981 node->matches(
1982 "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
1983 /*const_inputs=*/{attr::dim, attr::length})) {
1984 auto tp = tensor_types.at(0);
1985 auto sizes = tp->sizes().concrete_sizes().value();
1986 int64_t dim = node->get<int64_t>(attr::dim).value();
1987 int64_t length = node->get<int64_t>(attr::length).value();
1988 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
1989 sizes.at(dim) = length;
1990 node->output()->setType(
1991 tp->withSizesStrides(sizes, tp->strides().concrete_sizes().value()));
1992 return true;
1993 } else if (node->matches(
1994 "aten::sum(Tensor self, *, int? dtype) -> Tensor")) {
1995 node->output()->setType(tensor_types.at(0)->withSizes({}));
1996 return true;
1997 } else if (
1998 node->matches(
1999 "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
2000 /*const_inputs=*/{attr::dim, attr::keepdim})) {
2001 auto& tp = tensor_types.at(0);
2002 auto sizes = tp->sizes().concrete_sizes().value();
2003 auto dims = node->get<c10::List<int64_t>>(attr::dim).value();
2004 bool keepdim = node->get<bool>(attr::keepdim).value();
2005 std::reverse(dims.begin(), dims.end());
2006 for (int64_t dim : dims) {
2007 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2008 if (keepdim) {
2009 sizes.at(dim) = 1;
2010 } else {
2011 sizes.erase(sizes.begin() + dim);
2012 }
2013 }
2014 node->output()->setType(tp->withSizes(sizes));
2015 return true;
2016 } else if (node->matches(
2017 "aten::squeeze(Tensor self, int dim) -> Tensor",
2018 /*const_inputs=*/attr::dim)) {
2019 auto& tp = tensor_types.at(0);
2020 auto sizes = tp->sizes().concrete_sizes().value();
2021 auto strides = tp->strides().concrete_sizes().value();
2022 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2023 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
2024 if (sizes.at(dim) == 1) {
2025 sizes.erase(sizes.begin() + dim);
2026 strides.erase(strides.begin() + dim);
2027 }
2028 node->output()->setType(tp->withSizesStrides(sizes, strides));
2029 return true;
2030 } else if (node->matches(
2031 "aten::unsqueeze(Tensor self, int dim) -> Tensor",
2032 /*const_inputs=*/attr::dim)) {
2033 auto& tp = tensor_types.at(0);
2034 auto sizes = tp->sizes().concrete_sizes().value();
2035 auto strides = tp->strides().concrete_sizes().value();
2036 int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
2037 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
2038 int64_t new_stride = dim >= static_cast<int64_t>(sizes.size())
2039 ? 1
2040 : sizes.at(dim) * strides.at(dim);
2041 sizes.insert(sizes.begin() + dim, 1);
2042 strides.insert(strides.begin() + dim, new_stride);
2043 node->output()->setType(tp->withSizesStrides(sizes, strides));
2044 return true;
2045 } else if (node->matches(
2046 "aten::view(Tensor self, int[] size) -> Tensor",
2047 /*const_inputs=*/attr::size)) {
2048 auto sizes = node->get<c10::List<int64_t>>(attr::size).value();
2049 bool inferred = false;
2050 size_t inferred_idx = 0;
2051 int64_t size_product = 1;
2052 for (const auto i : c10::irange(sizes.size())) {
2053 if (sizes.get(i) == -1) {
2054 if (inferred)
2055 throw propagation_error();
2056 inferred = true;
2057 inferred_idx = i;
2058 } else {
2059 size_product *= sizes.get(i);
2060 }
2061 }
2062
2063 if (inferred) {
2064 SHAPE_ASSERT(size_product != 0);
2065 int64_t numel = 1;
2066 auto concrete_sizes =
2067 tensor_types.at(0)->sizes().concrete_sizes().value();
2068 for (int64_t s : concrete_sizes)
2069 numel *= s;
2070 int64_t inferred_size = numel / size_product;
2071 sizes[inferred_idx] = inferred_size;
2072 }
2073 node->output()->setType(tensor_types.at(0)->withSizes(sizes.vec()));
2074 return true;
2075 } else if (node->matches(
2076 "aten::type_as(Tensor self, Tensor other) -> Tensor")) {
2077 if (tensor_types.at(0)->scalarType() ==
2078 tensor_types.at(1)->scalarType()) {
2079 node->output()->setType(node->namedInput(attr::self)->type());
2080 } else {
2081 // This will be a copy, so the result will be contiguous
2082 node->output()->setType(tensor_types.at(1)->withSizes(
2083 tensor_types.at(0)->sizes().concrete_sizes().value()));
2084 }
2085 return true;
2086 } else if (
2087 node->matches(
2088 "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
2089 /*const_inputs=*/attr::size)) {
2090 auto tp = tensor_types.at(0);
2091 auto sizesAndStrides = at::inferExpandGeometry_dimvector(
2092 tp->sizes().concrete_sizes().value(),
2093 tp->strides().concrete_sizes().value(),
2094 node->get<c10::List<int64_t>>(attr::size).value().vec());
2095 node->output()->setType(
2096 tp->withSizesStrides(sizesAndStrides.sizes, sizesAndStrides.strides));
2097 return true;
2098 } else if (
2099 node->matches(
2100 "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
2101 /*const_inputs=*/attr::dim)) {
2102 auto ten = tensor_types.at(0);
2103 auto index = tensor_types.at(1);
2104 int64_t dim = node->get<int64_t>(attr::dim).value();
2105 SHAPE_ASSERT(*index->sizes().size() == 1);
2106 SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
2107 std::vector<int64_t> sizes = ten->sizes().concrete_sizes().value();
2108 sizes[dim] = index->sizes()[0].value();
2109 node->output()->setType(ten->withSizes(sizes));
2110 return true;
2111 } else if (node->matches(
2112 "aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
2113 /*const_inputs=*/{attr::chunks, attr::dim})) {
2114 auto input_type = tensor_types.at(0);
2115 auto sizes = input_type->sizes().concrete_sizes().value();
2116 auto strides = input_type->strides().concrete_sizes().value();
2117 int64_t dim = node->get<int64_t>(attr::dim).value();
2118 int64_t chunks = node->get<int64_t>(attr::chunks).value();
2119 sizes[dim] /= chunks;
2120 for (Value* output : node->outputs()) {
2121 output->setType(input_type->withSizesStrides(sizes, strides));
2122 }
2123 if (*input_type->sizes()[dim] % chunks != 0) {
2124 sizes[dim] = *input_type->sizes()[dim] % chunks;
2125 node->outputs().back()->setType(
2126 input_type->withSizesStrides(sizes, strides));
2127 }
2128 return true;
2129 } else if (node->kind() == ::c10::onnx::Shape) {
2130 SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
2131 std::vector<int64_t> dim_vec = {
2132 (int64_t)*tensor_types.at(0)->sizes().size()};
2133 at::IntArrayRef dims(dim_vec);
2134 node->output()->setType(
2135 TensorType::createContiguous(at::kLong, at::kCPU, dims));
2136 return true;
2137 } else if (node->kind() == ::c10::onnx::Reshape) {
2138 setUnshapedType(node);
2139 return true;
2140 }
2141 setUnshapedType(node);
2142 return false;
2143 }
2144 };
2145 } // anonymous namespace
2146
PropagateInputShapes(const std::shared_ptr<Graph> & graph)2147 void PropagateInputShapes(const std::shared_ptr<Graph>& graph) {
2148 ShapePropagator(graph).propagateBlock(graph->block());
2149 }
2150
2151 namespace {
2152
2153 using TypeCache = std::unordered_map<TypePtr, TypePtr>;
2154
2155 TypePtr getOrCreateUnshapedType(
2156 const TypePtr& type,
2157 TypeCache& unshaped_type_cache);
2158
unshapedTypeImpl(TypePtr type,TypeCache & unshaped_type_cache)2159 TypePtr unshapedTypeImpl(TypePtr type, TypeCache& unshaped_type_cache) {
2160 if (type->isSubtypeOf(*TensorType::get())) {
2161 return TensorType::get();
2162 }
2163 at::ArrayRef<TypePtr> contained = type->containedTypes();
2164 if (contained.empty()) {
2165 return type;
2166 }
2167 std::vector<TypePtr> unshaped_contained_types;
2168 for (const auto& contained_type : contained) {
2169 unshaped_contained_types.push_back(
2170 getOrCreateUnshapedType(contained_type, unshaped_type_cache));
2171 }
2172 return type->withContained(std::move(unshaped_contained_types));
2173 }
2174
getOrCreateUnshapedType(const TypePtr & type,TypeCache & unshaped_type_cache)2175 TypePtr getOrCreateUnshapedType(
2176 const TypePtr& type,
2177 TypeCache& unshaped_type_cache) {
2178 auto maybe_cached_type = unshaped_type_cache.find(type);
2179 if (maybe_cached_type != unshaped_type_cache.end()) {
2180 return maybe_cached_type->second;
2181 }
2182 auto unshaped_type = unshapedTypeImpl(type, unshaped_type_cache);
2183 unshaped_type_cache[type] = unshaped_type;
2184 return unshaped_type;
2185 }
2186
2187 void EraseShapeInformation(
2188 const std::shared_ptr<Graph>& graph,
2189 TypeCache& unshaped_type_cache);
2190
EraseShapeInformation(at::ArrayRef<Value * > vals,TypeCache & unshaped_type_cache)2191 void EraseShapeInformation(
2192 at::ArrayRef<Value*> vals,
2193 TypeCache& unshaped_type_cache) {
2194 for (Value* v : vals) {
2195 v->setType(getOrCreateUnshapedType(v->type(), unshaped_type_cache));
2196 }
2197 }
2198
EraseShapeInformation(Block * b,TypeCache & unshaped_type_cache)2199 void EraseShapeInformation(Block* b, TypeCache& unshaped_type_cache) {
2200 EraseShapeInformation(b->inputs(), unshaped_type_cache);
2201 EraseShapeInformation(b->outputs(), unshaped_type_cache);
2202 for (Node* n : b->nodes()) {
2203 EraseShapeInformation(n->outputs(), unshaped_type_cache);
2204 for (Block* sb : n->blocks()) {
2205 EraseShapeInformation(sb, unshaped_type_cache);
2206 }
2207 if (n->hasAttribute(attr::Subgraph)) {
2208 EraseShapeInformation(n->g(attr::Subgraph), unshaped_type_cache);
2209 }
2210 }
2211 }
2212
EraseShapeInformation(const std::shared_ptr<Graph> & graph,TypeCache & unshaped_type_cache)2213 void EraseShapeInformation(
2214 const std::shared_ptr<Graph>& graph,
2215 TypeCache& unshaped_type_cache) {
2216 EraseShapeInformation(graph->block(), unshaped_type_cache);
2217 }
2218
2219 } // anonymous namespace
2220
EraseShapeInformation(const std::shared_ptr<Graph> & graph)2221 void EraseShapeInformation(const std::shared_ptr<Graph>& graph) {
2222 TypeCache unshaped_type_cache;
2223 EraseShapeInformation(graph->block(), unshaped_type_cache);
2224 }
2225 } // namespace torch::jit
2226