1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.bootstrap.ServerBootstrap;
20 import io.netty.buffer.ByteBuf;
21 import io.netty.channel.Channel;
22 import io.netty.channel.ChannelFuture;
23 import io.netty.channel.ChannelFutureListener;
24 import io.netty.channel.ChannelHandlerContext;
25 import io.netty.channel.ChannelInboundHandlerAdapter;
26 import io.netty.channel.ChannelInitializer;
27 import io.netty.channel.ChannelOption;
28 import io.netty.channel.socket.SocketChannel;
29 import io.netty.util.concurrent.ImmediateEventExecutor;
30 import io.netty.util.concurrent.Promise;
31 import org.junit.jupiter.api.Test;
32 import org.junit.jupiter.api.TestInfo;
33 import org.junit.jupiter.api.Timeout;
34
35 import java.io.ByteArrayOutputStream;
36 import java.net.InetSocketAddress;
37 import java.net.SocketAddress;
38 import java.util.concurrent.BlockingQueue;
39 import java.util.concurrent.LinkedBlockingQueue;
40 import java.util.concurrent.Semaphore;
41 import java.util.concurrent.TimeUnit;
42
43 import static io.netty.buffer.ByteBufUtil.writeAscii;
44 import static io.netty.buffer.UnpooledByteBufAllocator.DEFAULT;
45 import static io.netty.util.CharsetUtil.US_ASCII;
46 import static org.junit.jupiter.api.Assertions.assertEquals;
47 import static org.junit.jupiter.api.Assertions.assertFalse;
48 import static org.junit.jupiter.api.Assertions.assertNotNull;
49 import static org.junit.jupiter.api.Assertions.assertNull;
50 import static org.junit.jupiter.api.Assertions.assertTrue;
51
52 public class SocketConnectTest extends AbstractSocketTest {
53
54 @Test
55 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
56 public void testLocalAddressAfterConnect(TestInfo testInfo) throws Throwable {
57 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
58 @Override
59 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
60 testLocalAddressAfterConnect(serverBootstrap, bootstrap);
61 }
62 });
63 }
64
65 public void testLocalAddressAfterConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
66 Channel serverChannel = null;
67 Channel clientChannel = null;
68 try {
69 final Promise<InetSocketAddress> localAddressPromise = ImmediateEventExecutor.INSTANCE.newPromise();
70 serverChannel = sb.childHandler(new ChannelInboundHandlerAdapter() {
71 @Override
72 public void channelActive(ChannelHandlerContext ctx) throws Exception {
73 localAddressPromise.setSuccess((InetSocketAddress) ctx.channel().localAddress());
74 }
75 }).bind().syncUninterruptibly().channel();
76
77 clientChannel = cb.handler(new ChannelInboundHandlerAdapter()).register().syncUninterruptibly().channel();
78
79 assertNull(clientChannel.localAddress());
80 assertNull(clientChannel.remoteAddress());
81
82 clientChannel.connect(serverChannel.localAddress()).syncUninterruptibly().channel();
83 assertLocalAddress((InetSocketAddress) clientChannel.localAddress());
84 assertNotNull(clientChannel.remoteAddress());
85
86 assertLocalAddress(localAddressPromise.get());
87 } finally {
88 if (clientChannel != null) {
89 clientChannel.close().syncUninterruptibly();
90 }
91 if (serverChannel != null) {
92 serverChannel.close().syncUninterruptibly();
93 }
94 }
95 }
96
97 @Test
98 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
99 public void testChannelEventsFiredWhenClosedDirectly(TestInfo testInfo) throws Throwable {
100 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
101 @Override
102 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
103 testChannelEventsFiredWhenClosedDirectly(serverBootstrap, bootstrap);
104 }
105 });
106 }
107
108 public void testChannelEventsFiredWhenClosedDirectly(ServerBootstrap sb, Bootstrap cb) throws Throwable {
109 final BlockingQueue<Integer> events = new LinkedBlockingQueue<Integer>();
110
111 Channel sc = null;
112 Channel cc = null;
113 try {
114 sb.childHandler(new ChannelInboundHandlerAdapter());
115 sc = sb.bind().syncUninterruptibly().channel();
116
117 cb.handler(new ChannelInboundHandlerAdapter() {
118 @Override
119 public void channelActive(ChannelHandlerContext ctx) throws Exception {
120 events.add(0);
121 }
122
123 @Override
124 public void channelInactive(ChannelHandlerContext ctx) throws Exception {
125 events.add(1);
126 }
127 });
128
129 cc = cb.connect(sc.localAddress()).addListener(ChannelFutureListener.CLOSE).
130 syncUninterruptibly().channel();
131 assertEquals(0, events.take().intValue());
132 assertEquals(1, events.take().intValue());
133 } finally {
134 if (cc != null) {
135 cc.close();
136 }
137 if (sc != null) {
138 sc.close();
139 }
140 }
141 }
142
143 @Test
144 @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS)
145 public void testWriteWithFastOpenBeforeConnect(TestInfo testInfo) throws Throwable {
146 run(testInfo, new Runner<ServerBootstrap, Bootstrap>() {
147 @Override
148 public void run(ServerBootstrap serverBootstrap, Bootstrap bootstrap) throws Throwable {
149 testWriteWithFastOpenBeforeConnect(serverBootstrap, bootstrap);
150 }
151 });
152 }
153
154 public void testWriteWithFastOpenBeforeConnect(ServerBootstrap sb, Bootstrap cb) throws Throwable {
155 enableTcpFastOpen(sb, cb);
156 sb.childOption(ChannelOption.AUTO_READ, true);
157 cb.option(ChannelOption.AUTO_READ, true);
158
159 sb.childHandler(new ChannelInitializer<SocketChannel>() {
160 @Override
161 protected void initChannel(SocketChannel ch) throws Exception {
162 ch.pipeline().addLast(new EchoServerHandler());
163 }
164 });
165
166 Channel sc = sb.bind().sync().channel();
167 connectAndVerifyDataTransfer(cb, sc);
168 connectAndVerifyDataTransfer(cb, sc);
169 }
170
171 private static void connectAndVerifyDataTransfer(Bootstrap cb, Channel sc)
172 throws InterruptedException {
173 BufferingClientHandler handler = new BufferingClientHandler();
174 cb.handler(handler);
175 ChannelFuture register = cb.register();
176 Channel channel = register.sync().channel();
177 ChannelFuture write = channel.write(writeAscii(DEFAULT, "[fastopen]"));
178 SocketAddress remoteAddress = sc.localAddress();
179 ChannelFuture connectFuture = channel.connect(remoteAddress);
180 Channel cc = connectFuture.sync().channel();
181 cc.writeAndFlush(writeAscii(DEFAULT, "[normal data]")).sync();
182 write.sync();
183 String expectedString = "[fastopen][normal data]";
184 String result = handler.collectBuffer(expectedString.getBytes(US_ASCII).length);
185 cc.disconnect().sync();
186 assertEquals(expectedString, result);
187 }
188
189 protected void enableTcpFastOpen(ServerBootstrap sb, Bootstrap cb) {
190
191 sb.option(ChannelOption.TCP_FASTOPEN, 5);
192 cb.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
193 }
194
195 private static void assertLocalAddress(InetSocketAddress address) {
196 assertTrue(address.getPort() > 0);
197 assertFalse(address.getAddress().isAnyLocalAddress());
198 }
199
200 private static class BufferingClientHandler extends ChannelInboundHandlerAdapter {
201 private final Semaphore semaphore = new Semaphore(0);
202 private final ByteArrayOutputStream streamBuffer = new ByteArrayOutputStream();
203
204 @Override
205 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
206 if (msg instanceof ByteBuf) {
207 ByteBuf buf = (ByteBuf) msg;
208 int readableBytes = buf.readableBytes();
209 buf.readBytes(streamBuffer, readableBytes);
210 semaphore.release(readableBytes);
211 buf.release();
212 } else {
213 throw new IllegalArgumentException("Unexpected message type: " + msg);
214 }
215 }
216
217 String collectBuffer(int expectedBytes) throws InterruptedException {
218 semaphore.acquire(expectedBytes);
219 byte[] bytes = streamBuffer.toByteArray();
220 streamBuffer.reset();
221 return new String(bytes, US_ASCII);
222 }
223 }
224
225 private static final class EchoServerHandler extends ChannelInboundHandlerAdapter {
226 @Override
227 public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
228 if (msg instanceof ByteBuf) {
229 ByteBuf buffer = ctx.alloc().buffer();
230 ByteBuf buf = (ByteBuf) msg;
231 buffer.writeBytes(buf);
232 buf.release();
233 ctx.channel().writeAndFlush(buffer);
234 } else {
235 throw new IllegalArgumentException("Unexpected message type: " + msg);
236 }
237 }
238 }
239 }