#pragma once #include #include #include #include #include namespace torch { using at::IterArgs; struct CountTensors : IterArgs { size_t out = 0; void operator()(const at::Tensor& x) { out += 1; } void operator()(const std::optional& x) { out += x.has_value(); } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template size_t count_tensors(Args&&... args) { return CountTensors().apply(std::forward(args)...).out; } struct CountVariables : IterArgs { size_t out = 0; void operator()(const autograd::Variable& x) { out += 1; } void operator()(at::ArrayRef xs) { out += xs.size(); } }; template inline size_t count_variables(Args&&... args) { return CountVariables().apply(std::forward(args)...).out; } //===----------------------------------------------------------------------===// // std::index_sequence shim for C++11 //===----------------------------------------------------------------------===// // A container of type-template parameter indices. template struct Indices {}; // Decrements the index N, adds N-1 to the list of indices and forwards // whatever we already have. template struct MakeIndices : MakeIndices {}; // Partial specialization that forms our base case. When N is zero, we stop // and define a typedef that will be visible to earlier classes due to // inheritance. The typedef we define is an index list containing the numbers // 0 through N-1. template struct MakeIndices<0, Is...> { using indices = Indices; }; //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// template void apply(Function function, Ts&&... ts) { // https://stackoverflow.com/questions/13978916/inserting-a-variadic-argument-list-into-a-vector // Creates a dummy array, so that each function call is evaluated in order. // `(function(), 0)` is because `function` should (!) return `void`, so // according to the comma operator, it is evaluated and its result (`void`) // is discarded. Then the zero is evaluated and used as an element in the // array. The first zero ensures the array is not empty. // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) int _[]{0, (function(std::forward(ts)), 0)...}; (void)_; } template < typename ReturnType, typename... Ts, typename Function, typename Accessor> ReturnType unpack(Function function, Accessor accessor) { return ReturnType(unpack( std::move(function), std::move(accessor), typename MakeIndices::indices())); } template < typename ReturnType, typename... Ts, typename Function, typename Accessor, size_t... Is> ReturnType unpack(Function function, Accessor accessor, Indices) { return ReturnType(function(accessor.template operator()(Is)...)); } } // namespace torch