#include #include #include #include #include #include #include using namespace at; static size_t instance_count = 0; struct TestCPUGenerator : public c10::GeneratorImpl { TestCPUGenerator(uint64_t value) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) { ++instance_count; } ~TestCPUGenerator() { --instance_count; } uint32_t random() { return static_cast(value_); } uint64_t random64() { return value_; } void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); } void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); } uint64_t get_offset() const override { throw std::runtime_error("not implemented"); } uint64_t current_seed() const override { throw std::runtime_error("not implemented"); } uint64_t seed() override { throw std::runtime_error("not implemented"); } void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); } c10::intrusive_ptr get_state() const override { throw std::runtime_error("not implemented"); } TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); } static DeviceType device_type() { return DeviceType::CPU; } uint64_t value_; }; Tensor& random_(Tensor& self, std::optional generator) { return at::native::templates::random_impl(self, generator); } Tensor& random_from_to(Tensor& self, int64_t from, optional to, std::optional generator) { return at::native::templates::random_from_to_impl(self, from, to, generator); } Tensor& random_to(Tensor& self, int64_t to, std::optional generator) { return random_from_to(self, 0, to, generator); } Generator createTestCPUGenerator(uint64_t value) { return at::make_generator(value); } Generator identity(Generator g) { return g; } size_t getInstanceCount() { return instance_count; } TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) { m.impl("aten::random_.from", random_from_to); m.impl("aten::random_.to", random_to); m.impl("aten::random_", random_); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("createTestCPUGenerator", &createTestCPUGenerator); m.def("getInstanceCount", &getInstanceCount); m.def("identity", &identity); }