1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec.compression;
17
18 import com.github.luben.zstd.BaseZstdBufferDecompressingStreamNoFinalizer;
19 import com.github.luben.zstd.ZstdBufferDecompressingStreamNoFinalizer;
20 import com.github.luben.zstd.ZstdDirectBufferDecompressingStreamNoFinalizer;
21 import io.netty.buffer.ByteBuf;
22 import io.netty.buffer.ByteBufAllocator;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.handler.codec.ByteToMessageDecoder;
25
26 import java.io.IOException;
27 import java.nio.ByteBuffer;
28 import java.util.List;
29
30
31
32
33
34 public final class ZstdDecoder extends ByteToMessageDecoder {
35
36 {
37 try {
38 Zstd.ensureAvailability();
39 outCapacity = ZstdBufferDecompressingStreamNoFinalizer.recommendedTargetBufferSize();
40 } catch (Throwable throwable) {
41 throw new ExceptionInInitializerError(throwable);
42 }
43 }
44 private final int outCapacity;
45
46 private State currentState = State.DECOMPRESS_DATA;
47 private ZstdStream stream;
48
49
50
51
52 private enum State {
53 DECOMPRESS_DATA,
54 CORRUPTED
55 }
56
57 @Override
58 protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
59 try {
60 if (currentState == State.CORRUPTED) {
61 in.skipBytes(in.readableBytes());
62 return;
63 }
64 final int compressedLength = in.readableBytes();
65 if (compressedLength == 0) {
66
67 return;
68 }
69 if (stream == null) {
70
71 stream = new ZstdStream(in.isDirect(), outCapacity);
72 }
73
74 do {
75 ByteBuf decompressed = stream.decompress(ctx.alloc(), in);
76 if (decompressed == null) {
77 return;
78 }
79 out.add(decompressed);
80 } while (in.isReadable());
81 } catch (DecompressionException e) {
82 currentState = State.CORRUPTED;
83 throw e;
84 }
85 }
86
87 @Override
88 protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
89 try {
90 if (stream != null) {
91 stream.close();
92 stream = null;
93 }
94 } finally {
95 super.handlerRemoved0(ctx);
96 }
97 }
98
99 private static final class ZstdStream {
100 private static final ByteBuffer EMPTY_HEAP_BUFFER = ByteBuffer.allocate(0);
101 private static final ByteBuffer EMPTY_DIRECT_BUFFER = ByteBuffer.allocateDirect(0);
102
103 private final boolean direct;
104 private final int outCapacity;
105 private final BaseZstdBufferDecompressingStreamNoFinalizer decompressingStream;
106 private ByteBuffer current;
107
108 ZstdStream(boolean direct, int outCapacity) {
109 this.direct = direct;
110 this.outCapacity = outCapacity;
111 if (direct) {
112 decompressingStream = new ZstdDirectBufferDecompressingStreamNoFinalizer(EMPTY_DIRECT_BUFFER) {
113 @Override
114 protected ByteBuffer refill(ByteBuffer toRefill) {
115 return ZstdStream.this.refill(toRefill);
116 }
117 };
118 } else {
119 decompressingStream = new ZstdBufferDecompressingStreamNoFinalizer(EMPTY_HEAP_BUFFER) {
120 @Override
121 protected ByteBuffer refill(ByteBuffer toRefill) {
122 return ZstdStream.this.refill(toRefill);
123 }
124 };
125 }
126 }
127
128 ByteBuf decompress(ByteBufAllocator alloc, ByteBuf in) throws DecompressionException {
129 final ByteBuf source;
130
131 if (direct && !in.isDirect()) {
132 source = alloc.directBuffer(in.readableBytes());
133 source.writeBytes(in, in.readerIndex(), in.readableBytes());
134 } else if (!direct && !in.hasArray()) {
135 source = alloc.heapBuffer(in.readableBytes());
136 source.writeBytes(in, in.readerIndex(), in.readableBytes());
137 } else {
138 source = in;
139 }
140 int inPosition = -1;
141 ByteBuf outBuffer = null;
142 try {
143 ByteBuffer inNioBuffer = CompressionUtil.safeNioBuffer(
144 source, source.readerIndex(), source.readableBytes());
145 inPosition = inNioBuffer.position();
146 assert inNioBuffer.hasRemaining();
147 current = inNioBuffer;
148
149
150 if (direct) {
151 outBuffer = alloc.directBuffer(outCapacity);
152 } else {
153 outBuffer = alloc.heapBuffer(outCapacity);
154 }
155 ByteBuffer target = outBuffer.internalNioBuffer(outBuffer.writerIndex(), outBuffer.writableBytes());
156 int position = target.position();
157 do {
158 do {
159 if (decompressingStream.read(target) == 0) {
160 break;
161 }
162 } while (decompressingStream.hasRemaining() && target.hasRemaining() && current.hasRemaining());
163 int written = target.position() - position;
164 if (written > 0) {
165 outBuffer.writerIndex(outBuffer.writerIndex() + written);
166 ByteBuf out = outBuffer;
167 outBuffer = null;
168 return out;
169 }
170 } while (decompressingStream.hasRemaining() && current.hasRemaining());
171 } catch (IOException e) {
172 throw new DecompressionException(e);
173 } finally {
174 if (outBuffer != null) {
175 outBuffer.release();
176 }
177
178 if (source != in) {
179 source.release();
180 }
181 ByteBuffer buffer = current;
182 current = null;
183 if (inPosition != -1) {
184 int read = buffer.position() - inPosition;
185 if (read > 0) {
186 in.skipBytes(read);
187 }
188 }
189 }
190 return null;
191 }
192
193 private ByteBuffer refill(@SuppressWarnings("unused") ByteBuffer toRefill) {
194 return current;
195 }
196
197 void close() {
198 decompressingStream.close();
199 }
200 }
201 }