1 #pragma once 2 #include <c10/core/ScalarType.h> 3 #include <memory> 4 5 namespace torch { 6 namespace jit { 7 namespace tensorexpr { 8 9 template <typename Node> 10 using NodePtr = std::shared_ptr<Node>; 11 12 template <typename To, typename From> to(const NodePtr<From> & x)13NodePtr<To> to(const NodePtr<From>& x) { 14 return std::dynamic_pointer_cast<To>(x); 15 } 16 17 template <typename To, typename From> static_to(NodePtr<From> x)18NodePtr<To> static_to(NodePtr<From> x) { 19 return std::static_pointer_cast<To>(x); 20 } 21 22 template <typename Node, typename... Args> alloc(Args &&...args)23NodePtr<Node> alloc(Args&&... args) { 24 return std::make_shared<Node>(std::forward<Args>(args)...); 25 } 26 27 class Buf; 28 class Expr; 29 class Stmt; 30 class Var; 31 32 using BufPtr = NodePtr<Buf>; 33 using ExprPtr = NodePtr<Expr>; 34 using StmtPtr = NodePtr<Stmt>; 35 using VarPtr = NodePtr<Var>; 36 37 class ExprHandle; 38 class VarHandle; 39 class BufHandle; 40 41 class Add; 42 class And; 43 class BitCast; 44 class Broadcast; 45 class Cast; 46 class CompareSelect; 47 class Div; 48 class IfThenElse; 49 class Intrinsics; 50 class Let; 51 class Load; 52 class Lshift; 53 class Max; 54 class MaxTerm; 55 class Min; 56 class MinTerm; 57 class Mod; 58 class Mul; 59 class Or; 60 class Polynomial; 61 class Ramp; 62 class ReduceOp; 63 class RoundOff; 64 class Rshift; 65 class Store; 66 class Sub; 67 class Term; 68 class Xor; 69 using AddPtr = NodePtr<Add>; 70 using AndPtr = NodePtr<And>; 71 using BitCastPtr = NodePtr<BitCast>; 72 using BroadcastPtr = NodePtr<Broadcast>; 73 using CastPtr = NodePtr<Cast>; 74 using CompareSelectPtr = NodePtr<CompareSelect>; 75 using DivPtr = NodePtr<Div>; 76 using IfThenElsePtr = NodePtr<IfThenElse>; 77 using IntrinsicsPtr = NodePtr<Intrinsics>; 78 using LetPtr = NodePtr<Let>; 79 using LoadPtr = NodePtr<Load>; 80 using LshiftPtr = NodePtr<Lshift>; 81 using MaxPtr = NodePtr<Max>; 82 using MaxTermPtr = NodePtr<MaxTerm>; 83 using MinPtr = NodePtr<Min>; 84 using MinTermPtr = NodePtr<MinTerm>; 85 using ModPtr = NodePtr<Mod>; 86 using MulPtr = NodePtr<Mul>; 87 using OrPtr = NodePtr<Or>; 88 using PolynomialPtr = NodePtr<Polynomial>; 89 using RampPtr = NodePtr<Ramp>; 90 using ReduceOpPtr = NodePtr<ReduceOp>; 91 using RoundOffPtr = NodePtr<RoundOff>; 92 using RshiftPtr = NodePtr<Rshift>; 93 using StorePtr = NodePtr<Store>; 94 using SubPtr = NodePtr<Sub>; 95 using TermPtr = NodePtr<Term>; 96 using XorPtr = NodePtr<Xor>; 97 98 class Allocate; 99 class AtomicAdd; 100 class Block; 101 class Cond; 102 class ExternalCall; 103 class ExternalCallWithAlloc; 104 class For; 105 class Free; 106 class FreeExt; 107 class PlacementAllocate; 108 class SyncThreads; 109 using AllocatePtr = NodePtr<Allocate>; 110 using AtomicAddPtr = NodePtr<AtomicAdd>; 111 using BlockPtr = NodePtr<Block>; 112 using CondPtr = NodePtr<Cond>; 113 using ExternalCallPtr = NodePtr<ExternalCall>; 114 using ExternalCallWithAllocPtr = NodePtr<ExternalCallWithAlloc>; 115 using ForPtr = NodePtr<For>; 116 using FreePtr = NodePtr<Free>; 117 using FreeExtPtr = NodePtr<FreeExt>; 118 using PlacementAllocatePtr = NodePtr<PlacementAllocate>; 119 using SyncThreadsPtr = NodePtr<SyncThreads>; 120 121 #define IMM_DECLARE(Type, Name) \ 122 class Name##Imm; \ 123 using Name##ImmPtr = NodePtr<Name##Imm>; 124 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_DECLARE); 125 #undef IMM_DECLARE 126 127 } // namespace tensorexpr 128 } // namespace jit 129 } // namespace torch 130