xref: /unit/src/java/nginx/unit/websocket/WsFrameBase.java (revision 1157:7ae152bda303)
1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */
17 package nginx.unit.websocket;
18 
19 import java.io.IOException;
20 import java.nio.ByteBuffer;
21 import java.nio.CharBuffer;
22 import java.nio.charset.CharsetDecoder;
23 import java.nio.charset.CoderResult;
24 import java.nio.charset.CodingErrorAction;
25 import java.util.List;
26 import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
27 
28 import javax.websocket.CloseReason;
29 import javax.websocket.CloseReason.CloseCodes;
30 import javax.websocket.Extension;
31 import javax.websocket.MessageHandler;
32 import javax.websocket.PongMessage;
33 
34 import org.apache.juli.logging.Log;
35 import org.apache.tomcat.util.ExceptionUtils;
36 import org.apache.tomcat.util.buf.Utf8Decoder;
37 import org.apache.tomcat.util.res.StringManager;
38 
39 /**
40  * Takes the ServletInputStream, processes the WebSocket frames it contains and
41  * extracts the messages. WebSocket Pings received will be responded to
42  * automatically without any action required by the application.
43  */
44 public abstract class WsFrameBase {
45 
46     private static final StringManager sm = StringManager.getManager(WsFrameBase.class);
47 
48     // Connection level attributes
49     protected final WsSession wsSession;
50     protected final ByteBuffer inputBuffer;
51     private final Transformation transformation;
52 
53     // Attributes for control messages
54     // Control messages can appear in the middle of other messages so need
55     // separate attributes
56     private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125);
57     private final CharBuffer controlBufferText = CharBuffer.allocate(125);
58 
59     // Attributes of the current message
60     private final CharsetDecoder utf8DecoderControl = new Utf8Decoder().
61             onMalformedInput(CodingErrorAction.REPORT).
62             onUnmappableCharacter(CodingErrorAction.REPORT);
63     private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder().
64             onMalformedInput(CodingErrorAction.REPORT).
65             onUnmappableCharacter(CodingErrorAction.REPORT);
66     private boolean continuationExpected = false;
67     private boolean textMessage = false;
68     private ByteBuffer messageBufferBinary;
69     private CharBuffer messageBufferText;
70     // Cache the message handler in force when the message starts so it is used
71     // consistently for the entire message
72     private MessageHandler binaryMsgHandler = null;
73     private MessageHandler textMsgHandler = null;
74 
75     // Attributes of the current frame
76     private boolean fin = false;
77     private int rsv = 0;
78     private byte opCode = 0;
79     private final byte[] mask = new byte[4];
80     private int maskIndex = 0;
81     private long payloadLength = 0;
82     private volatile long payloadWritten = 0;
83 
84     // Attributes tracking state
85     private volatile State state = State.NEW_FRAME;
86     private volatile boolean open = true;
87 
88     private static final AtomicReferenceFieldUpdater<WsFrameBase, ReadState> READ_STATE_UPDATER =
89             AtomicReferenceFieldUpdater.newUpdater(WsFrameBase.class, ReadState.class, "readState");
90     private volatile ReadState readState = ReadState.WAITING;
91 
WsFrameBase(WsSession wsSession, Transformation transformation)92     public WsFrameBase(WsSession wsSession, Transformation transformation) {
93         inputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
94         inputBuffer.position(0).limit(0);
95         messageBufferBinary = ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize());
96         messageBufferText = CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize());
97         this.wsSession = wsSession;
98         Transformation finalTransformation;
99         if (isMasked()) {
100             finalTransformation = new UnmaskTransformation();
101         } else {
102             finalTransformation = new NoopTransformation();
103         }
104         if (transformation == null) {
105             this.transformation = finalTransformation;
106         } else {
107             transformation.setNext(finalTransformation);
108             this.transformation = transformation;
109         }
110     }
111 
112 
processInputBuffer()113     protected void processInputBuffer() throws IOException {
114         while (!isSuspended()) {
115             wsSession.updateLastActive();
116             if (state == State.NEW_FRAME) {
117                 if (!processInitialHeader()) {
118                     break;
119                 }
120                 // If a close frame has been received, no further data should
121                 // have seen
122                 if (!open) {
123                     throw new IOException(sm.getString("wsFrame.closed"));
124                 }
125             }
126             if (state == State.PARTIAL_HEADER) {
127                 if (!processRemainingHeader()) {
128                     break;
129                 }
130             }
131             if (state == State.DATA) {
132                 if (!processData()) {
133                     break;
134                 }
135             }
136         }
137     }
138 
139 
140     /**
141      * @return <code>true</code> if sufficient data was present to process all
142      *         of the initial header
143      */
processInitialHeader()144     private boolean processInitialHeader() throws IOException {
145         // Need at least two bytes of data to do this
146         if (inputBuffer.remaining() < 2) {
147             return false;
148         }
149         int b = inputBuffer.get();
150         fin = (b & 0x80) != 0;
151         rsv = (b & 0x70) >>> 4;
152         opCode = (byte) (b & 0x0F);
153         if (!transformation.validateRsv(rsv, opCode)) {
154             throw new WsIOException(new CloseReason(
155                     CloseCodes.PROTOCOL_ERROR,
156                     sm.getString("wsFrame.wrongRsv", Integer.valueOf(rsv), Integer.valueOf(opCode))));
157         }
158 
159         if (Util.isControl(opCode)) {
160             if (!fin) {
161                 throw new WsIOException(new CloseReason(
162                         CloseCodes.PROTOCOL_ERROR,
163                         sm.getString("wsFrame.controlFragmented")));
164             }
165             if (opCode != Constants.OPCODE_PING &&
166                     opCode != Constants.OPCODE_PONG &&
167                     opCode != Constants.OPCODE_CLOSE) {
168                 throw new WsIOException(new CloseReason(
169                         CloseCodes.PROTOCOL_ERROR,
170                         sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
171             }
172         } else {
173             if (continuationExpected) {
174                 if (!Util.isContinuation(opCode)) {
175                     throw new WsIOException(new CloseReason(
176                             CloseCodes.PROTOCOL_ERROR,
177                             sm.getString("wsFrame.noContinuation")));
178                 }
179             } else {
180                 try {
181                     if (opCode == Constants.OPCODE_BINARY) {
182                         // New binary message
183                         textMessage = false;
184                         int size = wsSession.getMaxBinaryMessageBufferSize();
185                         if (size != messageBufferBinary.capacity()) {
186                             messageBufferBinary = ByteBuffer.allocate(size);
187                         }
188                         binaryMsgHandler = wsSession.getBinaryMessageHandler();
189                         textMsgHandler = null;
190                     } else if (opCode == Constants.OPCODE_TEXT) {
191                         // New text message
192                         textMessage = true;
193                         int size = wsSession.getMaxTextMessageBufferSize();
194                         if (size != messageBufferText.capacity()) {
195                             messageBufferText = CharBuffer.allocate(size);
196                         }
197                         binaryMsgHandler = null;
198                         textMsgHandler = wsSession.getTextMessageHandler();
199                     } else {
200                         throw new WsIOException(new CloseReason(
201                                 CloseCodes.PROTOCOL_ERROR,
202                                 sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
203                     }
204                 } catch (IllegalStateException ise) {
205                     // Thrown if the session is already closed
206                     throw new WsIOException(new CloseReason(
207                             CloseCodes.PROTOCOL_ERROR,
208                             sm.getString("wsFrame.sessionClosed")));
209                 }
210             }
211             continuationExpected = !fin;
212         }
213         b = inputBuffer.get();
214         // Client data must be masked
215         if ((b & 0x80) == 0 && isMasked()) {
216             throw new WsIOException(new CloseReason(
217                     CloseCodes.PROTOCOL_ERROR,
218                     sm.getString("wsFrame.notMasked")));
219         }
220         payloadLength = b & 0x7F;
221         state = State.PARTIAL_HEADER;
222         if (getLog().isDebugEnabled()) {
223             getLog().debug(sm.getString("wsFrame.partialHeaderComplete", Boolean.toString(fin),
224                     Integer.toString(rsv), Integer.toString(opCode), Long.toString(payloadLength)));
225         }
226         return true;
227     }
228 
229 
isMasked()230     protected abstract boolean isMasked();
getLog()231     protected abstract Log getLog();
232 
233 
234     /**
235      * @return <code>true</code> if sufficient data was present to complete the
236      *         processing of the header
237      */
processRemainingHeader()238     private boolean processRemainingHeader() throws IOException {
239         // Ignore the 2 bytes already read. 4 for the mask
240         int headerLength;
241         if (isMasked()) {
242             headerLength = 4;
243         } else {
244             headerLength = 0;
245         }
246         // Add additional bytes depending on length
247         if (payloadLength == 126) {
248             headerLength += 2;
249         } else if (payloadLength == 127) {
250             headerLength += 8;
251         }
252         if (inputBuffer.remaining() < headerLength) {
253             return false;
254         }
255         // Calculate new payload length if necessary
256         if (payloadLength == 126) {
257             payloadLength = byteArrayToLong(inputBuffer.array(),
258                     inputBuffer.arrayOffset() + inputBuffer.position(), 2);
259             inputBuffer.position(inputBuffer.position() + 2);
260         } else if (payloadLength == 127) {
261             payloadLength = byteArrayToLong(inputBuffer.array(),
262                     inputBuffer.arrayOffset() + inputBuffer.position(), 8);
263             inputBuffer.position(inputBuffer.position() + 8);
264         }
265         if (Util.isControl(opCode)) {
266             if (payloadLength > 125) {
267                 throw new WsIOException(new CloseReason(
268                         CloseCodes.PROTOCOL_ERROR,
269                         sm.getString("wsFrame.controlPayloadTooBig", Long.valueOf(payloadLength))));
270             }
271             if (!fin) {
272                 throw new WsIOException(new CloseReason(
273                         CloseCodes.PROTOCOL_ERROR,
274                         sm.getString("wsFrame.controlNoFin")));
275             }
276         }
277         if (isMasked()) {
278             inputBuffer.get(mask, 0, 4);
279         }
280         state = State.DATA;
281         return true;
282     }
283 
284 
processData()285     private boolean processData() throws IOException {
286         boolean result;
287         if (Util.isControl(opCode)) {
288             result = processDataControl();
289         } else if (textMessage) {
290             if (textMsgHandler == null) {
291                 result = swallowInput();
292             } else {
293                 result = processDataText();
294             }
295         } else {
296             if (binaryMsgHandler == null) {
297                 result = swallowInput();
298             } else {
299                 result = processDataBinary();
300             }
301         }
302         checkRoomPayload();
303         return result;
304     }
305 
306 
processDataControl()307     private boolean processDataControl() throws IOException {
308         TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, controlBufferBinary);
309         if (TransformationResult.UNDERFLOW.equals(tr)) {
310             return false;
311         }
312         // Control messages have fixed message size so
313         // TransformationResult.OVERFLOW is not possible here
314 
315         controlBufferBinary.flip();
316         if (opCode == Constants.OPCODE_CLOSE) {
317             open = false;
318             String reason = null;
319             int code = CloseCodes.NORMAL_CLOSURE.getCode();
320             if (controlBufferBinary.remaining() == 1) {
321                 controlBufferBinary.clear();
322                 // Payload must be zero or 2+ bytes long
323                 throw new WsIOException(new CloseReason(
324                         CloseCodes.PROTOCOL_ERROR,
325                         sm.getString("wsFrame.oneByteCloseCode")));
326             }
327             if (controlBufferBinary.remaining() > 1) {
328                 code = controlBufferBinary.getShort();
329                 if (controlBufferBinary.remaining() > 0) {
330                     CoderResult cr = utf8DecoderControl.decode(controlBufferBinary,
331                             controlBufferText, true);
332                     if (cr.isError()) {
333                         controlBufferBinary.clear();
334                         controlBufferText.clear();
335                         throw new WsIOException(new CloseReason(
336                                 CloseCodes.PROTOCOL_ERROR,
337                                 sm.getString("wsFrame.invalidUtf8Close")));
338                     }
339                     // There will be no overflow as the output buffer is big
340                     // enough. There will be no underflow as all the data is
341                     // passed to the decoder in a single call.
342                     controlBufferText.flip();
343                     reason = controlBufferText.toString();
344                 }
345             }
346             wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason));
347         } else if (opCode == Constants.OPCODE_PING) {
348             if (wsSession.isOpen()) {
349                 wsSession.getBasicRemote().sendPong(controlBufferBinary);
350             }
351         } else if (opCode == Constants.OPCODE_PONG) {
352             MessageHandler.Whole<PongMessage> mhPong = wsSession.getPongMessageHandler();
353             if (mhPong != null) {
354                 try {
355                     mhPong.onMessage(new WsPongMessage(controlBufferBinary));
356                 } catch (Throwable t) {
357                     handleThrowableOnSend(t);
358                 } finally {
359                     controlBufferBinary.clear();
360                 }
361             }
362         } else {
363             // Should have caught this earlier but just in case...
364             controlBufferBinary.clear();
365             throw new WsIOException(new CloseReason(
366                     CloseCodes.PROTOCOL_ERROR,
367                     sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
368         }
369         controlBufferBinary.clear();
370         newFrame();
371         return true;
372     }
373 
374 
375     @SuppressWarnings("unchecked")
sendMessageText(boolean last)376     protected void sendMessageText(boolean last) throws WsIOException {
377         if (textMsgHandler instanceof WrappedMessageHandler) {
378             long maxMessageSize = ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize();
379             if (maxMessageSize > -1 && messageBufferText.remaining() > maxMessageSize) {
380                 throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
381                         sm.getString("wsFrame.messageTooBig",
382                                 Long.valueOf(messageBufferText.remaining()),
383                                 Long.valueOf(maxMessageSize))));
384             }
385         }
386 
387         try {
388             if (textMsgHandler instanceof MessageHandler.Partial<?>) {
389                 ((MessageHandler.Partial<String>) textMsgHandler)
390                         .onMessage(messageBufferText.toString(), last);
391             } else {
392                 // Caller ensures last == true if this branch is used
393                 ((MessageHandler.Whole<String>) textMsgHandler)
394                         .onMessage(messageBufferText.toString());
395             }
396         } catch (Throwable t) {
397             handleThrowableOnSend(t);
398         } finally {
399             messageBufferText.clear();
400         }
401     }
402 
403 
processDataText()404     private boolean processDataText() throws IOException {
405         // Copy the available data to the buffer
406         TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
407         while (!TransformationResult.END_OF_FRAME.equals(tr)) {
408             // Frame not complete - we ran out of something
409             // Convert bytes to UTF-8
410             messageBufferBinary.flip();
411             while (true) {
412                 CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText,
413                         false);
414                 if (cr.isError()) {
415                     throw new WsIOException(new CloseReason(
416                             CloseCodes.NOT_CONSISTENT,
417                             sm.getString("wsFrame.invalidUtf8")));
418                 } else if (cr.isOverflow()) {
419                     // Ran out of space in text buffer - flush it
420                     if (usePartial()) {
421                         messageBufferText.flip();
422                         sendMessageText(false);
423                         messageBufferText.clear();
424                     } else {
425                         throw new WsIOException(new CloseReason(
426                                 CloseCodes.TOO_BIG,
427                                 sm.getString("wsFrame.textMessageTooBig")));
428                     }
429                 } else if (cr.isUnderflow()) {
430                     // Compact what we have to create as much space as possible
431                     messageBufferBinary.compact();
432 
433                     // Need more input
434                     // What did we run out of?
435                     if (TransformationResult.OVERFLOW.equals(tr)) {
436                         // Ran out of message buffer - exit inner loop and
437                         // refill
438                         break;
439                     } else {
440                         // TransformationResult.UNDERFLOW
441                         // Ran out of input data - get some more
442                         return false;
443                     }
444                 }
445             }
446             // Read more input data
447             tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
448         }
449 
450         messageBufferBinary.flip();
451         boolean last = false;
452         // Frame is fully received
453         // Convert bytes to UTF-8
454         while (true) {
455             CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText,
456                     last);
457             if (cr.isError()) {
458                 throw new WsIOException(new CloseReason(
459                         CloseCodes.NOT_CONSISTENT,
460                         sm.getString("wsFrame.invalidUtf8")));
461             } else if (cr.isOverflow()) {
462                 // Ran out of space in text buffer - flush it
463                 if (usePartial()) {
464                     messageBufferText.flip();
465                     sendMessageText(false);
466                     messageBufferText.clear();
467                 } else {
468                     throw new WsIOException(new CloseReason(
469                             CloseCodes.TOO_BIG,
470                             sm.getString("wsFrame.textMessageTooBig")));
471                 }
472             } else if (cr.isUnderflow() && !last) {
473                 // End of frame and possible message as well.
474 
475                 if (continuationExpected) {
476                     // If partial messages are supported, send what we have
477                     // managed to decode
478                     if (usePartial()) {
479                         messageBufferText.flip();
480                         sendMessageText(false);
481                         messageBufferText.clear();
482                     }
483                     messageBufferBinary.compact();
484                     newFrame();
485                     // Process next frame
486                     return true;
487                 } else {
488                     // Make sure coder has flushed all output
489                     last = true;
490                 }
491             } else {
492                 // End of message
493                 messageBufferText.flip();
494                 sendMessageText(true);
495 
496                 newMessage();
497                 return true;
498             }
499         }
500     }
501 
502 
processDataBinary()503     private boolean processDataBinary() throws IOException {
504         // Copy the available data to the buffer
505         TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
506         while (!TransformationResult.END_OF_FRAME.equals(tr)) {
507             // Frame not complete - what did we run out of?
508             if (TransformationResult.UNDERFLOW.equals(tr)) {
509                 // Ran out of input data - get some more
510                 return false;
511             }
512 
513             // Ran out of message buffer - flush it
514             if (!usePartial()) {
515                 CloseReason cr = new CloseReason(CloseCodes.TOO_BIG,
516                         sm.getString("wsFrame.bufferTooSmall",
517                                 Integer.valueOf(messageBufferBinary.capacity()),
518                                 Long.valueOf(payloadLength)));
519                 throw new WsIOException(cr);
520             }
521             messageBufferBinary.flip();
522             ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit());
523             copy.put(messageBufferBinary);
524             copy.flip();
525             sendMessageBinary(copy, false);
526             messageBufferBinary.clear();
527             // Read more data
528             tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
529         }
530 
531         // Frame is fully received
532         // Send the message if either:
533         // - partial messages are supported
534         // - the message is complete
535         if (usePartial() || !continuationExpected) {
536             messageBufferBinary.flip();
537             ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit());
538             copy.put(messageBufferBinary);
539             copy.flip();
540             sendMessageBinary(copy, !continuationExpected);
541             messageBufferBinary.clear();
542         }
543 
544         if (continuationExpected) {
545             // More data for this message expected, start a new frame
546             newFrame();
547         } else {
548             // Message is complete, start a new message
549             newMessage();
550         }
551 
552         return true;
553     }
554 
555 
handleThrowableOnSend(Throwable t)556     private void handleThrowableOnSend(Throwable t) throws WsIOException {
557         ExceptionUtils.handleThrowable(t);
558         wsSession.getLocal().onError(wsSession, t);
559         CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY,
560                 sm.getString("wsFrame.ioeTriggeredClose"));
561         throw new WsIOException(cr);
562     }
563 
564 
565     @SuppressWarnings("unchecked")
sendMessageBinary(ByteBuffer msg, boolean last)566     protected void sendMessageBinary(ByteBuffer msg, boolean last) throws WsIOException {
567         if (binaryMsgHandler instanceof WrappedMessageHandler) {
568             long maxMessageSize = ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize();
569             if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) {
570                 throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
571                         sm.getString("wsFrame.messageTooBig",
572                                 Long.valueOf(msg.remaining()),
573                                 Long.valueOf(maxMessageSize))));
574             }
575         }
576         try {
577             if (binaryMsgHandler instanceof MessageHandler.Partial<?>) {
578                 ((MessageHandler.Partial<ByteBuffer>) binaryMsgHandler).onMessage(msg, last);
579             } else {
580                 // Caller ensures last == true if this branch is used
581                 ((MessageHandler.Whole<ByteBuffer>) binaryMsgHandler).onMessage(msg);
582             }
583         } catch (Throwable t) {
584             handleThrowableOnSend(t);
585         }
586     }
587 
588 
newMessage()589     private void newMessage() {
590         messageBufferBinary.clear();
591         messageBufferText.clear();
592         utf8DecoderMessage.reset();
593         continuationExpected = false;
594         newFrame();
595     }
596 
597 
newFrame()598     private void newFrame() {
599         if (inputBuffer.remaining() == 0) {
600             inputBuffer.position(0).limit(0);
601         }
602 
603         maskIndex = 0;
604         payloadWritten = 0;
605         state = State.NEW_FRAME;
606 
607         // These get reset in processInitialHeader()
608         // fin, rsv, opCode, payloadLength, mask
609 
610         checkRoomHeaders();
611     }
612 
613 
checkRoomHeaders()614     private void checkRoomHeaders() {
615         // Is the start of the current frame too near the end of the input
616         // buffer?
617         if (inputBuffer.capacity() - inputBuffer.position() < 131) {
618             // Limit based on a control frame with a full payload
619             makeRoom();
620         }
621     }
622 
623 
checkRoomPayload()624     private void checkRoomPayload() {
625         if (inputBuffer.capacity() - inputBuffer.position() - payloadLength + payloadWritten < 0) {
626             makeRoom();
627         }
628     }
629 
630 
makeRoom()631     private void makeRoom() {
632         inputBuffer.compact();
633         inputBuffer.flip();
634     }
635 
636 
usePartial()637     private boolean usePartial() {
638         if (Util.isControl(opCode)) {
639             return false;
640         } else if (textMessage) {
641             return textMsgHandler instanceof MessageHandler.Partial;
642         } else {
643             // Must be binary
644             return binaryMsgHandler instanceof MessageHandler.Partial;
645         }
646     }
647 
648 
swallowInput()649     private boolean swallowInput() {
650         long toSkip = Math.min(payloadLength - payloadWritten, inputBuffer.remaining());
651         inputBuffer.position(inputBuffer.position() + (int) toSkip);
652         payloadWritten += toSkip;
653         if (payloadWritten == payloadLength) {
654             if (continuationExpected) {
655                 newFrame();
656             } else {
657                 newMessage();
658             }
659             return true;
660         } else {
661             return false;
662         }
663     }
664 
665 
byteArrayToLong(byte[] b, int start, int len)666     protected static long byteArrayToLong(byte[] b, int start, int len) throws IOException {
667         if (len > 8) {
668             throw new IOException(sm.getString("wsFrame.byteToLongFail", Long.valueOf(len)));
669         }
670         int shift = 0;
671         long result = 0;
672         for (int i = start + len - 1; i >= start; i--) {
673             result = result + ((b[i] & 0xFF) << shift);
674             shift += 8;
675         }
676         return result;
677     }
678 
679 
isOpen()680     protected boolean isOpen() {
681         return open;
682     }
683 
684 
getTransformation()685     protected Transformation getTransformation() {
686         return transformation;
687     }
688 
689 
690     private enum State {
691         NEW_FRAME, PARTIAL_HEADER, DATA
692     }
693 
694 
695     /**
696      * WAITING            - not suspended
697      *                      Server case: waiting for a notification that data
698      *                      is ready to be read from the socket, the socket is
699      *                      registered to the poller
700      *                      Client case: data has been read from the socket and
701      *                      is waiting for data to be processed
702      * PROCESSING         - not suspended
703      *                      Server case: reading from the socket and processing
704      *                      the data
705      *                      Client case: processing the data if such has
706      *                      already been read and more data will be read from
707      *                      the socket
708      * SUSPENDING_WAIT    - suspended, a call to suspend() was made while in
709      *                      WAITING state. A call to resume() will do nothing
710      *                      and will transition to WAITING state
711      * SUSPENDING_PROCESS - suspended, a call to suspend() was made while in
712      *                      PROCESSING state. A call to resume() will do
713      *                      nothing and will transition to PROCESSING state
714      * SUSPENDED          - suspended
715      *                      Server case: processing data finished
716      *                      (SUSPENDING_PROCESS) / a notification was received
717      *                      that data is ready to be read from the socket
718      *                      (SUSPENDING_WAIT), socket is not registered to the
719      *                      poller
720      *                      Client case: processing data finished
721      *                      (SUSPENDING_PROCESS) / data has been read from the
722      *                      socket and is available for processing
723      *                      (SUSPENDING_WAIT)
724      *                      A call to resume() will:
725      *                      Server case: register the socket to the poller
726      *                      Client case: resume data processing
727      * CLOSING            - not suspended, a close will be send
728      *
729      * <pre>
730      *     resume           data to be        resume
731      *     no action        processed         no action
732      *  |---------------| |---------------| |----------|
733      *  |               v |               v v          |
734      *  |  |----------WAITING --------PROCESSING----|  |
735      *  |  |             ^   processing             |  |
736      *  |  |             |   finished               |  |
737      *  |  |             |                          |  |
738      *  | suspend        |                     suspend |
739      *  |  |             |                          |  |
740      *  |  |          resume                        |  |
741      *  |  |    register socket to poller (server)  |  |
742      *  |  |    resume data processing (client)     |  |
743      *  |  |             |                          |  |
744      *  |  v             |                          v  |
745      * SUSPENDING_WAIT   |                  SUSPENDING_PROCESS
746      *  |                |                             |
747      *  | data available |        processing finished  |
748      *  |------------- SUSPENDED ----------------------|
749      * </pre>
750      */
751     protected enum ReadState {
752         WAITING           (false),
753         PROCESSING        (false),
754         SUSPENDING_WAIT   (true),
755         SUSPENDING_PROCESS(true),
756         SUSPENDED         (true),
757         CLOSING           (false);
758 
759         private final boolean isSuspended;
760 
ReadState(boolean isSuspended)761         ReadState(boolean isSuspended) {
762             this.isSuspended = isSuspended;
763         }
764 
isSuspended()765         public boolean isSuspended() {
766             return isSuspended;
767         }
768     }
769 
suspend()770     public void suspend() {
771         while (true) {
772             switch (readState) {
773             case WAITING:
774                 if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.WAITING,
775                         ReadState.SUSPENDING_WAIT)) {
776                     continue;
777                 }
778                 return;
779             case PROCESSING:
780                 if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.PROCESSING,
781                         ReadState.SUSPENDING_PROCESS)) {
782                     continue;
783                 }
784                 return;
785             case SUSPENDING_WAIT:
786                 if (readState != ReadState.SUSPENDING_WAIT) {
787                     continue;
788                 } else {
789                     if (getLog().isWarnEnabled()) {
790                         getLog().warn(sm.getString("wsFrame.suspendRequested"));
791                     }
792                 }
793                 return;
794             case SUSPENDING_PROCESS:
795                 if (readState != ReadState.SUSPENDING_PROCESS) {
796                     continue;
797                 } else {
798                     if (getLog().isWarnEnabled()) {
799                         getLog().warn(sm.getString("wsFrame.suspendRequested"));
800                     }
801                 }
802                 return;
803             case SUSPENDED:
804                 if (readState != ReadState.SUSPENDED) {
805                     continue;
806                 } else {
807                     if (getLog().isWarnEnabled()) {
808                         getLog().warn(sm.getString("wsFrame.alreadySuspended"));
809                     }
810                 }
811                 return;
812             case CLOSING:
813                 return;
814             default:
815                 throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state));
816             }
817         }
818     }
819 
resume()820     public void resume() {
821         while (true) {
822             switch (readState) {
823             case WAITING:
824                 if (readState != ReadState.WAITING) {
825                     continue;
826                 } else {
827                     if (getLog().isWarnEnabled()) {
828                         getLog().warn(sm.getString("wsFrame.alreadyResumed"));
829                     }
830                 }
831                 return;
832             case PROCESSING:
833                 if (readState != ReadState.PROCESSING) {
834                     continue;
835                 } else {
836                     if (getLog().isWarnEnabled()) {
837                         getLog().warn(sm.getString("wsFrame.alreadyResumed"));
838                     }
839                 }
840                 return;
841             case SUSPENDING_WAIT:
842                 if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_WAIT,
843                         ReadState.WAITING)) {
844                     continue;
845                 }
846                 return;
847             case SUSPENDING_PROCESS:
848                 if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_PROCESS,
849                         ReadState.PROCESSING)) {
850                     continue;
851                 }
852                 return;
853             case SUSPENDED:
854                 if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDED,
855                         ReadState.WAITING)) {
856                     continue;
857                 }
858                 resumeProcessing();
859                 return;
860             case CLOSING:
861                 return;
862             default:
863                 throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state));
864             }
865         }
866     }
867 
isSuspended()868     protected boolean isSuspended() {
869         return readState.isSuspended();
870     }
871 
getReadState()872     protected ReadState getReadState() {
873         return readState;
874     }
875 
changeReadState(ReadState newState)876     protected void changeReadState(ReadState newState) {
877         READ_STATE_UPDATER.set(this, newState);
878     }
879 
changeReadState(ReadState oldState, ReadState newState)880     protected boolean changeReadState(ReadState oldState, ReadState newState) {
881         return READ_STATE_UPDATER.compareAndSet(this, oldState, newState);
882     }
883 
884     /**
885      * This method will be invoked when the read operation is resumed.
886      * As the suspend of the read operation can be invoked at any time, when
887      * implementing this method one should consider that there might still be
888      * data remaining into the internal buffers that needs to be processed
889      * before reading again from the socket.
890      */
resumeProcessing()891     protected abstract void resumeProcessing();
892 
893 
894     private abstract class TerminalTransformation implements Transformation {
895 
896         @Override
validateRsvBits(int i)897         public boolean validateRsvBits(int i) {
898             // Terminal transformations don't use RSV bits and there is no next
899             // transformation so always return true.
900             return true;
901         }
902 
903         @Override
getExtensionResponse()904         public Extension getExtensionResponse() {
905             // Return null since terminal transformations are not extensions
906             return null;
907         }
908 
909         @Override
setNext(Transformation t)910         public void setNext(Transformation t) {
911             // NO-OP since this is the terminal transformation
912         }
913 
914         /**
915          * {@inheritDoc}
916          * <p>
917          * Anything other than a value of zero for rsv is invalid.
918          */
919         @Override
validateRsv(int rsv, byte opCode)920         public boolean validateRsv(int rsv, byte opCode) {
921             return rsv == 0;
922         }
923 
924         @Override
close()925         public void close() {
926             // NO-OP for the terminal transformations
927         }
928     }
929 
930 
931     /**
932      * For use by the client implementation that needs to obtain payload data
933      * without the need for unmasking.
934      */
935     private final class NoopTransformation extends TerminalTransformation {
936 
937         @Override
getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)938         public TransformationResult getMoreData(byte opCode, boolean fin, int rsv,
939                 ByteBuffer dest) {
940             // opCode is ignored as the transformation is the same for all
941             // opCodes
942             // rsv is ignored as it known to be zero at this point
943             long toWrite = Math.min(payloadLength - payloadWritten, inputBuffer.remaining());
944             toWrite = Math.min(toWrite, dest.remaining());
945 
946             int orgLimit = inputBuffer.limit();
947             inputBuffer.limit(inputBuffer.position() + (int) toWrite);
948             dest.put(inputBuffer);
949             inputBuffer.limit(orgLimit);
950             payloadWritten += toWrite;
951 
952             if (payloadWritten == payloadLength) {
953                 return TransformationResult.END_OF_FRAME;
954             } else if (inputBuffer.remaining() == 0) {
955                 return TransformationResult.UNDERFLOW;
956             } else {
957                 // !dest.hasRemaining()
958                 return TransformationResult.OVERFLOW;
959             }
960         }
961 
962 
963         @Override
sendMessagePart(List<MessagePart> messageParts)964         public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) {
965             // TODO Masking should move to this method
966             // NO-OP send so simply return the message unchanged.
967             return messageParts;
968         }
969     }
970 
971 
972     /**
973      * For use by the server implementation that needs to obtain payload data
974      * and unmask it before any further processing.
975      */
976     private final class UnmaskTransformation extends TerminalTransformation {
977 
978         @Override
getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)979         public TransformationResult getMoreData(byte opCode, boolean fin, int rsv,
980                 ByteBuffer dest) {
981             // opCode is ignored as the transformation is the same for all
982             // opCodes
983             // rsv is ignored as it known to be zero at this point
984             while (payloadWritten < payloadLength && inputBuffer.remaining() > 0 &&
985                     dest.hasRemaining()) {
986                 byte b = (byte) ((inputBuffer.get() ^ mask[maskIndex]) & 0xFF);
987                 maskIndex++;
988                 if (maskIndex == 4) {
989                     maskIndex = 0;
990                 }
991                 payloadWritten++;
992                 dest.put(b);
993             }
994             if (payloadWritten == payloadLength) {
995                 return TransformationResult.END_OF_FRAME;
996             } else if (inputBuffer.remaining() == 0) {
997                 return TransformationResult.UNDERFLOW;
998             } else {
999                 // !dest.hasRemaining()
1000                 return TransformationResult.OVERFLOW;
1001             }
1002         }
1003 
1004         @Override
sendMessagePart(List<MessagePart> messageParts)1005         public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) {
1006             // NO-OP send so simply return the message unchanged.
1007             return messageParts;
1008         }
1009     }
1010 }
1011