xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/logical_buffer_analysis.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
22 #include "tensorflow/compiler/xla/service/logical_buffer.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 
29 namespace {
30 
31 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)32 void GatherFusionInstructions(
33     HloInstruction* instruction,
34     std::vector<HloInstruction*>* fusion_instructions) {
35   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
36   for (auto* fused : instruction->fused_instructions()) {
37     if (fused->opcode() == HloOpcode::kFusion) {
38       GatherFusionInstructions(fused, fusion_instructions);
39     }
40   }
41   fusion_instructions->push_back(instruction);
42 }
43 
44 }  // namespace
45 
46 /* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
Run(const HloModule * module)47 LogicalBufferAnalysis::Run(const HloModule* module) {
48   std::unique_ptr<LogicalBufferAnalysis> analysis(
49       new LogicalBufferAnalysis(module));
50   TF_RETURN_IF_ERROR(analysis->Analyze());
51   return std::move(analysis);
52 }
53 
Analyze()54 Status LogicalBufferAnalysis::Analyze() {
55   // Empirically we usually have a few more logical buffers than instructions,
56   // so reserve 10% more than the number of instructions to avoid frequent
57   // resizes.
58   logical_buffers_.clear();
59   logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
60 
61   // We filter out fusion computations, and get to them through fusion
62   // instructions. This is because it's possible to have orphaned (unreachable)
63   // fusion computations, and we don't want to try to assign buffers to those.
64   std::vector<HloInstruction*> fusion_instructions;
65   for (auto* computation : module_->MakeNonfusionComputations()) {
66     TF_RETURN_IF_ERROR(computation->Accept(this));
67     for (auto* instruction : computation->instructions()) {
68       if (instruction->opcode() != HloOpcode::kFusion) {
69         continue;
70       }
71       GatherFusionInstructions(instruction, &fusion_instructions);
72     }
73   }
74   for (auto* instruction : fusion_instructions) {
75     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
76   }
77   return OkStatus();
78 }
79 
GetBuffer(LogicalBuffer::Id id) const80 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const {
81   return *logical_buffers_.at(id);
82 }
83 
GetBuffer(HloInstruction * instruction,const ShapeIndex & index) const84 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction,
85                                                 const ShapeIndex& index) const {
86   return *output_buffers_.at(std::make_pair(instruction, index));
87 }
88 
NewLogicalBuffer(HloInstruction * instruction,const ShapeIndex & index)89 void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
90                                              const ShapeIndex& index) {
91   LogicalBuffer::Id id = logical_buffers_.size();
92   auto buffer = std::make_unique<LogicalBuffer>(instruction, index, id);
93   auto position = std::make_pair(instruction, index);
94   CHECK(output_buffers_.insert({position, buffer.get()}).second);
95   logical_buffers_.push_back(std::move(buffer));
96 }
97 
DefaultAction(HloInstruction * hlo_instruction)98 Status LogicalBufferAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
99   // Create a logical buffer for each output of the instruction.
100   ShapeUtil::ForEachSubshape(
101       hlo_instruction->shape(),
102       [this, hlo_instruction](const Shape& shape, const ShapeIndex& index) {
103         NewLogicalBuffer(hlo_instruction, index);
104       });
105 
106   return OkStatus();
107 }
108 
HandleGetTupleElement(HloInstruction *)109 Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) {
110   // GetTupleElement does not create buffers.
111   return OkStatus();
112 }
113 
HandleAddDependency(HloInstruction * add_dependency)114 Status LogicalBufferAnalysis::HandleAddDependency(
115     HloInstruction* add_dependency) {
116   // AddDependency just forwards the value of its zero-th operand and does not
117   // create buffers.
118   return OkStatus();
119 }
120 
HandleCopy(HloInstruction * copy)121 Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) {
122   // The top-level buffer (index={}) for kCopy is newly created, but all other
123   // buffers (in the case of a tuple shape) come from the operand
124   NewLogicalBuffer(copy, /*index=*/{});
125   return OkStatus();
126 }
127 
HandleBitcast(HloInstruction *)128 Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
129   // A kBitcast instruction aliases its operand. That is, the buffer of its
130   // result *is* the buffer of its operand.
131   return OkStatus();
132 }
133 
HandleDomain(HloInstruction *)134 Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
135   // A kDomain instruction aliases its operand. That is, the buffer of its
136   // result *is* the buffer of its operand.
137   return OkStatus();
138 }
139 
HandleRecvDone(HloInstruction * recv_done)140 Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
141   // RecvDone produces a two-element tuple containing the data value (which
142   // aliases part of its operand) and a token. Only the tuple index table and
143   // the token are defined by the RecvDone.
144   NewLogicalBuffer(recv_done, /*index=*/{});
145   NewLogicalBuffer(recv_done, /*index=*/{1});
146   return OkStatus();
147 }
148 
HandleSend(HloInstruction * send)149 Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
150   // Send creates new buffers for the top-level tuple, the context (tuple
151   // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
152   // is an alias of the Send operand, so we don't need to create a new Logical
153   // Buffer for that.
154   NewLogicalBuffer(send, /*index=*/{});
155   NewLogicalBuffer(send, /*index=*/{1});
156   NewLogicalBuffer(send, /*index=*/{2});
157   return OkStatus();
158 }
159 
HandleCopyStart(HloInstruction * copy_start)160 Status LogicalBufferAnalysis::HandleCopyStart(HloInstruction* copy_start) {
161   // CopyStart defines the tuple, target buffer at index {0}, and context at
162   // index {2}.
163   NewLogicalBuffer(copy_start, /*index=*/{});
164   NewLogicalBuffer(copy_start, /*index=*/{0});
165   NewLogicalBuffer(copy_start, /*index=*/{2});
166   return OkStatus();
167 }
168 
HandleCopyDone(HloInstruction * copy_done)169 Status LogicalBufferAnalysis::HandleCopyDone(HloInstruction* copy_done) {
170   // The output of CopyDone aliases with operand {0}. CopyDone doesn't create
171   // any buffers.
172   return OkStatus();
173 }
174 
HandleTuple(HloInstruction * tuple)175 Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
176   // A Tuple instruction only creates the top-level buffer.
177   NewLogicalBuffer(tuple, /*index=*/{});
178   return OkStatus();
179 }
180 
HandleCustomCall(HloInstruction * custom_call)181 Status LogicalBufferAnalysis::HandleCustomCall(HloInstruction* custom_call) {
182   auto ccall = Cast<HloCustomCallInstruction>(custom_call);
183   absl::flat_hash_set<ShapeIndex> aliased_outputs;
184   for (const auto& pair : ccall->output_to_operand_aliasing()) {
185     aliased_outputs.insert(pair.first);
186   }
187   ShapeUtil::ForEachSubshape(ccall->shape(),
188                              [&](const Shape& shape, const ShapeIndex& index) {
189                                if (!aliased_outputs.contains(index)) {
190                                  NewLogicalBuffer(custom_call, index);
191                                }
192                              });
193   return OkStatus();
194 }
195 
196 }  // namespace xla
197