xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/builtin_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/function.h>
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/intrusive_ptr.h>
7 #include <functional>
8 #include <utility>
9 
10 namespace torch::jit {
11 
12 struct BuiltinOpFunction : public Function {
13   BuiltinOpFunction(
14       c10::QualifiedName qualname,
15       c10::FunctionSchema schema,
16       std::function<void(Stack&)> callable,
17       std::string doc_string = "")
name_BuiltinOpFunction18       : name_(std::move(qualname)),
19         callable_(std::move(callable)),
20         schema_(std::move(schema)),
21         doc_string_(std::move(doc_string)) {
22     TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
23   }
24 
doc_stringBuiltinOpFunction25   c10::string_view doc_string() const override {
26     return doc_string_;
27   }
28 
runBuiltinOpFunction29   void run(Stack& stack) override {
30     callable_(stack);
31   }
32 
runAsyncBuiltinOpFunction33   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
34       Stack& stack,
35       TaskLauncher /* not used */) override {
36     run(stack);
37     auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
38     res->markCompleted(std::move(stack.front()));
39     return res;
40   }
41 
qualnameBuiltinOpFunction42   const c10::QualifiedName& qualname() const override {
43     return name_;
44   }
45 
46   // if this isn't yet defined, run its method_creator function
ensure_definedBuiltinOpFunction47   void ensure_defined() override {
48     // nop
49   }
50 
getSchemaBuiltinOpFunction51   const c10::FunctionSchema& getSchema() const override {
52     return schema_;
53   }
54 
num_inputsBuiltinOpFunction55   size_t num_inputs() const override {
56     return schema_.arguments().size();
57   }
58 
setSchemaBuiltinOpFunction59   Function& setSchema(c10::FunctionSchema schema) override {
60     schema_ = std::move(schema);
61     return *this;
62   }
63 
callBuiltinOpFunction64   bool call(
65       Stack& stack,
66       std::optional<size_t>,
67       c10::function_ref<void(const Code&)>) override {
68     run(stack);
69     return false;
70   }
71 
callBuiltinOpFunction72   bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>)
73       override {
74     run(stack);
75     return false;
76   }
77 
78   ~BuiltinOpFunction() override = default;
79 
80  private:
81   c10::QualifiedName name_;
82 
83   std::function<void(Stack&)> callable_;
84 
85   c10::FunctionSchema schema_;
86 
87   std::string doc_string_;
88 };
89 
90 } // namespace torch::jit
91