Lines Matching full:dim
61 def _rand_shape(dim, min_size, max_size): argument
63 for i in range(dim):
67 def _reduced_shape(shape, dim=None, keepdim=False): argument
68 """Computes the expected reduced shape given dim and keepdim
72 dim : The dimensions to reduce
79 if dim is None:
83 dim = dim if isinstance(dim, Sequence) else [dim]
84 dim = {i if i >= 0 else len(shape) + i for i in dim}
88 if i not in dim:
102 """Tests output shape for input with ndim and dim and keepdim kwargs"""
118 """Tests that the default dim reduces all dimensions."""
124 """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1."""
130 """Tests that dim=None reduces all dimensions."""
132 self._test_dim_keepdim(op, device, ndim=ndim, dim=None)
136 """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1."""
138 self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True)
142 """Tests that dim=i reduces dimension i."""
143 self._test_dim_keepdim(op, device, ndim=0, dim=0)
144 self._test_dim_keepdim(op, device, ndim=1, dim=0)
145 self._test_dim_keepdim(op, device, ndim=2, dim=-1)
146 self._test_dim_keepdim(op, device, ndim=3, dim=1)
150 """Tests that dim=i, when keepdim=True, reduces dimension i to size 1."""
151 self._test_dim_keepdim(op, device, ndim=0, dim=0, keepdim=True)
152 self._test_dim_keepdim(op, device, ndim=1, dim=0, keepdim=True)
153 self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True)
154 self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True)
158 """Tests that dim=[] is a no-op"""
159 self._test_dim_keepdim(op, device, ndim=0, dim=[])
160 self._test_dim_keepdim(op, device, ndim=2, dim=[])
164 """Tests that dim=[], when keepdim=True, is a no-op"""
165 self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True)
166 self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True)
170 """Tests that dim=[i, j, ...] reduces dimensions i, j, ...."""
171 self._test_dim_keepdim(op, device, ndim=1, dim=[0])
172 self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
176 … """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1."""
177 self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True)
178 self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True)
182 """Tests that operator correctly handles unsorted dim list."""
183 self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2])
187 """Tests that operator correctly handles unsorted dim list when keepdim=True."""
188 self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True)
192 """Tests that an error is raised if dim has duplicate entries."""
194 self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2])
198 """Tests that ops claiming to not support multi dim actually don't."""
200 self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2])
204 """Tests that passing an off-bounds dim throws"""
206 self._test_dim_keepdim(op, device, ndim=2, dim=2)
211 than 64 dims along some specific dimensions. dim=None is ok"""
214 op(t, dim=0)
288 for dim in [0] + [[0, 2]] if op.supports_multiple_dims else []:
289 args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
292 result = op(t, *args, dim=dim, **kwargs)
296 result = op(t, *args, dim=dim, **kwargs)
303 op(t, *args, dim=dim, **kwargs)
306 op(t, *args, dim=dim, **kwargs)
313 for dim in [1] + [[1, 2]] if op.supports_multiple_dims else []:
314 args, kwargs = next(op.generate_args_kwargs(t, dim=dim))
315 result = op(t, *args, dim=dim, **kwargs)
316 self.assertEqual(result.shape, _reduced_shape(t.shape, dim))
333 self._test_noncontiguous(op, t[:, ::2], dim=1)
339 self._test_noncontiguous(op, t[::2, :], dim=0)
385 for dim in [0, 1, 3] + ([[0, 2], [1, 3]] if op.supports_multiple_dims else []):
386 self._test_ref(op, t, dim=dim)
399 self._test_ref(op, t, dim=1)
415 self._test_ref(op, t, dim=0)
416 self._test_ref(op, t, dim=1)
453 self.assertEqual(tensor.var(dim=0), 0.03125)
473 op(x, dim=64)
475 op(x, dim=-1)
485 y = op(x, dim=-1)
486 y2 = op(x2, dim=-1)
542 def logcumsumexp_slow(a, dim): argument
544 for i in range(a.size(dim)):
546 index[dim] = slice(None, i + 1, None)
548 res_lst.append(logsumexp(a_inp.cpu().numpy(), axis=dim, keepdims=True))
549 res = np.concatenate(res_lst, axis=dim)
554 actual = torch.logcumsumexp(a, dim=i)
557 expected2 = logcumsumexp_slow(a, dim=i)
638 for dim in range(len(size) + 1):
644 if dim == len(size):
648 nvs = nv.sum(dim)
649 tvs = tv.sum(dim)
709 r1 = x.prod(dim=0, keepdim=False).byte()
710 r2 = x.all(dim=0, keepdim=False)
714 r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte()
715 r4 = x.any(dim=1, keepdim=True)
732 res1 = x1.sum(dim=(0, 2), keepdim=True)
799 def do_one(tensors_dict, dim): argument
802 dim = 0
807 expected = numpy_op(tensor.cpu().numpy(), dim)
808 actual = pytorch_op(tensor, dim)
811 self._assert_matches_numpy(pytorch_op(tensor.cuda(), dim).cpu(), expected)
978 # Test non-default dim
1107 for dim in dim_list:
1110 var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
1111 var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
1112 mean2 = x.mean(dim=dim, keepdim=keepdim)
1129 b = torch.all(a, dim=0)
1130 c = a.to(torch.bool).all(dim=0)
1137 self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item())
1138 self.assertTrue(torch.all(torch.amax(a, dim=1) == inf).item())
1146 self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item())
1147 self.assertTrue(torch.all(torch.amin(a, dim=1) == (-inf)).item())
1229 def _amin_wrapper(x, dim=None, keepdims=False): argument
1230 return torch.aminmax(x, dim=dim, keepdim=keepdims)[0]
1232 def _amax_wrapper(x, dim=None, keepdims=False): argument
1233 return torch.aminmax(x, dim=dim, keepdim=keepdims)[1]
1242 torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0)
1265 # input and weights dim mismatch
1342 # Stability for inner dim
1674 (0, 1, None), # dim
1678 for noncontiguous, dim in args:
1691 dim_kwargs = {} if dim is None else {"dim": dim}
1715 … torch_func_partial = partial(torch_func, keepdim=True, dim=count_dim)
1718 torch_func_partial = partial(torch_func, dim=count_dim)
1843 self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1],
1845 self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1],
1850 torch_fn = partial(torch.argmax, dim=1)
1858 … self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1859 … self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1862 torch_fn = partial(torch.argmin, dim=1)
1870 … self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None)
1871 … self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None)
1896 def _test_all_any_with_dim(x, dim): argument
1897 torch_fn = partial(torch.all, dim=dim)
1898 np_fn = partial(np.all, axis=dim)
1901 torch_fn = partial(torch.any, dim=dim)
1902 np_fn = partial(np.any, axis=dim)
1905 def _test_out_variant(x, dim): argument
1908 expected = torch.all(x, dim)
1909 torch.all(x, dim, out=out)
1912 expected = torch.any(x, dim)
1913 torch.any(x, dim, out=out)
1917 torch.all(x, dim, out=out)
1920 torch.any(x, dim, out=out)
1922 def _test_all_any_with_dim_keepdim(x, dim, keepdim): argument
1923 torch_fn = partial(torch.all, dim=dim, keepdim=keepdim)
1924 np_fn = partial(np.all, axis=dim, keepdims=keepdim)
1927 torch_fn = partial(torch.any, dim=dim, keepdim=keepdim)
1928 np_fn = partial(np.any, axis=dim, keepdims=keepdim)
1938 self.assertEqual(torch.all(x, dim=0).dtype, expected_dtype)
1939 self.assertEqual(torch.any(x, dim=0).dtype, expected_dtype)
1963 for dim in range(ndim):
1965 _test_all_any_with_dim(x, dim)
1966 _test_all_any_with_dim(x.T, dim)
1967 _test_all_any_with_dim(x[..., ::2], dim)
1968 _test_out_variant(x, dim)
1969 _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1970 _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1973 _test_all_any_with_dim(x, dim)
1974 _test_all_any_with_dim(x.T, dim)
1975 _test_all_any_with_dim(x[..., ::2], dim)
1976 _test_out_variant(x, dim)
1977 _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1978 _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1981 _test_all_any_with_dim(x, dim)
1982 _test_all_any_with_dim(x.T, dim)
1983 _test_all_any_with_dim(x[..., ::2], dim)
1984 _test_out_variant(x, dim)
1985 _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1986 _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
1989 _test_all_any_with_dim(x, dim)
1990 _test_all_any_with_dim(x.T, dim)
1991 _test_all_any_with_dim(x[..., ::2], dim)
1992 _test_out_variant(x, dim)
1993 _test_all_any_with_dim_keepdim(x, dim, keepdim=True)
1994 _test_all_any_with_dim_keepdim(x, dim, keepdim=False)
2005 for dim in [(0, 0), (0, -4)]:
2007 op(x, dim=dim)
2041 self.assertEqual(x.sum(dim=(-1, -2)).cpu(), y.sum(dim=(-1, -2)))
2042 self.assertEqual(x.sum(dim=(1, 3)).cpu(), y.sum(dim=(1, 3)))
2073 torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2079 torch.sum(x, dim=[0], dtype=torch.float32, out=y)
2089 torch.max(x, dim=0, out=(valid_values, valid_indices))
2090 torch.min(x, dim=0, out=(valid_values, valid_indices))
2091 torch.amax(x, dim=0, out=valid_values)
2092 torch.amin(x, dim=0, out=valid_values)
2095 torch.max(x, dim=0, out=(illegal_values, valid_indices))
2097 torch.min(x, dim=0, out=(illegal_values, valid_indices))
2099 torch.max(x, dim=0, out=(valid_values, illegal_indices))
2101 torch.min(x, dim=0, out=(valid_values, illegal_indices))
2103 torch.max(x, dim=0, out=(illegal_values, illegal_indices))
2105 torch.min(x, dim=0, out=(illegal_values, illegal_indices))
2113 self.assertEqual(x.argmax(dim=None).item(), 0)
2114 self.assertEqual(x.argmax(dim=0).item(), 0)
2115 self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2119 self.assertEqual(x.argmin(dim=None).item(), 0)
2120 self.assertEqual(x.argmin(dim=0).item(), 0)
2121 self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2184 self.assertEqual(x.argmin(dim=None).item(), 0)
2191 self.assertEqual(x.min(dim=0, keepdim=True), (torch.tensor([[-1, 2, 1]], dtype=dtype),
2193 self.assertEqual(x.amin(dim=0, keepdim=True), torch.tensor([[-1, 2, 1]], dtype=dtype))
2194 … self.assertEqual(x.argmin(dim=0, keepdim=True), torch.tensor([[0, 0, 0]], dtype=torch.int64))
2201 self.assertEqual(x.min(dim=1, keepdim=True), (torch.tensor([[-1], [3]], dtype=dtype),
2203 self.assertEqual(x.amin(dim=1, keepdim=True), torch.tensor([[-1], [3]], dtype=dtype))
2204 self.assertEqual(x.argmin(dim=1, keepdim=True), torch.tensor([[0], [1]], dtype=torch.int64))
2220 self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2222 self.assertEqual(x.max(dim=0, keepdim=True), (torch.tensor([[5, 3, 6]], dtype=dtype),
2224 self.assertEqual(x.amax(dim=0, keepdim=True), torch.tensor([[5, 3, 6]], dtype=dtype))
2225 … self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2230 self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2235 self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2249 def normfn_attr(t, dim, keepdim=False, out=None): argument
2251 return attr(t, 2, dim, keepdim, out=out)
2255 def fn(x, dim, keepdim=False, out=None): argument
2256 ans = fn_attr(x, dim, keepdim=keepdim, out=out)
2259 def fn_tuple(x, dim, keepdim=False, out=None): argument
2260 return fn_attr(x, dim, keepdim=keepdim, out=out)
2262 def test_multidim(x, dim): argument
2263 self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
2264 self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
2265 self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension())
2269 dim = random.randint(0, 2)
2270 test_multidim(x, dim)
2274 dim = 0
2275 self.assertEqual(fn(x, dim).shape, ())
2276 self.assertEqual(fn(x, dim, keepdim=True).shape, (1,))
2310 result = input_.sum(dim=0)
2359 xs = x.sum(dim=-1)
2366 xs = x.sum(dim=-1)
2376 xs1 = x.argmax(dim=-1)
2377 xs2 = x.max(dim=-1).indices
2385 xs1 = x.argmax(dim=-1)
2386 xs2 = x.max(dim=-1).indices
2397 xs1 = x.argmin(dim=-1)
2398 xs2 = x.min(dim=-1).indices
2406 xs1 = x.argmin(dim=-1)
2407 xs2 = x.min(dim=-1).indices
2420 output1 = input_.argmax(dim=0)
2421 output2 = input_.sum(dim=0)
2451 self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2452 self.assertEqual(x.argmin(dim=0), torch.zeros(n, dtype=torch.int64))
2454 self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2455 self.assertEqual(x.argmin(dim=-2), torch.zeros(n, dtype=torch.int64))
2457 self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2458 self.assertEqual(x.argmin(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2460 self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2461 self.assertEqual(x.argmin(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2480 for dim in range(t.ndim):
2481 res = t.median(dim, True)
2482 self.assertEqual(res, t.nanmedian(dim, True))
2483 size = t.size(dim) if t.ndim > 0 else 1
2485 self.assertEqual(res[0], (t.sort(dim)[0]).select(dim, k).unsqueeze_(dim))
2486 self.assertEqual(res[0], t.gather(dim, res[1]))
2490 … self.assertEqual(res[0].cpu().numpy(), np.median(t_numpy, dim, keepdims=True), exact_dtype=False)
2515 for dim in range(t.ndim):
2516 res = op(t, dim, True)
2517 size = t.size(dim) if t.ndim > 0 else 1
2518 num_nan = t.isnan().sum(dim, True)
2523 self.assertEqual(res[0], (t.sort(dim)[0]).gather(dim, k))
2524 self.assertEqual(res[0], t.gather(dim, res[1]))
2529 ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()]
2612 for interpolation, dim in product(interpolations,
2614 result = torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation)
2615 expected = numpy_op(a.cpu().numpy(), q.cpu().numpy(), dim,
2621 torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out)
2625 def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)): argument
2628 op(t, torch.tensor(q, device=device), dim).sum().backward()
2671 for dim in range(x.dim()):
2674 std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2675 std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2676 mean2 = x.mean(dim=dim, keepdim=keepdim)
2691 for dim in range(x.dim()):
2694 var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2695 var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim)
2696 mean2 = x.mean(dim=dim, keepdim=keepdim)
2715 for dim in dim_list:
2718 std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim)
2719 std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim)
2720 mean2 = x.mean(dim=dim, keepdim=keepdim)
2724 def _compare_std_var_with_numpy(self, op, device, dtype, input, dim, argument
2728 'axis' : dim,
2733 if dim is None:
2748 if dim is None and use_out is False:
2750 elif dim is not None and use_out is False:
2751 torch_result = torch_op(input, dim, unbiased, keepdim)
2752 elif dim is not None and use_out is True:
2754 torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2757 torch_result = torch_op(input, dim, unbiased, keepdim, out=out)
2789 # dim
2802 for dim, correction, keepdim in test_args:
2803 numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2809 torch_res = torch.var(tensor, dim=dim, correction=correction, keepdim=keepdim)
2823 # dim
2836 for dim, correction, keepdim in test_args:
2837 numpy_kwargs = dict(axis=dim, ddof=correction, keepdims=keepdim)
2843 torch_res = torch.std(tensor, dim=dim, correction=correction, keepdim=keepdim)
2857 # dim
2869 for dim, correction, keepdim in test_args:
2870 kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2872 if dim is not None:
2873 mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2888 # dim
2900 for dim, correction, keepdim in test_args:
2901 kwargs = dict(dim=dim, correction=correction, keepdim=keepdim)
2903 if dim is not None:
2904 mean1 = torch.mean(tensor, dim=dim, keepdim=keepdim)
2918 _func(_tensor, dim=-1, correction=_correction)
2933 for dim in dim_list:
2935 amin1 = torch.amin(x, dim=dim, keepdim=keepdim)
2936 amax1 = torch.amax(x, dim=dim, keepdim=keepdim)
2939 for i, d in enumerate(dim):
2942 amin2 = torch.amin(amin2, dim=d, keepdim=keepdim)
2943 amax2 = torch.amax(amax2, dim=d, keepdim=keepdim)
3245 for dim in range(D):
3246 … self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim], atol=1e-5, rtol=1e-5)
3253 for dim in range(D):
3254 self.assertEqual(actual_bin_edges[dim], expected_bin_edges[dim])
3286 bin_ct = [random.randint(1, 5) for dim in range(D)]
3290 … bin_range_tuples = [sorted((random.uniform(-9, 9), random.uniform(-9, 9))) for dim in range(D)]
3295 for dim in range(D):
3296 bin_range[2 * dim + 1] = bin_range[2 * dim]
3307 for dim in range(D):
3308 bin_edges_noncontig[dim].copy_(bin_edges[dim])
3310 for dim in range(D):
3311 self.assertEqual(bin_edges[dim].is_contiguous(), bins_contig)
3378 # cater to functions where specifying the `dim` parameter is necessary.
3392 … # Check if reduction happens along the specified dim with and without keepdim. Check with
3395 … self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3397 … fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=False)
3399 … self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3401 … fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype=False)
3403 … self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3406 … fn(master_input, dim=2, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3408 … self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3411 … fn(master_input, dim=-1, keepdim=True).cpu().numpy(), msg=error_msg, exact_dtype=False)
3413 # Check if function raises error on specified zero'd dimension as reduction dim.
3414 … self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3416 …# Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using comparison operato…
3417 # raises an error if no `dim` parameter is specified. This exists separately from tests in
3418 … # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does
3434 …self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=2), msg=error_m…
3436 …np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_msg, exact_dtype=F…
3439 …self.assertEqual(torch.empty((2, 0), device=device, **dtype), fn(master_input, dim=-1), msg=error_…
3441 …np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error_msg, exact_dtype…
3445 …self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=2, keepdim=T…
3447 …self.assertEqual(torch.empty((2, 0, 1), device=device, **dtype), fn(master_input, dim=-1, keepdim=…
3450 # Check if function raises error on specified zero'd dimension as reduction dim.
3451 … self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input, dim=1))
3453 … self.assertRaisesRegex(IndexError, "Expected reduction dim", lambda: fn(master_input))
3455 …# Tests to ensure that reduction of zero-dim tensors (i.e. empty tensors) using math operators wor…
3456 …# non-zero dim is specified for the reduction and throws an error when the dim specified is 0. Alt…
3479 … self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=2), msg=error_msg)
3480 …self.assertEqual(np_function(np_input, axis=2), fn(master_input, dim=2).cpu().numpy(), msg=error_m…
3483 … self.assertEqual(torch.empty((2, 0), device=device), fn(master_input, dim=-1), msg=error_msg)
3484 …self.assertEqual(np_function(np_input, axis=-1), fn(master_input, dim=-1).cpu().numpy(), msg=error…
3487 … self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=2, keepdim=True),
3489 …self.assertEqual(np_function(np_input, axis=2, keepdims=True), fn(master_input, dim=2, keepdim=Tru…
3492 … self.assertEqual(torch.empty((2, 0, 1), device=device), fn(master_input, dim=-1, keepdim=True),
3494 …self.assertEqual(np_function(np_input, axis=-1, keepdims=True), fn(master_input, dim=-1, keepdim=T…
3497 …self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=1), msg=err…
3498 …self.assertEqual(torch.full((2, 4), return_value, device=device), fn(master_input, dim=-2), msg=er…
3499 …ertEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=1, keepdim=True),
3501 …rtEqual(torch.full((2, 1, 4), return_value, device=device), fn(master_input, dim=-2, keepdim=True),
3506 … self.assertEqual(np.float32(np_function(np_input, axis=1)), fn(master_input, dim=1).cpu().numpy(),
3508 …self.assertEqual(np.float32(np_function(np_input, axis=-2)), fn(master_input, dim=-2).cpu().numpy(…
3511 fn(master_input, dim=1, keepdim=True).cpu().numpy(),
3514 fn(master_input, dim=-2, keepdim=True).cpu().numpy(),
3517 # logsumexp throws a type error when not specifying dim so test separately.
3522 # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from
3523 … # other tests for checking reduction with zero-dim tensors because these tests have significantly
3563 gi = torch.randn(op(x, dim=0).shape, dtype=torch.float, device=device)
3564 grad1, = torch.autograd.grad([op(x, dim=0)], [x], gi)
3566 grad2, = torch.autograd.grad([op(x, dim=0, dtype=torch.double)], [x], gi.double())
3568 grad2, = torch.autograd.grad([op(x.double(), dim=0)], [x], gi.double())