xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/function_schema_inl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ostream>
3 #include <sstream>
4 
5 namespace c10 {
6 
7 template<typename T>
checkArg(const IValue & value,const Argument & argument,std::optional<size_t> pos)8 inline void FunctionSchema::checkArg(
9     const IValue& value,
10     const Argument& argument,
11     std::optional<size_t> pos) const {
12   if (value.isTensor() && argument.type() == TensorType::get()) {
13     // Fast-path for the common case
14     return;
15   }
16   if (!value.type<T>()->isSubtypeOf(*argument.type())) {
17     TORCH_CHECK(
18         false,
19         formatTypeMismatchMsg(
20             argument, value.type<T>()->repr_str(), pos));
21   }
22 }
23 
24 template <typename T>
checkAndNormalizeInputs(std::vector<IValue> & inputs,const std::unordered_map<std::string,IValue> & kwargs)25 inline void FunctionSchema::checkAndNormalizeInputs(
26     std::vector<IValue>& inputs,
27     const std::unordered_map<std::string, IValue>& kwargs) const {
28   // Do we have more inputs than the schema accepts?
29   TORCH_CHECK(
30       inputs.size() <= arguments().size(),
31       "Expected at most ",
32       arguments().size(),
33       " argument(s) for operator '",
34       name(),
35       "', but received ",
36       inputs.size(),
37       " argument(s). Declaration: ",
38       *this);
39 
40   size_t consumed_kwargs = 0;
41   for (const auto pos : c10::irange(arguments().size())) {
42     const auto& argument = arguments()[pos];
43     if (pos < inputs.size()) {
44       checkArg<T>(inputs[pos], argument, pos);
45       continue;
46     }
47     auto it = kwargs.find(argument.name());
48     if (it != kwargs.end()) {
49       checkArg<T>(it->second, argument, std::nullopt);
50       inputs.push_back(it->second);
51       consumed_kwargs++;
52       continue;
53     }
54     if (argument.default_value()) {
55       inputs.push_back(*argument.default_value());
56       continue;
57     }
58     AT_ERROR(
59         name(),
60         "() is missing value for argument '",
61         argument.name(),
62         "'. Declaration: ",
63         *this);
64   }
65   if (consumed_kwargs != kwargs.size()) {
66     std::vector<std::string> names;
67     names.reserve(kwargs.size());
68     for(const auto& k : kwargs) {
69       names.emplace_back(k.first);
70     }
71     throw std::runtime_error(findErrorInKwargs(names));
72   }
73 }
74 
75 } // namespace c10
76