1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build linux
6
7package syscall_test
8
9import (
10	"bytes"
11	"net"
12	"os"
13	"syscall"
14	"testing"
15)
16
17// TestSCMCredentials tests the sending and receiving of credentials
18// (PID, UID, GID) in an ancillary message between two UNIX
19// sockets. The SO_PASSCRED socket option is enabled on the sending
20// socket for this to work.
21func TestSCMCredentials(t *testing.T) {
22	socketTypeTests := []struct {
23		socketType int
24		dataLen    int
25	}{
26		{
27			syscall.SOCK_STREAM,
28			1,
29		}, {
30			syscall.SOCK_DGRAM,
31			0,
32		},
33	}
34
35	for _, tt := range socketTypeTests {
36		fds, err := syscall.Socketpair(syscall.AF_LOCAL, tt.socketType, 0)
37		if err != nil {
38			t.Fatalf("Socketpair: %v", err)
39		}
40
41		err = syscall.SetsockoptInt(fds[0], syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
42		if err != nil {
43			syscall.Close(fds[0])
44			syscall.Close(fds[1])
45			t.Fatalf("SetsockoptInt: %v", err)
46		}
47
48		srvFile := os.NewFile(uintptr(fds[0]), "server")
49		cliFile := os.NewFile(uintptr(fds[1]), "client")
50		defer srvFile.Close()
51		defer cliFile.Close()
52
53		srv, err := net.FileConn(srvFile)
54		if err != nil {
55			t.Errorf("FileConn: %v", err)
56			return
57		}
58		defer srv.Close()
59
60		cli, err := net.FileConn(cliFile)
61		if err != nil {
62			t.Errorf("FileConn: %v", err)
63			return
64		}
65		defer cli.Close()
66
67		var ucred syscall.Ucred
68		if os.Getuid() != 0 {
69			ucred.Pid = int32(os.Getpid())
70			ucred.Uid = 0
71			ucred.Gid = 0
72			oob := syscall.UnixCredentials(&ucred)
73			_, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
74			if op, ok := err.(*net.OpError); ok {
75				err = op.Err
76			}
77			if sys, ok := err.(*os.SyscallError); ok {
78				err = sys.Err
79			}
80			switch err {
81			case syscall.EPERM, syscall.EINVAL:
82			default:
83				t.Fatalf("WriteMsgUnix failed with %v, want EPERM or EINVAL", err)
84			}
85		}
86
87		ucred.Pid = int32(os.Getpid())
88		ucred.Uid = uint32(os.Getuid())
89		ucred.Gid = uint32(os.Getgid())
90		oob := syscall.UnixCredentials(&ucred)
91
92		// On SOCK_STREAM, this is internally going to send a dummy byte
93		n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
94		if err != nil {
95			t.Fatalf("WriteMsgUnix: %v", err)
96		}
97		if n != 0 {
98			t.Fatalf("WriteMsgUnix n = %d, want 0", n)
99		}
100		if oobn != len(oob) {
101			t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
102		}
103
104		oob2 := make([]byte, 10*len(oob))
105		n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
106		if err != nil {
107			t.Fatalf("ReadMsgUnix: %v", err)
108		}
109		if flags != syscall.MSG_CMSG_CLOEXEC {
110			t.Fatalf("ReadMsgUnix flags = %#x, want %#x (MSG_CMSG_CLOEXEC)", flags, syscall.MSG_CMSG_CLOEXEC)
111		}
112		if n != tt.dataLen {
113			t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
114		}
115		if oobn2 != oobn {
116			// without SO_PASSCRED set on the socket, ReadMsgUnix will
117			// return zero oob bytes
118			t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
119		}
120		oob2 = oob2[:oobn2]
121		if !bytes.Equal(oob, oob2) {
122			t.Fatal("ReadMsgUnix oob bytes don't match")
123		}
124
125		scm, err := syscall.ParseSocketControlMessage(oob2)
126		if err != nil {
127			t.Fatalf("ParseSocketControlMessage: %v", err)
128		}
129		newUcred, err := syscall.ParseUnixCredentials(&scm[0])
130		if err != nil {
131			t.Fatalf("ParseUnixCredentials: %v", err)
132		}
133		if *newUcred != ucred {
134			t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
135		}
136	}
137}
138