1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_ 18 19 #include <functional> 20 21 #include "tensorflow/compiler/xla/service/hlo_module.h" 22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 23 24 namespace xla { 25 26 // Changes `from_ty op(from_ty a, from_ty b)` into 27 // `from_ty convert(op(to_ty convert(a), to_ty convert(b)))`. 28 // 29 // One place where this pass is useful is for fp16 dots/convs in XLA:CPU. 30 // Although XLA:CPU supports fp16 dots/convs, they are significantly slower than 31 // fp32 convs. This pass lets us run the fp16 dot/conv as "convert to fp32, 32 // run in fp32, then convert back to fp16". (This is of course not 33 // mathematically the same, but it's close enough for our purposes.) 34 // 35 // This pass only considers ops that match `op_matcher` and where all operands 36 // have type `from_ty`. It will not do the correct thing for ops like 37 // dynamic-slice where only some of the arguments should be converted; it's up 38 // to you to avoid matching such ops with `op_matcher`. 39 class ChangeOpDataType : public HloModulePass { 40 public: ChangeOpDataType(PrimitiveType from_ty,PrimitiveType to_ty,std::function<bool (const HloInstruction *)> op_matcher)41 ChangeOpDataType(PrimitiveType from_ty, PrimitiveType to_ty, 42 std::function<bool(const HloInstruction*)> op_matcher) 43 : from_ty_(from_ty), to_ty_(to_ty), op_matcher_(op_matcher) {} 44 name()45 absl::string_view name() const override { return "change-op-data-type"; } 46 StatusOr<bool> Run( 47 HloModule* module, 48 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 49 50 private: 51 PrimitiveType from_ty_; 52 PrimitiveType to_ty_; 53 std::function<bool(const HloInstruction*)> op_matcher_; 54 }; 55 56 } // namespace xla 57 58 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_ 59