1 #pragma once 2 3 #include <torch/csrc/lazy/ts_backend/ts_node.h> 4 5 #include <torch/csrc/lazy/core/ir_builder.h> 6 7 namespace torch { 8 namespace lazy { 9 10 // Generic IR Node implementation for nodes which can simply be described by a 11 // specific OpKind and a lowering function. IR nodes carrying 12 // metadata should not be using this class TORCH_API (and have the metadata 13 // captured by the LowerFn), but they should instead create a dedicated IR node. 14 // Doing the former would limit IR introspection. 15 class TORCH_API Generic : public TsNode { 16 public: 17 Generic( 18 OpKind op, 19 OpList operands, 20 Shape shape, 21 size_t num_outputs = 1, 22 hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9)); 23 24 Generic( 25 OpKind op, 26 OpList operands, 27 const std::function<Shape()>& shape_fn, 28 size_t num_outputs = 1, 29 hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9)); 30 31 Generic( 32 OpKind op, 33 OpList operands, 34 size_t num_outputs = 1, 35 hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9)); 36 37 Generic(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed); 38 39 private: 40 hash_t hash_seed_; 41 }; 42 43 inline NodePtr GenericOp( 44 OpKind op, 45 OpList operands, 46 Shape shape, 47 size_t num_outputs = 1, 48 hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9)) { 49 return MakeNode<Generic>( 50 op, operands, std::move(shape), num_outputs, hash_seed); 51 } 52 53 } // namespace lazy 54 } // namespace torch 55