/aosp_15_r20/external/pytorch/test/functorch/ |
H A D | test_ops.py | 143 def wrapped(*primals): 174 def wrapped(*primals): 201 def ref_vjp(f, *primals): 210 def simulate_jvp(f, primals, tangents): argument 215 def ref_jvp(f, primals, tangents): argument 267 def _get_vjpfull_variant(fn, primals): argument 327 def _get_jvp_variant(fn, primals, tangents): argument 1742 def get_vjp(cotangents, *primals): 1867 def push_vjp(primals, cotangents): argument 1883 def reference(primals, cotangents, primals_tangents, cotangents_tangents): argument [all …]
|
/aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/ |
H A D | traced_function_transforms.py | 315 primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset argument 643 def joint_helper(primals, tangents): argument 784 def joint_fn(primals, tangents): argument 790 def fw_fn(*primals): 794 def metadata_fn(*primals):
|
/aosp_15_r20/external/tensorflow/tensorflow/python/eager/ |
H A D | forwardprop.py | 328 def __init__(self, primals, tangents): argument 382 def _watch(self, primals, tangents): argument 413 def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): argument 449 def _batch_accumulator(cls, primals, tangents): argument
|
H A D | forwardprop_test.py | 56 def _jvp(f, primals, tangents): argument 64 def _jacfwd(f, primals): argument 98 def _jvp_batch_matmul(f, primals, tangent_batch): argument 145 def _hvp(f, primals, tangents): argument 201 primals, argument
|
/aosp_15_r20/external/pytorch/torch/_functorch/ |
H A D | eager_transforms.py | 80 def _jvp_treespec_compare(primals, tangents): argument 93 def _linearize_treespec_compare(primals, tangents): argument 242 def vjp(func: Callable, *primals, has_aux: bool = False): 373 func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False 1723 def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
|
H A D | deprecated.py | 82 def vjp(func: Callable, *primals, has_aux: bool = False):
|
/aosp_15_r20/external/pytorch/test/ |
H A D | test_decomp.py | 145 def ref_vjp_no_create(f, *primals): 323 def wrapped(*primals):
|