xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_
18 
19 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
20 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
22 
23 namespace tensorflow {
24 
25 // This is a temporary control flag to gradually enable compilation for
26 // operations based on the correctness and performance confidence. For example
27 // Tier 1 operations are simple enough and well tested, so they can be safely
28 // enabled for all models. We'll be introducing new tiers based on the
29 // completeness of lowering and testing, and eventually will remove this flag.
30 enum class CpurtClusteringTier : uint8_t {
31   // All cwise operations (unary, binary, ternary) plus a tf.Transpose.
32   kTier1 = 0,
33 
34   // TODO(ezhulenev): Include metadata (shape, reshape) and slicing into tier 2?
35   // TODO(ezhulenev): Include reductions into tier 3?
36 
37   // All operations that do have clustering policy.
38   kAll = 1
39 };
40 
41 // Adds policies for clustering operations for TF->CPURT JIT compilation.
42 void populateTfCpurtClusteringPolicies(
43     mlir::TFDevice::ClusteringPolicySet& policies,
44     CpurtClusteringTier tier = CpurtClusteringTier::kAll);
45 
46 // Adds policies for propagating constraints through Tensorflow operations. We
47 // do not add `tf.Const` operations to the clusters, however before compilation
48 // we sink some of them into the cluster body, and to properly verify compiled
49 // function body and infer operands constraints we need a policy for constants.
50 void populateTfCpurtConstraintsPolicies(
51     mlir::TFDevice::ClusteringPolicySet& policies,
52     CpurtClusteringTier tier = CpurtClusteringTier::kAll);
53 
54 // Returns success if constant value can be sunk into the compiled function. We
55 // currently only support small integer constants that typically correspond to
56 // the reduction dimension, transpose permutation and other similar values that
57 // are required for successful compilation.
58 //
59 // We prefer to keep large constants as `tf.Const` operations outside of the
60 // compiled regions, and rely on the runtime to instantiate them as tensors.
61 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value);
62 
63 // Verifies that discovered operations cluster satisfies TF->CPURT JIT
64 // compilation constraints.
65 mlir::LogicalResult VerifyCluster(const mlir::TFDevice::Cluster& cluster);
66 
67 }  // namespace tensorflow
68 
69 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_
70