1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package saferio provides I/O functions that avoid allocating large
6// amounts of memory unnecessarily. This is intended for packages that
7// read data from an [io.Reader] where the size is part of the input
8// data but the input may be corrupt, or may be provided by an
9// untrustworthy attacker.
10package saferio
11
12import (
13	"io"
14	"unsafe"
15)
16
17// chunk is an arbitrary limit on how much memory we are willing
18// to allocate without concern.
19const chunk = 10 << 20 // 10M
20
21// ReadData reads n bytes from the input stream, but avoids allocating
22// all n bytes if n is large. This avoids crashing the program by
23// allocating all n bytes in cases where n is incorrect.
24//
25// The error is io.EOF only if no bytes were read.
26// If an io.EOF happens after reading some but not all the bytes,
27// ReadData returns io.ErrUnexpectedEOF.
28func ReadData(r io.Reader, n uint64) ([]byte, error) {
29	if int64(n) < 0 || n != uint64(int(n)) {
30		// n is too large to fit in int, so we can't allocate
31		// a buffer large enough. Treat this as a read failure.
32		return nil, io.ErrUnexpectedEOF
33	}
34
35	if n < chunk {
36		buf := make([]byte, n)
37		_, err := io.ReadFull(r, buf)
38		if err != nil {
39			return nil, err
40		}
41		return buf, nil
42	}
43
44	var buf []byte
45	buf1 := make([]byte, chunk)
46	for n > 0 {
47		next := n
48		if next > chunk {
49			next = chunk
50		}
51		_, err := io.ReadFull(r, buf1[:next])
52		if err != nil {
53			if len(buf) > 0 && err == io.EOF {
54				err = io.ErrUnexpectedEOF
55			}
56			return nil, err
57		}
58		buf = append(buf, buf1[:next]...)
59		n -= next
60	}
61	return buf, nil
62}
63
64// ReadDataAt reads n bytes from the input stream at off, but avoids
65// allocating all n bytes if n is large. This avoids crashing the program
66// by allocating all n bytes in cases where n is incorrect.
67func ReadDataAt(r io.ReaderAt, n uint64, off int64) ([]byte, error) {
68	if int64(n) < 0 || n != uint64(int(n)) {
69		// n is too large to fit in int, so we can't allocate
70		// a buffer large enough. Treat this as a read failure.
71		return nil, io.ErrUnexpectedEOF
72	}
73
74	if n < chunk {
75		buf := make([]byte, n)
76		_, err := r.ReadAt(buf, off)
77		if err != nil {
78			// io.SectionReader can return EOF for n == 0,
79			// but for our purposes that is a success.
80			if err != io.EOF || n > 0 {
81				return nil, err
82			}
83		}
84		return buf, nil
85	}
86
87	var buf []byte
88	buf1 := make([]byte, chunk)
89	for n > 0 {
90		next := n
91		if next > chunk {
92			next = chunk
93		}
94		_, err := r.ReadAt(buf1[:next], off)
95		if err != nil {
96			return nil, err
97		}
98		buf = append(buf, buf1[:next]...)
99		n -= next
100		off += int64(next)
101	}
102	return buf, nil
103}
104
105// SliceCapWithSize returns the capacity to use when allocating a slice.
106// After the slice is allocated with the capacity, it should be
107// built using append. This will avoid allocating too much memory
108// if the capacity is large and incorrect.
109//
110// A negative result means that the value is always too big.
111func SliceCapWithSize(size, c uint64) int {
112	if int64(c) < 0 || c != uint64(int(c)) {
113		return -1
114	}
115	if size > 0 && c > (1<<64-1)/size {
116		return -1
117	}
118	if c*size > chunk {
119		c = chunk / size
120		if c == 0 {
121			c = 1
122		}
123	}
124	return int(c)
125}
126
127// SliceCap is like SliceCapWithSize but using generics.
128func SliceCap[E any](c uint64) int {
129	var v E
130	size := uint64(unsafe.Sizeof(v))
131	return SliceCapWithSize(size, c)
132}
133