summaryrefslogtreecommitdiff
path: root/vendor/github.com/DataDog/zstd/zstd.go
blob: b6af4eb19df67466f73cc354f0b8a7559431cdfc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package zstd

/*
#define ZSTD_STATIC_LINKING_ONLY
#include "zstd.h"
#include "stdint.h"  // for uintptr_t

// The following *_wrapper function are used for removing superflouos
// memory allocations when calling the wrapped functions from Go code.
// See https://github.com/golang/go/issues/24450 for details.

static size_t ZSTD_compress_wrapper(uintptr_t dst, size_t maxDstSize, const uintptr_t src, size_t srcSize, int compressionLevel) {
	return ZSTD_compress((void*)dst, maxDstSize, (const void*)src, srcSize, compressionLevel);
}

static size_t ZSTD_decompress_wrapper(uintptr_t dst, size_t maxDstSize, uintptr_t src, size_t srcSize) {
	return ZSTD_decompress((void*)dst, maxDstSize, (const void *)src, srcSize);
}

*/
import "C"
import (
	"bytes"
	"errors"
	"io/ioutil"
	"runtime"
	"unsafe"
)

// Defines best and standard values for zstd cli
const (
	BestSpeed          = 1
	BestCompression    = 20
	DefaultCompression = 5
)

var (
	// ErrEmptySlice is returned when there is nothing to compress
	ErrEmptySlice = errors.New("Bytes slice is empty")
)

// CompressBound returns the worst case size needed for a destination buffer,
// which can be used to preallocate a destination buffer or select a previously
// allocated buffer from a pool.
// See zstd.h to mirror implementation of ZSTD_COMPRESSBOUND
func CompressBound(srcSize int) int {
	lowLimit := 128 << 10 // 128 kB
	var margin int
	if srcSize < lowLimit {
		margin = (lowLimit - srcSize) >> 11
	}
	return srcSize + (srcSize >> 8) + margin
}

// cCompressBound is a cgo call to check the go implementation above against the c code.
func cCompressBound(srcSize int) int {
	return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}

// Compress src into dst.  If you have a buffer to use, you can pass it to
// prevent allocation.  If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
func Compress(dst, src []byte) ([]byte, error) {
	return CompressLevel(dst, src, DefaultCompression)
}

// CompressLevel is the same as Compress but you can pass a compression level
func CompressLevel(dst, src []byte, level int) ([]byte, error) {
	bound := CompressBound(len(src))
	if cap(dst) >= bound {
		dst = dst[0:bound] // Reuse dst buffer
	} else {
		dst = make([]byte, bound)
	}

	srcPtr := C.uintptr_t(uintptr(0)) // Do not point anywhere, if src is empty
	if len(src) > 0 {
		srcPtr = C.uintptr_t(uintptr(unsafe.Pointer(&src[0])))
	}

	cWritten := C.ZSTD_compress_wrapper(
		C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
		C.size_t(len(dst)),
		srcPtr,
		C.size_t(len(src)),
		C.int(level))

	runtime.KeepAlive(src)
	written := int(cWritten)
	// Check if the return is an Error code
	if err := getError(written); err != nil {
		return nil, err
	}
	return dst[:written], nil
}

// Decompress src into dst.  If you have a buffer to use, you can pass it to
// prevent allocation.  If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
func Decompress(dst, src []byte) ([]byte, error) {
	if len(src) == 0 {
		return []byte{}, ErrEmptySlice
	}
	decompress := func(dst, src []byte) ([]byte, error) {

		cWritten := C.ZSTD_decompress_wrapper(
			C.uintptr_t(uintptr(unsafe.Pointer(&dst[0]))),
			C.size_t(len(dst)),
			C.uintptr_t(uintptr(unsafe.Pointer(&src[0]))),
			C.size_t(len(src)))

		runtime.KeepAlive(src)
		written := int(cWritten)
		// Check error
		if err := getError(written); err != nil {
			return nil, err
		}
		return dst[:written], nil
	}

	if len(dst) == 0 {
		// Attempt to use zStd to determine decompressed size (may result in error or 0)
		size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))

		if err := getError(size); err != nil {
			return nil, err
		}

		if size > 0 {
			dst = make([]byte, size)
		} else {
			dst = make([]byte, len(src)*3) // starting guess
		}
	}
	for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
		result, err := decompress(dst, src)
		if !IsDstSizeTooSmallError(err) {
			return result, err
		}
		dst = make([]byte, len(dst)*2) // Grow buffer by 2
	}

	// We failed getting a dst buffer of correct size, use stream API
	r := NewReader(bytes.NewReader(src))
	defer r.Close()
	return ioutil.ReadAll(r)
}