xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ir/utility.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ir/utility.h"
17 
18 #include "mlir/IR/Block.h"  // from @llvm-project
19 #include "mlir/IR/Operation.h"  // from @llvm-project
20 #include "mlir/IR/Types.h"  // from @llvm-project
21 #include "tensorflow/core/ir/dialect.h"
22 #include "tensorflow/core/ir/interfaces.h"
23 #include "tensorflow/core/ir/types/dialect.h"
24 
25 namespace mlir {
26 namespace tfg {
27 
28 // For region-based loop ops, the first N block arguments are data values, with
29 // N control tokens afterwards.
GetLoopRegionDataArgs(Region & region)30 Block::BlockArgListType GetLoopRegionDataArgs(Region &region) {
31   Block::BlockArgListType args = region.getArguments();
32   return args.drop_back(args.size() / 2);
33 }
GetLoopRegionControlTokens(Region & region)34 Block::BlockArgListType GetLoopRegionControlTokens(Region &region) {
35   Block::BlockArgListType args = region.getArguments();
36   return args.drop_front(args.size() / 2);
37 }
GetLoopRegionControlOf(BlockArgument data)38 BlockArgument GetLoopRegionControlOf(BlockArgument data) {
39   Block &block = *data.getOwner();
40   return block.getArgument(data.getArgNumber() + block.getNumArguments() / 2);
41 }
GetLoopRegionDataOf(BlockArgument ctl)42 BlockArgument GetLoopRegionDataOf(BlockArgument ctl) {
43   Block &block = *ctl.getOwner();
44   return block.getArgument(ctl.getArgNumber() - block.getNumArguments() / 2);
45 }
46 
LookupControlDependency(Value data)47 Value LookupControlDependency(Value data) {
48   assert(!data.getType().isa<ControlType>() && "expected a data type");
49   // If the value is defined by an op, then the last result is the control
50   // dependency.
51   Value control_dep;
52   if (auto result = data.dyn_cast<OpResult>()) {
53     control_dep = *std::prev(result.getOwner()->result_end());
54   } else {
55     auto arg = data.cast<BlockArgument>();
56     control_dep = cast<ControlArgumentInterface>(arg.getOwner()->getParentOp())
57                       .getControlTokenOf(arg);
58   }
59   assert(control_dep.getType().isa<ControlType>() && "expected a control type");
60   return control_dep;
61 }
62 
LookupDataValue(Value ctl)63 Optional<Value> LookupDataValue(Value ctl) {
64   assert(ctl.getType().isa<ControlType>() && "expected a control type");
65   // If the value is defined by an op, then return the first result.
66   Value data;
67   if (auto result = ctl.dyn_cast<OpResult>()) {
68     // If the op only has a control result, then there is no data value.
69     if (result.getOwner()->getNumResults() == 1) return {};
70     data = *result.getOwner()->result_begin();
71   } else {
72     auto arg = ctl.cast<BlockArgument>();
73     data = cast<ControlArgumentInterface>(arg.getOwner()->getParentOp())
74                .getDataValueOf(arg);
75   }
76   assert(!data.getType().isa<ControlType>() && "expected a data type");
77   return data;
78 }
79 
80 }  // namespace tfg
81 }  // namespace mlir
82