// Copyright 2019+ Klaus Post. All rights reserved. // License information can be found in the LICENSE file. // Based on work by Yann Collet, released under BSD License. package zstd import ( "encoding/binary" "errors" "hash/crc32" "io" "github.com/klauspost/compress/huff0" "github.com/klauspost/compress/snappy" ) const ( snappyTagLiteral = 0x00 snappyTagCopy1 = 0x01 snappyTagCopy2 = 0x02 snappyTagCopy4 = 0x03 ) const ( snappyChecksumSize = 4 snappyMagicBody = "sNaPpY" // snappyMaxBlockSize is the maximum size of the input to encodeBlock. It is not // part of the wire format per se, but some parts of the encoder assume // that an offset fits into a uint16. // // Also, for the framing format (Writer type instead of Encode function), // https://github.com/google/snappy/blob/master/framing_format.txt says // that "the uncompressed data in a chunk must be no longer than 65536 // bytes". snappyMaxBlockSize = 65536 // snappyMaxEncodedLenOfMaxBlockSize equals MaxEncodedLen(snappyMaxBlockSize), but is // hard coded to be a const instead of a variable, so that obufLen can also // be a const. Their equivalence is confirmed by // TestMaxEncodedLenOfMaxBlockSize. snappyMaxEncodedLenOfMaxBlockSize = 76490 ) const ( chunkTypeCompressedData = 0x00 chunkTypeUncompressedData = 0x01 chunkTypePadding = 0xfe chunkTypeStreamIdentifier = 0xff ) var ( // ErrSnappyCorrupt reports that the input is invalid. ErrSnappyCorrupt = errors.New("snappy: corrupt input") // ErrSnappyTooLarge reports that the uncompressed length is too large. ErrSnappyTooLarge = errors.New("snappy: decoded block is too large") // ErrSnappyUnsupported reports that the input isn't supported. ErrSnappyUnsupported = errors.New("snappy: unsupported input") errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length") ) // SnappyConverter can read SnappyConverter-compressed streams and convert them to zstd. // Conversion is done by converting the stream directly from Snappy without intermediate // full decoding. // Therefore the compression ratio is much less than what can be done by a full decompression // and compression, and a faulty Snappy stream may lead to a faulty Zstandard stream without // any errors being generated. // No CRC value is being generated and not all CRC values of the Snappy stream are checked. // However, it provides really fast recompression of Snappy streams. // The converter can be reused to avoid allocations, even after errors. type SnappyConverter struct { r io.Reader err error buf []byte block *blockEnc } // Convert the Snappy stream supplied in 'in' and write the zStandard stream to 'w'. // If any error is detected on the Snappy stream it is returned. // The number of bytes written is returned. func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) { initPredefined() r.err = nil r.r = in if r.block == nil { r.block = &blockEnc{} r.block.init() } r.block.initNewEncode() if len(r.buf) != snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize { r.buf = make([]byte, snappyMaxEncodedLenOfMaxBlockSize+snappyChecksumSize) } r.block.litEnc.Reuse = huff0.ReusePolicyNone var written int64 var readHeader bool { var header []byte var n int header, r.err = frameHeader{WindowSize: snappyMaxBlockSize}.appendTo(r.buf[:0]) n, r.err = w.Write(header) if r.err != nil { return written, r.err } written += int64(n) } for { if !r.readFull(r.buf[:4], true) { // Add empty last block r.block.reset(nil) r.block.last = true err := r.block.encodeLits(r.block.literals, false) if err != nil { return written, err } n, err := w.Write(r.block.output) if err != nil { return written, err } written += int64(n) return written, r.err } chunkType := r.buf[0] if !readHeader { if chunkType != chunkTypeStreamIdentifier { println("chunkType != chunkTypeStreamIdentifier", chunkType) r.err = ErrSnappyCorrupt return written, r.err } readHeader = true } chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16 if chunkLen > len(r.buf) { println("chunkLen > len(r.buf)", chunkType) r.err = ErrSnappyUnsupported return written, r.err } // The chunk types are specified at // https://github.com/google/snappy/blob/master/framing_format.txt switch chunkType { case chunkTypeCompressedData: // Section 4.2. Compressed data (chunk type 0x00). if chunkLen < snappyChecksumSize { println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize) r.err = ErrSnappyCorrupt return written, r.err } buf := r.buf[:chunkLen] if !r.readFull(buf, false) { return written, r.err } //checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 buf = buf[snappyChecksumSize:] n, hdr, err := snappyDecodedLen(buf) if err != nil { r.err = err return written, r.err } buf = buf[hdr:] if n > snappyMaxBlockSize { println("n > snappyMaxBlockSize", n, snappyMaxBlockSize) r.err = ErrSnappyCorrupt return written, r.err } r.block.reset(nil) r.block.pushOffsets() if err := decodeSnappy(r.block, buf); err != nil { r.err = err return written, r.err } if r.block.size+r.block.extraLits != n { printf("invalid size, want %d, got %d\n", n, r.block.size+r.block.extraLits) r.err = ErrSnappyCorrupt return written, r.err } err = r.block.encode(nil, false, false) switch err { case errIncompressible: r.block.popOffsets() r.block.reset(nil) r.block.literals, err = snappy.Decode(r.block.literals[:n], r.buf[snappyChecksumSize:chunkLen]) if err != nil { println("snappy.Decode:", err) return written, err } err = r.block.encodeLits(r.block.literals, false) if err != nil { return written, err } case nil: default: return written, err } n, r.err = w.Write(r.block.output) if r.err != nil { return written, err } written += int64(n) continue case chunkTypeUncompressedData: if debug { println("Uncompressed, chunklen", chunkLen) } // Section 4.3. Uncompressed data (chunk type 0x01). if chunkLen < snappyChecksumSize { println("chunkLen < snappyChecksumSize", chunkLen, snappyChecksumSize) r.err = ErrSnappyCorrupt return written, r.err } r.block.reset(nil) buf := r.buf[:snappyChecksumSize] if !r.readFull(buf, false) { return written, r.err } checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24 // Read directly into r.decoded instead of via r.buf. n := chunkLen - snappyChecksumSize if n > snappyMaxBlockSize { println("n > snappyMaxBlockSize", n, snappyMaxBlockSize) r.err = ErrSnappyCorrupt return written, r.err } r.block.literals = r.block.literals[:n] if !r.readFull(r.block.literals, false) { return written, r.err } if snappyCRC(r.block.literals) != checksum { println("literals crc mismatch") r.err = ErrSnappyCorrupt return written, r.err } err := r.block.encodeLits(r.block.literals, false) if err != nil { return written, err } n, r.err = w.Write(r.block.output) if r.err != nil { return written, err } written += int64(n) continue case chunkTypeStreamIdentifier: if debug { println("stream id", chunkLen, len(snappyMagicBody)) } // Section 4.1. Stream identifier (chunk type 0xff). if chunkLen != len(snappyMagicBody) { println("chunkLen != len(snappyMagicBody)", chunkLen, len(snappyMagicBody)) r.err = ErrSnappyCorrupt return written, r.err } if !r.readFull(r.buf[:len(snappyMagicBody)], false) { return written, r.err } for i := 0; i < len(snappyMagicBody); i++ { if r.buf[i] != snappyMagicBody[i] { println("r.buf[i] != snappyMagicBody[i]", r.buf[i], snappyMagicBody[i], i) r.err = ErrSnappyCorrupt return written, r.err } } continue } if chunkType <= 0x7f { // Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f). println("chunkType <= 0x7f") r.err = ErrSnappyUnsupported return written, r.err } // Section 4.4 Padding (chunk type 0xfe). // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). if !r.readFull(r.buf[:chunkLen], false) { return written, r.err } } } // decodeSnappy writes the decoding of src to dst. It assumes that the varint-encoded // length of the decompressed bytes has already been read. func decodeSnappy(blk *blockEnc, src []byte) error { //decodeRef(make([]byte, snappyMaxBlockSize), src) var s, length int lits := blk.extraLits var offset uint32 for s < len(src) { switch src[s] & 0x03 { case snappyTagLiteral: x := uint32(src[s] >> 2) switch { case x < 60: s++ case x == 60: s += 2 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, src) return ErrSnappyCorrupt } x = uint32(src[s-1]) case x == 61: s += 3 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, src) return ErrSnappyCorrupt } x = uint32(src[s-2]) | uint32(src[s-1])<<8 case x == 62: s += 4 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, src) return ErrSnappyCorrupt } x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 case x == 63: s += 5 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, src) return ErrSnappyCorrupt } x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 } if x > snappyMaxBlockSize { println("x > snappyMaxBlockSize", x, snappyMaxBlockSize) return ErrSnappyCorrupt } length = int(x) + 1 if length <= 0 { println("length <= 0 ", length) return errUnsupportedLiteralLength } //if length > snappyMaxBlockSize-d || uint32(length) > len(src)-s { // return ErrSnappyCorrupt //} blk.literals = append(blk.literals, src[s:s+length]...) //println(length, "litLen") lits += length s += length continue case snappyTagCopy1: s += 2 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, len(src)) return ErrSnappyCorrupt } length = 4 + int(src[s-2])>>2&0x7 offset = uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]) case snappyTagCopy2: s += 3 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, len(src)) return ErrSnappyCorrupt } length = 1 + int(src[s-3])>>2 offset = uint32(src[s-2]) | uint32(src[s-1])<<8 case snappyTagCopy4: s += 5 if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line. println("uint(s) > uint(len(src)", s, len(src)) return ErrSnappyCorrupt } length = 1 + int(src[s-5])>>2 offset = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 } if offset <= 0 || blk.size+lits < int(offset) /*|| length > len(blk)-d */ { println("offset <= 0 || blk.size+lits < int(offset)", offset, blk.size+lits, int(offset), blk.size, lits) return ErrSnappyCorrupt } // Check if offset is one of the recent offsets. // Adjusts the output offset accordingly. // Gives a tiny bit of compression, typically around 1%. if false { offset = blk.matchOffset(offset, uint32(lits)) } else { offset += 3 } blk.sequences = append(blk.sequences, seq{ litLen: uint32(lits), offset: offset, matchLen: uint32(length) - zstdMinMatch, }) blk.size += length + lits lits = 0 } blk.extraLits = lits return nil } func (r *SnappyConverter) readFull(p []byte, allowEOF bool) (ok bool) { if _, r.err = io.ReadFull(r.r, p); r.err != nil { if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) { r.err = ErrSnappyCorrupt } return false } return true } var crcTable = crc32.MakeTable(crc32.Castagnoli) // crc implements the checksum specified in section 3 of // https://github.com/google/snappy/blob/master/framing_format.txt func snappyCRC(b []byte) uint32 { c := crc32.Update(0, crcTable, b) return c>>15 | c<<17 + 0xa282ead8 } // snappyDecodedLen returns the length of the decoded block and the number of bytes // that the length header occupied. func snappyDecodedLen(src []byte) (blockLen, headerLen int, err error) { v, n := binary.Uvarint(src) if n <= 0 || v > 0xffffffff { return 0, 0, ErrSnappyCorrupt } const wordSize = 32 << (^uint(0) >> 32 & 1) if wordSize == 32 && v > 0x7fffffff { return 0, 0, ErrSnappyTooLarge } return int(v), n, nil }