1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 // Runs constant propagation on all objects unless ignore_custom_classes is 8 // specified as true, in which case user defined classes are skipped. This is 9 // useful to prevent early fusion of packing operations, which end up lowering 10 // away information about their constructors (e.g. packed::linear_clamp_prepack 11 // and prepacked::conv2d_clamp_prepack) 12 // Returns True if the pass made a change to the graph 13 TORCH_API bool ConstantPropagation( 14 std::shared_ptr<Graph>& graph, 15 bool ignore_custom_classes = false); 16 17 // runs constant propagation only on ops that have non-aliasing inputs & outputs 18 // Returns True if the pass made a change to the graph 19 TORCH_API bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph); 20 21 // Runs the node if its inputs are constants. Callers of this function must 22 // make their own determination if constant prop is appropriate - for example 23 // non-deterministic ops or ops with side effects. If ignore_custom_classes is 24 // specified, nodes that output user defined classes are not run. 25 TORCH_API std::optional<Stack> runNodeIfInputsAreConstant( 26 const Node* node, 27 bool ignore_custom_classes = false, 28 AliasDb* db = nullptr); 29 30 } // namespace torch::jit 31