1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelDuplexHandler;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelInboundHandlerAdapter;
25 import io.netty.channel.ChannelInitializer;
26 import io.netty.channel.ChannelOption;
27 import io.netty.channel.WriteBufferWaterMark;
28 import io.netty.util.ReferenceCountUtil;
29 import org.junit.jupiter.api.Test;
30 import org.junit.jupiter.api.TestInfo;
31 import org.junit.jupiter.api.Timeout;
32
33 import java.util.concurrent.CountDownLatch;
34 import java.util.concurrent.TimeUnit;
35
36 public class SocketConditionalWritabilityTest extends AbstractSocketTest {
37 @Test
38 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
39 public void testConditionalWritability(TestInfo testInfo) throws Throwable {
40 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
41 @Override
42 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
43 testConditionalWritability(serverBootstrap, bootstrap);
44 }
45 });
46 }
47
48 public void testConditionalWritability(ServerBootstrap sb, Bootstrap cb) throws Throwable {
49 Channel serverChannel = null;
50 Channel clientChannel = null;
51 try {
52 final int expectedBytes = 100 * 1024 * 1024;
53 final int maxWriteChunkSize = 16 * 1024;
54 final CountDownLatch latch = new CountDownLatch(1);
55 sb.childOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(8 * 1024, 16 * 1024));
56 sb.childHandler(new ChannelInitializer<Channel>() {
57 @Override
58 protected void initChannel(Channel ch) {
59 ch.pipeline().addLast(new ChannelDuplexHandler() {
60 private int bytesWritten;
61
62 @Override
63 public void channelRead(ChannelHandlerContext ctx, Object msg) {
64 ReferenceCountUtil.release(msg);
65 writeRemainingBytes(ctx);
66 }
67
68 @Override
69 public void flush(ChannelHandlerContext ctx) {
70 if (ctx.channel().isWritable()) {
71 writeRemainingBytes(ctx);
72 } else {
73 ctx.flush();
74 }
75 }
76
77 @Override
78 public void channelWritabilityChanged(ChannelHandlerContext ctx) {
79 if (ctx.channel().isWritable()) {
80 writeRemainingBytes(ctx);
81 }
82 ctx.fireChannelWritabilityChanged();
83 }
84
85 private void writeRemainingBytes(ChannelHandlerContext ctx) {
86 while (ctx.channel().isWritable() && bytesWritten < expectedBytes) {
87 int chunkSize = Math.min(expectedBytes - bytesWritten, maxWriteChunkSize);
88 bytesWritten += chunkSize;
89 ctx.write(ctx.alloc().buffer(chunkSize).writeZero(chunkSize));
90 }
91 ctx.flush();
92 }
93 });
94 }
95 });
96
97 serverChannel = sb.bind().syncUninterruptibly().channel();
98
99 cb.handler(new ChannelInitializer<Channel>() {
100 @Override
101 protected void initChannel(Channel ch) {
102 ch.pipeline().addLast(new ChannelInboundHandlerAdapter() {
103 private int totalRead;
104 @Override
105 public void channelActive(ChannelHandlerContext ctx) {
106 ctx.writeAndFlush(ctx.alloc().buffer(1).writeByte(0));
107 }
108
109 @Override
110 public void channelRead(ChannelHandlerContext ctx, Object msg) {
111 if (msg instanceof ByteBuf) {
112 totalRead += ((ByteBuf) msg).readableBytes();
113 if (totalRead == expectedBytes) {
114 latch.countDown();
115 }
116 }
117 ReferenceCountUtil.release(msg);
118 }
119 });
120 }
121 });
122 clientChannel = cb.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
123 latch.await();
124 } finally {
125 if (serverChannel != null) {
126 serverChannel.close();
127 }
128 if (clientChannel != null) {
129 clientChannel.close();
130 }
131 }
132 }
133 }