xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/schema_info.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <torch/csrc/utils/schema_info.h>
3 
4 namespace torch::utils {
addArgumentValue(const std::string & name,const at::IValue & value)5 void SchemaInfo::addArgumentValue(
6     const std::string& name,
7     const at::IValue& value) {
8   std::optional<int> index = schema_.argumentIndexWithName(name);
9   TORCH_INTERNAL_ASSERT(
10       index != std::nullopt, "Schema has no argument named ", name);
11   value_map_[name] = value;
12   alias_maps_current_ = false;
13 }
14 
addArgumentValues(const std::vector<std::optional<at::IValue>> & value_list)15 void SchemaInfo::addArgumentValues(
16     const std::vector<std::optional<at::IValue>>& value_list) {
17   TORCH_INTERNAL_ASSERT(
18       value_list.size() <= schema_.arguments().size(),
19       "Schema does not have enough arguments for value list");
20 
21   for (size_t i = 0; i < value_list.size(); i++) {
22     if (value_list[i].has_value()) {
23       // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
24       value_map_[schema_.arguments()[i].name()] = *value_list[i];
25       alias_maps_current_ = false;
26     }
27   }
28 }
29 
addArgumentValues(const std::unordered_map<std::string,at::IValue> & values)30 void SchemaInfo::addArgumentValues(
31     const std::unordered_map<std::string, at::IValue>& values) {
32   for (const auto& key_pair : values) {
33     addArgumentValue(key_pair.first, key_pair.second);
34   }
35 }
36 
hasInputArgumentNamed(const std::string & name) const37 bool SchemaInfo::hasInputArgumentNamed(const std::string& name) const {
38   return std::any_of(
39       schema_.arguments().begin(),
40       schema_.arguments().end(),
41       [&name](const c10::Argument& arg) { return arg.name() == name; });
42 }
43 
is_mutable()44 bool SchemaInfo::is_mutable() {
45   for (size_t i = 0; i < schema_.arguments().size(); i++) {
46     if (is_mutable({c10::SchemaArgType::input, i})) {
47       return true;
48     }
49   }
50   return false;
51 }
52 
is_mutable(const c10::SchemaArgument & argument)53 bool SchemaInfo::is_mutable(const c10::SchemaArgument& argument) {
54   TORCH_INTERNAL_ASSERT(
55       argument.index < schema_.getCorrectList(argument.type).size(),
56       "Invalid index for schema.");
57   if (!alias_maps_current_) {
58     generateAliasMaps();
59   }
60   static const std::vector<SchemaSpecialCasePair> training_ops =
61       getTrainingOps();
62   const auto& correct_map = (argument.type == c10::SchemaArgType::input)
63       ? input_alias_map_
64       : output_alias_map_;
65   // Note that the training_op checks depend on index because
66   // of cases where either running_mean or running_var alias another input
67   // argument causing its alias status to change.
68   return std::any_of(
69       correct_map[argument.index].begin(),
70       correct_map[argument.index].end(),
71       [this](size_t aliasing_index) {
72         const auto is_training_op = std::find_if(
73             training_ops.begin(),
74             training_ops.end(),
75             [this](const auto& training_op) {
76               return this->schema_ == training_op.first;
77             });
78 
79         bool special_case = (is_training_op != training_ops.end()) &&
80             is_training_op->second.count(
81                 this->schema_.arguments()[aliasing_index].name());
82         if (special_case) {
83           bool has_training = (hasInputArgumentNamed("training") &&
84                                !value_map_.count("training")) ||
85               (value_map_.count("training") &&
86                value_map_.at("training").toBool());
87           bool has_train =
88               (hasInputArgumentNamed("train") && !value_map_.count("train")) ||
89               (value_map_.count("train") && value_map_.at("train").toBool());
90           bool has_use_input_stats =
91               (hasInputArgumentNamed("use_input_stats") &&
92                !value_map_.count("use_input_stats")) ||
93               (value_map_.count("use_input_stats") &&
94                value_map_.at("use_input_stats").toBool());
95           return has_training || has_train || has_use_input_stats;
96         } else {
97           return this->schema_.is_mutable(
98               {c10::SchemaArgType::input, aliasing_index});
99         }
100       });
101 }
102 
has_argument(c10::string_view name)103 bool SchemaInfo::has_argument(c10::string_view name) {
104   return schema_.argumentIndexWithName(name) != std::nullopt;
105 }
106 
is_mutable(c10::string_view name)107 bool SchemaInfo::is_mutable(c10::string_view name) {
108   std::optional<int> index = schema_.argumentIndexWithName(name);
109   TORCH_INTERNAL_ASSERT(
110       index.has_value(), "Schema has no argument named ", name);
111 
112   return is_mutable({c10::SchemaArgType::input, static_cast<size_t>(*index)});
113 }
114 
is_nondeterministic() const115 bool SchemaInfo::is_nondeterministic() const {
116   static const c10::FunctionSchema dropout_schema = torch::jit::parseSchema(
117       "aten::dropout(Tensor input, float p, bool train) -> Tensor");
118   if (dropout_schema == schema_ && value_map_.count("train") &&
119       !value_map_.at("train").toBool()) {
120     return false;
121   }
122 
123 #if defined C10_MOBILE
124   static const std::vector<c10::FunctionSchema> nondeterministic_ops =
125       getNonDeterministicOps();
126   return std::any_of(
127       nondeterministic_ops.begin(),
128       nondeterministic_ops.end(),
129       [this](const c10 ::FunctionSchema& nondeterministic_op) {
130         return nondeterministic_op == this->schema_;
131       });
132 #else
133   const auto& op = c10::Dispatcher::singleton().findOp(
134       c10::OperatorName(schema_.name(), schema_.overload_name()));
135   return op && op->hasTag(at::Tag::nondeterministic_seeded);
136 #endif
137 }
138 
may_alias(const c10::SchemaArgument & lhs,const c10::SchemaArgument & rhs)139 bool SchemaInfo::may_alias(
140     const c10::SchemaArgument& lhs,
141     const c10::SchemaArgument& rhs) {
142   bool basic_check = schema_.may_alias(lhs, rhs);
143   if (basic_check) {
144     return true;
145   }
146   std::optional<c10::AliasTypeSet> lhsAliasTypeSet =
147       schema_.mapTypeToAliasTypeSet(
148           schema_.getCorrectList(lhs.type)[lhs.index].type());
149   std::optional<c10::AliasTypeSet> rhsAliasTypeSet =
150       schema_.mapTypeToAliasTypeSet(
151           schema_.getCorrectList(rhs.type)[rhs.index].type());
152   bool types_can_alias =
153       schema_.canAliasTypeSetsAlias(lhsAliasTypeSet, rhsAliasTypeSet);
154   if (!types_can_alias) {
155     return false;
156   }
157 
158   if (!alias_maps_current_) {
159     generateAliasMaps();
160   }
161   bool wildcard_alias_check =
162       wildcardSet().count(lhs) && wildcardSet().count(rhs);
163   if (wildcard_alias_check) {
164     return true;
165   }
166 
167   if (lhs.type == c10::SchemaArgType::input &&
168       rhs.type == c10::SchemaArgType::input) {
169     return input_alias_map_[lhs.index].count(rhs.index);
170   } else if (
171       lhs.type == c10::SchemaArgType::output &&
172       rhs.type == c10::SchemaArgType::output) {
173     for (size_t lhs_alias_input : output_alias_map_[lhs.index]) {
174       if (output_alias_map_[rhs.index].count(lhs_alias_input)) {
175         return true;
176       }
177     }
178     return false;
179   } else if (lhs.type == c10::SchemaArgType::output) {
180     return output_alias_map_[lhs.index].count(rhs.index);
181   } else {
182     return output_alias_map_[rhs.index].count(lhs.index);
183   }
184 }
185 
may_contain_alias(const c10::SchemaArgument & lhs,const c10::SchemaArgument & rhs,bool bidirectional)186 bool SchemaInfo::may_contain_alias(
187     const c10::SchemaArgument& lhs,
188     const c10::SchemaArgument& rhs,
189     bool bidirectional) {
190   bool basic_check = schema_.may_contain_alias(lhs, rhs) || may_alias(lhs, rhs);
191   if (basic_check) {
192     return true;
193   }
194   if (!alias_maps_current_) {
195     generateAliasMaps();
196   }
197   if (bidirectional) {
198     return mayContainAliasImpl(lhs, rhs) || mayContainAliasImpl(rhs, lhs);
199   } else {
200     return mayContainAliasImpl(lhs, rhs);
201   }
202 }
203 
mayContainAliasImpl(const c10::SchemaArgument & lhs,const c10::SchemaArgument & rhs)204 bool SchemaInfo::mayContainAliasImpl(
205     const c10::SchemaArgument& lhs,
206     const c10::SchemaArgument& rhs) {
207   std::optional<c10::AliasTypeSet> lhsContainedAliasTypeSet =
208       schema_.getAliasTypeSetContainedTypes(schema_.mapTypeToAliasTypeSet(
209           schema_.getCorrectList(lhs.type)[lhs.index].type()));
210   std::optional<c10::AliasTypeSet> rhsAliasTypeSet =
211       schema_.mapTypeToAliasTypeSet(
212           schema_.getCorrectList(rhs.type)[rhs.index].type());
213   bool types_can_alias =
214       schema_.canAliasTypeSetsAlias(lhsContainedAliasTypeSet, rhsAliasTypeSet);
215   return types_can_alias && containerSet().count(lhs) &&
216       wildcardSet().count(rhs);
217 }
218 
ensureConservativity(const std::unordered_set<at::Symbol> & duplicates,const std::vector<c10::Argument> & arguments_list,c10::SchemaArgType type)219 void SchemaInfo::ensureConservativity(
220     const std::unordered_set<at::Symbol>& duplicates,
221     const std::vector<c10::Argument>& arguments_list,
222     c10::SchemaArgType type) {
223   for (size_t i = 0; i < arguments_list.size(); i++) {
224     if (arguments_list[i].alias_info()) {
225       for (const auto& set : arguments_list[i].alias_info()->afterSets()) {
226         if (duplicates.count(set)) {
227           wildcard_set_.insert({type, i});
228         }
229       }
230     }
231   }
232 }
233 
getNonDeterministicOps()234 std::vector<c10::FunctionSchema> SchemaInfo::getNonDeterministicOps() {
235   // This list of nondeterministic ops is copied from JIT ir.cpp.
236   static const std::vector<std::string> nondeterministic_op_strings = {
237       "aten::dropout(Tensor input, float p, bool train) -> Tensor",
238       "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
239       "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
240       "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
241       "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
242       "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
243       "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
244       "aten::normal(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
245       "aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
246       "aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
247       "aten::poisson(Tensor self, Generator? generator) -> Tensor",
248       "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
249       "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
250       "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
251       "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
252       "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
253       "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
254       "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
255       "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
256       "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
257       "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
258       "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
259       "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
260 
261   std::vector<c10::FunctionSchema> nondeterministic_ops;
262   nondeterministic_ops.reserve(nondeterministic_op_strings.size());
263   for (const std::string& signature : nondeterministic_op_strings) {
264     nondeterministic_ops.emplace_back(torch::jit::parseSchema(signature));
265   }
266 
267   return nondeterministic_ops;
268 }
269 
getTrainingOps()270 std::vector<SchemaSpecialCasePair> SchemaInfo::getTrainingOps() {
271   // This is a list of pairs of ops to sets of strings
272   //  where the a boolean variable (either "training",
273   // "train" or "use_input_stats") affects the mutability
274   // of the unorderered set of strings.
275   static const std::vector<std::pair<std::string, std::unordered_set<std::string>>> training_op_pairs =
276       {{"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
277         {"running_mean", "running_var"}},
278        {"aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor",
279         {"running_mean", "running_var"}},
280        {"aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)",
281         {"running_mean", "running_var"}},
282        {"aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)",
283         {"running_mean", "running_var"}},
284        {"aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)",
285         {"running_mean", "running_var"}},
286        {"aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
287         {"running_mean", "running_var"}},
288        {"aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))",
289         {"running_mean", "running_var"}},
290        {"aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor",
291         {"noise"}},
292        {"aten::rrelu_with_noise.out(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)",
293         {"noise"}},
294        {"rrelu_with_noise_(Tensor(a!) self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)",
295         {"noise"}}};
296 
297   std::vector<SchemaSpecialCasePair> training_ops;
298   training_ops.reserve(training_op_pairs.size());
299   for (const auto& signature : training_op_pairs) {
300     training_ops.emplace_back(
301         torch::jit::parseSchema(signature.first), signature.second);
302   }
303 
304   return training_ops;
305 }
306 
initSchemaInfo()307 void SchemaInfo::initSchemaInfo() {
308   if (has_init_) {
309     return;
310   }
311   has_init_ = true;
312 
313   std::unordered_set<at::Symbol> duplicates;
314   auto init_schema_arguments = [this, &duplicates](
315                                    const std::vector<c10::Argument>&
316                                        arguments_list,
317                                    c10::SchemaArgType type) {
318     std::unordered_set<at::Symbol> seen;
319     for (size_t i = 0; i < arguments_list.size(); i++) {
320       const c10::Argument& argument = arguments_list[i];
321       if (argument.alias_info()) {
322         if (argument.alias_info()->isWildcardAfter()) {
323           wildcard_set_.insert({type, i});
324         } else {
325           // This check is to ensure that the FunctionSchema will accurately
326           // be represented when calling may_alias and may_contain_alias
327           // on schemas with more than one argument within arguments_list that
328           // shares an alias set.
329           for (const auto& set : argument.alias_info()->afterSets()) {
330             if (seen.count(set)) {
331               TORCH_WARN(
332                   set.toQualString(),
333                   " appears twice in same argument list which will make aliasing checks more conservative.");
334               duplicates.insert(set);
335             } else {
336               seen.insert(set);
337             }
338           }
339         }
340       }
341       std::optional<c10::AliasTypeSet> contained_types =
342           schema_.getAliasTypeSetContainedTypes(
343               schema_.mapTypeToAliasTypeSet(argument.type()));
344       if (contained_types && !contained_types->empty()) {
345         container_set_.insert({type, i});
346       }
347     }
348   };
349 
350   init_schema_arguments(schema_.arguments(), c10::SchemaArgType::input);
351   init_schema_arguments(schema_.returns(), c10::SchemaArgType::output);
352   ensureConservativity(
353       duplicates, schema_.arguments(), c10::SchemaArgType::input);
354   ensureConservativity(
355       duplicates, schema_.returns(), c10::SchemaArgType::output);
356 }
357 
wildcardSet()358 const std::unordered_set<c10::SchemaArgument>& SchemaInfo::wildcardSet() {
359   initSchemaInfo();
360   return wildcard_set_;
361 }
362 
containerSet()363 const std::unordered_set<c10::SchemaArgument>& SchemaInfo::containerSet() {
364   initSchemaInfo();
365   return container_set_;
366 }
367 
generateAliasMaps()368 void SchemaInfo::generateAliasMaps() {
369   initSchemaInfo();
370 
371   alias_maps_current_ = true;
372   input_alias_map_ = std::vector<std::unordered_set<size_t>>(
373       schema_.arguments().size(), std::unordered_set<size_t>());
374   output_alias_map_ = std::vector<std::unordered_set<size_t>>(
375       schema_.returns().size(), std::unordered_set<size_t>());
376 
377   // Fills input_alias_map_
378   for (size_t i = 0; i < schema_.arguments().size(); i++) {
379     for (size_t j = i; j < schema_.arguments().size(); j++) {
380       if (i == j) {
381         input_alias_map_[i].insert(i);
382       } else if (
383           value_map_.count(schema_.arguments()[i].name()) &&
384           value_map_.count(schema_.arguments()[j].name())) {
385         if (value_map_[schema_.arguments()[i].name()].isAliasOf(
386                 value_map_[schema_.arguments()[j].name()])) {
387           input_alias_map_[i].insert(j);
388           input_alias_map_[j].insert(i);
389           if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
390             wildcard_set_.insert({c10::SchemaArgType::input, j});
391           } else if (wildcard_set_.count({c10::SchemaArgType::input, j})) {
392             wildcard_set_.insert({c10::SchemaArgType::input, i});
393           }
394         }
395       }
396     }
397   }
398 
399   // Fills wildcard_set with container created wildcards.
400   // For instance, given the schema:
401   // test(Tensor a, Tensor(*) b, Tensor[] c) -> Tensor
402   // where value(a) is contained in value(c), then a will be added to the
403   // wildcard set where it can now alias b.
404   for (size_t i = 0; i < schema_.arguments().size(); i++) {
405     for (size_t j = 0; j < schema_.arguments().size(); j++) {
406       // if they are already aliasing, there is no way one contains the other
407       if (!input_alias_map_[i].count(j) &&
408           value_map_.count(schema_.arguments()[i].name()) &&
409           value_map_.count(schema_.arguments()[j].name())) {
410         c10::IValue::HashAliasedIValues subValues;
411         value_map_[schema_.arguments()[i].name()].getSubValues(subValues);
412         if (subValues.count(value_map_[schema_.arguments()[j].name()])) {
413           wildcard_set_.insert({c10::SchemaArgType::input, j});
414         }
415       }
416     }
417   }
418 
419   // Fills output_alias_map_
420   for (size_t i = 0; i < schema_.arguments().size(); i++) {
421     for (size_t j = 0; j < schema_.returns().size(); j++) {
422       if (schema_.may_alias(
423               {c10::SchemaArgType::input, i},
424               {c10::SchemaArgType::output, j})) {
425         if (wildcard_set_.count({c10::SchemaArgType::input, i})) {
426           wildcard_set_.insert({c10::SchemaArgType::output, j});
427         }
428         output_alias_map_[j].insert(
429             input_alias_map_[i].begin(), input_alias_map_[i].end());
430       }
431     }
432   }
433 }
434 
435 } // namespace torch::utils
436