1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_ 18 19 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 20 #include "mlir/Support/LogicalResult.h" // from @llvm-project 21 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" 22 23 namespace tensorflow { 24 25 // This is a temporary control flag to gradually enable compilation for 26 // operations based on the correctness and performance confidence. For example 27 // Tier 1 operations are simple enough and well tested, so they can be safely 28 // enabled for all models. We'll be introducing new tiers based on the 29 // completeness of lowering and testing, and eventually will remove this flag. 30 enum class CpurtClusteringTier : uint8_t { 31 // All cwise operations (unary, binary, ternary) plus a tf.Transpose. 32 kTier1 = 0, 33 34 // TODO(ezhulenev): Include metadata (shape, reshape) and slicing into tier 2? 35 // TODO(ezhulenev): Include reductions into tier 3? 36 37 // All operations that do have clustering policy. 38 kAll = 1 39 }; 40 41 // Adds policies for clustering operations for TF->CPURT JIT compilation. 42 void populateTfCpurtClusteringPolicies( 43 mlir::TFDevice::ClusteringPolicySet& policies, 44 CpurtClusteringTier tier = CpurtClusteringTier::kAll); 45 46 // Adds policies for propagating constraints through Tensorflow operations. We 47 // do not add `tf.Const` operations to the clusters, however before compilation 48 // we sink some of them into the cluster body, and to properly verify compiled 49 // function body and infer operands constraints we need a policy for constants. 50 void populateTfCpurtConstraintsPolicies( 51 mlir::TFDevice::ClusteringPolicySet& policies, 52 CpurtClusteringTier tier = CpurtClusteringTier::kAll); 53 54 // Returns success if constant value can be sunk into the compiled function. We 55 // currently only support small integer constants that typically correspond to 56 // the reduction dimension, transpose permutation and other similar values that 57 // are required for successful compilation. 58 // 59 // We prefer to keep large constants as `tf.Const` operations outside of the 60 // compiled regions, and rely on the runtime to instantiate them as tensors. 61 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value); 62 63 // Verifies that discovered operations cluster satisfies TF->CPURT JIT 64 // compilation constraints. 65 mlir::LogicalResult VerifyCluster(const mlir::TFDevice::Cluster& cluster); 66 67 } // namespace tensorflow 68 69 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_CPURT_CLUSTERING_H_ 70