xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
19 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
20 
21 namespace xla {
22 
IsValueAllowedInAlternateMemory(const HloValue * value)23 bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory(
24     const HloValue* value) {
25   // If the buffer is a tuple, don't use this algorithm for now. The buffers
26   // that are pointed to by the tuple will still use this algorithm.  Because
27   // tuples are cheap to place in the alternate memory (they are just pointers)
28   // we don't need to use prefetch/evict logic.
29   if (value->shape().IsTuple()) {
30     VLOG(4) << "Keeping value " << value->ToShortString()
31             << " in default mem because it is a tuple.";
32     return false;
33   }
34 
35   // Don't place scalars in the alternate memory.
36   if (ShapeUtil::IsEffectiveScalar(value->shape())) {
37     VLOG(4) << "Keeping value " << value->ToShortString()
38             << " in default mem because it is a scalar.";
39     return false;
40   }
41 
42   // TODO(berkin): Not allocating add-dependencies either since they need to be
43   // treated specially. We should revisit this later.
44   for (const HloPosition& position : value->positions()) {
45     if (position.instruction->opcode() == HloOpcode::kAddDependency) {
46       VLOG(4) << "Keeping value " << value->ToShortString()
47               << " in default mem because it has a "
48               << "add-dependency position.";
49       return false;
50     }
51   }
52 
53   // Send and Recv HLOs return a request identifier. These should not be
54   // allocated in the alternate memory.
55   for (const HloPosition& position : value->positions()) {
56     if ((position.instruction->opcode() == HloOpcode::kSend ||
57          position.instruction->opcode() == HloOpcode::kRecv) &&
58         DynCast<HloSendRecvInstruction>(position.instruction)
59             ->is_host_transfer()) {
60       // TODO(berkin): Host transfers using alternate memory space doesn't seem
61       // to work at the moment.
62       VLOG(4) << "Keeping value " << value->ToShortString()
63               << " in default mem because it is a send/recv buffer used for "
64                  "host transfer.";
65       return false;
66     }
67 
68     if (auto* custom_call =
69             DynCast<HloCustomCallInstruction>(position.instruction)) {
70       for (const auto& pair : custom_call->output_to_operand_aliasing()) {
71         if (position.index == pair.first) {
72           VLOG(4) << "Keeping value " << value->ToShortString()
73                   << " in default mem because it is a custom-call output that "
74                      "aliases an operand buffer.";
75           return false;
76         }
77       }
78     }
79   }
80 
81   return true;
82 }
83 
IsIntervalAllowedInAlternateMemory(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)84 bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
85     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
86   return IsValueAllowedInAlternateMemory(interval.buffer) &&
87          absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory);
88 }
89 
HoistConstantOperations(HloModule & module)90 /*static*/ void MemorySpaceAssignmentUtils::HoistConstantOperations(
91     HloModule& module) {
92   CHECK(module.has_schedule());
93   HloSchedule& schedule = module.schedule();
94   for (const HloComputation* computation : module.MakeNonfusionComputations()) {
95     CHECK(schedule.is_computation_scheduled(computation));
96     const HloInstructionSequence& sequence = schedule.sequence(computation);
97     HloInstructionSequence new_sequence;
98 
99     for (HloInstruction* instruction : sequence.instructions()) {
100       if (instruction->opcode() == HloOpcode::kConstant) {
101         new_sequence.push_back(instruction);
102       }
103     }
104     for (HloInstruction* instruction : sequence.instructions()) {
105       if (instruction->opcode() != HloOpcode::kConstant) {
106         new_sequence.push_back(instruction);
107       }
108     }
109     CHECK_EQ(new_sequence.size(), sequence.size());
110     schedule.set_sequence(computation, new_sequence);
111   }
112 }
113 
114 }  // namespace xla
115