# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import logging import typing from typing import Callable, Union import numpy as np import torch # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[ [ # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. typing.Union[np.ndarray, torch._tensor.Tensor], # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. typing.Union[np.ndarray, torch._tensor.Tensor], ], float, ]: # A distance decorator that performs all the necessary checkes before calculating # the distance between two N-D tensors given a function. This can be a RMS # function, maximum abs diff, or any kind of distance function. def wrapper( # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. a: Union[np.ndarray, torch.Tensor], # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. b: Union[np.ndarray, torch.Tensor], ) -> float: # convert a and b to np.ndarray type fp64 a = to_np_arr_fp64(a) b = to_np_arr_fp64(b) # return NaN if shape mismatches if a.shape != b.shape: return np.nan # After we make sure shape matches, check if it's empty. If yes, return 0 if a.size == 0: return 0 # np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at # the same places in a and b. If not, return NaN if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)): return np.nan # mask out all the values that are either Inf or NaN mask = np.isinf(a) | np.isnan(a) if np.any(mask): logging.warning("Found inf/nan in tensor when calculating the distance") a_masked = a[~mask] b_masked = b[~mask] # after masking, the resulting tensor might be empty. If yes, return 0 if a_masked.size == 0: return 0 # only compare the rest (those that are actually numbers) using the metric return fn(a_masked, b_masked) return wrapper @distance # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def rms(a: np.ndarray, b: np.ndarray) -> float: return ((a - b) ** 2).mean() ** 0.5 @distance # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: return np.abs(a - b).max() @distance # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float: return np.abs((x - x_ref) / x_ref).max() # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if isinstance(x, np.ndarray): x = x.astype(np.float64) return x # pyre-fixme[3]: Return type must be annotated. def normalized_rms( # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. predicted: Union[np.ndarray, torch.Tensor], # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. ground_truth: Union[np.ndarray, torch.Tensor], ): num = rms(predicted, ground_truth) if num == 0: return 0 den = np.linalg.norm(to_np_arr_fp64(ground_truth)) return np.float64(num) / np.float64(den)