1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 16 #define TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/ADT/StringRef.h" 24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 25 #include "mlir/IR/Builders.h" // from @llvm-project 26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 27 #include "mlir/IR/Location.h" // from @llvm-project 28 #include "mlir/IR/Operation.h" // from @llvm-project 29 #include "mlir/IR/Value.h" // from @llvm-project 30 #include "tensorflow/compiler/xla/client/xla_builder.h" 31 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 32 #include "tensorflow/compiler/xla/shape.h" 33 #include "tensorflow/compiler/xla/types.h" 34 #include "tensorflow/core/platform/types.h" 35 #include "tensorflow/stream_executor/lib/statusor.h" 36 37 namespace xla { 38 39 // Provides a way to construct mhlo dialect ops in MLIR using XlaBuilder 40 // interface. 41 // 42 // Requires that all XlaOp arguments are either returned by any of the builder 43 // method or constructed using MakeXlaOp method in this builder. 44 // 45 // TODO(hinsu): Support more ops and utility functions to set special attributes 46 // like OpMetadata and Sharding. 47 class MlirHloBuilder : public XlaBuilder { 48 public: 49 // Constructs builder for the given function. New operations are added to the 50 // beginning of the function, if it is non empty and has a block. MlirHloBuilder(mlir::func::FuncOp func)51 explicit MlirHloBuilder(mlir::func::FuncOp func) 52 : XlaBuilder(func.getName().str()), 53 builder_(&func.getBody()), 54 loc_(builder_.getUnknownLoc()), 55 build_functions_(false) {} 56 57 // TODO(hinsu): Add a constructor to build a new MLIR function from scratch 58 // and override Build methods. 59 MlirHloBuilder(std::string name,mlir::OpBuilder builder,mlir::Location loc,bool build_functions)60 MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc, 61 bool build_functions) 62 : XlaBuilder(name), 63 builder_(builder), 64 loc_(loc), 65 build_functions_(build_functions) {} 66 67 MlirHloBuilder(const MlirHloBuilder&) = delete; 68 MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; 69 70 ~MlirHloBuilder() override; 71 72 // Wraps the given MLIR value under an XlaOp instance. Note that all HLO 73 // operations returns exactly one result therefore each op has an XlaOp 74 // wrapping result of the op. 75 // 76 // Returns an error if the HLO dialect doesn't support type of the given 77 // value. 78 StatusOr<XlaOp> MakeXlaOp(mlir::Value val); 79 80 // Returns value corresponding to the given op. 81 // 82 // Requires that the op was created by this builder. GetValue(XlaOp op)83 mlir::Value GetValue(XlaOp op) { 84 void* ptr = reinterpret_cast<void*>(op.handle()); 85 return mlir::Value::getFromOpaquePointer(ptr); 86 } 87 88 // Returns MLIR values corresponding to the given XLA ops. 89 // 90 // Requires that the ops were created by this builder. GetValues(absl::Span<const XlaOp> ops)91 std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) { 92 std::vector<mlir::Value> values; 93 for (auto xla_op : ops) { 94 values.push_back(GetValue(xla_op)); 95 } 96 return values; 97 } 98 99 // Sets location for newly built ops, until reset. SetLocation(mlir::Location loc)100 void SetLocation(mlir::Location loc) { loc_ = loc; } 101 102 // Update insertion point so that newly built ops are inserted before the 103 // given op in order, until reset. setInsertionPoint(mlir::Operation * op)104 void setInsertionPoint(mlir::Operation* op) { 105 builder_.setInsertionPoint(op); 106 } 107 108 // Returns the shape of the given op. 109 StatusOr<const Shape*> GetShapePtr(XlaOp op) const override; 110 111 // Creates the given op at the current location. 112 template <typename OpTy, typename... Args> create(Args &&...args)113 OpTy create(Args&&... args) { 114 return builder_.create<OpTy>(loc_, std::forward<Args>(args)...); 115 } 116 117 private: 118 XlaOp ConstantLiteral(const LiteralSlice& literal) override; 119 120 StatusOr<XlaOp> ConvGeneralDilatedInternal( 121 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, 122 absl::Span<const int64_t> window_strides, 123 absl::Span<const std::pair<int64_t, int64_t>> padding, 124 absl::Span<const int64_t> lhs_dilation, 125 absl::Span<const int64_t> rhs_dilation, 126 const ConvolutionDimensionNumbers& dimension_numbers, 127 int64_t feature_group_count, int64_t batch_group_count, 128 const PrecisionConfig* precision_config) override; 129 130 StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand, 131 FftType fft_type, 132 absl::Span<const int64_t> fft_length) override; 133 134 StatusOr<XlaOp> TriangularSolveInternal( 135 const Shape& shape, XlaOp a, XlaOp b, 136 TriangularSolveOptions options) override; 137 138 StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a, 139 bool lower) override; 140 141 StatusOr<XlaOp> CustomCallInternal( 142 const std::string& call_target_name, absl::Span<const XlaOp> operands, 143 const XlaComputation* computation, const Shape& shape, 144 const std::string& opaque, 145 std::optional<absl::Span<const Shape>> operand_shapes_with_layout, 146 bool has_side_effect, 147 absl::Span<const std::pair<ShapeIndex, std::pair<int64_t, ShapeIndex>>> 148 output_operand_aliasing, 149 const Literal* literal, std::optional<Window> window, 150 std::optional<ConvolutionDimensionNumbers> dnums, 151 CustomCallSchedule schedule, CustomCallApiVersion api_version) override; 152 153 StatusOr<XlaOp> ReduceInternal( 154 const Shape& shape, absl::Span<const XlaOp> all_operands, 155 const XlaComputation& computation, 156 absl::Span<const int64_t> dimensions_to_reduce) override; 157 158 StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand, 159 XlaOp init_value, 160 const XlaComputation& computation, 161 Window window) override; 162 163 XlaOp Iota(const Shape& shape, int64_t iota_dimension) override; 164 165 StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape, 166 XlaOp operand) override; 167 168 StatusOr<XlaOp> TransposeInternal( 169 const Shape& shape, XlaOp operand, 170 absl::Span<const int64_t> permutation) override; 171 172 StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand, 173 absl::Span<const int64_t> dimensions) override; 174 175 StatusOr<XlaOp> SortInternal(const Shape& shape, 176 absl::Span<const XlaOp> operands, 177 const XlaComputation& comparator, 178 int64_t dimension, bool is_stable) override; 179 180 StatusOr<XlaOp> WhileInternal(const Shape& shape, 181 const XlaComputation& condition, 182 const XlaComputation& body, 183 XlaOp init) override; 184 185 StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape, XlaOp operand, 186 const int exponent_bits, 187 const int mantissa_bits) override; 188 189 StatusOr<XlaOp> GatherInternal( 190 const Shape& shape, XlaOp input, XlaOp start_indices, 191 const GatherDimensionNumbers& dimension_numbers, 192 absl::Span<const int64_t> slice_sizes, bool indices_are_sorted) override; 193 194 StatusOr<XlaOp> ScatterInternal( 195 const Shape& shape, absl::Span<const XlaOp> inputs, XlaOp scatter_indices, 196 absl::Span<const XlaOp> updates, const XlaComputation& update_computation, 197 const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, 198 bool unique_indices) override; 199 200 StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape, XlaOp operand, 201 XlaOp val, 202 int64_t dimension) override; 203 204 StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution, 205 absl::Span<const XlaOp> parameters, 206 const Shape& shape) override; 207 StatusOr<XlaOp> RngBitGeneratorInternal(const Shape& full_result_shape, 208 RandomAlgorithm algorithm, 209 XlaOp initial_state) override; 210 211 StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand, 212 int64_t inferred_dimension) override; 213 214 StatusOr<XlaOp> DotGeneralInternal( 215 const Shape& shape, XlaOp lhs, XlaOp rhs, 216 const DotDimensionNumbers& dimension_number, 217 const PrecisionConfig* precision_config) override; 218 219 StatusOr<XlaOp> InDimBroadcast( 220 const Shape& shape, XlaOp operand, 221 absl::Span<const int64_t> broadcast_dimensions) override; 222 223 StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, 224 absl::Span<const XlaOp> operands) override; 225 226 StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, 227 ComparisonDirection direction, 228 Comparison::Type type) override; 229 230 XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, XlaOp lhs, 231 XlaOp rhs) override; 232 233 StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape, 234 absl::Span<const XlaOp> operands) override; 235 236 XlaOp CreateToken() override; 237 238 StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape, 239 XlaOp token, 240 const std::string& config) override; 241 StatusOr<XlaOp> OutfeedWithTokenInternal( 242 XlaOp operand, XlaOp token, const Shape& shape_with_layout, 243 const std::string& outfeed_config) override; 244 245 StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape, 246 absl::Span<const XlaOp> operands, 247 int64_t dimension) override; 248 249 StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data, 250 int64_t index) override; 251 252 StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand, 253 absl::Span<const int64_t> start_indices, 254 absl::Span<const int64_t> limit_indices, 255 absl::Span<const int64_t> strides) override; 256 257 StatusOr<XlaOp> DynamicSliceInternal( 258 const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices, 259 absl::Span<const int64_t> slice_sizes) override; 260 261 StatusOr<XlaOp> DynamicUpdateSliceInternal( 262 const Shape& shape, XlaOp operand, XlaOp update, 263 absl::Span<const XlaOp> start_indices) override; 264 265 StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand, 266 XlaOp padding_value, 267 const PaddingConfig& padding_config) override; 268 269 StatusOr<XlaOp> TupleInternal(const Shape& shape, 270 absl::Span<const XlaOp> elements) override; 271 272 // Creates HLO dialect op and returns the result as an XlaOp. 273 StatusOr<XlaOp> CreateOp( 274 const std::string& op_name, const Shape& shape, 275 llvm::ArrayRef<XlaOp> operands, 276 llvm::ArrayRef<mlir::NamedAttribute> attributes = {}); 277 278 Status ImportComputation(const HloModuleProto& computation, 279 mlir::Region* region, 280 bool flatten_region_arg_tuple = false); 281 282 Status ImportComputation(const HloModuleProto& computation, 283 mlir::ModuleOp module); 284 285 mlir::OpBuilder builder_; 286 mlir::Location loc_; 287 bool build_functions_; 288 289 absl::flat_hash_map<int64_t, std::unique_ptr<Shape>> handle_to_shape_; 290 }; 291 292 } // namespace xla 293 294 #endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_MLIR_HLO_BUILDER_H_ 295