1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/container/inlined_vector.h" 25 #include "tensorflow/core/framework/attr_value.pb.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/function.pb.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/framework/op_def.pb.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 #include "tensorflow/core/lib/gtl/flatset.h" 32 33 namespace tensorflow { 34 namespace grappler { 35 36 // Function input argument instantiated into an '_Arg' node in the function body 37 // graph, with an 'index' attribute corresponding to the input position. 38 struct InputArgInstantiation { InputArgInstantiationInputArgInstantiation39 InputArgInstantiation(string node_name, DataType data_type) 40 : node_name(std::move(node_name)), data_type(data_type) {} 41 string node_name; 42 DataType data_type; 43 }; 44 45 // Function output instantiated into a '_Retval' node in the function body 46 // graph, with an 'index' attribute corresponding to the output position. 47 struct OutputArgInstantiation { OutputArgInstantiationOutputArgInstantiation48 OutputArgInstantiation(string node_name, DataType data_type) 49 : node_name(std::move(node_name)), data_type(data_type) {} 50 string node_name; 51 DataType data_type; 52 }; 53 54 // A mapping from control output name to node name in function body graph. 55 struct ControlOutput { 56 string output_name; 57 string node_name; 58 bool operator<(const ControlOutput& a) const { 59 return output_name < a.output_name; 60 } 61 }; 62 63 // A special case of GrapplerItem, constructed from a TensorFlow Function. 64 class GrapplerFunctionItem : public GrapplerItem { 65 public: 66 GrapplerFunctionItem() = default; 67 68 const string& description() const; 69 70 const std::vector<InputArgInstantiation>& inputs() const; 71 const InputArgInstantiation& input(int i) const; 72 const std::size_t input_size() const; 73 74 const std::vector<OutputArgInstantiation>& outputs() const; 75 const OutputArgInstantiation& output(int i) const; 76 const std::size_t output_size() const; 77 78 const std::vector<ControlOutput>& control_outputs() const; 79 const std::size_t control_output_size() const; 80 81 const AttrSlice& func_attr() const; 82 const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const; 83 const GraphDef& function_body() const; 84 GraphDef& mutable_function_body(); 85 86 bool is_stateful() const; 87 88 GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); 89 90 private: 91 friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&, 92 const FunctionLibraryDefinition&, int, 93 GrapplerFunctionItem*); 94 friend Status ReplaceInputWithConst(const NodeDef&, int, 95 GrapplerFunctionItem*); 96 friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&, 97 GrapplerFunctionItem*, 98 std::vector<std::pair<int, int>>*); 99 100 GrapplerFunctionItem(string func_name, string description, 101 AttrSlice func_attr, 102 std::vector<const FunctionDef::ArgAttrs*> arg_attr, 103 std::vector<InputArgInstantiation> input_args, 104 std::vector<OutputArgInstantiation> output_args, 105 std::vector<ControlOutput> control_outputs, 106 int graph_def_version, bool is_stateful, 107 GraphDef&& function_body); 108 109 string description_; 110 AttrSlice func_attr_; // Attributes specific to function definition that 111 // produced this item (FuncDef.attr field). 112 113 // Attributes of function arguments 114 std::vector<const FunctionDef::ArgAttrs*> arg_attr_; 115 116 std::vector<InputArgInstantiation> input_args_; 117 std::vector<OutputArgInstantiation> output_args_; 118 std::vector<ControlOutput> control_outputs_; 119 120 bool is_stateful_ = false; 121 }; 122 123 // Check if function input/output types are fully defined only at instantiation 124 // time (parametrized by its instantiation node). 125 bool HasParametrizedType(const FunctionDef& func); 126 127 // Check if a function body is parametrized by its instantiation node. Function 128 // body is parametrized, if it has at least one node with a 'placeholder' 129 // attribute. 130 bool HasParametrizedBody(const FunctionDef& func); 131 132 // Check if function has parametrized type or body. 133 bool IsParametrized(const FunctionDef& func); 134 135 // Resolve function instantiation type parameters from the attributes of the 136 // caller node. Return error if type can't be resolved. 137 Status InstantiationTypeParameters( 138 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 139 absl::flat_hash_map<string, DataType>* type_parameters); 140 141 // Resolve function instantiation body parameters (values for the function body 142 // attr placeholders) from the attributes of the caller node. Return error if 143 // type can't be resolved. 144 Status InstantiationBodyParameters( 145 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 146 absl::flat_hash_map<string, AttrValue>* body_parameters); 147 148 // Replace one of the function inputs with a constant. 149 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, 150 GrapplerFunctionItem* item); 151 152 // Removes outputs from instantiated grappler function item. For all active 153 // function outputs that changed its output index, this function adds an output 154 // mapping (std::pair<old index, new index>). 155 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs, 156 GrapplerFunctionItem* item, 157 std::vector<std::pair<int, int>>* output_mapping); 158 159 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs. 160 161 // Make a GrapplerFunctionItem from the function definition and function 162 // instantiation attributes (caller node attributes). Returns error if the given 163 // function def cannot be converted (e.g. not all attributes are defined). 164 Status MakeGrapplerFunctionItem(const FunctionDef& func, 165 const AttrSlice& func_instantiation_attr, 166 const FunctionLibraryDefinition& flib, 167 int graph_def_version, 168 GrapplerFunctionItem* item); 169 170 // Make a GrapplerFunction item from the function definition. Function must be 171 // fully defined (no type or body parametrization). 172 // TODO(ezhulenev): Support parametrized functions without fully defined 173 // instantiation attributes? Do we ever want to optimize parametrized function 174 // without specializing it to its instantiation attributes (at least types)? 175 Status MakeGrapplerFunctionItem(const FunctionDef& func, 176 const FunctionLibraryDefinition& flib, 177 int graph_def_version, 178 GrapplerFunctionItem* item); 179 180 // Make a FunctionDef from the GrapplerFunctionItem. Use function library 181 // definition to lookup function body nodes output names and ranges. 182 Status MakeFunctionDef(const GrapplerFunctionItem& item, 183 const FunctionLibraryDefinition& flib, 184 FunctionDef* func); 185 186 } // end namespace grappler 187 } // end namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 190