1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket; 18 19 import java.io.IOException; 20 import java.io.OutputStream; 21 import java.io.Writer; 22 import java.net.SocketTimeoutException; 23 import java.nio.ByteBuffer; 24 import java.nio.CharBuffer; 25 import java.nio.charset.CharsetEncoder; 26 import java.nio.charset.CoderResult; 27 import java.util.ArrayDeque; 28 import java.util.ArrayList; 29 import java.util.List; 30 import java.util.Queue; 31 import java.util.concurrent.Future; 32 import java.util.concurrent.Semaphore; 33 import java.util.concurrent.TimeUnit; 34 import java.util.concurrent.atomic.AtomicBoolean; 35 36 import javax.websocket.CloseReason; 37 import javax.websocket.CloseReason.CloseCodes; 38 import javax.websocket.DeploymentException; 39 import javax.websocket.EncodeException; 40 import javax.websocket.Encoder; 41 import javax.websocket.EndpointConfig; 42 import javax.websocket.RemoteEndpoint; 43 import javax.websocket.SendHandler; 44 import javax.websocket.SendResult; 45 46 import org.apache.juli.logging.Log; 47 import org.apache.juli.logging.LogFactory; 48 import org.apache.tomcat.util.buf.Utf8Encoder; 49 import org.apache.tomcat.util.res.StringManager; 50 51 import nginx.unit.Request; 52 53 public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint { 54 55 private static final StringManager sm = 56 StringManager.getManager(WsRemoteEndpointImplBase.class); 57 58 protected static final SendResult SENDRESULT_OK = new SendResult(); 59 60 private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static 61 62 private final StateMachine stateMachine = new StateMachine(); 63 64 private final IntermediateMessageHandler intermediateMessageHandler = 65 new IntermediateMessageHandler(this); 66 67 private Transformation transformation = null; 68 private final Semaphore messagePartInProgress = new Semaphore(1); 69 private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>(); 70 private final Object messagePartLock = new Object(); 71 72 // State 73 private volatile boolean closed = false; 74 private boolean fragmented = false; 75 private boolean nextFragmented = false; 76 private boolean text = false; 77 private boolean nextText = false; 78 79 // Max size of WebSocket header is 14 bytes 80 private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); 81 private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 82 private final CharsetEncoder encoder = new Utf8Encoder(); 83 private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 84 private final AtomicBoolean batchingAllowed = new AtomicBoolean(false); 85 private volatile long sendTimeout = -1; 86 private WsSession wsSession; 87 private List<EncoderEntry> encoderEntries = new ArrayList<>(); 88 89 private Request request; 90 91 setTransformation(Transformation transformation)92 protected void setTransformation(Transformation transformation) { 93 this.transformation = transformation; 94 } 95 96 getSendTimeout()97 public long getSendTimeout() { 98 return sendTimeout; 99 } 100 101 setSendTimeout(long timeout)102 public void setSendTimeout(long timeout) { 103 this.sendTimeout = timeout; 104 } 105 106 107 @Override setBatchingAllowed(boolean batchingAllowed)108 public void setBatchingAllowed(boolean batchingAllowed) throws IOException { 109 boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); 110 111 if (oldValue && !batchingAllowed) { 112 flushBatch(); 113 } 114 } 115 116 117 @Override getBatchingAllowed()118 public boolean getBatchingAllowed() { 119 return batchingAllowed.get(); 120 } 121 122 123 @Override flushBatch()124 public void flushBatch() throws IOException { 125 sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true); 126 } 127 128 sendBytes(ByteBuffer data)129 public void sendBytes(ByteBuffer data) throws IOException { 130 if (data == null) { 131 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 132 } 133 stateMachine.binaryStart(); 134 sendMessageBlock(Constants.OPCODE_BINARY, data, true); 135 stateMachine.complete(true); 136 } 137 138 sendBytesByFuture(ByteBuffer data)139 public Future<Void> sendBytesByFuture(ByteBuffer data) { 140 FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); 141 sendBytesByCompletion(data, f2sh); 142 return f2sh; 143 } 144 145 sendBytesByCompletion(ByteBuffer data, SendHandler handler)146 public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) { 147 if (data == null) { 148 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 149 } 150 if (handler == null) { 151 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); 152 } 153 StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine); 154 stateMachine.binaryStart(); 155 startMessage(Constants.OPCODE_BINARY, data, true, sush); 156 } 157 158 sendPartialBytes(ByteBuffer partialByte, boolean last)159 public void sendPartialBytes(ByteBuffer partialByte, boolean last) 160 throws IOException { 161 if (partialByte == null) { 162 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 163 } 164 stateMachine.binaryPartialStart(); 165 sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last); 166 stateMachine.complete(last); 167 } 168 169 170 @Override sendPing(ByteBuffer applicationData)171 public void sendPing(ByteBuffer applicationData) throws IOException, 172 IllegalArgumentException { 173 if (applicationData.remaining() > 125) { 174 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); 175 } 176 sendMessageBlock(Constants.OPCODE_PING, applicationData, true); 177 } 178 179 180 @Override sendPong(ByteBuffer applicationData)181 public void sendPong(ByteBuffer applicationData) throws IOException, 182 IllegalArgumentException { 183 if (applicationData.remaining() > 125) { 184 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); 185 } 186 sendMessageBlock(Constants.OPCODE_PONG, applicationData, true); 187 } 188 189 sendString(String text)190 public void sendString(String text) throws IOException { 191 if (text == null) { 192 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 193 } 194 stateMachine.textStart(); 195 sendMessageBlock(CharBuffer.wrap(text), true); 196 } 197 198 sendStringByFuture(String text)199 public Future<Void> sendStringByFuture(String text) { 200 FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); 201 sendStringByCompletion(text, f2sh); 202 return f2sh; 203 } 204 205 sendStringByCompletion(String text, SendHandler handler)206 public void sendStringByCompletion(String text, SendHandler handler) { 207 if (text == null) { 208 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 209 } 210 if (handler == null) { 211 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); 212 } 213 stateMachine.textStart(); 214 TextMessageSendHandler tmsh = new TextMessageSendHandler(handler, 215 CharBuffer.wrap(text), true, encoder, encoderBuffer, this); 216 tmsh.write(); 217 // TextMessageSendHandler will update stateMachine when it completes 218 } 219 220 sendPartialString(String fragment, boolean isLast)221 public void sendPartialString(String fragment, boolean isLast) 222 throws IOException { 223 if (fragment == null) { 224 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 225 } 226 stateMachine.textPartialStart(); 227 sendMessageBlock(CharBuffer.wrap(fragment), isLast); 228 } 229 230 getSendStream()231 public OutputStream getSendStream() { 232 stateMachine.streamStart(); 233 return new WsOutputStream(this); 234 } 235 236 getSendWriter()237 public Writer getSendWriter() { 238 stateMachine.writeStart(); 239 return new WsWriter(this); 240 } 241 242 sendMessageBlock(CharBuffer part, boolean last)243 void sendMessageBlock(CharBuffer part, boolean last) throws IOException { 244 long timeoutExpiry = getTimeoutExpiry(); 245 boolean isDone = false; 246 while (!isDone) { 247 encoderBuffer.clear(); 248 CoderResult cr = encoder.encode(part, encoderBuffer, true); 249 if (cr.isError()) { 250 throw new IllegalArgumentException(cr.toString()); 251 } 252 isDone = !cr.isOverflow(); 253 encoderBuffer.flip(); 254 sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry); 255 } 256 stateMachine.complete(last); 257 } 258 259 sendMessageBlock(byte opCode, ByteBuffer payload, boolean last)260 void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last) 261 throws IOException { 262 sendMessageBlock(opCode, payload, last, getTimeoutExpiry()); 263 } 264 265 getTimeoutExpiry()266 private long getTimeoutExpiry() { 267 // Get the timeout before we send the message. The message may 268 // trigger a session close and depending on timing the client 269 // session may close before we can read the timeout. 270 long timeout = getBlockingSendTimeout(); 271 if (timeout < 0) { 272 return Long.MAX_VALUE; 273 } else { 274 return System.currentTimeMillis() + timeout; 275 } 276 } 277 278 private byte currentOpCode = Constants.OPCODE_CONTINUATION; 279 sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, long timeoutExpiry)280 private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, 281 long timeoutExpiry) throws IOException { 282 wsSession.updateLastActive(); 283 284 if (opCode == currentOpCode) { 285 opCode = Constants.OPCODE_CONTINUATION; 286 } 287 288 request.sendWsFrame(payload, opCode, last, timeoutExpiry); 289 290 if (!last && opCode != Constants.OPCODE_CONTINUATION) { 291 currentOpCode = opCode; 292 } 293 294 if (last && opCode == Constants.OPCODE_CONTINUATION) { 295 currentOpCode = Constants.OPCODE_CONTINUATION; 296 } 297 } 298 299 startMessage(byte opCode, ByteBuffer payload, boolean last, SendHandler handler)300 void startMessage(byte opCode, ByteBuffer payload, boolean last, 301 SendHandler handler) { 302 303 wsSession.updateLastActive(); 304 305 List<MessagePart> messageParts = new ArrayList<>(); 306 messageParts.add(new MessagePart(last, 0, opCode, payload, 307 intermediateMessageHandler, 308 new EndMessageHandler(this, handler), -1)); 309 310 messageParts = transformation.sendMessagePart(messageParts); 311 312 // Some extensions/transformations may buffer messages so it is possible 313 // that no message parts will be returned. If this is the case the 314 // trigger the supplied SendHandler 315 if (messageParts.size() == 0) { 316 handler.onResult(new SendResult()); 317 return; 318 } 319 320 MessagePart mp = messageParts.remove(0); 321 322 boolean doWrite = false; 323 synchronized (messagePartLock) { 324 if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) { 325 // Should not happen. To late to send batched messages now since 326 // the session has been closed. Complain loudly. 327 log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed")); 328 } 329 if (messagePartInProgress.tryAcquire()) { 330 doWrite = true; 331 } else { 332 // When a control message is sent while another message is being 333 // sent, the control message is queued. Chances are the 334 // subsequent data message part will end up queued while the 335 // control message is sent. The logic in this class (state 336 // machine, EndMessageHandler, TextMessageSendHandler) ensures 337 // that there will only ever be one data message part in the 338 // queue. There could be multiple control messages in the queue. 339 340 // Add it to the queue 341 messagePartQueue.add(mp); 342 } 343 // Add any remaining messages to the queue 344 messagePartQueue.addAll(messageParts); 345 } 346 if (doWrite) { 347 // Actual write has to be outside sync block to avoid possible 348 // deadlock between messagePartLock and writeLock in 349 // o.a.coyote.http11.upgrade.AbstractServletOutputStream 350 writeMessagePart(mp); 351 } 352 } 353 354 endMessage(SendHandler handler, SendResult result)355 void endMessage(SendHandler handler, SendResult result) { 356 boolean doWrite = false; 357 MessagePart mpNext = null; 358 synchronized (messagePartLock) { 359 360 fragmented = nextFragmented; 361 text = nextText; 362 363 mpNext = messagePartQueue.poll(); 364 if (mpNext == null) { 365 messagePartInProgress.release(); 366 } else if (!closed){ 367 // Session may have been closed unexpectedly in the middle of 368 // sending a fragmented message closing the endpoint. If this 369 // happens, clearly there is no point trying to send the rest of 370 // the message. 371 doWrite = true; 372 } 373 } 374 if (doWrite) { 375 // Actual write has to be outside sync block to avoid possible 376 // deadlock between messagePartLock and writeLock in 377 // o.a.coyote.http11.upgrade.AbstractServletOutputStream 378 writeMessagePart(mpNext); 379 } 380 381 wsSession.updateLastActive(); 382 383 // Some handlers, such as the IntermediateMessageHandler, do not have a 384 // nested handler so handler may be null. 385 if (handler != null) { 386 handler.onResult(result); 387 } 388 } 389 390 writeMessagePart(MessagePart mp)391 void writeMessagePart(MessagePart mp) { 392 if (closed) { 393 throw new IllegalStateException( 394 sm.getString("wsRemoteEndpoint.closed")); 395 } 396 397 if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) { 398 nextFragmented = fragmented; 399 nextText = text; 400 outputBuffer.flip(); 401 SendHandler flushHandler = new OutputBufferFlushSendHandler( 402 outputBuffer, mp.getEndHandler()); 403 doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer); 404 return; 405 } 406 407 // Control messages may be sent in the middle of fragmented message 408 // so they have no effect on the fragmented or text flags 409 boolean first; 410 if (Util.isControl(mp.getOpCode())) { 411 nextFragmented = fragmented; 412 nextText = text; 413 if (mp.getOpCode() == Constants.OPCODE_CLOSE) { 414 closed = true; 415 } 416 first = true; 417 } else { 418 boolean isText = Util.isText(mp.getOpCode()); 419 420 if (fragmented) { 421 // Currently fragmented 422 if (text != isText) { 423 throw new IllegalStateException( 424 sm.getString("wsRemoteEndpoint.changeType")); 425 } 426 nextText = text; 427 nextFragmented = !mp.isFin(); 428 first = false; 429 } else { 430 // Wasn't fragmented. Might be now 431 if (mp.isFin()) { 432 nextFragmented = false; 433 } else { 434 nextFragmented = true; 435 nextText = isText; 436 } 437 first = true; 438 } 439 } 440 441 byte[] mask; 442 443 if (isMasked()) { 444 mask = Util.generateMask(); 445 } else { 446 mask = null; 447 } 448 449 headerBuffer.clear(); 450 writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(), 451 isMasked(), mp.getPayload(), mask, first); 452 headerBuffer.flip(); 453 454 if (getBatchingAllowed() || isMasked()) { 455 // Need to write via output buffer 456 OutputBufferSendHandler obsh = new OutputBufferSendHandler( 457 mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), 458 headerBuffer, mp.getPayload(), mask, 459 outputBuffer, !getBatchingAllowed(), this); 460 obsh.write(); 461 } else { 462 // Can write directly 463 doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), 464 headerBuffer, mp.getPayload()); 465 } 466 } 467 468 getBlockingSendTimeout()469 private long getBlockingSendTimeout() { 470 Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY); 471 Long userTimeout = null; 472 if (obj instanceof Long) { 473 userTimeout = (Long) obj; 474 } 475 if (userTimeout == null) { 476 return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT; 477 } else { 478 return userTimeout.longValue(); 479 } 480 } 481 482 483 /** 484 * Wraps the user provided handler so that the end point is notified when 485 * the message is complete. 486 */ 487 private static class EndMessageHandler implements SendHandler { 488 489 private final WsRemoteEndpointImplBase endpoint; 490 private final SendHandler handler; 491 EndMessageHandler(WsRemoteEndpointImplBase endpoint, SendHandler handler)492 public EndMessageHandler(WsRemoteEndpointImplBase endpoint, 493 SendHandler handler) { 494 this.endpoint = endpoint; 495 this.handler = handler; 496 } 497 498 499 @Override onResult(SendResult result)500 public void onResult(SendResult result) { 501 endpoint.endMessage(handler, result); 502 } 503 } 504 505 506 /** 507 * If a transformation needs to split a {@link MessagePart} into multiple 508 * {@link MessagePart}s, it uses this handler as the end handler for each of 509 * the additional {@link MessagePart}s. This handler notifies this this 510 * class that the {@link MessagePart} has been processed and that the next 511 * {@link MessagePart} in the queue should be started. The final 512 * {@link MessagePart} will use the {@link EndMessageHandler} provided with 513 * the original {@link MessagePart}. 514 */ 515 private static class IntermediateMessageHandler implements SendHandler { 516 517 private final WsRemoteEndpointImplBase endpoint; 518 IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint)519 public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) { 520 this.endpoint = endpoint; 521 } 522 523 524 @Override onResult(SendResult result)525 public void onResult(SendResult result) { 526 endpoint.endMessage(null, result); 527 } 528 } 529 530 531 @SuppressWarnings({"unchecked", "rawtypes"}) sendObject(Object obj)532 public void sendObject(Object obj) throws IOException, EncodeException { 533 if (obj == null) { 534 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 535 } 536 /* 537 * Note that the implementation will convert primitives and their object 538 * equivalents by default but that users are free to specify their own 539 * encoders and decoders for this if they wish. 540 */ 541 Encoder encoder = findEncoder(obj); 542 if (encoder == null && Util.isPrimitive(obj.getClass())) { 543 String msg = obj.toString(); 544 sendString(msg); 545 return; 546 } 547 if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { 548 ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); 549 sendBytes(msg); 550 return; 551 } 552 553 if (encoder instanceof Encoder.Text) { 554 String msg = ((Encoder.Text) encoder).encode(obj); 555 sendString(msg); 556 } else if (encoder instanceof Encoder.TextStream) { 557 try (Writer w = getSendWriter()) { 558 ((Encoder.TextStream) encoder).encode(obj, w); 559 } 560 } else if (encoder instanceof Encoder.Binary) { 561 ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); 562 sendBytes(msg); 563 } else if (encoder instanceof Encoder.BinaryStream) { 564 try (OutputStream os = getSendStream()) { 565 ((Encoder.BinaryStream) encoder).encode(obj, os); 566 } 567 } else { 568 throw new EncodeException(obj, sm.getString( 569 "wsRemoteEndpoint.noEncoder", obj.getClass())); 570 } 571 } 572 573 sendObjectByFuture(Object obj)574 public Future<Void> sendObjectByFuture(Object obj) { 575 FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); 576 sendObjectByCompletion(obj, f2sh); 577 return f2sh; 578 } 579 580 581 @SuppressWarnings({"unchecked", "rawtypes"}) sendObjectByCompletion(Object obj, SendHandler completion)582 public void sendObjectByCompletion(Object obj, SendHandler completion) { 583 584 if (obj == null) { 585 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); 586 } 587 if (completion == null) { 588 throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); 589 } 590 591 /* 592 * Note that the implementation will convert primitives and their object 593 * equivalents by default but that users are free to specify their own 594 * encoders and decoders for this if they wish. 595 */ 596 Encoder encoder = findEncoder(obj); 597 if (encoder == null && Util.isPrimitive(obj.getClass())) { 598 String msg = obj.toString(); 599 sendStringByCompletion(msg, completion); 600 return; 601 } 602 if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { 603 ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); 604 sendBytesByCompletion(msg, completion); 605 return; 606 } 607 608 try { 609 if (encoder instanceof Encoder.Text) { 610 String msg = ((Encoder.Text) encoder).encode(obj); 611 sendStringByCompletion(msg, completion); 612 } else if (encoder instanceof Encoder.TextStream) { 613 try (Writer w = getSendWriter()) { 614 ((Encoder.TextStream) encoder).encode(obj, w); 615 } 616 completion.onResult(new SendResult()); 617 } else if (encoder instanceof Encoder.Binary) { 618 ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); 619 sendBytesByCompletion(msg, completion); 620 } else if (encoder instanceof Encoder.BinaryStream) { 621 try (OutputStream os = getSendStream()) { 622 ((Encoder.BinaryStream) encoder).encode(obj, os); 623 } 624 completion.onResult(new SendResult()); 625 } else { 626 throw new EncodeException(obj, sm.getString( 627 "wsRemoteEndpoint.noEncoder", obj.getClass())); 628 } 629 } catch (Exception e) { 630 SendResult sr = new SendResult(e); 631 completion.onResult(sr); 632 } 633 } 634 635 setSession(WsSession wsSession)636 protected void setSession(WsSession wsSession) { 637 this.wsSession = wsSession; 638 } 639 640 setRequest(Request request)641 protected void setRequest(Request request) { 642 this.request = request; 643 } 644 setEncoders(EndpointConfig endpointConfig)645 protected void setEncoders(EndpointConfig endpointConfig) 646 throws DeploymentException { 647 encoderEntries.clear(); 648 for (Class<? extends Encoder> encoderClazz : 649 endpointConfig.getEncoders()) { 650 Encoder instance; 651 try { 652 instance = encoderClazz.getConstructor().newInstance(); 653 instance.init(endpointConfig); 654 } catch (ReflectiveOperationException e) { 655 throw new DeploymentException( 656 sm.getString("wsRemoteEndpoint.invalidEncoder", 657 encoderClazz.getName()), e); 658 } 659 EncoderEntry entry = new EncoderEntry( 660 Util.getEncoderType(encoderClazz), instance); 661 encoderEntries.add(entry); 662 } 663 } 664 665 findEncoder(Object obj)666 private Encoder findEncoder(Object obj) { 667 for (EncoderEntry entry : encoderEntries) { 668 if (entry.getClazz().isAssignableFrom(obj.getClass())) { 669 return entry.getEncoder(); 670 } 671 } 672 return null; 673 } 674 675 close()676 public final void close() { 677 for (EncoderEntry entry : encoderEntries) { 678 entry.getEncoder().destroy(); 679 } 680 681 request.closeWs(); 682 } 683 684 doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, ByteBuffer... data)685 protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, 686 ByteBuffer... data); isMasked()687 protected abstract boolean isMasked(); doClose()688 protected abstract void doClose(); 689 writeHeader(ByteBuffer headerBuffer, boolean fin, int rsv, byte opCode, boolean masked, ByteBuffer payload, byte[] mask, boolean first)690 private static void writeHeader(ByteBuffer headerBuffer, boolean fin, 691 int rsv, byte opCode, boolean masked, ByteBuffer payload, 692 byte[] mask, boolean first) { 693 694 byte b = 0; 695 696 if (fin) { 697 // Set the fin bit 698 b -= 128; 699 } 700 701 b += (rsv << 4); 702 703 if (first) { 704 // This is the first fragment of this message 705 b += opCode; 706 } 707 // If not the first fragment, it is a continuation with opCode of zero 708 709 headerBuffer.put(b); 710 711 if (masked) { 712 b = (byte) 0x80; 713 } else { 714 b = 0; 715 } 716 717 // Next write the mask && length length 718 if (payload.limit() < 126) { 719 headerBuffer.put((byte) (payload.limit() | b)); 720 } else if (payload.limit() < 65536) { 721 headerBuffer.put((byte) (126 | b)); 722 headerBuffer.put((byte) (payload.limit() >>> 8)); 723 headerBuffer.put((byte) (payload.limit() & 0xFF)); 724 } else { 725 // Will never be more than 2^31-1 726 headerBuffer.put((byte) (127 | b)); 727 headerBuffer.put((byte) 0); 728 headerBuffer.put((byte) 0); 729 headerBuffer.put((byte) 0); 730 headerBuffer.put((byte) 0); 731 headerBuffer.put((byte) (payload.limit() >>> 24)); 732 headerBuffer.put((byte) (payload.limit() >>> 16)); 733 headerBuffer.put((byte) (payload.limit() >>> 8)); 734 headerBuffer.put((byte) (payload.limit() & 0xFF)); 735 } 736 if (masked) { 737 headerBuffer.put(mask[0]); 738 headerBuffer.put(mask[1]); 739 headerBuffer.put(mask[2]); 740 headerBuffer.put(mask[3]); 741 } 742 } 743 744 745 private class TextMessageSendHandler implements SendHandler { 746 747 private final SendHandler handler; 748 private final CharBuffer message; 749 private final boolean isLast; 750 private final CharsetEncoder encoder; 751 private final ByteBuffer buffer; 752 private final WsRemoteEndpointImplBase endpoint; 753 private volatile boolean isDone = false; 754 TextMessageSendHandler(SendHandler handler, CharBuffer message, boolean isLast, CharsetEncoder encoder, ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint)755 public TextMessageSendHandler(SendHandler handler, CharBuffer message, 756 boolean isLast, CharsetEncoder encoder, 757 ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) { 758 this.handler = handler; 759 this.message = message; 760 this.isLast = isLast; 761 this.encoder = encoder.reset(); 762 this.buffer = encoderBuffer; 763 this.endpoint = endpoint; 764 } 765 write()766 public void write() { 767 buffer.clear(); 768 CoderResult cr = encoder.encode(message, buffer, true); 769 if (cr.isError()) { 770 throw new IllegalArgumentException(cr.toString()); 771 } 772 isDone = !cr.isOverflow(); 773 buffer.flip(); 774 endpoint.startMessage(Constants.OPCODE_TEXT, buffer, 775 isDone && isLast, this); 776 } 777 778 @Override onResult(SendResult result)779 public void onResult(SendResult result) { 780 if (isDone) { 781 endpoint.stateMachine.complete(isLast); 782 handler.onResult(result); 783 } else if(!result.isOK()) { 784 handler.onResult(result); 785 } else if (closed){ 786 SendResult sr = new SendResult(new IOException( 787 sm.getString("wsRemoteEndpoint.closedDuringMessage"))); 788 handler.onResult(sr); 789 } else { 790 write(); 791 } 792 } 793 } 794 795 796 /** 797 * Used to write data to the output buffer, flushing the buffer if it fills 798 * up. 799 */ 800 private static class OutputBufferSendHandler implements SendHandler { 801 802 private final SendHandler handler; 803 private final long blockingWriteTimeoutExpiry; 804 private final ByteBuffer headerBuffer; 805 private final ByteBuffer payload; 806 private final byte[] mask; 807 private final ByteBuffer outputBuffer; 808 private final boolean flushRequired; 809 private final WsRemoteEndpointImplBase endpoint; 810 private int maskIndex = 0; 811 OutputBufferSendHandler(SendHandler completion, long blockingWriteTimeoutExpiry, ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, ByteBuffer outputBuffer, boolean flushRequired, WsRemoteEndpointImplBase endpoint)812 public OutputBufferSendHandler(SendHandler completion, 813 long blockingWriteTimeoutExpiry, 814 ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, 815 ByteBuffer outputBuffer, boolean flushRequired, 816 WsRemoteEndpointImplBase endpoint) { 817 this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; 818 this.handler = completion; 819 this.headerBuffer = headerBuffer; 820 this.payload = payload; 821 this.mask = mask; 822 this.outputBuffer = outputBuffer; 823 this.flushRequired = flushRequired; 824 this.endpoint = endpoint; 825 } 826 write()827 public void write() { 828 // Write the header 829 while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { 830 outputBuffer.put(headerBuffer.get()); 831 } 832 if (headerBuffer.hasRemaining()) { 833 // Still more headers to write, need to flush 834 outputBuffer.flip(); 835 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); 836 return; 837 } 838 839 // Write the payload 840 int payloadLeft = payload.remaining(); 841 int payloadLimit = payload.limit(); 842 int outputSpace = outputBuffer.remaining(); 843 int toWrite = payloadLeft; 844 845 if (payloadLeft > outputSpace) { 846 toWrite = outputSpace; 847 // Temporarily reduce the limit 848 payload.limit(payload.position() + toWrite); 849 } 850 851 if (mask == null) { 852 // Use a bulk copy 853 outputBuffer.put(payload); 854 } else { 855 for (int i = 0; i < toWrite; i++) { 856 outputBuffer.put( 857 (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); 858 if (maskIndex > 3) { 859 maskIndex = 0; 860 } 861 } 862 } 863 864 if (payloadLeft > outputSpace) { 865 // Restore the original limit 866 payload.limit(payloadLimit); 867 // Still more data to write, need to flush 868 outputBuffer.flip(); 869 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); 870 return; 871 } 872 873 if (flushRequired) { 874 outputBuffer.flip(); 875 if (outputBuffer.remaining() == 0) { 876 handler.onResult(SENDRESULT_OK); 877 } else { 878 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); 879 } 880 } else { 881 handler.onResult(SENDRESULT_OK); 882 } 883 } 884 885 // ------------------------------------------------- SendHandler methods 886 @Override onResult(SendResult result)887 public void onResult(SendResult result) { 888 if (result.isOK()) { 889 if (outputBuffer.hasRemaining()) { 890 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); 891 } else { 892 outputBuffer.clear(); 893 write(); 894 } 895 } else { 896 handler.onResult(result); 897 } 898 } 899 } 900 901 902 /** 903 * Ensures that the output buffer is cleared after it has been flushed. 904 */ 905 private static class OutputBufferFlushSendHandler implements SendHandler { 906 907 private final ByteBuffer outputBuffer; 908 private final SendHandler handler; 909 OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler)910 public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) { 911 this.outputBuffer = outputBuffer; 912 this.handler = handler; 913 } 914 915 @Override onResult(SendResult result)916 public void onResult(SendResult result) { 917 if (result.isOK()) { 918 outputBuffer.clear(); 919 } 920 handler.onResult(result); 921 } 922 } 923 924 925 private static class WsOutputStream extends OutputStream { 926 927 private final WsRemoteEndpointImplBase endpoint; 928 private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 929 private final Object closeLock = new Object(); 930 private volatile boolean closed = false; 931 private volatile boolean used = false; 932 WsOutputStream(WsRemoteEndpointImplBase endpoint)933 public WsOutputStream(WsRemoteEndpointImplBase endpoint) { 934 this.endpoint = endpoint; 935 } 936 937 @Override write(int b)938 public void write(int b) throws IOException { 939 if (closed) { 940 throw new IllegalStateException( 941 sm.getString("wsRemoteEndpoint.closedOutputStream")); 942 } 943 944 used = true; 945 if (buffer.remaining() == 0) { 946 flush(); 947 } 948 buffer.put((byte) b); 949 } 950 951 @Override write(byte[] b, int off, int len)952 public void write(byte[] b, int off, int len) throws IOException { 953 if (closed) { 954 throw new IllegalStateException( 955 sm.getString("wsRemoteEndpoint.closedOutputStream")); 956 } 957 if (len == 0) { 958 return; 959 } 960 if ((off < 0) || (off > b.length) || (len < 0) || 961 ((off + len) > b.length) || ((off + len) < 0)) { 962 throw new IndexOutOfBoundsException(); 963 } 964 965 used = true; 966 if (buffer.remaining() == 0) { 967 flush(); 968 } 969 int remaining = buffer.remaining(); 970 int written = 0; 971 972 while (remaining < len - written) { 973 buffer.put(b, off + written, remaining); 974 written += remaining; 975 flush(); 976 remaining = buffer.remaining(); 977 } 978 buffer.put(b, off + written, len - written); 979 } 980 981 @Override flush()982 public void flush() throws IOException { 983 if (closed) { 984 throw new IllegalStateException( 985 sm.getString("wsRemoteEndpoint.closedOutputStream")); 986 } 987 988 // Optimisation. If there is no data to flush then do not send an 989 // empty message. 990 if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { 991 doWrite(false); 992 } 993 } 994 995 @Override close()996 public void close() throws IOException { 997 synchronized (closeLock) { 998 if (closed) { 999 return; 1000 } 1001 closed = true; 1002 } 1003 1004 doWrite(true); 1005 } 1006 doWrite(boolean last)1007 private void doWrite(boolean last) throws IOException { 1008 if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { 1009 buffer.flip(); 1010 endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last); 1011 } 1012 endpoint.stateMachine.complete(last); 1013 buffer.clear(); 1014 } 1015 } 1016 1017 1018 private static class WsWriter extends Writer { 1019 1020 private final WsRemoteEndpointImplBase endpoint; 1021 private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 1022 private final Object closeLock = new Object(); 1023 private volatile boolean closed = false; 1024 private volatile boolean used = false; 1025 WsWriter(WsRemoteEndpointImplBase endpoint)1026 public WsWriter(WsRemoteEndpointImplBase endpoint) { 1027 this.endpoint = endpoint; 1028 } 1029 1030 @Override write(char[] cbuf, int off, int len)1031 public void write(char[] cbuf, int off, int len) throws IOException { 1032 if (closed) { 1033 throw new IllegalStateException( 1034 sm.getString("wsRemoteEndpoint.closedWriter")); 1035 } 1036 if (len == 0) { 1037 return; 1038 } 1039 if ((off < 0) || (off > cbuf.length) || (len < 0) || 1040 ((off + len) > cbuf.length) || ((off + len) < 0)) { 1041 throw new IndexOutOfBoundsException(); 1042 } 1043 1044 used = true; 1045 if (buffer.remaining() == 0) { 1046 flush(); 1047 } 1048 int remaining = buffer.remaining(); 1049 int written = 0; 1050 1051 while (remaining < len - written) { 1052 buffer.put(cbuf, off + written, remaining); 1053 written += remaining; 1054 flush(); 1055 remaining = buffer.remaining(); 1056 } 1057 buffer.put(cbuf, off + written, len - written); 1058 } 1059 1060 @Override flush()1061 public void flush() throws IOException { 1062 if (closed) { 1063 throw new IllegalStateException( 1064 sm.getString("wsRemoteEndpoint.closedWriter")); 1065 } 1066 1067 if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { 1068 doWrite(false); 1069 } 1070 } 1071 1072 @Override close()1073 public void close() throws IOException { 1074 synchronized (closeLock) { 1075 if (closed) { 1076 return; 1077 } 1078 closed = true; 1079 } 1080 1081 doWrite(true); 1082 } 1083 doWrite(boolean last)1084 private void doWrite(boolean last) throws IOException { 1085 if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { 1086 buffer.flip(); 1087 endpoint.sendMessageBlock(buffer, last); 1088 buffer.clear(); 1089 } else { 1090 endpoint.stateMachine.complete(last); 1091 } 1092 } 1093 } 1094 1095 1096 private static class EncoderEntry { 1097 1098 private final Class<?> clazz; 1099 private final Encoder encoder; 1100 EncoderEntry(Class<?> clazz, Encoder encoder)1101 public EncoderEntry(Class<?> clazz, Encoder encoder) { 1102 this.clazz = clazz; 1103 this.encoder = encoder; 1104 } 1105 getClazz()1106 public Class<?> getClazz() { 1107 return clazz; 1108 } 1109 getEncoder()1110 public Encoder getEncoder() { 1111 return encoder; 1112 } 1113 } 1114 1115 1116 private enum State { 1117 OPEN, 1118 STREAM_WRITING, 1119 WRITER_WRITING, 1120 BINARY_PARTIAL_WRITING, 1121 BINARY_PARTIAL_READY, 1122 BINARY_FULL_WRITING, 1123 TEXT_PARTIAL_WRITING, 1124 TEXT_PARTIAL_READY, 1125 TEXT_FULL_WRITING 1126 } 1127 1128 1129 private static class StateMachine { 1130 private State state = State.OPEN; 1131 streamStart()1132 public synchronized void streamStart() { 1133 checkState(State.OPEN); 1134 state = State.STREAM_WRITING; 1135 } 1136 writeStart()1137 public synchronized void writeStart() { 1138 checkState(State.OPEN); 1139 state = State.WRITER_WRITING; 1140 } 1141 binaryPartialStart()1142 public synchronized void binaryPartialStart() { 1143 checkState(State.OPEN, State.BINARY_PARTIAL_READY); 1144 state = State.BINARY_PARTIAL_WRITING; 1145 } 1146 binaryStart()1147 public synchronized void binaryStart() { 1148 checkState(State.OPEN); 1149 state = State.BINARY_FULL_WRITING; 1150 } 1151 textPartialStart()1152 public synchronized void textPartialStart() { 1153 checkState(State.OPEN, State.TEXT_PARTIAL_READY); 1154 state = State.TEXT_PARTIAL_WRITING; 1155 } 1156 textStart()1157 public synchronized void textStart() { 1158 checkState(State.OPEN); 1159 state = State.TEXT_FULL_WRITING; 1160 } 1161 complete(boolean last)1162 public synchronized void complete(boolean last) { 1163 if (last) { 1164 checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING, 1165 State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING, 1166 State.STREAM_WRITING, State.WRITER_WRITING); 1167 state = State.OPEN; 1168 } else { 1169 checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING, 1170 State.STREAM_WRITING, State.WRITER_WRITING); 1171 if (state == State.TEXT_PARTIAL_WRITING) { 1172 state = State.TEXT_PARTIAL_READY; 1173 } else if (state == State.BINARY_PARTIAL_WRITING){ 1174 state = State.BINARY_PARTIAL_READY; 1175 } else if (state == State.WRITER_WRITING) { 1176 // NO-OP. Leave state as is. 1177 } else if (state == State.STREAM_WRITING) { 1178 // NO-OP. Leave state as is. 1179 } else { 1180 // Should never happen 1181 // The if ... else ... blocks above should cover all states 1182 // permitted by the preceding checkState() call 1183 throw new IllegalStateException( 1184 "BUG: This code should never be called"); 1185 } 1186 } 1187 } 1188 checkState(State... required)1189 private void checkState(State... required) { 1190 for (State state : required) { 1191 if (this.state == state) { 1192 return; 1193 } 1194 } 1195 throw new IllegalStateException( 1196 sm.getString("wsRemoteEndpoint.wrongState", this.state)); 1197 } 1198 } 1199 1200 1201 private static class StateUpdateSendHandler implements SendHandler { 1202 1203 private final SendHandler handler; 1204 private final StateMachine stateMachine; 1205 StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine)1206 public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) { 1207 this.handler = handler; 1208 this.stateMachine = stateMachine; 1209 } 1210 1211 @Override onResult(SendResult result)1212 public void onResult(SendResult result) { 1213 if (result.isOK()) { 1214 stateMachine.complete(true); 1215 } 1216 handler.onResult(result); 1217 } 1218 } 1219 1220 1221 private static class BlockingSendHandler implements SendHandler { 1222 1223 private SendResult sendResult = null; 1224 1225 @Override onResult(SendResult result)1226 public void onResult(SendResult result) { 1227 sendResult = result; 1228 } 1229 getSendResult()1230 public SendResult getSendResult() { 1231 return sendResult; 1232 } 1233 } 1234 } 1235