xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/change_op_data_type.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_
18 
19 #include <functional>
20 
21 #include "tensorflow/compiler/xla/service/hlo_module.h"
22 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
23 
24 namespace xla {
25 
26 // Changes `from_ty op(from_ty a, from_ty b)` into
27 // `from_ty convert(op(to_ty convert(a), to_ty convert(b)))`.
28 //
29 // One place where this pass is useful is for fp16 dots/convs in XLA:CPU.
30 // Although XLA:CPU supports fp16 dots/convs, they are significantly slower than
31 // fp32 convs.   This pass lets us run the fp16 dot/conv as "convert to fp32,
32 // run in fp32, then convert back to fp16".  (This is of course not
33 // mathematically the same, but it's close enough for our purposes.)
34 //
35 // This pass only considers ops that match `op_matcher` and where all operands
36 // have type `from_ty`.  It will not do the correct thing for ops like
37 // dynamic-slice where only some of the arguments should be converted; it's up
38 // to you to avoid matching such ops with `op_matcher`.
39 class ChangeOpDataType : public HloModulePass {
40  public:
ChangeOpDataType(PrimitiveType from_ty,PrimitiveType to_ty,std::function<bool (const HloInstruction *)> op_matcher)41   ChangeOpDataType(PrimitiveType from_ty, PrimitiveType to_ty,
42                    std::function<bool(const HloInstruction*)> op_matcher)
43       : from_ty_(from_ty), to_ty_(to_ty), op_matcher_(op_matcher) {}
44 
name()45   absl::string_view name() const override { return "change-op-data-type"; }
46   StatusOr<bool> Run(
47       HloModule* module,
48       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
49 
50  private:
51   PrimitiveType from_ty_;
52   PrimitiveType to_ty_;
53   std::function<bool(const HloInstruction*)> op_matcher_;
54 };
55 
56 }  // namespace xla
57 
58 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CHANGE_OP_DATA_TYPE_H_
59