Lines Matching full:argmax

483         ops = [torch.norm, torch.argmax, torch.argmin]
1804 self.compare_with_numpy(torch.argmax, np.argmax, t)
1810 self.compare_with_numpy(torch.argmax, np.argmax, t)
1837 self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None)
1844 … lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None)
1849 # Argmax
1850 torch_fn = partial(torch.argmax, dim=1)
1851 np_fn = partial(np.argmax, axis=1)
2112 self.assertEqual(x.argmax().item(), 0)
2113 self.assertEqual(x.argmax(dim=None).item(), 0)
2114 self.assertEqual(x.argmax(dim=0).item(), 0)
2115 self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor(0, dtype=torch.int64))
2215 self.assertEqual(x.argmax().item(), 5)
2220 self.assertEqual(x.argmax(dim=0), torch.tensor([1, 1, 1], dtype=torch.int64))
2225 … self.assertEqual(x.argmax(dim=0, keepdim=True), torch.tensor([[1, 1, 1]], dtype=torch.int64))
2230 self.assertEqual(x.argmax(dim=1), torch.tensor([1, 2], dtype=torch.int64))
2235 self.assertEqual(x.argmax(dim=1, keepdim=True), torch.tensor([[1], [2]], dtype=torch.int64))
2240 self.assertEqual(x[:, :2].argmax().item(), 2)
2335 # 1D case: argmax
2344 self.assertEqual(x.argmax().item(), i)
2346 self.assertEqual(y.argmax().item(), i - shift)
2350 self.assertEqual(x.argmax().item(), size - i)
2351 self.assertEqual(y.argmax().item(), ysize - i)
2369 # 2D case: max/argmax
2376 xs1 = x.argmax(dim=-1)
2385 xs1 = x.argmax(dim=-1)
2420 output1 = input_.argmax(dim=0)
2441 self.assertEqual(x.argmax(0), x.shape[0] - 1)
2451 self.assertEqual(x.argmax(dim=0), torch.zeros(n, dtype=torch.int64))
2454 self.assertEqual(x.argmax(dim=-2), torch.zeros(n, dtype=torch.int64))
2457 self.assertEqual(x.argmax(dim=0, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
2460 self.assertEqual(x.argmax(dim=-2, keepdim=True), torch.zeros(1, n, dtype=torch.int64))
3419 … # not throw errors. Also, checking the return type of argmax requires supplying a different dtype
3426 ('argmax', torch.argmax, {'dtype': torch.int64}, np.argmax),