xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_
17 #define TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_
18 
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "tensorflow/dtensor/cc/dstatus.h"
21 #include "tensorflow/dtensor/mlir/spmd_expander.h"
22 
23 namespace tensorflow {
24 namespace dtensor {
25 
26 // Implement Layout propagation and SPMD expansion for Convolution ops.
27 //
28 // The extended class will be registered in spmd_expander.cc for Conv2D/3D and
29 // Conv2D/3D Backprop ops to enable proper DTensor behavior of them. This
30 // implementation is internal and specific to DTensor while upstream(python)
31 // users won't need to use this class directly in any fashion.
32 class ConvSPMDExpander : public SPMDExpanderBase {
33  public:
34   StatusOr<mlir::Operation*> ExpandOp(mlir::Operation* op) override;
35 
36   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutForward(
37       mlir::Operation* op,
38       const llvm::DenseMap<int, Layout>& input_layouts) override;
39 
40   StatusOr<llvm::DenseMap<int, Layout>> ComputeLayoutBackward(
41       mlir::Operation* op,
42       const llvm::DenseMap<int, Layout>& output_layouts) override;
43 };
44 
45 }  // namespace dtensor
46 }  // namespace tensorflow
47 
48 #endif  // TENSORFLOW_DTENSOR_MLIR_EXPANSIONS_CONV_SPMD_EXPANDER_H_
49