1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */
17 package nginx.unit.websocket;
18 
19 import java.io.EOFException;
20 import java.io.IOException;
21 import java.nio.ByteBuffer;
22 import java.nio.channels.AsynchronousSocketChannel;
23 import java.nio.channels.CompletionHandler;
24 import java.util.concurrent.CountDownLatch;
25 import java.util.concurrent.ExecutionException;
26 import java.util.concurrent.ExecutorService;
27 import java.util.concurrent.Executors;
28 import java.util.concurrent.Future;
29 import java.util.concurrent.ThreadFactory;
30 import java.util.concurrent.TimeUnit;
31 import java.util.concurrent.TimeoutException;
32 import java.util.concurrent.atomic.AtomicBoolean;
33 import java.util.concurrent.atomic.AtomicInteger;
34 
35 import javax.net.ssl.SSLEngine;
36 import javax.net.ssl.SSLEngineResult;
37 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
38 import javax.net.ssl.SSLEngineResult.Status;
39 import javax.net.ssl.SSLException;
40 
41 import org.apache.juli.logging.Log;
42 import org.apache.juli.logging.LogFactory;
43 import org.apache.tomcat.util.res.StringManager;
44 
45 /**
46  * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot
47  * more testing before it can be considered robust.
48  */
49 public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
50 
51     private final Log log =
52             LogFactory.getLog(AsyncChannelWrapperSecure.class);
53     private static final StringManager sm =
54             StringManager.getManager(AsyncChannelWrapperSecure.class);
55 
56     private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921);
57     private final AsynchronousSocketChannel socketChannel;
58     private final SSLEngine sslEngine;
59     private final ByteBuffer socketReadBuffer;
60     private final ByteBuffer socketWriteBuffer;
61     // One thread for read, one for write
62     private final ExecutorService executor =
63             Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
64     private AtomicBoolean writing = new AtomicBoolean(false);
65     private AtomicBoolean reading = new AtomicBoolean(false);
66 
AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine)67     public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel,
68             SSLEngine sslEngine) {
69         this.socketChannel = socketChannel;
70         this.sslEngine = sslEngine;
71 
72         int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
73         socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
74         socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize);
75     }
76 
77     @Override
read(ByteBuffer dst)78     public Future<Integer> read(ByteBuffer dst) {
79         WrapperFuture<Integer,Void> future = new WrapperFuture<>();
80 
81         if (!reading.compareAndSet(false, true)) {
82             throw new IllegalStateException(sm.getString(
83                     "asyncChannelWrapperSecure.concurrentRead"));
84         }
85 
86         ReadTask readTask = new ReadTask(dst, future);
87 
88         executor.execute(readTask);
89 
90         return future;
91     }
92 
93     @Override
read(ByteBuffer dst, A attachment, CompletionHandler<Integer,B> handler)94     public <B,A extends B> void read(ByteBuffer dst, A attachment,
95             CompletionHandler<Integer,B> handler) {
96 
97         WrapperFuture<Integer,B> future =
98                 new WrapperFuture<>(handler, attachment);
99 
100         if (!reading.compareAndSet(false, true)) {
101             throw new IllegalStateException(sm.getString(
102                     "asyncChannelWrapperSecure.concurrentRead"));
103         }
104 
105         ReadTask readTask = new ReadTask(dst, future);
106 
107         executor.execute(readTask);
108     }
109 
110     @Override
write(ByteBuffer src)111     public Future<Integer> write(ByteBuffer src) {
112 
113         WrapperFuture<Long,Void> inner = new WrapperFuture<>();
114 
115         if (!writing.compareAndSet(false, true)) {
116             throw new IllegalStateException(sm.getString(
117                     "asyncChannelWrapperSecure.concurrentWrite"));
118         }
119 
120         WriteTask writeTask =
121                 new WriteTask(new ByteBuffer[] {src}, 0, 1, inner);
122 
123         executor.execute(writeTask);
124 
125         Future<Integer> future = new LongToIntegerFuture(inner);
126         return future;
127     }
128 
129     @Override
write(ByteBuffer[] srcs, int offset, int length, long timeout, TimeUnit unit, A attachment, CompletionHandler<Long,B> handler)130     public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
131             long timeout, TimeUnit unit, A attachment,
132             CompletionHandler<Long,B> handler) {
133 
134         WrapperFuture<Long,B> future =
135                 new WrapperFuture<>(handler, attachment);
136 
137         if (!writing.compareAndSet(false, true)) {
138             throw new IllegalStateException(sm.getString(
139                     "asyncChannelWrapperSecure.concurrentWrite"));
140         }
141 
142         WriteTask writeTask = new WriteTask(srcs, offset, length, future);
143 
144         executor.execute(writeTask);
145     }
146 
147     @Override
close()148     public void close() {
149         try {
150             socketChannel.close();
151         } catch (IOException e) {
152             log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
153         }
154         executor.shutdownNow();
155     }
156 
157     @Override
handshake()158     public Future<Void> handshake() throws SSLException {
159 
160         WrapperFuture<Void,Void> wFuture = new WrapperFuture<>();
161 
162         Thread t = new WebSocketSslHandshakeThread(wFuture);
163         t.start();
164 
165         return wFuture;
166     }
167 
168 
169     private class WriteTask implements Runnable {
170 
171         private final ByteBuffer[] srcs;
172         private final int offset;
173         private final int length;
174         private final WrapperFuture<Long,?> future;
175 
WriteTask(ByteBuffer[] srcs, int offset, int length, WrapperFuture<Long,?> future)176         public WriteTask(ByteBuffer[] srcs, int offset, int length,
177                 WrapperFuture<Long,?> future) {
178             this.srcs = srcs;
179             this.future = future;
180             this.offset = offset;
181             this.length = length;
182         }
183 
184         @Override
run()185         public void run() {
186             long written = 0;
187 
188             try {
189                 for (int i = offset; i < offset + length; i++) {
190                     ByteBuffer src = srcs[i];
191                     while (src.hasRemaining()) {
192                         socketWriteBuffer.clear();
193 
194                         // Encrypt the data
195                         SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer);
196                         written += r.bytesConsumed();
197                         Status s = r.getStatus();
198 
199                         if (s == Status.OK || s == Status.BUFFER_OVERFLOW) {
200                             // Need to write out the bytes and may need to read from
201                             // the source again to empty it
202                         } else {
203                             // Status.BUFFER_UNDERFLOW - only happens on unwrap
204                             // Status.CLOSED - unexpected
205                             throw new IllegalStateException(sm.getString(
206                                     "asyncChannelWrapperSecure.statusWrap"));
207                         }
208 
209                         // Check for tasks
210                         if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
211                             Runnable runnable = sslEngine.getDelegatedTask();
212                             while (runnable != null) {
213                                 runnable.run();
214                                 runnable = sslEngine.getDelegatedTask();
215                             }
216                         }
217 
218                         socketWriteBuffer.flip();
219 
220                         // Do the write
221                         int toWrite = r.bytesProduced();
222                         while (toWrite > 0) {
223                             Future<Integer> f =
224                                     socketChannel.write(socketWriteBuffer);
225                             Integer socketWrite = f.get();
226                             toWrite -= socketWrite.intValue();
227                         }
228                     }
229                 }
230 
231 
232                 if (writing.compareAndSet(true, false)) {
233                     future.complete(Long.valueOf(written));
234                 } else {
235                     future.fail(new IllegalStateException(sm.getString(
236                             "asyncChannelWrapperSecure.wrongStateWrite")));
237                 }
238             } catch (Exception e) {
239                 writing.set(false);
240                 future.fail(e);
241             }
242         }
243     }
244 
245 
246     private class ReadTask implements Runnable {
247 
248         private final ByteBuffer dest;
249         private final WrapperFuture<Integer,?> future;
250 
ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future)251         public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) {
252             this.dest = dest;
253             this.future = future;
254         }
255 
256         @Override
run()257         public void run() {
258             int read = 0;
259 
260             boolean forceRead = false;
261 
262             try {
263                 while (read == 0) {
264                     socketReadBuffer.compact();
265 
266                     if (forceRead) {
267                         forceRead = false;
268                         Future<Integer> f = socketChannel.read(socketReadBuffer);
269                         Integer socketRead = f.get();
270                         if (socketRead.intValue() == -1) {
271                             throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof"));
272                         }
273                     }
274 
275                     socketReadBuffer.flip();
276 
277                     if (socketReadBuffer.hasRemaining()) {
278                         // Decrypt the data in the buffer
279                         SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest);
280                         read += r.bytesProduced();
281                         Status s = r.getStatus();
282 
283                         if (s == Status.OK) {
284                             // Bytes available for reading and there may be
285                             // sufficient data in the socketReadBuffer to
286                             // support further reads without reading from the
287                             // socket
288                         } else if (s == Status.BUFFER_UNDERFLOW) {
289                             // There is partial data in the socketReadBuffer
290                             if (read == 0) {
291                                 // Need more data before the partial data can be
292                                 // processed and some output generated
293                                 forceRead = true;
294                             }
295                             // else return the data we have and deal with the
296                             // partial data on the next read
297                         } else if (s == Status.BUFFER_OVERFLOW) {
298                             // Not enough space in the destination buffer to
299                             // store all of the data. We could use a bytes read
300                             // value of -bufferSizeRequired to signal the new
301                             // buffer size required but an explicit exception is
302                             // clearer.
303                             if (reading.compareAndSet(true, false)) {
304                                 throw new ReadBufferOverflowException(sslEngine.
305                                         getSession().getApplicationBufferSize());
306                             } else {
307                                 future.fail(new IllegalStateException(sm.getString(
308                                         "asyncChannelWrapperSecure.wrongStateRead")));
309                             }
310                         } else {
311                             // Status.CLOSED - unexpected
312                             throw new IllegalStateException(sm.getString(
313                                     "asyncChannelWrapperSecure.statusUnwrap"));
314                         }
315 
316                         // Check for tasks
317                         if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
318                             Runnable runnable = sslEngine.getDelegatedTask();
319                             while (runnable != null) {
320                                 runnable.run();
321                                 runnable = sslEngine.getDelegatedTask();
322                             }
323                         }
324                     } else {
325                         forceRead = true;
326                     }
327                 }
328 
329 
330                 if (reading.compareAndSet(true, false)) {
331                     future.complete(Integer.valueOf(read));
332                 } else {
333                     future.fail(new IllegalStateException(sm.getString(
334                             "asyncChannelWrapperSecure.wrongStateRead")));
335                 }
336             } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException |
337                     ExecutionException | InterruptedException e) {
338                 reading.set(false);
339                 future.fail(e);
340             }
341         }
342     }
343 
344 
345     private class WebSocketSslHandshakeThread extends Thread {
346 
347         private final WrapperFuture<Void,Void> hFuture;
348 
349         private HandshakeStatus handshakeStatus;
350         private Status resultStatus;
351 
WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture)352         public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) {
353             this.hFuture = hFuture;
354         }
355 
356         @Override
run()357         public void run() {
358             try {
359                 sslEngine.beginHandshake();
360                 // So the first compact does the right thing
361                 socketReadBuffer.position(socketReadBuffer.limit());
362 
363                 handshakeStatus = sslEngine.getHandshakeStatus();
364                 resultStatus = Status.OK;
365 
366                 boolean handshaking = true;
367 
368                 while(handshaking) {
369                     switch (handshakeStatus) {
370                         case NEED_WRAP: {
371                             socketWriteBuffer.clear();
372                             SSLEngineResult r =
373                                     sslEngine.wrap(DUMMY, socketWriteBuffer);
374                             checkResult(r, true);
375                             socketWriteBuffer.flip();
376                             Future<Integer> fWrite =
377                                     socketChannel.write(socketWriteBuffer);
378                             fWrite.get();
379                             break;
380                         }
381                         case NEED_UNWRAP: {
382                             socketReadBuffer.compact();
383                             if (socketReadBuffer.position() == 0 ||
384                                     resultStatus == Status.BUFFER_UNDERFLOW) {
385                                 Future<Integer> fRead =
386                                         socketChannel.read(socketReadBuffer);
387                                 fRead.get();
388                             }
389                             socketReadBuffer.flip();
390                             SSLEngineResult r =
391                                     sslEngine.unwrap(socketReadBuffer, DUMMY);
392                             checkResult(r, false);
393                             break;
394                         }
395                         case NEED_TASK: {
396                             Runnable r = null;
397                             while ((r = sslEngine.getDelegatedTask()) != null) {
398                                 r.run();
399                             }
400                             handshakeStatus = sslEngine.getHandshakeStatus();
401                             break;
402                         }
403                         case FINISHED: {
404                             handshaking = false;
405                             break;
406                         }
407                         case NOT_HANDSHAKING: {
408                             throw new SSLException(
409                                     sm.getString("asyncChannelWrapperSecure.notHandshaking"));
410                         }
411                     }
412                 }
413             } catch (Exception e) {
414                 hFuture.fail(e);
415                 return;
416             }
417 
418             hFuture.complete(null);
419         }
420 
checkResult(SSLEngineResult result, boolean wrap)421         private void checkResult(SSLEngineResult result, boolean wrap)
422                 throws SSLException {
423 
424             handshakeStatus = result.getHandshakeStatus();
425             resultStatus = result.getStatus();
426 
427             if (resultStatus != Status.OK &&
428                     (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) {
429                 throw new SSLException(
430                         sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus));
431             }
432             if (wrap && result.bytesConsumed() != 0) {
433                 throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap"));
434             }
435             if (!wrap && result.bytesProduced() != 0) {
436                 throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap"));
437             }
438         }
439     }
440 
441 
442     private static class WrapperFuture<T,A> implements Future<T> {
443 
444         private final CompletionHandler<T,A> handler;
445         private final A attachment;
446 
447         private volatile T result = null;
448         private volatile Throwable throwable = null;
449         private CountDownLatch completionLatch = new CountDownLatch(1);
450 
WrapperFuture()451         public WrapperFuture() {
452             this(null, null);
453         }
454 
WrapperFuture(CompletionHandler<T,A> handler, A attachment)455         public WrapperFuture(CompletionHandler<T,A> handler, A attachment) {
456             this.handler = handler;
457             this.attachment = attachment;
458         }
459 
complete(T result)460         public void complete(T result) {
461             this.result = result;
462             completionLatch.countDown();
463             if (handler != null) {
464                 handler.completed(result, attachment);
465             }
466         }
467 
fail(Throwable t)468         public void fail(Throwable t) {
469             throwable = t;
470             completionLatch.countDown();
471             if (handler != null) {
472                 handler.failed(throwable, attachment);
473             }
474         }
475 
476         @Override
cancel(boolean mayInterruptIfRunning)477         public final boolean cancel(boolean mayInterruptIfRunning) {
478             // Could support cancellation by closing the connection
479             return false;
480         }
481 
482         @Override
isCancelled()483         public final boolean isCancelled() {
484             // Could support cancellation by closing the connection
485             return false;
486         }
487 
488         @Override
isDone()489         public final boolean isDone() {
490             return completionLatch.getCount() > 0;
491         }
492 
493         @Override
get()494         public T get() throws InterruptedException, ExecutionException {
495             completionLatch.await();
496             if (throwable != null) {
497                 throw new ExecutionException(throwable);
498             }
499             return result;
500         }
501 
502         @Override
get(long timeout, TimeUnit unit)503         public T get(long timeout, TimeUnit unit)
504                 throws InterruptedException, ExecutionException,
505                 TimeoutException {
506             boolean latchResult = completionLatch.await(timeout, unit);
507             if (latchResult == false) {
508                 throw new TimeoutException();
509             }
510             if (throwable != null) {
511                 throw new ExecutionException(throwable);
512             }
513             return result;
514         }
515     }
516 
517     private static final class LongToIntegerFuture implements Future<Integer> {
518 
519         private final Future<Long> wrapped;
520 
LongToIntegerFuture(Future<Long> wrapped)521         public LongToIntegerFuture(Future<Long> wrapped) {
522             this.wrapped = wrapped;
523         }
524 
525         @Override
cancel(boolean mayInterruptIfRunning)526         public boolean cancel(boolean mayInterruptIfRunning) {
527             return wrapped.cancel(mayInterruptIfRunning);
528         }
529 
530         @Override
isCancelled()531         public boolean isCancelled() {
532             return wrapped.isCancelled();
533         }
534 
535         @Override
isDone()536         public boolean isDone() {
537             return wrapped.isDone();
538         }
539 
540         @Override
get()541         public Integer get() throws InterruptedException, ExecutionException {
542             Long result = wrapped.get();
543             if (result.longValue() > Integer.MAX_VALUE) {
544                 throw new ExecutionException(sm.getString(
545                         "asyncChannelWrapperSecure.tooBig", result), null);
546             }
547             return Integer.valueOf(result.intValue());
548         }
549 
550         @Override
get(long timeout, TimeUnit unit)551         public Integer get(long timeout, TimeUnit unit)
552                 throws InterruptedException, ExecutionException,
553                 TimeoutException {
554             Long result = wrapped.get(timeout, unit);
555             if (result.longValue() > Integer.MAX_VALUE) {
556                 throw new ExecutionException(sm.getString(
557                         "asyncChannelWrapperSecure.tooBig", result), null);
558             }
559             return Integer.valueOf(result.intValue());
560         }
561     }
562 
563 
564     private static class SecureIOThreadFactory implements ThreadFactory {
565 
566         private AtomicInteger count = new AtomicInteger(0);
567 
568         @Override
newThread(Runnable r)569         public Thread newThread(Runnable r) {
570             Thread t = new Thread(r);
571             t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet());
572             // No need to set the context class loader. The threads will be
573             // cleaned up when the connection is closed.
574             t.setDaemon(true);
575             return t;
576         }
577     }
578 }
579