1// Copyright 2024 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package exithook provides limited support for on-exit cleanup.
6//
7// CAREFUL! The expectation is that Add should only be called
8// from a safe context (e.g. not an error/panic path or signal
9// handler, preemption enabled, allocation allowed, write barriers
10// allowed, etc), and that the exit function F will be invoked under
11// similar circumstances. That is the say, we are expecting that F
12// uses normal / high-level Go code as opposed to one of the more
13// restricted dialects used for the trickier parts of the runtime.
14package exithook
15
16import (
17	"internal/runtime/atomic"
18	_ "unsafe" // for linkname
19)
20
21// A Hook is a function to be run at program termination
22// (when someone invokes os.Exit, or when main.main returns).
23// Hooks are run in reverse order of registration:
24// the first hook added is the last one run.
25type Hook struct {
26	F            func() // func to run
27	RunOnFailure bool   // whether to run on non-zero exit code
28}
29
30var (
31	locked  atomic.Int32
32	runGoid atomic.Uint64
33	hooks   []Hook
34	running bool
35
36	// runtime sets these for us
37	Gosched func()
38	Goid    func() uint64
39	Throw   func(string)
40)
41
42// Add adds a new exit hook.
43func Add(h Hook) {
44	for !locked.CompareAndSwap(0, 1) {
45		Gosched()
46	}
47	hooks = append(hooks, h)
48	locked.Store(0)
49}
50
51// Run runs the exit hooks.
52//
53// If an exit hook panics, Run will throw with the panic on the stack.
54// If an exit hook invokes exit in the same goroutine, the goroutine will throw.
55// If an exit hook invokes exit in another goroutine, that exit will block.
56func Run(code int) {
57	for !locked.CompareAndSwap(0, 1) {
58		if Goid() == runGoid.Load() {
59			Throw("exit hook invoked exit")
60		}
61		Gosched()
62	}
63	defer locked.Store(0)
64	runGoid.Store(Goid())
65	defer runGoid.Store(0)
66
67	defer func() {
68		if e := recover(); e != nil {
69			Throw("exit hook invoked panic")
70		}
71	}()
72
73	for len(hooks) > 0 {
74		h := hooks[len(hooks)-1]
75		hooks = hooks[:len(hooks)-1]
76		if code != 0 && !h.RunOnFailure {
77			continue
78		}
79		h.F()
80	}
81}
82
83type exitError string
84
85func (e exitError) Error() string { return string(e) }
86