xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/eager/attr_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
17 
18 #include "tensorflow/core/common_runtime/device_factory.h"
19 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
20 #include "tensorflow/core/framework/allocator.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/node_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/fingerprint.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/public/version.h"
28 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 mutex g_op_name_to_attr_type_map_lock(LINKER_INITIALIZED);
34 
OpNameToAttrTypeMap()35 tensorflow::gtl::FlatMap<string, const AttrTypeMap*>* OpNameToAttrTypeMap() {
36   static auto* const m =
37       new tensorflow::gtl::FlatMap<string, const AttrTypeMap*>;
38   return m;
39 }
40 
41 const uint32 kIsList = 1U << 31;
42 
DefaultFunctionAttrTypeMap()43 AttrTypeMap* DefaultFunctionAttrTypeMap() {
44   AttrTypeMap* map = new AttrTypeMap();
45   (*map)["executor_type"] = TF_ATTR_STRING;
46   (*map)["config_proto"] = TF_ATTR_STRING;
47   return map;
48 }
49 
GetDefaultFunctionAttrTypeMap()50 const AttrTypeMap* GetDefaultFunctionAttrTypeMap() {
51   static const AttrTypeMap* map = DefaultFunctionAttrTypeMap();
52   return map;
53 }
54 
55 }  // namespace
56 
OpDefForOp(const string & op_name,const OpDef ** op_def)57 Status OpDefForOp(const string& op_name, const OpDef** op_def) {
58   const OpRegistrationData* op_reg_data = nullptr;
59   Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
60   if (s.ok()) {
61     *op_def = &op_reg_data->op_def;
62   }
63   return s;
64 }
65 
AttrTypeMapForOp(const char * op_name,const AttrTypeMap ** out,bool * is_function)66 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out,
67                         bool* is_function) {
68   {
69     tf_shared_lock l(g_op_name_to_attr_type_map_lock);
70     *is_function = false;
71     *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
72     if (*out != nullptr) return OkStatus();
73   }
74 
75   mutex_lock l(g_op_name_to_attr_type_map_lock);
76 
77   // Check the existence of AttrTypeMap for op_name again because another thread
78   // may insert this map after the tf_shared_lock is released but before the
79   // mutex_lock is acquired.
80   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
81   if (*out != nullptr) return OkStatus();
82 
83   const OpDef* op_def = nullptr;
84   Status s = OpDefForOp(op_name, &op_def);
85   if (errors::IsNotFound(s)) {
86     // If we did not find the op def, we assume `op_name` is a function.
87     // If it is actually a misspelled op, user will get another error when
88     // trying to run it.
89     // TODO(iga): If we ever have a use case for different attribute specs
90     // in different functions, we will need to look at the OpDef in the
91     // function def to retrieve their types.
92     *out = GetDefaultFunctionAttrTypeMap();
93     *is_function = true;
94     return OkStatus();
95   } else if (!s.ok()) {
96     return s;
97   }
98   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
99   // TODO(agarwal): Avoid having to create this "registry" at runtime,
100   // perhaps can be done at op registration time?
101   for (const auto& attr : op_def->attr()) {
102     string type = attr.type();
103     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
104     if (is_list) {
105       type = type.substr(5, type.length() - 6);
106     }
107     uint32 t = is_list ? kIsList : 0;
108     if (type == "string") {
109       t |= TF_ATTR_STRING;
110     } else if (type == "int") {
111       t |= TF_ATTR_INT;
112     } else if (type == "float") {
113       t |= TF_ATTR_FLOAT;
114     } else if (type == "bool") {
115       t |= TF_ATTR_BOOL;
116     } else if (type == "type") {
117       t |= TF_ATTR_TYPE;
118     } else if (type == "shape") {
119       t |= TF_ATTR_SHAPE;
120     } else if (type == "tensor") {
121       t |= TF_ATTR_TENSOR;
122     } else if (type == "func") {
123       t |= TF_ATTR_FUNC;
124     } else {
125       return errors::Unimplemented(
126           "TODO(agarwal): Enable support for ops with attributes of type '",
127           type, "'");
128     }
129     gtl::InsertIfNotPresent(m.get(), attr.name(), t);
130   }
131   *out = m.get();
132   auto r = OpNameToAttrTypeMap()->emplace(op_name, m.release());
133   DCHECK(r.second) << "AttrTypeMap already exists for " << op_name;
134 
135   return OkStatus();
136 }
137 
138 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE)                         \
139   template <>                                                           \
140   Status AttrBuilder::Get(StringPiece attr_name, TYPE* value) const {   \
141     auto it = encoded_attrs_.find(string(attr_name));                   \
142     if (it == encoded_attrs_.end()) {                                   \
143       return errors::NotFound("No attr named '", attr_name,             \
144                               "' found in AttrBuilder for ", op_name_); \
145     }                                                                   \
146     attr_tmp_.ParseFromString(it->second);                              \
147     TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, ATTR_TYPE));         \
148     *value = attr_tmp_.FIELD();                                         \
149     return OkStatus();                                                  \
150   }
151 
152 DEFINE_GET_ATTR(float, f, "float");
153 DEFINE_GET_ATTR(int, i, "int");
154 DEFINE_GET_ATTR(int64_t, i, "int");
155 DEFINE_GET_ATTR(bool, b, "bool");
156 DEFINE_GET_ATTR(tensorflow::DataType, type, "type");
157 
158 #undef DEFINE_GET_ATTR
159 
160 template <>
Get(StringPiece attr_name,absl::InlinedVector<DataType,4> * value) const161 Status AttrBuilder::Get(StringPiece attr_name,
162                         absl::InlinedVector<DataType, 4>* value) const {
163   auto it = encoded_attrs_.find(string(attr_name));
164   if (it == encoded_attrs_.end()) {
165     return errors::NotFound("No attr named '", attr_name,
166                             "' found in AttrBuilder for ", op_name_);
167   }
168   attr_tmp_.ParseFromString(it->second);
169   TF_RETURN_IF_ERROR(AttrValueHasType(attr_tmp_, "list(type)"));
170   for (size_t i = 0; i < attr_tmp_.list().type_size(); i++) {
171     value->push_back(attr_tmp_.list().type(i));
172   }
173   return OkStatus();
174 }
175 
NumInputs(int n)176 AttrBuilder& AttrBuilder::NumInputs(int n) {
177   DCHECK(!node_def_finalized_) << "Calling NumInputs after BuildNodeDef.";
178   num_inputs_ = n;
179   return *this;
180 }
181 
FillAttrValueMap(AttrValueMap * m) const182 void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const {
183   for (auto& entry : encoded_attrs_) {
184     attr_tmp_.ParseFromString(entry.second);
185     m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
186   }
187   // For any attr-value pairs that exist in the op def (from op registry) but
188   // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
189   // specify all the default attr values (e.g. for matmul, the `transpose_a`
190   // attr defaults to false).
191   const OpDef* op_def = nullptr;
192   Status s = OpDefForOp(op_name().c_str(), &op_def);
193   // This is expected, if this op is a custom function, and is therefore not
194   // present in the op registry.
195   if (!s.ok()) return;
196 
197   DCHECK(op_def);
198   for (const auto& attr_def : op_def->attr()) {
199     if (attr_def.has_default_value() && !m->count(attr_def.name())) {
200       SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
201     }
202   }
203 }
204 
205 namespace {
206 
ValueMatchesDefault(const OpDef * op_def,const string & attr_name,const AttrValue & attr_value)207 bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name,
208                          const AttrValue& attr_value) {
209   // TODO(iga): It might make sense to augment OpRegistrationData with a
210   // {attr_name -> default_attr_value} FlatMap to avoid the loop here.
211   for (const OpDef::AttrDef& attr_def : op_def->attr()) {
212     if (attr_def.name() == attr_name && attr_def.has_default_value() &&
213         AreAttrValuesEqual(attr_def.default_value(), attr_value)) {
214       return true;
215     }
216   }
217   return false;
218 }
219 
220 }  // namespace
221 
FillAttrValueMapWithoutDefaults(AttrValueMap * m) const222 void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const {
223   const OpDef* op_def = nullptr;
224   Status s = OpDefForOp(op_name().c_str(), &op_def);
225 
226   for (auto& entry : encoded_attrs_) {
227     attr_tmp_.ParseFromString(entry.second);
228     // Insert the attr-value pair if we did not find the OpDef or if the value
229     // is different from default.
230     if (!s.ok() || !ValueMatchesDefault(op_def, entry.first, attr_tmp_)) {
231       m->insert(AttrValueMap::value_type(entry.first, attr_tmp_));
232     }
233   }
234 }
235 
AddAttrIfNotPresent(StringPiece attr_name,const AttrValue & value)236 void AttrBuilder::AddAttrIfNotPresent(StringPiece attr_name,
237                                       const AttrValue& value) {
238   encoded_attrs_.emplace(string(attr_name), value.SerializeAsString());
239 }
240 
BuildNodeDef()241 const NodeDef& AttrBuilder::BuildNodeDef() {
242   if (node_def_finalized_) return node_def_;
243   if (!node_def_initialized_) {
244     InitializeNodeDef();
245   }
246   for (int i = 0; i < num_inputs_; ++i) {
247     node_def_.add_input("dummy_input");
248   }
249   FillAttrValueMap(node_def_.mutable_attr());
250   node_def_finalized_ = true;
251   return node_def_;
252 }
253 
CopyAttributes(const AttrBuilder & other)254 void AttrBuilder::CopyAttributes(const AttrBuilder& other) {
255   encoded_attrs_.insert(other.encoded_attrs_.begin(),
256                         other.encoded_attrs_.end());
257 }
258 
AttrTypeByName(const AttrTypeMap & m,const string & attr_name,TF_AttrType * out,unsigned char * is_list)259 Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
260                       TF_AttrType* out, unsigned char* is_list) {
261   auto* t = gtl::FindOrNull(m, attr_name);
262   if (t == nullptr) {
263     return errors::InvalidArgument("Attribute '", attr_name,
264                                    "' does not exist for this operation");
265   }
266   *out = static_cast<TF_AttrType>(*t & ~kIsList);
267   if (*t & kIsList) {
268     *is_list = 1;
269   } else {
270     *is_list = 0;
271   }
272   return OkStatus();
273 }
274 
275 namespace {
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)276 inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
277                                                const tensorflow::Fprint128& b) {
278   return {tensorflow::FingerprintCat64(a.low64, b.low64),
279           tensorflow::FingerprintCat64(a.high64, b.high64)};
280 }
281 
CombineUnordered(const tensorflow::Fprint128 & a,tensorflow::Fprint128 * b)282 void CombineUnordered(const tensorflow::Fprint128& a,
283                       tensorflow::Fprint128* b) {
284   b->low64 += a.low64;
285   b->high64 += a.high64;
286 }
287 
CacheKeyHelper(StringPiece s,const tensorflow::Fprint128 & b)288 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
289                                             const tensorflow::Fprint128& b) {
290   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s);
291   return FingerprintCat128(a, b);
292 }
293 
CacheKeyHelper(StringPiece s,uint64 b)294 inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
295   return CacheKeyHelper(s, {b, b});
296 }
297 
298 }  // namespace
299 
CacheKey(const StringPiece device)300 tensorflow::Fprint128 AttrBuilder::CacheKey(const StringPiece device) {
301   if (!cached_cache_key_ || device != device_for_cached_cache_key_) {
302     cached_cache_key_ = BuildCacheKeyForDevice(device);
303     device_for_cached_cache_key_ = string(device);
304   }
305 
306   return *cached_cache_key_;
307 }
308 
BuildCacheKeyForDevice(const StringPiece device) const309 tensorflow::Fprint128 AttrBuilder::BuildCacheKeyForDevice(
310     const StringPiece device) const {
311   tensorflow::Fprint128 f = tensorflow::Fingerprint128(op_name());
312   f = tensorflow::FingerprintCat128(f, tensorflow::Fingerprint128(device));
313   for (const auto& p : encoded_attrs_) {
314     CombineUnordered(
315         CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
316   }
317   return f;
318 }
319 
InitializeNodeDef()320 void AttrBuilder::InitializeNodeDef() {
321   DCHECK(!node_def_initialized_);
322   node_def_.Clear();
323   node_def_.set_name(op_name_);
324   node_def_.set_op(op_name_);
325   node_def_initialized_ = true;
326 }
327 
GetNameAttrList(tensorflow::NameAttrList * name_and_attrs) const328 void AttrBuilder::GetNameAttrList(
329     tensorflow::NameAttrList* name_and_attrs) const {
330   FillAttrValueMap(name_and_attrs->mutable_attr());
331   name_and_attrs->set_name(op_name());
332 }
333 
GetTypeList(absl::string_view attr_name,absl::InlinedVector<DataType,4> * type_list) const334 Status AttrBuilder::GetTypeList(
335     absl::string_view attr_name,
336     absl::InlinedVector<DataType, 4>* type_list) const {
337   return Get(attr_name, type_list);
338 }
339 
GetInt(absl::string_view attr_name,int64_t * result) const340 bool AttrBuilder::GetInt(absl::string_view attr_name, int64_t* result) const {
341   Status s = Get(attr_name, result);
342   return s.ok();
343 }
GetFloat(absl::string_view attr_name,float * result) const344 bool AttrBuilder::GetFloat(absl::string_view attr_name, float* result) const {
345   Status s = Get(attr_name, result);
346   return s.ok();
347 }
GetBool(absl::string_view attr_name,bool * result) const348 bool AttrBuilder::GetBool(absl::string_view attr_name, bool* result) const {
349   Status s = Get(attr_name, result);
350   return s.ok();
351 }
352 
GetType(absl::string_view attr_name,tensorflow::DataType * result) const353 bool AttrBuilder::GetType(absl::string_view attr_name,
354                           tensorflow::DataType* result) const {
355   Status s = Get(attr_name, result);
356   return s.ok();
357 }
358 
359 }  // namespace tensorflow
360