xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/graph_transforms/backports.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 #include "tensorflow/core/common_runtime/graph_constructor.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/graph/subgraph.h"
20 #include "tensorflow/core/platform/init_main.h"
21 #include "tensorflow/core/public/session.h"
22 #include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23 #include "tensorflow/tools/graph_transforms/transform_utils.h"
24 
25 namespace tensorflow {
26 namespace graph_transforms {
27 
28 // Switch any ConcatV2 nodes to the v1 version, swapping the input order.
BackportConcatV2Transform(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)29 Status BackportConcatV2Transform(const GraphDef& input_graph_def,
30                                  const TransformFuncContext& context,
31                                  GraphDef* output_graph_def) {
32   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
33       input_graph_def, {"ConcatV2"},
34       [](const NodeMatch& match, const std::set<string>& input_nodes,
35          const std::set<string>& output_nodes,
36          std::vector<NodeDef>* new_nodes) {
37         const NodeDef& concat_v2_node = match.node;
38         NodeDef concat_node = concat_v2_node;
39         concat_node.set_op("Concat");
40         // The last input is inserted at the head of the inputs, because Concat
41         // expects the dimension as the first input (not the last as in
42         // ConcatV2).
43         concat_node.mutable_input()->Clear();
44         const string& dim_input =
45             concat_v2_node.input(concat_v2_node.input_size() - 1);
46         concat_node.add_input(dim_input);
47         for (int i = 0; i < (concat_v2_node.input_size() - 1); ++i) {
48           concat_node.add_input(concat_v2_node.input(i));
49         }
50         // Tidx attribute must be deleted because it's not used in Concat.
51         concat_node.mutable_attr()->erase("Tidx");
52         new_nodes->push_back(concat_node);
53         return OkStatus();
54       },
55       {true}, output_graph_def));
56 
57   return OkStatus();
58 }
59 
60 REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform);
61 
62 // Switch any TensorArrayV3 nodes to the v2 version, removing the second output.
BackportTensorArrayV3Transform(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)63 Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
64                                       const TransformFuncContext& context,
65                                       GraphDef* output_graph_def) {
66   std::map<string, string> inputs_to_rename;
67   GraphDef replaced_graph_def;
68   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
69       input_graph_def, {"TensorArrayV3|TensorArrayGradV3"},
70       [&inputs_to_rename](const NodeMatch& match,
71                           const std::set<string>& input_nodes,
72                           const std::set<string>& output_nodes,
73                           std::vector<NodeDef>* new_nodes) {
74         const NodeDef& tensor_array_v3_node = match.node;
75 
76         // All we need to do here is rename the op type, since the attributes
77         // remain the same.
78         NodeDef tensor_array_v2_node = tensor_array_v3_node;
79         if (tensor_array_v3_node.op() == "TensorArrayV3") {
80           tensor_array_v2_node.set_op("TensorArrayV2");
81         } else {
82           tensor_array_v2_node.set_op("TensorArrayGradV2");
83         }
84 
85         // The v3 version has a second 'flow' output that's not present in v2,
86         // so substitute a dummy constant instead in any places that use it.
87         NodeDef replacement_flow_node;
88         replacement_flow_node.set_op("Const");
89         SetNodeAttr("dtype", DT_FLOAT, &replacement_flow_node);
90         replacement_flow_node.set_name(tensor_array_v3_node.name() +
91                                        "/replacement_flow_node");
92         Tensor replacement_flow_tensor(DT_FLOAT, {});
93         // I'm picking an arbitrary value for the gradient flow here, for lack
94         // of a better alternative.
95         replacement_flow_tensor.flat<float>()(0) = 1.0f;
96         SetNodeTensorAttr<float>("value", replacement_flow_tensor,
97                                  &replacement_flow_node);
98         inputs_to_rename[tensor_array_v3_node.name() + ":1"] =
99             replacement_flow_node.name();
100 
101         new_nodes->push_back(tensor_array_v2_node);
102         new_nodes->push_back(replacement_flow_node);
103         return OkStatus();
104       },
105       {true}, &replaced_graph_def));
106   // Update the graph so that any nodes that referred to removed inputs now
107   // pull from the substitute constants we've added.
108   GraphDef renamed_graph_def;
109   TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
110                                       std::unordered_set<string>(),
111                                       &renamed_graph_def));
112   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
113       renamed_graph_def,
114       {"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|"
115        "TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|"
116        "TensorArraySizeV3|TensorArrayCloseV3"},
117       [](const NodeMatch& match, const std::set<string>& input_nodes,
118          const std::set<string>& output_nodes,
119          std::vector<NodeDef>* new_nodes) {
120         const NodeDef& v3_node = match.node;
121         NodeDef v2_node = v3_node;
122         v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2");
123         new_nodes->push_back(v2_node);
124         return OkStatus();
125       },
126       {true}, output_graph_def));
127   return OkStatus();
128 }
129 
130 REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3",
131                          BackportTensorArrayV3Transform);
132 
133 }  // namespace graph_transforms
134 }  // namespace tensorflow
135