xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h"
17 
18 #include <algorithm>
19 
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
22 #include "mlir/Interfaces/SideEffectInterfaces.h"  // from @llvm-project
23 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
24 #include "tensorflow/c/tf_status.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
29 #include "tensorflow/core/platform/mutex.h"
30 
31 namespace mlir {
32 namespace TF {
33 
34 // Implements a TF specific policy on when constant folding is allowed.
35 // Policy:
36 //
37 // Disable constant folding if operands size is greater than a certain
38 // threshold (`kOperandsSizeThreshold`).
39 //
40 // Otherwise, allow folding if we do not know the shape of an operand or
41 // result i.e., one of these values has non-static shape. If we know all the
42 // shapes, find the total size of the operands and results. Folding of the op is
43 // allowed if one of the following conditions are met:
44 // 1. size of results is less than a certain threshold
45 // (`kResultsSizeThreshold`), or
46 // 2. size of results is within a factor (`kSizeFactor`) of size of operands, or
47 // TODO(b/157226221): Look into other heuristics for constant fold policy.
ShouldBeFolded(Operation * inst)48 static bool ShouldBeFolded(Operation* inst) {
49   bool has_unknown_shape = false;
50   auto get_size = [&](TypeRange types) {
51     int64_t size = 0;
52     for (auto t : types) {
53       auto tensor_type = t.cast<TensorType>();
54       // Ignore types with undefined bit widths.
55       if (!tensor_type.getElementType().isIntOrFloat()) continue;
56       if (!tensor_type.hasStaticShape()) {
57         has_unknown_shape = true;
58         return size;
59       }
60       size += tensor_type.getNumElements() *
61               tensor_type.getElementType().getIntOrFloatBitWidth();
62     }
63     return size;
64   };
65 
66   int64_t results_size = get_size(inst->getResultTypes());
67   int64_t operands_size = get_size(inst->getOperandTypes());
68 
69   constexpr int kSizeFactor = 2;
70 // TODO(b/233827625): Remove TF_DISABLE_CONSTANT_FOLDING macro.
71 #ifdef TF_DISABLE_CONSTANT_FOLDING
72   constexpr int64_t kResultsSizeThreshold = 0;
73 #else
74   constexpr int64_t kResultsSizeThreshold = (1 << 23);   // 1 MB
75 #endif
76   constexpr int64_t kOperandsSizeThreshold = (1 << 30);  // 1 GB
77 
78   return (operands_size <= kOperandsSizeThreshold) &&
79          (has_unknown_shape || (results_size <= kResultsSizeThreshold) ||
80           (results_size <= kSizeFactor * operands_size));
81 }
82 
ConstantFoldFallbackHook(Operation * inst,ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)83 LogicalResult ConstantFoldFallbackHook(
84     Operation* inst, ArrayRef<Attribute> operands,
85     SmallVectorImpl<OpFoldResult>& results) {  // NOLINT
86   // Instructions with side effects should not be constant folded to preserve
87   // the original semantics. Ops that have no side effect and zero results but
88   // could be folded should have a custom folder instead of relying on the
89   // TensorFlow folding hook.
90   if (inst->getNumResults() == 0 ||
91       inst->hasTrait<OpTrait::TF::NoConstantFold>() ||
92       inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst))
93     return failure();
94 
95   // If any of the result types are variants, don't try to constant fold them.
96   // This creates opaque variant constants which lose information and would
97   // require "raising" later.
98   for (auto type : inst->getResultTypes()) {
99     if (auto tensor_type = type.dyn_cast<TensorType>()) {
100       if (tensor_type.getElementType().isa<VariantType>()) {
101         return failure();
102       }
103     }
104   }
105 
106   // If all the results are empty and has numerical element types, set results
107   // to empty elements attribute. This is restricted to the numerical element
108   // types as the DenseElementsAttr only supports numerical and string types.
109   // TODO(hinsu): Handle ops that have one of the results empty for constant
110   // propagation.
111   bool has_empty_numerical_results =
112       llvm::all_of(inst->getResultTypes(), [](Type ty) {
113         ShapedType shaped_ty = ty.cast<ShapedType>();
114         Type element_ty = shaped_ty.getElementType();
115         return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 &&
116                element_ty.isIntOrFloat();
117       });
118   if (has_empty_numerical_results &&
119       // TODO(jpienaar): Remove this once some unmodeled op behavior is
120       // addressed.
121       inst->isRegistered()) {
122     for (Type ty : inst->getResultTypes()) {
123       auto shaped_ty = ty.cast<ShapedType>();
124       results.push_back(
125           DenseElementsAttr::get(shaped_ty, llvm::ArrayRef<Attribute>()));
126     }
127     return success();
128   }
129 
130   // Do not execute function calls.
131   if (llvm::isa<TF::WhileOp, TF::CaseOp, TF::IfOp, CallOpInterface>(inst)) {
132     return failure();
133   }
134 
135   // Determine if we should attempt to fold this operation by considering the
136   // size/size increase due to folding.
137   if (!ShouldBeFolded(inst)) return failure();
138 
139   // TODO(jpienaar): Currently this persists the entire program execution. This
140   // should instead be per module/set from the Graph being executed in TF (if
141   // any) so that the value of variables in the context could be read.
142   // Note: Sharing the context is fine as ops are side-effect free.
143   auto initialize = []() -> TFE_Context* {
144     std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
145         TF_NewStatus(), TF_DeleteStatus);
146     std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)>
147         opts(TFE_NewContextOptions(), TFE_DeleteContextOptions);
148     // Only initialize single CPU.
149     tensorflow::ConfigProto config_proto;
150     // This is conceptually equal to what we do in python/eager/context.py but
151     // with all GPU devices ignored and CPU only set to 1.
152     (*config_proto.mutable_device_count())["CPU"] = 1;
153     (*config_proto.mutable_device_count())["GPU"] = 0;
154     std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
155         TF_NewBuffer(), TF_DeleteBuffer);
156     DCHECK(config->data == nullptr);
157 
158     // Copy config_proto into config.
159     {
160       const size_t proto_size = config_proto.ByteSizeLong();
161       void* buf = tensorflow::port::Malloc(proto_size);
162       if (buf == nullptr) {
163         LOG(ERROR) << "Failed to allocate memory to serialize ConfigProto "
164                       "while creating context options for constant folding";
165         return nullptr;
166       }
167       if (!config_proto.SerializeWithCachedSizesToArray(
168               static_cast<uint8_t*>(buf))) {
169         tensorflow::port::Free(buf);
170         LOG(ERROR) << "Unable to serialize ConfigProto while creating context "
171                       "options for constant folding";
172         return nullptr;
173       }
174       config->data = buf;
175       config->length = proto_size;
176       config->data_deallocator = [](void* data, size_t length) {
177         tensorflow::port::Free(data);
178       };
179     }
180 
181     TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
182                                 status.get());
183     if (TF_GetCode(status.get()) != TF_OK) {
184       LOG(ERROR) << "Failed to set context options for constant folding: "
185                  << status.get();
186       return nullptr;
187     }
188 
189     // Input tensors are placed on the host CPU so use the explicit device
190     // policy to fail if no CPU kernels are available for the op.
191     TFE_ContextOptionsSetDevicePlacementPolicy(opts.get(),
192                                                TFE_DEVICE_PLACEMENT_EXPLICIT);
193     auto ctx = TFE_NewContext(opts.get(), status.get());
194     if (TF_GetCode(status.get()) != TF_OK) {
195       LOG(ERROR) << "Failed to create context for constant folding: "
196                  << status.get();
197       return nullptr;
198     }
199     return ctx;
200   };
201   static TFE_Context* ctx = initialize();
202   if (!ctx) return failure();
203 
204   // Returns directly if any of the operands is not an elements attributes.
205   if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) {
206         return !attr || !attr.isa<ElementsAttr>();
207       }))
208     return failure();
209 
210   SmallVector<ElementsAttr, 4> inputs;
211   inputs.reserve(operands.size());
212   for (auto input : operands) {
213     inputs.push_back(input.cast<ElementsAttr>());
214   }
215 
216   // Avoid overlapping folds with the same context.
217   // TODO(jpienaar): Avoid using global context & mutex here.
218   static auto* mu = new tensorflow::mutex();
219   tensorflow::mutex_lock l(*mu);
220   SmallVector<Attribute, 8> constants;
221   LogicalResult status =
222       tensorflow::EvaluateOperation(inst, inputs, ctx, &constants);
223   results.assign(constants.begin(), constants.end());
224   return status;
225 }
226 
__anona6891cba0602() 227 static bool init_hooks = ([] () {
228   TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook);
229 }(), true);
230 
231 }  // namespace TF
232 }  // namespace mlir
233