1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ 18 19 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 20 #include "tensorflow/compiler/xla/service/hlo_module.h" 21 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 22 23 namespace xla { 24 25 // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not 26 // support BF16 input/output or mixed precision, according to the passed-in 27 // backend-specific BF16 support rules. 28 class BFloat16Normalization : public HloModulePass { 29 public: BFloat16Normalization(const BFloat16Support * bfloat16_support)30 explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) 31 : bfloat16_support_(bfloat16_support) {} 32 33 ~BFloat16Normalization() override = default; name()34 absl::string_view name() const override { return "bf16-normalization"; } 35 36 // Run BF16 normalization on the given computation. Returns whether the 37 // computation was changed. 38 using HloPassInterface::Run; 39 StatusOr<bool> Run( 40 HloModule* module, 41 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 42 43 private: 44 const BFloat16Support* bfloat16_support_; 45 }; 46 47 // A pass that unconditionally removes the mixed F32/BF16 uses in HLO 48 // instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike 49 // BFloat16Normalization, this pass does not use a backend-specific 50 // BFloat16Support, and does not change HLOs that have BF16 data if they do not 51 // use mixed precision; it removes mixed precision even if the backend supports 52 // it. This pass is used to make the HLO module valid for other HLO passes which 53 // do not support mixed precision. Currently, this pass is only used by the 54 // Despecializer, not by our normal compilation flow on TPU. 55 class BFloat16MixedPrecisionRemoval : public HloModulePass { 56 public: BFloat16MixedPrecisionRemoval()57 BFloat16MixedPrecisionRemoval() {} 58 59 ~BFloat16MixedPrecisionRemoval() override = default; 60 name()61 absl::string_view name() const override { 62 return "bf16-mixed-precision-removal"; 63 } 64 65 // Run mixed precision removal on the given computation. Returns whether the 66 // computation was changed. 67 using HloPassInterface::Run; Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)68 StatusOr<bool> Run(HloModule* module, 69 const absl::flat_hash_set<absl::string_view>& 70 execution_threads) override { 71 BFloat16Normalization normalization(&no_mixed_precision_support_); 72 return normalization.Run(module, execution_threads); 73 } 74 75 private: 76 class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support { 77 public: BFloat16SupportForMixedPrecisionRemoval()78 BFloat16SupportForMixedPrecisionRemoval() {} 79 80 ~BFloat16SupportForMixedPrecisionRemoval() override = default; 81 SupportsBF16Operand(const HloInstruction & hlo,int64_t operand_index)82 bool SupportsBF16Operand(const HloInstruction& hlo, 83 int64_t operand_index) const override { 84 return true; 85 } 86 SupportsBF16Output(const HloInstruction & hlo)87 bool SupportsBF16Output(const HloInstruction& hlo) const override { 88 return true; 89 } 90 SupportsMixedPrecisions(const HloInstruction & hlo)91 bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { 92 return false; 93 } 94 } no_mixed_precision_support_; 95 }; 96 97 } // namespace xla 98 99 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ 100