aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/klauspost/compress/zstd/bytebuf.go
blob: 2ad02070d740a528a1f10f6ca2b2a0eeaebe6a21 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright 2019+ Klaus Post. All rights reserved.
// License information can be found in the LICENSE file.
// Based on work by Yann Collet, released under BSD License.

package zstd

import (
	"fmt"
	"io"
	"io/ioutil"
)

type byteBuffer interface {
	// Read up to 8 bytes.
	// Returns io.ErrUnexpectedEOF if this cannot be satisfied.
	readSmall(n int) ([]byte, error)

	// Read >8 bytes.
	// MAY use the destination slice.
	readBig(n int, dst []byte) ([]byte, error)

	// Read a single byte.
	readByte() (byte, error)

	// Skip n bytes.
	skipN(n int64) error
}

// in-memory buffer
type byteBuf []byte

func (b *byteBuf) readSmall(n int) ([]byte, error) {
	if debugAsserts && n > 8 {
		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
	}
	bb := *b
	if len(bb) < n {
		return nil, io.ErrUnexpectedEOF
	}
	r := bb[:n]
	*b = bb[n:]
	return r, nil
}

func (b *byteBuf) readBig(n int, dst []byte) ([]byte, error) {
	bb := *b
	if len(bb) < n {
		return nil, io.ErrUnexpectedEOF
	}
	r := bb[:n]
	*b = bb[n:]
	return r, nil
}

func (b *byteBuf) readByte() (byte, error) {
	bb := *b
	if len(bb) < 1 {
		return 0, nil
	}
	r := bb[0]
	*b = bb[1:]
	return r, nil
}

func (b *byteBuf) skipN(n int64) error {
	bb := *b
	if n < 0 {
		return fmt.Errorf("negative skip (%d) requested", n)
	}
	if int64(len(bb)) < n {
		return io.ErrUnexpectedEOF
	}
	*b = bb[n:]
	return nil
}

// wrapper around a reader.
type readerWrapper struct {
	r   io.Reader
	tmp [8]byte
}

func (r *readerWrapper) readSmall(n int) ([]byte, error) {
	if debugAsserts && n > 8 {
		panic(fmt.Errorf("small read > 8 (%d). use readBig", n))
	}
	n2, err := io.ReadFull(r.r, r.tmp[:n])
	// We only really care about the actual bytes read.
	if err != nil {
		if err == io.EOF {
			return nil, io.ErrUnexpectedEOF
		}
		if debugDecoder {
			println("readSmall: got", n2, "want", n, "err", err)
		}
		return nil, err
	}
	return r.tmp[:n], nil
}

func (r *readerWrapper) readBig(n int, dst []byte) ([]byte, error) {
	if cap(dst) < n {
		dst = make([]byte, n)
	}
	n2, err := io.ReadFull(r.r, dst[:n])
	if err == io.EOF && n > 0 {
		err = io.ErrUnexpectedEOF
	}
	return dst[:n2], err
}

func (r *readerWrapper) readByte() (byte, error) {
	n2, err := r.r.Read(r.tmp[:1])
	if err != nil {
		if err == io.EOF {
			err = io.ErrUnexpectedEOF
		}
		return 0, err
	}
	if n2 != 1 {
		return 0, io.ErrUnexpectedEOF
	}
	return r.tmp[0], nil
}

func (r *readerWrapper) skipN(n int64) error {
	n2, err := io.CopyN(ioutil.Discard, r.r, n)
	if n2 != n {
		err = io.ErrUnexpectedEOF
	}
	return err
}