1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.handler.codec;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.buffer.ByteBufHolder;
20 import io.netty.buffer.CompositeByteBuf;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandler;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelPipeline;
26 import io.netty.util.ReferenceCountUtil;
27
28 import java.util.List;
29
30 import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
31 import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 public abstract class MessageAggregator<I, S, C extends ByteBufHolder, O extends ByteBufHolder>
53 extends MessageToMessageDecoder<I> {
54
55 private static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024;
56
57 private final int maxContentLength;
58 private O currentMessage;
59 private boolean handlingOversizedMessage;
60
61 private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS;
62 private ChannelHandlerContext ctx;
63 private ChannelFutureListener continueResponseWriteListener;
64
65 private boolean aggregating;
66 private boolean handleIncompleteAggregateDuringClose = true;
67
68
69
70
71
72
73
74
75
76 protected MessageAggregator(int maxContentLength) {
77 validateMaxContentLength(maxContentLength);
78 this.maxContentLength = maxContentLength;
79 }
80
81 protected MessageAggregator(int maxContentLength, Class<? extends I> inboundMessageType) {
82 super(inboundMessageType);
83 validateMaxContentLength(maxContentLength);
84 this.maxContentLength = maxContentLength;
85 }
86
87 private static void validateMaxContentLength(int maxContentLength) {
88 checkPositiveOrZero(maxContentLength, "maxContentLength");
89 }
90
91 @Override
92 public boolean acceptInboundMessage(Object msg) throws Exception {
93
94 if (!super.acceptInboundMessage(msg)) {
95 return false;
96 }
97
98 @SuppressWarnings("unchecked")
99 I in = (I) msg;
100
101 if (isAggregated(in)) {
102 return false;
103 }
104
105
106
107 if (isStartMessage(in)) {
108 return true;
109 } else {
110 return aggregating && isContentMessage(in);
111 }
112 }
113
114
115
116
117
118
119
120
121 protected abstract boolean isStartMessage(I msg) throws Exception;
122
123
124
125
126
127
128
129
130 protected abstract boolean isContentMessage(I msg) throws Exception;
131
132
133
134
135
136
137
138
139
140
141
142
143 protected abstract boolean isLastContentMessage(C msg) throws Exception;
144
145
146
147
148
149 protected abstract boolean isAggregated(I msg) throws Exception;
150
151
152
153
154 public final int maxContentLength() {
155 return maxContentLength;
156 }
157
158
159
160
161
162
163
164 public final int maxCumulationBufferComponents() {
165 return maxCumulationBufferComponents;
166 }
167
168
169
170
171
172
173
174
175 public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) {
176 if (maxCumulationBufferComponents < 2) {
177 throw new IllegalArgumentException(
178 "maxCumulationBufferComponents: " + maxCumulationBufferComponents +
179 " (expected: >= 2)");
180 }
181
182 if (ctx == null) {
183 this.maxCumulationBufferComponents = maxCumulationBufferComponents;
184 } else {
185 throw new IllegalStateException(
186 "decoder properties cannot be changed once the decoder is added to a pipeline.");
187 }
188 }
189
190
191
192
193 @Deprecated
194 public final boolean isHandlingOversizedMessage() {
195 return handlingOversizedMessage;
196 }
197
198 protected final ChannelHandlerContext ctx() {
199 if (ctx == null) {
200 throw new IllegalStateException("not added to a pipeline yet");
201 }
202 return ctx;
203 }
204
205 @Override
206 protected void decode(final ChannelHandlerContext ctx, I msg, List<Object> out) throws Exception {
207 if (isStartMessage(msg)) {
208 aggregating = true;
209 handlingOversizedMessage = false;
210 if (currentMessage != null) {
211 currentMessage.release();
212 currentMessage = null;
213 throw new MessageAggregationException();
214 }
215
216 @SuppressWarnings("unchecked")
217 S m = (S) msg;
218
219
220
221 Object continueResponse = newContinueResponse(m, maxContentLength, ctx.pipeline());
222 if (continueResponse != null) {
223
224 ChannelFutureListener listener = continueResponseWriteListener;
225 if (listener == null) {
226 continueResponseWriteListener = listener = new ChannelFutureListener() {
227 @Override
228 public void operationComplete(ChannelFuture future) throws Exception {
229 if (!future.isSuccess()) {
230 ctx.fireExceptionCaught(future.cause());
231 }
232 }
233 };
234 }
235
236
237 boolean closeAfterWrite = closeAfterContinueResponse(continueResponse);
238 handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse);
239
240 final ChannelFuture future = ctx.writeAndFlush(continueResponse).addListener(listener);
241
242 if (closeAfterWrite) {
243 handleIncompleteAggregateDuringClose = false;
244 future.addListener(ChannelFutureListener.CLOSE);
245 return;
246 }
247 if (handlingOversizedMessage) {
248 return;
249 }
250 } else if (isContentLengthInvalid(m, maxContentLength)) {
251
252 invokeHandleOversizedMessage(ctx, m);
253 return;
254 }
255
256 if (m instanceof DecoderResultProvider && !((DecoderResultProvider) m).decoderResult().isSuccess()) {
257 O aggregated;
258 if (m instanceof ByteBufHolder) {
259 aggregated = beginAggregation(m, ((ByteBufHolder) m).content().retain());
260 } else {
261 aggregated = beginAggregation(m, EMPTY_BUFFER);
262 }
263 finishAggregation0(aggregated);
264 out.add(aggregated);
265 return;
266 }
267
268
269 CompositeByteBuf content = ctx.alloc().compositeBuffer(maxCumulationBufferComponents);
270 if (m instanceof ByteBufHolder) {
271 appendPartialContent(content, ((ByteBufHolder) m).content());
272 }
273 currentMessage = beginAggregation(m, content);
274 } else if (isContentMessage(msg)) {
275 if (currentMessage == null) {
276
277
278 return;
279 }
280
281
282 CompositeByteBuf content = (CompositeByteBuf) currentMessage.content();
283
284 @SuppressWarnings("unchecked")
285 final C m = (C) msg;
286
287 if (content.readableBytes() > maxContentLength - m.content().readableBytes()) {
288
289 @SuppressWarnings("unchecked")
290 S s = (S) currentMessage;
291 invokeHandleOversizedMessage(ctx, s);
292 return;
293 }
294
295
296 appendPartialContent(content, m.content());
297
298
299 aggregate(currentMessage, m);
300
301 final boolean last;
302 if (m instanceof DecoderResultProvider) {
303 DecoderResult decoderResult = ((DecoderResultProvider) m).decoderResult();
304 if (!decoderResult.isSuccess()) {
305 if (currentMessage instanceof DecoderResultProvider) {
306 ((DecoderResultProvider) currentMessage).setDecoderResult(
307 DecoderResult.failure(decoderResult.cause()));
308 }
309 last = true;
310 } else {
311 last = isLastContentMessage(m);
312 }
313 } else {
314 last = isLastContentMessage(m);
315 }
316
317 if (last) {
318 finishAggregation0(currentMessage);
319
320
321 out.add(currentMessage);
322 currentMessage = null;
323 }
324 } else {
325 throw new MessageAggregationException();
326 }
327 }
328
329 private static void appendPartialContent(CompositeByteBuf content, ByteBuf partialContent) {
330 if (partialContent.isReadable()) {
331 content.addComponent(true, partialContent.retain());
332 }
333 }
334
335
336
337
338
339
340
341
342
343 protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception;
344
345
346
347
348
349
350
351 protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline)
352 throws Exception;
353
354
355
356
357
358
359
360
361 protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception;
362
363
364
365
366
367
368
369
370
371 protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception;
372
373
374
375
376
377
378 protected abstract O beginAggregation(S start, ByteBuf content) throws Exception;
379
380
381
382
383
384
385
386 protected void aggregate(O aggregated, C content) throws Exception { }
387
388 private void finishAggregation0(O aggregated) throws Exception {
389 aggregating = false;
390 finishAggregation(aggregated);
391 }
392
393
394
395
396 protected void finishAggregation(O aggregated) throws Exception { }
397
398 private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
399 handlingOversizedMessage = true;
400 currentMessage = null;
401 handleIncompleteAggregateDuringClose = false;
402 try {
403 handleOversizedMessage(ctx, oversized);
404 } finally {
405
406 ReferenceCountUtil.release(oversized);
407 }
408 }
409
410
411
412
413
414
415
416
417 protected void handleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception {
418 ctx.fireExceptionCaught(
419 new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes."));
420 }
421
422 @Override
423 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
424
425
426
427 if (currentMessage != null && !ctx.channel().config().isAutoRead()) {
428 ctx.read();
429 }
430 ctx.fireChannelReadComplete();
431 }
432
433 @Override
434 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
435 if (aggregating && handleIncompleteAggregateDuringClose) {
436 ctx.fireExceptionCaught(
437 new PrematureChannelClosureException("Channel closed while still aggregating message"));
438 }
439 try {
440
441 super.channelInactive(ctx);
442 } finally {
443 releaseCurrentMessage();
444 }
445 }
446
447 @Override
448 public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
449 this.ctx = ctx;
450 }
451
452 @Override
453 public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
454 try {
455 super.handlerRemoved(ctx);
456 } finally {
457
458
459 releaseCurrentMessage();
460 }
461 }
462
463 private void releaseCurrentMessage() {
464 if (currentMessage != null) {
465 currentMessage.release();
466 currentMessage = null;
467 handlingOversizedMessage = false;
468 aggregating = false;
469 }
470 }
471 }