package fse import ( "errors" "fmt" ) const ( tablelogAbsoluteMax = 15 ) // Decompress a block of data. // You can provide a scratch buffer to avoid allocations. // If nil is provided a temporary one will be allocated. // It is possible, but by no way guaranteed that corrupt data will // return an error. // It is up to the caller to verify integrity of the returned data. // Use a predefined Scrach to set maximum acceptable output size. func Decompress(b []byte, s *Scratch) ([]byte, error) { s, err := s.prepare(b) if err != nil { return nil, err } s.Out = s.Out[:0] err = s.readNCount() if err != nil { return nil, err } err = s.buildDtable() if err != nil { return nil, err } err = s.decompress() if err != nil { return nil, err } return s.Out, nil } // readNCount will read the symbol distribution so decoding tables can be constructed. func (s *Scratch) readNCount() error { var ( charnum uint16 previous0 bool b = &s.br ) iend := b.remain() if iend < 4 { return errors.New("input too small") } bitStream := b.Uint32() nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog if nbBits > tablelogAbsoluteMax { return errors.New("tableLog too large") } bitStream >>= 4 bitCount := uint(4) s.actualTableLog = uint8(nbBits) remaining := int32((1 << nbBits) + 1) threshold := int32(1 << nbBits) gotTotal := int32(0) nbBits++ for remaining > 1 { if previous0 { n0 := charnum for (bitStream & 0xFFFF) == 0xFFFF { n0 += 24 if b.off < iend-5 { b.advance(2) bitStream = b.Uint32() >> bitCount } else { bitStream >>= 16 bitCount += 16 } } for (bitStream & 3) == 3 { n0 += 3 bitStream >>= 2 bitCount += 2 } n0 += uint16(bitStream & 3) bitCount += 2 if n0 > maxSymbolValue { return errors.New("maxSymbolValue too small") } for charnum < n0 { s.norm[charnum&0xff] = 0 charnum++ } if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { b.advance(bitCount >> 3) bitCount &= 7 bitStream = b.Uint32() >> bitCount } else { bitStream >>= 2 } } max := (2*(threshold) - 1) - (remaining) var count int32 if (int32(bitStream) & (threshold - 1)) < max { count = int32(bitStream) & (threshold - 1) bitCount += nbBits - 1 } else { count = int32(bitStream) & (2*threshold - 1) if count >= threshold { count -= max } bitCount += nbBits } count-- // extra accuracy if count < 0 { // -1 means +1 remaining += count gotTotal -= count } else { remaining -= count gotTotal += count } s.norm[charnum&0xff] = int16(count) charnum++ previous0 = count == 0 for remaining < threshold { nbBits-- threshold >>= 1 } if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { b.advance(bitCount >> 3) bitCount &= 7 } else { bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) b.off = len(b.b) - 4 } bitStream = b.Uint32() >> (bitCount & 31) } s.symbolLen = charnum if s.symbolLen <= 1 { return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) } if s.symbolLen > maxSymbolValue+1 { return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) } if remaining != 1 { return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) } if bitCount > 32 { return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) } if gotTotal != 1<> 3) return nil } // decSymbol contains information about a state entry, // Including the state offset base, the output symbol and // the number of bits to read for the low part of the destination state. type decSymbol struct { newState uint16 symbol uint8 nbBits uint8 } // allocDtable will allocate decoding tables if they are not big enough. func (s *Scratch) allocDtable() { tableSize := 1 << s.actualTableLog if cap(s.decTable) < int(tableSize) { s.decTable = make([]decSymbol, tableSize) } s.decTable = s.decTable[:tableSize] if cap(s.ct.tableSymbol) < 256 { s.ct.tableSymbol = make([]byte, 256) } s.ct.tableSymbol = s.ct.tableSymbol[:256] if cap(s.ct.stateTable) < 256 { s.ct.stateTable = make([]uint16, 256) } s.ct.stateTable = s.ct.stateTable[:256] } // buildDtable will build the decoding table. func (s *Scratch) buildDtable() error { tableSize := uint32(1 << s.actualTableLog) highThreshold := tableSize - 1 s.allocDtable() symbolNext := s.ct.stateTable[:256] // Init, lay down lowprob symbols s.zeroBits = false { largeLimit := int16(1 << (s.actualTableLog - 1)) for i, v := range s.norm[:s.symbolLen] { if v == -1 { s.decTable[highThreshold].symbol = uint8(i) highThreshold-- symbolNext[i] = 1 } else { if v >= largeLimit { s.zeroBits = true } symbolNext[i] = uint16(v) } } } // Spread symbols { tableMask := tableSize - 1 step := tableStep(tableSize) position := uint32(0) for ss, v := range s.norm[:s.symbolLen] { for i := 0; i < int(v); i++ { s.decTable[position].symbol = uint8(ss) position = (position + step) & tableMask for position > highThreshold { // lowprob area position = (position + step) & tableMask } } } if position != 0 { // position must reach all cells once, otherwise normalizedCounter is incorrect return errors.New("corrupted input (position != 0)") } } // Build Decoding table { tableSize := uint16(1 << s.actualTableLog) for u, v := range s.decTable { symbol := v.symbol nextState := symbolNext[symbol] symbolNext[symbol] = nextState + 1 nBits := s.actualTableLog - byte(highBits(uint32(nextState))) s.decTable[u].nbBits = nBits newState := (nextState << nBits) - tableSize if newState > tableSize { return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) } if newState == uint16(u) && nBits == 0 { // Seems weird that this is possible with nbits > 0. return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) } s.decTable[u].newState = newState } } return nil } // decompress will decompress the bitstream. // If the buffer is over-read an error is returned. func (s *Scratch) decompress() error { br := &s.bits br.init(s.br.unread()) var s1, s2 decoder // Initialize and decode first state and symbol. s1.init(br, s.decTable, s.actualTableLog) s2.init(br, s.decTable, s.actualTableLog) // Use temp table to avoid bound checks/append penalty. var tmp = s.ct.tableSymbol[:256] var off uint8 // Main part if !s.zeroBits { for br.off >= 8 { br.fillFast() tmp[off+0] = s1.nextFast() tmp[off+1] = s2.nextFast() br.fillFast() tmp[off+2] = s1.nextFast() tmp[off+3] = s2.nextFast() off += 4 if off == 0 { s.Out = append(s.Out, tmp...) } } } else { for br.off >= 8 { br.fillFast() tmp[off+0] = s1.next() tmp[off+1] = s2.next() br.fillFast() tmp[off+2] = s1.next() tmp[off+3] = s2.next() off += 4 if off == 0 { s.Out = append(s.Out, tmp...) off = 0 if len(s.Out) >= s.DecompressLimit { return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) } } } } s.Out = append(s.Out, tmp[:off]...) // Final bits, a bit more expensive check for { if s1.finished() { s.Out = append(s.Out, s1.final(), s2.final()) break } br.fill() s.Out = append(s.Out, s1.next()) if s2.finished() { s.Out = append(s.Out, s2.final(), s1.final()) break } s.Out = append(s.Out, s2.next()) if len(s.Out) >= s.DecompressLimit { return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) } } return br.close() } // decoder keeps track of the current state and updates it from the bitstream. type decoder struct { state uint16 br *bitReader dt []decSymbol } // init will initialize the decoder and read the first state from the stream. func (d *decoder) init(in *bitReader, dt []decSymbol, tableLog uint8) { d.dt = dt d.br = in d.state = uint16(in.getBits(tableLog)) } // next returns the next symbol and sets the next state. // At least tablelog bits must be available in the bit reader. func (d *decoder) next() uint8 { n := &d.dt[d.state] lowBits := d.br.getBits(n.nbBits) d.state = n.newState + lowBits return n.symbol } // finished returns true if all bits have been read from the bitstream // and the next state would require reading bits from the input. func (d *decoder) finished() bool { return d.br.finished() && d.dt[d.state].nbBits > 0 } // final returns the current state symbol without decoding the next. func (d *decoder) final() uint8 { return d.dt[d.state].symbol } // nextFast returns the next symbol and sets the next state. // This can only be used if no symbols are 0 bits. // At least tablelog bits must be available in the bit reader. func (d *decoder) nextFast() uint8 { n := d.dt[d.state] lowBits := d.br.getBitsFast(n.nbBits) d.state = n.newState + lowBits return n.symbol }