#include #include using namespace at; static int test_int; Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { auto tensor_impl = c10::make_intrusive( Storage( Storage::use_byte_size_t(), 0, at::DataPtr(nullptr, Device(DeviceType::MAIA, 0)), nullptr, false), DispatchKey::MAIA, dtype); // This is a hack to workaround the shape checks in _convolution. tensor_impl->set_sizes_contiguous(size); return Tensor(std::move(tensor_impl)); } Tensor empty_override(IntArrayRef size, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, std::optional optional_memory_format) { test_int = 0; return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size); } Tensor& add_out_override(const Tensor & a, const Tensor & b , const Scalar& c, Tensor & out) { test_int = 1; return out; } Tensor fake_convolution( const Tensor& input, const Tensor& weight, const std::optional& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups) { test_int = 2; // Only the first 2 dimension of output shape is correct. return get_tensor(input.dtype(), {input.size(0), weight.size(0), input.size(2), input.size(3)}); } std::tuple fake_convolution_backward( const Tensor & grad_output, const Tensor & input, const Tensor & weight, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups, std::array output_mask) { test_int = 3; return std::tuple( get_tensor(input.dtype(), input.sizes()), get_tensor(weight.dtype(), weight.sizes()), get_tensor(input.dtype(), {})); } TORCH_LIBRARY_IMPL(aten, MAIA, m) { m.impl("empty.memory_format", empty_override); m.impl("add.out", add_out_override); m.impl("convolution_overrideable", fake_convolution); m.impl("convolution_backward_overrideable", fake_convolution_backward); } // TODO: Extend this to exercise multi-device setting. In that case, // we need to add a thread local variable to track the current device. struct MAIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { static constexpr DeviceType static_type = DeviceType::MAIA; MAIAGuardImpl() {} MAIAGuardImpl(DeviceType t) { AT_ASSERT(t == DeviceType::MAIA); } DeviceType type() const override { return DeviceType::MAIA; } Device exchangeDevice(Device d) const override { AT_ASSERT(d.type() == DeviceType::MAIA); AT_ASSERT(d.index() == 0); return d; } Device getDevice() const override { return Device(DeviceType::MAIA, 0); } void setDevice(Device d) const override { AT_ASSERT(d.type() == DeviceType::MAIA); AT_ASSERT(d.index() == 0); } void uncheckedSetDevice(Device d) const noexcept override { } Stream getStream(Device d) const noexcept override { return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0)); } Stream exchangeStream(Stream s) const noexcept override { return Stream(Stream::DEFAULT, Device(DeviceType::MAIA, 0)); } DeviceIndex deviceCount() const noexcept override { return 1; } // Event-related functions void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override { TORCH_CHECK(false, "MAIA backend doesn't support events."); } void block( void* event, const Stream& stream) const override { TORCH_CHECK(false, "MAIA backend doesn't support events."); } bool queryEvent(void* event) const override { TORCH_CHECK(false, "MAIA backend doesn't support events."); } void destroyEvent( void* event, const DeviceIndex device_index) const noexcept override { } }; constexpr DeviceType MAIAGuardImpl::static_type; C10_REGISTER_GUARD_IMPL(MAIA, MAIAGuardImpl); int get_test_int() { return test_int; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_test_int", &get_test_int); }