1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/transforms/utils/eval_utils.h"
19
20 #include <cassert>
21 #include <utility>
22
23 #include "llvm/ADT/STLExtras.h"
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/Support/LLVM.h" // from @llvm-project
26 #include "tensorflow/core/framework/allocator.h"
27 #include "tensorflow/core/framework/control_flow.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/ir/importexport/convert_tensor.h"
30 #include "tensorflow/core/ir/importexport/graphdef_export.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/threadpool.h"
33 #include "tensorflow/core/public/version.h"
34
35 namespace mlir {
36 namespace tfg {
37 namespace util {
38
39 // The SimpleDevice is supposed to be used for evaluating single operation. To
40 // avoid the overhead of thread creation. Set a small and conservative number as
41 // the default.
42 static constexpr int kThreads = 2;
43
SimpleDevice()44 SimpleDevice::SimpleDevice() : DeviceBase(tensorflow::Env::Default()) {
45 eigen_worker_ = std::make_unique<tensorflow::thread::ThreadPool>(
46 tensorflow::Env::Default(), "eval_utils", kThreads);
47
48 eigen_worker_threads_.num_threads = kThreads;
49 eigen_worker_threads_.workers = eigen_worker_.get();
50
51 eigen_device_ = std::make_unique<Eigen::ThreadPoolDevice>(
52 eigen_worker_threads_.workers->AsEigenThreadPool(),
53 eigen_worker_threads_.num_threads);
54 set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
55 set_eigen_cpu_device(eigen_device_.get());
56 }
57
~SimpleDevice()58 SimpleDevice::~SimpleDevice() {}
59
GetAllocator(tensorflow::AllocatorAttributes attr)60 tensorflow::Allocator *SimpleDevice::GetAllocator(
61 tensorflow::AllocatorAttributes attr) {
62 return tensorflow::cpu_allocator();
63 }
64
MakeTensorFromProto(const tensorflow::TensorProto & tensor_proto,const tensorflow::AllocatorAttributes alloc_attrs,tensorflow::Tensor * tensor)65 tensorflow::Status SimpleDevice::MakeTensorFromProto(
66 const tensorflow::TensorProto &tensor_proto,
67 const tensorflow::AllocatorAttributes alloc_attrs,
68 tensorflow::Tensor *tensor) {
69 tensorflow::Tensor parsed(tensor_proto.dtype());
70 if (!parsed.FromProto(tensorflow::cpu_allocator(), tensor_proto)) {
71 return tensorflow::errors::InvalidArgument(
72 "Cannot parse tensor from tensor_proto.");
73 }
74 *tensor = std::move(parsed);
75 return ::tensorflow::OkStatus();
76 }
77
EvaluateOperation(tensorflow::DeviceBase * cpu_device,tensorflow::ResourceMgr * resource_mgr,TFOp op,ArrayRef<ElementsAttr> operands,SmallVectorImpl<TypedAttr> & results)78 LogicalResult EvaluateOperation(tensorflow::DeviceBase *cpu_device,
79 tensorflow::ResourceMgr *resource_mgr, TFOp op,
80 ArrayRef<ElementsAttr> operands,
81 SmallVectorImpl<TypedAttr> &results) {
82 assert(cpu_device && "cpu device can't be null");
83 assert(resource_mgr && "ResourceMgr can't be null");
84
85 if (llvm::any_of(operands, [](Attribute operand) { return !operand; })) {
86 VLOG(3) << "cannot be evaluated with null operands";
87 return failure();
88 }
89
90 tensorflow::NodeDef node_def;
91 if (!ConvertToNodeDef(&*op, &node_def, op.getDialect(), [&](Value value) {
92 return GetValueName(value, op.getDialect());
93 }).ok()) {
94 VLOG(3) << "failed to convert operation to NodeDef";
95 return failure();
96 }
97
98 absl::InlinedVector<tensorflow::Tensor, 4> input_tensors(operands.size());
99 absl::InlinedVector<tensorflow::TensorValue, 4> input_tensor_values(
100 operands.size());
101 // For each operand, convert its ElementsAttr to a Tensor and the Tensor will
102 // be referenced by a TensorValue. To ensure Tensor/TensorValue have thier
103 // lifecycle across the later evaluation. They are stored in
104 // `input_tensors`\`input_tensor_values` respectively. The following loop zips
105 // them together so that the bundled values are related. Note that the
106 // accessor index associates with the order of arguments in llvm::zip.
107 for (auto it : llvm::zip(operands, input_tensors, input_tensor_values)) {
108 auto &[operand, input_tensor, input_tensor_value] = it;
109 if (!ConvertToTensor(operand, &input_tensor).ok()) return failure();
110 input_tensor_value.tensor = &input_tensor;
111 }
112
113 tensorflow::Status status;
114 std::unique_ptr<tensorflow::OpKernel> op_kernel = tensorflow::CreateOpKernel(
115 "CPU", cpu_device, cpu_device->GetAllocator({}), node_def,
116 TF_GRAPH_DEF_VERSION, &status);
117 if (!status.ok()) {
118 VLOG(3) << status.error_message();
119 return failure();
120 }
121
122 tensorflow::OpKernelContext::Params params;
123 params.device = cpu_device;
124 params.frame_iter = tensorflow::FrameAndIter(0, 0);
125 params.inputs = input_tensor_values;
126 params.op_kernel = op_kernel.get();
127 params.resource_manager = resource_mgr;
128
129 absl::InlinedVector<tensorflow::AllocatorAttributes, 4> output_attrs(
130 op_kernel->num_outputs());
131 for (auto &attr : output_attrs) attr.set_on_host(true);
132 params.output_attr_array = output_attrs.data();
133
134 // Evaluate the operation.
135 tensorflow::OpKernelContext op_context(¶ms);
136 op_kernel->Compute(&op_context);
137 if (!op_context.status().ok()) {
138 VLOG(3) << op_context.status().error_message();
139 return failure();
140 }
141
142 // Converts the outputs to MLIR attributes.
143 Builder builder(op->getContext());
144 for (int i = 0; i < op_kernel->num_outputs(); ++i) {
145 // The output is invalidated, returns a `dead` value here.
146 if (op_context.mutable_output(i) == nullptr) {
147 results.push_back(nullptr);
148 continue;
149 }
150
151 tensorflow::StatusOr<ElementsAttr> attr_or =
152 ConvertTensor(*(op_context.mutable_output(i)), builder);
153 if (!attr_or.status().ok()) {
154 VLOG(3) << attr_or.status().error_message();
155 return failure();
156 }
157 results.push_back(attr_or.ValueOrDie());
158 }
159
160 return success();
161 }
162
163 } // namespace util
164 } // namespace tfg
165 } // namespace mlir
166