1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.Unpooled;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInboundHandlerAdapter;
24 import io.netty.channel.ChannelInitializer;
25 import io.netty.channel.ChannelOption;
26 import io.netty.channel.ChannelPipeline;
27 import io.netty.util.ReferenceCountUtil;
28 import org.junit.jupiter.api.Test;
29
30 import java.util.concurrent.CountDownLatch;
31 import java.util.concurrent.TimeUnit;
32 import java.util.concurrent.atomic.AtomicLong;
33 import org.junit.jupiter.api.TestInfo;
34
35 import static org.junit.jupiter.api.Assertions.assertFalse;
36 import static org.junit.jupiter.api.Assertions.assertTrue;
37
38 public class SocketExceptionHandlingTest extends AbstractSocketTest {
39 @Test
40 public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
41 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
42 @Override
43 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
44 testReadPendingIsResetAfterEachRead(serverBootstrap, bootstrap);
45 }
46 });
47 }
48
49 public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
50 Channel serverChannel = null;
51 Channel clientChannel = null;
52 try {
53 MyInitializer serverInitializer = new MyInitializer();
54 sb.option(ChannelOption.SO_BACKLOG, 1024);
55 sb.childHandler(serverInitializer);
56
57 serverChannel = sb.bind().syncUninterruptibly().channel();
58
59 cb.handler(new MyInitializer());
60 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
61
62 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(new byte[1024]));
63
64
65 assertTrue(serverInitializer.exceptionHandler.latch1.await(5, TimeUnit.SECONDS));
66
67
68 assertFalse(serverInitializer.exceptionHandler.latch2.await(1, TimeUnit.SECONDS),
69 "Encountered " + serverInitializer.exceptionHandler.count.get() +
70 " exceptions when 1 was expected");
71 } finally {
72 if (serverChannel != null) {
73 serverChannel.close().syncUninterruptibly();
74 }
75 if (clientChannel != null) {
76 clientChannel.close().syncUninterruptibly();
77 }
78 }
79 }
80
81 private static class MyInitializer extends ChannelInitializer<Channel> {
82 final ExceptionHandler exceptionHandler = new ExceptionHandler();
83 @Override
84 protected void initChannel(Channel ch) throws Exception {
85 ChannelPipeline pipeline = ch.pipeline();
86
87 pipeline.addLast(new BuggyChannelHandler());
88 pipeline.addLast(exceptionHandler);
89 }
90 }
91
92 private static class BuggyChannelHandler extends ChannelInboundHandlerAdapter {
93 @Override
94 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
95 ReferenceCountUtil.release(msg);
96 throw new NullPointerException("I am a bug!");
97 }
98 }
99
100 private static class ExceptionHandler extends ChannelInboundHandlerAdapter {
101 final AtomicLong count = new AtomicLong();
102
103
104
105 final CountDownLatch latch1 = new CountDownLatch(1);
106 final CountDownLatch latch2 = new CountDownLatch(1);
107
108 @Override
109 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
110 if (count.incrementAndGet() <= 2) {
111 latch1.countDown();
112 } else {
113 latch2.countDown();
114 }
115
116 ctx.close();
117 }
118 }
119 }