查看本类的 API文档回源码主页即时通讯网 - 即时通讯开发者社区!
1   /*
2    * Copyright 2012 The Netty Project
3    *
4    * The Netty Project licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package io.netty.channel.socket.nio;
17  
18  import io.netty.buffer.ByteBuf;
19  import io.netty.channel.Channel;
20  import io.netty.channel.ChannelException;
21  import io.netty.channel.ChannelFuture;
22  import io.netty.channel.ChannelFutureListener;
23  import io.netty.channel.ChannelOption;
24  import io.netty.channel.ChannelOutboundBuffer;
25  import io.netty.channel.ChannelPromise;
26  import io.netty.channel.EventLoop;
27  import io.netty.channel.FileRegion;
28  import io.netty.channel.RecvByteBufAllocator;
29  import io.netty.channel.nio.AbstractNioByteChannel;
30  import io.netty.channel.socket.DefaultSocketChannelConfig;
31  import io.netty.channel.socket.InternetProtocolFamily;
32  import io.netty.channel.socket.ServerSocketChannel;
33  import io.netty.channel.socket.SocketChannelConfig;
34  import io.netty.util.concurrent.GlobalEventExecutor;
35  import io.netty.util.internal.PlatformDependent;
36  import io.netty.util.internal.SocketUtils;
37  import io.netty.util.internal.SuppressJava6Requirement;
38  import io.netty.util.internal.UnstableApi;
39  import io.netty.util.internal.logging.InternalLogger;
40  import io.netty.util.internal.logging.InternalLoggerFactory;
41  
42  import java.io.IOException;
43  import java.lang.reflect.Method;
44  import java.net.InetSocketAddress;
45  import java.net.Socket;
46  import java.net.SocketAddress;
47  import java.nio.ByteBuffer;
48  import java.nio.channels.SelectionKey;
49  import java.nio.channels.SocketChannel;
50  import java.nio.channels.spi.SelectorProvider;
51  import java.util.Map;
52  import java.util.concurrent.Executor;
53  
54  import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD;
55  
56  /**
57   * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation.
58   */
59  public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel {
60      private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class);
61      private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider();
62  
63      private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY =
64              SelectorProviderUtil.findOpenMethod("openSocketChannel");
65  
66      private final SocketChannelConfig config;
67  
68      private static SocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) {
69          try {
70              SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family);
71              return channel == null ? provider.openSocketChannel() : channel;
72          } catch (IOException e) {
73              throw new ChannelException("Failed to open a socket.", e);
74          }
75      }
76  
77      /**
78       * Create a new instance
79       */
80      public NioSocketChannel() {
81          this(DEFAULT_SELECTOR_PROVIDER);
82      }
83  
84      /**
85       * Create a new instance using the given {@link SelectorProvider}.
86       */
87      public NioSocketChannel(SelectorProvider provider) {
88          this(provider, null);
89      }
90  
91      /**
92       * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15).
93       */
94      public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) {
95          this(newChannel(provider, family));
96      }
97  
98      /**
99       * Create a new instance using the given {@link SocketChannel}.
100      */
101     public NioSocketChannel(SocketChannel socket) {
102         this(null, socket);
103     }
104 
105     /**
106      * Create a new instance
107      *
108      * @param parent    the {@link Channel} which created this instance or {@code null} if it was created by the user
109      * @param socket    the {@link SocketChannel} which will be used
110      */
111     public NioSocketChannel(Channel parent, SocketChannel socket) {
112         super(parent, socket);
113         config = new NioSocketChannelConfig(this, socket.socket());
114     }
115 
116     @Override
117     public ServerSocketChannel parent() {
118         return (ServerSocketChannel) super.parent();
119     }
120 
121     @Override
122     public SocketChannelConfig config() {
123         return config;
124     }
125 
126     @Override
127     protected SocketChannel javaChannel() {
128         return (SocketChannel) super.javaChannel();
129     }
130 
131     @Override
132     public boolean isActive() {
133         SocketChannel ch = javaChannel();
134         return ch.isOpen() && ch.isConnected();
135     }
136 
137     @Override
138     public boolean isOutputShutdown() {
139         return javaChannel().socket().isOutputShutdown() || !isActive();
140     }
141 
142     @Override
143     public boolean isInputShutdown() {
144         return javaChannel().socket().isInputShutdown() || !isActive();
145     }
146 
147     @Override
148     public boolean isShutdown() {
149         Socket socket = javaChannel().socket();
150         return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive();
151     }
152 
153     @Override
154     public InetSocketAddress localAddress() {
155         return (InetSocketAddress) super.localAddress();
156     }
157 
158     @Override
159     public InetSocketAddress remoteAddress() {
160         return (InetSocketAddress) super.remoteAddress();
161     }
162 
163     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
164     @UnstableApi
165     @Override
166     protected final void doShutdownOutput() throws Exception {
167         if (PlatformDependent.javaVersion() >= 7) {
168             javaChannel().shutdownOutput();
169         } else {
170             javaChannel().socket().shutdownOutput();
171         }
172     }
173 
174     @Override
175     public ChannelFuture shutdownOutput() {
176         return shutdownOutput(newPromise());
177     }
178 
179     @Override
180     public ChannelFuture shutdownOutput(final ChannelPromise promise) {
181         final EventLoop loop = eventLoop();
182         if (loop.inEventLoop()) {
183             ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
184         } else {
185             loop.execute(new Runnable() {
186                 @Override
187                 public void run() {
188                     ((AbstractUnsafe) unsafe()).shutdownOutput(promise);
189                 }
190             });
191         }
192         return promise;
193     }
194 
195     @Override
196     public ChannelFuture shutdownInput() {
197         return shutdownInput(newPromise());
198     }
199 
200     @Override
201     protected boolean isInputShutdown0() {
202         return isInputShutdown();
203     }
204 
205     @Override
206     public ChannelFuture shutdownInput(final ChannelPromise promise) {
207         EventLoop loop = eventLoop();
208         if (loop.inEventLoop()) {
209             shutdownInput0(promise);
210         } else {
211             loop.execute(new Runnable() {
212                 @Override
213                 public void run() {
214                     shutdownInput0(promise);
215                 }
216             });
217         }
218         return promise;
219     }
220 
221     @Override
222     public ChannelFuture shutdown() {
223         return shutdown(newPromise());
224     }
225 
226     @Override
227     public ChannelFuture shutdown(final ChannelPromise promise) {
228         ChannelFuture shutdownOutputFuture = shutdownOutput();
229         if (shutdownOutputFuture.isDone()) {
230             shutdownOutputDone(shutdownOutputFuture, promise);
231         } else {
232             shutdownOutputFuture.addListener(new ChannelFutureListener() {
233                 @Override
234                 public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception {
235                     shutdownOutputDone(shutdownOutputFuture, promise);
236                 }
237             });
238         }
239         return promise;
240     }
241 
242     private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) {
243         ChannelFuture shutdownInputFuture = shutdownInput();
244         if (shutdownInputFuture.isDone()) {
245             shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
246         } else {
247             shutdownInputFuture.addListener(new ChannelFutureListener() {
248                 @Override
249                 public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception {
250                     shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise);
251                 }
252             });
253         }
254     }
255 
256     private static void shutdownDone(ChannelFuture shutdownOutputFuture,
257                                      ChannelFuture shutdownInputFuture,
258                                      ChannelPromise promise) {
259         Throwable shutdownOutputCause = shutdownOutputFuture.cause();
260         Throwable shutdownInputCause = shutdownInputFuture.cause();
261         if (shutdownOutputCause != null) {
262             if (shutdownInputCause != null) {
263                 logger.debug("Exception suppressed because a previous exception occurred.",
264                         shutdownInputCause);
265             }
266             promise.setFailure(shutdownOutputCause);
267         } else if (shutdownInputCause != null) {
268             promise.setFailure(shutdownInputCause);
269         } else {
270             promise.setSuccess();
271         }
272     }
273     private void shutdownInput0(final ChannelPromise promise) {
274         try {
275             shutdownInput0();
276             promise.setSuccess();
277         } catch (Throwable t) {
278             promise.setFailure(t);
279         }
280     }
281 
282     @SuppressJava6Requirement(reason = "Usage guarded by java version check")
283     private void shutdownInput0() throws Exception {
284         if (PlatformDependent.javaVersion() >= 7) {
285             javaChannel().shutdownInput();
286         } else {
287             javaChannel().socket().shutdownInput();
288         }
289     }
290 
291     @Override
292     protected SocketAddress localAddress0() {
293         return javaChannel().socket().getLocalSocketAddress();
294     }
295 
296     @Override
297     protected SocketAddress remoteAddress0() {
298         return javaChannel().socket().getRemoteSocketAddress();
299     }
300 
301     @Override
302     protected void doBind(SocketAddress localAddress) throws Exception {
303         doBind0(localAddress);
304     }
305 
306     private void doBind0(SocketAddress localAddress) throws Exception {
307         if (PlatformDependent.javaVersion() >= 7) {
308             SocketUtils.bind(javaChannel(), localAddress);
309         } else {
310             SocketUtils.bind(javaChannel().socket(), localAddress);
311         }
312     }
313 
314     @Override
315     protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception {
316         if (localAddress != null) {
317             doBind0(localAddress);
318         }
319 
320         boolean success = false;
321         try {
322             boolean connected = SocketUtils.connect(javaChannel(), remoteAddress);
323             if (!connected) {
324                 selectionKey().interestOps(SelectionKey.OP_CONNECT);
325             }
326             success = true;
327             return connected;
328         } finally {
329             if (!success) {
330                 doClose();
331             }
332         }
333     }
334 
335     @Override
336     protected void doFinishConnect() throws Exception {
337         if (!javaChannel().finishConnect()) {
338             throw new Error();
339         }
340     }
341 
342     @Override
343     protected void doDisconnect() throws Exception {
344         doClose();
345     }
346 
347     @Override
348     protected void doClose() throws Exception {
349         super.doClose();
350         javaChannel().close();
351     }
352 
353     @Override
354     protected int doReadBytes(ByteBuf byteBuf) throws Exception {
355         final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();
356         allocHandle.attemptedBytesRead(byteBuf.writableBytes());
357         return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead());
358     }
359 
360     @Override
361     protected int doWriteBytes(ByteBuf buf) throws Exception {
362         final int expectedWrittenBytes = buf.readableBytes();
363         return buf.readBytes(javaChannel(), expectedWrittenBytes);
364     }
365 
366     @Override
367     protected long doWriteFileRegion(FileRegion region) throws Exception {
368         final long position = region.transferred();
369         return region.transferTo(javaChannel(), position);
370     }
371 
372     private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) {
373         // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change
374         // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try
375         // make a best effort to adjust as OS behavior changes.
376         if (attempted == written) {
377             if (attempted << 1 > oldMaxBytesPerGatheringWrite) {
378                 ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1);
379             }
380         } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) {
381             ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1);
382         }
383     }
384 
385     @Override
386     protected void doWrite(ChannelOutboundBuffer in) throws Exception {
387         SocketChannel ch = javaChannel();
388         int writeSpinCount = config().getWriteSpinCount();
389         do {
390             if (in.isEmpty()) {
391                 // All written so clear OP_WRITE
392                 clearOpWrite();
393                 // Directly return here so incompleteWrite(...) is not called.
394                 return;
395             }
396 
397             // Ensure the pending writes are made of ByteBufs only.
398             int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite();
399             ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite);
400             int nioBufferCnt = in.nioBufferCount();
401 
402             // Always use nioBuffers() to workaround data-corruption.
403             // See https://github.com/netty/netty/issues/2761
404             switch (nioBufferCnt) {
405                 case 0:
406                     // We have something else beside ByteBuffers to write so fallback to normal writes.
407                     writeSpinCount -= doWrite0(in);
408                     break;
409                 case 1: {
410                     // Only one ByteBuf so use non-gathering write
411                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
412                     // to check if the total size of all the buffers is non-zero.
413                     ByteBuffer buffer = nioBuffers[0];
414                     int attemptedBytes = buffer.remaining();
415                     final int localWrittenBytes = ch.write(buffer);
416                     if (localWrittenBytes <= 0) {
417                         incompleteWrite(true);
418                         return;
419                     }
420                     adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite);
421                     in.removeBytes(localWrittenBytes);
422                     --writeSpinCount;
423                     break;
424                 }
425                 default: {
426                     // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need
427                     // to check if the total size of all the buffers is non-zero.
428                     // We limit the max amount to int above so cast is safe
429                     long attemptedBytes = in.nioBufferSize();
430                     final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt);
431                     if (localWrittenBytes <= 0) {
432                         incompleteWrite(true);
433                         return;
434                     }
435                     // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above.
436                     adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes,
437                             maxBytesPerGatheringWrite);
438                     in.removeBytes(localWrittenBytes);
439                     --writeSpinCount;
440                     break;
441                 }
442             }
443         } while (writeSpinCount > 0);
444 
445         incompleteWrite(writeSpinCount < 0);
446     }
447 
448     @Override
449     protected AbstractNioUnsafe newUnsafe() {
450         return new NioSocketChannelUnsafe();
451     }
452 
453     private final class NioSocketChannelUnsafe extends NioByteUnsafe {
454         @Override
455         protected Executor prepareToClose() {
456             try {
457                 if (javaChannel().isOpen() && config().getSoLinger() > 0) {
458                     // We need to cancel this key of the channel so we may not end up in a eventloop spin
459                     // because we try to read or write until the actual close happens which may be later due
460                     // SO_LINGER handling.
461                     // See https://github.com/netty/netty/issues/4449
462                     doDeregister();
463                     return GlobalEventExecutor.INSTANCE;
464                 }
465             } catch (Throwable ignore) {
466                 // Ignore the error as the underlying channel may be closed in the meantime and so
467                 // getSoLinger() may produce an exception. In this case we just return null.
468                 // See https://github.com/netty/netty/issues/4449
469             }
470             return null;
471         }
472     }
473 
474     private final class NioSocketChannelConfig extends DefaultSocketChannelConfig {
475         private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE;
476         private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) {
477             super(channel, javaSocket);
478             calculateMaxBytesPerGatheringWrite();
479         }
480 
481         @Override
482         protected void autoReadCleared() {
483             clearReadPending();
484         }
485 
486         @Override
487         public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) {
488             super.setSendBufferSize(sendBufferSize);
489             calculateMaxBytesPerGatheringWrite();
490             return this;
491         }
492 
493         @Override
494         public <T> boolean setOption(ChannelOption<T> option, T value) {
495             if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
496                 return NioChannelOption.setOption(jdkChannel(), (NioChannelOption<T>) option, value);
497             }
498             return super.setOption(option, value);
499         }
500 
501         @Override
502         public <T> T getOption(ChannelOption<T> option) {
503             if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) {
504                 return NioChannelOption.getOption(jdkChannel(), (NioChannelOption<T>) option);
505             }
506             return super.getOption(option);
507         }
508 
509         @Override
510         public Map<ChannelOption<?>, Object> getOptions() {
511             if (PlatformDependent.javaVersion() >= 7) {
512                 return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel()));
513             }
514             return super.getOptions();
515         }
516 
517         void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) {
518             this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite;
519         }
520 
521         int getMaxBytesPerGatheringWrite() {
522             return maxBytesPerGatheringWrite;
523         }
524 
525         private void calculateMaxBytesPerGatheringWrite() {
526             // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide.
527             int newSendBufferSize = getSendBufferSize() << 1;
528             if (newSendBufferSize > 0) {
529                 setMaxBytesPerGatheringWrite(newSendBufferSize);
530             }
531         }
532 
533         private SocketChannel jdkChannel() {
534             return ((NioSocketChannel) channel).javaChannel();
535         }
536     }
537 }