1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package io.netty.handler.proxy;
18
19 import io.netty.channel.Channel;
20 import io.netty.channel.ChannelDuplexHandler;
21 import io.netty.channel.ChannelFuture;
22 import io.netty.channel.ChannelFutureListener;
23 import io.netty.channel.ChannelHandlerContext;
24 import io.netty.channel.ChannelPromise;
25 import io.netty.channel.PendingWriteQueue;
26 import io.netty.util.ReferenceCountUtil;
27 import io.netty.util.concurrent.DefaultPromise;
28 import io.netty.util.concurrent.EventExecutor;
29 import io.netty.util.concurrent.Future;
30 import io.netty.util.internal.ObjectUtil;
31 import io.netty.util.internal.logging.InternalLogger;
32 import io.netty.util.internal.logging.InternalLoggerFactory;
33
34 import java.net.SocketAddress;
35 import java.nio.channels.ConnectionPendingException;
36 import java.util.concurrent.TimeUnit;
37
38
39
40
41 public abstract class ProxyHandler extends ChannelDuplexHandler {
42
43 private static final InternalLogger logger = InternalLoggerFactory.getInstance(ProxyHandler.class);
44
45
46
47
48 private static final long DEFAULT_CONNECT_TIMEOUT_MILLIS = 10000;
49
50
51
52
53 static final String AUTH_NONE = "none";
54
55 private final SocketAddress proxyAddress;
56 private volatile SocketAddress destinationAddress;
57 private volatile long connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT_MILLIS;
58
59 private volatile ChannelHandlerContext ctx;
60 private PendingWriteQueue pendingWrites;
61 private boolean finished;
62 private boolean suppressChannelReadComplete;
63 private boolean flushedPrematurely;
64 private final LazyChannelPromise connectPromise = new LazyChannelPromise();
65 private Future<?> connectTimeoutFuture;
66 private final ChannelFutureListener writeListener = new ChannelFutureListener() {
67 @Override
68 public void operationComplete(ChannelFuture future) throws Exception {
69 if (!future.isSuccess()) {
70 setConnectFailure(future.cause());
71 }
72 }
73 };
74
75 protected ProxyHandler(SocketAddress proxyAddress) {
76 this.proxyAddress = ObjectUtil.checkNotNull(proxyAddress, "proxyAddress");
77 }
78
79
80
81
82 public abstract String protocol();
83
84
85
86
87 public abstract String authScheme();
88
89
90
91
92 @SuppressWarnings("unchecked")
93 public final <T extends SocketAddress> T proxyAddress() {
94 return (T) proxyAddress;
95 }
96
97
98
99
100 @SuppressWarnings("unchecked")
101 public final <T extends SocketAddress> T destinationAddress() {
102 return (T) destinationAddress;
103 }
104
105
106
107
108 public final boolean isConnected() {
109 return connectPromise.isSuccess();
110 }
111
112
113
114
115
116 public final Future<Channel> connectFuture() {
117 return connectPromise;
118 }
119
120
121
122
123
124 public final long connectTimeoutMillis() {
125 return connectTimeoutMillis;
126 }
127
128
129
130
131
132 public final void setConnectTimeoutMillis(long connectTimeoutMillis) {
133 if (connectTimeoutMillis <= 0) {
134 connectTimeoutMillis = 0;
135 }
136
137 this.connectTimeoutMillis = connectTimeoutMillis;
138 }
139
140 @Override
141 public final void handlerAdded(ChannelHandlerContext ctx) throws Exception {
142 this.ctx = ctx;
143 addCodec(ctx);
144
145 if (ctx.channel().isActive()) {
146
147
148 sendInitialMessage(ctx);
149 } else {
150
151
152 }
153 }
154
155
156
157
158 protected abstract void addCodec(ChannelHandlerContext ctx) throws Exception;
159
160
161
162
163 protected abstract void removeEncoder(ChannelHandlerContext ctx) throws Exception;
164
165
166
167
168 protected abstract void removeDecoder(ChannelHandlerContext ctx) throws Exception;
169
170 @Override
171 public final void connect(
172 ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress,
173 ChannelPromise promise) throws Exception {
174
175 if (destinationAddress != null) {
176 promise.setFailure(new ConnectionPendingException());
177 return;
178 }
179
180 destinationAddress = remoteAddress;
181 ctx.connect(proxyAddress, localAddress, promise);
182 }
183
184 @Override
185 public final void channelActive(ChannelHandlerContext ctx) throws Exception {
186 sendInitialMessage(ctx);
187 ctx.fireChannelActive();
188 }
189
190
191
192
193
194 private void sendInitialMessage(final ChannelHandlerContext ctx) throws Exception {
195 final long connectTimeoutMillis = this.connectTimeoutMillis;
196 if (connectTimeoutMillis > 0) {
197 connectTimeoutFuture = ctx.executor().schedule(new Runnable() {
198 @Override
199 public void run() {
200 if (!connectPromise.isDone()) {
201 setConnectFailure(new ProxyConnectException(exceptionMessage("timeout")));
202 }
203 }
204 }, connectTimeoutMillis, TimeUnit.MILLISECONDS);
205 }
206
207 final Object initialMessage = newInitialMessage(ctx);
208 if (initialMessage != null) {
209 sendToProxyServer(initialMessage);
210 }
211
212 readIfNeeded(ctx);
213 }
214
215
216
217
218
219
220 protected abstract Object newInitialMessage(ChannelHandlerContext ctx) throws Exception;
221
222
223
224
225
226 protected final void sendToProxyServer(Object msg) {
227 ctx.writeAndFlush(msg).addListener(writeListener);
228 }
229
230 @Override
231 public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
232 if (finished) {
233 ctx.fireChannelInactive();
234 } else {
235
236 setConnectFailure(new ProxyConnectException(exceptionMessage("disconnected")));
237 }
238 }
239
240 @Override
241 public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
242 if (finished) {
243 ctx.fireExceptionCaught(cause);
244 } else {
245
246 setConnectFailure(cause);
247 }
248 }
249
250 @Override
251 public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
252 if (finished) {
253
254 suppressChannelReadComplete = false;
255 ctx.fireChannelRead(msg);
256 } else {
257 suppressChannelReadComplete = true;
258 Throwable cause = null;
259 try {
260 boolean done = handleResponse(ctx, msg);
261 if (done) {
262 setConnectSuccess();
263 }
264 } catch (Throwable t) {
265 cause = t;
266 } finally {
267 ReferenceCountUtil.release(msg);
268 if (cause != null) {
269 setConnectFailure(cause);
270 }
271 }
272 }
273 }
274
275
276
277
278
279
280
281
282 protected abstract boolean handleResponse(ChannelHandlerContext ctx, Object response) throws Exception;
283
284 private void setConnectSuccess() {
285 finished = true;
286 cancelConnectTimeoutFuture();
287
288 if (!connectPromise.isDone()) {
289 boolean removedCodec = true;
290
291 removedCodec &= safeRemoveEncoder();
292
293 ctx.fireUserEventTriggered(
294 new ProxyConnectionEvent(protocol(), authScheme(), proxyAddress, destinationAddress));
295
296 removedCodec &= safeRemoveDecoder();
297
298 if (removedCodec) {
299 writePendingWrites();
300
301 if (flushedPrematurely) {
302 ctx.flush();
303 }
304 connectPromise.trySuccess(ctx.channel());
305 } else {
306
307 Exception cause = new ProxyConnectException(
308 "failed to remove all codec handlers added by the proxy handler; bug?");
309 failPendingWritesAndClose(cause);
310 }
311 }
312 }
313
314 private boolean safeRemoveDecoder() {
315 try {
316 removeDecoder(ctx);
317 return true;
318 } catch (Exception e) {
319 logger.warn("Failed to remove proxy decoders:", e);
320 }
321
322 return false;
323 }
324
325 private boolean safeRemoveEncoder() {
326 try {
327 removeEncoder(ctx);
328 return true;
329 } catch (Exception e) {
330 logger.warn("Failed to remove proxy encoders:", e);
331 }
332
333 return false;
334 }
335
336 private void setConnectFailure(Throwable cause) {
337 finished = true;
338 cancelConnectTimeoutFuture();
339
340 if (!connectPromise.isDone()) {
341
342 if (!(cause instanceof ProxyConnectException)) {
343 cause = new ProxyConnectException(
344 exceptionMessage(cause.toString()), cause);
345 }
346
347 safeRemoveDecoder();
348 safeRemoveEncoder();
349 failPendingWritesAndClose(cause);
350 }
351 }
352
353 private void failPendingWritesAndClose(Throwable cause) {
354 failPendingWrites(cause);
355 connectPromise.tryFailure(cause);
356 ctx.fireExceptionCaught(cause);
357 ctx.close();
358 }
359
360 private void cancelConnectTimeoutFuture() {
361 if (connectTimeoutFuture != null) {
362 connectTimeoutFuture.cancel(false);
363 connectTimeoutFuture = null;
364 }
365 }
366
367
368
369
370
371 protected final String exceptionMessage(String msg) {
372 if (msg == null) {
373 msg = "";
374 }
375
376 StringBuilder buf = new StringBuilder(128 + msg.length())
377 .append(protocol())
378 .append(", ")
379 .append(authScheme())
380 .append(", ")
381 .append(proxyAddress)
382 .append(" => ")
383 .append(destinationAddress);
384 if (!msg.isEmpty()) {
385 buf.append(", ").append(msg);
386 }
387
388 return buf.toString();
389 }
390
391 @Override
392 public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
393 if (suppressChannelReadComplete) {
394 suppressChannelReadComplete = false;
395
396 readIfNeeded(ctx);
397 } else {
398 ctx.fireChannelReadComplete();
399 }
400 }
401
402 @Override
403 public final void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
404 if (finished) {
405 writePendingWrites();
406 ctx.write(msg, promise);
407 } else {
408 addPendingWrite(ctx, msg, promise);
409 }
410 }
411
412 @Override
413 public final void flush(ChannelHandlerContext ctx) throws Exception {
414 if (finished) {
415 writePendingWrites();
416 ctx.flush();
417 } else {
418 flushedPrematurely = true;
419 }
420 }
421
422 private static void readIfNeeded(ChannelHandlerContext ctx) {
423 if (!ctx.channel().config().isAutoRead()) {
424 ctx.read();
425 }
426 }
427
428 private void writePendingWrites() {
429 if (pendingWrites != null) {
430 pendingWrites.removeAndWriteAll();
431 pendingWrites = null;
432 }
433 }
434
435 private void failPendingWrites(Throwable cause) {
436 if (pendingWrites != null) {
437 pendingWrites.removeAndFailAll(cause);
438 pendingWrites = null;
439 }
440 }
441
442 private void addPendingWrite(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
443 PendingWriteQueue pendingWrites = this.pendingWrites;
444 if (pendingWrites == null) {
445 this.pendingWrites = pendingWrites = new PendingWriteQueue(ctx);
446 }
447 pendingWrites.add(msg, promise);
448 }
449
450 private final class LazyChannelPromise extends DefaultPromise<Channel> {
451 @Override
452 protected EventExecutor executor() {
453 if (ctx == null) {
454 throw new IllegalStateException();
455 }
456 return ctx.executor();
457 }
458 }
459 }