1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29
30 namespace tensorflow {
31 namespace {
32
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36 static auto* const m =
37 new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38 return m;
39 }
40
41 const uint32 kIsList = 1U << 31;
42
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44 AttrTypeMap* map = new AttrTypeMap();
45 (*map)["executor_type"] = TF_ATTR_STRING;
46 (*map)["config_proto"] = TF_ATTR_STRING;
47 return map;
48 }
49
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51 static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52 return map;
53 }
54
55 } // namespace
56
OpDefForOp(const string & op_name,const OpDef ** op_def)57 Status OpDefForOp(const string& op_name, const OpDef** op_def) {
58 const OpRegistrationData* op_reg_data = nullptr;
59 Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60 if (s.ok()) {
61 *op_def = &op_reg_data->op_def;
62 }
63 return s;
64 }
65
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67 bool* is_function) {
68 {
69 tf_shared_lock l(g_op_name_to_attr_type_map_lock);
70 *is_function = false;
71 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
72 if (*out != nullptr) return OkStatus();
73 }
74
75 mutex_lock l(g_op_name_to_attr_type_map_lock);
76
77 // Check the existence of AttrTypeMap for op_name again because another thread
78 // may insert this map after the tf_shared_lock is released but before the
79 // mutex_lock is acquired.
80 *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
81 if (*out != nullptr) return OkStatus();
82
83 const OpDef* op_def = nullptr;
84 Status s = OpDefForOp(op_name, &op_def);
85 if (errors::IsNotFound(s)) {
86 // If we did not find the op def, we assume `op_name` is a function.
87 // If it is actually a misspelled op, user will get another error when
88 // trying to run it.
89 // TODO(iga): If we ever have a use case for different attribute specs
90 // in different functions, we will need to look at the OpDef in the
91 // function def to retrieve their types.
92 *out = GetDefaultFunctionAttrTypeMap();
93 *is_function = true;
94 return OkStatus();
95 } else if (!s.ok()) {
96 return s;
97 }
98 std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
99 // TODO(agarwal): Avoid having to create this "registry" at runtime,
100 // perhaps can be done at op registration time?
101 for (const auto& attr : op_def->attr()) {
102 string type = attr.type();
103 const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
104 if (is_list) {
105 type = type.substr(5, type.length() - 6);
106 }
107 uint32 t = is_list ? kIsList : 0;
108 if (type == "string") {
109 t |= TF_ATTR_STRING;
110 } else if (type == "int") {
111 t |= TF_ATTR_INT;
112 } else if (type == "float") {
113 t |= TF_ATTR_FLOAT;
114 } else if (type == "bool") {
115 t |= TF_ATTR_BOOL;
116 } else if (type == "type") {
117 t |= TF_ATTR_TYPE;
118 } else if (type == "shape") {
119 t |= TF_ATTR_SHAPE;
120 } else if (type == "tensor") {
121 t |= TF_ATTR_TENSOR;
122 } else if (type == "func") {
123 t |= TF_ATTR_FUNC;
124 } else {
125 return errors::Unimplemented(
126 "TODO(agarwal): Enable support for ops with attributes of type '",
127 type, "'");
128 }
129 gtl::InsertIfNotPresent(m.get(), attr.name(), t);
130 }
131 *out = m.get();
132 auto r = OpNameToAttrTypeMap()->emplace(op_name, m.release());
133 DCHECK(r.second) << "AttrTypeMap already exists for " << op_name;
134
135 return OkStatus();
136 }
137
138 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE) \
139 template <> \
140 Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const { \
141 auto it = encoded_attrs_.find(string(attr_name)); \
142 if (it == encoded_attrs_.end()) { \
143 return errors::NotFound("No attr named '", attr_name, \
144 "' found in AttrBuilder for ", op_name_); \
145 } \
146 attr_tmp_.ParseFromString(it->second); \
147 TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE)); \
148 *value = attr_tmp_.FIELD(); \
149 return OkStatus(); \
150 }
151
152 DEFINE_GET_ATTR(float, f, "float");
153 DEFINE_GET_ATTR(int, i, "int");
154 DEFINE_GET_ATTR(int64_t, i, "int");
155 DEFINE_GET_ATTR(bool, b, "bool");
156 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
157
158 #undef DEFINE_GET_ATTR
159
160 template <>
Get(StringPiece attr_name,absl::InlinedVector<DataType,4> * value) const161 Status AttrBuilder::Get(StringPiece attr_name,
162 absl::InlinedVector<DataType, 4>* value) const {
163 auto it = encoded_attrs_.find(string(attr_name));
164 if (it == encoded_attrs_.end()) {
165 return errors::NotFound("No attr named '", attr_name,
166 "' found in AttrBuilder for ", op_name_);
167 }
168 attr_tmp_.ParseFromString(it->second);
169 TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, "list(type)"));
170 for (size_t i = 0; i < attr_tmp_.list().type_size(); i++) {
171 value->push_back(attr_tmp_.list().type(i));
172 }
173 return OkStatus();
174 }
175
NumInputs(int n)176 AttrBuilder& AttrBuilder::NumInputs(int n) {
177 DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
178 num_inputs_ = n;
179 return *this;
180 }
181
FillAttrValueMap(AttrValueMap * m) const182 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
183 for (auto& entry : encoded_attrs_) {
184 attr_tmp_.ParseFromString(entry.second);
185 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
186 }
187 // For any attr-value pairs that exist in the op def (from op registry) but
188 // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
189 // specify all the default attr values (e.g. for matmul, the `transpose_a`
190 // attr defaults to false).
191 const OpDef* op_def = nullptr;
192 Status s = OpDefForOp(op_name().c_str(), &op_def);
193 // This is expected, if this op is a custom function, and is therefore not
194 // present in the op registry.
195 if (!s.ok()) return;
196
197 DCHECK(op_def);
198 for (const auto& attr_def : op_def->attr()) {
199 if (attr_def.has_default_value() && !m->count(attr_def.name())) {
200 SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
201 }
202 }
203 }
204
205 namespace {
206
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)207 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
208 const AttrValue& attr_value) {
209 // TODO(iga): It might make sense to augment OpRegistrationData with a
210 // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
211 for (const OpDef::AttrDef& attr_def : op_def->attr()) {
212 if (attr_def.name() == attr_name && attr_def.has_default_value() &&
213 AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
214 return true;
215 }
216 }
217 return false;
218 }
219
220 } // namespace
221
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const222 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
223 const OpDef* op_def = nullptr;
224 Status s = OpDefForOp(op_name().c_str(), &op_def);
225
226 for (auto& entry : encoded_attrs_) {
227 attr_tmp_.ParseFromString(entry.second);
228 // Insert the attr-value pair if we did not find the OpDef or if the value
229 // is different from default.
230 if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
231 m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
232 }
233 }
234 }
235
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)236 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
237 const AttrValue& value) {
238 encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
239 }
240
BuildNodeDef()241 const NodeDef& AttrBuilder::BuildNodeDef() {
242 if (node_def_finalized_) return node_def_;
243 if (!node_def_initialized_) {
244 InitializeNodeDef();
245 }
246 for (int i = 0; i < num_inputs_; ++i) {
247 node_def_.add_input("dummy_input");
248 }
249 FillAttrValueMap(node_def_.mutable_attr());
250 node_def_finalized_ = true;
251 return node_def_;
252 }
253
CopyAttributes(const AttrBuilder & other)254 void AttrBuilder::CopyAttributes(const AttrBuilder& other) {
255 encoded_attrs_.insert(other.encoded_attrs_.begin(),
256 other.encoded_attrs_.end());
257 }
258
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)259 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
260 TF_AttrType* out, unsigned char* is_list) {
261 auto* t = gtl::FindOrNull(m, attr_name);
262 if (t == nullptr) {
263 return errors::InvalidArgument("Attribute '", attr_name,
264 "' does not exist for this operation");
265 }
266 *out = static_cast<TF_AttrType>(*t & ~kIsList);
267 if (*t & kIsList) {
268 *is_list = 1;
269 } else {
270 *is_list = 0;
271 }
272 return OkStatus();
273 }
274
275 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)276 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
277 const tensorflow::Fprint128& b) {
278 return {tensorflow::FingerprintCat64(a.low64, b.low64),
279 tensorflow::FingerprintCat64(a.high64, b.high64)};
280 }
281
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)282 void CombineUnordered(const tensorflow::Fprint128& a,
283 tensorflow::Fprint128* b) {
284 b->low64 += a.low64;
285 b->high64 += a.high64;
286 }
287
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)288 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
289 const tensorflow::Fprint128& b) {
290 tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
291 return FingerprintCat128(a, b);
292 }
293
CacheKeyHelper(StringPiece s,uint64 b)294 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
295 return CacheKeyHelper(s, {b, b});
296 }
297
298 } // namespace
299
CacheKey(const StringPiece device)300 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
301 if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
302 cached_cache_key_ = BuildCacheKeyForDevice(device);
303 device_for_cached_cache_key_ = string(device);
304 }
305
306 return *cached_cache_key_;
307 }
308
BuildCacheKeyForDevice(const StringPiece device) const309 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
310 const StringPiece device) const {
311 tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
312 f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
313 for (const auto& p : encoded_attrs_) {
314 CombineUnordered(
315 CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
316 }
317 return f;
318 }
319
InitializeNodeDef()320 void AttrBuilder::InitializeNodeDef() {
321 DCHECK(!node_def_initialized_);
322 node_def_.Clear();
323 node_def_.set_name(op_name_);
324 node_def_.set_op(op_name_);
325 node_def_initialized_ = true;
326 }
327
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const328 void AttrBuilder::GetNameAttrList(
329 tensorflow::NameAttrList* name_and_attrs) const {
330 FillAttrValueMap(name_and_attrs->mutable_attr());
331 name_and_attrs->set_name(op_name());
332 }
333
GetTypeList(absl::string_view attr_name,absl::InlinedVector<DataType,4> * type_list) const334 Status AttrBuilder::GetTypeList(
335 absl::string_view attr_name,
336 absl::InlinedVector<DataType, 4>* type_list) const {
337 return Get(attr_name, type_list);
338 }
339
GetInt(absl::string_view attr_name,int64_t * result) const340 bool AttrBuilder::GetInt(absl::string_view attr_name, int64_t* result) const {
341 Status s = Get(attr_name, result);
342 return s.ok();
343 }
GetFloat(absl::string_view attr_name,float * result) const344 bool AttrBuilder::GetFloat(absl::string_view attr_name, float* result) const {
345 Status s = Get(attr_name, result);
346 return s.ok();
347 }
GetBool(absl::string_view attr_name,bool * result) const348 bool AttrBuilder::GetBool(absl::string_view attr_name, bool* result) const {
349 Status s = Get(attr_name, result);
350 return s.ok();
351 }
352
GetType(absl::string_view attr_name,tensorflow::DataType * result) const353 bool AttrBuilder::GetType(absl::string_view attr_name,
354 tensorflow::DataType* result) const {
355 Status s = Get(attr_name, result);
356 return s.ok();
357 }
358
359 } // namespace tensorflow
360