package huff0 import ( "errors" "fmt" "io" "github.com/klauspost/compress/fse" ) type dTable struct { single []dEntrySingle double []dEntryDouble } // single-symbols decoding type dEntrySingle struct { entry uint16 } // double-symbols decoding type dEntryDouble struct { seq uint16 nBits uint8 len uint8 } // ReadTable will read a table from the input. // The size of the input may be larger than the table definition. // Any content remaining after the table definition will be returned. // If no Scratch is provided a new one is allocated. // The returned Scratch can be used for decoding input using this table. func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) { s, err = s.prepare(in) if err != nil { return s, nil, err } if len(in) <= 1 { return s, nil, errors.New("input too small for table") } iSize := in[0] in = in[1:] if iSize >= 128 { // Uncompressed oSize := iSize - 127 iSize = (oSize + 1) / 2 if int(iSize) > len(in) { return s, nil, errors.New("input too small for table") } for n := uint8(0); n < oSize; n += 2 { v := in[n/2] s.huffWeight[n] = v >> 4 s.huffWeight[n+1] = v & 15 } s.symbolLen = uint16(oSize) in = in[iSize:] } else { if len(in) <= int(iSize) { return s, nil, errors.New("input too small for table") } // FSE compressed weights s.fse.DecompressLimit = 255 hw := s.huffWeight[:] s.fse.Out = hw b, err := fse.Decompress(in[:iSize], s.fse) s.fse.Out = nil if err != nil { return s, nil, err } if len(b) > 255 { return s, nil, errors.New("corrupt input: output table too large") } s.symbolLen = uint16(len(b)) in = in[iSize:] } // collect weight stats var rankStats [16]uint32 weightTotal := uint32(0) for _, v := range s.huffWeight[:s.symbolLen] { if v > tableLogMax { return s, nil, errors.New("corrupt input: weight too large") } v2 := v & 15 rankStats[v2]++ weightTotal += (1 << v2) >> 1 } if weightTotal == 0 { return s, nil, errors.New("corrupt input: weights zero") } // get last non-null symbol weight (implied, total must be 2^n) { tableLog := highBit32(weightTotal) + 1 if tableLog > tableLogMax { return s, nil, errors.New("corrupt input: tableLog too big") } s.actualTableLog = uint8(tableLog) // determine last weight { total := uint32(1) << tableLog rest := total - weightTotal verif := uint32(1) << highBit32(rest) lastWeight := highBit32(rest) + 1 if verif != rest { // last value must be a clean power of 2 return s, nil, errors.New("corrupt input: last value not power of two") } s.huffWeight[s.symbolLen] = uint8(lastWeight) s.symbolLen++ rankStats[lastWeight]++ } } if (rankStats[1] < 2) || (rankStats[1]&1 != 0) { // by construction : at least 2 elts of rank 1, must be even return s, nil, errors.New("corrupt input: min elt size, even check failed ") } // TODO: Choose between single/double symbol decoding // Calculate starting value for each rank { var nextRankStart uint32 for n := uint8(1); n < s.actualTableLog+1; n++ { current := nextRankStart nextRankStart += rankStats[n] << (n - 1) rankStats[n] = current } } // fill DTable (always full size) tSize := 1 << tableLogMax if len(s.dt.single) != tSize { s.dt.single = make([]dEntrySingle, tSize) } for n, w := range s.huffWeight[:s.symbolLen] { if w == 0 { continue } length := (uint32(1) << w) >> 1 d := dEntrySingle{ entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8), } single := s.dt.single[rankStats[w] : rankStats[w]+length] for i := range single { single[i] = d } rankStats[w] += length } return s, in, nil } // Decompress1X will decompress a 1X encoded stream. // The length of the supplied input must match the end of a block exactly. // Before this is called, the table must be initialized with ReadTable unless // the encoder re-used the table. // deprecated: Use the stateless Decoder() to get a concurrent version. func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) { if cap(s.Out) < s.MaxDecodedSize { s.Out = make([]byte, s.MaxDecodedSize) } s.Out = s.Out[:0:s.MaxDecodedSize] s.Out, err = s.Decoder().Decompress1X(s.Out, in) return s.Out, err } // Decompress4X will decompress a 4X encoded stream. // Before this is called, the table must be initialized with ReadTable unless // the encoder re-used the table. // The length of the supplied input must match the end of a block exactly. // The destination size of the uncompressed data must be known and provided. // deprecated: Use the stateless Decoder() to get a concurrent version. func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) { if dstSize > s.MaxDecodedSize { return nil, ErrMaxDecodedSizeExceeded } if cap(s.Out) < dstSize { s.Out = make([]byte, s.MaxDecodedSize) } s.Out = s.Out[:0:dstSize] s.Out, err = s.Decoder().Decompress4X(s.Out, in) return s.Out, err } // Decoder will return a stateless decoder that can be used by multiple // decompressors concurrently. // Before this is called, the table must be initialized with ReadTable. // The Decoder is still linked to the scratch buffer so that cannot be reused. // However, it is safe to discard the scratch. func (s *Scratch) Decoder() *Decoder { return &Decoder{ dt: s.dt, actualTableLog: s.actualTableLog, } } // Decoder provides stateless decoding. type Decoder struct { dt dTable actualTableLog uint8 } // Decompress1X will decompress a 1X encoded stream. // The cap of the output buffer will be the maximum decompressed size. // The length of the supplied input must match the end of a block exactly. func (d *Decoder) Decompress1X(dst, src []byte) ([]byte, error) { if len(d.dt.single) == 0 { return nil, errors.New("no table loaded") } var br bitReader err := br.init(src) if err != nil { return dst, err } maxDecodedSize := cap(dst) dst = dst[:0] decode := func() byte { val := br.peekBitsFast(d.actualTableLog) /* note : actualTableLog >= 1 */ v := d.dt.single[val] br.bitsRead += uint8(v.entry) return uint8(v.entry >> 8) } hasDec := func(v dEntrySingle) byte { br.bitsRead += uint8(v.entry) return uint8(v.entry >> 8) } // Avoid bounds check by always having full sized table. const tlSize = 1 << tableLogMax const tlMask = tlSize - 1 dt := d.dt.single[:tlSize] // Use temp table to avoid bound checks/append penalty. var buf [256]byte var off uint8 for br.off >= 8 { br.fillFast() buf[off+0] = hasDec(dt[br.peekBitsFast(d.actualTableLog)&tlMask]) buf[off+1] = hasDec(dt[br.peekBitsFast(d.actualTableLog)&tlMask]) br.fillFast() buf[off+2] = hasDec(dt[br.peekBitsFast(d.actualTableLog)&tlMask]) buf[off+3] = hasDec(dt[br.peekBitsFast(d.actualTableLog)&tlMask]) off += 4 if off == 0 { if len(dst)+256 > maxDecodedSize { br.close() return nil, ErrMaxDecodedSizeExceeded } dst = append(dst, buf[:]...) } } if len(dst)+int(off) > maxDecodedSize { br.close() return nil, ErrMaxDecodedSizeExceeded } dst = append(dst, buf[:off]...) for !br.finished() { br.fill() if len(dst) >= maxDecodedSize { br.close() return nil, ErrMaxDecodedSizeExceeded } dst = append(dst, decode()) } return dst, br.close() } // Decompress4X will decompress a 4X encoded stream. // The length of the supplied input must match the end of a block exactly. // The *capacity* of the dst slice must match the destination size of // the uncompressed data exactly. func (s *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { if len(s.dt.single) == 0 { return nil, errors.New("no table loaded") } if len(src) < 6+(4*1) { return nil, errors.New("input too small") } var br [4]bitReader start := 6 for i := 0; i < 3; i++ { length := int(src[i*2]) | (int(src[i*2+1]) << 8) if start+length >= len(src) { return nil, errors.New("truncated input (or invalid offset)") } err := br[i].init(src[start : start+length]) if err != nil { return nil, err } start += length } err := br[3].init(src[start:]) if err != nil { return nil, err } // destination, offset to match first output dstSize := cap(dst) dst = dst[:dstSize] out := dst dstEvery := (dstSize + 3) / 4 const tlSize = 1 << tableLogMax const tlMask = tlSize - 1 single := s.dt.single[:tlSize] decode := func(br *bitReader) byte { val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */ v := single[val&tlMask] br.bitsRead += uint8(v.entry) return uint8(v.entry >> 8) } // Use temp table to avoid bound checks/append penalty. var buf [256]byte var off uint8 var decoded int // Decode 2 values from each decoder/loop. const bufoff = 256 / 4 bigloop: for { for i := range br { br := &br[i] if br.off < 4 { break bigloop } br.fillFast() } { const stream = 0 val := br[stream].peekBitsFast(s.actualTableLog) v := single[val&tlMask] br[stream].bitsRead += uint8(v.entry) val2 := br[stream].peekBitsFast(s.actualTableLog) v2 := single[val2&tlMask] buf[off+bufoff*stream+1] = uint8(v2.entry >> 8) buf[off+bufoff*stream] = uint8(v.entry >> 8) br[stream].bitsRead += uint8(v2.entry) } { const stream = 1 val := br[stream].peekBitsFast(s.actualTableLog) v := single[val&tlMask] br[stream].bitsRead += uint8(v.entry) val2 := br[stream].peekBitsFast(s.actualTableLog) v2 := single[val2&tlMask] buf[off+bufoff*stream+1] = uint8(v2.entry >> 8) buf[off+bufoff*stream] = uint8(v.entry >> 8) br[stream].bitsRead += uint8(v2.entry) } { const stream = 2 val := br[stream].peekBitsFast(s.actualTableLog) v := single[val&tlMask] br[stream].bitsRead += uint8(v.entry) val2 := br[stream].peekBitsFast(s.actualTableLog) v2 := single[val2&tlMask] buf[off+bufoff*stream+1] = uint8(v2.entry >> 8) buf[off+bufoff*stream] = uint8(v.entry >> 8) br[stream].bitsRead += uint8(v2.entry) } { const stream = 3 val := br[stream].peekBitsFast(s.actualTableLog) v := single[val&tlMask] br[stream].bitsRead += uint8(v.entry) val2 := br[stream].peekBitsFast(s.actualTableLog) v2 := single[val2&tlMask] buf[off+bufoff*stream+1] = uint8(v2.entry >> 8) buf[off+bufoff*stream] = uint8(v.entry >> 8) br[stream].bitsRead += uint8(v2.entry) } off += 2 if off == bufoff { if bufoff > dstEvery { return nil, errors.New("corruption detected: stream overrun 1") } copy(out, buf[:bufoff]) copy(out[dstEvery:], buf[bufoff:bufoff*2]) copy(out[dstEvery*2:], buf[bufoff*2:bufoff*3]) copy(out[dstEvery*3:], buf[bufoff*3:bufoff*4]) off = 0 out = out[bufoff:] decoded += 256 // There must at least be 3 buffers left. if len(out) < dstEvery*3 { return nil, errors.New("corruption detected: stream overrun 2") } } } if off > 0 { ioff := int(off) if len(out) < dstEvery*3+ioff { return nil, errors.New("corruption detected: stream overrun 3") } copy(out, buf[:off]) copy(out[dstEvery:dstEvery+ioff], buf[bufoff:bufoff*2]) copy(out[dstEvery*2:dstEvery*2+ioff], buf[bufoff*2:bufoff*3]) copy(out[dstEvery*3:dstEvery*3+ioff], buf[bufoff*3:bufoff*4]) decoded += int(off) * 4 out = out[off:] } // Decode remaining. for i := range br { offset := dstEvery * i br := &br[i] for !br.finished() { br.fill() if offset >= len(out) { return nil, errors.New("corruption detected: stream overrun 4") } out[offset] = decode(br) offset++ } decoded += offset - dstEvery*i err = br.close() if err != nil { return nil, err } } if dstSize != decoded { return nil, errors.New("corruption detected: short output block") } return dst, nil } // matches will compare a decoding table to a coding table. // Errors are written to the writer. // Nothing will be written if table is ok. func (s *Scratch) matches(ct cTable, w io.Writer) { if s == nil || len(s.dt.single) == 0 { return } dt := s.dt.single[:1<>8) == byte(sym) { fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym) errs++ break } } if errs == 0 { broken-- } continue } // Unused bits in input ub := tablelog - enc.nBits top := enc.val << ub // decoder looks at top bits. dec := dt[top] if uint8(dec.entry) != enc.nBits { fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry)) errs++ } if uint8(dec.entry>>8) != uint8(sym) { fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8)) errs++ } if errs > 0 { fmt.Fprintf(w, "%d errros in base, stopping\n", errs) continue } // Ensure that all combinations are covered. for i := uint16(0); i < (1 << ub); i++ { vval := top | i dec := dt[vval] if uint8(dec.entry) != enc.nBits { fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry)) errs++ } if uint8(dec.entry>>8) != uint8(sym) { fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8)) errs++ } if errs > 20 { fmt.Fprintf(w, "%d errros, stopping\n", errs) break } } if errs == 0 { ok++ broken-- } } if broken > 0 { fmt.Fprintf(w, "%d broken, %d ok\n", broken, ok) } }