1 #pragma once 2 3 #include <ATen/core/function.h> 4 #include <ATen/core/ivalue.h> 5 #include <c10/util/Exception.h> 6 #include <c10/util/intrusive_ptr.h> 7 #include <functional> 8 #include <utility> 9 10 namespace torch::jit { 11 12 struct BuiltinOpFunction : public Function { 13 BuiltinOpFunction( 14 c10::QualifiedName qualname, 15 c10::FunctionSchema schema, 16 std::function<void(Stack&)> callable, 17 std::string doc_string = "") name_BuiltinOpFunction18 : name_(std::move(qualname)), 19 callable_(std::move(callable)), 20 schema_(std::move(schema)), 21 doc_string_(std::move(doc_string)) { 22 TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); 23 } 24 doc_stringBuiltinOpFunction25 c10::string_view doc_string() const override { 26 return doc_string_; 27 } 28 runBuiltinOpFunction29 void run(Stack& stack) override { 30 callable_(stack); 31 } 32 runAsyncBuiltinOpFunction33 c10::intrusive_ptr<c10::ivalue::Future> runAsync( 34 Stack& stack, 35 TaskLauncher /* not used */) override { 36 run(stack); 37 auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type()); 38 res->markCompleted(std::move(stack.front())); 39 return res; 40 } 41 qualnameBuiltinOpFunction42 const c10::QualifiedName& qualname() const override { 43 return name_; 44 } 45 46 // if this isn't yet defined, run its method_creator function ensure_definedBuiltinOpFunction47 void ensure_defined() override { 48 // nop 49 } 50 getSchemaBuiltinOpFunction51 const c10::FunctionSchema& getSchema() const override { 52 return schema_; 53 } 54 num_inputsBuiltinOpFunction55 size_t num_inputs() const override { 56 return schema_.arguments().size(); 57 } 58 setSchemaBuiltinOpFunction59 Function& setSchema(c10::FunctionSchema schema) override { 60 schema_ = std::move(schema); 61 return *this; 62 } 63 callBuiltinOpFunction64 bool call( 65 Stack& stack, 66 std::optional<size_t>, 67 c10::function_ref<void(const Code&)>) override { 68 run(stack); 69 return false; 70 } 71 callBuiltinOpFunction72 bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>) 73 override { 74 run(stack); 75 return false; 76 } 77 78 ~BuiltinOpFunction() override = default; 79 80 private: 81 c10::QualifiedName name_; 82 83 std::function<void(Stack&)> callable_; 84 85 c10::FunctionSchema schema_; 86 87 std::string doc_string_; 88 }; 89 90 } // namespace torch::jit 91