xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/dtensor_send_recv.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
17 #define TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
18 
19 #include "llvm/Support/Casting.h"
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Location.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/dtensor/cc/dstatus.h"
28 #include "tensorflow/dtensor/cc/tensor_layout.h"
29 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
30 
31 namespace tensorflow {
32 namespace dtensor {
33 
34 // Given DTensorSend or DTensorRecv op, returns the corresponding DTensorRecv
35 // or DTensorSend op with the same key.
36 template <typename DTensorOp>
GetCorrespondingDTensorSendRecvOp(mlir::ModuleOp module,DTensorOp dtensor_op)37 StatusOr<mlir::Operation*> GetCorrespondingDTensorSendRecvOp(
38     mlir::ModuleOp module, DTensorOp dtensor_op) {
39   mlir::Operation* corresponding_op = nullptr;
40   if (std::is_same<DTensorOp, mlir::TF::DTensorSend>::value) {
41     module.walk([&](mlir::Operation* op) {
42       if (auto xla_recv_tpu = llvm::dyn_cast<mlir::TF::XlaRecvFromHostOp>(op)) {
43         if (dtensor_op.key() == xla_recv_tpu.key()) {
44           corresponding_op = op;
45           return mlir::WalkResult::interrupt();
46         }
47       } else if (auto xla_recv_cpu =
48                      llvm::dyn_cast<mlir::TF::_XlaRecvAtHostV2Op>(op)) {
49         if (dtensor_op.key() == xla_recv_cpu.key()) {
50           corresponding_op = op;
51           return mlir::WalkResult::interrupt();
52         }
53       } else if (auto dtensor_recv =
54                      llvm::dyn_cast<mlir::TF::DTensorRecv>(op)) {
55         if (dtensor_op.key() == dtensor_recv.key()) {
56           corresponding_op = op;
57           return mlir::WalkResult::interrupt();
58         }
59       } else if (auto host_recv = llvm::dyn_cast<mlir::TF::_HostRecvOp>(op)) {
60         if (dtensor_op.key() == host_recv.tensor_name()) {
61           corresponding_op = op;
62           return mlir::WalkResult::interrupt();
63         }
64       }
65       return mlir::WalkResult::advance();
66     });
67   } else {
68     const bool is_recv = std::is_same<DTensorOp, mlir::TF::DTensorRecv>::value;
69     if (!is_recv) {
70       return errors::Internal(
71           "Error checking if is same for DTensorOp and DTensorRecv.");
72     }
73     module.walk([&](mlir::Operation* op) {
74       if (auto xla_send_tpu = llvm::dyn_cast<mlir::TF::XlaSendToHostOp>(op)) {
75         if (dtensor_op.key() == xla_send_tpu.key()) {
76           corresponding_op = op;
77           return mlir::WalkResult::interrupt();
78         }
79       } else if (auto xla_send_cpu =
80                      llvm::dyn_cast<mlir::TF::_XlaSendFromHostV2Op>(op)) {
81         if (dtensor_op.key() == xla_send_cpu.key()) {
82           corresponding_op = op;
83           return mlir::WalkResult::interrupt();
84         }
85       } else if (auto dtensor_send =
86                      llvm::dyn_cast<mlir::TF::DTensorSend>(op)) {
87         if (dtensor_op.key() == dtensor_send.key()) {
88           corresponding_op = op;
89           return mlir::WalkResult::interrupt();
90         }
91       } else if (auto host_send = llvm::dyn_cast<mlir::TF::_HostSendOp>(op)) {
92         if (dtensor_op.key() == host_send.tensor_name()) {
93           corresponding_op = op;
94           return mlir::WalkResult::interrupt();
95         }
96       }
97       return mlir::WalkResult::advance();
98     });
99   }
100 
101   if (!corresponding_op)
102     return errors::InvalidArgument(
103         "DTensorSend/DTensorRecv op must have corresponding "
104         "DTensorRecv/DTensorSend op.");
105 
106   return corresponding_op;
107 }
108 
109 // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
110 // depending on src mesh cluster configuration.
111 StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
112     mlir::TF::DTensorRecv dtensor_recv);
113 
114 // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
115 // depending on src mesh cluster configuration. `output_type` can be set to the
116 // specific local tensor type needed, if different from the Recv op output type.
117 StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
118     mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type);
119 
120 // Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost,
121 // depending on the src mesh cluster. `send_from_device_zero` should be set if
122 // control flow needs to be inserted to gather data onto and only sent from the
123 // zero'th device.
124 StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp(
125     const Layout& send_input_layout, mlir::Value send_input,
126     mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero);
127 
128 // Lowers DTensorSend Op to a TF HostSend op.
129 StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp(
130     const Layout& send_input_layout, mlir::Value send_input,
131     mlir::TF::DTensorSend dtensor_send);
132 
133 // Lowers DTensorSend Op to a TF HostRecv op.
134 StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp(
135     const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv);
136 
137 }  // namespace dtensor
138 }  // namespace tensorflow
139 
140 #endif  // TENSORFLOW_DTENSOR_MLIR_DTENSOR_SEND_RECV_H_
141