xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/backend/lowering_context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <string>
5 #include <unordered_map>
6 #include <utility>
7 #include <vector>
8 
9 #include <torch/csrc/lazy/backend/backend_data.h>
10 #include <torch/csrc/lazy/backend/backend_device.h>
11 #include <torch/csrc/lazy/core/ir.h>
12 #include <torch/csrc/lazy/core/ir_util.h>
13 
14 namespace torch {
15 namespace lazy {
16 
17 class TORCH_API Computation {
18  public:
19   virtual int parameters_size() const = 0;
20 
21   virtual const std::vector<Shape>& parameter_shapes() const = 0;
22 
23   virtual const std::vector<std::string>& parameter_names() const = 0;
24 
25   virtual const Shape& result_shape() const = 0;
26 
27   virtual const std::string to_string() const = 0;
28 
29   virtual ~Computation() = default;
30 
31   // Indicates whether this computation is being executed inside a mark step
32   // Assume false unless set otherwise
33   bool in_mark_step = false;
34 };
35 
36 using ComputationPtr = std::shared_ptr<Computation>;
37 
38 // Keeps track of the code generation state.
39 class TORCH_API LoweringContext {
40  public:
41   LoweringContext(const std::string& name, BackendDevice device);
42   LoweringContext(
43       const std::string& name,
44       BackendDevice device,
45       c10::ArrayRef<const torch::lazy::Node*> post_order,
46       Util::EmissionMap emit_status);
47 
48   virtual ~LoweringContext() = default;
49 
50   static std::unique_ptr<LoweringContext> Create(
51       const std::string& name,
52       BackendDevice device,
53       c10::ArrayRef<const torch::lazy::Node*> post_order,
54       Util::EmissionMap emit_status);
55 
56   static std::unique_ptr<LoweringContext> Create(
57       const std::string& name,
58       BackendDevice device);
59 
device()60   const BackendDevice& device() const {
61     return device_;
62   };
63 
64   // Retrieves the vector holding all the tensors associated with the parameter
65   // instructions which have been created.
66   const std::vector<BackendDataPtr>& GetParametersData() const;
67 
68   // Adds a new input/output alias.
69   virtual void SetUpAlias(
70       const std::vector<int64_t>& output_index,
71       int64_t param_number,
72       const std::vector<int64_t>& param_index,
73       bool must_alias = false) {
74     // Dummy default implementation to do nothing.
75   }
76 
77   // Check if parameter shape matches result at index.
CheckResultShape(const BackendDataPtr & parameter_data,size_t result_idx)78   virtual bool CheckResultShape(
79       const BackendDataPtr& parameter_data,
80       size_t result_idx) {
81     // Dummy default implementation to do nothing.
82     return false;
83   }
84 
85   // Adds the given output as a component of the result tuple and returns its
86   // assigned position within the tuple.
87   virtual size_t AddResult(const torch::lazy::Output& output) = 0;
88 
89   // Associates the given output with the input parameter of the given index and
90   // shape. Only used for the operator-by-operator execution, mostly for
91   // debugging purposes.
92   virtual void AddParameter(
93       const torch::lazy::Output& output,
94       size_t index,
95       const Shape& shape,
96       const std::string& name) = 0;
97 
98   // Build the computation capturing all the operations created with the
99   // embedded builder (returned by the builder() API).
100   virtual ComputationPtr Build() = 0;
101 
GetEmittedNodeCount()102   size_t GetEmittedNodeCount() const {
103     return emit_status_.size();
104   }
105 
106  protected:
107   BackendDevice device_;
108   std::vector<BackendDataPtr> parameters_;
109   std::vector<size_t> parameter_sequence_;
110   Util::EmissionMap emit_status_;
111 };
112 
113 } // namespace lazy
114 } // namespace torch
115