1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.buffer.ByteBufAllocator;
22 import io.netty.buffer.Unpooled;
23 import io.netty.channel.Channel;
24 import io.netty.channel.ChannelConfig;
25 import io.netty.channel.ChannelHandlerContext;
26 import io.netty.channel.ChannelInboundHandlerAdapter;
27 import io.netty.channel.ChannelInitializer;
28 import io.netty.channel.ChannelOption;
29 import io.netty.channel.RecvByteBufAllocator;
30 import io.netty.util.ReferenceCountUtil;
31 import io.netty.util.UncheckedBooleanSupplier;
32 import org.junit.jupiter.api.Test;
33 import org.junit.jupiter.api.TestInfo;
34 import org.junit.jupiter.api.Timeout;
35
36 import java.util.concurrent.CountDownLatch;
37 import java.util.concurrent.TimeUnit;
38 import java.util.concurrent.atomic.AtomicInteger;
39
40 import static org.junit.jupiter.api.Assertions.assertEquals;
41 import static org.junit.jupiter.api.Assertions.assertFalse;
42 import static org.junit.jupiter.api.Assertions.assertTrue;
43
44 public class SocketReadPendingTest extends AbstractSocketTest {
45 @Test
46 @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS)
47 public void testReadPendingIsResetAfterEachRead(TestInfo testInfo) throws Throwable {
48 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
49 @Override
50 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
51 testReadPendingIsResetAfterEachRead(serverBootstrap, bootstrap);
52 }
53 });
54 }
55
56 public void testReadPendingIsResetAfterEachRead(ServerBootstrap sb, Bootstrap cb) throws Throwable {
57 Channel serverChannel = null;
58 Channel clientChannel = null;
59 try {
60 ReadPendingInitializer serverInitializer = new ReadPendingInitializer();
61 ReadPendingInitializer clientInitializer = new ReadPendingInitializer();
62 sb.option(ChannelOption.SO_BACKLOG, 1024)
63 .option(ChannelOption.AUTO_READ, true)
64 .childOption(ChannelOption.AUTO_READ, false)
65
66 .childOption(ChannelOption.RCVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
67 .childHandler(serverInitializer);
68
69 serverChannel = sb.bind().syncUninterruptibly().channel();
70
71 cb.option(ChannelOption.AUTO_READ, false)
72
73 .option(ChannelOption.RCVBUF_ALLOCATOR, new TestNumReadsRecvByteBufAllocator(2))
74 .handler(clientInitializer);
75 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
76
77
78 clientChannel.writeAndFlush(Unpooled.wrappedBuffer(new byte[4]));
79
80
81 assertTrue(serverInitializer.channelInitLatch.await(5, TimeUnit.SECONDS));
82 serverInitializer.channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[4]));
83
84 serverInitializer.channel.read();
85 serverInitializer.readPendingHandler.assertAllRead();
86
87 clientChannel.read();
88 clientInitializer.readPendingHandler.assertAllRead();
89 } finally {
90 if (serverChannel != null) {
91 serverChannel.close().syncUninterruptibly();
92 }
93 if (clientChannel != null) {
94 clientChannel.close().syncUninterruptibly();
95 }
96 }
97 }
98
99 private static class ReadPendingInitializer extends ChannelInitializer<Channel> {
100 final ReadPendingReadHandler readPendingHandler = new ReadPendingReadHandler();
101 final CountDownLatch channelInitLatch = new CountDownLatch(1);
102 volatile Channel channel;
103
104 @Override
105 protected void initChannel(Channel ch) throws Exception {
106 channel = ch;
107 ch.pipeline().addLast(readPendingHandler);
108 channelInitLatch.countDown();
109 }
110 }
111
112 private static final class ReadPendingReadHandler extends ChannelInboundHandlerAdapter {
113 private final AtomicInteger count = new AtomicInteger();
114 private final CountDownLatch latch = new CountDownLatch(1);
115 private final CountDownLatch latch2 = new CountDownLatch(2);
116
117 @Override
118 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
119 ReferenceCountUtil.release(msg);
120 if (count.incrementAndGet() == 1) {
121
122 ctx.read();
123 }
124 }
125
126 @Override
127 public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
128 latch.countDown();
129 latch2.countDown();
130 }
131
132 void assertAllRead() throws InterruptedException {
133 assertTrue(latch.await(5, TimeUnit.SECONDS));
134
135 assertFalse(latch2.await(1, TimeUnit.SECONDS));
136 assertEquals(2, count.get());
137 }
138 }
139
140
141
142
143 private static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator {
144 private final int numReads;
145 TestNumReadsRecvByteBufAllocator(int numReads) {
146 this.numReads = numReads;
147 }
148
149 @Override
150 public ExtendedHandle newHandle() {
151 return new ExtendedHandle() {
152 private int attemptedBytesRead;
153 private int lastBytesRead;
154 private int numMessagesRead;
155 @Override
156 public ByteBuf allocate(ByteBufAllocator alloc) {
157 return alloc.ioBuffer(guess(), guess());
158 }
159
160 @Override
161 public int guess() {
162 return 1;
163 }
164
165 @Override
166 public void reset(ChannelConfig config) {
167 numMessagesRead = 0;
168 }
169
170 @Override
171 public void incMessagesRead(int numMessages) {
172 numMessagesRead += numMessages;
173 }
174
175 @Override
176 public void lastBytesRead(int bytes) {
177 lastBytesRead = bytes;
178 }
179
180 @Override
181 public int lastBytesRead() {
182 return lastBytesRead;
183 }
184
185 @Override
186 public void attemptedBytesRead(int bytes) {
187 attemptedBytesRead = bytes;
188 }
189
190 @Override
191 public int attemptedBytesRead() {
192 return attemptedBytesRead;
193 }
194
195 @Override
196 public boolean continueReading() {
197 return numMessagesRead < numReads;
198 }
199
200 @Override
201 public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) {
202 return continueReading();
203 }
204
205 @Override
206 public void readComplete() {
207
208 }
209 };
210 }
211 }
212 }