xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/utils/functions.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
36 // Function input argument instantiated into an '_Arg' node in the function body
37 // graph, with an 'index' attribute corresponding to the input position.
38 struct InputArgInstantiation {
InputArgInstantiationInputArgInstantiation39   InputArgInstantiation(string node_name, DataType data_type)
40       : node_name(std::move(node_name)), data_type(data_type) {}
41   string node_name;
42   DataType data_type;
43 };
44 
45 // Function output instantiated into a '_Retval' node in the function body
46 // graph, with an 'index' attribute corresponding to the output position.
47 struct OutputArgInstantiation {
OutputArgInstantiationOutputArgInstantiation48   OutputArgInstantiation(string node_name, DataType data_type)
49       : node_name(std::move(node_name)), data_type(data_type) {}
50   string node_name;
51   DataType data_type;
52 };
53 
54 // A mapping from control output name to node name in function body graph.
55 struct ControlOutput {
56   string output_name;
57   string node_name;
58   bool operator<(const ControlOutput& a) const {
59     return output_name < a.output_name;
60   }
61 };
62 
63 // A special case of GrapplerItem, constructed from a TensorFlow Function.
64 class GrapplerFunctionItem : public GrapplerItem {
65  public:
66   GrapplerFunctionItem() = default;
67 
68   const string& description() const;
69 
70   const std::vector<InputArgInstantiation>& inputs() const;
71   const InputArgInstantiation& input(int i) const;
72   const std::size_t input_size() const;
73 
74   const std::vector<OutputArgInstantiation>& outputs() const;
75   const OutputArgInstantiation& output(int i) const;
76   const std::size_t output_size() const;
77 
78   const std::vector<ControlOutput>& control_outputs() const;
79   const std::size_t control_output_size() const;
80 
81   const AttrSlice& func_attr() const;
82   const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const;
83   const GraphDef& function_body() const;
84   GraphDef& mutable_function_body();
85 
86   bool is_stateful() const;
87 
88   GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
89 
90  private:
91   friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&,
92                                          const FunctionLibraryDefinition&, int,
93                                          GrapplerFunctionItem*);
94   friend Status ReplaceInputWithConst(const NodeDef&, int,
95                                       GrapplerFunctionItem*);
96   friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&,
97                                       GrapplerFunctionItem*,
98                                       std::vector<std::pair<int, int>>*);
99 
100   GrapplerFunctionItem(string func_name, string description,
101                        AttrSlice func_attr,
102                        std::vector<const FunctionDef::ArgAttrs*> arg_attr,
103                        std::vector<InputArgInstantiation> input_args,
104                        std::vector<OutputArgInstantiation> output_args,
105                        std::vector<ControlOutput> control_outputs,
106                        int graph_def_version, bool is_stateful,
107                        GraphDef&& function_body);
108 
109   string description_;
110   AttrSlice func_attr_;  // Attributes specific to function definition that
111                          // produced this item (FuncDef.attr field).
112 
113   // Attributes of function arguments
114   std::vector<const FunctionDef::ArgAttrs*> arg_attr_;
115 
116   std::vector<InputArgInstantiation> input_args_;
117   std::vector<OutputArgInstantiation> output_args_;
118   std::vector<ControlOutput> control_outputs_;
119 
120   bool is_stateful_ = false;
121 };
122 
123 // Check if function input/output types are fully defined only at instantiation
124 // time (parametrized by its instantiation node).
125 bool HasParametrizedType(const FunctionDef& func);
126 
127 // Check if a function body is parametrized by its instantiation node. Function
128 // body is parametrized, if it has at least one node with a 'placeholder'
129 // attribute.
130 bool HasParametrizedBody(const FunctionDef& func);
131 
132 // Check if function has parametrized type or body.
133 bool IsParametrized(const FunctionDef& func);
134 
135 // Resolve function instantiation type parameters from the attributes of the
136 // caller node. Return error if type can't be resolved.
137 Status InstantiationTypeParameters(
138     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
139     absl::flat_hash_map<string, DataType>* type_parameters);
140 
141 // Resolve function instantiation body parameters (values for the function body
142 // attr placeholders) from the attributes of the caller node. Return error if
143 // type can't be resolved.
144 Status InstantiationBodyParameters(
145     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
146     absl::flat_hash_map<string, AttrValue>* body_parameters);
147 
148 // Replace one of the function inputs with a constant.
149 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
150                              GrapplerFunctionItem* item);
151 
152 // Removes outputs from instantiated grappler function item. For all active
153 // function outputs that changed its output index, this function adds an output
154 // mapping (std::pair<old index, new index>).
155 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
156                              GrapplerFunctionItem* item,
157                              std::vector<std::pair<int, int>>* output_mapping);
158 
159 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs.
160 
161 // Make a GrapplerFunctionItem from the function definition and function
162 // instantiation attributes (caller node attributes). Returns error if the given
163 // function def cannot be converted (e.g. not all attributes are defined).
164 Status MakeGrapplerFunctionItem(const FunctionDef& func,
165                                 const AttrSlice& func_instantiation_attr,
166                                 const FunctionLibraryDefinition& flib,
167                                 int graph_def_version,
168                                 GrapplerFunctionItem* item);
169 
170 // Make a GrapplerFunction item from the function definition. Function must be
171 // fully defined (no type or body parametrization).
172 // TODO(ezhulenev): Support parametrized functions without fully defined
173 // instantiation attributes? Do we ever want to optimize parametrized function
174 // without specializing it to its instantiation attributes (at least types)?
175 Status MakeGrapplerFunctionItem(const FunctionDef& func,
176                                 const FunctionLibraryDefinition& flib,
177                                 int graph_def_version,
178                                 GrapplerFunctionItem* item);
179 
180 // Make a FunctionDef from the GrapplerFunctionItem. Use function library
181 // definition to lookup function body nodes output names and ranges.
182 Status MakeFunctionDef(const GrapplerFunctionItem& item,
183                        const FunctionLibraryDefinition& flib,
184                        FunctionDef* func);
185 
186 }  // end namespace grappler
187 }  // end namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
190