1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <torch/csrc/lazy/backend/backend_data.h> 5 #include <torch/csrc/lazy/backend/backend_device.h> 6 #include <torch/csrc/lazy/backend/lowering_context.h> 7 #include <torch/csrc/lazy/core/lazy_graph_executor.h> 8 #include <torch/csrc/lazy/core/shape.h> 9 #include <torch/csrc/lazy/core/tensor.h> 10 #include <atomic> 11 12 namespace torch { 13 namespace lazy { 14 15 struct IrBuilder; 16 17 /** 18 * Work in progress- don't treat this as a stable interface yet! 19 */ 20 class TORCH_API BackendImplInterface { 21 public: 22 virtual ~BackendImplInterface() = default; 23 24 /** 25 * Initialization/Teardown 26 * */ 27 // No-op by default. Allows custom functionality to be exposed through 28 // extension bindings. InitializeAtenBindings()29 virtual void InitializeAtenBindings() const {} 30 31 virtual void PrepareToExit() const = 0; 32 33 /** 34 * Configuration 35 * */ 36 37 virtual void SetRngSeed(size_t seed) const = 0; 38 39 /** 40 * IR Tracing 41 * */ 42 43 virtual const IrBuilder* GetIrBuilder() const = 0; 44 45 /** 46 * Data Transfer 47 * */ 48 49 virtual BackendDataPtr MakeComputationDataFromTensor( 50 const at::Tensor& tensor, 51 const Shape& shape, 52 const BackendDevice& device) const = 0; 53 virtual BackendDataPtr MakeComputationDataFromScalar( 54 const at::Scalar& scalar, 55 const torch::lazy::BackendDevice& device) const = 0; 56 virtual BackendDataPtr CreateDataPlaceholder( 57 const BackendDevice& device, 58 const Shape& shape) const = 0; 59 60 // Gets backend data if the node is a device data node. Otherwise returns 61 // nullptr 62 virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0; 63 64 virtual at::Tensor MakeTensorFromComputationData( 65 const BackendDataPtr data, 66 std::optional<at::ScalarType> logical_scalar_type) const = 0; 67 68 /** 69 * Lowering, Compilation, Execution 70 * */ 71 72 virtual std::unique_ptr<LoweringContext> CreateLoweringContext( 73 const std::string& name, 74 BackendDevice device, 75 c10::ArrayRef<const torch::lazy::Node*> post_order, 76 Util::EmissionMap emit_status) const = 0; 77 78 virtual std::unique_ptr<LoweringContext> CreateLoweringContext( 79 const std::string& name, 80 BackendDevice device) const = 0; 81 82 // TODO(whc) need to keep this? 83 virtual std::vector<std::string> GetCompilationDevices( 84 const std::string& device, 85 c10::ArrayRef<std::string> devices) const = 0; 86 87 virtual std::vector<ComputationPtr> Compile( 88 std::vector<ComputationPtr> instances) const = 0; 89 90 virtual std::vector<BackendDataPtr> ExecuteComputation( 91 torch::lazy::ComputationPtr computation, 92 c10::ArrayRef<BackendDataPtr> arguments, 93 const BackendDevice& device) const = 0; 94 95 /** 96 * Device Configuration 97 * */ 98 99 // Set or get the default device type. 100 // For backends used with virtual c10::Devices, this configures what real 101 // device type the backend should use, and matters if the backend supports 102 // more than one type of real device. 103 virtual std::shared_ptr<BackendDeviceType> GetDefaultDeviceType() const = 0; 104 virtual void SetDefaultDeviceType(int8_t type) = 0; 105 106 // Set or get the default device ordinal. 107 // For backends that supports multi-device, this configures what the 108 // default device the backend should use. 109 virtual int64_t GetDefaultDeviceOrdinal() const = 0; 110 virtual void SetDefaultDeviceOrdinal(int64_t) = 0; 111 112 // Specify which aten device should be used for eager fallback 113 // may change depending on current 'Default' DeviceType 114 virtual at::DeviceType EagerFallbackDeviceType() const = 0; 115 116 // Query all available backend devices 117 virtual std::vector<BackendDevice> GetBackendDevices() const = 0; 118 CreateMetricReport()119 virtual std::string CreateMetricReport() const { 120 return ""; 121 } 122 123 // Map a particular c10:: device to a concrete backend device 124 // Note:: c10:: devices may be virtual or concrete. xla:: and lazy:: are 125 // virtual devices, meaning they may map to a gpu, tpu, etc. behind the 126 // scenes. In the future, non-virtual c10:: devices may also use lazy tensors 127 // through a mode, in which case these APIs should still work, but should be 128 // identity mappings. 129 virtual BackendDevice GetBackendDevice(c10::Device device) const = 0; 130 131 // TODO(whc) 132 // Additional APIs expected for supporting distributed training, to be 133 // designed 134 135 /** 136 * Debug/Metrics 137 * */ 138 139 // virtual std::map<std::string, Metric> GetMetrics() const = 0; 140 141 // virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; 142 143 virtual std::string GetComputationBackendText( 144 const ComputationPtr computation) const = 0; 145 }; 146 147 class TORCH_API BackendRegistrar { 148 public: 149 BackendRegistrar(const BackendImplInterface* backend_impl_interface); 150 }; 151 152 TORCH_API bool hasBackend(); 153 TORCH_API const BackendImplInterface* getBackend(); 154 155 TORCH_API const IrBuilder* getIrBuilder(); 156 157 } // namespace lazy 158 } // namespace torch 159