xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_query.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 
24 namespace xla {
25 namespace hlo_query {
26 
IsCollectiveCommunicationOp(HloOpcode op)27 bool IsCollectiveCommunicationOp(HloOpcode op) {
28   return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather ||
29          op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute ||
30          op == HloOpcode::kReduceScatter;
31 }
32 
IsConstantR0F32(HloInstruction * instruction,float * out)33 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
34   if (instruction->opcode() == HloOpcode::kConstant &&
35       ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
36     *out = instruction->literal().Get<float>({});
37     return true;
38   }
39 
40   return false;
41 }
42 
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)43 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
44   for (const auto& operand : instruction.operands()) {
45     if (operand->opcode() != HloOpcode::kParameter &&
46         operand->opcode() != HloOpcode::kConstant) {
47       return false;
48     }
49   }
50   return true;
51 }
52 
AllOperandsAreParameters(const HloInstruction & instruction)53 bool AllOperandsAreParameters(const HloInstruction& instruction) {
54   for (const auto& operand : instruction.operands()) {
55     if (operand->opcode() != HloOpcode::kParameter) {
56       return false;
57     }
58   }
59   return true;
60 }
61 
AllOperandsAreConstants(const HloInstruction & instruction)62 bool AllOperandsAreConstants(const HloInstruction& instruction) {
63   for (const auto& operand : instruction.operands()) {
64     if (operand->opcode() != HloOpcode::kConstant) {
65       return false;
66     }
67   }
68   return true;
69 }
70 
GetMatchingOperand(const HloPredicate & matcher,HloInstruction * instruction)71 HloInstruction* GetMatchingOperand(const HloPredicate& matcher,
72                                    HloInstruction* instruction) {
73   for (HloInstruction* op : instruction->operands()) {
74     if (matcher(op)) {
75       return op;
76     }
77   }
78   return nullptr;
79 }
80 
MatchBinaryInstructionOperand(const HloPredicate & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)81 bool MatchBinaryInstructionOperand(const HloPredicate& matcher,
82                                    HloInstruction* instruction,
83                                    HloInstruction** matching_operand,
84                                    HloInstruction** other_operand) {
85   CHECK_EQ(instruction->operand_count(), 2);
86   if (matcher(instruction->operand(0))) {
87     *matching_operand = instruction->mutable_operand(0);
88     *other_operand = instruction->mutable_operand(1);
89     return true;
90   }
91   if (matcher(instruction->operand(1))) {
92     *matching_operand = instruction->mutable_operand(1);
93     *other_operand = instruction->mutable_operand(0);
94     return true;
95   }
96   return false;
97 }
98 
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)99 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
100                                          HloInstruction* instruction,
101                                          HloInstruction** matching_operand,
102                                          HloInstruction** other_operand) {
103   return MatchBinaryInstructionOperand(
104       [opcode](const HloInstruction* instruction) {
105         return instruction->opcode() == opcode;
106       },
107       instruction, matching_operand, other_operand);
108 }
109 
IsScalarConstant(const HloInstruction * instruction)110 bool IsScalarConstant(const HloInstruction* instruction) {
111   return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
112 }
113 
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)114 bool ContainsInstrWithOpcode(const HloComputation* comp,
115                              const absl::flat_hash_set<HloOpcode>& opcodes) {
116   for (const auto* instr : comp->instructions()) {
117     if (opcodes.count(instr->opcode())) {
118       return true;
119     }
120     for (const HloComputation* subcomp : instr->called_computations()) {
121       if (ContainsInstrWithOpcode(subcomp, opcodes)) {
122         return true;
123       }
124     }
125   }
126   return false;
127 }
128 
ContainsLayoutConstrainedCollective(const HloModule & module,HloOpcode op)129 bool ContainsLayoutConstrainedCollective(const HloModule& module,
130                                          HloOpcode op) {
131   CHECK(IsCollectiveCommunicationOp(op));
132 
133   for (auto computation : module.computations()) {
134     for (auto hlo : computation->instructions()) {
135       if (hlo->opcode() == op &&
136           DynCast<HloCollectiveInstruction>(hlo)->constrain_layout()) {
137         return true;
138       }
139     }
140   }
141   return false;
142 }
143 
NextChannelId(const HloModule & module)144 int64_t NextChannelId(const HloModule& module) {
145   int64_t next_channel_id = 1;
146   for (const HloComputation* comp : module.computations()) {
147     for (const HloInstruction* hlo : comp->instructions()) {
148       const HloChannelInstruction* channel_instr =
149           DynCast<HloChannelInstruction>(hlo);
150       if (channel_instr && channel_instr->channel_id()) {
151         next_channel_id =
152             std::max(next_channel_id, *channel_instr->channel_id() + 1);
153       }
154     }
155   }
156   return next_channel_id;
157 }
158 
HasX64TransformedHostTransfer(const HloModule & module)159 bool HasX64TransformedHostTransfer(const HloModule& module) {
160   for (auto computation : module.computations()) {
161     for (auto hlo : computation->instructions()) {
162       if (hlo->opcode() == HloOpcode::kSend) {
163         auto send = DynCast<HloSendInstruction>(hlo);
164         if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
165           return true;
166         }
167       } else if (hlo->opcode() == HloOpcode::kRecv) {
168         auto recv = DynCast<HloRecvInstruction>(hlo);
169         if (recv->is_host_transfer() &&
170             recv->shape().tuple_shapes(0).IsTuple()) {
171           return true;
172         }
173       }
174     }
175   }
176   return false;
177 }
178 
179 }  // namespace hlo_query
180 }  // namespace xla
181