1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package io.netty.channel.socket.nio;
17
18 import io.netty.buffer.ByteBuf;
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelException;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelOption;
24 import io.netty.channel.ChannelOutboundBuffer;
25 import io.netty.channel.ChannelPromise;
26 import io.netty.channel.EventLoop;
27 import io.netty.channel.FileRegion;
28 import io.netty.channel.RecvByteBufAllocator;
29 import io.netty.channel.nio.AbstractNioByteChannel;
30 import io.netty.channel.socket.DefaultSocketChannelConfig;
31 import io.netty.channel.socket.InternetProtocolFamily;
32 import io.netty.channel.socket.ServerSocketChannel;
33 import io.netty.channel.socket.SocketChannelConfig;
34 import io.netty.util.concurrent.GlobalEventExecutor;
35 import io.netty.util.internal.PlatformDependent;
36 import io.netty.util.internal.SocketUtils;
37 import io.netty.util.internal.SuppressJava6Requirement;
38 import io.netty.util.internal.UnstableApi;
39 import io.netty.util.internal.logging.InternalLogger;
40 import io.netty.util.internal.logging.InternalLoggerFactory;
41
42 import java.io.IOException;
43 import java.lang.reflect.Method;
44 import java.net.InetSocketAddress;
45 import java.net.Socket;
46 import java.net.SocketAddress;
47 import java.nio.ByteBuffer;
48 import java.nio.channels.SelectionKey;
49 import java.nio.channels.SocketChannel;
50 import java.nio.channels.spi.SelectorProvider;
51 import java.util.Map;
52 import java.util.concurrent.Executor;
53
54 import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
55
56
57
58
59 public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
60 private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
61 private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
62
63 private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
64 SelectorProviderUtil.findOpenMethod("openSocketChannel");
65
66 private final SocketChannelConfig config;
67
68 private static SocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) {
69 try {
70 SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
71 return channel == null ? provider.openSocketChannel() : channel;
72 } catch (IOException e) {
73 throw new ChannelException("Failed to open a socket.", e);
74 }
75 }
76
77
78
79
80 public NioSocketChannel() {
81 this(DEFAULT_SELECTOR_PROVIDER);
82 }
83
84
85
86
87 public NioSocketChannel(SelectorProvider provider) {
88 this(provider, null);
89 }
90
91
92
93
94 public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) {
95 this(newChannel(provider, family));
96 }
97
98
99
100
101 public NioSocketChannel(SocketChannel socket) {
102 this(null, socket);
103 }
104
105
106
107
108
109
110
111 public NioSocketChannel(Channel parent, SocketChannel socket) {
112 super(parent, socket);
113 config = new NioSocketChannelConfig(this, socket.socket());
114 }
115
116 @Override
117 public ServerSocketChannel parent() {
118 return (ServerSocketChannel) super.parent();
119 }
120
121 @Override
122 public SocketChannelConfig config() {
123 return config;
124 }
125
126 @Override
127 protected SocketChannel javaChannel() {
128 return (SocketChannel) super.javaChannel();
129 }
130
131 @Override
132 public boolean isActive() {
133 SocketChannel ch = javaChannel();
134 return ch.isOpen() && ch.isConnected();
135 }
136
137 @Override
138 public boolean isOutputShutdown() {
139 return javaChannel().socket().isOutputShutdown() || !isActive();
140 }
141
142 @Override
143 public boolean isInputShutdown() {
144 return javaChannel().socket().isInputShutdown() || !isActive();
145 }
146
147 @Override
148 public boolean isShutdown() {
149 Socket socket = javaChannel().socket();
150 return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
151 }
152
153 @Override
154 public InetSocketAddress localAddress() {
155 return (InetSocketAddress) super.localAddress();
156 }
157
158 @Override
159 public InetSocketAddress remoteAddress() {
160 return (InetSocketAddress) super.remoteAddress();
161 }
162
163 @SuppressJava6Requirement(reason = "Usage guarded by java version check")
164 @UnstableApi
165 @Override
166 protected final void doShutdownOutput() throws Exception {
167 if (PlatformDependent.javaVersion() >= 7) {
168 javaChannel().shutdownOutput();
169 } else {
170 javaChannel().socket().shutdownOutput();
171 }
172 }
173
174 @Override
175 public ChannelFuture shutdownOutput() {
176 return shutdownOutput(newPromise());
177 }
178
179 @Override
180 public ChannelFuture shutdownOutput(final ChannelPromise promise) {
181 final EventLoop loop = eventLoop();
182 if (loop.inEventLoop()) {
183 ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
184 } else {
185 loop.execute(new Runnable() {
186 @Override
187 public void run() {
188 ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
189 }
190 });
191 }
192 return promise;
193 }
194
195 @Override
196 public ChannelFuture shutdownInput() {
197 return shutdownInput(newPromise());
198 }
199
200 @Override
201 protected boolean isInputShutdown0() {
202 return isInputShutdown();
203 }
204
205 @Override
206 public ChannelFuture shutdownInput(final ChannelPromise promise) {
207 EventLoop loop = eventLoop();
208 if (loop.inEventLoop()) {
209 shutdownInput0(promise);
210 } else {
211 loop.execute(new Runnable() {
212 @Override
213 public void run() {
214 shutdownInput0(promise);
215 }
216 });
217 }
218 return promise;
219 }
220
221 @Override
222 public ChannelFuture shutdown() {
223 return shutdown(newPromise());
224 }
225
226 @Override
227 public ChannelFuture shutdown(final ChannelPromise promise) {
228 ChannelFuture shutdownOutputFuture = shutdownOutput();
229 if (shutdownOutputFuture.isDone()) {
230 shutdownOutputDone(shutdownOutputFuture, promise);
231 } else {
232 shutdownOutputFuture.addListener(new ChannelFutureListener() {
233 @Override
234 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
235 shutdownOutputDone(shutdownOutputFuture, promise);
236 }
237 });
238 }
239 return promise;
240 }
241
242 private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
243 ChannelFuture shutdownInputFuture = shutdownInput();
244 if (shutdownInputFuture.isDone()) {
245 shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
246 } else {
247 shutdownInputFuture.addListener(new ChannelFutureListener() {
248 @Override
249 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
250 shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
251 }
252 });
253 }
254 }
255
256 private static void shutdownDone(ChannelFuture shutdownOutputFuture,
257 ChannelFuture shutdownInputFuture,
258 ChannelPromise promise) {
259 Throwable shutdownOutputCause = shutdownOutputFuture.cause();
260 Throwable shutdownInputCause = shutdownInputFuture.cause();
261 if (shutdownOutputCause != null) {
262 if (shutdownInputCause != null) {
263 logger.debug("Exception suppressed because a previous exception occurred.",
264 shutdownInputCause);
265 }
266 promise.setFailure(shutdownOutputCause);
267 } else if (shutdownInputCause != null) {
268 promise.setFailure(shutdownInputCause);
269 } else {
270 promise.setSuccess();
271 }
272 }
273 private void shutdownInput0(final ChannelPromise promise) {
274 try {
275 shutdownInput0();
276 promise.setSuccess();
277 } catch (Throwable t) {
278 promise.setFailure(t);
279 }
280 }
281
282 @SuppressJava6Requirement(reason = "Usage guarded by java version check")
283 private void shutdownInput0() throws Exception {
284 if (PlatformDependent.javaVersion() >= 7) {
285 javaChannel().shutdownInput();
286 } else {
287 javaChannel().socket().shutdownInput();
288 }
289 }
290
291 @Override
292 protected SocketAddress localAddress0() {
293 return javaChannel().socket().getLocalSocketAddress();
294 }
295
296 @Override
297 protected SocketAddress remoteAddress0() {
298 return javaChannel().socket().getRemoteSocketAddress();
299 }
300
301 @Override
302 protected void doBind(SocketAddress localAddress) throws Exception {
303 doBind0(localAddress);
304 }
305
306 private void doBind0(SocketAddress localAddress) throws Exception {
307 if (PlatformDependent.javaVersion() >= 7) {
308 SocketUtils.bind(javaChannel(), localAddress);
309 } else {
310 SocketUtils.bind(javaChannel().socket(), localAddress);
311 }
312 }
313
314 @Override
315 protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
316 if (localAddress != null) {
317 doBind0(localAddress);
318 }
319
320 boolean success = false;
321 try {
322 boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
323 if (!connected) {
324 selectionKey().interestOps(SelectionKey.OP_CONNECT);
325 }
326 success = true;
327 return connected;
328 } finally {
329 if (!success) {
330 doClose();
331 }
332 }
333 }
334
335 @Override
336 protected void doFinishConnect() throws Exception {
337 if (!javaChannel().finishConnect()) {
338 throw new Error();
339 }
340 }
341
342 @Override
343 protected void doDisconnect() throws Exception {
344 doClose();
345 }
346
347 @Override
348 protected void doClose() throws Exception {
349 super.doClose();
350 javaChannel().close();
351 }
352
353 @Override
354 protected int doReadBytes(ByteBuf byteBuf) throws Exception {
355 final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
356 allocHandle.attemptedBytesRead(byteBuf.writableBytes());
357 return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
358 }
359
360 @Override
361 protected int doWriteBytes(ByteBuf buf) throws Exception {
362 final int expectedWrittenBytes = buf.readableBytes();
363 return buf.readBytes(javaChannel(), expectedWrittenBytes);
364 }
365
366 @Override
367 protected long doWriteFileRegion(FileRegion region) throws Exception {
368 final long position = region.transferred();
369 return region.transferTo(javaChannel(), position);
370 }
371
372 private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
373
374
375
376 if (attempted == written) {
377 if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
378 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
379 }
380 } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
381 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
382 }
383 }
384
385 @Override
386 protected void doWrite(ChannelOutboundBuffer in) throws Exception {
387 SocketChannel ch = javaChannel();
388 int writeSpinCount = config().getWriteSpinCount();
389 do {
390 if (in.isEmpty()) {
391
392 clearOpWrite();
393
394 return;
395 }
396
397
398 int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
399 ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
400 int nioBufferCnt = in.nioBufferCount();
401
402
403
404 switch (nioBufferCnt) {
405 case 0:
406
407 writeSpinCount -= doWrite0(in);
408 break;
409 case 1: {
410
411
412
413 ByteBuffer buffer = nioBuffers[0];
414 int attemptedBytes = buffer.remaining();
415 final int localWrittenBytes = ch.write(buffer);
416 if (localWrittenBytes <= 0) {
417 incompleteWrite(true);
418 return;
419 }
420 adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
421 in.removeBytes(localWrittenBytes);
422 --writeSpinCount;
423 break;
424 }
425 default: {
426
427
428
429 long attemptedBytes = in.nioBufferSize();
430 final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
431 if (localWrittenBytes <= 0) {
432 incompleteWrite(true);
433 return;
434 }
435
436 adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
437 maxBytesPerGatheringWrite);
438 in.removeBytes(localWrittenBytes);
439 --writeSpinCount;
440 break;
441 }
442 }
443 } while (writeSpinCount > 0);
444
445 incompleteWrite(writeSpinCount < 0);
446 }
447
448 @Override
449 protected AbstractNioUnsafe newUnsafe() {
450 return new NioSocketChannelUnsafe();
451 }
452
453 private final class NioSocketChannelUnsafe extends NioByteUnsafe {
454 @Override
455 protected Executor prepareToClose() {
456 try {
457 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
458
459
460
461
462 doDeregister();
463 return GlobalEventExecutor.INSTANCE;
464 }
465 } catch (Throwable ignore) {
466
467
468
469 }
470 return null;
471 }
472 }
473
474 private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
475 private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
476 private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
477 super(channel, javaSocket);
478 calculateMaxBytesPerGatheringWrite();
479 }
480
481 @Override
482 protected void autoReadCleared() {
483 clearReadPending();
484 }
485
486 @Override
487 public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
488 super.setSendBufferSize(sendBufferSize);
489 calculateMaxBytesPerGatheringWrite();
490 return this;
491 }
492
493 @Override
494 public <T> boolean setOption(ChannelOption<T> option, T value) {
495 if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
496 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
497 }
498 return super.setOption(option, value);
499 }
500
501 @Override
502 public <T> T getOption(ChannelOption<T> option) {
503 if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
504 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
505 }
506 return super.getOption(option);
507 }
508
509 @Override
510 public Map<ChannelOption<?>, Object> getOptions() {
511 if (PlatformDependent.javaVersion() >= 7) {
512 return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel()));
513 }
514 return super.getOptions();
515 }
516
517 void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
518 this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
519 }
520
521 int getMaxBytesPerGatheringWrite() {
522 return maxBytesPerGatheringWrite;
523 }
524
525 private void calculateMaxBytesPerGatheringWrite() {
526
527 int newSendBufferSize = getSendBufferSize() << 1;
528 if (newSendBufferSize > 0) {
529 setMaxBytesPerGatheringWrite(newSendBufferSize);
530 }
531 }
532
533 private SocketChannel jdkChannel() {
534 return ((NioSocketChannel) channel).javaChannel();
535 }
536 }
537 }