1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_ 18 19 #include "mlir/IR/Builders.h" // from @llvm-project 20 #include "tensorflow/dtensor/cc/dstatus.h" 21 #include "tensorflow/dtensor/mlir/spmd_expander.h" 22 23 namespace tensorflow { 24 namespace dtensor { 25 26 // Implement Layout propagation and SPMD expansion for Convolution ops. 27 // 28 // The extended class will be registered in spmd_expander.cc for Conv2D/3D and 29 // Conv2D/3D Backprop ops to enable proper DTensor behavior of them. This 30 // implementation is internal and specific to DTensor while upstream(python) 31 // users won't need to use this class directly in any fashion. 32 class ConvSPMDExpander : public SPMDExpanderBase { 33 public: 34 StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override; 35 36 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward( 37 mlir::Operation* op, 38 const llvm::DenseMap<int, Layout>& input_layouts) override; 39 40 StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward( 41 mlir::Operation* op, 42 const llvm::DenseMap<int, Layout>& output_layouts) override; 43 }; 44 45 } // namespace dtensor 46 } // namespace tensorflow 47 48 #endif // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_ 49