xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/compat/op_compatibility_lib.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/ops/compat/op_compatibility_lib.h"
17 
18 #include <stdio.h>
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/op_def_util.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/io/path.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 
28 namespace tensorflow {
29 
OpsHistoryDirectory(const string & ops_prefix,const string & history_version)30 static string OpsHistoryDirectory(const string& ops_prefix,
31                                   const string& history_version) {
32   return io::JoinPath(ops_prefix,
33                       strings::StrCat("compat/ops_history_", history_version));
34 }
35 
OpsHistoryFile(const string & ops_prefix,const string & history_version)36 static string OpsHistoryFile(const string& ops_prefix,
37                              const string& history_version) {
38   return io::JoinPath(ops_prefix, strings::StrCat("compat/ops_history.",
39                                                   history_version, ".pbtxt"));
40 }
41 
FileNameFromOpName(const string & op_name)42 static string FileNameFromOpName(const string& op_name) {
43   return strings::StrCat(op_name, ".pbtxt");
44 }
45 
AddNewOpToHistory(const OpDef & op,OpCompatibilityLib::OpHistory * out_op_history)46 static void AddNewOpToHistory(const OpDef& op,
47                               OpCompatibilityLib::OpHistory* out_op_history) {
48   if (out_op_history != nullptr) {
49     out_op_history->emplace_back(FileNameFromOpName(op.name()), OpList());
50     *out_op_history->back().second.add_op() = op;
51   }
52 }
53 
ReadOpHistory(Env * env,const string & file,const string & directory,OpCompatibilityLib::OpHistory * out)54 static Status ReadOpHistory(Env* env, const string& file,
55                             const string& directory,
56                             OpCompatibilityLib::OpHistory* out) {
57   // Read op history form `directory` if it exists there.
58   std::vector<string> matching_files;
59   Status status = env->GetMatchingPaths(io::JoinPath(directory, "*.pbtxt"),
60                                         &matching_files);
61   if (status.ok() && !matching_files.empty()) {
62     printf("Reading op history from %s/*.pbtxt...\n", directory.c_str());
63     std::sort(matching_files.begin(), matching_files.end());
64     for (const string& full_file : matching_files) {
65       string op_history_str;
66       TF_RETURN_IF_ERROR(ReadFileToString(env, full_file, &op_history_str));
67       OpList in_op_history;
68       protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
69       const string file_tail = FileNameFromOpName(in_op_history.op(0).name());
70       const string expected = io::JoinPath(directory, file_tail);
71       if (full_file != expected) {
72         return errors::Internal("Expected file paths to match but '", full_file,
73                                 "' != '", expected, "'");
74       }
75       out->emplace_back(file_tail, in_op_history);
76     }
77   } else {  // Otherwise, fall back to reading op history from `file`.
78     printf("Reading op history from %s...\n", file.c_str());
79     string op_history_str;
80     TF_RETURN_IF_ERROR(ReadFileToString(env, file, &op_history_str));
81     OpList in_op_history;
82     protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
83     // Convert from a linear OpList to OpHistory format with one OpList per
84     // unique op name.
85     int start = 0;
86     while (start < in_op_history.op_size()) {
87       int end = start + 1;
88       while (end < in_op_history.op_size() &&
89              in_op_history.op(start).name() == in_op_history.op(end).name()) {
90         ++end;
91       }
92       AddNewOpToHistory(in_op_history.op(start), out);
93       for (++start; start < end; ++start) {
94         *out->back().second.add_op() = in_op_history.op(start);
95       }
96     }
97   }
98   return OkStatus();
99 }
100 
OpCompatibilityLib(const string & ops_prefix,const string & history_version,const std::set<string> * stable_ops)101 OpCompatibilityLib::OpCompatibilityLib(const string& ops_prefix,
102                                        const string& history_version,
103                                        const std::set<string>* stable_ops)
104     : ops_file_(io::JoinPath(ops_prefix, "ops.pbtxt")),
105       op_history_file_(OpsHistoryFile(ops_prefix, history_version)),
106       op_history_directory_(OpsHistoryDirectory(ops_prefix, history_version)),
107       stable_ops_(stable_ops) {
108   // Get the sorted list of all registered OpDefs.
109   printf("Getting all registered ops...\n");
110   OpRegistry::Global()->Export(false, &op_list_);
111 }
112 
ValidateCompatible(Env * env,int * changed_ops,int * added_ops,OpHistory * out_op_history)113 Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
114                                               int* added_ops,
115                                               OpHistory* out_op_history) {
116   *changed_ops = 0;
117   *added_ops = 0;
118 
119   // Strip docs out of op_list_.
120   RemoveDescriptionsFromOpList(&op_list_);
121 
122   if (stable_ops_ != nullptr) {
123     printf("Verifying no stable ops have been removed...\n");
124     std::vector<string> removed;
125     // We rely on stable_ops_ and op_list_ being in sorted order.
126     auto iter = stable_ops_->begin();
127     for (int cur = 0; iter != stable_ops_->end() && cur < op_list_.op_size();
128          ++cur) {
129       const string& op_name = op_list_.op(cur).name();
130       while (op_name > *iter) {
131         removed.push_back(*iter);
132         ++iter;
133       }
134       if (op_name == *iter) {
135         ++iter;
136       }
137     }
138     for (; iter != stable_ops_->end(); ++iter) {
139       removed.push_back(*iter);
140     }
141     if (!removed.empty()) {
142       return errors::InvalidArgument("Error, stable op(s) removed: ",
143                                      absl::StrJoin(removed, ", "));
144     }
145   }
146 
147   OpHistory in_op_history;
148   TF_RETURN_IF_ERROR(ReadOpHistory(env, op_history_file_, op_history_directory_,
149                                    &in_op_history));
150 
151   int cur = 0;
152   int hist = 0;
153 
154   printf("Verifying updates are compatible...\n");
155   // Note: Op history is one OpList per unique op name in alphabetical order.
156   // Within the OplList it has versions in oldest-first order.
157   while (cur < op_list_.op_size() && hist < in_op_history.size()) {
158     const OpDef& cur_op = op_list_.op(cur);
159     const string& cur_op_name = cur_op.name();
160     const OpList& history_op_list = in_op_history[hist].second;
161     const string& history_op_name = history_op_list.op(0).name();
162     if (stable_ops_ != nullptr && stable_ops_->count(cur_op_name) == 0) {
163       // Ignore unstable op.
164       for (++cur; cur < op_list_.op_size(); ++cur) {
165         if (op_list_.op(cur).name() != cur_op_name) break;
166       }
167     } else if (cur_op_name < history_op_name) {
168       // New op: add it.
169       AddNewOpToHistory(cur_op, out_op_history);
170       ++*added_ops;
171       ++cur;
172     } else if (cur_op_name > history_op_name) {
173       if (stable_ops_ != nullptr) {
174         // Okay to remove ops from the history that have been made unstable.
175         ++hist;
176       } else {
177         // Op removed: error.
178         return errors::InvalidArgument("Error, removed op: ",
179                                        SummarizeOpDef(history_op_list.op(0)));
180       }
181     } else {
182       // Op match.
183       if (out_op_history != nullptr) {
184         // Copy from in_op_history to *out_op_history.
185         out_op_history->push_back(in_op_history[hist]);
186       }
187 
188       const int end = history_op_list.op_size();
189       // Is the last op in the history the same as the current op?
190       // Compare using their serialized representations.
191       string history_str, cur_str;
192       history_op_list.op(end - 1).SerializeToString(&history_str);
193       cur_op.SerializeToString(&cur_str);
194 
195       if (history_str != cur_str) {
196         // Op changed, verify the change is compatible.
197         for (int i = 0; i < end; ++i) {
198           TF_RETURN_IF_ERROR(OpDefCompatible(history_op_list.op(i), cur_op));
199         }
200 
201         // Verify default value of attrs has not been removed or modified
202         // as compared to only the last historical version.
203         TF_RETURN_IF_ERROR(
204             OpDefAttrDefaultsUnchanged(history_op_list.op(end - 1), cur_op));
205 
206         // Check that attrs missing from history_op_list.op(0) don't change
207         // their defaults.
208         if (end > 1) {
209           TF_RETURN_IF_ERROR(OpDefAddedDefaultsUnchanged(
210               history_op_list.op(0), history_op_list.op(end - 1), cur_op));
211         }
212 
213         // Compatible! Add changed op to the end of the history.
214         if (out_op_history != nullptr) {
215           *out_op_history->back().second.add_op() = cur_op;
216         }
217         ++*changed_ops;
218       }
219 
220       // Advance past this op.
221       ++hist;
222       ++cur;
223     }
224   }
225 
226   // Error if missing ops.
227   if (stable_ops_ == nullptr && hist < in_op_history.size()) {
228     return errors::InvalidArgument(
229         "Error, removed op: ",
230         SummarizeOpDef(in_op_history[hist].second.op(0)));
231   }
232 
233   // Add remaining new ops.
234   for (; cur < op_list_.op_size(); ++cur) {
235     const string& op_name = op_list_.op(cur).name();
236     if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
237       // Ignore unstable op.
238     } else {
239       AddNewOpToHistory(op_list_.op(cur), out_op_history);
240       ++*added_ops;
241     }
242   }
243 
244   return OkStatus();
245 }
246 
247 }  // namespace tensorflow
248