xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/utils/eval_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/transforms/utils/eval_utils.h"
19 
20 #include <cassert>
21 #include <utility>
22 
23 #include "llvm/ADT/STLExtras.h"
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/core/framework/allocator.h"
27 #include "tensorflow/core/framework/control_flow.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/ir/importexport/convert_tensor.h"
30 #include "tensorflow/core/ir/importexport/graphdef_export.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/threadpool.h"
33 #include "tensorflow/core/public/version.h"
34 
35 namespace mlir {
36 namespace tfg {
37 namespace util {
38 
39 // The SimpleDevice is supposed to be used for evaluating single operation. To
40 // avoid the overhead of thread creation. Set a small and conservative number as
41 // the default.
42 static constexpr int kThreads = 2;
43 
SimpleDevice()44 SimpleDevice::SimpleDevice() : DeviceBase(tensorflow::Env::Default()) {
45   eigen_worker_ = std::make_unique<tensorflow::thread::ThreadPool>(
46       tensorflow::Env::Default(), "eval_utils", kThreads);
47 
48   eigen_worker_threads_.num_threads = kThreads;
49   eigen_worker_threads_.workers = eigen_worker_.get();
50 
51   eigen_device_ = std::make_unique<Eigen::ThreadPoolDevice>(
52       eigen_worker_threads_.workers->AsEigenThreadPool(),
53       eigen_worker_threads_.num_threads);
54   set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
55   set_eigen_cpu_device(eigen_device_.get());
56 }
57 
~SimpleDevice()58 SimpleDevice::~SimpleDevice() {}
59 
GetAllocator(tensorflow::AllocatorAttributes attr)60 tensorflow::Allocator *SimpleDevice::GetAllocator(
61     tensorflow::AllocatorAttributes attr) {
62   return tensorflow::cpu_allocator();
63 }
64 
MakeTensorFromProto(const tensorflow::TensorProto & tensor_proto,const tensorflow::AllocatorAttributes alloc_attrs,tensorflow::Tensor * tensor)65 tensorflow::Status SimpleDevice::MakeTensorFromProto(
66     const tensorflow::TensorProto &tensor_proto,
67     const tensorflow::AllocatorAttributes alloc_attrs,
68     tensorflow::Tensor *tensor) {
69   tensorflow::Tensor parsed(tensor_proto.dtype());
70   if (!parsed.FromProto(tensorflow::cpu_allocator(), tensor_proto)) {
71     return tensorflow::errors::InvalidArgument(
72         "Cannot parse tensor from tensor_proto.");
73   }
74   *tensor = std::move(parsed);
75   return ::tensorflow::OkStatus();
76 }
77 
EvaluateOperation(tensorflow::DeviceBase * cpu_device,tensorflow::ResourceMgr * resource_mgr,TFOp op,ArrayRef<ElementsAttr> operands,SmallVectorImpl<TypedAttr> & results)78 LogicalResult EvaluateOperation(tensorflow::DeviceBase *cpu_device,
79                                 tensorflow::ResourceMgr *resource_mgr, TFOp op,
80                                 ArrayRef<ElementsAttr> operands,
81                                 SmallVectorImpl<TypedAttr> &results) {
82   assert(cpu_device && "cpu device can't be null");
83   assert(resource_mgr && "ResourceMgr can't be null");
84 
85   if (llvm::any_of(operands, [](Attribute operand) { return !operand; })) {
86     VLOG(3) << "cannot be evaluated with null operands";
87     return failure();
88   }
89 
90   tensorflow::NodeDef node_def;
91   if (!ConvertToNodeDef(&*op, &node_def, op.getDialect(), [&](Value value) {
92          return GetValueName(value, op.getDialect());
93        }).ok()) {
94     VLOG(3) << "failed to convert operation to NodeDef";
95     return failure();
96   }
97 
98   absl::InlinedVector<tensorflow::Tensor, 4> input_tensors(operands.size());
99   absl::InlinedVector<tensorflow::TensorValue, 4> input_tensor_values(
100       operands.size());
101   // For each operand, convert its ElementsAttr to a Tensor and the Tensor will
102   // be referenced by a TensorValue. To ensure Tensor/TensorValue have thier
103   // lifecycle across the later evaluation. They are stored in
104   // `input_tensors`\`input_tensor_values` respectively. The following loop zips
105   // them together so that the bundled values are related. Note that the
106   // accessor index associates with the order of arguments in llvm::zip.
107   for (auto it : llvm::zip(operands, input_tensors, input_tensor_values)) {
108     auto &[operand, input_tensor, input_tensor_value] = it;
109     if (!ConvertToTensor(operand, &input_tensor).ok()) return failure();
110     input_tensor_value.tensor = &input_tensor;
111   }
112 
113   tensorflow::Status status;
114   std::unique_ptr<tensorflow::OpKernel> op_kernel = tensorflow::CreateOpKernel(
115       "CPU", cpu_device, cpu_device->GetAllocator({}), node_def,
116       TF_GRAPH_DEF_VERSION, &status);
117   if (!status.ok()) {
118     VLOG(3) << status.error_message();
119     return failure();
120   }
121 
122   tensorflow::OpKernelContext::Params params;
123   params.device = cpu_device;
124   params.frame_iter = tensorflow::FrameAndIter(0, 0);
125   params.inputs = input_tensor_values;
126   params.op_kernel = op_kernel.get();
127   params.resource_manager = resource_mgr;
128 
129   absl::InlinedVector<tensorflow::AllocatorAttributes, 4> output_attrs(
130       op_kernel->num_outputs());
131   for (auto &attr : output_attrs) attr.set_on_host(true);
132   params.output_attr_array = output_attrs.data();
133 
134   // Evaluate the operation.
135   tensorflow::OpKernelContext op_context(&params);
136   op_kernel->Compute(&op_context);
137   if (!op_context.status().ok()) {
138     VLOG(3) << op_context.status().error_message();
139     return failure();
140   }
141 
142   // Converts the outputs to MLIR attributes.
143   Builder builder(op->getContext());
144   for (int i = 0; i < op_kernel->num_outputs(); ++i) {
145     // The output is invalidated, returns a `dead` value here.
146     if (op_context.mutable_output(i) == nullptr) {
147       results.push_back(nullptr);
148       continue;
149     }
150 
151     tensorflow::StatusOr<ElementsAttr> attr_or =
152         ConvertTensor(*(op_context.mutable_output(i)), builder);
153     if (!attr_or.status().ok()) {
154       VLOG(3) << attr_or.status().error_message();
155       return failure();
156     }
157     results.push_back(attr_or.ValueOrDie());
158   }
159 
160   return success();
161 }
162 
163 }  // namespace util
164 }  // namespace tfg
165 }  // namespace mlir
166