1 #pragma once
2 #include <ostream>
3 #include <sstream>
4
5 namespace c10 {
6
7 template<typename T>
checkArg(const IValue & value,const Argument & argument,std::optional<size_t> pos)8 inline void FunctionSchema::checkArg(
9 const IValue& value,
10 const Argument& argument,
11 std::optional<size_t> pos) const {
12 if (value.isTensor() && argument.type() == TensorType::get()) {
13 // Fast-path for the common case
14 return;
15 }
16 if (!value.type<T>()->isSubtypeOf(*argument.type())) {
17 TORCH_CHECK(
18 false,
19 formatTypeMismatchMsg(
20 argument, value.type<T>()->repr_str(), pos));
21 }
22 }
23
24 template <typename T>
checkAndNormalizeInputs(std::vector<IValue> & inputs,const std::unordered_map<std::string,IValue> & kwargs)25 inline void FunctionSchema::checkAndNormalizeInputs(
26 std::vector<IValue>& inputs,
27 const std::unordered_map<std::string, IValue>& kwargs) const {
28 // Do we have more inputs than the schema accepts?
29 TORCH_CHECK(
30 inputs.size() <= arguments().size(),
31 "Expected at most ",
32 arguments().size(),
33 " argument(s) for operator '",
34 name(),
35 "', but received ",
36 inputs.size(),
37 " argument(s). Declaration: ",
38 *this);
39
40 size_t consumed_kwargs = 0;
41 for (const auto pos : c10::irange(arguments().size())) {
42 const auto& argument = arguments()[pos];
43 if (pos < inputs.size()) {
44 checkArg<T>(inputs[pos], argument, pos);
45 continue;
46 }
47 auto it = kwargs.find(argument.name());
48 if (it != kwargs.end()) {
49 checkArg<T>(it->second, argument, std::nullopt);
50 inputs.push_back(it->second);
51 consumed_kwargs++;
52 continue;
53 }
54 if (argument.default_value()) {
55 inputs.push_back(*argument.default_value());
56 continue;
57 }
58 AT_ERROR(
59 name(),
60 "() is missing value for argument '",
61 argument.name(),
62 "'. Declaration: ",
63 *this);
64 }
65 if (consumed_kwargs != kwargs.size()) {
66 std::vector<std::string> names;
67 names.reserve(kwargs.size());
68 for(const auto& k : kwargs) {
69 names.emplace_back(k.first);
70 }
71 throw std::runtime_error(findErrorInKwargs(names));
72 }
73 }
74
75 } // namespace c10
76