//===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Make changes to isl's schedule tree data structure. // //===----------------------------------------------------------------------===// #ifndef POLLY_SCHEDULETREETRANSFORM_H #define POLLY_SCHEDULETREETRANSFORM_H #include "polly/Support/ISLTools.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/ErrorHandling.h" #include "isl/isl-noexceptions.h" #include namespace polly { struct BandAttr; /// This class defines a simple visitor class that may be used for /// various schedule tree analysis purposes. template struct ScheduleTreeVisitor { Derived &getDerived() { return *static_cast(this); } const Derived &getDerived() const { return *static_cast(this); } RetTy visit(isl::schedule_node Node, Args... args) { assert(!Node.is_null()); switch (isl_schedule_node_get_type(Node.get())) { case isl_schedule_node_domain: assert(isl_schedule_node_n_children(Node.get()) == 1); return getDerived().visitDomain(Node.as(), std::forward(args)...); case isl_schedule_node_band: assert(isl_schedule_node_n_children(Node.get()) == 1); return getDerived().visitBand(Node.as(), std::forward(args)...); case isl_schedule_node_sequence: assert(isl_schedule_node_n_children(Node.get()) >= 2); return getDerived().visitSequence(Node.as(), std::forward(args)...); case isl_schedule_node_set: return getDerived().visitSet(Node.as(), std::forward(args)...); assert(isl_schedule_node_n_children(Node.get()) >= 2); case isl_schedule_node_leaf: assert(isl_schedule_node_n_children(Node.get()) == 0); return getDerived().visitLeaf(Node.as(), std::forward(args)...); case isl_schedule_node_mark: assert(isl_schedule_node_n_children(Node.get()) == 1); return getDerived().visitMark(Node.as(), std::forward(args)...); case isl_schedule_node_extension: assert(isl_schedule_node_n_children(Node.get()) == 1); return getDerived().visitExtension( Node.as(), std::forward(args)...); case isl_schedule_node_filter: assert(isl_schedule_node_n_children(Node.get()) == 1); return getDerived().visitFilter(Node.as(), std::forward(args)...); default: llvm_unreachable("unimplemented schedule node type"); } } RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) { return getDerived().visitSingleChild(std::move(Domain), std::forward(args)...); } RetTy visitBand(isl::schedule_node_band Band, Args... args) { return getDerived().visitSingleChild(std::move(Band), std::forward(args)...); } RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) { return getDerived().visitMultiChild(std::move(Sequence), std::forward(args)...); } RetTy visitSet(isl::schedule_node_set Set, Args... args) { return getDerived().visitMultiChild(std::move(Set), std::forward(args)...); } RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { return getDerived().visitNode(std::move(Leaf), std::forward(args)...); } RetTy visitMark(isl::schedule_node_mark Mark, Args... args) { return getDerived().visitSingleChild(std::move(Mark), std::forward(args)...); } RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) { return getDerived().visitSingleChild(std::move(Extension), std::forward(args)...); } RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) { return getDerived().visitSingleChild(std::move(Filter), std::forward(args)...); } RetTy visitSingleChild(isl::schedule_node Node, Args... args) { return getDerived().visitNode(std::move(Node), std::forward(args)...); } RetTy visitMultiChild(isl::schedule_node Node, Args... args) { return getDerived().visitNode(std::move(Node), std::forward(args)...); } RetTy visitNode(isl::schedule_node Node, Args... args) { llvm_unreachable("Unimplemented other"); } }; /// Recursively visit all nodes of a schedule tree. template struct RecursiveScheduleTreeVisitor : ScheduleTreeVisitor { using BaseTy = ScheduleTreeVisitor; BaseTy &getBase() { return *this; } const BaseTy &getBase() const { return *this; } Derived &getDerived() { return *static_cast(this); } const Derived &getDerived() const { return *static_cast(this); } /// When visiting an entire schedule tree, start at its root node. RetTy visit(isl::schedule Schedule, Args... args) { return getDerived().visit(Schedule.get_root(), std::forward(args)...); } // Necessary to allow overload resolution with the added visit(isl::schedule) // overload. RetTy visit(isl::schedule_node Node, Args... args) { return getBase().visit(Node, std::forward(args)...); } /// By default, recursively visit the child nodes. RetTy visitNode(isl::schedule_node Node, Args... args) { for (unsigned i : rangeIslSize(0, Node.n_children())) getDerived().visit(Node.child(i), std::forward(args)...); return RetTy(); } }; /// Recursively visit all nodes of a schedule tree while allowing changes. /// /// The visit methods return an isl::schedule_node that is used to continue /// visiting the tree. Structural changes such as returning a different node /// will confuse the visitor. template struct ScheduleNodeRewriter : public RecursiveScheduleTreeVisitor { Derived &getDerived() { return *static_cast(this); } const Derived &getDerived() const { return *static_cast(this); } isl::schedule_node visitNode(isl::schedule_node Node, Args... args) { return getDerived().visitChildren(Node); } isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) { if (!Node.has_children()) return Node; isl::schedule_node It = Node.first_child(); while (true) { It = getDerived().visit(It, std::forward(args)...); if (!It.has_next_sibling()) break; It = It.next_sibling(); } return It.parent(); } }; /// Is this node the marker for its parent band? bool isBandMark(const isl::schedule_node &Node); /// Extract the BandAttr from a band's wrapping marker. Can also pass the band /// itself and this methods will try to find its wrapping mark. Returns nullptr /// if the band has not BandAttr. BandAttr *getBandAttr(isl::schedule_node MarkOrBand); /// Hoist all domains from extension into the root domain node, such that there /// are no more extension nodes (which isl does not support for some /// operations). This assumes that domains added by to extension nodes do not /// overlap. isl::schedule hoistExtensionNodes(isl::schedule Sched); /// Replace the AST band @p BandToUnroll by a sequence of all its iterations. /// /// The implementation enumerates all points in the partial schedule and creates /// an ISL sequence node for each point. The number of iterations must be a /// constant. isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll); /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent. isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor); /// Loop-distribute the band @p BandToFission as much as possible. isl::schedule applyMaxFission(isl::schedule_node BandToFission); /// Build the desired set of partial tile prefixes. /// /// We build a set of partial tile prefixes, which are prefixes of the vector /// loop that have exactly VectorWidth iterations. /// /// 1. Drop all constraints involving the dimension that represents the /// vector loop. /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth /// iterations. /// 3. Subtract loop domain from it, project out the vector loop dimension and /// get a set that contains prefixes, which do not have exactly VectorWidth /// iterations. /// 4. Project out the vector loop dimension of the set that was build on the /// first step and subtract the set built on the previous step to get the /// desired set of prefixes. /// /// @param ScheduleRange A range of a map, which describes a prefix schedule /// relation. isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth); /// Create an isl::union_set, which describes the isolate option based on /// IsolateDomain. /// /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should /// belong to the current band node. /// @param OutDimsNum A number of dimensions that should belong to /// the current band node. isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum); /// Create an isl::union_set, which describes the specified option for the /// dimension of the current node. /// /// @param Ctx An isl::ctx, which is used to create the isl::union_set. /// @param Option The name of the option. isl::union_set getDimOptions(isl::ctx Ctx, const char *Option); /// Tile a schedule node. /// /// @param Node The node to tile. /// @param Identifier An name that identifies this kind of tiling and /// that is used to mark the tiled loops in the /// generated AST. /// @param TileSizes A vector of tile sizes that should be used for /// tiling. /// @param DefaultTileSize A default tile size that is used for dimensions /// that are not covered by the TileSizes vector. isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, llvm::ArrayRef TileSizes, int DefaultTileSize); /// Tile a schedule node and unroll point loops. /// /// @param Node The node to register tile. /// @param TileSizes A vector of tile sizes that should be used for /// tiling. /// @param DefaultTileSize A default tile size that is used for dimensions isl::schedule_node applyRegisterTiling(isl::schedule_node Node, llvm::ArrayRef TileSizes, int DefaultTileSize); /// Apply greedy fusion. That is, fuse any loop that is possible to be fused /// top-down. /// /// @param Sched Sched tree to fuse all the loops in. /// @param Deps Validity constraints that must be preserved. isl::schedule applyGreedyFusion(isl::schedule Sched, const isl::union_map &Deps); } // namespace polly #endif // POLLY_SCHEDULETREETRANSFORM_H