1# mypy: allow-untyped-defs 2# Copyright (c) Meta Platforms, Inc. and affiliates 3 4from .core import MaskedTensor 5 6 7__all__ = [ 8 "as_masked_tensor", 9 "masked_tensor", 10] 11 12 13# These two factory functions are intended to mirror 14# torch.tensor - guaranteed to be a leaf node 15# torch.as_tensor - differentiable constructor that preserves the autograd history 16 17 18def masked_tensor(data, mask, requires_grad=False): 19 return MaskedTensor(data, mask, requires_grad) 20 21 22def as_masked_tensor(data, mask): 23 return MaskedTensor._from_values(data, mask) 24