1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket; 18 19 import java.io.EOFException; 20 import java.io.IOException; 21 import java.nio.ByteBuffer; 22 import java.nio.channels.AsynchronousSocketChannel; 23 import java.nio.channels.CompletionHandler; 24 import java.util.concurrent.CountDownLatch; 25 import java.util.concurrent.ExecutionException; 26 import java.util.concurrent.ExecutorService; 27 import java.util.concurrent.Executors; 28 import java.util.concurrent.Future; 29 import java.util.concurrent.ThreadFactory; 30 import java.util.concurrent.TimeUnit; 31 import java.util.concurrent.TimeoutException; 32 import java.util.concurrent.atomic.AtomicBoolean; 33 import java.util.concurrent.atomic.AtomicInteger; 34 35 import javax.net.ssl.SSLEngine; 36 import javax.net.ssl.SSLEngineResult; 37 import javax.net.ssl.SSLEngineResult.HandshakeStatus; 38 import javax.net.ssl.SSLEngineResult.Status; 39 import javax.net.ssl.SSLException; 40 41 import org.apache.juli.logging.Log; 42 import org.apache.juli.logging.LogFactory; 43 import org.apache.tomcat.util.res.StringManager; 44 45 /** 46 * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot 47 * more testing before it can be considered robust. 48 */ 49 public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { 50 51 private final Log log = 52 LogFactory.getLog(AsyncChannelWrapperSecure.class); 53 private static final StringManager sm = 54 StringManager.getManager(AsyncChannelWrapperSecure.class); 55 56 private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921); 57 private final AsynchronousSocketChannel socketChannel; 58 private final SSLEngine sslEngine; 59 private final ByteBuffer socketReadBuffer; 60 private final ByteBuffer socketWriteBuffer; 61 // One thread for read, one for write 62 private final ExecutorService executor = 63 Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); 64 private AtomicBoolean writing = new AtomicBoolean(false); 65 private AtomicBoolean reading = new AtomicBoolean(false); 66 AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, SSLEngine sslEngine)67 public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, 68 SSLEngine sslEngine) { 69 this.socketChannel = socketChannel; 70 this.sslEngine = sslEngine; 71 72 int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); 73 socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); 74 socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize); 75 } 76 77 @Override read(ByteBuffer dst)78 public Future<Integer> read(ByteBuffer dst) { 79 WrapperFuture<Integer,Void> future = new WrapperFuture<>(); 80 81 if (!reading.compareAndSet(false, true)) { 82 throw new IllegalStateException(sm.getString( 83 "asyncChannelWrapperSecure.concurrentRead")); 84 } 85 86 ReadTask readTask = new ReadTask(dst, future); 87 88 executor.execute(readTask); 89 90 return future; 91 } 92 93 @Override read(ByteBuffer dst, A attachment, CompletionHandler<Integer,B> handler)94 public <B,A extends B> void read(ByteBuffer dst, A attachment, 95 CompletionHandler<Integer,B> handler) { 96 97 WrapperFuture<Integer,B> future = 98 new WrapperFuture<>(handler, attachment); 99 100 if (!reading.compareAndSet(false, true)) { 101 throw new IllegalStateException(sm.getString( 102 "asyncChannelWrapperSecure.concurrentRead")); 103 } 104 105 ReadTask readTask = new ReadTask(dst, future); 106 107 executor.execute(readTask); 108 } 109 110 @Override write(ByteBuffer src)111 public Future<Integer> write(ByteBuffer src) { 112 113 WrapperFuture<Long,Void> inner = new WrapperFuture<>(); 114 115 if (!writing.compareAndSet(false, true)) { 116 throw new IllegalStateException(sm.getString( 117 "asyncChannelWrapperSecure.concurrentWrite")); 118 } 119 120 WriteTask writeTask = 121 new WriteTask(new ByteBuffer[] {src}, 0, 1, inner); 122 123 executor.execute(writeTask); 124 125 Future<Integer> future = new LongToIntegerFuture(inner); 126 return future; 127 } 128 129 @Override write(ByteBuffer[] srcs, int offset, int length, long timeout, TimeUnit unit, A attachment, CompletionHandler<Long,B> handler)130 public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, 131 long timeout, TimeUnit unit, A attachment, 132 CompletionHandler<Long,B> handler) { 133 134 WrapperFuture<Long,B> future = 135 new WrapperFuture<>(handler, attachment); 136 137 if (!writing.compareAndSet(false, true)) { 138 throw new IllegalStateException(sm.getString( 139 "asyncChannelWrapperSecure.concurrentWrite")); 140 } 141 142 WriteTask writeTask = new WriteTask(srcs, offset, length, future); 143 144 executor.execute(writeTask); 145 } 146 147 @Override close()148 public void close() { 149 try { 150 socketChannel.close(); 151 } catch (IOException e) { 152 log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); 153 } 154 executor.shutdownNow(); 155 } 156 157 @Override handshake()158 public Future<Void> handshake() throws SSLException { 159 160 WrapperFuture<Void,Void> wFuture = new WrapperFuture<>(); 161 162 Thread t = new WebSocketSslHandshakeThread(wFuture); 163 t.start(); 164 165 return wFuture; 166 } 167 168 169 private class WriteTask implements Runnable { 170 171 private final ByteBuffer[] srcs; 172 private final int offset; 173 private final int length; 174 private final WrapperFuture<Long,?> future; 175 WriteTask(ByteBuffer[] srcs, int offset, int length, WrapperFuture<Long,?> future)176 public WriteTask(ByteBuffer[] srcs, int offset, int length, 177 WrapperFuture<Long,?> future) { 178 this.srcs = srcs; 179 this.future = future; 180 this.offset = offset; 181 this.length = length; 182 } 183 184 @Override run()185 public void run() { 186 long written = 0; 187 188 try { 189 for (int i = offset; i < offset + length; i++) { 190 ByteBuffer src = srcs[i]; 191 while (src.hasRemaining()) { 192 socketWriteBuffer.clear(); 193 194 // Encrypt the data 195 SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer); 196 written += r.bytesConsumed(); 197 Status s = r.getStatus(); 198 199 if (s == Status.OK || s == Status.BUFFER_OVERFLOW) { 200 // Need to write out the bytes and may need to read from 201 // the source again to empty it 202 } else { 203 // Status.BUFFER_UNDERFLOW - only happens on unwrap 204 // Status.CLOSED - unexpected 205 throw new IllegalStateException(sm.getString( 206 "asyncChannelWrapperSecure.statusWrap")); 207 } 208 209 // Check for tasks 210 if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { 211 Runnable runnable = sslEngine.getDelegatedTask(); 212 while (runnable != null) { 213 runnable.run(); 214 runnable = sslEngine.getDelegatedTask(); 215 } 216 } 217 218 socketWriteBuffer.flip(); 219 220 // Do the write 221 int toWrite = r.bytesProduced(); 222 while (toWrite > 0) { 223 Future<Integer> f = 224 socketChannel.write(socketWriteBuffer); 225 Integer socketWrite = f.get(); 226 toWrite -= socketWrite.intValue(); 227 } 228 } 229 } 230 231 232 if (writing.compareAndSet(true, false)) { 233 future.complete(Long.valueOf(written)); 234 } else { 235 future.fail(new IllegalStateException(sm.getString( 236 "asyncChannelWrapperSecure.wrongStateWrite"))); 237 } 238 } catch (Exception e) { 239 writing.set(false); 240 future.fail(e); 241 } 242 } 243 } 244 245 246 private class ReadTask implements Runnable { 247 248 private final ByteBuffer dest; 249 private final WrapperFuture<Integer,?> future; 250 ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future)251 public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) { 252 this.dest = dest; 253 this.future = future; 254 } 255 256 @Override run()257 public void run() { 258 int read = 0; 259 260 boolean forceRead = false; 261 262 try { 263 while (read == 0) { 264 socketReadBuffer.compact(); 265 266 if (forceRead) { 267 forceRead = false; 268 Future<Integer> f = socketChannel.read(socketReadBuffer); 269 Integer socketRead = f.get(); 270 if (socketRead.intValue() == -1) { 271 throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof")); 272 } 273 } 274 275 socketReadBuffer.flip(); 276 277 if (socketReadBuffer.hasRemaining()) { 278 // Decrypt the data in the buffer 279 SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest); 280 read += r.bytesProduced(); 281 Status s = r.getStatus(); 282 283 if (s == Status.OK) { 284 // Bytes available for reading and there may be 285 // sufficient data in the socketReadBuffer to 286 // support further reads without reading from the 287 // socket 288 } else if (s == Status.BUFFER_UNDERFLOW) { 289 // There is partial data in the socketReadBuffer 290 if (read == 0) { 291 // Need more data before the partial data can be 292 // processed and some output generated 293 forceRead = true; 294 } 295 // else return the data we have and deal with the 296 // partial data on the next read 297 } else if (s == Status.BUFFER_OVERFLOW) { 298 // Not enough space in the destination buffer to 299 // store all of the data. We could use a bytes read 300 // value of -bufferSizeRequired to signal the new 301 // buffer size required but an explicit exception is 302 // clearer. 303 if (reading.compareAndSet(true, false)) { 304 throw new ReadBufferOverflowException(sslEngine. 305 getSession().getApplicationBufferSize()); 306 } else { 307 future.fail(new IllegalStateException(sm.getString( 308 "asyncChannelWrapperSecure.wrongStateRead"))); 309 } 310 } else { 311 // Status.CLOSED - unexpected 312 throw new IllegalStateException(sm.getString( 313 "asyncChannelWrapperSecure.statusUnwrap")); 314 } 315 316 // Check for tasks 317 if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { 318 Runnable runnable = sslEngine.getDelegatedTask(); 319 while (runnable != null) { 320 runnable.run(); 321 runnable = sslEngine.getDelegatedTask(); 322 } 323 } 324 } else { 325 forceRead = true; 326 } 327 } 328 329 330 if (reading.compareAndSet(true, false)) { 331 future.complete(Integer.valueOf(read)); 332 } else { 333 future.fail(new IllegalStateException(sm.getString( 334 "asyncChannelWrapperSecure.wrongStateRead"))); 335 } 336 } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | 337 ExecutionException | InterruptedException e) { 338 reading.set(false); 339 future.fail(e); 340 } 341 } 342 } 343 344 345 private class WebSocketSslHandshakeThread extends Thread { 346 347 private final WrapperFuture<Void,Void> hFuture; 348 349 private HandshakeStatus handshakeStatus; 350 private Status resultStatus; 351 WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture)352 public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) { 353 this.hFuture = hFuture; 354 } 355 356 @Override run()357 public void run() { 358 try { 359 sslEngine.beginHandshake(); 360 // So the first compact does the right thing 361 socketReadBuffer.position(socketReadBuffer.limit()); 362 363 handshakeStatus = sslEngine.getHandshakeStatus(); 364 resultStatus = Status.OK; 365 366 boolean handshaking = true; 367 368 while(handshaking) { 369 switch (handshakeStatus) { 370 case NEED_WRAP: { 371 socketWriteBuffer.clear(); 372 SSLEngineResult r = 373 sslEngine.wrap(DUMMY, socketWriteBuffer); 374 checkResult(r, true); 375 socketWriteBuffer.flip(); 376 Future<Integer> fWrite = 377 socketChannel.write(socketWriteBuffer); 378 fWrite.get(); 379 break; 380 } 381 case NEED_UNWRAP: { 382 socketReadBuffer.compact(); 383 if (socketReadBuffer.position() == 0 || 384 resultStatus == Status.BUFFER_UNDERFLOW) { 385 Future<Integer> fRead = 386 socketChannel.read(socketReadBuffer); 387 fRead.get(); 388 } 389 socketReadBuffer.flip(); 390 SSLEngineResult r = 391 sslEngine.unwrap(socketReadBuffer, DUMMY); 392 checkResult(r, false); 393 break; 394 } 395 case NEED_TASK: { 396 Runnable r = null; 397 while ((r = sslEngine.getDelegatedTask()) != null) { 398 r.run(); 399 } 400 handshakeStatus = sslEngine.getHandshakeStatus(); 401 break; 402 } 403 case FINISHED: { 404 handshaking = false; 405 break; 406 } 407 case NOT_HANDSHAKING: { 408 throw new SSLException( 409 sm.getString("asyncChannelWrapperSecure.notHandshaking")); 410 } 411 } 412 } 413 } catch (Exception e) { 414 hFuture.fail(e); 415 return; 416 } 417 418 hFuture.complete(null); 419 } 420 checkResult(SSLEngineResult result, boolean wrap)421 private void checkResult(SSLEngineResult result, boolean wrap) 422 throws SSLException { 423 424 handshakeStatus = result.getHandshakeStatus(); 425 resultStatus = result.getStatus(); 426 427 if (resultStatus != Status.OK && 428 (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) { 429 throw new SSLException( 430 sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus)); 431 } 432 if (wrap && result.bytesConsumed() != 0) { 433 throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap")); 434 } 435 if (!wrap && result.bytesProduced() != 0) { 436 throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap")); 437 } 438 } 439 } 440 441 442 private static class WrapperFuture<T,A> implements Future<T> { 443 444 private final CompletionHandler<T,A> handler; 445 private final A attachment; 446 447 private volatile T result = null; 448 private volatile Throwable throwable = null; 449 private CountDownLatch completionLatch = new CountDownLatch(1); 450 WrapperFuture()451 public WrapperFuture() { 452 this(null, null); 453 } 454 WrapperFuture(CompletionHandler<T,A> handler, A attachment)455 public WrapperFuture(CompletionHandler<T,A> handler, A attachment) { 456 this.handler = handler; 457 this.attachment = attachment; 458 } 459 complete(T result)460 public void complete(T result) { 461 this.result = result; 462 completionLatch.countDown(); 463 if (handler != null) { 464 handler.completed(result, attachment); 465 } 466 } 467 fail(Throwable t)468 public void fail(Throwable t) { 469 throwable = t; 470 completionLatch.countDown(); 471 if (handler != null) { 472 handler.failed(throwable, attachment); 473 } 474 } 475 476 @Override cancel(boolean mayInterruptIfRunning)477 public final boolean cancel(boolean mayInterruptIfRunning) { 478 // Could support cancellation by closing the connection 479 return false; 480 } 481 482 @Override isCancelled()483 public final boolean isCancelled() { 484 // Could support cancellation by closing the connection 485 return false; 486 } 487 488 @Override isDone()489 public final boolean isDone() { 490 return completionLatch.getCount() > 0; 491 } 492 493 @Override get()494 public T get() throws InterruptedException, ExecutionException { 495 completionLatch.await(); 496 if (throwable != null) { 497 throw new ExecutionException(throwable); 498 } 499 return result; 500 } 501 502 @Override get(long timeout, TimeUnit unit)503 public T get(long timeout, TimeUnit unit) 504 throws InterruptedException, ExecutionException, 505 TimeoutException { 506 boolean latchResult = completionLatch.await(timeout, unit); 507 if (latchResult == false) { 508 throw new TimeoutException(); 509 } 510 if (throwable != null) { 511 throw new ExecutionException(throwable); 512 } 513 return result; 514 } 515 } 516 517 private static final class LongToIntegerFuture implements Future<Integer> { 518 519 private final Future<Long> wrapped; 520 LongToIntegerFuture(Future<Long> wrapped)521 public LongToIntegerFuture(Future<Long> wrapped) { 522 this.wrapped = wrapped; 523 } 524 525 @Override cancel(boolean mayInterruptIfRunning)526 public boolean cancel(boolean mayInterruptIfRunning) { 527 return wrapped.cancel(mayInterruptIfRunning); 528 } 529 530 @Override isCancelled()531 public boolean isCancelled() { 532 return wrapped.isCancelled(); 533 } 534 535 @Override isDone()536 public boolean isDone() { 537 return wrapped.isDone(); 538 } 539 540 @Override get()541 public Integer get() throws InterruptedException, ExecutionException { 542 Long result = wrapped.get(); 543 if (result.longValue() > Integer.MAX_VALUE) { 544 throw new ExecutionException(sm.getString( 545 "asyncChannelWrapperSecure.tooBig", result), null); 546 } 547 return Integer.valueOf(result.intValue()); 548 } 549 550 @Override get(long timeout, TimeUnit unit)551 public Integer get(long timeout, TimeUnit unit) 552 throws InterruptedException, ExecutionException, 553 TimeoutException { 554 Long result = wrapped.get(timeout, unit); 555 if (result.longValue() > Integer.MAX_VALUE) { 556 throw new ExecutionException(sm.getString( 557 "asyncChannelWrapperSecure.tooBig", result), null); 558 } 559 return Integer.valueOf(result.intValue()); 560 } 561 } 562 563 564 private static class SecureIOThreadFactory implements ThreadFactory { 565 566 private AtomicInteger count = new AtomicInteger(0); 567 568 @Override newThread(Runnable r)569 public Thread newThread(Runnable r) { 570 Thread t = new Thread(r); 571 t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet()); 572 // No need to set the context class loader. The threads will be 573 // cleaned up when the connection is closed. 574 t.setDaemon(true); 575 return t; 576 } 577 } 578 } 579