1 #pragma once 2 #include <ATen/core/Tensor.h> 3 #include <ATen/AccumulateType.h> 4 #include <ATen/Dispatch.h> 5 #include <ATen/TensorUtils.h> 6 7 namespace at::native { multilabel_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target)8 inline void multilabel_margin_loss_shape_check( 9 int64_t& nframe, 10 int64_t& dim, 11 const int64_t& ndims, 12 const Tensor& input, 13 const Tensor& target) { 14 TORCH_CHECK( 15 (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, 16 "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", 17 input.sizes()); 18 19 if (ndims <= 1) { 20 nframe = 1; 21 dim = ndims == 0 ? 1 : input.size(0); 22 TORCH_CHECK( 23 target.dim() <= 1 && target.numel() == dim, 24 "inconsistent target size: ", target.sizes(), " for input of size: ", 25 input.sizes()); 26 } else { 27 nframe = input.size(0); 28 dim = input.size(1); 29 TORCH_CHECK( 30 target.dim() == 2 && target.size(0) == nframe && 31 target.size(1) == dim, 32 "inconsistent target size: ", target.sizes(), " for input of size: ", 33 input.sizes()); 34 } 35 } 36 multi_margin_loss_shape_check(int64_t & nframe,int64_t & dim,const int64_t & ndims,const Tensor & input,const Tensor & target,const std::optional<Tensor> & weight)37 inline void multi_margin_loss_shape_check( 38 int64_t& nframe, 39 int64_t& dim, 40 const int64_t& ndims, 41 const Tensor& input, 42 const Tensor& target, 43 const std::optional<Tensor>& weight) { 44 TORCH_CHECK( 45 (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, 46 "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", 47 input.sizes()); 48 49 if (ndims <= 1) { 50 nframe = 1; 51 dim = ndims == 0 ? 1 : input.size(0); 52 } else { 53 nframe = input.size(0); 54 dim = input.size(1); 55 } 56 57 TORCH_CHECK( 58 target.dim() <= 1 && target.numel() == nframe, 59 "inconsistent target size, expected ", nframe, " but got ", 60 target.sizes()); 61 if (weight && weight->defined()) { 62 TORCH_CHECK( 63 weight->dim() <= 1 && weight->numel() == dim, 64 "inconsistent weight size, expected ", dim, " but got ", 65 weight->sizes()); 66 } 67 } 68 69 } // namespace at::native 70