Home
last modified time | relevance | path

Searched defs:OpSchema (Results 1 – 18 of 18) sorted by relevance

/aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/
H A D_tensor_ops.py42 def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
83 def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
141 def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
182 def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
226 def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
244 def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
314 def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
359 def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
369 def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
392 def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
[all …]
H A D_matrix_ops.py33 def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
58 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
88 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
142 def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
147 def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
152 def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
157 def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
165 mesh: DeviceMesh, op_schema: OpSchema
248 mesh: DeviceMesh, op_schema: OpSchema
320 def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
[all …]
H A D_math_ops.py332 def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
358 def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
377 def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
400 def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy:
437 def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
469 def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
510 def softmax_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
551 def nll_loss_forward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
672 def nll_loss_backward_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
784 def layer_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
[all …]
H A D_pointwise_ops.py420 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
519 def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
601 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
656 mesh: DeviceMesh, op_schema: OpSchema
H A D_common_rules.py21 op_schema: OpSchema,
48 op_schema: OpSchema,
229 def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding:
H A D_conv_ops.py16 def convolution_rules(op_schema: OpSchema) -> OutputSharding:
69 def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding:
H A D_embedding_ops.py190 def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
239 mesh: DeviceMesh, op_schema: OpSchema
H A D_random_ops.py26 def random_op_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
H A D_experimental_ops.py22 def slice_backward_rules(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
H A Dutils.py221 op_schema: OpSchema,
H A D_view_ops.py595 def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
/aosp_15_r20/external/pytorch/torch/distributed/tensor/
H A D_sharding_prop.py84 rule_func: Callable[[OpSchema], OutputSharding],
97 strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], argument
109 self, op_schema: OpSchema
205 def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
473 schema: OpSchema,
H A D_op_schema.py214 class OpSchema: class
405 def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
H A D_dispatch.py301 suggested_input_schema: OpSchema,
/aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/
H A D_register_sharding.py74 op_schema: OpSchema,
H A D_tp_transform.py310 op_schema: OpSchema,
/aosp_15_r20/external/pytorch/torch/onnx/_internal/exporter/
H A D_schemas.py379 def from_opschema(cls, opschema: onnx.defs.OpSchema) -> OpSignature:
H A D_building.py549 schema: onnx.defs.OpSchema,