xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_rematerialization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <memory>
21 #include <set>
22 #include <string>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/str_format.h"
30 #include "absl/strings/str_join.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/primitive_util.h"
33 #include "tensorflow/compiler/xla/service/buffer_value.h"
34 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
35 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
36 #include "tensorflow/compiler/xla/service/hlo_computation.h"
37 #include "tensorflow/compiler/xla/service/hlo_dce.h"
38 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
39 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
40 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
41 #include "tensorflow/compiler/xla/service/hlo_module.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
44 #include "tensorflow/compiler/xla/service/hlo_query.h"
45 #include "tensorflow/compiler/xla/service/logical_buffer.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/statusor.h"
48 #include "tensorflow/compiler/xla/types.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/core/platform/logging.h"
51 
52 namespace xla {
53 namespace {
54 
55 using ::tensorflow::strings::HumanReadableNumBytes;
56 
57 // Potential optimizations:
58 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
59 //   of candidates.
60 // . Cache IsRematerializable in Item?  Only correct if control
61 //   predecessors and successors don't change.
62 
63 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)64 bool IsRematerializable(const HloInstruction* instruction) {
65   if (instruction->opcode() == HloOpcode::kCopy) {
66     if (LayoutUtil::Equal(instruction->shape().layout(),
67                           instruction->operand(0)->shape().layout())) {
68       // Don't rematerialize copies added by copy insertion (layout doesn't
69       // change).
70       return false;
71     }
72   }
73 
74   if (auto collective = DynCast<HloCollectiveInstruction>(instruction)) {
75     return !collective->constrain_layout();
76   }
77 
78   // Don't rematerialize instructions with side effects or instructions which
79   // cannot be cloned safely.
80   switch (instruction->opcode()) {
81     case HloOpcode::kCall:
82     case HloOpcode::kConstant:
83     case HloOpcode::kConditional:
84     case HloOpcode::kCustomCall:
85     case HloOpcode::kParameter:
86     case HloOpcode::kWhile:
87       return false;
88     default:
89       return !instruction->HasSideEffect();
90   }
91 }
92 
93 // Checks whether an instruction can be rematerialized, by looking up the
94 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)95 bool CanBeRematerialized(
96     const HloInstruction* instruction,
97     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
98   auto it = rematerializable_map->find(instruction);
99   if (it != rematerializable_map->end()) {
100     return it->second;
101   }
102   bool rematerializable = IsRematerializable(instruction);
103   (*rematerializable_map)[instruction] = rematerializable;
104   return rematerializable;
105 }
106 
107 // Return if this is an instruction that relays the buffers it uses to its own
108 // users and if this is one of these instructions we support the
109 // rematerialization of.
IsSupportedIndirectUser(const HloInstruction * instruction)110 bool IsSupportedIndirectUser(const HloInstruction* instruction) {
111   return instruction->opcode() == HloOpcode::kBitcast ||
112          instruction->opcode() == HloOpcode::kGetTupleElement;
113 }
114 
115 // Type holding a unique identifier for each Buffer object.
116 using BufferId = int64_t;
117 using BufferIdList = absl::InlinedVector<BufferId, 3>;
118 
119 struct RematStrategy {
120   enum {
121     // Recompute the node at a later program point.
122     kRecompute,
123     // Change the layout into a compact form and uncompress it back at a later
124     // program point.
125     kCompress,
126   } kind;
127   Shape compact_shape;
128 };
129 
130 // We wrap HloInstruction* with an Item that holds auxiliary
131 // per-instruction state.
132 struct Item {
133   HloInstruction* instruction;
134 
135   // True once the instruction is marked as placed (when BeginInstruction
136   // has been called for this instruction).
137   bool placed = false;
138 
139   // To avoid an infinite loop rematerializing the same set of
140   // instructions ad infinitum, keep a denylist of instructions
141   // which should not be rematerialized.
142   bool denylisted = false;
143 
144   // The buffers defined by this instruction.
145   BufferIdList buffers_defined;
146 
147   // Output buffers of this instruction. This is used to track outputs by GTE
148   // instructions (where the instruction doesn't define a buffer).
149   BufferIdList buffers_output;
150 
151   // The buffers used by this instruction.
152   BufferIdList buffers_used;
153 
154   bool is_skip_node = false;
155 
156  private:
157   friend class InstructionList;
158 
159   // Items are arranged in a doubly linked list.
160   Item* next = nullptr;
161   Item* prev = nullptr;
162 
163   Item* prev_skip_node = nullptr;
164   Item* next_skip_node = nullptr;
165 
166   // List is ordered by position, which can however be duplicated as
167   // new instructions are inserted.  See InsertBeforeInstructions
168   // comment for details.
169   int64_t position;
170 };
171 
172 // Data structure meant to record the user of the buffer defined from an Item.
173 // It records also the operand_number from where such use derives, so that
174 // indirect uses can be better identified (like for example a buffer used
175 // through a bitcast).
176 struct ItemUse {
177   Item* user;
178   int64_t operand_number;
179   std::optional<int64_t> index;
180 
ItemUsexla::__anon2b46f4430111::ItemUse181   ItemUse(Item* user, int64_t op_num, std::optional<int64_t> index)
182       : user(user), operand_number(op_num), index(index) {}
operator ==xla::__anon2b46f4430111::ItemUse183   bool operator==(const ItemUse& other) const {
184     return user == other.user && operand_number == other.operand_number &&
185            index == other.index;
186   }
187 };
188 
189 using ItemList = absl::InlinedVector<Item*, 3>;
190 using UsesList = absl::InlinedVector<ItemUse, 3>;
191 
192 // Class which maintains an ordered list of instructions with fast insertion
193 // before arbitrary elements.
194 //
195 // This is a skip list structure that has two lanes: express lane and slow lane.
196 // All nodes are presented on the slow lane but a node can be promoted into
197 // express lane for fast iteration.
198 //
199 // In the following case, node 2 and node + 1 are connected via an express lane.
200 //                    +--------------------------+----------->: Express lane
201 //                    |                          |
202 //       node1<-> node 2 <-> .. <-> node n <-> node n+1 <->...: Slow lane
203 //
204 class InstructionList {
205  public:
InstructionList(const HloInstructionSequence & order)206   explicit InstructionList(const HloInstructionSequence& order) {
207     int64_t position = 0;
208     Item* last = nullptr;
209     last_skip_node_ = nullptr;
210     first_skip_node_ = nullptr;
211     for (HloInstruction* inst : order.instructions()) {
212       // Add a new item to the linked list.
213       Item* item = new Item;
214       item->next = nullptr;
215       item->prev = last;
216       if (last == nullptr) {
217         first_ = item;
218       } else {
219         last->next = item;
220       }
221       last = item;
222 
223       // Initially position numbers are uniquely assigned in order. Later as
224       // instructions are added with InsertBefore* methods, some instructions
225       // may have duplicate position numbers, but the values will be guaranteed
226       // to be monotonically increasing through the list, and so is still useful
227       // for quickly(-ish) determining the order of arbitrary instructions in
228       // the list.
229       item->instruction = inst;
230       item->position = position;
231       position++;
232 
233       item_map_[inst] = item;
234     }
235   }
236 
~InstructionList()237   ~InstructionList() {
238     for (Item* item = first_; item != nullptr;) {
239       Item* next = item->next;
240       delete item;
241       item = next;
242     }
243   }
244 
size() const245   size_t size() const { return item_map_.size(); }
246 
247   // For ordered iteration over items.
248   //    for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const249   Item* first() const { return first_; }
next(Item * item) const250   Item* next(Item* item) const { return item->next; }
251 
first_skip_node() const252   Item* first_skip_node() const { return first_skip_node_; }
next_skip_node(Item * item) const253   Item* next_skip_node(Item* item) const { return item->next_skip_node; }
254 
255   // Creates an Item for the given instruction, but doesn't add it to the list.
256   // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)257   Item* CreateItem(HloInstruction* inst) {
258     Item* item = new Item;
259     item->instruction = inst;
260     CHECK(item_map_.insert({inst, item}).second)
261         << "inserting inst twice " << inst->name();
262     return item;
263   }
264 
265   // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const266   Item* GetItem(const HloInstruction* inst) const {
267     auto iter = item_map_.find(inst);
268     CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
269     return iter->second;
270   }
271 
272   // Insert instruction 'to_insert' immediately before the earliest instruction
273   // in 'before_instructions'.
274   //
275   // Each instruction gets a non-decreasing ordinal number. We use this to let
276   // InsertBeforeInstructions quickly insert an instruction before the earliest
277   // instruction in a set of instructions.  If position_number_[a] <
278   // position_number_[b] then 'a' comes before 'b' in the list. If the position
279   // numbers are the same then nothing can be said about their order without
280   // examining the list.
281   //
282   // On object construction this ordinal is precisely the instruction's index
283   // in the list. Later, instructions inserted via InsertBefore receive
284   // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)285   void InsertBeforeInstructions(Item* to_insert,
286                                 absl::Span<Item* const> before_instructions) {
287     VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
288             << " before {"
289             << absl::StrJoin(before_instructions, ", ",
290                              [](std::string* out, Item* item) {
291                                absl::StrAppend(out, item->instruction->name());
292                              })
293             << "}";
294 
295     // Find the minimal position number of any instruction in
296     // 'before_instructions'.
297     CHECK(!before_instructions.empty());
298     Item* min_position_item = nullptr;
299     for (Item* item : before_instructions) {
300       if (min_position_item == nullptr ||
301           item->position < min_position_item->position) {
302         min_position_item = item;
303       }
304     }
305 
306     // Because more than one instruction in 'before_instructions' may have a
307     // position number of 'min_position_number', find the first such instruction
308     // with position number 'min_position_number'.
309 
310     // First find first instruction with the min position.
311     while (min_position_item->prev != nullptr &&
312            min_position_item->position == min_position_item->prev->position) {
313       min_position_item = min_position_item->prev;
314     }
315 
316     // Now scan forwards until we find one of the before_instructions.
317     while (!absl::c_linear_search(before_instructions, min_position_item)) {
318       min_position_item = min_position_item->next;
319     }
320     return InsertBefore(to_insert, min_position_item);
321   }
322 
323   // Scan the list and promote nodes to express lane if should_promote(Item)
324   // returns true;
PromoteNodesToSkip(std::function<bool (Item *)> should_promote)325   void PromoteNodesToSkip(std::function<bool(Item*)> should_promote) {
326     int64_t count = 0;
327     for (auto* item = first(); item != nullptr; item = next(item)) {
328       if (should_promote(item)) {
329         count += 1;
330         if (first_skip_node_ == nullptr) {
331           first_skip_node_ = item;
332         }
333         item->is_skip_node = true;
334         item->prev_skip_node = last_skip_node_;
335         if (last_skip_node_ != nullptr) {
336           last_skip_node_->next_skip_node = item;
337         }
338         last_skip_node_ = item;
339       }
340     }
341     VLOG(1) << " Rematerialization has " << count << " items in express lane";
342   }
343 
InsertAfterInstructions(Item * to_insert,absl::Span<Item * const> after_instructions)344   void InsertAfterInstructions(Item* to_insert,
345                                absl::Span<Item* const> after_instructions) {
346     VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
347             << " after {"
348             << absl::StrJoin(after_instructions, ", ",
349                              [](std::string* out, Item* item) {
350                                absl::StrAppend(out, item->instruction->name());
351                              })
352             << "}";
353 
354     // Find the max position number of any instruction in
355     // 'after_instructions'.
356     CHECK(!after_instructions.empty());
357     Item* max_position_item = nullptr;
358     for (Item* item : after_instructions) {
359       if (max_position_item == nullptr ||
360           item->position > max_position_item->position) {
361         max_position_item = item;
362       }
363     }
364     // No rematerializable instruction should be inserted at the end of the
365     // computation.
366     CHECK(max_position_item->next != nullptr);
367     InsertBeforeInstructions(to_insert, {max_position_item->next});
368   }
369 
Denylist(const HloInstruction * inst)370   void Denylist(const HloInstruction* inst) {
371     GetItem(inst)->denylisted = true;
372   }
373 
374  private:
375   // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)376   void InsertBefore(Item* item, Item* before) {
377     VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
378             << before->instruction->name();
379     // Always place new nodes on express lane for the ease of implementation.
380     item->is_skip_node = true;
381     // Find the next express node starting from 'before'. Set up the node's
382     // express pointers.
383     Item* cursor = before;
384     while (cursor != nullptr && !cursor->is_skip_node) {
385       cursor = cursor->next;
386     }
387     CHECK(cursor == nullptr || cursor->is_skip_node);
388     if (cursor == nullptr) {
389       //
390       // last_skip_node_<---+                              : express lane
391       //                    |
392       //           ...<->`item`<-> .. <-> `cursor`(null)   : slow lane
393       //
394       // Reached the end. Set the prev_express to last_skip_node, and reset
395       // last_skip.
396       item->prev_skip_node = last_skip_node_;
397       item->next_skip_node = nullptr;
398       last_skip_node_ = item;
399     } else {
400       //
401       //     <-+------------+----------------+--------->   : express lane
402       //       |            |                |
403       // prev_express..<->`item`<-> .. <-> `cursor` <-> ...: slow lane
404       //
405       // Reached the next skip node, sets up express pointers accordingly.
406       CHECK(cursor->is_skip_node);
407       item->prev_skip_node = cursor->prev_skip_node;
408       if (item->prev_skip_node != nullptr) {
409         item->prev_skip_node->next_skip_node = item;
410       }
411       item->next_skip_node = cursor;
412       cursor->prev_skip_node = item;
413     }
414     if (first_skip_node_ == cursor) {
415       first_skip_node_ = item;
416     }
417     // Insert new item into linked list.
418     item->prev = before->prev;
419     item->next = before;
420     before->prev = item;
421     if (item->prev != nullptr) {
422       item->prev->next = item;
423     } else {
424       first_ = item;
425     }
426 
427     // Assign the same position number to the newly added instruction as
428     // 'before'. This guarantees monotonicity of the position numbers, but not
429     // uniqueness.
430     item->position = before->position;
431   }
432 
433   Item* first_;
434 
435   // First skip node of this list.
436   Item* first_skip_node_;
437 
438   // Last skip node of this list.
439   Item* last_skip_node_;
440 
441   // Item for each instruction.
442   absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
443 };
444 
445 // Return the items which use the given LogicalBuffer. Sets
446 // has_indirect_users to whether any of the uses is indirect. A use is indirect
447 // if the instruction defining logical_buffer is not an operand of the use. This
448 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)449 UsesList GetUsers(const InstructionList& instruction_list,
450                   const LogicalBuffer* logical_buffer,
451                   const TuplePointsToAnalysis& points_to_analysis,
452                   bool* has_indirect_users) {
453   UsesList users;
454   // To identify uses iterate through all HloInstruction users of the
455   // BufferAliases of the logical buffer.
456   *has_indirect_users = false;
457   for (const BufferAlias& buffer_alias :
458        points_to_analysis.GetBufferAliases(*logical_buffer)) {
459     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
460       if (points_to_analysis.DoesNotUseOperandBuffer(
461               buffer_alias.instruction(), buffer_alias.index(), user)) {
462         // The alias may be an operand of 'user', but the LogicalBuffer cannot
463         // possibly be used by the instruction so ignore 'user'. This is the
464         // case, for example, for the tuple element buffers in a GetTupleElement
465         // instruction (the GTE instruction only uses the pointer vector).
466         continue;
467       }
468       if (buffer_alias.instruction() != logical_buffer->instruction() &&
469           !IsSupportedIndirectUser(buffer_alias.instruction())) {
470         *has_indirect_users = true;
471       }
472       // A buffer may be used by the instruction via more than one alias. For
473       // example, a buffer which appears in more than one element of a tuple.
474       Item* user_item = instruction_list.GetItem(user);
475       std::optional<int64_t> user_index =
476           logical_buffer->index().size() != 1
477               ? std::nullopt
478               : std::make_optional(logical_buffer->index().back());
479       for (int64_t op_idx : user->OperandIndices(buffer_alias.instruction())) {
480         if (!absl::c_linear_search(
481                 users,
482                 ItemUse{user_item, static_cast<int>(op_idx), user_index})) {
483           users.push_back(
484               ItemUse{user_item, static_cast<int>(op_idx), user_index});
485         }
486       }
487     }
488   }
489   return users;
490 }
491 
492 // Class for tracking memory usage of a computation as the instructions are
493 // placed sequentially. Memory usage is the sum of the sizes of live values
494 // (LogicalBuffers) at the current point in the instruction sequence.
495 class MemoryUsageTracker {
496  public:
497   MemoryUsageTracker(
498       const HloComputation* computation,
499       const HloRematerialization::ShapeSizeFunction& size_function,
500       const HloRematerialization::CompactShapeFunction& compact_shape_function,
501       const TuplePointsToAnalysis& points_to_analysis,
502       const InstructionList& instruction_list,
503       HloRematerialization::RematerializationMode mode);
504 
505   // Starts the placement of the given instruction. This adds the sizes of the
506   // LogicalBuffers defined by the instruction to the current memory
507   // usage. Placement is broken into two steps (BeginInstruction and
508   // EndInstruction) to accurately model memory usage. At BeginInstruction the
509   // memory for the output value(s) of the current instruction is allocated. At
510   // EndInstruction memory for dead operand(s) is freed.
511   Status BeginInstruction(Item* item);
512 
RematerializationCost(const std::vector<Item * > & items,int64_t memory_reduced,int64_t memory_limit_bytes)513   int64_t RematerializationCost(const std::vector<Item*>& items,
514                                 int64_t memory_reduced,
515                                 int64_t memory_limit_bytes) {
516     // If none of the users of any 'item' have been placed in the
517     // sequence (as tracked by memory_tracker), then rematerialization of
518     // 'item' is a zero-cost move of 'item->instruction' in the sequence.
519     bool zero_cost_move = true;
520     for (auto* item : items) {
521       auto* instruction = item->instruction;
522       if (absl::c_any_of(
523               instruction->users(),
524               [this](const HloInstruction* inst) { return IsPlaced(inst); })) {
525         zero_cost_move = false;
526         break;
527       }
528     }
529     if (zero_cost_move) {
530       return 0;
531     }
532 
533     CHECK_GT(memory_reduced, 0);
534     // Return the inverse of the benefit of rematerialization.
535     return memory_limit_bytes / memory_reduced;
536   }
537 
538   // Finishes the placement of the current instruction. This frees any dead
539   // operands or dead result of the instruction. This must be called after
540   // each call to BeginInstruction.
541   Status EndInstruction();
542 
543   // Returns the number of bytes that the current memory usage will be reduced
544   // if the given instruction is compact.
545   int64_t MemoryReducedIfCompressed(Item* item,
546                                     const Shape& compact_shape) const;
547 
548   // Returns the number of bytes that the current memory usage will be reduced
549   // by if the given sequence of instructions is rematerialized.
550   int64_t MemoryReducedIfRematerialized(
551       absl::Span<const Item* const> items) const;
552 
553   Status AddCompressInstructions(Item* original_item, Item* compressed_item,
554                                  Item* uncompressed_item);
555 
556   // Adjusts memory usage to account for the rematerialization of
557   // original_item for all remaining unplaced uses. The rematerialization
558   // is remat_item. This method should be called after the HLO graph has
559   // been transformed (rematerialization instruction created and connected
560   // to uses).
561   Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
562                                       absl::Span<Item*> indirect_users);
563 
564   // Selects and returns the best candidate instructions for rematerialization.
565   // A sequence of candidate instructions of length between min_block_size and
566   // max_block_size (both inclusive) with the lowest rematerialization cost is
567   // selected among those candidates which reduce memory use at the program
568   // point of the current instruction as indicated by memory_tracker. Returns an
569   // empty vector if no candidates are found. Also returns an integer that
570   // represents the amount of "effort" expended to find the candidate
571   // instructions.
572   std::tuple<std::vector<Item*>, RematStrategy, int>
573   PickRematerializationCandidates(
574       const InstructionList& instruction_list, int64_t memory_limit_bytes,
575       absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
576       int min_block_size, int max_block_size, int64_t peak_memory_bytes);
577 
578   // Returns whether the given instruction has been placed (BeginInstruction
579   // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const580   bool IsPlaced(const HloInstruction* instruction) const {
581     return instruction_list_.GetItem(instruction)->placed;
582   }
583 
584   // Returns whether 'item' has any unplaced users.
585   bool HasUnplacedUsers(Item* item) const;
586 
587   // Returns the list of uses for a specific 'item'.
588   const UsesList GetItemUses(Item* item) const;
589 
590   // Returns whether 'item' is currently in progress.
IsInProgressItem(Item * item) const591   bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
592 
593   // Returns the current memory usage. This is the sum of sizes of all live
594   // values.
memory_usage() const595   int64_t memory_usage() const { return memory_usage_; }
596 
597   //
AllocatedSize(Item * item) const598   int64_t AllocatedSize(Item* item) const {
599     int64_t size = 0;
600     for (auto buffer_id : item->buffers_defined) {
601       size += AllocatedSize(buffer_id);
602     }
603     return size;
604   }
605 
computation() const606   const HloComputation* computation() const { return computation_; }
607 
608   // Check invariants of the data structure. This is expensive to call.
609   bool Check() const;
610 
611   std::string ToString() const;
612 
613  private:
614   // A Buffer represents a single LogicalBuffer in the computation including
615   // various metadata useful for tracking liveness of the value. A LogicalBuffer
616   // is not used directly because the HLO graph is transformed and
617   // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
618   // HLO graph transformations.
619   struct Buffer {
620     // The unique id of this Buffer. This value is equal to the buffer's index
621     // in the vector buffers_.
622     const BufferId id;
623 
624     // The instruction which defines this buffer.
625     Item* defining_instruction;
626 
627     // The materialized size of the buffer in bytes.
628     const int64_t size;
629 
630     // Shape of the buffer.
631     Shape shape;
632 
633     // Whether this buffer is live-out of the computation.
634     bool live_out;
635 
636     // Whether this buffer has indirect uses. Ie, an instruction which is not a
637     // user of defining_instruction uses this buffer. This can occur due to
638     // buffer aliasing (eg, tuples).
639     bool has_indirect_uses;
640 
641     // Position in the tuple this buffer definition lives in.
642     ShapeIndex index;
643 
644     // The instructions which use this buffer.
645     UsesList users;
646 
647     // The number of users (HloInstructions) of this buffer which have not yet
648     // been placed in the sequence.
649     int64_t unfinished_user_count;
650 
ToStringxla::__anon2b46f4430111::MemoryUsageTracker::Buffer651     std::string ToString() const {
652       return absl::StrCat("Buffer ", id, " (defined by ",
653                           defining_instruction->instruction->name(), ", size ",
654                           size, " bytes)");
655     }
656   };
657 
658   // Get the compact shape of given hlo instruction. An internal cache is used
659   // to avoid computing the shape multiple times.
660   StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
661 
662   // Creates a Buffer representing the given logical buffer. The buffer is added
663   // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool live_out)664   Buffer& CreateBufferFromLogicalBuffer(
665       const LogicalBuffer* logical_buffer,
666       const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
667     bool has_indirect_uses = false;
668     UsesList users = GetUsers(instruction_list_, logical_buffer,
669                               points_to_analysis, &has_indirect_uses);
670     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
671                      logical_buffer->shape(), logical_buffer->index(),
672                      std::move(users), live_out, has_indirect_uses);
673   }
674 
675   // Create a new buffer representing a rematerialization of given buffer for
676   // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,UsesList && rematerialized_uses)677   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
678                               UsesList&& rematerialized_uses) {
679     CHECK(original_buffer.defining_instruction->placed)
680         << original_buffer.defining_instruction->instruction->name();
681     CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
682     CHECK(!original_buffer.live_out) << original_buffer.ToString();
683     for (ItemUse& use : rematerialized_uses) {
684       CHECK(!use.user->placed) << use.user->instruction->name();
685     }
686     return NewBuffer(remat_item, original_buffer.shape, original_buffer.index,
687                      std::move(rematerialized_uses), /*live_out=*/false,
688                      /*has_indirect_uses=*/false);
689   }
690 
691   // Return number of bytes allocated for the buffer with the given id. Buffers
692   // allocated by the calling computation (eg, parameter and output buffers) are
693   // considered to have zero bytes because the memory is accounted for in a
694   // different computation.
AllocatedSize(BufferId buffer_id) const695   int64_t AllocatedSize(BufferId buffer_id) const {
696     const Buffer& buffer = buffers_.at(buffer_id);
697     HloInstruction* inst = buffer.defining_instruction->instruction;
698     HloOpcode def_opcode = inst->opcode();
699     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
700       return 0;
701     } else {
702       return buffer.size;
703     }
704   }
705 
706   // Returns true if BeginInstruction and EndInstruction has been called for the
707   // given instruction.
IsFinished(Item * item) const708   bool IsFinished(Item* item) const {
709     return item->placed && item != in_progress_item_;
710   }
711 
712   // Returns whether the given buffer is being used by the in-progress
713   // instruction.
IsInUse(BufferId buffer_id) const714   bool IsInUse(BufferId buffer_id) const {
715     if (in_progress_item_ == nullptr) {
716       return false;
717     }
718     const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
719     return absl::c_linear_search(in_progress_uses, buffer_id);
720   }
721 
IsCurrentlyLive(BufferId buffer_id) const722   bool IsCurrentlyLive(BufferId buffer_id) const {
723     const Buffer& buffer = buffers_[buffer_id];
724     return (buffer.defining_instruction->placed &&
725             buffer.unfinished_user_count > 0);
726   }
727 
728   // Returns whether the given instruction is live at the current program
729   // point.
IsInstructionCurrentlyLive(Item * instruction) const730   bool IsInstructionCurrentlyLive(Item* instruction) const {
731     // If the instruction has not started yet, it is not alive.
732     if (!IsPlaced(instruction->instruction)) {
733       return false;
734     }
735     for (const HloInstruction* user : instruction->instruction->users()) {
736       if (!IsPlaced(user)) {
737         // If there is an unplaced user, consider this instruction currently
738         // live.
739         return true;
740       }
741     }
742     return false;
743   }
744 
745   // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,const Shape & shape,const ShapeIndex & index,UsesList && uses,bool live_out,bool has_indirect_uses)746   Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
747                     const ShapeIndex& index, UsesList&& uses, bool live_out,
748                     bool has_indirect_uses) {
749     int buffer_id = buffers_.size();
750     auto get_num_of_unique_users = [](const UsesList& uses) -> int64_t {
751       absl::flat_hash_set<Item*> users_set;
752       for (const ItemUse& use : uses) {
753         users_set.insert(use.user);
754       }
755       return users_set.size();
756     };
757     buffers_.push_back(Buffer{
758         buffer_id, defining_instruction, size_function_(shape), shape, live_out,
759         has_indirect_uses, index, uses, get_num_of_unique_users(uses)});
760     return buffers_.back();
761   }
762 
763   const HloComputation* computation_;
764 
765   // Instruction list containing the ordering of instructions in
766   // computation_. This is the order in which instructions are placed
767   // (BeginInstruction/EndInstruction calls).
768   const InstructionList& instruction_list_;
769 
770   // Size function returns the bytes of a given buffer.
771   const HloRematerialization::ShapeSizeFunction& size_function_;
772 
773   // Converts a shape into compact form, returns the same shape if a shape is
774   // already considered compact.
775   const HloRematerialization::CompactShapeFunction& compact_shape_function_;
776 
777   // A map that caches existing known compact shape for each instruction.
778   absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
779 
780   // Memory usage at the currently placed instruction.
781   int64_t memory_usage_ = 0;
782 
783   // The instruction currently being placed. This value is non-null only
784   // between the calling of BeginInstruction and EndInstruction.
785   Item* in_progress_item_ = nullptr;
786 
787   HloRematerialization::RematerializationMode mode_;
788   // All buffers in the computation.
789   std::vector<Buffer> buffers_;
790 };
791 
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const HloRematerialization::CompactShapeFunction & compact_shape_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list,HloRematerialization::RematerializationMode mode)792 MemoryUsageTracker::MemoryUsageTracker(
793     const HloComputation* computation,
794     const HloRematerialization::ShapeSizeFunction& size_function,
795     const HloRematerialization::CompactShapeFunction& compact_shape_function,
796     const TuplePointsToAnalysis& points_to_analysis,
797     const InstructionList& instruction_list,
798     HloRematerialization::RematerializationMode mode)
799     : computation_(computation),
800       instruction_list_(instruction_list),
801       size_function_(size_function),
802       compact_shape_function_(compact_shape_function),
803       mode_(mode) {
804   PointsToSet::BufferSet live_out_set =
805       points_to_analysis.GetPointsToSet(computation_->root_instruction())
806           .CreateFlattenedSet();
807   absl::flat_hash_map<const LogicalBuffer*, BufferId>
808       logical_buffer_to_buffer_id;
809   for (auto* item = instruction_list_.first(); item != nullptr;
810        item = instruction_list_.next(item)) {
811     const HloInstruction* const instruction = item->instruction;
812     for (const LogicalBuffer* logical_buffer :
813          points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
814       Buffer* buffer;
815       if (instruction->opcode() == HloOpcode::kWhile) {
816         // The while instruction defines no new buffers. Instead it reuses the
817         // buffers of its operand. Find the Buffer of its operand at the
818         // proper ShapeIndex.
819         const PointsToSet& operand_points_to =
820             points_to_analysis.GetPointsToSet(instruction->operand(0));
821         CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
822         const LogicalBuffer* source_logical_buffer =
823             operand_points_to.element(logical_buffer->index())[0];
824         buffer =
825             &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
826 
827         // Mark buffer as has indirect use and live out.
828         buffer->has_indirect_uses = true;
829         buffer->live_out =
830             buffer->live_out || ContainsKey(live_out_set, logical_buffer);
831 
832         // Add users of while to Buffer users.
833         bool unused;
834         for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
835                                            points_to_analysis, &unused)) {
836           auto existing_user_it = absl::c_find_if(
837               buffer->users,
838               [&](const ItemUse& use) { return user_item.user == use.user; });
839           if (existing_user_it == buffer->users.end()) {
840             buffer->unfinished_user_count++;
841             user_item.user->buffers_used.push_back(buffer->id);
842             buffer->users.push_back(user_item);
843           }
844         }
845       } else {
846         buffer = &CreateBufferFromLogicalBuffer(
847             logical_buffer, points_to_analysis,
848             ContainsKey(live_out_set, logical_buffer));
849         item->buffers_defined.push_back(buffer->id);
850         for (ItemUse& user : buffer->users) {
851           if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
852             user.user->buffers_used.push_back(buffer->id);
853           }
854         }
855       }
856 
857       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
858     }
859 
860     // Trace the output of each instruction. This is so that we can properly
861     // track which outputs does GTEs have.
862     for (const LogicalBuffer* logical_buffer :
863          points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
864       item->buffers_output.push_back(
865           logical_buffer_to_buffer_id[logical_buffer]);
866     }
867   }
868   XLA_VLOG_LINES(10, ToString());
869   DCHECK(Check());
870 }
871 
BeginInstruction(Item * item)872 Status MemoryUsageTracker::BeginInstruction(Item* item) {
873   const HloInstruction* instruction = item->instruction;
874   VLOG(3) << "BeginInstruction " << instruction->name();
875   TF_RET_CHECK(in_progress_item_ == nullptr);
876   in_progress_item_ = item;
877 
878   item->placed = true;
879 
880   // All buffers defined by this instruction need memory.
881   for (BufferId buffer_id : item->buffers_defined) {
882     VLOG(3) << "  Buffer " << buffers_.at(buffer_id).ToString()
883             << " is now live.";
884     memory_usage_ += AllocatedSize(buffer_id);
885   }
886 
887   // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
888   // operand. Account for this potential reuse here.
889 
890   VLOG(3) << "  memory usage = " << memory_usage_;
891   VLOG(10) << ToString();
892 
893   if (VLOG_IS_ON(1)) {
894     DCHECK(Check());
895   }
896   return OkStatus();
897 }
898 
EndInstruction()899 Status MemoryUsageTracker::EndInstruction() {
900   TF_RET_CHECK(in_progress_item_ != nullptr);
901   VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
902 
903   for (BufferId buffer_id : in_progress_item_->buffers_used) {
904     Buffer& buffer = buffers_.at(buffer_id);
905     buffer.unfinished_user_count--;
906     TF_RET_CHECK(buffer.unfinished_user_count >= 0)
907         << buffer.ToString() << " has negative unfinished user count.";
908     if (buffer.unfinished_user_count == 0) {
909       // Buffer is now dead.
910       VLOG(3) << "  " << buffer.ToString() << " is now dead.";
911       memory_usage_ -= AllocatedSize(buffer_id);
912       // The memory usage can become negative inside the computation as we can
913       // free up the parameter space and reuse it for other tensors.
914     }
915   }
916 
917   // If any buffer defined by this instruction has no uses, then memory can be
918   // reclaimed immediately.
919   for (BufferId buffer_id : in_progress_item_->buffers_defined) {
920     const Buffer& buffer = buffers_.at(buffer_id);
921     if (buffer.unfinished_user_count == 0) {
922       VLOG(3) << "  " << buffer.ToString() << " is immediately dead.";
923       memory_usage_ -= AllocatedSize(buffer_id);
924       // The memory usage can become negative inside the computation as we can
925       // free up the parameter space and reuse it for other tensors.
926     }
927   }
928 
929   in_progress_item_ = nullptr;
930 
931   VLOG(3) << "  memory usage = " << memory_usage_;
932   VLOG(10) << ToString();
933 
934   if (VLOG_IS_ON(1)) {
935     DCHECK(Check());
936   }
937   return OkStatus();
938 }
939 
MemoryReducedIfCompressed(Item * item,const Shape & compact_shape) const940 int64_t MemoryUsageTracker::MemoryReducedIfCompressed(
941     Item* item, const Shape& compact_shape) const {
942   CHECK_NE(in_progress_item_, nullptr);
943   if (!item->placed || item == in_progress_item_) {
944     return 0;
945   }
946 
947   int64_t memory_reduced = 0;
948 
949   // We only compress a single piece of an output at one time.
950   CHECK_EQ(item->buffers_output.size(), 1);
951   BufferId buffer_id = item->buffers_output[0];
952   if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) &&
953       IsInstructionCurrentlyLive(item)) {
954     const Buffer& buffer = buffers_.at(buffer_id);
955     memory_reduced += buffer.size;
956 
957     int64_t compact_shape_size = size_function_(compact_shape);
958     // Account for buffers that are compressed after instruction.
959     memory_reduced -= compact_shape_size;
960   }
961   return memory_reduced;
962 }
963 
MemoryReducedIfRematerialized(absl::Span<const Item * const> items) const964 int64_t MemoryUsageTracker::MemoryReducedIfRematerialized(
965     absl::Span<const Item* const> items) const {
966   CHECK_NE(in_progress_item_, nullptr);
967   int64_t memory_reduced = 0;
968   absl::flat_hash_set<const Item*> remat_candidates;
969 
970   for (const Item* item : items) {
971     if (!item->placed || item == in_progress_item_) {
972       LOG(WARNING) << "Unplaced item or in progress item being checked for "
973                       "rematerialization.";
974       return 0;
975     }
976 
977     // Compute the amount of memory reduced (if any) by rematerializing
978     // 'item->instruction'. The LogicalBuffers defined by 'item->instruction'
979     // will no longer be live at this program point, so initially set
980     // memory_reduced to the size of its defined values.
981     for (BufferId buffer_id : item->buffers_defined) {
982       const Buffer& buffer = buffers_.at(buffer_id);
983       // Avoid rematerializing instructions with indirect uses as it is
984       // difficult to reason about liveness after rematerializing the
985       // instruction.
986       // Avoid rematerializing instructions with live out buffers.
987       // Avoid rematerializing buffers that are in nested tuples.
988       // TODO(mpurohit): Check why live_out buffers are an issue here.
989       if (buffer.has_indirect_uses || buffer.live_out ||
990           buffer.index.size() > 1) {
991         return 0;
992       }
993       if (IsInUse(buffer_id)) {
994         return 0;
995       }
996       if (IsCurrentlyLive(buffer_id)) {
997         memory_reduced += AllocatedSize(buffer_id);
998       }
999     }
1000 
1001     // Account for any logical buffers whose live range must be extended across
1002     // this program point.
1003     for (BufferId buffer_id : item->buffers_used) {
1004       if (!IsCurrentlyLive(buffer_id)) {
1005         // This logical buffer is used by 'item->instruction' but is not live at
1006         // this program point. Rematerializing 'item->instruction' will extend
1007         // the buffer's live range across this program point unless it is
1008         // defined by an instruction that is also being rematerialized.
1009         Item* defining_instruction =
1010             buffers_.at(buffer_id).defining_instruction;
1011         if (!remat_candidates.contains(defining_instruction)) {
1012           memory_reduced -= AllocatedSize(buffer_id);
1013         }
1014       }
1015     }
1016     remat_candidates.insert(item);
1017   }
1018 
1019   return memory_reduced;
1020 }
1021 
AddCompressInstructions(Item * original_item,Item * compressed_item,Item * uncompressed_item)1022 Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
1023                                                    Item* compressed_item,
1024                                                    Item* uncompressed_item) {
1025   // Original buffer is now dead.
1026   memory_usage_ -= size_function_(original_item->instruction->shape());
1027   // Compressed buffer is now alive.
1028   memory_usage_ += size_function_(compressed_item->instruction->shape());
1029 
1030   UsesList placed_users;
1031   UsesList unplaced_users;
1032   CHECK_EQ(original_item->buffers_output.size(), 1);
1033   BufferId original_buffer_id = original_item->buffers_output[0];
1034   Buffer& original_buffer = buffers_.at(original_buffer_id);
1035   for (ItemUse& user : original_buffer.users) {
1036     if (user.user->placed) {
1037       CHECK(IsFinished(user.user)) << user.user->instruction->name();
1038       placed_users.push_back(user);
1039     } else {
1040       unplaced_users.push_back(user);
1041     }
1042   }
1043   original_buffer.users = std::move(placed_users);
1044   original_buffer.unfinished_user_count = 0;
1045   original_buffer.users.push_back(ItemUse{compressed_item, 0, std::nullopt});
1046   // We are reallocating the vector containing the buffers potentially,
1047   // invalidating the original_buffer reference, so copy the index that we need
1048   // across NewBuffer calls.
1049   ShapeIndex copied_index = original_buffer.index;
1050   Buffer& compressed_buffer =
1051       NewBuffer(compressed_item, compressed_item->instruction->shape(),
1052                 copied_index, {ItemUse{uncompressed_item, 0, std::nullopt}},
1053                 /*live_out=*/false,
1054                 /*has_indirect_uses=*/false);
1055   compressed_item->buffers_used = original_item->buffers_output;
1056   compressed_item->buffers_output = {compressed_buffer.id};
1057   compressed_item->buffers_defined.push_back(compressed_buffer.id);
1058 
1059   Buffer& uncompressed_buffer =
1060       NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
1061                 copied_index, std::move(unplaced_users), /*live_out=*/false,
1062                 /*has_indirect_uses=*/false);
1063 
1064   uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
1065   uncompressed_item->buffers_output = {uncompressed_buffer.id};
1066   uncompressed_item->buffers_defined = {uncompressed_buffer.id};
1067 
1068   for (ItemUse& user : uncompressed_buffer.users) {
1069     BufferIdList& buffers_used = user.user->buffers_used;
1070     std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
1071                  uncompressed_buffer.id);
1072   }
1073 
1074   return OkStatus();
1075 }
1076 
AddRematerializedInstruction(Item * original_item,Item * remat_item,absl::Span<Item * > indirect_users)1077 Status MemoryUsageTracker::AddRematerializedInstruction(
1078     Item* original_item, Item* remat_item, absl::Span<Item*> indirect_users) {
1079   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
1080           << original_item->instruction->name()
1081           << ", remat_instruction = " << remat_item->instruction->name();
1082 
1083   TF_RET_CHECK(in_progress_item_ != nullptr);
1084   TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
1085   TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
1086 
1087   // Construct the list of buffers used and defined by the rematerialization.
1088   remat_item->buffers_used = original_item->buffers_used;
1089 
1090   // Account for the additional buffer uses created by the new rematerialization
1091   // instruction. Update memory usage if the rematerialization makes a dead
1092   // buffer live again.
1093   for (BufferId buffer_id : original_item->buffers_used) {
1094     Buffer& buffer = buffers_.at(buffer_id);
1095     if (buffer.unfinished_user_count == 0) {
1096       // Buffer used by this instruction was dead, now is alive.
1097       memory_usage_ += AllocatedSize(buffer.id);
1098     }
1099     buffer.unfinished_user_count++;
1100     absl::InlinedVector<ItemUse, 2> filtered_users;
1101     std::copy_if(buffer.users.begin(), buffer.users.end(),
1102                  std::back_inserter(filtered_users),
1103                  [&](const ItemUse& iu) { return iu.user == original_item; });
1104     for (ItemUse& u : filtered_users) {
1105       buffer.users.push_back(ItemUse{remat_item, u.operand_number, u.index});
1106     }
1107   }
1108 
1109   const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1110                                                       indirect_users.end());
1111   // Create a new set of Buffers defined by the new rematerialization
1112   // instruction. Update the internal data structures and memory use to account
1113   // for them.
1114   for (BufferId old_buffer_id : original_item->buffers_defined) {
1115     Buffer& old_buffer = buffers_.at(old_buffer_id);
1116 
1117     UsesList placed_users;
1118     UsesList unplaced_users;
1119     for (ItemUse& user : old_buffer.users) {
1120       if (user.user->placed) {
1121         placed_users.push_back(user);
1122       } else {
1123         // We keep only the indirect users that are in the provided list.
1124         // We consider all the other dead and remove any buffer use they might
1125         // perform and remove it from the buffer user list.
1126         if (!IsSupportedIndirectUser(user.user->instruction) ||
1127             indirect_users_set.contains(user.user)) {
1128           unplaced_users.push_back(user);
1129         } else {
1130           CHECK(user.user->buffers_defined.empty())
1131               << "Buffers defined expected to be empty for use passthrough "
1132                  "instructions";
1133           user.user->buffers_output.clear();
1134           user.user->buffers_used.clear();
1135         }
1136       }
1137     }
1138     old_buffer.users = std::move(placed_users);
1139     old_buffer.unfinished_user_count = 0;
1140 
1141     // Buffer is now dead.
1142     memory_usage_ -= AllocatedSize(old_buffer.id);
1143 
1144     Buffer& new_buffer =
1145         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
1146 
1147     remat_item->buffers_defined.push_back(new_buffer.id);
1148     auto update_buffers = [old_buffer_id, new_buffer_id = new_buffer.id](
1149                               BufferIdList& to_update) {
1150       std::replace(to_update.begin(), to_update.end(), old_buffer_id,
1151                    new_buffer_id);
1152     };
1153     // Update users with the id of the new buffer.
1154     for (ItemUse& user : new_buffer.users) {
1155       update_buffers(user.user->buffers_used);
1156       update_buffers(user.user->buffers_output);
1157     }
1158   }
1159 
1160   // Update the indirect users with the id of the new buffers.
1161   for (Item* indirect_user : indirect_users) {
1162     // Source of the buffers that are gonna be passthrough.
1163     const Item* source_item =
1164         instruction_list_.GetItem(indirect_user->instruction->operand(0));
1165     switch (indirect_user->instruction->opcode()) {
1166       case HloOpcode::kBitcast: {
1167         // If the source is another indirect user then copy the output
1168         // in the used and output lists of the bitcast as they don't define any
1169         // buffer.
1170         if (IsSupportedIndirectUser(source_item->instruction)) {
1171           indirect_user->buffers_used = source_item->buffers_output;
1172           indirect_user->buffers_output = source_item->buffers_output;
1173         } else {
1174           // If it's a real instruction producing a buffer then copy the defined
1175           // buffers into used and output.
1176           indirect_user->buffers_used = source_item->buffers_defined;
1177           indirect_user->buffers_output = source_item->buffers_defined;
1178         }
1179         break;
1180       }
1181       case HloOpcode::kGetTupleElement: {
1182         // GTEs just use the tuple buffer and output the buffer they actually
1183         // extract from the tuple.
1184         const HloGetTupleElementInstruction* gte =
1185             Cast<HloGetTupleElementInstruction>(indirect_user->instruction);
1186         for (BufferId buffer_id : source_item->buffers_defined) {
1187           const Buffer& def_buffer = buffers_.at(buffer_id);
1188           if (def_buffer.index == ShapeIndex{gte->tuple_index()}) {
1189             indirect_user->buffers_output.push_back(buffer_id);
1190           }
1191           // This is the tuple buffer.
1192           if (def_buffer.index.empty()) {
1193             indirect_user->buffers_used.push_back(buffer_id);
1194           }
1195         }
1196         break;
1197       }
1198       default: {
1199         LOG(FATAL) << "Unsupported indirect instruction with opcode "
1200                    << HloOpcodeString(indirect_user->instruction->opcode());
1201         break;
1202       }
1203     }
1204     // Fixup buffer users for the indirect instructions. For GTEs is only the
1205     // tuple buffer, while for bitcast is the buffer they pass through.
1206     for (BufferId buffer_id : indirect_user->buffers_used) {
1207       Buffer& buffer = buffers_.at(buffer_id);
1208       buffer.unfinished_user_count++;
1209       buffer.users.push_back(ItemUse{indirect_user, 0, std::nullopt});
1210     }
1211   }
1212 
1213   VLOG(3) << "  memory usage = " << memory_usage_;
1214   XLA_VLOG_LINES(10, ToString());
1215 
1216   DCHECK(Check());
1217 
1218   return OkStatus();
1219 }
1220 
ToString() const1221 std::string MemoryUsageTracker::ToString() const {
1222   std::string output =
1223       absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
1224   absl::StrAppend(&output,
1225                   "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
1226                   memory_usage(), " bytes)");
1227   for (auto* item = instruction_list_.first(); item != nullptr;
1228        item = instruction_list_.next(item)) {
1229     const HloInstruction* instruction = item->instruction;
1230     std::string inprogress = item == in_progress_item_ ? " in-progress" : "";
1231     std::string placed = item->placed ? " placed" : "";
1232     absl::StrAppend(&output, "  ", instruction->name(), inprogress, placed,
1233                     "\n    Defines:\n");
1234     for (BufferId buffer_id : item->buffers_defined) {
1235       const Buffer& buffer = buffers_[buffer_id];
1236       std::string live = IsCurrentlyLive(buffer_id) ? " live" : "";
1237       absl::StrAppend(&output, "      ", buffer.ToString(), live, ", ",
1238                       buffer.unfinished_user_count, " unfinished uses\n");
1239     }
1240     absl::StrAppend(&output, "    Outputs:\n");
1241     for (BufferId buffer_id : item->buffers_output) {
1242       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
1243     }
1244     absl::StrAppend(&output, "    Uses:\n");
1245     for (BufferId buffer_id : item->buffers_used) {
1246       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
1247     }
1248   }
1249   return output;
1250 }
1251 
GetCompactShape(const HloInstruction * hlo)1252 StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
1253   auto it = compact_shape_.find(hlo);
1254   if (it != compact_shape_.end()) {
1255     return it->second;
1256   }
1257   const Shape& original_shape = hlo->shape();
1258   TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
1259   compact_shape_[hlo] = min_shape;
1260   return min_shape;
1261 }
1262 
Check() const1263 bool MemoryUsageTracker::Check() const {
1264   auto elements_are_unique = [](const BufferIdList& vec) {
1265     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
1266   };
1267 
1268   // Verify buffers_defined per instruction.
1269   for (auto* instruction : computation_->instructions()) {
1270     const BufferIdList& defined_buffers =
1271         instruction_list_.GetItem(instruction)->buffers_defined;
1272     CHECK(elements_are_unique(defined_buffers))
1273         << "Instruction " << instruction->name()
1274         << " does not have unique defined buffers: "
1275         << absl::StrJoin(defined_buffers, ", ",
1276                          [this](std::string* out, BufferId buffer_id) {
1277                            absl::StrAppend(out,
1278                                            buffers_.at(buffer_id).ToString());
1279                          });
1280 
1281     for (const Buffer& buffer : buffers_) {
1282       if (buffer.defining_instruction->instruction == instruction) {
1283         CHECK(absl::c_linear_search(defined_buffers, buffer.id))
1284             << "Instruction " << instruction->name()
1285             << " defined buffers is missing: " << buffer.ToString();
1286       }
1287     }
1288   }
1289 
1290   // Verify buffers_used per instruction.
1291   for (auto* instruction : computation_->instructions()) {
1292     const BufferIdList& used_buffers =
1293         instruction_list_.GetItem(instruction)->buffers_used;
1294     CHECK(elements_are_unique(used_buffers))
1295         << "Instruction " << instruction->name()
1296         << " does not have unique used buffers: "
1297         << absl::StrJoin(used_buffers, ", ",
1298                          [this](std::string* out, BufferId buffer_id) {
1299                            absl::StrAppend(out,
1300                                            buffers_.at(buffer_id).ToString());
1301                          });
1302   }
1303   for (const Buffer& buffer : buffers_) {
1304     int64_t unfinished_uses = 0;
1305     absl::flat_hash_set<Item*> already_counted_user;
1306     for (const ItemUse& user : buffer.users) {
1307       const BufferIdList& used_buffers = user.user->buffers_used;
1308       CHECK(absl::c_linear_search(used_buffers, buffer.id))
1309           << "Instruction " << user.user->instruction->name()
1310           << " used buffers is missing " << buffer.ToString();
1311       if (!IsFinished(user.user) &&
1312           already_counted_user.insert(user.user).second) {
1313         unfinished_uses++;
1314       }
1315     }
1316     CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
1317         << "Incorrect unplaced use count for " << buffer.ToString();
1318   }
1319   return true;
1320 }
1321 
1322 // Computes and returns the cost of rematerializing the given instruction.
1323 // Cost per rematerialized instruction is defined as:
1324 //
1325 // memory_limit_bytes / memory_reduced
1326 //
1327 // The idea is to choose the operation that will save the most memory for
1328 // rematerialization and do not worry about how much the compute costs since
1329 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64_t memory_reduced,int64_t memory_limit_bytes)1330 int64_t RematerializationCost(const HloInstruction* instruction,
1331                               const MemoryUsageTracker& memory_tracker,
1332                               int64_t memory_reduced,
1333                               int64_t memory_limit_bytes) {
1334   // If none of the users of 'instruction' have been placed in the sequence (as
1335   // tracked by memory_tracker), then rematerialization of 'instruction' is a
1336   // zero-cost move of 'instruction' in the sequence.
1337   if (!absl::c_any_of(instruction->users(),
1338                       [&memory_tracker](const HloInstruction* inst) {
1339                         return memory_tracker.IsPlaced(inst);
1340                       })) {
1341     return 0;
1342   }
1343 
1344   CHECK_GT(memory_reduced, 0);
1345   // Return the inverse of the benefit of rematerialization.
1346   return memory_limit_bytes / memory_reduced;
1347 }
1348 
1349 // Returns a block of up to min_block_size consecutive candidate instructions
1350 // from instruction_list starting from start_item. Returns fewer than
1351 // min_block_size instructions if the block of unplaced instructions starting
1352 // from start_item is smaller than min_block_size.
GetInitialBlock(const InstructionList & instruction_list,const MemoryUsageTracker & tracker,Item * start_item,int min_block_size)1353 std::vector<Item*> GetInitialBlock(const InstructionList& instruction_list,
1354                                    const MemoryUsageTracker& tracker,
1355                                    Item* start_item, int min_block_size) {
1356   std::vector<Item*> item_block;
1357   Item* curr_item = start_item;
1358   for (int i = 0; i < min_block_size; ++i) {
1359     if (curr_item == nullptr || !curr_item->placed ||
1360         tracker.IsInProgressItem(curr_item)) {
1361       break;
1362     }
1363     item_block.push_back(curr_item);
1364     curr_item = instruction_list.next(curr_item);
1365   }
1366   return item_block;
1367 }
1368 
1369 // Returns whether any instruction in 'block' is denylisted or
1370 // non-rematerializable.
AnyDenylistedOrNonRematerializable(const std::vector<Item * > & block,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map)1371 bool AnyDenylistedOrNonRematerializable(
1372     const std::vector<Item*>& block,
1373     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map) {
1374   for (auto* item : block) {
1375     if (item->denylisted) {
1376       return true;
1377     }
1378     if (!CanBeRematerialized(item->instruction, rematerializable_map)) {
1379       return true;
1380     }
1381   }
1382   return false;
1383 }
1384 
1385 std::tuple<std::vector<Item*>, RematStrategy, int>
PickRematerializationCandidates(const InstructionList & instruction_list,int64_t memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,int min_block_size,int max_block_size,int64_t peak_memory_bytes)1386 MemoryUsageTracker::PickRematerializationCandidates(
1387     const InstructionList& instruction_list, int64_t memory_limit_bytes,
1388     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1389     int min_block_size, int max_block_size, int64_t peak_memory_bytes) {
1390   std::vector<Item*> best_items;
1391   int64_t best_cost = 0;
1392   RematStrategy best_strategy;
1393 
1394   int effort = 0;
1395   VLOG(5) << "Picking candidate block with size in [" << min_block_size << ", "
1396           << max_block_size << "]";
1397 
1398   for (auto* start_item = instruction_list.first_skip_node();
1399        start_item != nullptr;
1400        start_item = instruction_list.next_skip_node(start_item)) {
1401     std::vector<Item*> block =
1402         GetInitialBlock(instruction_list, *this, start_item, min_block_size);
1403     if (block.size() < min_block_size) {
1404       // There are no more blocks of size at least min_block_size with unplaced
1405       // instructions.
1406       break;
1407     }
1408     // If any item in the starting block are denylisted or non-rematable, then
1409     // break and move on to next start_item (we can actually move to the last
1410     // invalid item in this block, but let's ignore that optimization for now).
1411     if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) {
1412       continue;
1413     }
1414     while (block.size() <= max_block_size) {
1415       // block size = 1 is treated separately since we consider compression in
1416       // this case only.
1417       if (block.size() == 1) {
1418         auto* item = block[0];
1419         auto* candidate = item->instruction;
1420         if (item->buffers_output.size() == 1 &&
1421             (mode_ ==
1422                  HloRematerialization::RematerializationMode::kCompressOnly ||
1423              mode_ == HloRematerialization::RematerializationMode::
1424                           kRecomputeAndCompress)) {
1425           // Only consider compressing single output instruction.
1426           const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
1427 
1428           if (item->placed && item != in_progress_item_ &&
1429               !output_buffer.live_out) {
1430             const Shape& original_shape = item->instruction->shape();
1431             if (original_shape.IsArray()) {
1432               Shape compact_shape =
1433                   GetCompactShape(item->instruction).ValueOrDie();
1434               const int64_t memory_reduced =
1435                   MemoryReducedIfCompressed(item, compact_shape);
1436               // Since the compressed and uncompressed buffers need to be alive
1437               // while performing the compression/uncompression, only perform
1438               // the compression if the sum of the two sizes is less than the
1439               // peak memory.
1440               const int64_t size = size_function_(item->instruction->shape());
1441               const int64_t reduced_size = size_function_(compact_shape);
1442               effort++;
1443               if (memory_reduced > 0 &&
1444                   size + reduced_size < peak_memory_bytes) {
1445                 const int64_t cost = memory_limit_bytes / memory_reduced;
1446                 if (best_items.empty() || cost < best_cost) {
1447                   VLOG(3) << "candidate " << candidate->name() << "("
1448                           << candidate->ToShortString() << ")"
1449                           << " now best when compressed into "
1450                           << compact_shape.ToString(true);
1451                   RematStrategy strategy;
1452                   strategy.kind = RematStrategy::kCompress;
1453                   best_strategy = strategy;
1454                   best_strategy.compact_shape = compact_shape;
1455                   best_items = block;
1456                   best_cost = cost;
1457                 }
1458               }
1459             }
1460           }
1461         }
1462       }
1463       // Do not consider recomputation in compress-only mode.
1464       if (mode_ == HloRematerialization::RematerializationMode::kCompressOnly) {
1465         // break out of this loop. Move on to the next start_item.
1466         break;
1467       }
1468       // If any of the candidate's control successor has been placed, we need
1469       // to skip this candidate. Otherwise we will violate control dependency.
1470       bool control_successor_placed = false;
1471       for (auto* item : block) {
1472         HloInstruction* candidate = item->instruction;
1473         if (std::any_of(candidate->control_successors().begin(),
1474                         candidate->control_successors().end(),
1475                         [this](const HloInstruction* inst) {
1476                           return IsPlaced(inst);
1477                         })) {
1478           control_successor_placed = true;
1479           break;
1480         }
1481       }
1482       if (control_successor_placed) {
1483         // break out of this loop. Move on to the next start_item.
1484         break;
1485       }
1486       VLOG(5) << "Block contains:";
1487       for (auto* hlo : block) {
1488         VLOG(5) << hlo->instruction->name();
1489       }
1490       const int64_t memory_reduced = MemoryReducedIfRematerialized(block);
1491       effort++;
1492       if (memory_reduced > 0) {
1493         const int cost =
1494             RematerializationCost(block, memory_reduced, memory_limit_bytes);
1495 
1496         VLOG(5) << "Candidate block of size " << block.size()
1497                 << " starting from " << block[0]->instruction->name()
1498                 << ", memory reduced " << memory_reduced << ", cost per byte "
1499                 << cost;
1500 
1501         if (best_items.empty() || cost < best_cost) {
1502           VLOG(5) << "Candidate block of size " << block.size()
1503                   << " starting from " << block[0]->instruction->name()
1504                   << " now best";
1505           best_strategy.kind = RematStrategy::kRecompute;
1506           best_items = block;
1507           best_cost = cost;
1508         }
1509       }
1510 
1511       // Time to update the block to include the next instruction.
1512       auto* last_item = block[block.size() - 1];
1513       auto* next_item = instruction_list.next(last_item);
1514       if (next_item == nullptr || next_item->denylisted || !next_item->placed ||
1515           next_item == in_progress_item_ ||
1516           !CanBeRematerialized(next_item->instruction, rematerializable_map)) {
1517         break;
1518       }
1519       block.push_back(next_item);
1520     }
1521   }
1522   return {best_items, best_strategy, effort};
1523 }
1524 
HasUnplacedUsers(Item * item) const1525 bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
1526   for (BufferId buffer_id : item->buffers_defined) {
1527     const Buffer& buffer = buffers_.at(buffer_id);
1528     for (const ItemUse& user : buffer.users) {
1529       if (!user.user->placed) {
1530         return true;
1531       }
1532     }
1533   }
1534   return false;
1535 }
1536 
GetItemUses(Item * item) const1537 const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
1538   UsesList combined_users;
1539   for (BufferId buffer_id : item->buffers_defined) {
1540     const Buffer& buffer = buffers_.at(buffer_id);
1541     for (const ItemUse& user : buffer.users) {
1542       combined_users.push_back(user);
1543     }
1544   }
1545   return combined_users;
1546 }
1547 
RematerializeInstructions(MemoryUsageTracker * memory_tracker,std::vector<Item * > * best_items,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,InstructionList * instruction_list,HloRematerialization * rematerialization)1548 StatusOr<int64_t> RematerializeInstructions(
1549     MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
1550     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1551     InstructionList* instruction_list,
1552     HloRematerialization* rematerialization) {
1553   int64_t net_instructions_added = 0;
1554   int64_t total_memory_saved =
1555       memory_tracker->MemoryReducedIfRematerialized(*best_items);
1556   std::vector<std::string> instruction_names(best_items->size());
1557   // Rematerialize the block of instructions in the reverse order to account for
1558   // dependencies between instructions in best_items.
1559   for (int i = best_items->size() - 1; i >= 0; --i) {
1560     Item* best_item = (*best_items)[i];
1561     HloInstruction* best = best_item->instruction;
1562     instruction_names[i] = best->name();
1563     HloComputation* computation = best->parent();
1564 
1565     // If the item to remat has no unplaced users, then skip the
1566     // rematerialization. Such an instruction can appear in best_items because
1567     // it is part of a good block, but does not itself add any benefit.
1568     if (!memory_tracker->HasUnplacedUsers(best_item)) {
1569       continue;
1570     }
1571 
1572     HloInstruction* remat =
1573         computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1574     // Increment channel_id on channel instructions with a channel id.
1575     if (DynCast<HloChannelInstruction>(best) &&
1576         DynCast<HloChannelInstruction>(best)->channel_id()) {
1577       remat->set_channel_id(rematerialization->NextChannelId());
1578     }
1579 
1580     // Add control dependencies to the new operation.
1581     for (auto successor : best->control_successors()) {
1582       TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1583     }
1584     for (auto predecessor : best->control_predecessors()) {
1585       TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1586     }
1587 
1588     Item* remat_item = instruction_list->CreateItem(remat);
1589 
1590     // Replace each remaining use of 'best' with the rematerialization.
1591     absl::InlinedVector<Item*, 4> indirect_users;
1592     absl::flat_hash_map<int64_t, HloInstruction*> gte_cache;
1593     for (auto& user : memory_tracker->GetItemUses(best_item)) {
1594       if (!memory_tracker->IsPlaced(user.user->instruction)) {
1595         VLOG(2) << "  Replacing use of " << best->name() << " in "
1596                 << user.user->instruction->name() << " with " << remat->name();
1597         HloInstruction* remat_use = remat;
1598         HloInstruction* const user_operand =
1599             user.user->instruction->mutable_operand(user.operand_number);
1600         if (remat_use == user_operand) {
1601           continue;
1602         }
1603         // If the output of a multi-output fusion node is forwarded to one of
1604         // its users as is, all the element buffers are also treated as uses
1605         // by that user, which need to be skipped.
1606         if (user.index && remat_use->shape() != user_operand->shape()) {
1607           auto cached_gte = gte_cache.find(*user.index);
1608           if (cached_gte == gte_cache.end()) {
1609             remat_use = computation->AddInstruction(
1610                 HloInstruction::CreateGetTupleElement(
1611                     ShapeUtil::GetTupleElementShape(remat_use->shape(),
1612                                                     *user.index),
1613                     remat_use, *user.index));
1614             indirect_users.push_back(instruction_list->CreateItem(remat_use));
1615             gte_cache[*user.index] = remat_use;
1616           } else {
1617             remat_use = cached_gte->second;
1618           }
1619         }
1620         if (user_operand->shape() != remat_use->shape()) {
1621           remat_use = computation->AddInstruction(
1622               HloInstruction::CreateBitcast(user_operand->shape(), remat_use));
1623           indirect_users.push_back(instruction_list->CreateItem(remat_use));
1624         }
1625         TF_RETURN_IF_ERROR(user.user->instruction->ReplaceOperandWith(
1626             user.operand_number, remat_use));
1627       }
1628     }
1629 
1630     // Account for the rematerialization in the memory tracker.
1631     TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
1632         best_item, remat_item, absl::MakeSpan(indirect_users)));
1633 
1634     // Insert rematerialized instruction right before the earliest unplaced
1635     // use of the instruction *and* the earliest unplaced last use of any
1636     // operands of remat. Unplaced uses of the remat's operands are included
1637     // because we don't want to extend the live range of remat's operands as
1638     // this could increase memory usage.
1639     ItemList place_before;
1640     const absl::flat_hash_set<Item*> indirect_users_set(indirect_users.begin(),
1641                                                         indirect_users.end());
1642     for (auto user : remat->users()) {
1643       if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1644         place_before.push_back(instruction_list->GetItem(user));
1645       }
1646     }
1647     for (auto* indirect_user : indirect_users) {
1648       for (auto user : indirect_user->instruction->users()) {
1649         if (!indirect_users_set.contains(instruction_list->GetItem(user))) {
1650           place_before.push_back(instruction_list->GetItem(user));
1651         }
1652       }
1653     }
1654     for (auto* operand : remat->operands()) {
1655       for (auto* operand_user : operand->users()) {
1656         if (operand_user != remat) {
1657           Item* operand_user_item = instruction_list->GetItem(operand_user);
1658           if (!operand_user_item->placed) {
1659             place_before.push_back(operand_user_item);
1660           }
1661         }
1662       }
1663     }
1664     // Insert rematerialized instruction before any of its successors to
1665     // preserve ordering regarding control dependency.
1666     for (auto successor : remat->control_successors()) {
1667       Item* successor_item = instruction_list->GetItem(successor);
1668       // Assert to make sure we never remat an operation with control
1669       // successor already placed.
1670       CHECK(!successor_item->placed) << successor_item->instruction->name();
1671       place_before.push_back(successor_item);
1672     }
1673     instruction_list->InsertBeforeInstructions(remat_item, place_before);
1674 
1675     for (auto* bitcast : indirect_users) {
1676       instruction_list->InsertBeforeInstructions(bitcast, place_before);
1677     }
1678     // Helper function that looks through indirect users when determining if
1679     // there is an active user for an HloInstruction.
1680     std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
1681       for (auto* u : i->users()) {
1682         if (!IsSupportedIndirectUser(u) || !uses_empty(u)) {
1683           return false;
1684         }
1685       }
1686       return true;
1687     };
1688     // If the rematerialized instruction is dead then rematerialization is
1689     // essentially a move. Don't delete the instruction now because we don't
1690     // want duplicate HloInstruction* values during the course of the
1691     // transformation because we keep maps with HloInstruction* values as
1692     // keys.
1693     if (uses_empty(best)) {
1694       VLOG(2) << best->name() << " is now dead";
1695       if (ContainsKey(*remat_move_instructions, best)) {
1696         // Previously, 'best' was a rematerialization which killed the
1697         // instruction it was a copying of. Now 'remat' is a rematerialization
1698         // of 'best' and kills 'best'. Stop rematerializing this instruction
1699         // to avoid an infinite loop.
1700         instruction_list->Denylist(remat);
1701       }
1702       remat_move_instructions->insert(remat);
1703       net_instructions_added += indirect_users.size();
1704     } else {
1705       net_instructions_added += indirect_users.size() + 1;
1706     }
1707     for (auto* indirect_user : indirect_users) {
1708       instruction_list->Denylist(indirect_user->instruction);
1709     }
1710     if (HloDataflowAnalysis::IsAsynchronousOperationStart(best->opcode()) ||
1711         HloDataflowAnalysis::IsAsynchronousOperationDone(best->opcode())) {
1712       VLOG(2) << "The old instruction " << best->name()
1713               << " is an async op. Removing to maintain one start to one done "
1714                  "invariant to keep the HLO valid.";
1715       TF_RETURN_IF_ERROR(computation->RemoveInstruction(best));
1716     }
1717   }
1718   VLOG(1) << "Rematerializing instructions ["
1719           << absl::StrJoin(instruction_names, ", ") << "] (saving "
1720           << HumanReadableNumBytes(total_memory_saved) << ")";
1721   return net_instructions_added;
1722 }
1723 
CompressInstruction(MemoryUsageTracker * memory_tracker,Item * best_item,const Shape & compact_shape,InstructionList * instruction_list)1724 StatusOr<int64_t> CompressInstruction(MemoryUsageTracker* memory_tracker,
1725                                       Item* best_item,
1726                                       const Shape& compact_shape,
1727                                       InstructionList* instruction_list) {
1728   HloInstruction* best = best_item->instruction;
1729   VLOG(5) << "Transposing instruction " << best->name() << " (saving "
1730           << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1731                  best_item, compact_shape))
1732           << ") to" << compact_shape.ToString(true);
1733 
1734   HloComputation* computation = best->parent();
1735   HloInstruction* compressed = computation->AddInstruction(
1736       HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best),
1737       /*new_name=*/best->name() + ".remat_compressed");
1738 
1739   HloInstruction* uncompressed = computation->AddInstruction(
1740       HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed),
1741       /*new_name=*/best->name() + ".remat_uncompressed");
1742 
1743   Item* compressed_item = instruction_list->CreateItem(compressed);
1744   compressed_item->placed = true;
1745 
1746   Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
1747 
1748   // Replace each remaining use of 'best' with the uncompressed.
1749   std::vector<HloInstruction*> best_users_copy = best->users();
1750   for (HloInstruction* user : best_users_copy) {
1751     if (!memory_tracker->IsPlaced(user)) {
1752       VLOG(5) << "  Replacing use of " << best->name() << " in " << user->name()
1753               << " with " << uncompressed->name();
1754       TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
1755     }
1756   }
1757 
1758   // Account for the rematerialization in the memory tracker.
1759   TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
1760       best_item, compressed_item, uncompressed_item));
1761 
1762   // Insert rematerialized instruction right before the earliest unplaced
1763   // use of the instruction.
1764   ItemList place_before;
1765   for (auto user : uncompressed->users()) {
1766     place_before.push_back(instruction_list->GetItem(user));
1767   }
1768 
1769   instruction_list->Denylist(compressed_item->instruction);
1770   instruction_list->Denylist(uncompressed_item->instruction);
1771 
1772   instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
1773 
1774   instruction_list->InsertAfterInstructions(compressed_item, {best_item});
1775 
1776   return 2;
1777 }
1778 
1779 // A simple struct to encapsulate the number of instructions added during
1780 // rematerialization.
1781 struct InstructionsAdded {
1782   // Total count of instructions rematerialized.
1783   int remat_count;
1784   // Total count of instructions rematerialized minus number of original
1785   // instructions that are now dead.
1786   int net_instructions_added;
1787   // Amount of effort expended to find the instructions to rematerialize.
1788   int effort;
1789 };
1790 
1791 // Rematerializes the best block of instructions of size between min_block_size
1792 // and max_block_size (both inclusive) if at least one candidate block of
1793 // instructions can be found. Returns number of instructions rematerialized.
RematerializeBestBlock(int min_block_size,int max_block_size,MemoryUsageTracker * memory_tracker,InstructionList * instruction_list,int64_t memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * rematerializable_map,absl::flat_hash_set<const HloInstruction * > * remat_move_instructions,HloRematerialization * rematerialization)1794 StatusOr<InstructionsAdded> RematerializeBestBlock(
1795     int min_block_size, int max_block_size, MemoryUsageTracker* memory_tracker,
1796     InstructionList* instruction_list, int64_t memory_limit_bytes,
1797     absl::flat_hash_map<const HloInstruction*, bool>* rematerializable_map,
1798     absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
1799     HloRematerialization* rematerialization) {
1800   CHECK(min_block_size > 0) << "Negative block size.";
1801 
1802   std::vector<Item*> best_items;
1803   RematStrategy best_strategy;
1804   int effort;
1805   std::tie(best_items, best_strategy, effort) =
1806       memory_tracker->PickRematerializationCandidates(
1807           *instruction_list, memory_limit_bytes, rematerializable_map,
1808           min_block_size, max_block_size,
1809           rematerialization->ComputationPeakMemory(
1810               memory_tracker->computation()));
1811   InstructionsAdded num_instructions_added;
1812   num_instructions_added.remat_count = best_items.size();
1813   num_instructions_added.effort = effort;
1814   if (best_items.empty()) {
1815     num_instructions_added.net_instructions_added = 0;
1816     return num_instructions_added;
1817   }
1818 
1819   if (best_strategy.kind == RematStrategy::kCompress) {
1820     CHECK(best_items.size() == 1)
1821         << "More than one instruction compressed simultaneously.";
1822     HloInstruction* best = best_items[0]->instruction;
1823     VLOG(1) << "Compressing instruction " << best->name() << " (saving "
1824             << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
1825                    best_items[0], best_strategy.compact_shape))
1826             << ")";
1827 
1828     TF_ASSIGN_OR_RETURN(
1829         num_instructions_added.net_instructions_added,
1830         CompressInstruction(memory_tracker, best_items[0],
1831                             best_strategy.compact_shape, instruction_list));
1832   } else {
1833     TF_ASSIGN_OR_RETURN(
1834         num_instructions_added.net_instructions_added,
1835         RematerializeInstructions(memory_tracker, &best_items,
1836                                   remat_move_instructions, instruction_list,
1837                                   rematerialization));
1838   }
1839   return num_instructions_added;
1840 }
1841 }  // namespace
1842 
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order,const absl::flat_hash_set<absl::string_view> & execution_threads) const1843 StatusOr<int64_t> HloRematerialization::ComputePeakMemory(
1844     const HloComputation* computation, const HloInstructionSequence& order,
1845     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
1846   InstructionList instruction_list(order);
1847   MemoryUsageTracker tracker(computation, size_function_,
1848                              compact_shape_function_, *points_to_analysis_,
1849                              instruction_list, mode_);
1850   int64_t peak_memory = tracker.memory_usage();
1851   for (auto* item = instruction_list.first(); item != nullptr;
1852        item = instruction_list.next(item)) {
1853     const HloInstruction* instruction = item->instruction;
1854     TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
1855     TF_ASSIGN_OR_RETURN(
1856         int64_t callee_usage,
1857         CalledComputationsMemoryUsage(instruction, execution_threads));
1858     peak_memory =
1859         std::max<int64_t>(peak_memory, tracker.memory_usage() + callee_usage);
1860     TF_RETURN_IF_ERROR(tracker.EndInstruction());
1861   }
1862   VLOG(1) << "Peak memory for " << computation->name() << ": "
1863           << HumanReadableNumBytes(peak_memory);
1864   return peak_memory;
1865 }
1866 
CalledComputationsMemoryUsage(const HloInstruction * instruction,const absl::flat_hash_set<absl::string_view> & execution_threads) const1867 StatusOr<int64_t> HloRematerialization::CalledComputationsMemoryUsage(
1868     const HloInstruction* instruction,
1869     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
1870   const CallSite* callsite =
1871       call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
1872   if (callsite == nullptr || callsite->context() == CallContext::kEmbedded) {
1873     return 0;
1874   }
1875   int64_t callee_usage = 0;
1876   for (const HloComputation* computation : callsite->called_computations()) {
1877     if (!IsExecutionThreadIncluded(execution_threads,
1878                                    computation->execution_thread())) {
1879       continue;
1880     }
1881     TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
1882     callee_usage += computation_peak_memory_.at(computation);
1883   }
1884   return callee_usage;
1885 }
1886 
IsExecutionThreadIncluded(const absl::flat_hash_set<absl::string_view> & execution_threads,absl::string_view thread) const1887 bool HloRematerialization::IsExecutionThreadIncluded(
1888     const absl::flat_hash_set<absl::string_view>& execution_threads,
1889     absl::string_view thread) const {
1890   return execution_threads.empty() || execution_threads.contains(thread);
1891 }
1892 
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64_t memory_limit_bytes,int64_t min_remat_size,const absl::flat_hash_set<absl::string_view> & execution_threads)1893 StatusOr<bool> HloRematerialization::RematerializeComputation(
1894     HloComputation* computation, HloSchedule* schedule,
1895     int64_t memory_limit_bytes, int64_t min_remat_size,
1896     const absl::flat_hash_set<absl::string_view>& execution_threads) {
1897   VLOG(1) << "Rematerializing computation " << computation->name()
1898           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
1899   VLOG(1) << "peak memory usage is "
1900           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
1901   CHECK(!ContainsKey(rematerialized_computations_, computation));
1902 
1903   InstructionList instruction_list(schedule->sequence(computation));
1904   MemoryUsageTracker memory_tracker(
1905       computation, size_function_, compact_shape_function_,
1906       *points_to_analysis_, instruction_list, mode_);
1907 
1908   instruction_list.PromoteNodesToSkip([&](Item* item) {
1909     return memory_tracker.AllocatedSize(item) >= min_remat_size;
1910   });
1911   bool changed = false;
1912 
1913   // If the rematerialization makes the source instruction dead, then the
1914   // rematerialization is added to 'remat_move_instructions' (the
1915   // rematerialization is essentially a move). If the next rematerialization of
1916   // the instruction is also a move then the rematerialization is added to the
1917   // denylist.
1918   absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
1919 
1920   // The map from instructions to their rematerializable status.
1921   absl::flat_hash_map<const HloInstruction*, bool> rematerializable_map;
1922 
1923   // The peak memory of the computation at any point in the instruction
1924   // sequence.
1925   int64_t peak_memory = memory_tracker.memory_usage();
1926 
1927   // Total count of instructions rematerialized.
1928   int64_t remat_count = 0;
1929   // Total count of clones created minus number of original rematerialized
1930   // instructions which are dead.
1931   int64_t net_instructions_added = 0;
1932 
1933   const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1934 
1935   // Iterate through all instructions in the sequence. At each instruction
1936   // (program point) if memory_usage exceeds the specified limit then
1937   // rematerialize HLO instructions until memory_usage is reduced.
1938   int64_t instruction_index = 0;
1939   for (auto* item = instruction_list.first(); item != nullptr;
1940        item = instruction_list.next(item)) {
1941     const HloInstruction* instruction = item->instruction;
1942     TF_ASSIGN_OR_RETURN(
1943         int64_t callee_usage,
1944         CalledComputationsMemoryUsage(instruction, execution_threads));
1945     TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1946 
1947     VLOG(2) << "Program point at " << instruction->name()
1948             << ", memory usage = " << memory_tracker.memory_usage()
1949             << ", callee usage = " << callee_usage << ", [" << instruction_index
1950             << "/" << instruction_list.size() << "]";
1951     instruction_index++;
1952 
1953     // Initialize both min_block_size and max_block_size to 1 so that only
1954     // single instruction rematerialization is considered first.
1955     int min_block_size = 1;
1956     int max_block_size = 1;
1957     // Only trigger rematerialization when the memory usage changes.
1958     if (memory_tracker.AllocatedSize(item) + callee_usage > 0) {
1959       // Finding larger blocks of instructions to rematerialize can be time
1960       // consuming. To limit the amount of time spent attempting to find such
1961       // large blocks, count the amount of effort expended to find single
1962       // instructions to rematerialize and then limit the total amount of effort
1963       // to at most a factor of block_rematerialization_factor_ more.
1964       bool is_first_phase = true;
1965       int64_t first_phase_effort = 0;
1966       int64_t second_phase_effort = 0;
1967       while (memory_tracker.memory_usage() + callee_usage >
1968              memory_limit_bytes) {
1969         VLOG(2) << "Over memory limit at instruction " << instruction->name()
1970                 << ", using "
1971                 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1972                                          callee_usage)
1973                 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1974 
1975         TF_ASSIGN_OR_RETURN(
1976             InstructionsAdded instructions_added,
1977             RematerializeBestBlock(min_block_size, max_block_size,
1978                                    &memory_tracker, &instruction_list,
1979                                    memory_limit_bytes, &rematerializable_map,
1980                                    &remat_move_instructions, this));
1981         net_instructions_added += instructions_added.net_instructions_added;
1982         remat_count += instructions_added.remat_count;
1983         if (is_first_phase) {
1984           first_phase_effort += instructions_added.effort;
1985         } else {
1986           second_phase_effort += instructions_added.effort;
1987         }
1988         VLOG(1) << "memory_usage after rematerialization = "
1989                 << HumanReadableNumBytes(memory_tracker.memory_usage());
1990         if (instructions_added.remat_count == 0) {
1991           // Unable to find a block to rematerialize.
1992           // Consider doubling the block size.
1993           min_block_size = max_block_size + 1;
1994           max_block_size = 2 * max_block_size;
1995           is_first_phase = false;
1996         } else {
1997           // Found a valid block. Reset to start looking for single instructions
1998           // again.
1999           max_rematerialized_block_size_ =
2000               std::max(max_rematerialized_block_size_, max_block_size);
2001           changed = true;
2002           min_block_size = 1;
2003           max_block_size = 1;
2004         }
2005         if (max_block_size > block_size_limit_ ||
2006             second_phase_effort >
2007                 block_rematerialization_factor_ * first_phase_effort) {
2008           break;
2009         }
2010       }
2011     }
2012     const CallSite* callsite = call_graph_node.GetCallSite(instruction);
2013     if (callsite != nullptr &&
2014         callsite->context() == CallContext::kControlFlow &&
2015         memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
2016       // Memory usage exceeds the limit. Try to rematerialize any
2017       // subcomputation(s) that this instruction calls.
2018       VLOG(1) << "Memory usage still over the limit ("
2019               << (memory_tracker.memory_usage() + callee_usage) << " > "
2020               << memory_limit_bytes
2021               << "). Rematerializing computations called by "
2022               << instruction->name();
2023 
2024       // Recompute callee usage to account for any rematerialization performed
2025       // in the callee computations.
2026       for (HloComputation* called_computation :
2027            callsite->called_computations()) {
2028         if (!ContainsKey(rematerialized_computations_, called_computation)) {
2029           // Memory limit for the subcomputation is the memory limit less the
2030           // amount of memory used at this point in the computation.
2031           int64_t subcomputation_memory_limit_bytes = std::max<int64_t>(
2032               0, memory_limit_bytes - memory_tracker.memory_usage());
2033           TF_ASSIGN_OR_RETURN(
2034               bool subcomputation_changed,
2035               RematerializeComputation(called_computation, schedule,
2036                                        subcomputation_memory_limit_bytes,
2037                                        min_remat_size, execution_threads));
2038           changed |= subcomputation_changed;
2039         }
2040       }
2041 
2042       TF_ASSIGN_OR_RETURN(callee_usage, CalledComputationsMemoryUsage(
2043                                             instruction, execution_threads));
2044     }
2045 
2046     peak_memory = std::max<int64_t>(
2047         peak_memory, memory_tracker.memory_usage() + callee_usage);
2048     VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
2049 
2050     TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
2051   }
2052 
2053   // Verify some invariants on the memory tracker.
2054   for (auto* instruction : computation->instructions()) {
2055     CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
2056   }
2057 
2058   VLOG(1) << "In computation " << computation->name() << " rematerialized "
2059           << remat_count << " instructions; " << net_instructions_added
2060           << " net instructions added";
2061   VLOG(1) << "  peak memory usage now " << HumanReadableNumBytes(peak_memory)
2062           << " (was "
2063           << HumanReadableNumBytes(computation_peak_memory_.at(computation))
2064           << ")";
2065 
2066   // Update peak memory used by computation.
2067   computation_peak_memory_.at(computation) = peak_memory;
2068 
2069   // Update order to include rematerialized instructions.
2070   HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
2071   sequence.clear();
2072   for (auto* item = instruction_list.first(); item != nullptr;
2073        item = instruction_list.next(item)) {
2074     HloInstruction* instruction = item->instruction;
2075     sequence.push_back(instruction);
2076   }
2077   rematerialized_computations_.insert(computation);
2078 
2079   instructions_rematerialized_ += remat_count;
2080   net_instructions_added_ += net_instructions_added;
2081 
2082   return changed;
2083 }
2084 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)2085 StatusOr<bool> HloRematerialization::Run(
2086     HloModule* module,
2087     const absl::flat_hash_set<absl::string_view>& execution_threads) {
2088   VLOG(1) << "HloRematerialization() with memory limit of "
2089           << HumanReadableNumBytes(memory_limit_bytes_);
2090   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
2091 
2092   // Initialize pass object state.
2093   computation_peak_memory_.clear();
2094   rematerialized_computations_.clear();
2095   instructions_rematerialized_ = 0;
2096   net_instructions_added_ = 0;
2097 
2098   TF_RET_CHECK(module->has_schedule());
2099   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
2100   next_channel_id_ = hlo_query::NextChannelId(*module);
2101 
2102   // Adjust memory limit to account for the output of the entry
2103   // computation. This is necessary because the per-computation accounting in
2104   // MemoryUsageTracker do not include output as these are typically allocated
2105   // by the caller.
2106   int64_t module_output_size = 0;
2107   ShapeUtil::ForEachSubshape(
2108       module->result_shape(),
2109       [&module_output_size, module, this](const Shape& subshape,
2110                                           const ShapeIndex& output_index) {
2111         module_output_size += size_function_(subshape);
2112       });
2113 
2114   const int64_t adjusted_memory_limit_bytes =
2115       memory_limit_bytes_ - module_output_size;
2116   VLOG(1) << "Adjusted memory limit accounting for output ("
2117           << HumanReadableNumBytes(module_output_size)
2118           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
2119 
2120   // Compute peak memory usage of all computations in the module called in a
2121   // sequential context.
2122   call_graph_ = CallGraph::Build(module);
2123   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
2124       [this, module, &execution_threads](const CallGraphNode& node) -> Status {
2125         if (node.context() == CallContext::kControlFlow &&
2126             IsExecutionThreadIncluded(execution_threads,
2127                                       node.computation()->execution_thread())) {
2128           TF_ASSIGN_OR_RETURN(
2129               computation_peak_memory_[node.computation()],
2130               ComputePeakMemory(node.computation(),
2131                                 module->schedule().sequence(node.computation()),
2132                                 execution_threads));
2133         }
2134         return OkStatus();
2135       },
2136       /*visit_unreachable_nodes=*/false));
2137 
2138   // The peak memory usage of the module equals the peak memory use of the entry
2139   // computation plus the output size of the computation. This is because the
2140   // peak memory for a computation does not include the output as this is
2141   // typically accounted for in the caller.
2142   const int64_t before_peak_memory =
2143       computation_peak_memory_.at(module->entry_computation()) +
2144       module_output_size;
2145   VLOG(1) << "Peak memory usage of module (before): "
2146           << HumanReadableNumBytes(before_peak_memory);
2147   // Subcomputations called by the entry computation will also be
2148   // rematerialized.
2149   TF_ASSIGN_OR_RETURN(
2150       bool changed,
2151       RematerializeComputation(module->entry_computation(), &module->schedule(),
2152                                adjusted_memory_limit_bytes, min_remat_size_,
2153                                execution_threads));
2154   // Rematerialization can introduce dead code. This occurs if all uses of an
2155   // instruction are replaced with rematerializations of the instruction.
2156 
2157   // Stash away the schedule during copy insertion, to avoid validation failures
2158   // while the module is in flux.
2159   HloSchedule saved_schedule = module->schedule();
2160   module->clear_schedule();
2161   TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
2162   changed |= dead_code_removed;
2163 
2164   // After DCE, the module sequence may include instructions which no longer
2165   // exist. Update the schedule and restore it.
2166   TF_RETURN_IF_ERROR(saved_schedule.Update(execution_threads));
2167   TF_RETURN_IF_ERROR(module->set_schedule(std::move(saved_schedule)));
2168   VLOG(1) << "Rematerialized " << instructions_rematerialized_
2169           << " instructions in module " << module->name() << "; "
2170           << net_instructions_added_ << " net instructions added";
2171   const int64_t current_peak_memory =
2172       computation_peak_memory_.at(module->entry_computation()) +
2173       module_output_size;
2174   VLOG(1) << "Peak memory usage of module now "
2175           << HumanReadableNumBytes(current_peak_memory) << " ("
2176           << current_peak_memory << " bytes), was "
2177           << HumanReadableNumBytes(before_peak_memory) << " ("
2178           << before_peak_memory << " bytes)";
2179   const int64_t reduced_peak_memory = before_peak_memory - current_peak_memory;
2180   VLOG(1) << "Reduced peak memory by "
2181           << HumanReadableNumBytes(reduced_peak_memory) << " ("
2182           << reduced_peak_memory << " bytes)";
2183 
2184   if (sizes_ != nullptr) {
2185     sizes_->before_bytes = before_peak_memory;
2186     sizes_->after_bytes = current_peak_memory;
2187   }
2188 
2189   XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
2190 
2191   if (current_peak_memory > memory_limit_bytes_) {
2192     LOG(WARNING) << absl::StrFormat(
2193         "Can't reduce memory use below %s (%d bytes) by rematerialization; "
2194         "only reduced to %s (%d bytes)",
2195         HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
2196         HumanReadableNumBytes(current_peak_memory), current_peak_memory);
2197   }
2198   return changed;
2199 }
2200 
2201 }  // namespace xla
2202