1 #pragma once 2 #include <ATen/ATen.h> 3 #include <ATen/core/op_registration/op_registration.h> 4 #include <torch/library.h> 5 6 namespace at { 7 8 // If an operator doesn't have a batching rule implemented then we fallback 9 // to this implementation. The fallback only works on out-of-place operators 10 // that return only tensors with new memory. (e.g., no in-place operators, no 11 // view operations). 12 // 13 // The fallback effectively takes all of the BatchedTensors in `stack`, slices 14 // them, and runs `op` on all of the corresponding slices to produce slices 15 // of the outputs. The output slices then get `torch.stack`ed to create the 16 // final returns. 17 // 18 // The performance of the fallback is not very good because it introduces an 19 // extra copy from stacking the sliced outputs. Because of this, we prefer to 20 // write batching rules for operators whenever possible. 21 void batchedTensorForLoopFallback( 22 const c10::OperatorHandle& op, 23 torch::jit::Stack* stack); 24 25 } // namespace at 26