xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/worker.proto (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1syntax = "proto3";
2
3package tensorflow.data;
4
5import "tensorflow/core/data/dataset.proto";
6import "tensorflow/core/data/service/common.proto";
7
8message ProcessTaskRequest {
9  TaskDef task = 1;
10}
11
12message ProcessTaskResponse {}
13
14message GetElementRequest {
15  // The task to fetch an element from.
16  int64 task_id = 1;
17  // Optional index to identify the consumer.
18  oneof optional_consumer_index {
19    int64 consumer_index = 2;
20  }
21  // Optional round index, indicating which round of round-robin the consumer
22  // wants to read from. This is used to keep consumers in sync.
23  oneof optional_round_index {
24    int64 round_index = 3;
25  }
26  // Whether the previous round was skipped. This information is needed by the
27  // worker to recover after restarts.
28  bool skipped_previous_round = 4;
29  // Whether to skip the round if data isn't ready fast enough.
30  bool allow_skip = 5;
31  // The trainer ID used to read elements from a multi-trainer cache. This cache
32  // enables sharing data across concurrent training iterations. If set, this
33  // request will read the data requested by other trainers, if available.
34  string trainer_id = 6;
35}
36
37message GetElementResponse {
38  // The produced element.
39  oneof element {
40    CompressedElement compressed = 3;
41    UncompressedElement uncompressed = 5;
42  }
43  // The element's index within the task it came from.
44  int64 element_index = 6;
45  // Boolean to indicate whether the iterator has been exhausted.
46  bool end_of_sequence = 2;
47  // Indicates whether the round was skipped.
48  bool skip_task = 4;
49}
50
51// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
52message GetWorkerTasksRequest {}
53
54message GetWorkerTasksResponse {
55  repeated TaskInfo tasks = 1;
56}
57
58service WorkerService {
59  // Processes a task for a dataset, making elements available to clients.
60  rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
61
62  // Gets the next dataset element.
63  rpc GetElement(GetElementRequest) returns (GetElementResponse);
64
65  // Gets the tasks currently being executed by the worker.
66  rpc GetWorkerTasks(GetWorkerTasksRequest) returns (GetWorkerTasksResponse);
67}
68