xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/spmd_expander_common.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
17 #define TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
18 
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Value.h"  // from @llvm-project
33 #include "mlir/IR/Visitors.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
37 #include "tensorflow/dtensor/cc/dstatus.h"
38 #include "tensorflow/dtensor/cc/tensor_layout.h"
39 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
40 
41 namespace tensorflow {
42 namespace dtensor {
43 
44 constexpr absl::string_view kReduceOpAdd = "Add";
45 constexpr absl::string_view kReduceOpAll = "All";
46 constexpr absl::string_view kReduceOpAny = "Any";
47 constexpr absl::string_view kReduceOpMax = "Max";
48 constexpr absl::string_view kReduceOpMin = "Min";
49 constexpr absl::string_view kReduceOpMul = "Mul";
50 // Mean is not a valid combinator function on its own. It is handled specially
51 // by the reduce expansion.
52 constexpr absl::string_view kReduceOpMean = "Mean";
53 
54 // Takes a global type and converts it to a local type. Fails if the number of
55 // shards does not divide the size of the dimension (if not dynamic).
56 StatusOr<mlir::TensorType> LocalTypeFromGlobalType(
57     const Layout& layout, const mlir::TensorType& original_type);
58 
59 // Takes a global type and converts it to a local type.
60 StatusOr<mlir::TensorType> GlobalTypeFromLocalType(
61     const Layout& layout, const mlir::TensorType& original_type);
62 
63 // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways
64 // in 'split_dimension' dimension and returns the split values.
65 Status CreateSplitOp(const int num_split, const int split_dimension,
66                      const mlir::Location location, mlir::Value src_input,
67                      mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op);
68 
69 // Given layouts + shapes, determines if the two are broadcast compatible.
70 // See source file for more documentation.
71 StatusOr<Layout> GetBroadcastLayoutForElementWise(
72     const Layout& layout_a, const Layout& layout_b,
73     mlir::ArrayRef<int64_t> shape_a, mlir::ArrayRef<int64_t> shape_b,
74     int64_t dims_to_ignore, std::vector<std::string>& to_split_a,
75     std::vector<std::string>& to_split_b);
76 
77 // Returns a merged layout using `GetBroadcastLayoutForElementwise()` function
78 // given a list of operand layouts.
79 StatusOr<absl::optional<Layout>> GetMergedOperandLayout(
80     const llvm::DenseMap<int, Layout>& operand_layouts, mlir::Operation* op);
81 
82 // Returns the forwarded input value of DTensorLayout op for which `value` is
83 // the output. This must be used after layout propagation and before SPMD
84 // expansion when all mlir::Value's of tf ops are followed by DTensorLayout op
85 // to specify output layout.
86 // To make the implementation safe for Layout Propagation V1 algorithm, if the
87 // defining op of `value` is not DTensorLayout op (only the case for V1),
88 // returns `value` directly.
89 // TODO(b/172936130): Remove special casing for v1 Layout Propagation algorithm.
90 mlir::Value GetForwardedDTensorLayoutInput(mlir::Value value);
91 
92 // Goal of this function is to connect 'mlir::Value's (read 'mlir::OpResult's)
93 // to the 'mlir::OpOperand's which use them, crossing function call boundaries.
94 // The only keys in consumers which will not actually be 'mlir::OpResult's will
95 // be the 'mlir::Value's representing the inputs of the main function.
96 // The rest will be direct output of operations -- i.e. mlir::OpResult.
97 // Note that 'mlir::Value's that are not used by any op or are simply returned
98 // from the main functiuon will not be in this list. In these cases, there are
99 // no conditions on the layouts for these 'mlir::Value's.
100 //
101 // A list of current assumptions in this code:
102 // * Functions are only called once.
103 // * Functions that are not reachable from main have been trimmed.
104 // * Input to CopyToMesh can always be traced back to function inputs.
105 mlir::LogicalResult PopulateConsumersFromModule(
106     mlir::ModuleOp* module, mlir::Dialect* tf_dialect,
107     llvm::DenseMap<mlir::Value, std::vector<mlir::OpOperand*>>& consumers);
108 
109 // From device id, return an mlir::Value for a tensor of shape [1, mesh.rank()]
110 // whose entries are the mesh coordinates of the device. The mesh used, is the
111 // mesh for the given cluster.
112 StatusOr<mlir::Value> GetMeshCoordinatesFromCluster(
113     mlir::tf_device::ClusterOp cluster);
114 
115 // Checks that optional metadata attributes of `op` are valid if they
116 // exist. More specifically, output layouts of tf.Shape op and layouts of
117 // resources inferred from AssignVariable op is added as metadata.
118 mlir::LogicalResult ValidateMetadataAttributes(mlir::Operation* op);
119 
120 // Creates a map from function to ops which calls the function.
121 mlir::LogicalResult GetFuncToCaller(
122     mlir::ModuleOp module,
123     llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller);
124 
125 // Takes an operand and traces its use across function call and
126 // tf_device.cluster boundaries. Note that this may turn one operand into many.
127 llvm::SmallVector<mlir::OpOperand*, 4> TraceUseToNextTFOp(
128     mlir::OpOperand* operand,
129     const llvm::DenseMap<llvm::StringRef, mlir::Operation*>& func_to_caller,
130     llvm::SmallVector<mlir::Value, 4>* skipped_values = nullptr);
131 
132 // Replaces `cluster` with a new tf_device.cluster without return values
133 // if result values are not used by any other ops.
134 //
135 // For example:
136 //
137 //  %unused_value  = "tf_device.cluster"() ({
138 //      %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
139 //      %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32>
140 //      tf_device.return %2 : tensor<i32>
141 //  }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
142 //
143 // Will be transformed to:
144 //
145 //  "tf_device.cluster"() ({
146 //      %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
147 //      %2 = "tf.Neg"(%1) : (tensor<i32>) -> tensor<i32>
148 //      tf_device.return
149 //  }) {_mesh="mesh:CPU,x=2,y=2"} : () -> ()
150 void RemoveUnusedClusterResults(mlir::tf_device::ClusterOp cluster);
151 
152 mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix,
153                                             mlir::OpBuilder& builder);
154 
155 // Sets the builder insertion point to after value. If value is a block
156 // argument, this checks that all users of the value are in the same cluster.
157 // If not it errors out. If they are then it sets the inserition point to the
158 // top of the cluster.
159 Status SetBuilderInsertionAfterValue(mlir::Value value,
160                                      mlir::OpBuilder& builder);
161 
162 // Inserts a StringFormat and Print op, should only be used for debugging
163 // on CPU.
164 Status PrintTensor(mlir::Value value, const std::string& format_string);
165 
166 // Extract a vector of string from mlir value.
167 Status ExtractConstStringVectorFromValue(
168     mlir::Value value, llvm::SmallVectorImpl<std::string>& out_vector);
169 
170 StatusOr<std::string> ExtractConstScalarStringFromValue(mlir::Value value);
171 
172 // A general Iterator that visits a FuncOp's body in topological order. Note
173 // that this does not visit the given FuncOp itself. Function ops are visited
174 // exactly once if functions are used in multiple call sites.
175 //
176 // An example usage of this Iterator is for SPMD Expansion or Sparse Expansion,
177 // where we expand ops in topological order starting from the `main` FuncOp,
178 // only visiting function ops once so that we don't expand multiple times.
179 class TopologicalIterator {
180  public:
181   explicit TopologicalIterator(mlir::func::FuncOp main_func);
182 
183   // Returns whether there is any further ops to visit.
184   bool hasNext();
185 
186   // Returns the next op to visit in the topological ordering. Returns
187   // a nullptr if there is no next op to visit.
188   mlir::Operation* next();
189 
190  private:
191   // Stack to keep track of ops to visit.
192   llvm::SmallVector<mlir::Operation*, 4> ops_to_visit_;
193 
194   // Keep track of functions we are walking, this is needed to avoid recursive
195   // function calls.
196   llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_in_call_stack_;
197 
198   // Keep track of all visit functions. This is to guarantee that
199   // functions are visited exactly once if functions are used in multiple
200   // callsites.
201   llvm::SmallDenseSet<mlir::StringRef, 4> funcs_visited_;
202 };
203 }  // namespace dtensor
204 }  // namespace tensorflow
205 
206 #endif  // TENSORFLOW_DTENSOR_MLIR_SPMD_EXPANDER_COMMON_H_
207