1 #pragma once 2 3 #include <memory> 4 #include <string> 5 #include <unordered_map> 6 #include <utility> 7 #include <vector> 8 9 #include <torch/csrc/lazy/backend/backend_data.h> 10 #include <torch/csrc/lazy/backend/backend_device.h> 11 #include <torch/csrc/lazy/core/ir.h> 12 #include <torch/csrc/lazy/core/ir_util.h> 13 14 namespace torch { 15 namespace lazy { 16 17 class TORCH_API Computation { 18 public: 19 virtual int parameters_size() const = 0; 20 21 virtual const std::vector<Shape>& parameter_shapes() const = 0; 22 23 virtual const std::vector<std::string>& parameter_names() const = 0; 24 25 virtual const Shape& result_shape() const = 0; 26 27 virtual const std::string to_string() const = 0; 28 29 virtual ~Computation() = default; 30 31 // Indicates whether this computation is being executed inside a mark step 32 // Assume false unless set otherwise 33 bool in_mark_step = false; 34 }; 35 36 using ComputationPtr = std::shared_ptr<Computation>; 37 38 // Keeps track of the code generation state. 39 class TORCH_API LoweringContext { 40 public: 41 LoweringContext(const std::string& name, BackendDevice device); 42 LoweringContext( 43 const std::string& name, 44 BackendDevice device, 45 c10::ArrayRef<const torch::lazy::Node*> post_order, 46 Util::EmissionMap emit_status); 47 48 virtual ~LoweringContext() = default; 49 50 static std::unique_ptr<LoweringContext> Create( 51 const std::string& name, 52 BackendDevice device, 53 c10::ArrayRef<const torch::lazy::Node*> post_order, 54 Util::EmissionMap emit_status); 55 56 static std::unique_ptr<LoweringContext> Create( 57 const std::string& name, 58 BackendDevice device); 59 device()60 const BackendDevice& device() const { 61 return device_; 62 }; 63 64 // Retrieves the vector holding all the tensors associated with the parameter 65 // instructions which have been created. 66 const std::vector<BackendDataPtr>& GetParametersData() const; 67 68 // Adds a new input/output alias. 69 virtual void SetUpAlias( 70 const std::vector<int64_t>& output_index, 71 int64_t param_number, 72 const std::vector<int64_t>& param_index, 73 bool must_alias = false) { 74 // Dummy default implementation to do nothing. 75 } 76 77 // Check if parameter shape matches result at index. CheckResultShape(const BackendDataPtr & parameter_data,size_t result_idx)78 virtual bool CheckResultShape( 79 const BackendDataPtr& parameter_data, 80 size_t result_idx) { 81 // Dummy default implementation to do nothing. 82 return false; 83 } 84 85 // Adds the given output as a component of the result tuple and returns its 86 // assigned position within the tuple. 87 virtual size_t AddResult(const torch::lazy::Output& output) = 0; 88 89 // Associates the given output with the input parameter of the given index and 90 // shape. Only used for the operator-by-operator execution, mostly for 91 // debugging purposes. 92 virtual void AddParameter( 93 const torch::lazy::Output& output, 94 size_t index, 95 const Shape& shape, 96 const std::string& name) = 0; 97 98 // Build the computation capturing all the operations created with the 99 // embedded builder (returned by the builder() API). 100 virtual ComputationPtr Build() = 0; 101 GetEmittedNodeCount()102 size_t GetEmittedNodeCount() const { 103 return emit_status_.size(); 104 } 105 106 protected: 107 BackendDevice device_; 108 std::vector<BackendDataPtr> parameters_; 109 std::vector<size_t> parameter_sequence_; 110 Util::EmissionMap emit_status_; 111 }; 112 113 } // namespace lazy 114 } // namespace torch 115