1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ 18 19 #include <functional> 20 #include <string> 21 #include <vector> 22 23 #include "tensorflow/core/framework/partial_tensor_shape.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_shape.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/core/stringpiece.h" 29 #include "tensorflow/core/lib/gtl/array_slice.h" 30 31 namespace tensorflow { 32 33 // Forward declare protos so their symbols can be removed from .so exports 34 class AttrValue; 35 class NameAttrList; 36 37 // A human-readable rendering of attr_value, that is more concise than a 38 // text-format proto. 39 std::string SummarizeAttrValue(const AttrValue& attr_value); 40 41 // Generates an error if attr_value doesn't have the indicated attr type. 42 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); 43 44 // Converts a text proto value from "text" into the field of *out 45 // indicated by "type" (e.g. from the type field of an AttrDef). 46 // Examples: 47 // * If type:"int" and text:"-14", then *out is set to "i: -14" 48 // * If type:"list(string)" and text:"['foo', 'bar']", 49 // then *out is set to "list { s: ['foo', 'bar'] }" 50 // Returns true on success. 51 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out); 52 53 // Sets *out based on the type of value. 54 void SetAttrValue(const std::string& value, AttrValue* out); 55 void SetAttrValue(const tstring& value, AttrValue* out); 56 void SetAttrValue(const char* value, AttrValue* out); 57 void SetAttrValue(StringPiece value, AttrValue* out); 58 void SetAttrValue(int64_t value, AttrValue* out); 59 void SetAttrValue(int32_t value, AttrValue* out); 60 void SetAttrValue(float value, AttrValue* out); 61 void SetAttrValue(double value, AttrValue* out); 62 void SetAttrValue(bool value, AttrValue* out); 63 void SetAttrValue(DataType value, AttrValue* out); 64 void SetAttrValue(const TensorShape& value, AttrValue* out); 65 void SetAttrValue(const TensorShapeProto& value, AttrValue* out); 66 void SetAttrValue(const PartialTensorShape& value, AttrValue* out); 67 void SetAttrValue(const Tensor& value, AttrValue* out); 68 void SetAttrValue(const TensorProto& value, AttrValue* out); 69 void SetAttrValue(const NameAttrList& value, AttrValue* out); 70 71 void SetAttrValue(gtl::ArraySlice<string> value, AttrValue* out); 72 void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out); 73 void SetAttrValue(gtl::ArraySlice<const char*> value, AttrValue* out); 74 void SetAttrValue(gtl::ArraySlice<StringPiece> value, AttrValue* out); 75 void SetAttrValue(gtl::ArraySlice<int64_t> value, AttrValue* out); 76 void SetAttrValue(gtl::ArraySlice<int32> value, AttrValue* out); 77 void SetAttrValue(gtl::ArraySlice<float> value, AttrValue* out); 78 void SetAttrValue(gtl::ArraySlice<double> value, AttrValue* out); 79 void SetAttrValue(gtl::ArraySlice<bool> value, AttrValue* out); 80 void SetAttrValue(const std::vector<bool>& value, AttrValue* out); 81 void SetAttrValue(std::initializer_list<bool> value, AttrValue* out); 82 void SetAttrValue(DataTypeSlice value, AttrValue* out); 83 void SetAttrValue(gtl::ArraySlice<TensorShape> value, AttrValue* out); 84 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out); 85 void SetAttrValue(gtl::ArraySlice<PartialTensorShape> value, AttrValue* out); 86 void SetAttrValue(gtl::ArraySlice<Tensor> value, AttrValue* out); 87 void SetAttrValue(gtl::ArraySlice<TensorProto> value, AttrValue* out); 88 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out); 89 90 void SetAttrValue(const AttrValue& value, AttrValue* out); 91 92 void MoveAttrValue(std::vector<string>&& value, AttrValue* out); 93 94 // Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other 95 // words, if two AttrValues compare equal according to AreAttrValuesEqual, 96 // they will have the same hash value. 97 // Similarly to protobuf deterministic serialization, hash value is 98 // guaranteed to be stable only for a given binary. In particular, one should 99 // probably not persist the returned value. 100 uint64 AttrValueHash(const AttrValue& a); 101 102 // WARNING: Equality check might return false-negative for large (> 32mb) 103 // tensors defined with different TensorProto representations. 104 // 105 // A pair of consistent hash and equals functions that are guaranteed to be fast 106 // with AttrValues that potentially can have very large Tensors (larger than 107 // 32mb) defined by TensorProto. If large identical Tensors are defined using 108 // different representations (e.g. one with tensor content, and second with 109 // bool_val), they will have different hash code and equals will return false. 110 // Small (less than 32mb) tensors with different TensorProto representations 111 // hashed/compared by their tensor content. 112 uint64 FastAttrValueHash(const AttrValue& a); 113 // Returns true if a and b have the same value. If false negatives are allowed, 114 // then compares proto representation to avoid construction of large (> 32mb) 115 // tensors. 116 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, 117 bool allow_false_negatives = false); 118 119 // Returns true if "val" has a placeholder. 120 bool HasPlaceHolder(const AttrValue& val); 121 122 // SubstitutePlaceholders recursively replaces placeholders in 'value' 123 // with an attr value by calling SubstituteFunc. Returns true iff all 124 // placeholders in "value" are replaced with a value. 125 // 126 // SubstituteFunc is given a placeholder string. If the placeholder is 127 // unknown, SubstituteFunc returns false. Otherwise, overwrites the 128 // attr value and returns true. 129 using SubstituteFunc = std::function<bool(const string&, AttrValue*)>; 130 bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value); 131 132 } // namespace tensorflow 133 134 #endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ 135