1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h"
17
18 #include <algorithm>
19
20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21 #include "mlir/IR/OpDefinition.h" // from @llvm-project
22 #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
23 #include "mlir/Support/LogicalResult.h" // from @llvm-project
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
29 #include "tensorflow/core/platform/mutex.h"
30
31 namespace mlir {
32 namespace TF {
33
34 // Implements a TF specific policy on when constant folding is allowed.
35 // Policy:
36 //
37 // Disable constant folding if operands size is greater than a certain
38 // threshold (`kOperandsSizeThreshold`).
39 //
40 // Otherwise, allow folding if we do not know the shape of an operand or
41 // result i.e., one of these values has non-static shape. If we know all the
42 // shapes, find the total size of the operands and results. Folding of the op is
43 // allowed if one of the following conditions are met:
44 // 1. size of results is less than a certain threshold
45 // (`kResultsSizeThreshold`), or
46 // 2. size of results is within a factor (`kSizeFactor`) of size of operands, or
47 // TODO(b/157226221): Look into other heuristics for constant fold policy.
ShouldBeFolded(Operation * inst)48 static bool ShouldBeFolded(Operation* inst) {
49 bool has_unknown_shape = false;
50 auto get_size = [&](TypeRange types) {
51 int64_t size = 0;
52 for (auto t : types) {
53 auto tensor_type = t.cast<TensorType>();
54 // Ignore types with undefined bit widths.
55 if (!tensor_type.getElementType().isIntOrFloat()) continue;
56 if (!tensor_type.hasStaticShape()) {
57 has_unknown_shape = true;
58 return size;
59 }
60 size += tensor_type.getNumElements() *
61 tensor_type.getElementType().getIntOrFloatBitWidth();
62 }
63 return size;
64 };
65
66 int64_t results_size = get_size(inst->getResultTypes());
67 int64_t operands_size = get_size(inst->getOperandTypes());
68
69 constexpr int kSizeFactor = 2;
70 // TODO(b/233827625): Remove TF_DISABLE_CONSTANT_FOLDING macro.
71 #ifdef TF_DISABLE_CONSTANT_FOLDING
72 constexpr int64_t kResultsSizeThreshold = 0;
73 #else
74 constexpr int64_t kResultsSizeThreshold = (1 << 23); // 1 MB
75 #endif
76 constexpr int64_t kOperandsSizeThreshold = (1 << 30); // 1 GB
77
78 return (operands_size <= kOperandsSizeThreshold) &&
79 (has_unknown_shape || (results_size <= kResultsSizeThreshold) ||
80 (results_size <= kSizeFactor * operands_size));
81 }
82
ConstantFoldFallbackHook(Operation * inst,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)83 LogicalResult ConstantFoldFallbackHook(
84 Operation* inst, ArrayRef<Attribute> operands,
85 SmallVectorImpl<OpFoldResult>& results) { // NOLINT
86 // Instructions with side effects should not be constant folded to preserve
87 // the original semantics. Ops that have no side effect and zero results but
88 // could be folded should have a custom folder instead of relying on the
89 // TensorFlow folding hook.
90 if (inst->getNumResults() == 0 ||
91 inst->hasTrait<OpTrait::TF::NoConstantFold>() ||
92 inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst))
93 return failure();
94
95 // If any of the result types are variants, don't try to constant fold them.
96 // This creates opaque variant constants which lose information and would
97 // require "raising" later.
98 for (auto type : inst->getResultTypes()) {
99 if (auto tensor_type = type.dyn_cast<TensorType>()) {
100 if (tensor_type.getElementType().isa<VariantType>()) {
101 return failure();
102 }
103 }
104 }
105
106 // If all the results are empty and has numerical element types, set results
107 // to empty elements attribute. This is restricted to the numerical element
108 // types as the DenseElementsAttr only supports numerical and string types.
109 // TODO(hinsu): Handle ops that have one of the results empty for constant
110 // propagation.
111 bool has_empty_numerical_results =
112 llvm::all_of(inst->getResultTypes(), [](Type ty) {
113 ShapedType shaped_ty = ty.cast<ShapedType>();
114 Type element_ty = shaped_ty.getElementType();
115 return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 &&
116 element_ty.isIntOrFloat();
117 });
118 if (has_empty_numerical_results &&
119 // TODO(jpienaar): Remove this once some unmodeled op behavior is
120 // addressed.
121 inst->isRegistered()) {
122 for (Type ty : inst->getResultTypes()) {
123 auto shaped_ty = ty.cast<ShapedType>();
124 results.push_back(
125 DenseElementsAttr::get(shaped_ty, llvm::ArrayRef<Attribute>()));
126 }
127 return success();
128 }
129
130 // Do not execute function calls.
131 if (llvm::isa<TF::WhileOp, TF::CaseOp, TF::IfOp, CallOpInterface>(inst)) {
132 return failure();
133 }
134
135 // Determine if we should attempt to fold this operation by considering the
136 // size/size increase due to folding.
137 if (!ShouldBeFolded(inst)) return failure();
138
139 // TODO(jpienaar): Currently this persists the entire program execution. This
140 // should instead be per module/set from the Graph being executed in TF (if
141 // any) so that the value of variables in the context could be read.
142 // Note: Sharing the context is fine as ops are side-effect free.
143 auto initialize = []() -> TFE_Context* {
144 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
145 TF_NewStatus(), TF_DeleteStatus);
146 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)>
147 opts(TFE_NewContextOptions(), TFE_DeleteContextOptions);
148 // Only initialize single CPU.
149 tensorflow::ConfigProto config_proto;
150 // This is conceptually equal to what we do in python/eager/context.py but
151 // with all GPU devices ignored and CPU only set to 1.
152 (*config_proto.mutable_device_count())["CPU"] = 1;
153 (*config_proto.mutable_device_count())["GPU"] = 0;
154 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
155 TF_NewBuffer(), TF_DeleteBuffer);
156 DCHECK(config->data == nullptr);
157
158 // Copy config_proto into config.
159 {
160 const size_t proto_size = config_proto.ByteSizeLong();
161 void* buf = tensorflow::port::Malloc(proto_size);
162 if (buf == nullptr) {
163 LOG(ERROR) << "Failed to allocate memory to serialize ConfigProto "
164 "while creating context options for constant folding";
165 return nullptr;
166 }
167 if (!config_proto.SerializeWithCachedSizesToArray(
168 static_cast<uint8_t*>(buf))) {
169 tensorflow::port::Free(buf);
170 LOG(ERROR) << "Unable to serialize ConfigProto while creating context "
171 "options for constant folding";
172 return nullptr;
173 }
174 config->data = buf;
175 config->length = proto_size;
176 config->data_deallocator = [](void* data, size_t length) {
177 tensorflow::port::Free(data);
178 };
179 }
180
181 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
182 status.get());
183 if (TF_GetCode(status.get()) != TF_OK) {
184 LOG(ERROR) << "Failed to set context options for constant folding: "
185 << status.get();
186 return nullptr;
187 }
188
189 // Input tensors are placed on the host CPU so use the explicit device
190 // policy to fail if no CPU kernels are available for the op.
191 TFE_ContextOptionsSetDevicePlacementPolicy(opts.get(),
192 TFE_DEVICE_PLACEMENT_EXPLICIT);
193 auto ctx = TFE_NewContext(opts.get(), status.get());
194 if (TF_GetCode(status.get()) != TF_OK) {
195 LOG(ERROR) << "Failed to create context for constant folding: "
196 << status.get();
197 return nullptr;
198 }
199 return ctx;
200 };
201 static TFE_Context* ctx = initialize();
202 if (!ctx) return failure();
203
204 // Returns directly if any of the operands is not an elements attributes.
205 if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) {
206 return !attr || !attr.isa<ElementsAttr>();
207 }))
208 return failure();
209
210 SmallVector<ElementsAttr, 4> inputs;
211 inputs.reserve(operands.size());
212 for (auto input : operands) {
213 inputs.push_back(input.cast<ElementsAttr>());
214 }
215
216 // Avoid overlapping folds with the same context.
217 // TODO(jpienaar): Avoid using global context & mutex here.
218 static auto* mu = new tensorflow::mutex();
219 tensorflow::mutex_lock l(*mu);
220 SmallVector<Attribute, 8> constants;
221 LogicalResult status =
222 tensorflow::EvaluateOperation(inst, inputs, ctx, &constants);
223 results.assign(constants.begin(), constants.end());
224 return status;
225 }
226
__anona6891cba0602() 227 static bool init_hooks = ([] () {
228 TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook);
229 }(), true);
230
231 } // namespace TF
232 } // namespace mlir
233