1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ 18 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/strings/string_view.h" 23 #include "llvm/ADT/ArrayRef.h" 24 #include "llvm/ADT/SmallPtrSet.h" 25 #include "llvm/ADT/SmallVector.h" 26 #include "llvm/Support/Casting.h" 27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 28 #include "mlir/IR/Builders.h" // from @llvm-project 29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 31 #include "mlir/IR/MLIRContext.h" // from @llvm-project 32 #include "mlir/IR/Value.h" // from @llvm-project 33 #include "mlir/IR/Visitors.h" // from @llvm-project 34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" 35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" 36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" 37 #include "tensorflow/dtensor/cc/dstatus.h" 38 #include "tensorflow/dtensor/cc/tensor_layout.h" 39 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" 40 41 namespace tensorflow { 42 namespace dtensor { 43 44 constexpr absl::string_view kReduceOpAdd = "Add"; 45 constexpr absl::string_view kReduceOpAll = "All"; 46 constexpr absl::string_view kReduceOpAny = "Any"; 47 constexpr absl::string_view kReduceOpMax = "Max"; 48 constexpr absl::string_view kReduceOpMin = "Min"; 49 constexpr absl::string_view kReduceOpMul = "Mul"; 50 // Mean is not a valid combinator function on its own. It is handled specially 51 // by the reduce expansion. 52 constexpr absl::string_view kReduceOpMean = "Mean"; 53 54 // Takes a global type and converts it to a local type. Fails if the number of 55 // shards does not divide the size of the dimension (if not dynamic). 56 StatusOr<mlir::TensorType> LocalTypeFromGlobalType( 57 const Layout& layout, const mlir::TensorType& original_type); 58 59 // Takes a global type and converts it to a local type. 60 StatusOr<mlir::TensorType> GlobalTypeFromLocalType( 61 const Layout& layout, const mlir::TensorType& original_type); 62 63 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways 64 // in 'split_dimension' dimension and returns the split values. 65 Status CreateSplitOp(const int num_split, const int split_dimension, 66 const mlir::Location location, mlir::Value src_input, 67 mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op); 68 69 // Given layouts + shapes, determines if the two are broadcast compatible. 70 // See source file for more documentation. 71 StatusOr<Layout> GetBroadcastLayoutForElementWise( 72 const Layout& layout_a, const Layout& layout_b, 73 mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b, 74 int64_t dims_to_ignore, std::vector<std::string>& to_split_a, 75 std::vector<std::string>& to_split_b); 76 77 // Returns a merged layout using `GetBroadcastLayoutForElementwise()` function 78 // given a list of operand layouts. 79 StatusOr<absl::optional<Layout>> GetMergedOperandLayout( 80 const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op); 81 82 // Returns the forwarded input value of DTensorLayout op for which `value` is 83 // the output. This must be used after layout propagation and before SPMD 84 // expansion when all mlir::Value's of tf ops are followed by DTensorLayout op 85 // to specify output layout. 86 // To make the implementation safe for Layout Propagation V1 algorithm, if the 87 // defining op of `value` is not DTensorLayout op (only the case for V1), 88 // returns `value` directly. 89 // TODO(b/172936130): Remove special casing for v1 Layout Propagation algorithm. 90 mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value); 91 92 // Goal of this function is to connect 'mlir::Value's (read 'mlir::OpResult's) 93 // to the 'mlir::OpOperand's which use them, crossing function call boundaries. 94 // The only keys in consumers which will not actually be 'mlir::OpResult's will 95 // be the 'mlir::Value's representing the inputs of the main function. 96 // The rest will be direct output of operations -- i.e. mlir::OpResult. 97 // Note that 'mlir::Value's that are not used by any op or are simply returned 98 // from the main functiuon will not be in this list. In these cases, there are 99 // no conditions on the layouts for these 'mlir::Value's. 100 // 101 // A list of current assumptions in this code: 102 // * Functions are only called once. 103 // * Functions that are not reachable from main have been trimmed. 104 // * Input to CopyToMesh can always be traced back to function inputs. 105 mlir::LogicalResult PopulateConsumersFromModule( 106 mlir::ModuleOp* module, mlir::Dialect* tf_dialect, 107 llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers); 108 109 // From device id, return an mlir::Value for a tensor of shape [1, mesh.rank()] 110 // whose entries are the mesh coordinates of the device. The mesh used, is the 111 // mesh for the given cluster. 112 StatusOr<mlir::Value> GetMeshCoordinatesFromCluster( 113 mlir::tf_device::ClusterOp cluster); 114 115 // Checks that optional metadata attributes of `op` are valid if they 116 // exist. More specifically, output layouts of tf.Shape op and layouts of 117 // resources inferred from AssignVariable op is added as metadata. 118 mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op); 119 120 // Creates a map from function to ops which calls the function. 121 mlir::LogicalResult GetFuncToCaller( 122 mlir::ModuleOp module, 123 llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller); 124 125 // Takes an operand and traces its use across function call and 126 // tf_device.cluster boundaries. Note that this may turn one operand into many. 127 llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp( 128 mlir::OpOperand* operand, 129 const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller, 130 llvm::SmallVector<mlir::Value, 4>* skipped_values = nullptr); 131 132 // Replaces `cluster` with a new tf_device.cluster without return values 133 // if result values are not used by any other ops. 134 // 135 // For example: 136 // 137 // %unused_value = "tf_device.cluster"() ({ 138 // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32> 139 // %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32> 140 // tf_device.return %2 : tensor<i32> 141 // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>) 142 // 143 // Will be transformed to: 144 // 145 // "tf_device.cluster"() ({ 146 // %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32> 147 // %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32> 148 // tf_device.return 149 // }) {_mesh="mesh:CPU,x=2,y=2"} : () -> () 150 void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster); 151 152 mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, 153 mlir::OpBuilder& builder); 154 155 // Sets the builder insertion point to after value. If value is a block 156 // argument, this checks that all users of the value are in the same cluster. 157 // If not it errors out. If they are then it sets the inserition point to the 158 // top of the cluster. 159 Status SetBuilderInsertionAfterValue(mlir::Value value, 160 mlir::OpBuilder& builder); 161 162 // Inserts a StringFormat and Print op, should only be used for debugging 163 // on CPU. 164 Status PrintTensor(mlir::Value value, const std::string& format_string); 165 166 // Extract a vector of string from mlir value. 167 Status ExtractConstStringVectorFromValue( 168 mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector); 169 170 StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value); 171 172 // A general Iterator that visits a FuncOp's body in topological order. Note 173 // that this does not visit the given FuncOp itself. Function ops are visited 174 // exactly once if functions are used in multiple call sites. 175 // 176 // An example usage of this Iterator is for SPMD Expansion or Sparse Expansion, 177 // where we expand ops in topological order starting from the `main` FuncOp, 178 // only visiting function ops once so that we don't expand multiple times. 179 class TopologicalIterator { 180 public: 181 explicit TopologicalIterator(mlir::func::FuncOp main_func); 182 183 // Returns whether there is any further ops to visit. 184 bool hasNext(); 185 186 // Returns the next op to visit in the topological ordering. Returns 187 // a nullptr if there is no next op to visit. 188 mlir::Operation* next(); 189 190 private: 191 // Stack to keep track of ops to visit. 192 llvm::SmallVector<mlir::Operation*, 4> ops_to_visit_; 193 194 // Keep track of functions we are walking, this is needed to avoid recursive 195 // function calls. 196 llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_in_call_stack_; 197 198 // Keep track of all visit functions. This is to guarantee that 199 // functions are visited exactly once if functions are used in multiple 200 // callsites. 201 llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_; 202 }; 203 } // namespace dtensor 204 } // namespace tensorflow 205 206 #endif // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_ 207