xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/custom_call_sharding_helper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_SHARDING_HELPER_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_SHARDING_HELPER_H_
21 
22 namespace xla {
23 
24 // Helper class that helps implement sharding propagation policies for
25 // CustomCalls. It is called and used by the ShardingPropagation pass. Meant to
26 // be overridden by targets.
27 class CustomCallShardingHelper {
28  public:
29   // Function that manipulates an instruction sharding based on a user wanting
30   // to update the sharding of an instruction.
31   virtual HloSharding PropagateUserSharding(const HloInstruction* instruction,
32                                             const HloInstruction* user,
33                                             const HloSharding& sharding) const;
34   // Infer sharding from the operands of an instruction.
35   virtual std::optional<HloSharding> InferShardingFromOperands(
36       const HloInstruction* instruction) const;
37   // Returns if the instruction passed as parameter is a supported custom-call
38   // for which the functions of this class are implemented.
39   virtual bool IsCustomCallShardable(const HloInstruction* instruction) const;
40   virtual ~CustomCallShardingHelper() = default;
41 };
42 
43 }  // namespace xla
44 
45 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CUSTOM_CALL_SHARDING_HELPER_H__
46