xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/op_registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <memory>
6 
7 namespace torch {
8 namespace jit {
9 // Moved from shape_analysis.cpp
10 
11 // Requirements:
12 //   dims           : preserved from the first argument
13 //   scalar type    : preserved from the first argument (doesn't have to
14 //                    match other arguments)
15 //   device         : always matching and preserved
16 //   tensor inputs  : *
17 //   tensor outputs : 1
18 // NB: those ops (with slight adjustments) are good candidates for restarts.
19 //     Knowing the type and device of weights or biases is usually enough to
20 //     infer the output type.
21 std::shared_ptr<OperatorSet> nn_ops_first_input_preserving();
22 
23 // Requirements:
24 //   dims           : Changed from first argument
25 //   scalar type    : preserved from the first argument
26 //   device         : always matching and preserved
27 //   tensor inputs  : 1
28 //   tensor outputs : 1
29 std::shared_ptr<OperatorSet> ops_one_tensor_in_shape_transform();
30 } // namespace jit
31 } // namespace torch
32