diff options
Diffstat (limited to 'vendor/github.com/DataDog/zstd/zstd_stream.go')
-rw-r--r-- | vendor/github.com/DataDog/zstd/zstd_stream.go | 294 |
1 files changed, 294 insertions, 0 deletions
diff --git a/vendor/github.com/DataDog/zstd/zstd_stream.go b/vendor/github.com/DataDog/zstd/zstd_stream.go new file mode 100644 index 000000000..233035352 --- /dev/null +++ b/vendor/github.com/DataDog/zstd/zstd_stream.go @@ -0,0 +1,294 @@ +package zstd + +/* +#define ZSTD_STATIC_LINKING_ONLY +#define ZBUFF_DISABLE_DEPRECATE_WARNINGS +#include "zstd.h" +#include "zbuff.h" +*/ +import "C" +import ( + "errors" + "fmt" + "io" + "runtime" + "unsafe" +) + +var errShortRead = errors.New("short read") + +// Writer is an io.WriteCloser that zstd-compresses its input. +type Writer struct { + CompressionLevel int + + ctx *C.ZSTD_CCtx + dict []byte + dstBuffer []byte + firstError error + underlyingWriter io.Writer +} + +func resize(in []byte, newSize int) []byte { + if in == nil { + return make([]byte, newSize) + } + if newSize <= cap(in) { + return in[:newSize] + } + toAdd := newSize - len(in) + return append(in, make([]byte, toAdd)...) +} + +// NewWriter creates a new Writer with default compression options. Writes to +// the writer will be written in compressed form to w. +func NewWriter(w io.Writer) *Writer { + return NewWriterLevelDict(w, DefaultCompression, nil) +} + +// NewWriterLevel is like NewWriter but specifies the compression level instead +// of assuming default compression. +// +// The level can be DefaultCompression or any integer value between BestSpeed +// and BestCompression inclusive. +func NewWriterLevel(w io.Writer, level int) *Writer { + return NewWriterLevelDict(w, level, nil) + +} + +// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to +// compress with. If the dictionary is empty or nil it is ignored. The dictionary +// should not be modified until the writer is closed. +func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer { + var err error + ctx := C.ZSTD_createCCtx() + + if dict == nil { + err = getError(int(C.ZSTD_compressBegin(ctx, + C.int(level)))) + } else { + err = getError(int(C.ZSTD_compressBegin_usingDict( + ctx, + unsafe.Pointer(&dict[0]), + C.size_t(len(dict)), + C.int(level)))) + } + + return &Writer{ + CompressionLevel: level, + ctx: ctx, + dict: dict, + dstBuffer: make([]byte, CompressBound(1024)), + firstError: err, + underlyingWriter: w, + } +} + +// Write writes a compressed form of p to the underlying io.Writer. +func (w *Writer) Write(p []byte) (int, error) { + if w.firstError != nil { + return 0, w.firstError + } + if len(p) == 0 { + return 0, nil + } + // Check if dstBuffer is enough + if len(w.dstBuffer) < CompressBound(len(p)) { + w.dstBuffer = make([]byte, CompressBound(len(p))) + } + + retCode := C.ZSTD_compressContinue( + w.ctx, + unsafe.Pointer(&w.dstBuffer[0]), + C.size_t(len(w.dstBuffer)), + unsafe.Pointer(&p[0]), + C.size_t(len(p))) + + if err := getError(int(retCode)); err != nil { + return 0, err + } + written := int(retCode) + + // Write to underlying buffer + _, err := w.underlyingWriter.Write(w.dstBuffer[:written]) + + // Same behaviour as zlib, we can't know how much data we wrote, only + // if there was an error + if err != nil { + return 0, err + } + return len(p), err +} + +// Close closes the Writer, flushing any unwritten data to the underlying +// io.Writer and freeing objects, but does not close the underlying io.Writer. +func (w *Writer) Close() error { + retCode := C.ZSTD_compressEnd( + w.ctx, + unsafe.Pointer(&w.dstBuffer[0]), + C.size_t(len(w.dstBuffer)), + unsafe.Pointer(nil), + C.size_t(0)) + + if err := getError(int(retCode)); err != nil { + return err + } + written := int(retCode) + retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end + + if err := getError(int(retCode)); err != nil { + return err + } + + _, err := w.underlyingWriter.Write(w.dstBuffer[:written]) + if err != nil { + return err + } + return nil +} + +// reader is an io.ReadCloser that decompresses when read from. +type reader struct { + ctx *C.ZBUFF_DCtx + compressionBuffer []byte + compressionLeft int + decompressionBuffer []byte + decompOff int + decompSize int + dict []byte + firstError error + recommendedSrcSize int + underlyingReader io.Reader +} + +// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser +// read and decompress data from r. It is the caller's responsibility to call +// Close on the ReadCloser when done. If this is not done, underlying objects +// in the zstd library will not be freed. +func NewReader(r io.Reader) io.ReadCloser { + return NewReaderDict(r, nil) +} + +// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict +// ignores the dictionary if it is nil. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + var err error + ctx := C.ZBUFF_createDCtx() + if len(dict) == 0 { + err = getError(int(C.ZBUFF_decompressInit(ctx))) + } else { + err = getError(int(C.ZBUFF_decompressInitDictionary( + ctx, + unsafe.Pointer(&dict[0]), + C.size_t(len(dict))))) + } + cSize := int(C.ZBUFF_recommendedDInSize()) + dSize := int(C.ZBUFF_recommendedDOutSize()) + if cSize <= 0 { + panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize)) + } + if dSize <= 0 { + panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize)) + } + + compressionBuffer := make([]byte, cSize) + decompressionBuffer := make([]byte, dSize) + return &reader{ + ctx: ctx, + dict: dict, + compressionBuffer: compressionBuffer, + decompressionBuffer: decompressionBuffer, + firstError: err, + recommendedSrcSize: cSize, + underlyingReader: r, + } +} + +// Close frees the allocated C objects +func (r *reader) Close() error { + return getError(int(C.ZBUFF_freeDCtx(r.ctx))) +} + +func (r *reader) Read(p []byte) (int, error) { + + // If we already have enough bytes, return + if r.decompSize-r.decompOff >= len(p) { + copy(p, r.decompressionBuffer[r.decompOff:]) + r.decompOff += len(p) + return len(p), nil + } + + copy(p, r.decompressionBuffer[r.decompOff:r.decompSize]) + got := r.decompSize - r.decompOff + r.decompSize = 0 + r.decompOff = 0 + + for got < len(p) { + // Populate src + src := r.compressionBuffer + reader := r.underlyingReader + n, err := TryReadFull(reader, src[r.compressionLeft:]) + if err != nil && err != errShortRead { // Handle underlying reader errors first + return 0, fmt.Errorf("failed to read from underlying reader: %s", err) + } else if n == 0 && r.compressionLeft == 0 { + return got, io.EOF + } + src = src[:r.compressionLeft+n] + + // C code + cSrcSize := C.size_t(len(src)) + cDstSize := C.size_t(len(r.decompressionBuffer)) + retCode := int(C.ZBUFF_decompressContinue( + r.ctx, + unsafe.Pointer(&r.decompressionBuffer[0]), + &cDstSize, + unsafe.Pointer(&src[0]), + &cSrcSize)) + + // Keep src here eventhough, we reuse later, the code might be deleted at some point + runtime.KeepAlive(src) + if err = getError(retCode); err != nil { + return 0, fmt.Errorf("failed to decompress: %s", err) + } + + // Put everything in buffer + if int(cSrcSize) < len(src) { + left := src[int(cSrcSize):] + copy(r.compressionBuffer, left) + } + r.compressionLeft = len(src) - int(cSrcSize) + r.decompSize = int(cDstSize) + r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize]) + got += r.decompOff + + // Resize buffers + nsize := retCode // Hint for next src buffer size + if nsize <= 0 { + // Reset to recommended size + nsize = r.recommendedSrcSize + } + if nsize < r.compressionLeft { + nsize = r.compressionLeft + } + r.compressionBuffer = resize(r.compressionBuffer, nsize) + } + return got, nil +} + +// TryReadFull reads buffer just as ReadFull does +// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does. +// We return errShortRead instead to distinguish short reads and failures. +// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures +// and causes panic instead of error. +func TryReadFull(r io.Reader, buf []byte) (n int, err error) { + for n < len(buf) && err == nil { + var nn int + nn, err = r.Read(buf[n:]) + n += nn + } + if n == len(buf) && err == io.EOF { + err = nil // EOF at the end is somewhat expected + } else if err == io.EOF { + err = errShortRead + } + return +} |