diff options
Diffstat (limited to 'vendor/github.com/klauspost/compress/huff0')
-rw-r--r-- | vendor/github.com/klauspost/compress/huff0/decompress.go | 119 |
1 files changed, 84 insertions, 35 deletions
diff --git a/vendor/github.com/klauspost/compress/huff0/decompress.go b/vendor/github.com/klauspost/compress/huff0/decompress.go index 7e68a4eb4..97ae66a4a 100644 --- a/vendor/github.com/klauspost/compress/huff0/decompress.go +++ b/vendor/github.com/klauspost/compress/huff0/decompress.go @@ -15,8 +15,7 @@ type dTable struct { // single-symbols decoding type dEntrySingle struct { - byte uint8 - nBits uint8 + entry uint16 } // double-symbols decoding @@ -76,14 +75,15 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { } // collect weight stats - var rankStats [tableLogMax + 1]uint32 + var rankStats [16]uint32 weightTotal := uint32(0) for _, v := range s.huffWeight[:s.symbolLen] { if v > tableLogMax { return s, nil, errors.New("corrupt input: weight too large") } - rankStats[v]++ - weightTotal += (1 << (v & 15)) >> 1 + v2 := v & 15 + rankStats[v2]++ + weightTotal += (1 << v2) >> 1 } if weightTotal == 0 { return s, nil, errors.New("corrupt input: weights zero") @@ -134,15 +134,17 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { if len(s.dt.single) != tSize { s.dt.single = make([]dEntrySingle, tSize) } - for n, w := range s.huffWeight[:s.symbolLen] { + if w == 0 { + continue + } length := (uint32(1) << w) >> 1 d := dEntrySingle{ - byte: uint8(n), - nBits: s.actualTableLog + 1 - w, + entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8), } - for u := rankStats[w]; u < rankStats[w]+length; u++ { - s.dt.single[u] = d + single := s.dt.single[rankStats[w] : rankStats[w]+length] + for i := range single { + single[i] = d } rankStats[w] += length } @@ -167,12 +169,12 @@ func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) { decode := func() byte { val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ v := s.dt.single[val] - br.bitsRead += v.nBits - return v.byte + br.bitsRead += uint8(v.entry) + return uint8(v.entry >> 8) } hasDec := func(v dEntrySingle) byte { - br.bitsRead += v.nBits - return v.byte + br.bitsRead += uint8(v.entry) + return uint8(v.entry >> 8) } // Avoid bounds check by always having full sized table. @@ -269,8 +271,8 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { decode := func(br *bitReader) byte { val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ v := single[val&tlMask] - br.bitsRead += v.nBits - return v.byte + br.bitsRead += uint8(v.entry) + return uint8(v.entry >> 8) } // Use temp table to avoid bound checks/append penalty. @@ -283,20 +285,67 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { bigloop: for { for i := range br { - if br[i].off < 4 { + br := &br[i] + if br.off < 4 { break bigloop } - br[i].fillFast() + br.fillFast() + } + + { + const stream = 0 + val := br[stream].peekBitsFast(s.actualTableLog) + v := single[val&tlMask] + br[stream].bitsRead += uint8(v.entry) + + val2 := br[stream].peekBitsFast(s.actualTableLog) + v2 := single[val2&tlMask] + tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) + tmp[off+bufoff*stream] = uint8(v.entry >> 8) + br[stream].bitsRead += uint8(v2.entry) + } + + { + const stream = 1 + val := br[stream].peekBitsFast(s.actualTableLog) + v := single[val&tlMask] + br[stream].bitsRead += uint8(v.entry) + + val2 := br[stream].peekBitsFast(s.actualTableLog) + v2 := single[val2&tlMask] + tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) + tmp[off+bufoff*stream] = uint8(v.entry >> 8) + br[stream].bitsRead += uint8(v2.entry) + } + + { + const stream = 2 + val := br[stream].peekBitsFast(s.actualTableLog) + v := single[val&tlMask] + br[stream].bitsRead += uint8(v.entry) + + val2 := br[stream].peekBitsFast(s.actualTableLog) + v2 := single[val2&tlMask] + tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) + tmp[off+bufoff*stream] = uint8(v.entry >> 8) + br[stream].bitsRead += uint8(v2.entry) } - tmp[off] = decode(&br[0]) - tmp[off+bufoff] = decode(&br[1]) - tmp[off+bufoff*2] = decode(&br[2]) - tmp[off+bufoff*3] = decode(&br[3]) - tmp[off+1] = decode(&br[0]) - tmp[off+1+bufoff] = decode(&br[1]) - tmp[off+1+bufoff*2] = decode(&br[2]) - tmp[off+1+bufoff*3] = decode(&br[3]) + + { + const stream = 3 + val := br[stream].peekBitsFast(s.actualTableLog) + v := single[val&tlMask] + br[stream].bitsRead += uint8(v.entry) + + val2 := br[stream].peekBitsFast(s.actualTableLog) + v2 := single[val2&tlMask] + tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8) + tmp[off+bufoff*stream] = uint8(v.entry >> 8) + br[stream].bitsRead += uint8(v2.entry) + } + off += 2 + if off == bufoff { if bufoff > dstEvery { return nil, errors.New("corruption detected: stream overrun 1") @@ -367,7 +416,7 @@ func (s *Scratch) matches(ct cTable, w io.Writer) { broken++ if enc.nBits == 0 { for _, dec := range dt { - if dec.byte == byte(sym) { + if uint8(dec.entry>>8) == byte(sym) { fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym) errs++ break @@ -383,12 +432,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) { top := enc.val << ub // decoder looks at top bits. dec := dt[top] - if dec.nBits != enc.nBits { - fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits) + if uint8(dec.entry) != enc.nBits { + fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry)) errs++ } - if dec.byte != uint8(sym) { - fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte) + if uint8(dec.entry>>8) != uint8(sym) { + fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8)) errs++ } if errs > 0 { @@ -399,12 +448,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) { for i := uint16(0); i < (1 << ub); i++ { vval := top | i dec := dt[vval] - if dec.nBits != enc.nBits { - fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits) + if uint8(dec.entry) != enc.nBits { + fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry)) errs++ } - if dec.byte != uint8(sym) { - fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte) + if uint8(dec.entry>>8) != uint8(sym) { + fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8)) errs++ } if errs > 20 { |