Home
last modified time | relevance | path

Searched defs:primals (Results 1 – 7 of 7) sorted by relevance

/aosp_15_r20/external/pytorch/test/functorch/
H A Dtest_ops.py143 def wrapped(*primals):
174 def wrapped(*primals):
201 def ref_vjp(f, *primals):
210 def simulate_jvp(f, primals, tangents): argument
215 def ref_jvp(f, primals, tangents): argument
267 def _get_vjpfull_variant(fn, primals): argument
327 def _get_jvp_variant(fn, primals, tangents): argument
1742 def get_vjp(cotangents, *primals):
1867 def push_vjp(primals, cotangents): argument
1883 def reference(primals, cotangents, primals_tangents, cotangents_tangents): argument
[all …]
/aosp_15_r20/external/pytorch/torch/_functorch/_aot_autograd/
H A Dtraced_function_transforms.py315 primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset argument
643 def joint_helper(primals, tangents): argument
784 def joint_fn(primals, tangents): argument
790 def fw_fn(*primals):
794 def metadata_fn(*primals):
/aosp_15_r20/external/tensorflow/tensorflow/python/eager/
H A Dforwardprop.py328 def __init__(self, primals, tangents): argument
382 def _watch(self, primals, tangents): argument
413 def jvp(self, primals, unconnected_gradients=UnconnectedGradients.NONE): argument
449 def _batch_accumulator(cls, primals, tangents): argument
H A Dforwardprop_test.py56 def _jvp(f, primals, tangents): argument
64 def _jacfwd(f, primals): argument
98 def _jvp_batch_matmul(f, primals, tangent_batch): argument
145 def _hvp(f, primals, tangents): argument
201 primals, argument
/aosp_15_r20/external/pytorch/torch/_functorch/
H A Deager_transforms.py80 def _jvp_treespec_compare(primals, tangents): argument
93 def _linearize_treespec_compare(primals, tangents): argument
242 def vjp(func: Callable, *primals, has_aux: bool = False):
373 func: Callable, *primals, argnums: Optional[argnums_t] = None, has_aux: bool = False
1723 def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
H A Ddeprecated.py82 def vjp(func: Callable, *primals, has_aux: bool = False):
/aosp_15_r20/external/pytorch/test/
H A Dtest_decomp.py145 def ref_vjp_no_create(f, *primals):
323 def wrapped(*primals):