1 //! Core task module.
2 //!
3 //! # Safety
4 //!
5 //! The functions in this module are private to the `task` module. All of them
6 //! should be considered `unsafe` to use, but are not marked as such since it
7 //! would be too noisy.
8 //!
9 //! Make sure to consult the relevant safety section of each function before
10 //! use.
11 
12 use crate::future::Future;
13 use crate::loom::cell::UnsafeCell;
14 use crate::runtime::context;
15 use crate::runtime::task::raw::{self, Vtable};
16 use crate::runtime::task::state::State;
17 use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks};
18 use crate::util::linked_list;
19 
20 use std::num::NonZeroU64;
21 use std::pin::Pin;
22 use std::ptr::NonNull;
23 use std::task::{Context, Poll, Waker};
24 
25 /// The task cell. Contains the components of the task.
26 ///
27 /// It is critical for `Header` to be the first field as the task structure will
28 /// be referenced by both *mut Cell and *mut Header.
29 ///
30 /// Any changes to the layout of this struct _must_ also be reflected in the
31 /// `const` fns in raw.rs.
32 ///
33 // # This struct should be cache padded to avoid false sharing. The cache padding rules are copied
34 // from crossbeam-utils/src/cache_padded.rs
35 //
36 // Starting from Intel's Sandy Bridge, spatial prefetcher is now pulling pairs of 64-byte cache
37 // lines at a time, so we have to align to 128 bytes rather than 64.
38 //
39 // Sources:
40 // - https://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-optimization-manual.pdf
41 // - https://github.com/facebook/folly/blob/1b5288e6eea6df074758f877c849b6e73bbb9fbb/folly/lang/Align.h#L107
42 //
43 // ARM's big.LITTLE architecture has asymmetric cores and "big" cores have 128-byte cache line size.
44 //
45 // Sources:
46 // - https://www.mono-project.com/news/2016/09/12/arm64-icache/
47 //
48 // powerpc64 has 128-byte cache line size.
49 //
50 // Sources:
51 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_ppc64x.go#L9
52 #[cfg_attr(
53     any(
54         target_arch = "x86_64",
55         target_arch = "aarch64",
56         target_arch = "powerpc64",
57     ),
58     repr(align(128))
59 )]
60 // arm, mips, mips64, sparc, and hexagon have 32-byte cache line size.
61 //
62 // Sources:
63 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_arm.go#L7
64 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips.go#L7
65 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mipsle.go#L7
66 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_mips64x.go#L9
67 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L17
68 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/hexagon/include/asm/cache.h#L12
69 #[cfg_attr(
70     any(
71         target_arch = "arm",
72         target_arch = "mips",
73         target_arch = "mips64",
74         target_arch = "sparc",
75         target_arch = "hexagon",
76     ),
77     repr(align(32))
78 )]
79 // m68k has 16-byte cache line size.
80 //
81 // Sources:
82 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/m68k/include/asm/cache.h#L9
83 #[cfg_attr(target_arch = "m68k", repr(align(16)))]
84 // s390x has 256-byte cache line size.
85 //
86 // Sources:
87 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_s390x.go#L7
88 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/s390/include/asm/cache.h#L13
89 #[cfg_attr(target_arch = "s390x", repr(align(256)))]
90 // x86, riscv, wasm, and sparc64 have 64-byte cache line size.
91 //
92 // Sources:
93 // - https://github.com/golang/go/blob/dda2991c2ea0c5914714469c4defc2562a907230/src/internal/cpu/cpu_x86.go#L9
94 // - https://github.com/golang/go/blob/3dd58676054223962cd915bb0934d1f9f489d4d2/src/internal/cpu/cpu_wasm.go#L7
95 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/sparc/include/asm/cache.h#L19
96 // - https://github.com/torvalds/linux/blob/3516bd729358a2a9b090c1905bd2a3fa926e24c6/arch/riscv/include/asm/cache.h#L10
97 //
98 // All others are assumed to have 64-byte cache line size.
99 #[cfg_attr(
100     not(any(
101         target_arch = "x86_64",
102         target_arch = "aarch64",
103         target_arch = "powerpc64",
104         target_arch = "arm",
105         target_arch = "mips",
106         target_arch = "mips64",
107         target_arch = "sparc",
108         target_arch = "hexagon",
109         target_arch = "m68k",
110         target_arch = "s390x",
111     )),
112     repr(align(64))
113 )]
114 #[repr(C)]
115 pub(super) struct Cell<T: Future, S> {
116     /// Hot task state data
117     pub(super) header: Header,
118 
119     /// Either the future or output, depending on the execution stage.
120     pub(super) core: Core<T, S>,
121 
122     /// Cold data
123     pub(super) trailer: Trailer,
124 }
125 
126 pub(super) struct CoreStage<T: Future> {
127     stage: UnsafeCell<Stage<T>>,
128 }
129 
130 /// The core of the task.
131 ///
132 /// Holds the future or output, depending on the stage of execution.
133 ///
134 /// Any changes to the layout of this struct _must_ also be reflected in the
135 /// `const` fns in raw.rs.
136 #[repr(C)]
137 pub(super) struct Core<T: Future, S> {
138     /// Scheduler used to drive this future.
139     pub(super) scheduler: S,
140 
141     /// The task's ID, used for populating `JoinError`s.
142     pub(super) task_id: Id,
143 
144     /// Either the future or the output.
145     pub(super) stage: CoreStage<T>,
146 }
147 
148 /// Crate public as this is also needed by the pool.
149 #[repr(C)]
150 pub(crate) struct Header {
151     /// Task state.
152     pub(super) state: State,
153 
154     /// Pointer to next task, used with the injection queue.
155     pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>,
156 
157     /// Table of function pointers for executing actions on the task.
158     pub(super) vtable: &'static Vtable,
159 
160     /// This integer contains the id of the `OwnedTasks` or `LocalOwnedTasks`
161     /// that this task is stored in. If the task is not in any list, should be
162     /// the id of the list that it was previously in, or `None` if it has never
163     /// been in any list.
164     ///
165     /// Once a task has been bound to a list, it can never be bound to another
166     /// list, even if removed from the first list.
167     ///
168     /// The id is not unset when removed from a list because we want to be able
169     /// to read the id without synchronization, even if it is concurrently being
170     /// removed from the list.
171     pub(super) owner_id: UnsafeCell<Option<NonZeroU64>>,
172 
173     /// The tracing ID for this instrumented task.
174     #[cfg(all(tokio_unstable, feature = "tracing"))]
175     pub(super) tracing_id: Option<tracing::Id>,
176 }
177 
178 unsafe impl Send for Header {}
179 unsafe impl Sync for Header {}
180 
181 /// Cold data is stored after the future. Data is considered cold if it is only
182 /// used during creation or shutdown of the task.
183 pub(super) struct Trailer {
184     /// Pointers for the linked list in the `OwnedTasks` that owns this task.
185     pub(super) owned: linked_list::Pointers<Header>,
186     /// Consumer task waiting on completion of this task.
187     pub(super) waker: UnsafeCell<Option<Waker>>,
188     /// Optional hooks needed in the harness.
189     pub(super) hooks: TaskHarnessScheduleHooks,
190 }
191 
192 generate_addr_of_methods! {
193     impl<> Trailer {
194         pub(super) unsafe fn addr_of_owned(self: NonNull<Self>) -> NonNull<linked_list::Pointers<Header>> {
195             &self.owned
196         }
197     }
198 }
199 
200 /// Either the future or the output.
201 #[repr(C)] // https://github.com/rust-lang/miri/issues/3780
202 pub(super) enum Stage<T: Future> {
203     Running(T),
204     Finished(super::Result<T::Output>),
205     Consumed,
206 }
207 
208 impl<T: Future, S: Schedule> Cell<T, S> {
209     /// Allocates a new task cell, containing the header, trailer, and core
210     /// structures.
new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>>211     pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
212         // Separated into a non-generic function to reduce LLVM codegen
213         fn new_header(
214             state: State,
215             vtable: &'static Vtable,
216             #[cfg(all(tokio_unstable, feature = "tracing"))] tracing_id: Option<tracing::Id>,
217         ) -> Header {
218             Header {
219                 state,
220                 queue_next: UnsafeCell::new(None),
221                 vtable,
222                 owner_id: UnsafeCell::new(None),
223                 #[cfg(all(tokio_unstable, feature = "tracing"))]
224                 tracing_id,
225             }
226         }
227 
228         #[cfg(all(tokio_unstable, feature = "tracing"))]
229         let tracing_id = future.id();
230         let vtable = raw::vtable::<T, S>();
231         let result = Box::new(Cell {
232             trailer: Trailer::new(scheduler.hooks()),
233             header: new_header(
234                 state,
235                 vtable,
236                 #[cfg(all(tokio_unstable, feature = "tracing"))]
237                 tracing_id,
238             ),
239             core: Core {
240                 scheduler,
241                 stage: CoreStage {
242                     stage: UnsafeCell::new(Stage::Running(future)),
243                 },
244                 task_id,
245             },
246         });
247 
248         #[cfg(debug_assertions)]
249         {
250             // Using a separate function for this code avoids instantiating it separately for every `T`.
251             unsafe fn check<S>(header: &Header, trailer: &Trailer, scheduler: &S, task_id: &Id) {
252                 let trailer_addr = trailer as *const Trailer as usize;
253                 let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(header)) };
254                 assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);
255 
256                 let scheduler_addr = scheduler as *const S as usize;
257                 let scheduler_ptr = unsafe { Header::get_scheduler::<S>(NonNull::from(header)) };
258                 assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);
259 
260                 let id_addr = task_id as *const Id as usize;
261                 let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(header)) };
262                 assert_eq!(id_addr, id_ptr.as_ptr() as usize);
263             }
264             unsafe {
265                 check(
266                     &result.header,
267                     &result.trailer,
268                     &result.core.scheduler,
269                     &result.core.task_id,
270                 );
271             }
272         }
273 
274         result
275     }
276 }
277 
278 impl<T: Future> CoreStage<T> {
with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R279     pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R {
280         self.stage.with_mut(f)
281     }
282 }
283 
284 /// Set and clear the task id in the context when the future is executed or
285 /// dropped, or when the output produced by the future is dropped.
286 pub(crate) struct TaskIdGuard {
287     parent_task_id: Option<Id>,
288 }
289 
290 impl TaskIdGuard {
enter(id: Id) -> Self291     fn enter(id: Id) -> Self {
292         TaskIdGuard {
293             parent_task_id: context::set_current_task_id(Some(id)),
294         }
295     }
296 }
297 
298 impl Drop for TaskIdGuard {
drop(&mut self)299     fn drop(&mut self) {
300         context::set_current_task_id(self.parent_task_id);
301     }
302 }
303 
304 impl<T: Future, S: Schedule> Core<T, S> {
305     /// Polls the future.
306     ///
307     /// # Safety
308     ///
309     /// The caller must ensure it is safe to mutate the `state` field. This
310     /// requires ensuring mutual exclusion between any concurrent thread that
311     /// might modify the future or output field.
312     ///
313     /// The mutual exclusion is implemented by `Harness` and the `Lifecycle`
314     /// component of the task state.
315     ///
316     /// `self` must also be pinned. This is handled by storing the task on the
317     /// heap.
poll(&self, mut cx: Context<'_>) -> Poll<T::Output>318     pub(super) fn poll(&self, mut cx: Context<'_>) -> Poll<T::Output> {
319         let res = {
320             self.stage.stage.with_mut(|ptr| {
321                 // Safety: The caller ensures mutual exclusion to the field.
322                 let future = match unsafe { &mut *ptr } {
323                     Stage::Running(future) => future,
324                     _ => unreachable!("unexpected stage"),
325                 };
326 
327                 // Safety: The caller ensures the future is pinned.
328                 let future = unsafe { Pin::new_unchecked(future) };
329 
330                 let _guard = TaskIdGuard::enter(self.task_id);
331                 future.poll(&mut cx)
332             })
333         };
334 
335         if res.is_ready() {
336             self.drop_future_or_output();
337         }
338 
339         res
340     }
341 
342     /// Drops the future.
343     ///
344     /// # Safety
345     ///
346     /// The caller must ensure it is safe to mutate the `stage` field.
drop_future_or_output(&self)347     pub(super) fn drop_future_or_output(&self) {
348         // Safety: the caller ensures mutual exclusion to the field.
349         unsafe {
350             self.set_stage(Stage::Consumed);
351         }
352     }
353 
354     /// Stores the task output.
355     ///
356     /// # Safety
357     ///
358     /// The caller must ensure it is safe to mutate the `stage` field.
store_output(&self, output: super::Result<T::Output>)359     pub(super) fn store_output(&self, output: super::Result<T::Output>) {
360         // Safety: the caller ensures mutual exclusion to the field.
361         unsafe {
362             self.set_stage(Stage::Finished(output));
363         }
364     }
365 
366     /// Takes the task output.
367     ///
368     /// # Safety
369     ///
370     /// The caller must ensure it is safe to mutate the `stage` field.
take_output(&self) -> super::Result<T::Output>371     pub(super) fn take_output(&self) -> super::Result<T::Output> {
372         use std::mem;
373 
374         self.stage.stage.with_mut(|ptr| {
375             // Safety:: the caller ensures mutual exclusion to the field.
376             match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) {
377                 Stage::Finished(output) => output,
378                 _ => panic!("JoinHandle polled after completion"),
379             }
380         })
381     }
382 
set_stage(&self, stage: Stage<T>)383     unsafe fn set_stage(&self, stage: Stage<T>) {
384         let _guard = TaskIdGuard::enter(self.task_id);
385         self.stage.stage.with_mut(|ptr| *ptr = stage);
386     }
387 }
388 
389 impl Header {
set_next(&self, next: Option<NonNull<Header>>)390     pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) {
391         self.queue_next.with_mut(|ptr| *ptr = next);
392     }
393 
394     // safety: The caller must guarantee exclusive access to this field, and
395     // must ensure that the id is either `None` or the id of the OwnedTasks
396     // containing this task.
set_owner_id(&self, owner: NonZeroU64)397     pub(super) unsafe fn set_owner_id(&self, owner: NonZeroU64) {
398         self.owner_id.with_mut(|ptr| *ptr = Some(owner));
399     }
400 
get_owner_id(&self) -> Option<NonZeroU64>401     pub(super) fn get_owner_id(&self) -> Option<NonZeroU64> {
402         // safety: If there are concurrent writes, then that write has violated
403         // the safety requirements on `set_owner_id`.
404         unsafe { self.owner_id.with(|ptr| *ptr) }
405     }
406 
407     /// Gets a pointer to the `Trailer` of the task containing this `Header`.
408     ///
409     /// # Safety
410     ///
411     /// The provided raw pointer must point at the header of a task.
get_trailer(me: NonNull<Header>) -> NonNull<Trailer>412     pub(super) unsafe fn get_trailer(me: NonNull<Header>) -> NonNull<Trailer> {
413         let offset = me.as_ref().vtable.trailer_offset;
414         let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
415         NonNull::new_unchecked(trailer)
416     }
417 
418     /// Gets a pointer to the scheduler of the task containing this `Header`.
419     ///
420     /// # Safety
421     ///
422     /// The provided raw pointer must point at the header of a task.
423     ///
424     /// The generic type S must be set to the correct scheduler type for this
425     /// task.
get_scheduler<S>(me: NonNull<Header>) -> NonNull<S>426     pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
427         let offset = me.as_ref().vtable.scheduler_offset;
428         let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
429         NonNull::new_unchecked(scheduler)
430     }
431 
432     /// Gets a pointer to the id of the task containing this `Header`.
433     ///
434     /// # Safety
435     ///
436     /// The provided raw pointer must point at the header of a task.
get_id_ptr(me: NonNull<Header>) -> NonNull<Id>437     pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
438         let offset = me.as_ref().vtable.id_offset;
439         let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
440         NonNull::new_unchecked(id)
441     }
442 
443     /// Gets the id of the task containing this `Header`.
444     ///
445     /// # Safety
446     ///
447     /// The provided raw pointer must point at the header of a task.
get_id(me: NonNull<Header>) -> Id448     pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
449         let ptr = Header::get_id_ptr(me).as_ptr();
450         *ptr
451     }
452 
453     /// Gets the tracing id of the task containing this `Header`.
454     ///
455     /// # Safety
456     ///
457     /// The provided raw pointer must point at the header of a task.
458     #[cfg(all(tokio_unstable, feature = "tracing"))]
get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id>459     pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
460         me.as_ref().tracing_id.as_ref()
461     }
462 }
463 
464 impl Trailer {
new(hooks: TaskHarnessScheduleHooks) -> Self465     fn new(hooks: TaskHarnessScheduleHooks) -> Self {
466         Trailer {
467             waker: UnsafeCell::new(None),
468             owned: linked_list::Pointers::new(),
469             hooks,
470         }
471     }
472 
set_waker(&self, waker: Option<Waker>)473     pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) {
474         self.waker.with_mut(|ptr| {
475             *ptr = waker;
476         });
477     }
478 
will_wake(&self, waker: &Waker) -> bool479     pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool {
480         self.waker
481             .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker))
482     }
483 
wake_join(&self)484     pub(super) fn wake_join(&self) {
485         self.waker.with(|ptr| match unsafe { &*ptr } {
486             Some(waker) => waker.wake_by_ref(),
487             None => panic!("waker missing"),
488         });
489     }
490 }
491 
492 #[test]
493 #[cfg(not(loom))]
header_lte_cache_line()494 fn header_lte_cache_line() {
495     assert!(std::mem::size_of::<Header>() <= 8 * std::mem::size_of::<*const ()>());
496 }
497