1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.testsuite.transport.socket;
17
18 import io.netty.bootstrap.Bootstrap;
19 import io.netty.buffer.ByteBuf;
20 import io.netty.buffer.Unpooled;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelHandlerContext;
23 import io.netty.channel.ChannelInboundHandlerAdapter;
24 import io.netty.channel.ChannelOption;
25 import io.netty.channel.SimpleChannelInboundHandler;
26 import io.netty.channel.WriteBufferWaterMark;
27 import io.netty.channel.socket.SocketChannel;
28 import io.netty.channel.socket.oio.OioSocketChannel;
29 import org.junit.jupiter.api.Disabled;
30 import org.junit.jupiter.api.Test;
31 import org.junit.jupiter.api.TestInfo;
32 import org.junit.jupiter.api.Timeout;
33
34 import java.net.ServerSocket;
35 import java.net.Socket;
36 import java.net.SocketException;
37 import java.nio.channels.ClosedChannelException;
38 import java.util.concurrent.BlockingDeque;
39 import java.util.concurrent.BlockingQueue;
40 import java.util.concurrent.LinkedBlockingDeque;
41 import java.util.concurrent.LinkedBlockingQueue;
42 import java.util.concurrent.TimeUnit;
43
44 import static org.junit.jupiter.api.Assertions.assertEquals;
45 import static org.junit.jupiter.api.Assertions.assertFalse;
46 import static org.junit.jupiter.api.Assertions.assertNull;
47 import static org.junit.jupiter.api.Assertions.assertTrue;
48 import static org.junit.jupiter.api.Assertions.fail;
49 import static org.junit.jupiter.api.Assumptions.assumeFalse;
50
51 public class SocketShutdownOutputBySelfTest extends AbstractClientSocketTest {
52
53 @Test
54 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
55 public void testShutdownOutput(TestInfo testInfo) throws Throwable {
56 run(testInfo, new Runner<Bootstrap>() {
57 @Override
58 public void run(Bootstrap bootstrap) throws Throwable {
59 testShutdownOutput(bootstrap);
60 }
61 });
62 }
63
64 public void testShutdownOutput(Bootstrap cb) throws Throwable {
65 TestHandler h = new TestHandler();
66 ServerSocket ss = new ServerSocket();
67 Socket s = null;
68 SocketChannel ch = null;
69 try {
70 ss.bind(newSocketAddress());
71 ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
72 assertTrue(ch.isActive());
73 assertFalse(ch.isOutputShutdown());
74
75 s = ss.accept();
76 ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[] { 1 })).sync();
77 assertEquals(1, s.getInputStream().read());
78
79 assertTrue(h.ch.isOpen());
80 assertTrue(h.ch.isActive());
81 assertFalse(h.ch.isInputShutdown());
82 assertFalse(h.ch.isOutputShutdown());
83
84
85 ch.shutdownOutput().sync();
86 assertEquals(-1, s.getInputStream().read());
87
88 assertTrue(h.ch.isOpen());
89 assertTrue(h.ch.isActive());
90 assertFalse(h.ch.isInputShutdown());
91 assertTrue(h.ch.isOutputShutdown());
92
93
94 s.getOutputStream().write(new byte[] { 1 });
95 assertEquals(1, (int) h.queue.take());
96 } finally {
97 if (s != null) {
98 s.close();
99 }
100 if (ch != null) {
101 ch.close();
102 }
103 ss.close();
104 }
105 }
106
107 @Test
108 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
109 public void testShutdownOutputAfterClosed(TestInfo testInfo) throws Throwable {
110 run(testInfo, new Runner<Bootstrap>() {
111 @Override
112 public void run(Bootstrap bootstrap) throws Throwable {
113 testShutdownOutputAfterClosed(bootstrap);
114 }
115 });
116 }
117
118 public void testShutdownOutputAfterClosed(Bootstrap cb) throws Throwable {
119 TestHandler h = new TestHandler();
120 ServerSocket ss = new ServerSocket();
121 Socket s = null;
122 try {
123 ss.bind(newSocketAddress());
124 SocketChannel ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
125 assertTrue(ch.isActive());
126 s = ss.accept();
127
128 ch.close().syncUninterruptibly();
129 try {
130 ch.shutdownInput().syncUninterruptibly();
131 fail();
132 } catch (Throwable cause) {
133 checkThrowable(cause);
134 }
135 try {
136 ch.shutdownOutput().syncUninterruptibly();
137 fail();
138 } catch (Throwable cause) {
139 checkThrowable(cause);
140 }
141 } finally {
142 if (s != null) {
143 s.close();
144 }
145 ss.close();
146 }
147 }
148
149 @Disabled
150 @Test
151 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
152 public void testWriteAfterShutdownOutputNoWritabilityChange(TestInfo testInfo) throws Throwable {
153 run(testInfo, new Runner<Bootstrap>() {
154 @Override
155 public void run(Bootstrap bootstrap) throws Throwable {
156 testWriteAfterShutdownOutputNoWritabilityChange(bootstrap);
157 }
158 });
159 }
160
161 public void testWriteAfterShutdownOutputNoWritabilityChange(Bootstrap cb) throws Throwable {
162 final TestHandler h = new TestHandler();
163 ServerSocket ss = new ServerSocket();
164 Socket s = null;
165 SocketChannel ch = null;
166 try {
167 ss.bind(newSocketAddress());
168 cb.option(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(2, 4));
169 ch = (SocketChannel) cb.handler(h).connect(ss.getLocalSocketAddress()).sync().channel();
170 assumeFalse(ch instanceof OioSocketChannel);
171 assertTrue(ch.isActive());
172 assertFalse(ch.isOutputShutdown());
173
174 s = ss.accept();
175
176 byte[] expectedBytes = new byte[]{ 1, 2, 3, 4, 5, 6 };
177 ChannelFuture writeFuture = ch.write(Unpooled.wrappedBuffer(expectedBytes));
178 h.assertWritability(false);
179 ch.flush();
180 writeFuture.sync();
181 h.assertWritability(true);
182 for (int i = 0; i < expectedBytes.length; ++i) {
183 assertEquals(expectedBytes[i], s.getInputStream().read());
184 }
185
186 assertTrue(h.ch.isOpen());
187 assertTrue(h.ch.isActive());
188 assertFalse(h.ch.isInputShutdown());
189 assertFalse(h.ch.isOutputShutdown());
190
191
192 ch.shutdownOutput().sync();
193 assertEquals(-1, s.getInputStream().read());
194
195 assertTrue(h.ch.isOpen());
196 assertTrue(h.ch.isActive());
197 assertFalse(h.ch.isInputShutdown());
198 assertTrue(h.ch.isOutputShutdown());
199
200 try {
201
202 ch.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{ 2 })).sync();
203 fail();
204 } catch (Throwable cause) {
205 checkThrowable(cause);
206 }
207 assertNull(h.writabilityQueue.poll());
208 } finally {
209 if (s != null) {
210 s.close();
211 }
212 if (ch != null) {
213 ch.close();
214 }
215 ss.close();
216 }
217 }
218
219 @Test
220 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
221 public void testShutdownOutputSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
222 run(testInfo, new Runner<Bootstrap>() {
223 @Override
224 public void run(Bootstrap bootstrap) throws Throwable {
225 testShutdownOutputSoLingerNoAssertError(bootstrap);
226 }
227 });
228 }
229
230 public void testShutdownOutputSoLingerNoAssertError(Bootstrap cb) throws Throwable {
231 testShutdownSoLingerNoAssertError0(cb, true);
232 }
233
234 @Test
235 @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS)
236 public void testShutdownSoLingerNoAssertError(TestInfo testInfo) throws Throwable {
237 run(testInfo, new Runner<Bootstrap>() {
238 @Override
239 public void run(Bootstrap bootstrap) throws Throwable {
240 testShutdownSoLingerNoAssertError(bootstrap);
241 }
242 });
243 }
244
245 public void testShutdownSoLingerNoAssertError(Bootstrap cb) throws Throwable {
246 testShutdownSoLingerNoAssertError0(cb, false);
247 }
248
249 private void testShutdownSoLingerNoAssertError0(Bootstrap cb, boolean output) throws Throwable {
250 ServerSocket ss = new ServerSocket();
251 Socket s = null;
252
253 ChannelFuture cf = null;
254 try {
255 ss.bind(newSocketAddress());
256 cf = cb.option(ChannelOption.SO_LINGER, 1).handler(new ChannelInboundHandlerAdapter())
257 .connect(ss.getLocalSocketAddress()).sync();
258 s = ss.accept();
259
260 cf.sync();
261
262 if (output) {
263 ((SocketChannel) cf.channel()).shutdownOutput().sync();
264 } else {
265 ((SocketChannel) cf.channel()).shutdown().sync();
266 }
267 } finally {
268 if (s != null) {
269 s.close();
270 }
271 if (cf != null) {
272 cf.channel().close();
273 }
274 ss.close();
275 }
276 }
277 private static void checkThrowable(Throwable cause) throws Throwable {
278
279 if (!(cause instanceof ClosedChannelException) && !(cause instanceof SocketException)) {
280 throw cause;
281 }
282 }
283
284 private static final class TestHandler extends SimpleChannelInboundHandler<ByteBuf> {
285 volatile SocketChannel ch;
286 final BlockingQueue<Byte> queue = new LinkedBlockingQueue<Byte>();
287 final BlockingDeque<Boolean> writabilityQueue = new LinkedBlockingDeque<Boolean>();
288
289 @Override
290 public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
291 writabilityQueue.add(ctx.channel().isWritable());
292 }
293
294 @Override
295 public void channelActive(ChannelHandlerContext ctx) throws Exception {
296 ch = (SocketChannel) ctx.channel();
297 }
298
299 @Override
300 public void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception {
301 queue.offer(msg.readByte());
302 }
303
304 private void drainWritabilityQueue() throws InterruptedException {
305 while ((writabilityQueue.poll(100, TimeUnit.MILLISECONDS)) != null) {
306
307 }
308 }
309
310 void assertWritability(boolean isWritable) throws InterruptedException {
311 try {
312 Boolean writability = writabilityQueue.takeLast();
313 assertEquals(isWritable, writability);
314
315 drainWritabilityQueue();
316 } catch (Throwable c) {
317 c.printStackTrace();
318 }
319 }
320 }
321 }