1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/buffer_value.h"
34 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_dce.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
44 #include "tensorflow/compiler/xla/service/hlo_query.h"
45 #include "tensorflow/compiler/xla/service/logical_buffer.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/statusor.h"
48 #include "tensorflow/compiler/xla/types.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/core/platform/logging.h"
51
52 namespace xla {
53 namespace {
54
55 using ::tensorflow::strings::HumanReadableNumBytes;
56
57 // Potential optimizations:
58 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
59 // of candidates.
60 // . Cache IsRematerializable in Item? Only correct if control
61 // predecessors and successors don't change.
62
63 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)64 bool IsRematerializable(const HloInstruction* instruction) {
65 if (instruction->opcode() == HloOpcode::kCopy) {
66 if (LayoutUtil::Equal(instruction->shape().layout(),
67 instruction->operand(0)->shape().layout())) {
68 // Don't rematerialize copies added by copy insertion (layout doesn't
69 // change).
70 return false;
71 }
72 }
73
74 if (auto collective = DynCast<HloCollectiveInstruction>(instruction)) {
75 return !collective->constrain_layout();
76 }
77
78 // Don't rematerialize instructions with side effects or instructions which
79 // cannot be cloned safely.
80 switch (instruction->opcode()) {
81 case HloOpcode::kCall:
82 case HloOpcode::kConstant:
83 case HloOpcode::kConditional:
84 case HloOpcode::kCustomCall:
85 case HloOpcode::kParameter:
86 case HloOpcode::kWhile:
87 return false;
88 default:
89 return !instruction->HasSideEffect();
90 }
91 }
92
93 // Checks whether an instruction can be rematerialized, by looking up the
94 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)95 bool CanBeRematerialized(
96 const HloInstruction* instruction,
97 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
98 auto it = rematerializable_map->find(instruction);
99 if (it != rematerializable_map->end()) {
100 return it->second;
101 }
102 bool rematerializable = IsRematerializable(instruction);
103 (*rematerializable_map)[instruction] = rematerializable;
104 return rematerializable;
105 }
106
107 // Return if this is an instruction that relays the buffers it uses to its own
108 // users and if this is one of these instructions we support the
109 // rematerialization of.
IsSupportedIndirectUser(const HloInstruction * instruction)110 bool IsSupportedIndirectUser(const HloInstruction* instruction) {
111 return instruction->opcode() == HloOpcode::kBitcast ||
112 instruction->opcode() == HloOpcode::kGetTupleElement;
113 }
114
115 // Type holding a unique identifier for each Buffer object.
116 using BufferId = int64_t;
117 using BufferIdList = absl::InlinedVector<BufferId, 3>;
118
119 struct RematStrategy {
120 enum {
121 // Recompute the node at a later program point.
122 kRecompute,
123 // Change the layout into a compact form and uncompress it back at a later
124 // program point.
125 kCompress,
126 } kind;
127 Shape compact_shape;
128 };
129
130 // We wrap HloInstruction* with an Item that holds auxiliary
131 // per-instruction state.
132 struct Item {
133 HloInstruction* instruction;
134
135 // True once the instruction is marked as placed (when BeginInstruction
136 // has been called for this instruction).
137 bool placed = false;
138
139 // To avoid an infinite loop rematerializing the same set of
140 // instructions ad infinitum, keep a denylist of instructions
141 // which should not be rematerialized.
142 bool denylisted = false;
143
144 // The buffers defined by this instruction.
145 BufferIdList buffers_defined;
146
147 // Output buffers of this instruction. This is used to track outputs by GTE
148 // instructions (where the instruction doesn't define a buffer).
149 BufferIdList buffers_output;
150
151 // The buffers used by this instruction.
152 BufferIdList buffers_used;
153
154 bool is_skip_node = false;
155
156 private:
157 friend class InstructionList;
158
159 // Items are arranged in a doubly linked list.
160 Item* next = nullptr;
161 Item* prev = nullptr;
162
163 Item* prev_skip_node = nullptr;
164 Item* next_skip_node = nullptr;
165
166 // List is ordered by position, which can however be duplicated as
167 // new instructions are inserted. See InsertBeforeInstructions
168 // comment for details.
169 int64_t position;
170 };
171
172 // Data structure meant to record the user of the buffer defined from an Item.
173 // It records also the operand_number from where such use derives, so that
174 // indirect uses can be better identified (like for example a buffer used
175 // through a bitcast).
176 struct ItemUse {
177 Item* user;
178 int64_t operand_number;
179 std::optional<int64_t> index;
180
ItemUsexla::__anon2b46f4430111::ItemUse181 ItemUse(Item* user, int64_t op_num, std::optional<int64_t> index)
182 : user(user), operand_number(op_num), index(index) {}
operator ==xla::__anon2b46f4430111::ItemUse183 bool operator==(const ItemUse& other) const {
184 return user == other.user && operand_number == other.operand_number &&
185 index == other.index;
186 }
187 };
188
189 using ItemList = absl::InlinedVector<Item*, 3>;
190 using UsesList = absl::InlinedVector<ItemUse, 3>;
191
192 // Class which maintains an ordered list of instructions with fast insertion
193 // before arbitrary elements.
194 //
195 // This is a skip list structure that has two lanes: express lane and slow lane.
196 // All nodes are presented on the slow lane but a node can be promoted into
197 // express lane for fast iteration.
198 //
199 // In the following case, node 2 and node + 1 are connected via an express lane.
200 // +--------------------------+----------->: Express lane
201 // | |
202 // node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
203 //
204 class InstructionList {
205 public:
InstructionList(const HloInstructionSequence & order)206 explicit InstructionList(const HloInstructionSequence& order) {
207 int64_t position = 0;
208 Item* last = nullptr;
209 last_skip_node_ = nullptr;
210 first_skip_node_ = nullptr;
211 for (HloInstruction* inst : order.instructions()) {
212 // Add a new item to the linked list.
213 Item* item = new Item;
214 item->next = nullptr;
215 item->prev = last;
216 if (last == nullptr) {
217 first_ = item;
218 } else {
219 last->next = item;
220 }
221 last = item;
222
223 // Initially position numbers are uniquely assigned in order. Later as
224 // instructions are added with InsertBefore* methods, some instructions
225 // may have duplicate position numbers, but the values will be guaranteed
226 // to be monotonically increasing through the list, and so is still useful
227 // for quickly(-ish) determining the order of arbitrary instructions in
228 // the list.
229 item->instruction = inst;
230 item->position = position;
231 position++;
232
233 item_map_[inst] = item;
234 }
235 }
236
~InstructionList()237 ~InstructionList() {
238 for (Item* item = first_; item != nullptr;) {
239 Item* next = item->next;
240 delete item;
241 item = next;
242 }
243 }
244
size() const245 size_t size() const { return item_map_.size(); }
246
247 // For ordered iteration over items.
248 // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const249 Item* first() const { return first_; }
next(Item * item) const250 Item* next(Item* item) const { return item->next; }
251
first_skip_node() const252 Item* first_skip_node() const { return first_skip_node_; }
next_skip_node(Item * item) const253 Item* next_skip_node(Item* item) const { return item->next_skip_node; }
254
255 // Creates an Item for the given instruction, but doesn't add it to the list.
256 // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)257 Item* CreateItem(HloInstruction* inst) {
258 Item* item = new Item;
259 item->instruction = inst;
260 CHECK(item_map_.insert({inst, item}).second)
261 << "inserting inst twice " << inst->name();
262 return item;
263 }
264
265 // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const266 Item* GetItem(const HloInstruction* inst) const {
267 auto iter = item_map_.find(inst);
268 CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
269 return iter->second;
270 }
271
272 // Insert instruction 'to_insert' immediately before the earliest instruction
273 // in 'before_instructions'.
274 //
275 // Each instruction gets a non-decreasing ordinal number. We use this to let
276 // InsertBeforeInstructions quickly insert an instruction before the earliest
277 // instruction in a set of instructions. If position_number_[a] <
278 // position_number_[b] then 'a' comes before 'b' in the list. If the position
279 // numbers are the same then nothing can be said about their order without
280 // examining the list.
281 //
282 // On object construction this ordinal is precisely the instruction's index
283 // in the list. Later, instructions inserted via InsertBefore receive
284 // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)285 void InsertBeforeInstructions(Item* to_insert,
286 absl::Span<Item* const> before_instructions) {
287 VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
288 << " before {"
289 << absl::StrJoin(before_instructions, ", ",
290 [](std::string* out, Item* item) {
291 absl::StrAppend(out, item->instruction->name());
292 })
293 << "}";
294
295 // Find the minimal position number of any instruction in
296 // 'before_instructions'.
297 CHECK(!before_instructions.empty());
298 Item* min_position_item = nullptr;
299 for (Item* item : before_instructions) {
300 if (min_position_item == nullptr ||
301 item->position < min_position_item->position) {
302 min_position_item = item;
303 }
304 }
305
306 // Because more than one instruction in 'before_instructions' may have a
307 // position number of 'min_position_number', find the first such instruction
308 // with position number 'min_position_number'.
309
310 // First find first instruction with the min position.
311 while (min_position_item->prev != nullptr &&
312 min_position_item->position == min_position_item->prev->position) {
313 min_position_item = min_position_item->prev;
314 }
315
316 // Now scan forwards until we find one of the before_instructions.
317 while (!absl::c_linear_search(before_instructions, min_position_item)) {
318 min_position_item = min_position_item->next;
319 }
320 return InsertBefore(to_insert, min_position_item);
321 }
322
323 // Scan the list and promote nodes to express lane if should_promote(Item)
324 // returns true;
PromoteNodesToSkip(std::function<bool (Item *)> should_promote)325 void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
326 int64_t count = 0;
327 for (auto* item = first(); item != nullptr; item = next(item)) {
328 if (should_promote(item)) {
329 count += 1;
330 if (first_skip_node_ == nullptr) {
331 first_skip_node_ = item;
332 }
333 item->is_skip_node = true;
334 item->prev_skip_node = last_skip_node_;
335 if (last_skip_node_ != nullptr) {
336 last_skip_node_->next_skip_node = item;
337 }
338 last_skip_node_ = item;
339 }
340 }
341 VLOG(1) << " Rematerialization has " << count << " items in express lane";
342 }
343
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)344 void InsertAfterInstructions(Item* to_insert,
345 absl::Span<Item* const> after_instructions) {
346 VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
347 << " after {"
348 << absl::StrJoin(after_instructions, ", ",
349 [](std::string* out, Item* item) {
350 absl::StrAppend(out, item->instruction->name());
351 })
352 << "}";
353
354 // Find the max position number of any instruction in
355 // 'after_instructions'.
356 CHECK(!after_instructions.empty());
357 Item* max_position_item = nullptr;
358 for (Item* item : after_instructions) {
359 if (max_position_item == nullptr ||
360 item->position > max_position_item->position) {
361 max_position_item = item;
362 }
363 }
364 // No rematerializable instruction should be inserted at the end of the
365 // computation.
366 CHECK(max_position_item->next != nullptr);
367 InsertBeforeInstructions(to_insert, {max_position_item->next});
368 }
369
Denylist(const HloInstruction * inst)370 void Denylist(const HloInstruction* inst) {
371 GetItem(inst)->denylisted = true;
372 }
373
374 private:
375 // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)376 void InsertBefore(Item* item, Item* before) {
377 VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
378 << before->instruction->name();
379 // Always place new nodes on express lane for the ease of implementation.
380 item->is_skip_node = true;
381 // Find the next express node starting from 'before'. Set up the node's
382 // express pointers.
383 Item* cursor = before;
384 while (cursor != nullptr && !cursor->is_skip_node) {
385 cursor = cursor->next;
386 }
387 CHECK(cursor == nullptr || cursor->is_skip_node);
388 if (cursor == nullptr) {
389 //
390 // last_skip_node_<---+ : express lane
391 // |
392 // ...<->`item`<-> .. <-> `cursor`(null) : slow lane
393 //
394 // Reached the end. Set the prev_express to last_skip_node, and reset
395 // last_skip.
396 item->prev_skip_node = last_skip_node_;
397 item->next_skip_node = nullptr;
398 last_skip_node_ = item;
399 } else {
400 //
401 // <-+------------+----------------+---------> : express lane
402 // | | |
403 // prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
404 //
405 // Reached the next skip node, sets up express pointers accordingly.
406 CHECK(cursor->is_skip_node);
407 item->prev_skip_node = cursor->prev_skip_node;
408 if (item->prev_skip_node != nullptr) {
409 item->prev_skip_node->next_skip_node = item;
410 }
411 item->next_skip_node = cursor;
412 cursor->prev_skip_node = item;
413 }
414 if (first_skip_node_ == cursor) {
415 first_skip_node_ = item;
416 }
417 // Insert new item into linked list.
418 item->prev = before->prev;
419 item->next = before;
420 before->prev = item;
421 if (item->prev != nullptr) {
422 item->prev->next = item;
423 } else {
424 first_ = item;
425 }
426
427 // Assign the same position number to the newly added instruction as
428 // 'before'. This guarantees monotonicity of the position numbers, but not
429 // uniqueness.
430 item->position = before->position;
431 }
432
433 Item* first_;
434
435 // First skip node of this list.
436 Item* first_skip_node_;
437
438 // Last skip node of this list.
439 Item* last_skip_node_;
440
441 // Item for each instruction.
442 absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
443 };
444
445 // Return the items which use the given LogicalBuffer. Sets
446 // has_indirect_users to whether any of the uses is indirect. A use is indirect
447 // if the instruction defining logical_buffer is not an operand of the use. This
448 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)449 UsesList GetUsers(const InstructionList& instruction_list,
450 const LogicalBuffer* logical_buffer,
451 const TuplePointsToAnalysis& points_to_analysis,
452 bool* has_indirect_users) {
453 UsesList users;
454 // To identify uses iterate through all HloInstruction users of the
455 // BufferAliases of the logical buffer.
456 *has_indirect_users = false;
457 for (const BufferAlias& buffer_alias :
458 points_to_analysis.GetBufferAliases(*logical_buffer)) {
459 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
460 if (points_to_analysis.DoesNotUseOperandBuffer(
461 buffer_alias.instruction(), buffer_alias.index(), user)) {
462 // The alias may be an operand of 'user', but the LogicalBuffer cannot
463 // possibly be used by the instruction so ignore 'user'. This is the
464 // case, for example, for the tuple element buffers in a GetTupleElement
465 // instruction (the GTE instruction only uses the pointer vector).
466 continue;
467 }
468 if (buffer_alias.instruction() != logical_buffer->instruction() &&
469 !IsSupportedIndirectUser(buffer_alias.instruction())) {
470 *has_indirect_users = true;
471 }
472 // A buffer may be used by the instruction via more than one alias. For
473 // example, a buffer which appears in more than one element of a tuple.
474 Item* user_item = instruction_list.GetItem(user);
475 std::optional<int64_t> user_index =
476 logical_buffer->index().size() != 1
477 ? std::nullopt
478 : std::make_optional(logical_buffer->index().back());
479 for (int64_t op_idx : user->OperandIndices(buffer_alias.instruction())) {
480 if (!absl::c_linear_search(
481 users,
482 ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
483 users.push_back(
484 ItemUse{user_item, static_cast<int>(op_idx), user_index});
485 }
486 }
487 }
488 }
489 return users;
490 }
491
492 // Class for tracking memory usage of a computation as the instructions are
493 // placed sequentially. Memory usage is the sum of the sizes of live values
494 // (LogicalBuffers) at the current point in the instruction sequence.
495 class MemoryUsageTracker {
496 public:
497 MemoryUsageTracker(
498 const HloComputation* computation,
499 const HloRematerialization::ShapeSizeFunction& size_function,
500 const HloRematerialization::CompactShapeFunction& compact_shape_function,
501 const TuplePointsToAnalysis& points_to_analysis,
502 const InstructionList& instruction_list,
503 HloRematerialization::RematerializationMode mode);
504
505 // Starts the placement of the given instruction. This adds the sizes of the
506 // LogicalBuffers defined by the instruction to the current memory
507 // usage. Placement is broken into two steps (BeginInstruction and
508 // EndInstruction) to accurately model memory usage. At BeginInstruction the
509 // memory for the output value(s) of the current instruction is allocated. At
510 // EndInstruction memory for dead operand(s) is freed.
511 Status BeginInstruction(Item* item);
512
RematerializationCost(const std::vector<Item * > & items,int64_t memory_reduced,int64_t memory_limit_bytes)513 int64_t RematerializationCost(const std::vector<Item*>& items,
514 int64_t memory_reduced,
515 int64_t memory_limit_bytes) {
516 // If none of the users of any 'item' have been placed in the
517 // sequence (as tracked by memory_tracker), then rematerialization of
518 // 'item' is a zero-cost move of 'item->instruction' in the sequence.
519 bool zero_cost_move = true;
520 for (auto* item : items) {
521 auto* instruction = item->instruction;
522 if (absl::c_any_of(
523 instruction->users(),
524 [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
525 zero_cost_move = false;
526 break;
527 }
528 }
529 if (zero_cost_move) {
530 return 0;
531 }
532
533 CHECK_GT(memory_reduced, 0);
534 // Return the inverse of the benefit of rematerialization.
535 return memory_limit_bytes / memory_reduced;
536 }
537
538 // Finishes the placement of the current instruction. This frees any dead
539 // operands or dead result of the instruction. This must be called after
540 // each call to BeginInstruction.
541 Status EndInstruction();
542
543 // Returns the number of bytes that the current memory usage will be reduced
544 // if the given instruction is compact.
545 int64_t MemoryReducedIfCompressed(Item* item,
546 const Shape& compact_shape) const;
547
548 // Returns the number of bytes that the current memory usage will be reduced
549 // by if the given sequence of instructions is rematerialized.
550 int64_t MemoryReducedIfRematerialized(
551 absl::Span<const Item* const> items) const;
552
553 Status AddCompressInstructions(Item* original_item, Item* compressed_item,
554 Item* uncompressed_item);
555
556 // Adjusts memory usage to account for the rematerialization of
557 // original_item for all remaining unplaced uses. The rematerialization
558 // is remat_item. This method should be called after the HLO graph has
559 // been transformed (rematerialization instruction created and connected
560 // to uses).
561 Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
562 absl::Span<Item*> indirect_users);
563
564 // Selects and returns the best candidate instructions for rematerialization.
565 // A sequence of candidate instructions of length between min_block_size and
566 // max_block_size (both inclusive) with the lowest rematerialization cost is
567 // selected among those candidates which reduce memory use at the program
568 // point of the current instruction as indicated by memory_tracker. Returns an
569 // empty vector if no candidates are found. Also returns an integer that
570 // represents the amount of "effort" expended to find the candidate
571 // instructions.
572 std::tuple<std::vector<Item*>, RematStrategy, int>
573 PickRematerializationCandidates(
574 const InstructionList& instruction_list, int64_t memory_limit_bytes,
575 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
576 int min_block_size, int max_block_size, int64_t peak_memory_bytes);
577
578 // Returns whether the given instruction has been placed (BeginInstruction
579 // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const580 bool IsPlaced(const HloInstruction* instruction) const {
581 return instruction_list_.GetItem(instruction)->placed;
582 }
583
584 // Returns whether 'item' has any unplaced users.
585 bool HasUnplacedUsers(Item* item) const;
586
587 // Returns the list of uses for a specific 'item'.
588 const UsesList GetItemUses(Item* item) const;
589
590 // Returns whether 'item' is currently in progress.
IsInProgressItem(Item * item) const591 bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
592
593 // Returns the current memory usage. This is the sum of sizes of all live
594 // values.
memory_usage() const595 int64_t memory_usage() const { return memory_usage_; }
596
597 //
AllocatedSize(Item * item) const598 int64_t AllocatedSize(Item* item) const {
599 int64_t size = 0;
600 for (auto buffer_id : item->buffers_defined) {
601 size += AllocatedSize(buffer_id);
602 }
603 return size;
604 }
605
computation() const606 const HloComputation* computation() const { return computation_; }
607
608 // Check invariants of the data structure. This is expensive to call.
609 bool Check() const;
610
611 std::string ToString() const;
612
613 private:
614 // A Buffer represents a single LogicalBuffer in the computation including
615 // various metadata useful for tracking liveness of the value. A LogicalBuffer
616 // is not used directly because the HLO graph is transformed and
617 // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
618 // HLO graph transformations.
619 struct Buffer {
620 // The unique id of this Buffer. This value is equal to the buffer's index
621 // in the vector buffers_.
622 const BufferId id;
623
624 // The instruction which defines this buffer.
625 Item* defining_instruction;
626
627 // The materialized size of the buffer in bytes.
628 const int64_t size;
629
630 // Shape of the buffer.
631 Shape shape;
632
633 // Whether this buffer is live-out of the computation.
634 bool live_out;
635
636 // Whether this buffer has indirect uses. Ie, an instruction which is not a
637 // user of defining_instruction uses this buffer. This can occur due to
638 // buffer aliasing (eg, tuples).
639 bool has_indirect_uses;
640
641 // Position in the tuple this buffer definition lives in.
642 ShapeIndex index;
643
644 // The instructions which use this buffer.
645 UsesList users;
646
647 // The number of users (HloInstructions) of this buffer which have not yet
648 // been placed in the sequence.
649 int64_t unfinished_user_count;
650
ToStringxla::__anon2b46f4430111::MemoryUsageTracker::Buffer651 std::string ToString() const {
652 return absl::StrCat("Buffer ", id, " (defined by ",
653 defining_instruction->instruction->name(), ", size ",
654 size, " bytes)");
655 }
656 };
657
658 // Get the compact shape of given hlo instruction. An internal cache is used
659 // to avoid computing the shape multiple times.
660 StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
661
662 // Creates a Buffer representing the given logical buffer. The buffer is added
663 // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)664 Buffer& CreateBufferFromLogicalBuffer(
665 const LogicalBuffer* logical_buffer,
666 const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
667 bool has_indirect_uses = false;
668 UsesList users = GetUsers(instruction_list_, logical_buffer,
669 points_to_analysis, &has_indirect_uses);
670 return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
671 logical_buffer->shape(), logical_buffer->index(),
672 std::move(users), live_out, has_indirect_uses);
673 }
674
675 // Create a new buffer representing a rematerialization of given buffer for
676 // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,UsesList && rematerialized_uses)677 Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
678 UsesList&& rematerialized_uses) {
679 CHECK(original_buffer.defining_instruction->placed)
680 << original_buffer.defining_instruction->instruction->name();
681 CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
682 CHECK(!original_buffer.live_out) << original_buffer.ToString();
683 for (ItemUse& use : rematerialized_uses) {
684 CHECK(!use.user->placed) << use.user->instruction->name();
685 }
686 return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
687 std::move(rematerialized_uses), /*live_out=*/false,
688 /*has_indirect_uses=*/false);
689 }
690
691 // Return number of bytes allocated for the buffer with the given id. Buffers
692 // allocated by the calling computation (eg, parameter and output buffers) are
693 // considered to have zero bytes because the memory is accounted for in a
694 // different computation.
AllocatedSize(BufferId buffer_id) const695 int64_t AllocatedSize(BufferId buffer_id) const {
696 const Buffer& buffer = buffers_.at(buffer_id);
697 HloInstruction* inst = buffer.defining_instruction->instruction;
698 HloOpcode def_opcode = inst->opcode();
699 if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
700 return 0;
701 } else {
702 return buffer.size;
703 }
704 }
705
706 // Returns true if BeginInstruction and EndInstruction has been called for the
707 // given instruction.
IsFinished(Item * item) const708 bool IsFinished(Item* item) const {
709 return item->placed && item != in_progress_item_;
710 }
711
712 // Returns whether the given buffer is being used by the in-progress
713 // instruction.
IsInUse(BufferId buffer_id) const714 bool IsInUse(BufferId buffer_id) const {
715 if (in_progress_item_ == nullptr) {
716 return false;
717 }
718 const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
719 return absl::c_linear_search(in_progress_uses, buffer_id);
720 }
721
IsCurrentlyLive(BufferId buffer_id) const722 bool IsCurrentlyLive(BufferId buffer_id) const {
723 const Buffer& buffer = buffers_[buffer_id];
724 return (buffer.defining_instruction->placed &&
725 buffer.unfinished_user_count > 0);
726 }
727
728 // Returns whether the given instruction is live at the current program
729 // point.
IsInstructionCurrentlyLive(Item * instruction) const730 bool IsInstructionCurrentlyLive(Item* instruction) const {
731 // If the instruction has not started yet, it is not alive.
732 if (!IsPlaced(instruction->instruction)) {
733 return false;
734 }
735 for (const HloInstruction* user : instruction->instruction->users()) {
736 if (!IsPlaced(user)) {
737 // If there is an unplaced user, consider this instruction currently
738 // live.
739 return true;
740 }
741 }
742 return false;
743 }
744
745 // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,const ShapeIndex & index,UsesList && uses,bool live_out,bool has_indirect_uses)746 Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
747 const ShapeIndex& index, UsesList&& uses, bool live_out,
748 bool has_indirect_uses) {
749 int buffer_id = buffers_.size();
750 auto get_num_of_unique_users = [](const UsesList& uses) -> int64_t {
751 absl::flat_hash_set<Item*> users_set;
752 for (const ItemUse& use : uses) {
753 users_set.insert(use.user);
754 }
755 return users_set.size();
756 };
757 buffers_.push_back(Buffer{
758 buffer_id, defining_instruction, size_function_(shape), shape, live_out,
759 has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
760 return buffers_.back();
761 }
762
763 const HloComputation* computation_;
764
765 // Instruction list containing the ordering of instructions in
766 // computation_. This is the order in which instructions are placed
767 // (BeginInstruction/EndInstruction calls).
768 const InstructionList& instruction_list_;
769
770 // Size function returns the bytes of a given buffer.
771 const HloRematerialization::ShapeSizeFunction& size_function_;
772
773 // Converts a shape into compact form, returns the same shape if a shape is
774 // already considered compact.
775 const HloRematerialization::CompactShapeFunction& compact_shape_function_;
776
777 // A map that caches existing known compact shape for each instruction.
778 absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
779
780 // Memory usage at the currently placed instruction.
781 int64_t memory_usage_ = 0;
782
783 // The instruction currently being placed. This value is non-null only
784 // between the calling of BeginInstruction and EndInstruction.
785 Item* in_progress_item_ = nullptr;
786
787 HloRematerialization::RematerializationMode mode_;
788 // All buffers in the computation.
789 std::vector<Buffer> buffers_;
790 };
791
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)792 MemoryUsageTracker::MemoryUsageTracker(
793 const HloComputation* computation,
794 const HloRematerialization::ShapeSizeFunction& size_function,
795 const HloRematerialization::CompactShapeFunction& compact_shape_function,
796 const TuplePointsToAnalysis& points_to_analysis,
797 const InstructionList& instruction_list,
798 HloRematerialization::RematerializationMode mode)
799 : computation_(computation),
800 instruction_list_(instruction_list),
801 size_function_(size_function),
802 compact_shape_function_(compact_shape_function),
803 mode_(mode) {
804 PointsToSet::BufferSet live_out_set =
805 points_to_analysis.GetPointsToSet(computation_->root_instruction())
806 .CreateFlattenedSet();
807 absl::flat_hash_map<const LogicalBuffer*, BufferId>
808 logical_buffer_to_buffer_id;
809 for (auto* item = instruction_list_.first(); item != nullptr;
810 item = instruction_list_.next(item)) {
811 const HloInstruction* const instruction = item->instruction;
812 for (const LogicalBuffer* logical_buffer :
813 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
814 Buffer* buffer;
815 if (instruction->opcode() == HloOpcode::kWhile) {
816 // The while instruction defines no new buffers. Instead it reuses the
817 // buffers of its operand. Find the Buffer of its operand at the
818 // proper ShapeIndex.
819 const PointsToSet& operand_points_to =
820 points_to_analysis.GetPointsToSet(instruction->operand(0));
821 CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
822 const LogicalBuffer* source_logical_buffer =
823 operand_points_to.element(logical_buffer->index())[0];
824 buffer =
825 &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
826
827 // Mark buffer as has indirect use and live out.
828 buffer->has_indirect_uses = true;
829 buffer->live_out =
830 buffer->live_out || ContainsKey(live_out_set, logical_buffer);
831
832 // Add users of while to Buffer users.
833 bool unused;
834 for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
835 points_to_analysis, &unused)) {
836 auto existing_user_it = absl::c_find_if(
837 buffer->users,
838 [&](const ItemUse& use) { return user_item.user == use.user; });
839 if (existing_user_it == buffer->users.end()) {
840 buffer->unfinished_user_count++;
841 user_item.user->buffers_used.push_back(buffer->id);
842 buffer->users.push_back(user_item);
843 }
844 }
845 } else {
846 buffer = &CreateBufferFromLogicalBuffer(
847 logical_buffer, points_to_analysis,
848 ContainsKey(live_out_set, logical_buffer));
849 item->buffers_defined.push_back(buffer->id);
850 for (ItemUse& user : buffer->users) {
851 if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
852 user.user->buffers_used.push_back(buffer->id);
853 }
854 }
855 }
856
857 logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
858 }
859
860 // Trace the output of each instruction. This is so that we can properly
861 // track which outputs does GTEs have.
862 for (const LogicalBuffer* logical_buffer :
863 points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
864 item->buffers_output.push_back(
865 logical_buffer_to_buffer_id[logical_buffer]);
866 }
867 }
868 XLA_VLOG_LINES(10, ToString());
869 DCHECK(Check());
870 }
871
BeginInstruction(Item * item)872 Status MemoryUsageTracker::BeginInstruction(Item* item) {
873 const HloInstruction* instruction = item->instruction;
874 VLOG(3) << "BeginInstruction " << instruction->name();
875 TF_RET_CHECK(in_progress_item_ == nullptr);
876 in_progress_item_ = item;
877
878 item->placed = true;
879
880 // All buffers defined by this instruction need memory.
881 for (BufferId buffer_id : item->buffers_defined) {
882 VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
883 << " is now live.";
884 memory_usage_ += AllocatedSize(buffer_id);
885 }
886
887 // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
888 // operand. Account for this potential reuse here.
889
890 VLOG(3) << " memory usage = " << memory_usage_;
891 VLOG(10) << ToString();
892
893 if (VLOG_IS_ON(1)) {
894 DCHECK(Check());
895 }
896 return OkStatus();
897 }
898
EndInstruction()899 Status MemoryUsageTracker::EndInstruction() {
900 TF_RET_CHECK(in_progress_item_ != nullptr);
901 VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
902
903 for (BufferId buffer_id : in_progress_item_->buffers_used) {
904 Buffer& buffer = buffers_.at(buffer_id);
905 buffer.unfinished_user_count--;
906 TF_RET_CHECK(buffer.unfinished_user_count >= 0)
907 << buffer.ToString() << " has negative unfinished user count.";
908 if (buffer.unfinished_user_count == 0) {
909 // Buffer is now dead.
910 VLOG(3) << " " << buffer.ToString() << " is now dead.";
911 memory_usage_ -= AllocatedSize(buffer_id);
912 // The memory usage can become negative inside the computation as we can
913 // free up the parameter space and reuse it for other tensors.
914 }
915 }
916
917 // If any buffer defined by this instruction has no uses, then memory can be
918 // reclaimed immediately.
919 for (BufferId buffer_id : in_progress_item_->buffers_defined) {
920 const Buffer& buffer = buffers_.at(buffer_id);
921 if (buffer.unfinished_user_count == 0) {
922 VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
923 memory_usage_ -= AllocatedSize(buffer_id);
924 // The memory usage can become negative inside the computation as we can
925 // free up the parameter space and reuse it for other tensors.
926 }
927 }
928
929 in_progress_item_ = nullptr;
930
931 VLOG(3) << " memory usage = " << memory_usage_;
932 VLOG(10) << ToString();
933
934 if (VLOG_IS_ON(1)) {
935 DCHECK(Check());
936 }
937 return OkStatus();
938 }
939
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const940 int64_t MemoryUsageTracker::MemoryReducedIfCompressed(
941 Item* item, const Shape& compact_shape) const {
942 CHECK_NE(in_progress_item_, nullptr);
943 if (!item->placed || item == in_progress_item_) {
944 return 0;
945 }
946
947 int64_t memory_reduced = 0;
948
949 // We only compress a single piece of an output at one time.
950 CHECK_EQ(item->buffers_output.size(), 1);
951 BufferId buffer_id = item->buffers_output[0];
952 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
953 IsInstructionCurrentlyLive(item)) {
954 const Buffer& buffer = buffers_.at(buffer_id);
955 memory_reduced += buffer.size;
956
957 int64_t compact_shape_size = size_function_(compact_shape);
958 // Account for buffers that are compressed after instruction.
959 memory_reduced -= compact_shape_size;
960 }
961 return memory_reduced;
962 }
963
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const964 int64_t MemoryUsageTracker::MemoryReducedIfRematerialized(
965 absl::Span<const Item* const> items) const {
966 CHECK_NE(in_progress_item_, nullptr);
967 int64_t memory_reduced = 0;
968 absl::flat_hash_set<const Item*> remat_candidates;
969
970 for (const Item* item : items) {
971 if (!item->placed || item == in_progress_item_) {
972 LOG(WARNING) << "Unplaced item or in progress item being checked for "
973 "rematerialization.";
974 return 0;
975 }
976
977 // Compute the amount of memory reduced (if any) by rematerializing
978 // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
979 // will no longer be live at this program point, so initially set
980 // memory_reduced to the size of its defined values.
981 for (BufferId buffer_id : item->buffers_defined) {
982 const Buffer& buffer = buffers_.at(buffer_id);
983 // Avoid rematerializing instructions with indirect uses as it is
984 // difficult to reason about liveness after rematerializing the
985 // instruction.
986 // Avoid rematerializing instructions with live out buffers.
987 // Avoid rematerializing buffers that are in nested tuples.
988 // TODO(mpurohit): Check why live_out buffers are an issue here.
989 if (buffer.has_indirect_uses || buffer.live_out ||
990 buffer.index.size() > 1) {
991 return 0;
992 }
993 if (IsInUse(buffer_id)) {
994 return 0;
995 }
996 if (IsCurrentlyLive(buffer_id)) {
997 memory_reduced += AllocatedSize(buffer_id);
998 }
999 }
1000
1001 // Account for any logical buffers whose live range must be extended across
1002 // this program point.
1003 for (BufferId buffer_id : item->buffers_used) {
1004 if (!IsCurrentlyLive(buffer_id)) {
1005 // This logical buffer is used by 'item->instruction' but is not live at
1006 // this program point. Rematerializing 'item->instruction' will extend
1007 // the buffer's live range across this program point unless it is
1008 // defined by an instruction that is also being rematerialized.
1009 Item* defining_instruction =
1010 buffers_.at(buffer_id).defining_instruction;
1011 if (!remat_candidates.contains(defining_instruction)) {
1012 memory_reduced -= AllocatedSize(buffer_id);
1013 }
1014 }
1015 }
1016 remat_candidates.insert(item);
1017 }
1018
1019 return memory_reduced;
1020 }
1021
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)1022 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
1023 Item* compressed_item,
1024 Item* uncompressed_item) {
1025 // Original buffer is now dead.
1026 memory_usage_ -= size_function_(original_item->instruction->shape());
1027 // Compressed buffer is now alive.
1028 memory_usage_ += size_function_(compressed_item->instruction->shape());
1029
1030 UsesList placed_users;
1031 UsesList unplaced_users;
1032 CHECK_EQ(original_item->buffers_output.size(), 1);
1033 BufferId original_buffer_id = original_item->buffers_output[0];
1034 Buffer& original_buffer = buffers_.at(original_buffer_id);
1035 for (ItemUse& user : original_buffer.users) {
1036 if (user.user->placed) {
1037 CHECK(IsFinished(user.user)) << user.user->instruction->name();
1038 placed_users.push_back(user);
1039 } else {
1040 unplaced_users.push_back(user);
1041 }
1042 }
1043 original_buffer.users = std::move(placed_users);
1044 original_buffer.unfinished_user_count = 0;
1045 original_buffer.users.push_back(ItemUse{compressed_item, 0, std::nullopt});
1046 // We are reallocating the vector containing the buffers potentially,
1047 // invalidating the original_buffer reference, so copy the index that we need
1048 // across NewBuffer calls.
1049 ShapeIndex copied_index = original_buffer.index;
1050 Buffer& compressed_buffer =
1051 NewBuffer(compressed_item, compressed_item->instruction->shape(),
1052 copied_index, {ItemUse{uncompressed_item, 0, std::nullopt}},
1053 /*live_out=*/false,
1054 /*has_indirect_uses=*/false);
1055 compressed_item->buffers_used = original_item->buffers_output;
1056 compressed_item->buffers_output = {compressed_buffer.id};
1057 compressed_item->buffers_defined.push_back(compressed_buffer.id);
1058
1059 Buffer& uncompressed_buffer =
1060 NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
1061 copied_index, std::move(unplaced_users), /*live_out=*/false,
1062 /*has_indirect_uses=*/false);
1063
1064 uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
1065 uncompressed_item->buffers_output = {uncompressed_buffer.id};
1066 uncompressed_item->buffers_defined = {uncompressed_buffer.id};
1067
1068 for (ItemUse& user : uncompressed_buffer.users) {
1069 BufferIdList& buffers_used = user.user->buffers_used;
1070 std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
1071 uncompressed_buffer.id);
1072 }
1073
1074 return OkStatus();
1075 }
1076
AddRematerializedInstruction(Item * original_item,Item * remat_item,absl::Span<Item * > indirect_users)1077 Status MemoryUsageTracker::AddRematerializedInstruction(
1078 Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
1079 VLOG(3) << "AddRematerializedInstruction: original_instruction = "
1080 << original_item->instruction->name()
1081 << ", remat_instruction = " << remat_item->instruction->name();
1082
1083 TF_RET_CHECK(in_progress_item_ != nullptr);
1084 TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
1085 TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
1086
1087 // Construct the list of buffers used and defined by the rematerialization.
1088 remat_item->buffers_used = original_item->buffers_used;
1089
1090 // Account for the additional buffer uses created by the new rematerialization
1091 // instruction. Update memory usage if the rematerialization makes a dead
1092 // buffer live again.
1093 for (BufferId buffer_id : original_item->buffers_used) {
1094 Buffer& buffer = buffers_.at(buffer_id);
1095 if (buffer.unfinished_user_count == 0) {
1096 // Buffer used by this instruction was dead, now is alive.
1097 memory_usage_ += AllocatedSize(buffer.id);
1098 }
1099 buffer.unfinished_user_count++;
1100 absl::InlinedVector<ItemUse, 2> filtered_users;
1101 std::copy_if(buffer.users.begin(), buffer.users.end(),
1102 std::back_inserter(filtered_users),
1103 [&](const ItemUse& iu) { return iu.user == original_item; });
1104 for (ItemUse& u : filtered_users) {
1105 buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
1106 }
1107 }
1108
1109 const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1110 indirect_users.end());
1111 // Create a new set of Buffers defined by the new rematerialization
1112 // instruction. Update the internal data structures and memory use to account
1113 // for them.
1114 for (BufferId old_buffer_id : original_item->buffers_defined) {
1115 Buffer& old_buffer = buffers_.at(old_buffer_id);
1116
1117 UsesList placed_users;
1118 UsesList unplaced_users;
1119 for (ItemUse& user : old_buffer.users) {
1120 if (user.user->placed) {
1121 placed_users.push_back(user);
1122 } else {
1123 // We keep only the indirect users that are in the provided list.
1124 // We consider all the other dead and remove any buffer use they might
1125 // perform and remove it from the buffer user list.
1126 if (!IsSupportedIndirectUser(user.user->instruction) ||
1127 indirect_users_set.contains(user.user)) {
1128 unplaced_users.push_back(user);
1129 } else {
1130 CHECK(user.user->buffers_defined.empty())
1131 << "Buffers defined expected to be empty for use passthrough "
1132 "instructions";
1133 user.user->buffers_output.clear();
1134 user.user->buffers_used.clear();
1135 }
1136 }
1137 }
1138 old_buffer.users = std::move(placed_users);
1139 old_buffer.unfinished_user_count = 0;
1140
1141 // Buffer is now dead.
1142 memory_usage_ -= AllocatedSize(old_buffer.id);
1143
1144 Buffer& new_buffer =
1145 RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
1146
1147 remat_item->buffers_defined.push_back(new_buffer.id);
1148 auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
1149 BufferIdList& to_update) {
1150 std::replace(to_update.begin(), to_update.end(), old_buffer_id,
1151 new_buffer_id);
1152 };
1153 // Update users with the id of the new buffer.
1154 for (ItemUse& user : new_buffer.users) {
1155 update_buffers(user.user->buffers_used);
1156 update_buffers(user.user->buffers_output);
1157 }
1158 }
1159
1160 // Update the indirect users with the id of the new buffers.
1161 for (Item* indirect_user : indirect_users) {
1162 // Source of the buffers that are gonna be passthrough.
1163 const Item* source_item =
1164 instruction_list_.GetItem(indirect_user->instruction->operand(0));
1165 switch (indirect_user->instruction->opcode()) {
1166 case HloOpcode::kBitcast: {
1167 // If the source is another indirect user then copy the output
1168 // in the used and output lists of the bitcast as they don't define any
1169 // buffer.
1170 if (IsSupportedIndirectUser(source_item->instruction)) {
1171 indirect_user->buffers_used = source_item->buffers_output;
1172 indirect_user->buffers_output = source_item->buffers_output;
1173 } else {
1174 // If it's a real instruction producing a buffer then copy the defined
1175 // buffers into used and output.
1176 indirect_user->buffers_used = source_item->buffers_defined;
1177 indirect_user->buffers_output = source_item->buffers_defined;
1178 }
1179 break;
1180 }
1181 case HloOpcode::kGetTupleElement: {
1182 // GTEs just use the tuple buffer and output the buffer they actually
1183 // extract from the tuple.
1184 const HloGetTupleElementInstruction* gte =
1185 Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
1186 for (BufferId buffer_id : source_item->buffers_defined) {
1187 const Buffer& def_buffer = buffers_.at(buffer_id);
1188 if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
1189 indirect_user->buffers_output.push_back(buffer_id);
1190 }
1191 // This is the tuple buffer.
1192 if (def_buffer.index.empty()) {
1193 indirect_user->buffers_used.push_back(buffer_id);
1194 }
1195 }
1196 break;
1197 }
1198 default: {
1199 LOG(FATAL) << "Unsupported indirect instruction with opcode "
1200 << HloOpcodeString(indirect_user->instruction->opcode());
1201 break;
1202 }
1203 }
1204 // Fixup buffer users for the indirect instructions. For GTEs is only the
1205 // tuple buffer, while for bitcast is the buffer they pass through.
1206 for (BufferId buffer_id : indirect_user->buffers_used) {
1207 Buffer& buffer = buffers_.at(buffer_id);
1208 buffer.unfinished_user_count++;
1209 buffer.users.push_back(ItemUse{indirect_user, 0, std::nullopt});
1210 }
1211 }
1212
1213 VLOG(3) << " memory usage = " << memory_usage_;
1214 XLA_VLOG_LINES(10, ToString());
1215
1216 DCHECK(Check());
1217
1218 return OkStatus();
1219 }
1220
ToString() const1221 std::string MemoryUsageTracker::ToString() const {
1222 std::string output =
1223 absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1224 absl::StrAppend(&output,
1225 "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1226 memory_usage(), " bytes)");
1227 for (auto* item = instruction_list_.first(); item != nullptr;
1228 item = instruction_list_.next(item)) {
1229 const HloInstruction* instruction = item->instruction;
1230 std::string inprogress = item == in_progress_item_ ? " in-progress" : "";
1231 std::string placed = item->placed ? " placed" : "";
1232 absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
1233 "\n Defines:\n");
1234 for (BufferId buffer_id : item->buffers_defined) {
1235 const Buffer& buffer = buffers_[buffer_id];
1236 std::string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1237 absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
1238 buffer.unfinished_user_count, " unfinished uses\n");
1239 }
1240 absl::StrAppend(&output, " Outputs:\n");
1241 for (BufferId buffer_id : item->buffers_output) {
1242 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
1243 }
1244 absl::StrAppend(&output, " Uses:\n");
1245 for (BufferId buffer_id : item->buffers_used) {
1246 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
1247 }
1248 }
1249 return output;
1250 }
1251
GetCompactShape(const HloInstruction * hlo)1252 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1253 auto it = compact_shape_.find(hlo);
1254 if (it != compact_shape_.end()) {
1255 return it->second;
1256 }
1257 const Shape& original_shape = hlo->shape();
1258 TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1259 compact_shape_[hlo] = min_shape;
1260 return min_shape;
1261 }
1262
Check() const1263 bool MemoryUsageTracker::Check() const {
1264 auto elements_are_unique = [](const BufferIdList& vec) {
1265 return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1266 };
1267
1268 // Verify buffers_defined per instruction.
1269 for (auto* instruction : computation_->instructions()) {
1270 const BufferIdList& defined_buffers =
1271 instruction_list_.GetItem(instruction)->buffers_defined;
1272 CHECK(elements_are_unique(defined_buffers))
1273 << "Instruction " << instruction->name()
1274 << " does not have unique defined buffers: "
1275 << absl::StrJoin(defined_buffers, ", ",
1276 [this](std::string* out, BufferId buffer_id) {
1277 absl::StrAppend(out,
1278 buffers_.at(buffer_id).ToString());
1279 });
1280
1281 for (const Buffer& buffer : buffers_) {
1282 if (buffer.defining_instruction->instruction == instruction) {
1283 CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1284 << "Instruction " << instruction->name()
1285 << " defined buffers is missing: " << buffer.ToString();
1286 }
1287 }
1288 }
1289
1290 // Verify buffers_used per instruction.
1291 for (auto* instruction : computation_->instructions()) {
1292 const BufferIdList& used_buffers =
1293 instruction_list_.GetItem(instruction)->buffers_used;
1294 CHECK(elements_are_unique(used_buffers))
1295 << "Instruction " << instruction->name()
1296 << " does not have unique used buffers: "
1297 << absl::StrJoin(used_buffers, ", ",
1298 [this](std::string* out, BufferId buffer_id) {
1299 absl::StrAppend(out,
1300 buffers_.at(buffer_id).ToString());
1301 });
1302 }
1303 for (const Buffer& buffer : buffers_) {
1304 int64_t unfinished_uses = 0;
1305 absl::flat_hash_set<Item*> already_counted_user;
1306 for (const ItemUse& user : buffer.users) {
1307 const BufferIdList& used_buffers = user.user->buffers_used;
1308 CHECK(absl::c_linear_search(used_buffers, buffer.id))
1309 << "Instruction " << user.user->instruction->name()
1310 << " used buffers is missing " << buffer.ToString();
1311 if (!IsFinished(user.user) &&
1312 already_counted_user.insert(user.user).second) {
1313 unfinished_uses++;
1314 }
1315 }
1316 CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1317 << "Incorrect unplaced use count for " << buffer.ToString();
1318 }
1319 return true;
1320 }
1321
1322 // Computes and returns the cost of rematerializing the given instruction.
1323 // Cost per rematerialized instruction is defined as:
1324 //
1325 // memory_limit_bytes / memory_reduced
1326 //
1327 // The idea is to choose the operation that will save the most memory for
1328 // rematerialization and do not worry about how much the compute costs since
1329 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64_t memory_reduced,int64_t memory_limit_bytes)1330 int64_t RematerializationCost(const HloInstruction* instruction,
1331 const MemoryUsageTracker& memory_tracker,
1332 int64_t memory_reduced,
1333 int64_t memory_limit_bytes) {
1334 // If none of the users of 'instruction' have been placed in the sequence (as
1335 // tracked by memory_tracker), then rematerialization of 'instruction' is a
1336 // zero-cost move of 'instruction' in the sequence.
1337 if (!absl::c_any_of(instruction->users(),
1338 [&memory_tracker](const HloInstruction* inst) {
1339 return memory_tracker.IsPlaced(inst);
1340 })) {
1341 return 0;
1342 }
1343
1344 CHECK_GT(memory_reduced, 0);
1345 // Return the inverse of the benefit of rematerialization.
1346 return memory_limit_bytes / memory_reduced;
1347 }
1348
1349 // Returns a block of up to min_block_size consecutive candidate instructions
1350 // from instruction_list starting from start_item. Returns fewer than
1351 // min_block_size instructions if the block of unplaced instructions starting
1352 // from start_item is smaller than min_block_size.
GetInitialBlock(const InstructionList & instruction_list,const MemoryUsageTracker & tracker,Item * start_item,int min_block_size)1353 std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
1354 const MemoryUsageTracker& tracker,
1355 Item* start_item, int min_block_size) {
1356 std::vector<Item*> item_block;
1357 Item* curr_item = start_item;
1358 for (int i = 0; i < min_block_size; ++i) {
1359 if (curr_item == nullptr || !curr_item->placed ||
1360 tracker.IsInProgressItem(curr_item)) {
1361 break;
1362 }
1363 item_block.push_back(curr_item);
1364 curr_item = instruction_list.next(curr_item);
1365 }
1366 return item_block;
1367 }
1368
1369 // Returns whether any instruction in 'block' is denylisted or
1370 // non-rematerializable.
AnyDenylistedOrNonRematerializable(const std::vector<Item * > & block,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)1371 bool AnyDenylistedOrNonRematerializable(
1372 const std::vector<Item*>& block,
1373 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
1374 for (auto* item : block) {
1375 if (item->denylisted) {
1376 return true;
1377 }
1378 if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
1379 return true;
1380 }
1381 }
1382 return false;
1383 }
1384
1385 std::tuple<std::vector<Item*>, RematStrategy, int>
PickRematerializationCandidates(const InstructionList & instruction_list,int64_t memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,int min_block_size,int max_block_size,int64_t peak_memory_bytes)1386 MemoryUsageTracker::PickRematerializationCandidates(
1387 const InstructionList& instruction_list, int64_t memory_limit_bytes,
1388 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1389 int min_block_size, int max_block_size, int64_t peak_memory_bytes) {
1390 std::vector<Item*> best_items;
1391 int64_t best_cost = 0;
1392 RematStrategy best_strategy;
1393
1394 int effort = 0;
1395 VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
1396 << max_block_size << "]";
1397
1398 for (auto* start_item = instruction_list.first_skip_node();
1399 start_item != nullptr;
1400 start_item = instruction_list.next_skip_node(start_item)) {
1401 std::vector<Item*> block =
1402 GetInitialBlock(instruction_list, *this, start_item, min_block_size);
1403 if (block.size() < min_block_size) {
1404 // There are no more blocks of size at least min_block_size with unplaced
1405 // instructions.
1406 break;
1407 }
1408 // If any item in the starting block are denylisted or non-rematable, then
1409 // break and move on to next start_item (we can actually move to the last
1410 // invalid item in this block, but let's ignore that optimization for now).
1411 if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
1412 continue;
1413 }
1414 while (block.size() <= max_block_size) {
1415 // block size = 1 is treated separately since we consider compression in
1416 // this case only.
1417 if (block.size() == 1) {
1418 auto* item = block[0];
1419 auto* candidate = item->instruction;
1420 if (item->buffers_output.size() == 1 &&
1421 (mode_ ==
1422 HloRematerialization::RematerializationMode::kCompressOnly ||
1423 mode_ == HloRematerialization::RematerializationMode::
1424 kRecomputeAndCompress)) {
1425 // Only consider compressing single output instruction.
1426 const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1427
1428 if (item->placed && item != in_progress_item_ &&
1429 !output_buffer.live_out) {
1430 const Shape& original_shape = item->instruction->shape();
1431 if (original_shape.IsArray()) {
1432 Shape compact_shape =
1433 GetCompactShape(item->instruction).ValueOrDie();
1434 const int64_t memory_reduced =
1435 MemoryReducedIfCompressed(item, compact_shape);
1436 // Since the compressed and uncompressed buffers need to be alive
1437 // while performing the compression/uncompression, only perform
1438 // the compression if the sum of the two sizes is less than the
1439 // peak memory.
1440 const int64_t size = size_function_(item->instruction->shape());
1441 const int64_t reduced_size = size_function_(compact_shape);
1442 effort++;
1443 if (memory_reduced > 0 &&
1444 size + reduced_size < peak_memory_bytes) {
1445 const int64_t cost = memory_limit_bytes / memory_reduced;
1446 if (best_items.empty() || cost < best_cost) {
1447 VLOG(3) << "candidate " << candidate->name() << "("
1448 << candidate->ToShortString() << ")"
1449 << " now best when compressed into "
1450 << compact_shape.ToString(true);
1451 RematStrategy strategy;
1452 strategy.kind = RematStrategy::kCompress;
1453 best_strategy = strategy;
1454 best_strategy.compact_shape = compact_shape;
1455 best_items = block;
1456 best_cost = cost;
1457 }
1458 }
1459 }
1460 }
1461 }
1462 }
1463 // Do not consider recomputation in compress-only mode.
1464 if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1465 // break out of this loop. Move on to the next start_item.
1466 break;
1467 }
1468 // If any of the candidate's control successor has been placed, we need
1469 // to skip this candidate. Otherwise we will violate control dependency.
1470 bool control_successor_placed = false;
1471 for (auto* item : block) {
1472 HloInstruction* candidate = item->instruction;
1473 if (std::any_of(candidate->control_successors().begin(),
1474 candidate->control_successors().end(),
1475 [this](const HloInstruction* inst) {
1476 return IsPlaced(inst);
1477 })) {
1478 control_successor_placed = true;
1479 break;
1480 }
1481 }
1482 if (control_successor_placed) {
1483 // break out of this loop. Move on to the next start_item.
1484 break;
1485 }
1486 VLOG(5) << "Block contains:";
1487 for (auto* hlo : block) {
1488 VLOG(5) << hlo->instruction->name();
1489 }
1490 const int64_t memory_reduced = MemoryReducedIfRematerialized(block);
1491 effort++;
1492 if (memory_reduced > 0) {
1493 const int cost =
1494 RematerializationCost(block, memory_reduced, memory_limit_bytes);
1495
1496 VLOG(5) << "Candidate block of size " << block.size()
1497 << " starting from " << block[0]->instruction->name()
1498 << ", memory reduced " << memory_reduced << ", cost per byte "
1499 << cost;
1500
1501 if (best_items.empty() || cost < best_cost) {
1502 VLOG(5) << "Candidate block of size " << block.size()
1503 << " starting from " << block[0]->instruction->name()
1504 << " now best";
1505 best_strategy.kind = RematStrategy::kRecompute;
1506 best_items = block;
1507 best_cost = cost;
1508 }
1509 }
1510
1511 // Time to update the block to include the next instruction.
1512 auto* last_item = block[block.size() - 1];
1513 auto* next_item = instruction_list.next(last_item);
1514 if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
1515 next_item == in_progress_item_ ||
1516 !CanBeRematerialized(next_item->instruction, rematerializable_map)) {
1517 break;
1518 }
1519 block.push_back(next_item);
1520 }
1521 }
1522 return {best_items, best_strategy, effort};
1523 }
1524
HasUnplacedUsers(Item * item) const1525 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1526 for (BufferId buffer_id : item->buffers_defined) {
1527 const Buffer& buffer = buffers_.at(buffer_id);
1528 for (const ItemUse& user : buffer.users) {
1529 if (!user.user->placed) {
1530 return true;
1531 }
1532 }
1533 }
1534 return false;
1535 }
1536
GetItemUses(Item * item) const1537 const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
1538 UsesList combined_users;
1539 for (BufferId buffer_id : item->buffers_defined) {
1540 const Buffer& buffer = buffers_.at(buffer_id);
1541 for (const ItemUse& user : buffer.users) {
1542 combined_users.push_back(user);
1543 }
1544 }
1545 return combined_users;
1546 }
1547
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list,HloRematerialization * rematerialization)1548 StatusOr<int64_t> RematerializeInstructions(
1549 MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1550 absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1551 InstructionList* instruction_list,
1552 HloRematerialization* rematerialization) {
1553 int64_t net_instructions_added = 0;
1554 int64_t total_memory_saved =
1555 memory_tracker->MemoryReducedIfRematerialized(*best_items);
1556 std::vector<std::string> instruction_names(best_items->size());
1557 // Rematerialize the block of instructions in the reverse order to account for
1558 // dependencies between instructions in best_items.
1559 for (int i = best_items->size() - 1; i >= 0; --i) {
1560 Item* best_item = (*best_items)[i];
1561 HloInstruction* best = best_item->instruction;
1562 instruction_names[i] = best->name();
1563 HloComputation* computation = best->parent();
1564
1565 // If the item to remat has no unplaced users, then skip the
1566 // rematerialization. Such an instruction can appear in best_items because
1567 // it is part of a good block, but does not itself add any benefit.
1568 if (!memory_tracker->HasUnplacedUsers(best_item)) {
1569 continue;
1570 }
1571
1572 HloInstruction* remat =
1573 computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1574 // Increment channel_id on channel instructions with a channel id.
1575 if (DynCast<HloChannelInstruction>(best) &&
1576 DynCast<HloChannelInstruction>(best)->channel_id()) {
1577 remat->set_channel_id(rematerialization->NextChannelId());
1578 }
1579
1580 // Add control dependencies to the new operation.
1581 for (auto successor : best->control_successors()) {
1582 TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1583 }
1584 for (auto predecessor : best->control_predecessors()) {
1585 TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1586 }
1587
1588 Item* remat_item = instruction_list->CreateItem(remat);
1589
1590 // Replace each remaining use of 'best' with the rematerialization.
1591 absl::InlinedVector<Item*, 4> indirect_users;
1592 absl::flat_hash_map<int64_t, HloInstruction*> gte_cache;
1593 for (auto& user : memory_tracker->GetItemUses(best_item)) {
1594 if (!memory_tracker->IsPlaced(user.user->instruction)) {
1595 VLOG(2) << " Replacing use of " << best->name() << " in "
1596 << user.user->instruction->name() << " with " << remat->name();
1597 HloInstruction* remat_use = remat;
1598 HloInstruction* const user_operand =
1599 user.user->instruction->mutable_operand(user.operand_number);
1600 if (remat_use == user_operand) {
1601 continue;
1602 }
1603 // If the output of a multi-output fusion node is forwarded to one of
1604 // its users as is, all the element buffers are also treated as uses
1605 // by that user, which need to be skipped.
1606 if (user.index && remat_use->shape() != user_operand->shape()) {
1607 auto cached_gte = gte_cache.find(*user.index);
1608 if (cached_gte == gte_cache.end()) {
1609 remat_use = computation->AddInstruction(
1610 HloInstruction::CreateGetTupleElement(
1611 ShapeUtil::GetTupleElementShape(remat_use->shape(),
1612 *user.index),
1613 remat_use, *user.index));
1614 indirect_users.push_back(instruction_list->CreateItem(remat_use));
1615 gte_cache[*user.index] = remat_use;
1616 } else {
1617 remat_use = cached_gte->second;
1618 }
1619 }
1620 if (user_operand->shape() != remat_use->shape()) {
1621 remat_use = computation->AddInstruction(
1622 HloInstruction::CreateBitcast(user_operand->shape(), remat_use));
1623 indirect_users.push_back(instruction_list->CreateItem(remat_use));
1624 }
1625 TF_RETURN_IF_ERROR(user.user->instruction->ReplaceOperandWith(
1626 user.operand_number, remat_use));
1627 }
1628 }
1629
1630 // Account for the rematerialization in the memory tracker.
1631 TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
1632 best_item, remat_item, absl::MakeSpan(indirect_users)));
1633
1634 // Insert rematerialized instruction right before the earliest unplaced
1635 // use of the instruction *and* the earliest unplaced last use of any
1636 // operands of remat. Unplaced uses of the remat's operands are included
1637 // because we don't want to extend the live range of remat's operands as
1638 // this could increase memory usage.
1639 ItemList place_before;
1640 const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1641 indirect_users.end());
1642 for (auto user : remat->users()) {
1643 if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1644 place_before.push_back(instruction_list->GetItem(user));
1645 }
1646 }
1647 for (auto* indirect_user : indirect_users) {
1648 for (auto user : indirect_user->instruction->users()) {
1649 if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1650 place_before.push_back(instruction_list->GetItem(user));
1651 }
1652 }
1653 }
1654 for (auto* operand : remat->operands()) {
1655 for (auto* operand_user : operand->users()) {
1656 if (operand_user != remat) {
1657 Item* operand_user_item = instruction_list->GetItem(operand_user);
1658 if (!operand_user_item->placed) {
1659 place_before.push_back(operand_user_item);
1660 }
1661 }
1662 }
1663 }
1664 // Insert rematerialized instruction before any of its successors to
1665 // preserve ordering regarding control dependency.
1666 for (auto successor : remat->control_successors()) {
1667 Item* successor_item = instruction_list->GetItem(successor);
1668 // Assert to make sure we never remat an operation with control
1669 // successor already placed.
1670 CHECK(!successor_item->placed) << successor_item->instruction->name();
1671 place_before.push_back(successor_item);
1672 }
1673 instruction_list->InsertBeforeInstructions(remat_item, place_before);
1674
1675 for (auto* bitcast : indirect_users) {
1676 instruction_list->InsertBeforeInstructions(bitcast, place_before);
1677 }
1678 // Helper function that looks through indirect users when determining if
1679 // there is an active user for an HloInstruction.
1680 std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
1681 for (auto* u : i->users()) {
1682 if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
1683 return false;
1684 }
1685 }
1686 return true;
1687 };
1688 // If the rematerialized instruction is dead then rematerialization is
1689 // essentially a move. Don't delete the instruction now because we don't
1690 // want duplicate HloInstruction* values during the course of the
1691 // transformation because we keep maps with HloInstruction* values as
1692 // keys.
1693 if (uses_empty(best)) {
1694 VLOG(2) << best->name() << " is now dead";
1695 if (ContainsKey(*remat_move_instructions, best)) {
1696 // Previously, 'best' was a rematerialization which killed the
1697 // instruction it was a copying of. Now 'remat' is a rematerialization
1698 // of 'best' and kills 'best'. Stop rematerializing this instruction
1699 // to avoid an infinite loop.
1700 instruction_list->Denylist(remat);
1701 }
1702 remat_move_instructions->insert(remat);
1703 net_instructions_added += indirect_users.size();
1704 } else {
1705 net_instructions_added += indirect_users.size() + 1;
1706 }
1707 for (auto* indirect_user : indirect_users) {
1708 instruction_list->Denylist(indirect_user->instruction);
1709 }
1710 if (HloDataflowAnalysis::IsAsynchronousOperationStart(best->opcode()) ||
1711 HloDataflowAnalysis::IsAsynchronousOperationDone(best->opcode())) {
1712 VLOG(2) << "The old instruction " << best->name()
1713 << " is an async op. Removing to maintain one start to one done "
1714 "invariant to keep the HLO valid.";
1715 TF_RETURN_IF_ERROR(computation->RemoveInstruction(best));
1716 }
1717 }
1718 VLOG(1) << "Rematerializing instructions ["
1719 << absl::StrJoin(instruction_names, ", ") << "] (saving "
1720 << HumanReadableNumBytes(total_memory_saved) << ")";
1721 return net_instructions_added;
1722 }
1723
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1724 StatusOr<int64_t> CompressInstruction(MemoryUsageTracker* memory_tracker,
1725 Item* best_item,
1726 const Shape& compact_shape,
1727 InstructionList* instruction_list) {
1728 HloInstruction* best = best_item->instruction;
1729 VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1730 << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1731 best_item, compact_shape))
1732 << ") to" << compact_shape.ToString(true);
1733
1734 HloComputation* computation = best->parent();
1735 HloInstruction* compressed = computation->AddInstruction(
1736 HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best),
1737 /*new_name=*/best->name() + ".remat_compressed");
1738
1739 HloInstruction* uncompressed = computation->AddInstruction(
1740 HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed),
1741 /*new_name=*/best->name() + ".remat_uncompressed");
1742
1743 Item* compressed_item = instruction_list->CreateItem(compressed);
1744 compressed_item->placed = true;
1745
1746 Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1747
1748 // Replace each remaining use of 'best' with the uncompressed.
1749 std::vector<HloInstruction*> best_users_copy = best->users();
1750 for (HloInstruction* user : best_users_copy) {
1751 if (!memory_tracker->IsPlaced(user)) {
1752 VLOG(5) << " Replacing use of " << best->name() << " in " << user->name()
1753 << " with " << uncompressed->name();
1754 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1755 }
1756 }
1757
1758 // Account for the rematerialization in the memory tracker.
1759 TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1760 best_item, compressed_item, uncompressed_item));
1761
1762 // Insert rematerialized instruction right before the earliest unplaced
1763 // use of the instruction.
1764 ItemList place_before;
1765 for (auto user : uncompressed->users()) {
1766 place_before.push_back(instruction_list->GetItem(user));
1767 }
1768
1769 instruction_list->Denylist(compressed_item->instruction);
1770 instruction_list->Denylist(uncompressed_item->instruction);
1771
1772 instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1773
1774 instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1775
1776 return 2;
1777 }
1778
1779 // A simple struct to encapsulate the number of instructions added during
1780 // rematerialization.
1781 struct InstructionsAdded {
1782 // Total count of instructions rematerialized.
1783 int remat_count;
1784 // Total count of instructions rematerialized minus number of original
1785 // instructions that are now dead.
1786 int net_instructions_added;
1787 // Amount of effort expended to find the instructions to rematerialize.
1788 int effort;
1789 };
1790
1791 // Rematerializes the best block of instructions of size between min_block_size
1792 // and max_block_size (both inclusive) if at least one candidate block of
1793 // instructions can be found. Returns number of instructions rematerialized.
RematerializeBestBlock(int min_block_size,int max_block_size,MemoryUsageTracker * memory_tracker,InstructionList * instruction_list,int64_t memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,HloRematerialization * rematerialization)1794 StatusOr<InstructionsAdded> RematerializeBestBlock(
1795 int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
1796 InstructionList* instruction_list, int64_t memory_limit_bytes,
1797 absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1798 absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1799 HloRematerialization* rematerialization) {
1800 CHECK(min_block_size > 0) << "Negative block size.";
1801
1802 std::vector<Item*> best_items;
1803 RematStrategy best_strategy;
1804 int effort;
1805 std::tie(best_items, best_strategy, effort) =
1806 memory_tracker->PickRematerializationCandidates(
1807 *instruction_list, memory_limit_bytes, rematerializable_map,
1808 min_block_size, max_block_size,
1809 rematerialization->ComputationPeakMemory(
1810 memory_tracker->computation()));
1811 InstructionsAdded num_instructions_added;
1812 num_instructions_added.remat_count = best_items.size();
1813 num_instructions_added.effort = effort;
1814 if (best_items.empty()) {
1815 num_instructions_added.net_instructions_added = 0;
1816 return num_instructions_added;
1817 }
1818
1819 if (best_strategy.kind == RematStrategy::kCompress) {
1820 CHECK(best_items.size() == 1)
1821 << "More than one instruction compressed simultaneously.";
1822 HloInstruction* best = best_items[0]->instruction;
1823 VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1824 << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1825 best_items[0], best_strategy.compact_shape))
1826 << ")";
1827
1828 TF_ASSIGN_OR_RETURN(
1829 num_instructions_added.net_instructions_added,
1830 CompressInstruction(memory_tracker, best_items[0],
1831 best_strategy.compact_shape, instruction_list));
1832 } else {
1833 TF_ASSIGN_OR_RETURN(
1834 num_instructions_added.net_instructions_added,
1835 RematerializeInstructions(memory_tracker, &best_items,
1836 remat_move_instructions, instruction_list,
1837 rematerialization));
1838 }
1839 return num_instructions_added;
1840 }
1841 } // namespace
1842
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order,const absl::flat_hash_set<absl::string_view> & execution_threads) const1843 StatusOr<int64_t> HloRematerialization::ComputePeakMemory(
1844 const HloComputation* computation, const HloInstructionSequence& order,
1845 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
1846 InstructionList instruction_list(order);
1847 MemoryUsageTracker tracker(computation, size_function_,
1848 compact_shape_function_, *points_to_analysis_,
1849 instruction_list, mode_);
1850 int64_t peak_memory = tracker.memory_usage();
1851 for (auto* item = instruction_list.first(); item != nullptr;
1852 item = instruction_list.next(item)) {
1853 const HloInstruction* instruction = item->instruction;
1854 TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1855 TF_ASSIGN_OR_RETURN(
1856 int64_t callee_usage,
1857 CalledComputationsMemoryUsage(instruction, execution_threads));
1858 peak_memory =
1859 std::max<int64_t>(peak_memory, tracker.memory_usage() + callee_usage);
1860 TF_RETURN_IF_ERROR(tracker.EndInstruction());
1861 }
1862 VLOG(1) << "Peak memory for " << computation->name() << ": "
1863 << HumanReadableNumBytes(peak_memory);
1864 return peak_memory;
1865 }
1866
CalledComputationsMemoryUsage(const HloInstruction * instruction,const absl::flat_hash_set<absl::string_view> & execution_threads) const1867 StatusOr<int64_t> HloRematerialization::CalledComputationsMemoryUsage(
1868 const HloInstruction* instruction,
1869 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
1870 const CallSite* callsite =
1871 call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1872 if (callsite == nullptr || callsite->context() == CallContext::kEmbedded) {
1873 return 0;
1874 }
1875 int64_t callee_usage = 0;
1876 for (const HloComputation* computation : callsite->called_computations()) {
1877 if (!IsExecutionThreadIncluded(execution_threads,
1878 computation->execution_thread())) {
1879 continue;
1880 }
1881 TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1882 callee_usage += computation_peak_memory_.at(computation);
1883 }
1884 return callee_usage;
1885 }
1886
IsExecutionThreadIncluded(const absl::flat_hash_set<absl::string_view> & execution_threads,absl::string_view thread) const1887 bool HloRematerialization::IsExecutionThreadIncluded(
1888 const absl::flat_hash_set<absl::string_view>& execution_threads,
1889 absl::string_view thread) const {
1890 return execution_threads.empty() || execution_threads.contains(thread);
1891 }
1892
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64_t memory_limit_bytes,int64_t min_remat_size,const absl::flat_hash_set<absl::string_view> & execution_threads)1893 StatusOr<bool> HloRematerialization::RematerializeComputation(
1894 HloComputation* computation, HloSchedule* schedule,
1895 int64_t memory_limit_bytes, int64_t min_remat_size,
1896 const absl::flat_hash_set<absl::string_view>& execution_threads) {
1897 VLOG(1) << "Rematerializing computation " << computation->name()
1898 << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1899 VLOG(1) << "peak memory usage is "
1900 << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1901 CHECK(!ContainsKey(rematerialized_computations_, computation));
1902
1903 InstructionList instruction_list(schedule->sequence(computation));
1904 MemoryUsageTracker memory_tracker(
1905 computation, size_function_, compact_shape_function_,
1906 *points_to_analysis_, instruction_list, mode_);
1907
1908 instruction_list.PromoteNodesToSkip([&](Item* item) {
1909 return memory_tracker.AllocatedSize(item) >= min_remat_size;
1910 });
1911 bool changed = false;
1912
1913 // If the rematerialization makes the source instruction dead, then the
1914 // rematerialization is added to 'remat_move_instructions' (the
1915 // rematerialization is essentially a move). If the next rematerialization of
1916 // the instruction is also a move then the rematerialization is added to the
1917 // denylist.
1918 absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1919
1920 // The map from instructions to their rematerializable status.
1921 absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
1922
1923 // The peak memory of the computation at any point in the instruction
1924 // sequence.
1925 int64_t peak_memory = memory_tracker.memory_usage();
1926
1927 // Total count of instructions rematerialized.
1928 int64_t remat_count = 0;
1929 // Total count of clones created minus number of original rematerialized
1930 // instructions which are dead.
1931 int64_t net_instructions_added = 0;
1932
1933 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1934
1935 // Iterate through all instructions in the sequence. At each instruction
1936 // (program point) if memory_usage exceeds the specified limit then
1937 // rematerialize HLO instructions until memory_usage is reduced.
1938 int64_t instruction_index = 0;
1939 for (auto* item = instruction_list.first(); item != nullptr;
1940 item = instruction_list.next(item)) {
1941 const HloInstruction* instruction = item->instruction;
1942 TF_ASSIGN_OR_RETURN(
1943 int64_t callee_usage,
1944 CalledComputationsMemoryUsage(instruction, execution_threads));
1945 TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1946
1947 VLOG(2) << "Program point at " << instruction->name()
1948 << ", memory usage = " << memory_tracker.memory_usage()
1949 << ", callee usage = " << callee_usage << ", [" << instruction_index
1950 << "/" << instruction_list.size() << "]";
1951 instruction_index++;
1952
1953 // Initialize both min_block_size and max_block_size to 1 so that only
1954 // single instruction rematerialization is considered first.
1955 int min_block_size = 1;
1956 int max_block_size = 1;
1957 // Only trigger rematerialization when the memory usage changes.
1958 if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
1959 // Finding larger blocks of instructions to rematerialize can be time
1960 // consuming. To limit the amount of time spent attempting to find such
1961 // large blocks, count the amount of effort expended to find single
1962 // instructions to rematerialize and then limit the total amount of effort
1963 // to at most a factor of block_rematerialization_factor_ more.
1964 bool is_first_phase = true;
1965 int64_t first_phase_effort = 0;
1966 int64_t second_phase_effort = 0;
1967 while (memory_tracker.memory_usage() + callee_usage >
1968 memory_limit_bytes) {
1969 VLOG(2) << "Over memory limit at instruction " << instruction->name()
1970 << ", using "
1971 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1972 callee_usage)
1973 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1974
1975 TF_ASSIGN_OR_RETURN(
1976 InstructionsAdded instructions_added,
1977 RematerializeBestBlock(min_block_size, max_block_size,
1978 &memory_tracker, &instruction_list,
1979 memory_limit_bytes, &rematerializable_map,
1980 &remat_move_instructions, this));
1981 net_instructions_added += instructions_added.net_instructions_added;
1982 remat_count += instructions_added.remat_count;
1983 if (is_first_phase) {
1984 first_phase_effort += instructions_added.effort;
1985 } else {
1986 second_phase_effort += instructions_added.effort;
1987 }
1988 VLOG(1) << "memory_usage after rematerialization = "
1989 << HumanReadableNumBytes(memory_tracker.memory_usage());
1990 if (instructions_added.remat_count == 0) {
1991 // Unable to find a block to rematerialize.
1992 // Consider doubling the block size.
1993 min_block_size = max_block_size + 1;
1994 max_block_size = 2 * max_block_size;
1995 is_first_phase = false;
1996 } else {
1997 // Found a valid block. Reset to start looking for single instructions
1998 // again.
1999 max_rematerialized_block_size_ =
2000 std::max(max_rematerialized_block_size_, max_block_size);
2001 changed = true;
2002 min_block_size = 1;
2003 max_block_size = 1;
2004 }
2005 if (max_block_size > block_size_limit_ ||
2006 second_phase_effort >
2007 block_rematerialization_factor_ * first_phase_effort) {
2008 break;
2009 }
2010 }
2011 }
2012 const CallSite* callsite = call_graph_node.GetCallSite(instruction);
2013 if (callsite != nullptr &&
2014 callsite->context() == CallContext::kControlFlow &&
2015 memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
2016 // Memory usage exceeds the limit. Try to rematerialize any
2017 // subcomputation(s) that this instruction calls.
2018 VLOG(1) << "Memory usage still over the limit ("
2019 << (memory_tracker.memory_usage() + callee_usage) << " > "
2020 << memory_limit_bytes
2021 << "). Rematerializing computations called by "
2022 << instruction->name();
2023
2024 // Recompute callee usage to account for any rematerialization performed
2025 // in the callee computations.
2026 for (HloComputation* called_computation :
2027 callsite->called_computations()) {
2028 if (!ContainsKey(rematerialized_computations_, called_computation)) {
2029 // Memory limit for the subcomputation is the memory limit less the
2030 // amount of memory used at this point in the computation.
2031 int64_t subcomputation_memory_limit_bytes = std::max<int64_t>(
2032 0, memory_limit_bytes - memory_tracker.memory_usage());
2033 TF_ASSIGN_OR_RETURN(
2034 bool subcomputation_changed,
2035 RematerializeComputation(called_computation, schedule,
2036 subcomputation_memory_limit_bytes,
2037 min_remat_size, execution_threads));
2038 changed |= subcomputation_changed;
2039 }
2040 }
2041
2042 TF_ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage(
2043 instruction, execution_threads));
2044 }
2045
2046 peak_memory = std::max<int64_t>(
2047 peak_memory, memory_tracker.memory_usage() + callee_usage);
2048 VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
2049
2050 TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
2051 }
2052
2053 // Verify some invariants on the memory tracker.
2054 for (auto* instruction : computation->instructions()) {
2055 CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
2056 }
2057
2058 VLOG(1) << "In computation " << computation->name() << " rematerialized "
2059 << remat_count << " instructions; " << net_instructions_added
2060 << " net instructions added";
2061 VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory)
2062 << " (was "
2063 << HumanReadableNumBytes(computation_peak_memory_.at(computation))
2064 << ")";
2065
2066 // Update peak memory used by computation.
2067 computation_peak_memory_.at(computation) = peak_memory;
2068
2069 // Update order to include rematerialized instructions.
2070 HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
2071 sequence.clear();
2072 for (auto* item = instruction_list.first(); item != nullptr;
2073 item = instruction_list.next(item)) {
2074 HloInstruction* instruction = item->instruction;
2075 sequence.push_back(instruction);
2076 }
2077 rematerialized_computations_.insert(computation);
2078
2079 instructions_rematerialized_ += remat_count;
2080 net_instructions_added_ += net_instructions_added;
2081
2082 return changed;
2083 }
2084
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2085 StatusOr<bool> HloRematerialization::Run(
2086 HloModule* module,
2087 const absl::flat_hash_set<absl::string_view>& execution_threads) {
2088 VLOG(1) << "HloRematerialization() with memory limit of "
2089 << HumanReadableNumBytes(memory_limit_bytes_);
2090 XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
2091
2092 // Initialize pass object state.
2093 computation_peak_memory_.clear();
2094 rematerialized_computations_.clear();
2095 instructions_rematerialized_ = 0;
2096 net_instructions_added_ = 0;
2097
2098 TF_RET_CHECK(module->has_schedule());
2099 TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
2100 next_channel_id_ = hlo_query::NextChannelId(*module);
2101
2102 // Adjust memory limit to account for the output of the entry
2103 // computation. This is necessary because the per-computation accounting in
2104 // MemoryUsageTracker do not include output as these are typically allocated
2105 // by the caller.
2106 int64_t module_output_size = 0;
2107 ShapeUtil::ForEachSubshape(
2108 module->result_shape(),
2109 [&module_output_size, module, this](const Shape& subshape,
2110 const ShapeIndex& output_index) {
2111 module_output_size += size_function_(subshape);
2112 });
2113
2114 const int64_t adjusted_memory_limit_bytes =
2115 memory_limit_bytes_ - module_output_size;
2116 VLOG(1) << "Adjusted memory limit accounting for output ("
2117 << HumanReadableNumBytes(module_output_size)
2118 << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
2119
2120 // Compute peak memory usage of all computations in the module called in a
2121 // sequential context.
2122 call_graph_ = CallGraph::Build(module);
2123 TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
2124 [this, module, &execution_threads](const CallGraphNode& node) -> Status {
2125 if (node.context() == CallContext::kControlFlow &&
2126 IsExecutionThreadIncluded(execution_threads,
2127 node.computation()->execution_thread())) {
2128 TF_ASSIGN_OR_RETURN(
2129 computation_peak_memory_[node.computation()],
2130 ComputePeakMemory(node.computation(),
2131 module->schedule().sequence(node.computation()),
2132 execution_threads));
2133 }
2134 return OkStatus();
2135 },
2136 /*visit_unreachable_nodes=*/false));
2137
2138 // The peak memory usage of the module equals the peak memory use of the entry
2139 // computation plus the output size of the computation. This is because the
2140 // peak memory for a computation does not include the output as this is
2141 // typically accounted for in the caller.
2142 const int64_t before_peak_memory =
2143 computation_peak_memory_.at(module->entry_computation()) +
2144 module_output_size;
2145 VLOG(1) << "Peak memory usage of module (before): "
2146 << HumanReadableNumBytes(before_peak_memory);
2147 // Subcomputations called by the entry computation will also be
2148 // rematerialized.
2149 TF_ASSIGN_OR_RETURN(
2150 bool changed,
2151 RematerializeComputation(module->entry_computation(), &module->schedule(),
2152 adjusted_memory_limit_bytes, min_remat_size_,
2153 execution_threads));
2154 // Rematerialization can introduce dead code. This occurs if all uses of an
2155 // instruction are replaced with rematerializations of the instruction.
2156
2157 // Stash away the schedule during copy insertion, to avoid validation failures
2158 // while the module is in flux.
2159 HloSchedule saved_schedule = module->schedule();
2160 module->clear_schedule();
2161 TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
2162 changed |= dead_code_removed;
2163
2164 // After DCE, the module sequence may include instructions which no longer
2165 // exist. Update the schedule and restore it.
2166 TF_RETURN_IF_ERROR(saved_schedule.Update(execution_threads));
2167 TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
2168 VLOG(1) << "Rematerialized " << instructions_rematerialized_
2169 << " instructions in module " << module->name() << "; "
2170 << net_instructions_added_ << " net instructions added";
2171 const int64_t current_peak_memory =
2172 computation_peak_memory_.at(module->entry_computation()) +
2173 module_output_size;
2174 VLOG(1) << "Peak memory usage of module now "
2175 << HumanReadableNumBytes(current_peak_memory) << " ("
2176 << current_peak_memory << " bytes), was "
2177 << HumanReadableNumBytes(before_peak_memory) << " ("
2178 << before_peak_memory << " bytes)";
2179 const int64_t reduced_peak_memory = before_peak_memory - current_peak_memory;
2180 VLOG(1) << "Reduced peak memory by "
2181 << HumanReadableNumBytes(reduced_peak_memory) << " ("
2182 << reduced_peak_memory << " bytes)";
2183
2184 if (sizes_ != nullptr) {
2185 sizes_->before_bytes = before_peak_memory;
2186 sizes_->after_bytes = current_peak_memory;
2187 }
2188
2189 XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
2190
2191 if (current_peak_memory > memory_limit_bytes_) {
2192 LOG(WARNING) << absl::StrFormat(
2193 "Can't reduce memory use below %s (%d bytes) by rematerialization; "
2194 "only reduced to %s (%d bytes)",
2195 HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
2196 HumanReadableNumBytes(current_peak_memory), current_peak_memory);
2197 }
2198 return changed;
2199 }
2200
2201 } // namespace xla
2202