diff --git a/internal/zstd/zstd.go b/internal/zstd/zstd.go index 66e5091..68de426 100644 --- a/internal/zstd/zstd.go +++ b/internal/zstd/zstd.go @@ -17,7 +17,6 @@ package zstd import ( - "bytes" "errors" "io" "runtime" @@ -48,9 +47,14 @@ type decoderWrapper struct { *zstd.Decoder } +type encoderWrapper struct { + *zstd.Encoder + pool *sync.Pool +} + type compressor struct { - encoder *zstd.Encoder - decoderPool sync.Pool // To hold *zstd.Decoder's. + encoderPool sync.Pool + decoderPool sync.Pool } func PretendInit(clobbering bool) { @@ -58,11 +62,7 @@ func PretendInit(clobbering bool) { return } - enc, _ := zstd.NewWriter(nil, encoderOptions...) - c := &compressor{ - encoder: enc, - } - encoding.RegisterCompressor(c) + encoding.RegisterCompressor(&compressor{}) } var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compressor has been registered") @@ -71,40 +71,42 @@ var ErrNotInUse = errors.New("SetLevel ineffective because another zstd compress // level. NOTE: this function must only be called from an init function, and // is not threadsafe. func SetLevel(level zstd.EncoderLevel) error { - c, ok := encoding.GetCompressor(Name).(*compressor) + _, ok := encoding.GetCompressor(Name).(*compressor) if !ok { return ErrNotInUse } - enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level)) - if err != nil { - return err - } - - c.encoder = enc + encoderOptions = append(encoderOptions, zstd.WithEncoderLevel(level)) return nil } func (c *compressor) Compress(w io.Writer) (io.WriteCloser, error) { - return &zstdWriteCloser{ - enc: c.encoder, - writer: w, - }, nil -} + var err error + var found bool + var encoder *zstd.Encoder -type zstdWriteCloser struct { - enc *zstd.Encoder - writer io.Writer // Compressed data will be written here. - buf bytes.Buffer // Buffer uncompressed data here, compress on Close. -} + encoder, found = c.encoderPool.Get().(*zstd.Encoder) + if !found { + encoder, err = zstd.NewWriter(w, encoderOptions...) + if err != nil { + return nil, err + } + } else { + encoder.Reset(w) + } + + wrapper := &encoderWrapper{Encoder: encoder, pool: &c.encoderPool} + runtime.SetFinalizer(wrapper, func(ew *encoderWrapper) { + ew.Reset(nil) + c.encoderPool.Put(ew.Encoder) + }) -func (z *zstdWriteCloser) Write(p []byte) (int, error) { - return z.buf.Write(p) + return wrapper, nil } -func (z *zstdWriteCloser) Close() error { - compressed := z.enc.EncodeAll(z.buf.Bytes(), nil) - _, err := io.Copy(z.writer, bytes.NewReader(compressed)) +func (w *encoderWrapper) Close() error { + err := w.Encoder.Close() + w.pool.Put(w.Encoder) return err }