1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.ChannelHandlerContext;
20 import io.netty.handler.codec.ByteToMessageDecoder;
21
22 import java.util.List;
23
24 import static io.netty.handler.codec.compression.Snappy.validateChecksum;
25
26
27
28
29
30
31
32
33
34
35
36
37 public class SnappyFrameDecoder extends ByteToMessageDecoder {
38
39 private enum ChunkType {
40 STREAM_IDENTIFIER,
41 COMPRESSED_DATA,
42 UNCOMPRESSED_DATA,
43 RESERVED_UNSKIPPABLE,
44 RESERVED_SKIPPABLE
45 }
46
47 private static final int SNAPPY_IDENTIFIER_LEN = 6;
48
49 private static final int MAX_UNCOMPRESSED_DATA_SIZE = 65536 + 4;
50
51 private static final int MAX_DECOMPRESSED_DATA_SIZE = 65536;
52
53 private static final int MAX_COMPRESSED_CHUNK_SIZE = 16777216 - 1;
54
55 private final Snappy snappy = new Snappy();
56 private final boolean validateChecksums;
57
58 private boolean started;
59 private boolean corrupted;
60 private int numBytesToSkip;
61
62
63
64
65
66
67 public SnappyFrameDecoder() {
68 this(false);
69 }
70
71
72
73
74
75
76
77
78
79
80 public SnappyFrameDecoder(boolean validateChecksums) {
81 this.validateChecksums = validateChecksums;
82 }
83
84 @Override
85 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
86 if (corrupted) {
87 in.skipBytes(in.readableBytes());
88 return;
89 }
90
91 if (numBytesToSkip != 0) {
92
93 int skipBytes = Math.min(numBytesToSkip, in.readableBytes());
94 in.skipBytes(skipBytes);
95 numBytesToSkip -= skipBytes;
96
97
98 return;
99 }
100
101 try {
102 int idx = in.readerIndex();
103 final int inSize = in.readableBytes();
104 if (inSize < 4) {
105
106
107 return;
108 }
109
110 final int chunkTypeVal = in.getUnsignedByte(idx);
111 final ChunkType chunkType = mapChunkType((byte) chunkTypeVal);
112 final int chunkLength = in.getUnsignedMediumLE(idx + 1);
113
114 switch (chunkType) {
115 case STREAM_IDENTIFIER:
116 if (chunkLength != SNAPPY_IDENTIFIER_LEN) {
117 throw new DecompressionException("Unexpected length of stream identifier: " + chunkLength);
118 }
119
120 if (inSize < 4 + SNAPPY_IDENTIFIER_LEN) {
121 break;
122 }
123
124 in.skipBytes(4);
125 int offset = in.readerIndex();
126 in.skipBytes(SNAPPY_IDENTIFIER_LEN);
127
128 checkByte(in.getByte(offset++), (byte) 's');
129 checkByte(in.getByte(offset++), (byte) 'N');
130 checkByte(in.getByte(offset++), (byte) 'a');
131 checkByte(in.getByte(offset++), (byte) 'P');
132 checkByte(in.getByte(offset++), (byte) 'p');
133 checkByte(in.getByte(offset), (byte) 'Y');
134
135 started = true;
136 break;
137 case RESERVED_SKIPPABLE:
138 if (!started) {
139 throw new DecompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER");
140 }
141
142 in.skipBytes(4);
143
144 int skipBytes = Math.min(chunkLength, in.readableBytes());
145 in.skipBytes(skipBytes);
146 if (skipBytes != chunkLength) {
147
148
149 numBytesToSkip = chunkLength - skipBytes;
150 }
151 break;
152 case RESERVED_UNSKIPPABLE:
153
154
155
156 throw new DecompressionException(
157 "Found reserved unskippable chunk type: 0x" + Integer.toHexString(chunkTypeVal));
158 case UNCOMPRESSED_DATA:
159 if (!started) {
160 throw new DecompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER");
161 }
162 if (chunkLength > MAX_UNCOMPRESSED_DATA_SIZE) {
163 throw new DecompressionException("Received UNCOMPRESSED_DATA larger than " +
164 MAX_UNCOMPRESSED_DATA_SIZE + " bytes");
165 }
166
167 if (inSize < 4 + chunkLength) {
168 return;
169 }
170
171 in.skipBytes(4);
172 if (validateChecksums) {
173 int checksum = in.readIntLE();
174 validateChecksum(checksum, in, in.readerIndex(), chunkLength - 4);
175 } else {
176 in.skipBytes(4);
177 }
178 out.add(in.readRetainedSlice(chunkLength - 4));
179 break;
180 case COMPRESSED_DATA:
181 if (!started) {
182 throw new DecompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER");
183 }
184
185 if (chunkLength > MAX_COMPRESSED_CHUNK_SIZE) {
186 throw new DecompressionException("Received COMPRESSED_DATA that contains" +
187 " chunk that exceeds " + MAX_COMPRESSED_CHUNK_SIZE + " bytes");
188 }
189
190 if (inSize < 4 + chunkLength) {
191 return;
192 }
193
194 in.skipBytes(4);
195 int checksum = in.readIntLE();
196
197 int uncompressedSize = snappy.getPreamble(in);
198 if (uncompressedSize > MAX_DECOMPRESSED_DATA_SIZE) {
199 throw new DecompressionException("Received COMPRESSED_DATA that contains" +
200 " uncompressed data that exceeds " + MAX_DECOMPRESSED_DATA_SIZE + " bytes");
201 }
202
203 ByteBuf uncompressed = ctx.alloc().buffer(uncompressedSize, MAX_DECOMPRESSED_DATA_SIZE);
204 try {
205 if (validateChecksums) {
206 int oldWriterIndex = in.writerIndex();
207 try {
208 in.writerIndex(in.readerIndex() + chunkLength - 4);
209 snappy.decode(in, uncompressed);
210 } finally {
211 in.writerIndex(oldWriterIndex);
212 }
213 validateChecksum(checksum, uncompressed, 0, uncompressed.writerIndex());
214 } else {
215 snappy.decode(in.readSlice(chunkLength - 4), uncompressed);
216 }
217 out.add(uncompressed);
218 uncompressed = null;
219 } finally {
220 if (uncompressed != null) {
221 uncompressed.release();
222 }
223 }
224 snappy.reset();
225 break;
226 }
227 } catch (Exception e) {
228 corrupted = true;
229 throw e;
230 }
231 }
232
233 private static void checkByte(byte actual, byte expect) {
234 if (actual != expect) {
235 throw new DecompressionException("Unexpected stream identifier contents. Mismatched snappy " +
236 "protocol version?");
237 }
238 }
239
240
241
242
243
244
245
246 private static ChunkType mapChunkType(byte type) {
247 if (type == 0) {
248 return ChunkType.COMPRESSED_DATA;
249 } else if (type == 1) {
250 return ChunkType.UNCOMPRESSED_DATA;
251 } else if (type == (byte) 0xff) {
252 return ChunkType.STREAM_IDENTIFIER;
253 } else if ((type & 0x80) == 0x80) {
254 return ChunkType.RESERVED_SKIPPABLE;
255 } else {
256 return ChunkType.RESERVED_UNSKIPPABLE;
257 }
258 }
259 }