1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */
17 package nginx.unit.websocket;
18 
19 import java.io.IOException;
20 import java.io.OutputStream;
21 import java.io.Writer;
22 import java.net.SocketTimeoutException;
23 import java.nio.ByteBuffer;
24 import java.nio.CharBuffer;
25 import java.nio.charset.CharsetEncoder;
26 import java.nio.charset.CoderResult;
27 import java.util.ArrayDeque;
28 import java.util.ArrayList;
29 import java.util.List;
30 import java.util.Queue;
31 import java.util.concurrent.Future;
32 import java.util.concurrent.Semaphore;
33 import java.util.concurrent.TimeUnit;
34 import java.util.concurrent.atomic.AtomicBoolean;
35 
36 import javax.websocket.CloseReason;
37 import javax.websocket.CloseReason.CloseCodes;
38 import javax.websocket.DeploymentException;
39 import javax.websocket.EncodeException;
40 import javax.websocket.Encoder;
41 import javax.websocket.EndpointConfig;
42 import javax.websocket.RemoteEndpoint;
43 import javax.websocket.SendHandler;
44 import javax.websocket.SendResult;
45 
46 import org.apache.juli.logging.Log;
47 import org.apache.juli.logging.LogFactory;
48 import org.apache.tomcat.util.buf.Utf8Encoder;
49 import org.apache.tomcat.util.res.StringManager;
50 
51 import nginx.unit.Request;
52 
53 public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint {
54 
55     private static final StringManager sm =
56             StringManager.getManager(WsRemoteEndpointImplBase.class);
57 
58     protected static final SendResult SENDRESULT_OK = new SendResult();
59 
60     private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static
61 
62     private final StateMachine stateMachine = new StateMachine();
63 
64     private final IntermediateMessageHandler intermediateMessageHandler =
65             new IntermediateMessageHandler(this);
66 
67     private Transformation transformation = null;
68     private final Semaphore messagePartInProgress = new Semaphore(1);
69     private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>();
70     private final Object messagePartLock = new Object();
71 
72     // State
73     private volatile boolean closed = false;
74     private boolean fragmented = false;
75     private boolean nextFragmented = false;
76     private boolean text = false;
77     private boolean nextText = false;
78 
79     // Max size of WebSocket header is 14 bytes
80     private final ByteBuffer headerBuffer = ByteBuffer.allocate(14);
81     private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
82     private final CharsetEncoder encoder = new Utf8Encoder();
83     private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
84     private final AtomicBoolean batchingAllowed = new AtomicBoolean(false);
85     private volatile long sendTimeout = -1;
86     private WsSession wsSession;
87     private List<EncoderEntry> encoderEntries = new ArrayList<>();
88 
89     private Request request;
90 
91 
setTransformation(Transformation transformation)92     protected void setTransformation(Transformation transformation) {
93         this.transformation = transformation;
94     }
95 
96 
getSendTimeout()97     public long getSendTimeout() {
98         return sendTimeout;
99     }
100 
101 
setSendTimeout(long timeout)102     public void setSendTimeout(long timeout) {
103         this.sendTimeout = timeout;
104     }
105 
106 
107     @Override
setBatchingAllowed(boolean batchingAllowed)108     public void setBatchingAllowed(boolean batchingAllowed) throws IOException {
109         boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
110 
111         if (oldValue && !batchingAllowed) {
112             flushBatch();
113         }
114     }
115 
116 
117     @Override
getBatchingAllowed()118     public boolean getBatchingAllowed() {
119         return batchingAllowed.get();
120     }
121 
122 
123     @Override
flushBatch()124     public void flushBatch() throws IOException {
125         sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true);
126     }
127 
128 
sendBytes(ByteBuffer data)129     public void sendBytes(ByteBuffer data) throws IOException {
130         if (data == null) {
131             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
132         }
133         stateMachine.binaryStart();
134         sendMessageBlock(Constants.OPCODE_BINARY, data, true);
135         stateMachine.complete(true);
136     }
137 
138 
sendBytesByFuture(ByteBuffer data)139     public Future<Void> sendBytesByFuture(ByteBuffer data) {
140         FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
141         sendBytesByCompletion(data, f2sh);
142         return f2sh;
143     }
144 
145 
sendBytesByCompletion(ByteBuffer data, SendHandler handler)146     public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) {
147         if (data == null) {
148             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
149         }
150         if (handler == null) {
151             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
152         }
153         StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine);
154         stateMachine.binaryStart();
155         startMessage(Constants.OPCODE_BINARY, data, true, sush);
156     }
157 
158 
sendPartialBytes(ByteBuffer partialByte, boolean last)159     public void sendPartialBytes(ByteBuffer partialByte, boolean last)
160             throws IOException {
161         if (partialByte == null) {
162             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
163         }
164         stateMachine.binaryPartialStart();
165         sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last);
166         stateMachine.complete(last);
167     }
168 
169 
170     @Override
sendPing(ByteBuffer applicationData)171     public void sendPing(ByteBuffer applicationData) throws IOException,
172             IllegalArgumentException {
173         if (applicationData.remaining() > 125) {
174             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
175         }
176         sendMessageBlock(Constants.OPCODE_PING, applicationData, true);
177     }
178 
179 
180     @Override
sendPong(ByteBuffer applicationData)181     public void sendPong(ByteBuffer applicationData) throws IOException,
182             IllegalArgumentException {
183         if (applicationData.remaining() > 125) {
184             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
185         }
186         sendMessageBlock(Constants.OPCODE_PONG, applicationData, true);
187     }
188 
189 
sendString(String text)190     public void sendString(String text) throws IOException {
191         if (text == null) {
192             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
193         }
194         stateMachine.textStart();
195         sendMessageBlock(CharBuffer.wrap(text), true);
196     }
197 
198 
sendStringByFuture(String text)199     public Future<Void> sendStringByFuture(String text) {
200         FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
201         sendStringByCompletion(text, f2sh);
202         return f2sh;
203     }
204 
205 
sendStringByCompletion(String text, SendHandler handler)206     public void sendStringByCompletion(String text, SendHandler handler) {
207         if (text == null) {
208             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
209         }
210         if (handler == null) {
211             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
212         }
213         stateMachine.textStart();
214         TextMessageSendHandler tmsh = new TextMessageSendHandler(handler,
215                 CharBuffer.wrap(text), true, encoder, encoderBuffer, this);
216         tmsh.write();
217         // TextMessageSendHandler will update stateMachine when it completes
218     }
219 
220 
sendPartialString(String fragment, boolean isLast)221     public void sendPartialString(String fragment, boolean isLast)
222             throws IOException {
223         if (fragment == null) {
224             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
225         }
226         stateMachine.textPartialStart();
227         sendMessageBlock(CharBuffer.wrap(fragment), isLast);
228     }
229 
230 
getSendStream()231     public OutputStream getSendStream() {
232         stateMachine.streamStart();
233         return new WsOutputStream(this);
234     }
235 
236 
getSendWriter()237     public Writer getSendWriter() {
238         stateMachine.writeStart();
239         return new WsWriter(this);
240     }
241 
242 
sendMessageBlock(CharBuffer part, boolean last)243     void sendMessageBlock(CharBuffer part, boolean last) throws IOException {
244         long timeoutExpiry = getTimeoutExpiry();
245         boolean isDone = false;
246         while (!isDone) {
247             encoderBuffer.clear();
248             CoderResult cr = encoder.encode(part, encoderBuffer, true);
249             if (cr.isError()) {
250                 throw new IllegalArgumentException(cr.toString());
251             }
252             isDone = !cr.isOverflow();
253             encoderBuffer.flip();
254             sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry);
255         }
256         stateMachine.complete(last);
257     }
258 
259 
sendMessageBlock(byte opCode, ByteBuffer payload, boolean last)260     void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last)
261             throws IOException {
262         sendMessageBlock(opCode, payload, last, getTimeoutExpiry());
263     }
264 
265 
getTimeoutExpiry()266     private long getTimeoutExpiry() {
267         // Get the timeout before we send the message. The message may
268         // trigger a session close and depending on timing the client
269         // session may close before we can read the timeout.
270         long timeout = getBlockingSendTimeout();
271         if (timeout < 0) {
272             return Long.MAX_VALUE;
273         } else {
274             return System.currentTimeMillis() + timeout;
275         }
276     }
277 
278     private byte currentOpCode = Constants.OPCODE_CONTINUATION;
279 
sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, long timeoutExpiry)280     private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last,
281             long timeoutExpiry) throws IOException {
282         wsSession.updateLastActive();
283 
284         if (opCode == currentOpCode) {
285             opCode = Constants.OPCODE_CONTINUATION;
286         }
287 
288         request.sendWsFrame(payload, opCode, last, timeoutExpiry);
289 
290         if (!last && opCode != Constants.OPCODE_CONTINUATION) {
291             currentOpCode = opCode;
292         }
293 
294         if (last && opCode == Constants.OPCODE_CONTINUATION) {
295             currentOpCode = Constants.OPCODE_CONTINUATION;
296         }
297     }
298 
299 
startMessage(byte opCode, ByteBuffer payload, boolean last, SendHandler handler)300     void startMessage(byte opCode, ByteBuffer payload, boolean last,
301             SendHandler handler) {
302 
303         wsSession.updateLastActive();
304 
305         List<MessagePart> messageParts = new ArrayList<>();
306         messageParts.add(new MessagePart(last, 0, opCode, payload,
307                 intermediateMessageHandler,
308                 new EndMessageHandler(this, handler), -1));
309 
310         messageParts = transformation.sendMessagePart(messageParts);
311 
312         // Some extensions/transformations may buffer messages so it is possible
313         // that no message parts will be returned. If this is the case the
314         // trigger the supplied SendHandler
315         if (messageParts.size() == 0) {
316             handler.onResult(new SendResult());
317             return;
318         }
319 
320         MessagePart mp = messageParts.remove(0);
321 
322         boolean doWrite = false;
323         synchronized (messagePartLock) {
324             if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) {
325                 // Should not happen. To late to send batched messages now since
326                 // the session has been closed. Complain loudly.
327                 log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed"));
328             }
329             if (messagePartInProgress.tryAcquire()) {
330                 doWrite = true;
331             } else {
332                 // When a control message is sent while another message is being
333                 // sent, the control message is queued. Chances are the
334                 // subsequent data message part will end up queued while the
335                 // control message is sent. The logic in this class (state
336                 // machine, EndMessageHandler, TextMessageSendHandler) ensures
337                 // that there will only ever be one data message part in the
338                 // queue. There could be multiple control messages in the queue.
339 
340                 // Add it to the queue
341                 messagePartQueue.add(mp);
342             }
343             // Add any remaining messages to the queue
344             messagePartQueue.addAll(messageParts);
345         }
346         if (doWrite) {
347             // Actual write has to be outside sync block to avoid possible
348             // deadlock between messagePartLock and writeLock in
349             // o.a.coyote.http11.upgrade.AbstractServletOutputStream
350             writeMessagePart(mp);
351         }
352     }
353 
354 
endMessage(SendHandler handler, SendResult result)355     void endMessage(SendHandler handler, SendResult result) {
356         boolean doWrite = false;
357         MessagePart mpNext = null;
358         synchronized (messagePartLock) {
359 
360             fragmented = nextFragmented;
361             text = nextText;
362 
363             mpNext = messagePartQueue.poll();
364             if (mpNext == null) {
365                 messagePartInProgress.release();
366             } else if (!closed){
367                 // Session may have been closed unexpectedly in the middle of
368                 // sending a fragmented message closing the endpoint. If this
369                 // happens, clearly there is no point trying to send the rest of
370                 // the message.
371                 doWrite = true;
372             }
373         }
374         if (doWrite) {
375             // Actual write has to be outside sync block to avoid possible
376             // deadlock between messagePartLock and writeLock in
377             // o.a.coyote.http11.upgrade.AbstractServletOutputStream
378             writeMessagePart(mpNext);
379         }
380 
381         wsSession.updateLastActive();
382 
383         // Some handlers, such as the IntermediateMessageHandler, do not have a
384         // nested handler so handler may be null.
385         if (handler != null) {
386             handler.onResult(result);
387         }
388     }
389 
390 
writeMessagePart(MessagePart mp)391     void writeMessagePart(MessagePart mp) {
392         if (closed) {
393             throw new IllegalStateException(
394                     sm.getString("wsRemoteEndpoint.closed"));
395         }
396 
397         if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) {
398             nextFragmented = fragmented;
399             nextText = text;
400             outputBuffer.flip();
401             SendHandler flushHandler = new OutputBufferFlushSendHandler(
402                     outputBuffer, mp.getEndHandler());
403             doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer);
404             return;
405         }
406 
407         // Control messages may be sent in the middle of fragmented message
408         // so they have no effect on the fragmented or text flags
409         boolean first;
410         if (Util.isControl(mp.getOpCode())) {
411             nextFragmented = fragmented;
412             nextText = text;
413             if (mp.getOpCode() == Constants.OPCODE_CLOSE) {
414                 closed = true;
415             }
416             first = true;
417         } else {
418             boolean isText = Util.isText(mp.getOpCode());
419 
420             if (fragmented) {
421                 // Currently fragmented
422                 if (text != isText) {
423                     throw new IllegalStateException(
424                             sm.getString("wsRemoteEndpoint.changeType"));
425                 }
426                 nextText = text;
427                 nextFragmented = !mp.isFin();
428                 first = false;
429             } else {
430                 // Wasn't fragmented. Might be now
431                 if (mp.isFin()) {
432                     nextFragmented = false;
433                 } else {
434                     nextFragmented = true;
435                     nextText = isText;
436                 }
437                 first = true;
438             }
439         }
440 
441         byte[] mask;
442 
443         if (isMasked()) {
444             mask = Util.generateMask();
445         } else {
446             mask = null;
447         }
448 
449         headerBuffer.clear();
450         writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(),
451                 isMasked(), mp.getPayload(), mask, first);
452         headerBuffer.flip();
453 
454         if (getBatchingAllowed() || isMasked()) {
455             // Need to write via output buffer
456             OutputBufferSendHandler obsh = new OutputBufferSendHandler(
457                     mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
458                     headerBuffer, mp.getPayload(), mask,
459                     outputBuffer, !getBatchingAllowed(), this);
460             obsh.write();
461         } else {
462             // Can write directly
463             doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
464                     headerBuffer, mp.getPayload());
465         }
466     }
467 
468 
getBlockingSendTimeout()469     private long getBlockingSendTimeout() {
470         Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY);
471         Long userTimeout = null;
472         if (obj instanceof Long) {
473             userTimeout = (Long) obj;
474         }
475         if (userTimeout == null) {
476             return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT;
477         } else {
478             return userTimeout.longValue();
479         }
480     }
481 
482 
483     /**
484      * Wraps the user provided handler so that the end point is notified when
485      * the message is complete.
486      */
487     private static class EndMessageHandler implements SendHandler {
488 
489         private final WsRemoteEndpointImplBase endpoint;
490         private final SendHandler handler;
491 
EndMessageHandler(WsRemoteEndpointImplBase endpoint, SendHandler handler)492         public EndMessageHandler(WsRemoteEndpointImplBase endpoint,
493                 SendHandler handler) {
494             this.endpoint = endpoint;
495             this.handler = handler;
496         }
497 
498 
499         @Override
onResult(SendResult result)500         public void onResult(SendResult result) {
501             endpoint.endMessage(handler, result);
502         }
503     }
504 
505 
506     /**
507      * If a transformation needs to split a {@link MessagePart} into multiple
508      * {@link MessagePart}s, it uses this handler as the end handler for each of
509      * the additional {@link MessagePart}s. This handler notifies this this
510      * class that the {@link MessagePart} has been processed and that the next
511      * {@link MessagePart} in the queue should be started. The final
512      * {@link MessagePart} will use the {@link EndMessageHandler} provided with
513      * the original {@link MessagePart}.
514      */
515     private static class IntermediateMessageHandler implements SendHandler {
516 
517         private final WsRemoteEndpointImplBase endpoint;
518 
IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint)519         public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) {
520             this.endpoint = endpoint;
521         }
522 
523 
524         @Override
onResult(SendResult result)525         public void onResult(SendResult result) {
526             endpoint.endMessage(null, result);
527         }
528     }
529 
530 
531     @SuppressWarnings({"unchecked", "rawtypes"})
sendObject(Object obj)532     public void sendObject(Object obj) throws IOException, EncodeException {
533         if (obj == null) {
534             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
535         }
536         /*
537          * Note that the implementation will convert primitives and their object
538          * equivalents by default but that users are free to specify their own
539          * encoders and decoders for this if they wish.
540          */
541         Encoder encoder = findEncoder(obj);
542         if (encoder == null && Util.isPrimitive(obj.getClass())) {
543             String msg = obj.toString();
544             sendString(msg);
545             return;
546         }
547         if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
548             ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
549             sendBytes(msg);
550             return;
551         }
552 
553         if (encoder instanceof Encoder.Text) {
554             String msg = ((Encoder.Text) encoder).encode(obj);
555             sendString(msg);
556         } else if (encoder instanceof Encoder.TextStream) {
557             try (Writer w = getSendWriter()) {
558                 ((Encoder.TextStream) encoder).encode(obj, w);
559             }
560         } else if (encoder instanceof Encoder.Binary) {
561             ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
562             sendBytes(msg);
563         } else if (encoder instanceof Encoder.BinaryStream) {
564             try (OutputStream os = getSendStream()) {
565                 ((Encoder.BinaryStream) encoder).encode(obj, os);
566             }
567         } else {
568             throw new EncodeException(obj, sm.getString(
569                     "wsRemoteEndpoint.noEncoder", obj.getClass()));
570         }
571     }
572 
573 
sendObjectByFuture(Object obj)574     public Future<Void> sendObjectByFuture(Object obj) {
575         FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
576         sendObjectByCompletion(obj, f2sh);
577         return f2sh;
578     }
579 
580 
581     @SuppressWarnings({"unchecked", "rawtypes"})
sendObjectByCompletion(Object obj, SendHandler completion)582     public void sendObjectByCompletion(Object obj, SendHandler completion) {
583 
584         if (obj == null) {
585             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
586         }
587         if (completion == null) {
588             throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
589         }
590 
591         /*
592          * Note that the implementation will convert primitives and their object
593          * equivalents by default but that users are free to specify their own
594          * encoders and decoders for this if they wish.
595          */
596         Encoder encoder = findEncoder(obj);
597         if (encoder == null && Util.isPrimitive(obj.getClass())) {
598             String msg = obj.toString();
599             sendStringByCompletion(msg, completion);
600             return;
601         }
602         if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
603             ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
604             sendBytesByCompletion(msg, completion);
605             return;
606         }
607 
608         try {
609             if (encoder instanceof Encoder.Text) {
610                 String msg = ((Encoder.Text) encoder).encode(obj);
611                 sendStringByCompletion(msg, completion);
612             } else if (encoder instanceof Encoder.TextStream) {
613                 try (Writer w = getSendWriter()) {
614                     ((Encoder.TextStream) encoder).encode(obj, w);
615                 }
616                 completion.onResult(new SendResult());
617             } else if (encoder instanceof Encoder.Binary) {
618                 ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
619                 sendBytesByCompletion(msg, completion);
620             } else if (encoder instanceof Encoder.BinaryStream) {
621                 try (OutputStream os = getSendStream()) {
622                     ((Encoder.BinaryStream) encoder).encode(obj, os);
623                 }
624                 completion.onResult(new SendResult());
625             } else {
626                 throw new EncodeException(obj, sm.getString(
627                         "wsRemoteEndpoint.noEncoder", obj.getClass()));
628             }
629         } catch (Exception e) {
630             SendResult sr = new SendResult(e);
631             completion.onResult(sr);
632         }
633     }
634 
635 
setSession(WsSession wsSession)636     protected void setSession(WsSession wsSession) {
637         this.wsSession = wsSession;
638     }
639 
640 
setRequest(Request request)641     protected void setRequest(Request request) {
642         this.request = request;
643     }
644 
setEncoders(EndpointConfig endpointConfig)645     protected void setEncoders(EndpointConfig endpointConfig)
646             throws DeploymentException {
647         encoderEntries.clear();
648         for (Class<? extends Encoder> encoderClazz :
649                 endpointConfig.getEncoders()) {
650             Encoder instance;
651             try {
652                 instance = encoderClazz.getConstructor().newInstance();
653                 instance.init(endpointConfig);
654             } catch (ReflectiveOperationException e) {
655                 throw new DeploymentException(
656                         sm.getString("wsRemoteEndpoint.invalidEncoder",
657                                 encoderClazz.getName()), e);
658             }
659             EncoderEntry entry = new EncoderEntry(
660                     Util.getEncoderType(encoderClazz), instance);
661             encoderEntries.add(entry);
662         }
663     }
664 
665 
findEncoder(Object obj)666     private Encoder findEncoder(Object obj) {
667         for (EncoderEntry entry : encoderEntries) {
668             if (entry.getClazz().isAssignableFrom(obj.getClass())) {
669                 return entry.getEncoder();
670             }
671         }
672         return null;
673     }
674 
675 
close()676     public final void close() {
677         for (EncoderEntry entry : encoderEntries) {
678             entry.getEncoder().destroy();
679         }
680 
681         request.closeWs();
682     }
683 
684 
doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, ByteBuffer... data)685     protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
686             ByteBuffer... data);
isMasked()687     protected abstract boolean isMasked();
doClose()688     protected abstract void doClose();
689 
writeHeader(ByteBuffer headerBuffer, boolean fin, int rsv, byte opCode, boolean masked, ByteBuffer payload, byte[] mask, boolean first)690     private static void writeHeader(ByteBuffer headerBuffer, boolean fin,
691             int rsv, byte opCode, boolean masked, ByteBuffer payload,
692             byte[] mask, boolean first) {
693 
694         byte b = 0;
695 
696         if (fin) {
697             // Set the fin bit
698             b -= 128;
699         }
700 
701         b += (rsv << 4);
702 
703         if (first) {
704             // This is the first fragment of this message
705             b += opCode;
706         }
707         // If not the first fragment, it is a continuation with opCode of zero
708 
709         headerBuffer.put(b);
710 
711         if (masked) {
712             b = (byte) 0x80;
713         } else {
714             b = 0;
715         }
716 
717         // Next write the mask && length length
718         if (payload.limit() < 126) {
719             headerBuffer.put((byte) (payload.limit() | b));
720         } else if (payload.limit() < 65536) {
721             headerBuffer.put((byte) (126 | b));
722             headerBuffer.put((byte) (payload.limit() >>> 8));
723             headerBuffer.put((byte) (payload.limit() & 0xFF));
724         } else {
725             // Will never be more than 2^31-1
726             headerBuffer.put((byte) (127 | b));
727             headerBuffer.put((byte) 0);
728             headerBuffer.put((byte) 0);
729             headerBuffer.put((byte) 0);
730             headerBuffer.put((byte) 0);
731             headerBuffer.put((byte) (payload.limit() >>> 24));
732             headerBuffer.put((byte) (payload.limit() >>> 16));
733             headerBuffer.put((byte) (payload.limit() >>> 8));
734             headerBuffer.put((byte) (payload.limit() & 0xFF));
735         }
736         if (masked) {
737             headerBuffer.put(mask[0]);
738             headerBuffer.put(mask[1]);
739             headerBuffer.put(mask[2]);
740             headerBuffer.put(mask[3]);
741         }
742     }
743 
744 
745     private class TextMessageSendHandler implements SendHandler {
746 
747         private final SendHandler handler;
748         private final CharBuffer message;
749         private final boolean isLast;
750         private final CharsetEncoder encoder;
751         private final ByteBuffer buffer;
752         private final WsRemoteEndpointImplBase endpoint;
753         private volatile boolean isDone = false;
754 
TextMessageSendHandler(SendHandler handler, CharBuffer message, boolean isLast, CharsetEncoder encoder, ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint)755         public TextMessageSendHandler(SendHandler handler, CharBuffer message,
756                 boolean isLast, CharsetEncoder encoder,
757                 ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) {
758             this.handler = handler;
759             this.message = message;
760             this.isLast = isLast;
761             this.encoder = encoder.reset();
762             this.buffer = encoderBuffer;
763             this.endpoint = endpoint;
764         }
765 
write()766         public void write() {
767             buffer.clear();
768             CoderResult cr = encoder.encode(message, buffer, true);
769             if (cr.isError()) {
770                 throw new IllegalArgumentException(cr.toString());
771             }
772             isDone = !cr.isOverflow();
773             buffer.flip();
774             endpoint.startMessage(Constants.OPCODE_TEXT, buffer,
775                     isDone && isLast, this);
776         }
777 
778         @Override
onResult(SendResult result)779         public void onResult(SendResult result) {
780             if (isDone) {
781                 endpoint.stateMachine.complete(isLast);
782                 handler.onResult(result);
783             } else if(!result.isOK()) {
784                 handler.onResult(result);
785             } else if (closed){
786                 SendResult sr = new SendResult(new IOException(
787                         sm.getString("wsRemoteEndpoint.closedDuringMessage")));
788                 handler.onResult(sr);
789             } else {
790                 write();
791             }
792         }
793     }
794 
795 
796     /**
797      * Used to write data to the output buffer, flushing the buffer if it fills
798      * up.
799      */
800     private static class OutputBufferSendHandler implements SendHandler {
801 
802         private final SendHandler handler;
803         private final long blockingWriteTimeoutExpiry;
804         private final ByteBuffer headerBuffer;
805         private final ByteBuffer payload;
806         private final byte[] mask;
807         private final ByteBuffer outputBuffer;
808         private final boolean flushRequired;
809         private final WsRemoteEndpointImplBase endpoint;
810         private int maskIndex = 0;
811 
OutputBufferSendHandler(SendHandler completion, long blockingWriteTimeoutExpiry, ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, ByteBuffer outputBuffer, boolean flushRequired, WsRemoteEndpointImplBase endpoint)812         public OutputBufferSendHandler(SendHandler completion,
813                 long blockingWriteTimeoutExpiry,
814                 ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask,
815                 ByteBuffer outputBuffer, boolean flushRequired,
816                 WsRemoteEndpointImplBase endpoint) {
817             this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry;
818             this.handler = completion;
819             this.headerBuffer = headerBuffer;
820             this.payload = payload;
821             this.mask = mask;
822             this.outputBuffer = outputBuffer;
823             this.flushRequired = flushRequired;
824             this.endpoint = endpoint;
825         }
826 
write()827         public void write() {
828             // Write the header
829             while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) {
830                 outputBuffer.put(headerBuffer.get());
831             }
832             if (headerBuffer.hasRemaining()) {
833                 // Still more headers to write, need to flush
834                 outputBuffer.flip();
835                 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
836                 return;
837             }
838 
839             // Write the payload
840             int payloadLeft = payload.remaining();
841             int payloadLimit = payload.limit();
842             int outputSpace = outputBuffer.remaining();
843             int toWrite = payloadLeft;
844 
845             if (payloadLeft > outputSpace) {
846                 toWrite = outputSpace;
847                 // Temporarily reduce the limit
848                 payload.limit(payload.position() + toWrite);
849             }
850 
851             if (mask == null) {
852                 // Use a bulk copy
853                 outputBuffer.put(payload);
854             } else {
855                 for (int i = 0; i < toWrite; i++) {
856                     outputBuffer.put(
857                             (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF)));
858                     if (maskIndex > 3) {
859                         maskIndex = 0;
860                     }
861                 }
862             }
863 
864             if (payloadLeft > outputSpace) {
865                 // Restore the original limit
866                 payload.limit(payloadLimit);
867                 // Still more data to write, need to flush
868                 outputBuffer.flip();
869                 endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
870                 return;
871             }
872 
873             if (flushRequired) {
874                 outputBuffer.flip();
875                 if (outputBuffer.remaining() == 0) {
876                     handler.onResult(SENDRESULT_OK);
877                 } else {
878                     endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
879                 }
880             } else {
881                 handler.onResult(SENDRESULT_OK);
882             }
883         }
884 
885         // ------------------------------------------------- SendHandler methods
886         @Override
onResult(SendResult result)887         public void onResult(SendResult result) {
888             if (result.isOK()) {
889                 if (outputBuffer.hasRemaining()) {
890                     endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
891                 } else {
892                     outputBuffer.clear();
893                     write();
894                 }
895             } else {
896                 handler.onResult(result);
897             }
898         }
899     }
900 
901 
902     /**
903      * Ensures that the output buffer is cleared after it has been flushed.
904      */
905     private static class OutputBufferFlushSendHandler implements SendHandler {
906 
907         private final ByteBuffer outputBuffer;
908         private final SendHandler handler;
909 
OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler)910         public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) {
911             this.outputBuffer = outputBuffer;
912             this.handler = handler;
913         }
914 
915         @Override
onResult(SendResult result)916         public void onResult(SendResult result) {
917             if (result.isOK()) {
918                 outputBuffer.clear();
919             }
920             handler.onResult(result);
921         }
922     }
923 
924 
925     private static class WsOutputStream extends OutputStream {
926 
927         private final WsRemoteEndpointImplBase endpoint;
928         private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
929         private final Object closeLock = new Object();
930         private volatile boolean closed = false;
931         private volatile boolean used = false;
932 
WsOutputStream(WsRemoteEndpointImplBase endpoint)933         public WsOutputStream(WsRemoteEndpointImplBase endpoint) {
934             this.endpoint = endpoint;
935         }
936 
937         @Override
write(int b)938         public void write(int b) throws IOException {
939             if (closed) {
940                 throw new IllegalStateException(
941                         sm.getString("wsRemoteEndpoint.closedOutputStream"));
942             }
943 
944             used = true;
945             if (buffer.remaining() == 0) {
946                 flush();
947             }
948             buffer.put((byte) b);
949         }
950 
951         @Override
write(byte[] b, int off, int len)952         public void write(byte[] b, int off, int len) throws IOException {
953             if (closed) {
954                 throw new IllegalStateException(
955                         sm.getString("wsRemoteEndpoint.closedOutputStream"));
956             }
957             if (len == 0) {
958                 return;
959             }
960             if ((off < 0) || (off > b.length) || (len < 0) ||
961                 ((off + len) > b.length) || ((off + len) < 0)) {
962                 throw new IndexOutOfBoundsException();
963             }
964 
965             used = true;
966             if (buffer.remaining() == 0) {
967                 flush();
968             }
969             int remaining = buffer.remaining();
970             int written = 0;
971 
972             while (remaining < len - written) {
973                 buffer.put(b, off + written, remaining);
974                 written += remaining;
975                 flush();
976                 remaining = buffer.remaining();
977             }
978             buffer.put(b, off + written, len - written);
979         }
980 
981         @Override
flush()982         public void flush() throws IOException {
983             if (closed) {
984                 throw new IllegalStateException(
985                         sm.getString("wsRemoteEndpoint.closedOutputStream"));
986             }
987 
988             // Optimisation. If there is no data to flush then do not send an
989             // empty message.
990             if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
991                 doWrite(false);
992             }
993         }
994 
995         @Override
close()996         public void close() throws IOException {
997             synchronized (closeLock) {
998                 if (closed) {
999                     return;
1000                 }
1001                 closed = true;
1002             }
1003 
1004             doWrite(true);
1005         }
1006 
doWrite(boolean last)1007         private void doWrite(boolean last) throws IOException {
1008             if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
1009                 buffer.flip();
1010                 endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last);
1011             }
1012             endpoint.stateMachine.complete(last);
1013             buffer.clear();
1014         }
1015     }
1016 
1017 
1018     private static class WsWriter extends Writer {
1019 
1020         private final WsRemoteEndpointImplBase endpoint;
1021         private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
1022         private final Object closeLock = new Object();
1023         private volatile boolean closed = false;
1024         private volatile boolean used = false;
1025 
WsWriter(WsRemoteEndpointImplBase endpoint)1026         public WsWriter(WsRemoteEndpointImplBase endpoint) {
1027             this.endpoint = endpoint;
1028         }
1029 
1030         @Override
write(char[] cbuf, int off, int len)1031         public void write(char[] cbuf, int off, int len) throws IOException {
1032             if (closed) {
1033                 throw new IllegalStateException(
1034                         sm.getString("wsRemoteEndpoint.closedWriter"));
1035             }
1036             if (len == 0) {
1037                 return;
1038             }
1039             if ((off < 0) || (off > cbuf.length) || (len < 0) ||
1040                     ((off + len) > cbuf.length) || ((off + len) < 0)) {
1041                 throw new IndexOutOfBoundsException();
1042             }
1043 
1044             used = true;
1045             if (buffer.remaining() == 0) {
1046                 flush();
1047             }
1048             int remaining = buffer.remaining();
1049             int written = 0;
1050 
1051             while (remaining < len - written) {
1052                 buffer.put(cbuf, off + written, remaining);
1053                 written += remaining;
1054                 flush();
1055                 remaining = buffer.remaining();
1056             }
1057             buffer.put(cbuf, off + written, len - written);
1058         }
1059 
1060         @Override
flush()1061         public void flush() throws IOException {
1062             if (closed) {
1063                 throw new IllegalStateException(
1064                         sm.getString("wsRemoteEndpoint.closedWriter"));
1065             }
1066 
1067             if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
1068                 doWrite(false);
1069             }
1070         }
1071 
1072         @Override
close()1073         public void close() throws IOException {
1074             synchronized (closeLock) {
1075                 if (closed) {
1076                     return;
1077                 }
1078                 closed = true;
1079             }
1080 
1081             doWrite(true);
1082         }
1083 
doWrite(boolean last)1084         private void doWrite(boolean last) throws IOException {
1085             if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
1086                 buffer.flip();
1087                 endpoint.sendMessageBlock(buffer, last);
1088                 buffer.clear();
1089             } else {
1090                 endpoint.stateMachine.complete(last);
1091             }
1092         }
1093     }
1094 
1095 
1096     private static class EncoderEntry {
1097 
1098         private final Class<?> clazz;
1099         private final Encoder encoder;
1100 
EncoderEntry(Class<?> clazz, Encoder encoder)1101         public EncoderEntry(Class<?> clazz, Encoder encoder) {
1102             this.clazz = clazz;
1103             this.encoder = encoder;
1104         }
1105 
getClazz()1106         public Class<?> getClazz() {
1107             return clazz;
1108         }
1109 
getEncoder()1110         public Encoder getEncoder() {
1111             return encoder;
1112         }
1113     }
1114 
1115 
1116     private enum State {
1117         OPEN,
1118         STREAM_WRITING,
1119         WRITER_WRITING,
1120         BINARY_PARTIAL_WRITING,
1121         BINARY_PARTIAL_READY,
1122         BINARY_FULL_WRITING,
1123         TEXT_PARTIAL_WRITING,
1124         TEXT_PARTIAL_READY,
1125         TEXT_FULL_WRITING
1126     }
1127 
1128 
1129     private static class StateMachine {
1130         private State state = State.OPEN;
1131 
streamStart()1132         public synchronized void streamStart() {
1133             checkState(State.OPEN);
1134             state = State.STREAM_WRITING;
1135         }
1136 
writeStart()1137         public synchronized void writeStart() {
1138             checkState(State.OPEN);
1139             state = State.WRITER_WRITING;
1140         }
1141 
binaryPartialStart()1142         public synchronized void binaryPartialStart() {
1143             checkState(State.OPEN, State.BINARY_PARTIAL_READY);
1144             state = State.BINARY_PARTIAL_WRITING;
1145         }
1146 
binaryStart()1147         public synchronized void binaryStart() {
1148             checkState(State.OPEN);
1149             state = State.BINARY_FULL_WRITING;
1150         }
1151 
textPartialStart()1152         public synchronized void textPartialStart() {
1153             checkState(State.OPEN, State.TEXT_PARTIAL_READY);
1154             state = State.TEXT_PARTIAL_WRITING;
1155         }
1156 
textStart()1157         public synchronized void textStart() {
1158             checkState(State.OPEN);
1159             state = State.TEXT_FULL_WRITING;
1160         }
1161 
complete(boolean last)1162         public synchronized void complete(boolean last) {
1163             if (last) {
1164                 checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING,
1165                         State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING,
1166                         State.STREAM_WRITING, State.WRITER_WRITING);
1167                 state = State.OPEN;
1168             } else {
1169                 checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING,
1170                         State.STREAM_WRITING, State.WRITER_WRITING);
1171                 if (state == State.TEXT_PARTIAL_WRITING) {
1172                     state = State.TEXT_PARTIAL_READY;
1173                 } else if (state == State.BINARY_PARTIAL_WRITING){
1174                     state = State.BINARY_PARTIAL_READY;
1175                 } else if (state == State.WRITER_WRITING) {
1176                     // NO-OP. Leave state as is.
1177                 } else if (state == State.STREAM_WRITING) {
1178                     // NO-OP. Leave state as is.
1179                 } else {
1180                     // Should never happen
1181                     // The if ... else ... blocks above should cover all states
1182                     // permitted by the preceding checkState() call
1183                     throw new IllegalStateException(
1184                             "BUG: This code should never be called");
1185                 }
1186             }
1187         }
1188 
checkState(State... required)1189         private void checkState(State... required) {
1190             for (State state : required) {
1191                 if (this.state == state) {
1192                     return;
1193                 }
1194             }
1195             throw new IllegalStateException(
1196                     sm.getString("wsRemoteEndpoint.wrongState", this.state));
1197         }
1198     }
1199 
1200 
1201     private static class StateUpdateSendHandler implements SendHandler {
1202 
1203         private final SendHandler handler;
1204         private final StateMachine stateMachine;
1205 
StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine)1206         public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) {
1207             this.handler = handler;
1208             this.stateMachine = stateMachine;
1209         }
1210 
1211         @Override
onResult(SendResult result)1212         public void onResult(SendResult result) {
1213             if (result.isOK()) {
1214                 stateMachine.complete(true);
1215             }
1216             handler.onResult(result);
1217         }
1218     }
1219 
1220 
1221     private static class BlockingSendHandler implements SendHandler {
1222 
1223         private SendResult sendResult = null;
1224 
1225         @Override
onResult(SendResult result)1226         public void onResult(SendResult result) {
1227             sendResult = result;
1228         }
1229 
getSendResult()1230         public SendResult getSendResult() {
1231             return sendResult;
1232         }
1233     }
1234 }
1235