1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h" 17 18 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" 19 #include "tensorflow/compiler/xla/service/hlo_instructions.h" 20 21 namespace xla { 22 IsValueAllowedInAlternateMemory(const HloValue * value)23bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( 24 const HloValue* value) { 25 // If the buffer is a tuple, don't use this algorithm for now. The buffers 26 // that are pointed to by the tuple will still use this algorithm. Because 27 // tuples are cheap to place in the alternate memory (they are just pointers) 28 // we don't need to use prefetch/evict logic. 29 if (value->shape().IsTuple()) { 30 VLOG(4) << "Keeping value " << value->ToShortString() 31 << " in default mem because it is a tuple."; 32 return false; 33 } 34 35 // Don't place scalars in the alternate memory. 36 if (ShapeUtil::IsEffectiveScalar(value->shape())) { 37 VLOG(4) << "Keeping value " << value->ToShortString() 38 << " in default mem because it is a scalar."; 39 return false; 40 } 41 42 // TODO(berkin): Not allocating add-dependencies either since they need to be 43 // treated specially. We should revisit this later. 44 for (const HloPosition& position : value->positions()) { 45 if (position.instruction->opcode() == HloOpcode::kAddDependency) { 46 VLOG(4) << "Keeping value " << value->ToShortString() 47 << " in default mem because it has a " 48 << "add-dependency position."; 49 return false; 50 } 51 } 52 53 // Send and Recv HLOs return a request identifier. These should not be 54 // allocated in the alternate memory. 55 for (const HloPosition& position : value->positions()) { 56 if ((position.instruction->opcode() == HloOpcode::kSend || 57 position.instruction->opcode() == HloOpcode::kRecv) && 58 DynCast<HloSendRecvInstruction>(position.instruction) 59 ->is_host_transfer()) { 60 // TODO(berkin): Host transfers using alternate memory space doesn't seem 61 // to work at the moment. 62 VLOG(4) << "Keeping value " << value->ToShortString() 63 << " in default mem because it is a send/recv buffer used for " 64 "host transfer."; 65 return false; 66 } 67 68 if (auto* custom_call = 69 DynCast<HloCustomCallInstruction>(position.instruction)) { 70 for (const auto& pair : custom_call->output_to_operand_aliasing()) { 71 if (position.index == pair.first) { 72 VLOG(4) << "Keeping value " << value->ToShortString() 73 << " in default mem because it is a custom-call output that " 74 "aliases an operand buffer."; 75 return false; 76 } 77 } 78 } 79 } 80 81 return true; 82 } 83 IsIntervalAllowedInAlternateMemory(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval)84bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( 85 const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) { 86 return IsValueAllowedInAlternateMemory(interval.buffer) && 87 absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); 88 } 89 HoistConstantOperations(HloModule & module)90/*static*/ void MemorySpaceAssignmentUtils::HoistConstantOperations( 91 HloModule& module) { 92 CHECK(module.has_schedule()); 93 HloSchedule& schedule = module.schedule(); 94 for (const HloComputation* computation : module.MakeNonfusionComputations()) { 95 CHECK(schedule.is_computation_scheduled(computation)); 96 const HloInstructionSequence& sequence = schedule.sequence(computation); 97 HloInstructionSequence new_sequence; 98 99 for (HloInstruction* instruction : sequence.instructions()) { 100 if (instruction->opcode() == HloOpcode::kConstant) { 101 new_sequence.push_back(instruction); 102 } 103 } 104 for (HloInstruction* instruction : sequence.instructions()) { 105 if (instruction->opcode() != HloOpcode::kConstant) { 106 new_sequence.push_back(instruction); 107 } 108 } 109 CHECK_EQ(new_sequence.size(), sequence.size()); 110 schedule.set_sequence(computation, new_sequence); 111 } 112 } 113 114 } // namespace xla 115