1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package zip
6
7import (
8	"compress/flate"
9	"errors"
10	"io"
11	"sync"
12)
13
14// A Compressor returns a new compressing writer, writing to w.
15// The WriteCloser's Close method must be used to flush pending data to w.
16// The Compressor itself must be safe to invoke from multiple goroutines
17// simultaneously, but each returned writer will be used only by
18// one goroutine at a time.
19type Compressor func(w io.Writer) (io.WriteCloser, error)
20
21// A Decompressor returns a new decompressing reader, reading from r.
22// The [io.ReadCloser]'s Close method must be used to release associated resources.
23// The Decompressor itself must be safe to invoke from multiple goroutines
24// simultaneously, but each returned reader will be used only by
25// one goroutine at a time.
26type Decompressor func(r io.Reader) io.ReadCloser
27
28var flateWriterPool sync.Pool
29
30func newFlateWriter(w io.Writer) io.WriteCloser {
31	fw, ok := flateWriterPool.Get().(*flate.Writer)
32	if ok {
33		fw.Reset(w)
34	} else {
35		fw, _ = flate.NewWriter(w, 5)
36	}
37	return &pooledFlateWriter{fw: fw}
38}
39
40type pooledFlateWriter struct {
41	mu sync.Mutex // guards Close and Write
42	fw *flate.Writer
43}
44
45func (w *pooledFlateWriter) Write(p []byte) (n int, err error) {
46	w.mu.Lock()
47	defer w.mu.Unlock()
48	if w.fw == nil {
49		return 0, errors.New("Write after Close")
50	}
51	return w.fw.Write(p)
52}
53
54func (w *pooledFlateWriter) Close() error {
55	w.mu.Lock()
56	defer w.mu.Unlock()
57	var err error
58	if w.fw != nil {
59		err = w.fw.Close()
60		flateWriterPool.Put(w.fw)
61		w.fw = nil
62	}
63	return err
64}
65
66var flateReaderPool sync.Pool
67
68func newFlateReader(r io.Reader) io.ReadCloser {
69	fr, ok := flateReaderPool.Get().(io.ReadCloser)
70	if ok {
71		fr.(flate.Resetter).Reset(r, nil)
72	} else {
73		fr = flate.NewReader(r)
74	}
75	return &pooledFlateReader{fr: fr}
76}
77
78type pooledFlateReader struct {
79	mu sync.Mutex // guards Close and Read
80	fr io.ReadCloser
81}
82
83func (r *pooledFlateReader) Read(p []byte) (n int, err error) {
84	r.mu.Lock()
85	defer r.mu.Unlock()
86	if r.fr == nil {
87		return 0, errors.New("Read after Close")
88	}
89	return r.fr.Read(p)
90}
91
92func (r *pooledFlateReader) Close() error {
93	r.mu.Lock()
94	defer r.mu.Unlock()
95	var err error
96	if r.fr != nil {
97		err = r.fr.Close()
98		flateReaderPool.Put(r.fr)
99		r.fr = nil
100	}
101	return err
102}
103
104var (
105	compressors   sync.Map // map[uint16]Compressor
106	decompressors sync.Map // map[uint16]Decompressor
107)
108
109func init() {
110	compressors.Store(Store, Compressor(func(w io.Writer) (io.WriteCloser, error) { return &nopCloser{w}, nil }))
111	compressors.Store(Deflate, Compressor(func(w io.Writer) (io.WriteCloser, error) { return newFlateWriter(w), nil }))
112
113	decompressors.Store(Store, Decompressor(io.NopCloser))
114	decompressors.Store(Deflate, Decompressor(newFlateReader))
115}
116
117// RegisterDecompressor allows custom decompressors for a specified method ID.
118// The common methods [Store] and [Deflate] are built in.
119func RegisterDecompressor(method uint16, dcomp Decompressor) {
120	if _, dup := decompressors.LoadOrStore(method, dcomp); dup {
121		panic("decompressor already registered")
122	}
123}
124
125// RegisterCompressor registers custom compressors for a specified method ID.
126// The common methods [Store] and [Deflate] are built in.
127func RegisterCompressor(method uint16, comp Compressor) {
128	if _, dup := compressors.LoadOrStore(method, comp); dup {
129		panic("compressor already registered")
130	}
131}
132
133func compressor(method uint16) Compressor {
134	ci, ok := compressors.Load(method)
135	if !ok {
136		return nil
137	}
138	return ci.(Compressor)
139}
140
141func decompressor(method uint16) Decompressor {
142	di, ok := decompressors.Load(method)
143	if !ok {
144		return nil
145	}
146	return di.(Decompressor)
147}
148