xref: /unit/src/java/nginx/unit/websocket/WsWebSocketContainer.java (revision 1157:7ae152bda303)
1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */
17 package nginx.unit.websocket;
18 
19 import java.io.EOFException;
20 import java.io.File;
21 import java.io.FileInputStream;
22 import java.io.IOException;
23 import java.io.InputStream;
24 import java.net.InetSocketAddress;
25 import java.net.Proxy;
26 import java.net.ProxySelector;
27 import java.net.SocketAddress;
28 import java.net.URI;
29 import java.net.URISyntaxException;
30 import java.nio.ByteBuffer;
31 import java.nio.channels.AsynchronousChannelGroup;
32 import java.nio.channels.AsynchronousSocketChannel;
33 import java.nio.charset.StandardCharsets;
34 import java.security.KeyStore;
35 import java.util.ArrayList;
36 import java.util.Arrays;
37 import java.util.Collections;
38 import java.util.HashMap;
39 import java.util.HashSet;
40 import java.util.List;
41 import java.util.Locale;
42 import java.util.Map;
43 import java.util.Map.Entry;
44 import java.util.Random;
45 import java.util.Set;
46 import java.util.concurrent.ConcurrentHashMap;
47 import java.util.concurrent.ExecutionException;
48 import java.util.concurrent.Future;
49 import java.util.concurrent.TimeUnit;
50 import java.util.concurrent.TimeoutException;
51 
52 import javax.net.ssl.SSLContext;
53 import javax.net.ssl.SSLEngine;
54 import javax.net.ssl.SSLException;
55 import javax.net.ssl.SSLParameters;
56 import javax.net.ssl.TrustManagerFactory;
57 import javax.websocket.ClientEndpoint;
58 import javax.websocket.ClientEndpointConfig;
59 import javax.websocket.CloseReason;
60 import javax.websocket.CloseReason.CloseCodes;
61 import javax.websocket.DeploymentException;
62 import javax.websocket.Endpoint;
63 import javax.websocket.Extension;
64 import javax.websocket.HandshakeResponse;
65 import javax.websocket.Session;
66 import javax.websocket.WebSocketContainer;
67 
68 import org.apache.juli.logging.Log;
69 import org.apache.juli.logging.LogFactory;
70 import org.apache.tomcat.InstanceManager;
71 import org.apache.tomcat.util.buf.StringUtils;
72 import org.apache.tomcat.util.codec.binary.Base64;
73 import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
74 import org.apache.tomcat.util.res.StringManager;
75 import nginx.unit.websocket.pojo.PojoEndpointClient;
76 
77 public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess {
78 
79     private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class);
80     private static final Random RANDOM = new Random();
81     private static final byte[] CRLF = new byte[] { 13, 10 };
82 
83     private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1);
84     private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1);
85     private static final byte[] HTTP_VERSION_BYTES =
86             " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1);
87 
88     private volatile AsynchronousChannelGroup asynchronousChannelGroup = null;
89     private final Object asynchronousChannelGroupLock = new Object();
90 
91     private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static
92     private final Map<Endpoint, Set<WsSession>> endpointSessionMap =
93             new HashMap<>();
94     private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>();
95     private final Object endPointSessionMapLock = new Object();
96 
97     private long defaultAsyncTimeout = -1;
98     private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
99     private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
100     private volatile long defaultMaxSessionIdleTimeout = 0;
101     private int backgroundProcessCount = 0;
102     private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD;
103 
104     private InstanceManager instanceManager;
105 
getInstanceManager()106     InstanceManager getInstanceManager() {
107         return instanceManager;
108     }
109 
setInstanceManager(InstanceManager instanceManager)110     protected void setInstanceManager(InstanceManager instanceManager) {
111         this.instanceManager = instanceManager;
112     }
113 
114     @Override
connectToServer(Object pojo, URI path)115     public Session connectToServer(Object pojo, URI path)
116             throws DeploymentException {
117 
118         ClientEndpoint annotation =
119                 pojo.getClass().getAnnotation(ClientEndpoint.class);
120         if (annotation == null) {
121             throw new DeploymentException(
122                     sm.getString("wsWebSocketContainer.missingAnnotation",
123                             pojo.getClass().getName()));
124         }
125 
126         Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders()));
127 
128         Class<? extends ClientEndpointConfig.Configurator> configuratorClazz =
129                 annotation.configurator();
130 
131         ClientEndpointConfig.Configurator configurator = null;
132         if (!ClientEndpointConfig.Configurator.class.equals(
133                 configuratorClazz)) {
134             try {
135                 configurator = configuratorClazz.getConstructor().newInstance();
136             } catch (ReflectiveOperationException e) {
137                 throw new DeploymentException(sm.getString(
138                         "wsWebSocketContainer.defaultConfiguratorFail"), e);
139             }
140         }
141 
142         ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create();
143         // Avoid NPE when using RI API JAR - see BZ 56343
144         if (configurator != null) {
145             builder.configurator(configurator);
146         }
147         ClientEndpointConfig config = builder.
148                 decoders(Arrays.asList(annotation.decoders())).
149                 encoders(Arrays.asList(annotation.encoders())).
150                 preferredSubprotocols(Arrays.asList(annotation.subprotocols())).
151                 build();
152         return connectToServer(ep, config, path);
153     }
154 
155 
156     @Override
connectToServer(Class<?> annotatedEndpointClass, URI path)157     public Session connectToServer(Class<?> annotatedEndpointClass, URI path)
158             throws DeploymentException {
159 
160         Object pojo;
161         try {
162             pojo = annotatedEndpointClass.getConstructor().newInstance();
163         } catch (ReflectiveOperationException e) {
164             throw new DeploymentException(sm.getString(
165                     "wsWebSocketContainer.endpointCreateFail",
166                     annotatedEndpointClass.getName()), e);
167         }
168 
169         return connectToServer(pojo, path);
170     }
171 
172 
173     @Override
connectToServer(Class<? extends Endpoint> clazz, ClientEndpointConfig clientEndpointConfiguration, URI path)174     public Session connectToServer(Class<? extends Endpoint> clazz,
175             ClientEndpointConfig clientEndpointConfiguration, URI path)
176             throws DeploymentException {
177 
178         Endpoint endpoint;
179         try {
180             endpoint = clazz.getConstructor().newInstance();
181         } catch (ReflectiveOperationException e) {
182             throw new DeploymentException(sm.getString(
183                     "wsWebSocketContainer.endpointCreateFail", clazz.getName()),
184                     e);
185         }
186 
187         return connectToServer(endpoint, clientEndpointConfiguration, path);
188     }
189 
190 
191     @Override
connectToServer(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path)192     public Session connectToServer(Endpoint endpoint,
193             ClientEndpointConfig clientEndpointConfiguration, URI path)
194             throws DeploymentException {
195         return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>());
196     }
197 
connectToServerRecursive(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path, Set<URI> redirectSet)198     private Session connectToServerRecursive(Endpoint endpoint,
199             ClientEndpointConfig clientEndpointConfiguration, URI path,
200             Set<URI> redirectSet)
201             throws DeploymentException {
202 
203         boolean secure = false;
204         ByteBuffer proxyConnect = null;
205         URI proxyPath;
206 
207         // Validate scheme (and build proxyPath)
208         String scheme = path.getScheme();
209         if ("ws".equalsIgnoreCase(scheme)) {
210             proxyPath = URI.create("http" + path.toString().substring(2));
211         } else if ("wss".equalsIgnoreCase(scheme)) {
212             proxyPath = URI.create("https" + path.toString().substring(3));
213             secure = true;
214         } else {
215             throw new DeploymentException(sm.getString(
216                     "wsWebSocketContainer.pathWrongScheme", scheme));
217         }
218 
219         // Validate host
220         String host = path.getHost();
221         if (host == null) {
222             throw new DeploymentException(
223                     sm.getString("wsWebSocketContainer.pathNoHost"));
224         }
225         int port = path.getPort();
226 
227         SocketAddress sa = null;
228 
229         // Check to see if a proxy is configured. Javadoc indicates return value
230         // will never be null
231         List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath);
232         Proxy selectedProxy = null;
233         for (Proxy proxy : proxies) {
234             if (proxy.type().equals(Proxy.Type.HTTP)) {
235                 sa = proxy.address();
236                 if (sa instanceof InetSocketAddress) {
237                     InetSocketAddress inet = (InetSocketAddress) sa;
238                     if (inet.isUnresolved()) {
239                         sa = new InetSocketAddress(inet.getHostName(), inet.getPort());
240                     }
241                 }
242                 selectedProxy = proxy;
243                 break;
244             }
245         }
246 
247         // If the port is not explicitly specified, compute it based on the
248         // scheme
249         if (port == -1) {
250             if ("ws".equalsIgnoreCase(scheme)) {
251                 port = 80;
252             } else {
253                 // Must be wss due to scheme validation above
254                 port = 443;
255             }
256         }
257 
258         // If sa is null, no proxy is configured so need to create sa
259         if (sa == null) {
260             sa = new InetSocketAddress(host, port);
261         } else {
262             proxyConnect = createProxyRequest(host, port);
263         }
264 
265         // Create the initial HTTP request to open the WebSocket connection
266         Map<String, List<String>> reqHeaders = createRequestHeaders(host, port,
267                 clientEndpointConfiguration);
268         clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders);
269         if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null
270                 && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) {
271             List<String> originValues = new ArrayList<>(1);
272             originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE);
273             reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues);
274         }
275         ByteBuffer request = createRequest(path, reqHeaders);
276 
277         AsynchronousSocketChannel socketChannel;
278         try {
279             socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup());
280         } catch (IOException ioe) {
281             throw new DeploymentException(sm.getString(
282                     "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe);
283         }
284 
285         Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties();
286 
287         // Get the connection timeout
288         long timeout = Constants.IO_TIMEOUT_MS_DEFAULT;
289         String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY);
290         if (timeoutValue != null) {
291             timeout = Long.valueOf(timeoutValue).intValue();
292         }
293 
294         // Set-up
295         // Same size as the WsFrame input buffer
296         ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize());
297         String subProtocol;
298         boolean success = false;
299         List<Extension> extensionsAgreed = new ArrayList<>();
300         Transformation transformation = null;
301 
302         // Open the connection
303         Future<Void> fConnect = socketChannel.connect(sa);
304         AsyncChannelWrapper channel = null;
305 
306         if (proxyConnect != null) {
307             try {
308                 fConnect.get(timeout, TimeUnit.MILLISECONDS);
309                 // Proxy CONNECT is clear text
310                 channel = new AsyncChannelWrapperNonSecure(socketChannel);
311                 writeRequest(channel, proxyConnect, timeout);
312                 HttpResponse httpResponse = processResponse(response, channel, timeout);
313                 if (httpResponse.getStatus() != 200) {
314                     throw new DeploymentException(sm.getString(
315                             "wsWebSocketContainer.proxyConnectFail", selectedProxy,
316                             Integer.toString(httpResponse.getStatus())));
317                 }
318             } catch (TimeoutException | InterruptedException | ExecutionException |
319                     EOFException e) {
320                 if (channel != null) {
321                     channel.close();
322                 }
323                 throw new DeploymentException(
324                         sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
325             }
326         }
327 
328         if (secure) {
329             // Regardless of whether a non-secure wrapper was created for a
330             // proxy CONNECT, need to use TLS from this point on so wrap the
331             // original AsynchronousSocketChannel
332             SSLEngine sslEngine = createSSLEngine(userProperties, host, port);
333             channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
334         } else if (channel == null) {
335             // Only need to wrap as this point if it wasn't wrapped to process a
336             // proxy CONNECT
337             channel = new AsyncChannelWrapperNonSecure(socketChannel);
338         }
339 
340         try {
341             fConnect.get(timeout, TimeUnit.MILLISECONDS);
342 
343             Future<Void> fHandshake = channel.handshake();
344             fHandshake.get(timeout, TimeUnit.MILLISECONDS);
345 
346             writeRequest(channel, request, timeout);
347 
348             HttpResponse httpResponse = processResponse(response, channel, timeout);
349 
350             // Check maximum permitted redirects
351             int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT;
352             String maxRedirectsValue =
353                     (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY);
354             if (maxRedirectsValue != null) {
355                 maxRedirects = Integer.parseInt(maxRedirectsValue);
356             }
357 
358             if (httpResponse.status != 101) {
359                 if(isRedirectStatus(httpResponse.status)){
360                     List<String> locationHeader =
361                             httpResponse.getHandshakeResponse().getHeaders().get(
362                                     Constants.LOCATION_HEADER_NAME);
363 
364                     if (locationHeader == null || locationHeader.isEmpty() ||
365                             locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) {
366                         throw new DeploymentException(sm.getString(
367                                 "wsWebSocketContainer.missingLocationHeader",
368                                 Integer.toString(httpResponse.status)));
369                     }
370 
371                     URI redirectLocation = URI.create(locationHeader.get(0)).normalize();
372 
373                     if (!redirectLocation.isAbsolute()) {
374                         redirectLocation = path.resolve(redirectLocation);
375                     }
376 
377                     String redirectScheme = redirectLocation.getScheme().toLowerCase();
378 
379                     if (redirectScheme.startsWith("http")) {
380                         redirectLocation = new URI(redirectScheme.replace("http", "ws"),
381                                 redirectLocation.getUserInfo(), redirectLocation.getHost(),
382                                 redirectLocation.getPort(), redirectLocation.getPath(),
383                                 redirectLocation.getQuery(), redirectLocation.getFragment());
384                     }
385 
386                     if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) {
387                         throw new DeploymentException(sm.getString(
388                                 "wsWebSocketContainer.redirectThreshold", redirectLocation,
389                                 Integer.toString(redirectSet.size()),
390                                 Integer.toString(maxRedirects)));
391                     }
392 
393                     return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet);
394 
395                 }
396 
397                 else if (httpResponse.status == 401) {
398 
399                     if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
400                         throw new DeploymentException(sm.getString(
401                                 "wsWebSocketContainer.failedAuthentication",
402                                 Integer.valueOf(httpResponse.status)));
403                     }
404 
405                     List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse()
406                             .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME);
407 
408                     if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() ||
409                             wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) {
410                         throw new DeploymentException(sm.getString(
411                                 "wsWebSocketContainer.missingWWWAuthenticateHeader",
412                                 Integer.toString(httpResponse.status)));
413                     }
414 
415                     String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0];
416                     String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1)
417                             .split("\\s", 3)[1];
418 
419                     Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme);
420 
421                     if (auth == null) {
422                         throw new DeploymentException(
423                                 sm.getString("wsWebSocketContainer.unsupportedAuthScheme",
424                                         Integer.valueOf(httpResponse.status), authScheme));
425                     }
426 
427                     userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization(
428                             requestUri, wwwAuthenticateHeaders.get(0), userProperties));
429 
430                     return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet);
431 
432                 }
433 
434                 else {
435                     throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus",
436                             Integer.toString(httpResponse.status)));
437                 }
438             }
439             HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse();
440             clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse);
441 
442             // Sub-protocol
443             List<String> protocolHeaders = handshakeResponse.getHeaders().get(
444                     Constants.WS_PROTOCOL_HEADER_NAME);
445             if (protocolHeaders == null || protocolHeaders.size() == 0) {
446                 subProtocol = null;
447             } else if (protocolHeaders.size() == 1) {
448                 subProtocol = protocolHeaders.get(0);
449             } else {
450                 throw new DeploymentException(
451                         sm.getString("wsWebSocketContainer.invalidSubProtocol"));
452             }
453 
454             // Extensions
455             // Should normally only be one header but handle the case of
456             // multiple headers
457             List<String> extHeaders = handshakeResponse.getHeaders().get(
458                     Constants.WS_EXTENSIONS_HEADER_NAME);
459             if (extHeaders != null) {
460                 for (String extHeader : extHeaders) {
461                     Util.parseExtensionHeader(extensionsAgreed, extHeader);
462                 }
463             }
464 
465             // Build the transformations
466             TransformationFactory factory = TransformationFactory.getInstance();
467             for (Extension extension : extensionsAgreed) {
468                 List<List<Extension.Parameter>> wrapper = new ArrayList<>(1);
469                 wrapper.add(extension.getParameters());
470                 Transformation t = factory.create(extension.getName(), wrapper, false);
471                 if (t == null) {
472                     throw new DeploymentException(sm.getString(
473                             "wsWebSocketContainer.invalidExtensionParameters"));
474                 }
475                 if (transformation == null) {
476                     transformation = t;
477                 } else {
478                     transformation.setNext(t);
479                 }
480             }
481 
482             success = true;
483         } catch (ExecutionException | InterruptedException | SSLException |
484                 EOFException | TimeoutException | URISyntaxException | AuthenticationException e) {
485             throw new DeploymentException(
486                     sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
487         } finally {
488             if (!success) {
489                 channel.close();
490             }
491         }
492 
493         // Switch to WebSocket
494         WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel);
495 
496         WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient,
497                 this, null, null, null, null, null, extensionsAgreed,
498                 subProtocol, Collections.<String,String>emptyMap(), secure,
499                 clientEndpointConfiguration, null);
500 
501         WsFrameClient wsFrameClient = new WsFrameClient(response, channel,
502                 wsSession, transformation);
503         // WsFrame adds the necessary final transformations. Copy the
504         // completed transformation chain to the remote end point.
505         wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation());
506 
507         endpoint.onOpen(wsSession, clientEndpointConfiguration);
508         registerSession(endpoint, wsSession);
509 
510         /* It is possible that the server sent one or more messages as soon as
511          * the WebSocket connection was established. Depending on the exact
512          * timing of when those messages were sent they could be sat in the
513          * input buffer waiting to be read and will not trigger a "data
514          * available to read" event. Therefore, it is necessary to process the
515          * input buffer here. Note that this happens on the current thread which
516          * means that this thread will be used for any onMessage notifications.
517          * This is a special case. Subsequent "data available to read" events
518          * will be handled by threads from the AsyncChannelGroup's executor.
519          */
520         wsFrameClient.startInputProcessing();
521 
522         return wsSession;
523     }
524 
525 
writeRequest(AsyncChannelWrapper channel, ByteBuffer request, long timeout)526     private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request,
527             long timeout) throws TimeoutException, InterruptedException, ExecutionException {
528         int toWrite = request.limit();
529 
530         Future<Integer> fWrite = channel.write(request);
531         Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
532         toWrite -= thisWrite.intValue();
533 
534         while (toWrite > 0) {
535             fWrite = channel.write(request);
536             thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
537             toWrite -= thisWrite.intValue();
538         }
539     }
540 
541 
isRedirectStatus(int httpResponseCode)542     private static boolean isRedirectStatus(int httpResponseCode) {
543 
544         boolean isRedirect = false;
545 
546         switch (httpResponseCode) {
547         case Constants.MULTIPLE_CHOICES:
548         case Constants.MOVED_PERMANENTLY:
549         case Constants.FOUND:
550         case Constants.SEE_OTHER:
551         case Constants.USE_PROXY:
552         case Constants.TEMPORARY_REDIRECT:
553             isRedirect = true;
554             break;
555         default:
556             break;
557         }
558 
559         return isRedirect;
560     }
561 
562 
createProxyRequest(String host, int port)563     private static ByteBuffer createProxyRequest(String host, int port) {
564         StringBuilder request = new StringBuilder();
565         request.append("CONNECT ");
566         request.append(host);
567         request.append(':');
568         request.append(port);
569 
570         request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: ");
571         request.append(host);
572         request.append(':');
573         request.append(port);
574 
575         request.append("\r\n\r\n");
576 
577         byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1);
578         return ByteBuffer.wrap(bytes);
579     }
580 
registerSession(Endpoint endpoint, WsSession wsSession)581     protected void registerSession(Endpoint endpoint, WsSession wsSession) {
582 
583         if (!wsSession.isOpen()) {
584             // The session was closed during onOpen. No need to register it.
585             return;
586         }
587         synchronized (endPointSessionMapLock) {
588             if (endpointSessionMap.size() == 0) {
589                 BackgroundProcessManager.getInstance().register(this);
590             }
591             Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
592             if (wsSessions == null) {
593                 wsSessions = new HashSet<>();
594                 endpointSessionMap.put(endpoint, wsSessions);
595             }
596             wsSessions.add(wsSession);
597         }
598         sessions.put(wsSession, wsSession);
599     }
600 
601 
unregisterSession(Endpoint endpoint, WsSession wsSession)602     protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
603 
604         synchronized (endPointSessionMapLock) {
605             Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
606             if (wsSessions != null) {
607                 wsSessions.remove(wsSession);
608                 if (wsSessions.size() == 0) {
609                     endpointSessionMap.remove(endpoint);
610                 }
611             }
612             if (endpointSessionMap.size() == 0) {
613                 BackgroundProcessManager.getInstance().unregister(this);
614             }
615         }
616         sessions.remove(wsSession);
617     }
618 
619 
getOpenSessions(Endpoint endpoint)620     Set<Session> getOpenSessions(Endpoint endpoint) {
621         HashSet<Session> result = new HashSet<>();
622         synchronized (endPointSessionMapLock) {
623             Set<WsSession> sessions = endpointSessionMap.get(endpoint);
624             if (sessions != null) {
625                 result.addAll(sessions);
626             }
627         }
628         return result;
629     }
630 
createRequestHeaders(String host, int port, ClientEndpointConfig clientEndpointConfiguration)631     private static Map<String, List<String>> createRequestHeaders(String host, int port,
632             ClientEndpointConfig clientEndpointConfiguration) {
633 
634         Map<String, List<String>> headers = new HashMap<>();
635         List<Extension> extensions = clientEndpointConfiguration.getExtensions();
636         List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols();
637         Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties();
638 
639         if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
640             List<String> authValues = new ArrayList<>(1);
641             authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME));
642             headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues);
643         }
644 
645         // Host header
646         List<String> hostValues = new ArrayList<>(1);
647         if (port == -1) {
648             hostValues.add(host);
649         } else {
650             hostValues.add(host + ':' + port);
651         }
652 
653         headers.put(Constants.HOST_HEADER_NAME, hostValues);
654 
655         // Upgrade header
656         List<String> upgradeValues = new ArrayList<>(1);
657         upgradeValues.add(Constants.UPGRADE_HEADER_VALUE);
658         headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues);
659 
660         // Connection header
661         List<String> connectionValues = new ArrayList<>(1);
662         connectionValues.add(Constants.CONNECTION_HEADER_VALUE);
663         headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues);
664 
665         // WebSocket version header
666         List<String> wsVersionValues = new ArrayList<>(1);
667         wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE);
668         headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues);
669 
670         // WebSocket key
671         List<String> wsKeyValues = new ArrayList<>(1);
672         wsKeyValues.add(generateWsKeyValue());
673         headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues);
674 
675         // WebSocket sub-protocols
676         if (subProtocols != null && subProtocols.size() > 0) {
677             headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols);
678         }
679 
680         // WebSocket extensions
681         if (extensions != null && extensions.size() > 0) {
682             headers.put(Constants.WS_EXTENSIONS_HEADER_NAME,
683                     generateExtensionHeaders(extensions));
684         }
685 
686         return headers;
687     }
688 
689 
generateExtensionHeaders(List<Extension> extensions)690     private static List<String> generateExtensionHeaders(List<Extension> extensions) {
691         List<String> result = new ArrayList<>(extensions.size());
692         for (Extension extension : extensions) {
693             StringBuilder header = new StringBuilder();
694             header.append(extension.getName());
695             for (Extension.Parameter param : extension.getParameters()) {
696                 header.append(';');
697                 header.append(param.getName());
698                 String value = param.getValue();
699                 if (value != null && value.length() > 0) {
700                     header.append('=');
701                     header.append(value);
702                 }
703             }
704             result.add(header.toString());
705         }
706         return result;
707     }
708 
709 
generateWsKeyValue()710     private static String generateWsKeyValue() {
711         byte[] keyBytes = new byte[16];
712         RANDOM.nextBytes(keyBytes);
713         return Base64.encodeBase64String(keyBytes);
714     }
715 
716 
createRequest(URI uri, Map<String,List<String>> reqHeaders)717     private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) {
718         ByteBuffer result = ByteBuffer.allocate(4 * 1024);
719 
720         // Request line
721         result.put(GET_BYTES);
722         if (null == uri.getPath() || "".equals(uri.getPath())) {
723             result.put(ROOT_URI_BYTES);
724         } else {
725             result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1));
726         }
727         String query = uri.getRawQuery();
728         if (query != null) {
729             result.put((byte) '?');
730             result.put(query.getBytes(StandardCharsets.ISO_8859_1));
731         }
732         result.put(HTTP_VERSION_BYTES);
733 
734         // Headers
735         for (Entry<String, List<String>> entry : reqHeaders.entrySet()) {
736             result = addHeader(result, entry.getKey(), entry.getValue());
737         }
738 
739         // Terminating CRLF
740         result.put(CRLF);
741 
742         result.flip();
743 
744         return result;
745     }
746 
747 
addHeader(ByteBuffer result, String key, List<String> values)748     private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) {
749         if (values.isEmpty()) {
750             return result;
751         }
752 
753         result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1));
754         result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1));
755         result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1));
756         result = putWithExpand(result, CRLF);
757 
758         return result;
759     }
760 
761 
putWithExpand(ByteBuffer input, byte[] bytes)762     private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) {
763         if (bytes.length > input.remaining()) {
764             int newSize;
765             if (bytes.length > input.capacity()) {
766                 newSize = 2 * bytes.length;
767             } else {
768                 newSize = input.capacity() * 2;
769             }
770             ByteBuffer expanded = ByteBuffer.allocate(newSize);
771             input.flip();
772             expanded.put(input);
773             input = expanded;
774         }
775         return input.put(bytes);
776     }
777 
778 
779     /**
780      * Process response, blocking until HTTP response has been fully received.
781      * @throws ExecutionException
782      * @throws InterruptedException
783      * @throws DeploymentException
784      * @throws TimeoutException
785      */
processResponse(ByteBuffer response, AsyncChannelWrapper channel, long timeout)786     private HttpResponse processResponse(ByteBuffer response,
787             AsyncChannelWrapper channel, long timeout) throws InterruptedException,
788             ExecutionException, DeploymentException, EOFException,
789             TimeoutException {
790 
791         Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>();
792 
793         int status = 0;
794         boolean readStatus = false;
795         boolean readHeaders = false;
796         String line = null;
797         while (!readHeaders) {
798             // On entering loop buffer will be empty and at the start of a new
799             // loop the buffer will have been fully read.
800             response.clear();
801             // Blocking read
802             Future<Integer> read = channel.read(response);
803             Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS);
804             if (bytesRead.intValue() == -1) {
805                 throw new EOFException();
806             }
807             response.flip();
808             while (response.hasRemaining() && !readHeaders) {
809                 if (line == null) {
810                     line = readLine(response);
811                 } else {
812                     line += readLine(response);
813                 }
814                 if ("\r\n".equals(line)) {
815                     readHeaders = true;
816                 } else if (line.endsWith("\r\n")) {
817                     if (readStatus) {
818                         parseHeaders(line, headers);
819                     } else {
820                         status = parseStatus(line);
821                         readStatus = true;
822                     }
823                     line = null;
824                 }
825             }
826         }
827 
828         return new HttpResponse(status, new WsHandshakeResponse(headers));
829     }
830 
831 
parseStatus(String line)832     private int parseStatus(String line) throws DeploymentException {
833         // This client only understands HTTP 1.
834         // RFC2616 is case specific
835         String[] parts = line.trim().split(" ");
836         // CONNECT for proxy may return a 1.0 response
837         if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) {
838             throw new DeploymentException(sm.getString(
839                     "wsWebSocketContainer.invalidStatus", line));
840         }
841         try {
842             return Integer.parseInt(parts[1]);
843         } catch (NumberFormatException nfe) {
844             throw new DeploymentException(sm.getString(
845                     "wsWebSocketContainer.invalidStatus", line));
846         }
847     }
848 
849 
parseHeaders(String line, Map<String,List<String>> headers)850     private void parseHeaders(String line, Map<String,List<String>> headers) {
851         // Treat headers as single values by default.
852 
853         int index = line.indexOf(':');
854         if (index == -1) {
855             log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line));
856             return;
857         }
858         // Header names are case insensitive so always use lower case
859         String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH);
860         // Multi-value headers are stored as a single header and the client is
861         // expected to handle splitting into individual values
862         String headerValue = line.substring(index + 1).trim();
863 
864         List<String> values = headers.get(headerName);
865         if (values == null) {
866             values = new ArrayList<>(1);
867             headers.put(headerName, values);
868         }
869         values.add(headerValue);
870     }
871 
readLine(ByteBuffer response)872     private String readLine(ByteBuffer response) {
873         // All ISO-8859-1
874         StringBuilder sb = new StringBuilder();
875 
876         char c = 0;
877         while (response.hasRemaining()) {
878             c = (char) response.get();
879             sb.append(c);
880             if (c == 10) {
881                 break;
882             }
883         }
884 
885         return sb.toString();
886     }
887 
888 
createSSLEngine(Map<String,Object> userProperties, String host, int port)889     private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port)
890             throws DeploymentException {
891 
892         try {
893             // See if a custom SSLContext has been provided
894             SSLContext sslContext =
895                     (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY);
896 
897             if (sslContext == null) {
898                 // Create the SSL Context
899                 sslContext = SSLContext.getInstance("TLS");
900 
901                 // Trust store
902                 String sslTrustStoreValue =
903                         (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY);
904                 if (sslTrustStoreValue != null) {
905                     String sslTrustStorePwdValue = (String) userProperties.get(
906                             Constants.SSL_TRUSTSTORE_PWD_PROPERTY);
907                     if (sslTrustStorePwdValue == null) {
908                         sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT;
909                     }
910 
911                     File keyStoreFile = new File(sslTrustStoreValue);
912                     KeyStore ks = KeyStore.getInstance("JKS");
913                     try (InputStream is = new FileInputStream(keyStoreFile)) {
914                         ks.load(is, sslTrustStorePwdValue.toCharArray());
915                     }
916 
917                     TrustManagerFactory tmf = TrustManagerFactory.getInstance(
918                             TrustManagerFactory.getDefaultAlgorithm());
919                     tmf.init(ks);
920 
921                     sslContext.init(null, tmf.getTrustManagers(), null);
922                 } else {
923                     sslContext.init(null, null, null);
924                 }
925             }
926 
927             SSLEngine engine = sslContext.createSSLEngine(host, port);
928 
929             String sslProtocolsValue =
930                     (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY);
931             if (sslProtocolsValue != null) {
932                 engine.setEnabledProtocols(sslProtocolsValue.split(","));
933             }
934 
935             engine.setUseClientMode(true);
936 
937             // Enable host verification
938             // Start with current settings (returns a copy)
939             SSLParameters sslParams = engine.getSSLParameters();
940             // Use HTTPS since WebSocket starts over HTTP(S)
941             sslParams.setEndpointIdentificationAlgorithm("HTTPS");
942             // Write the parameters back
943             engine.setSSLParameters(sslParams);
944 
945             return engine;
946         } catch (Exception e) {
947             throw new DeploymentException(sm.getString(
948                     "wsWebSocketContainer.sslEngineFail"), e);
949         }
950     }
951 
952 
953     @Override
getDefaultMaxSessionIdleTimeout()954     public long getDefaultMaxSessionIdleTimeout() {
955         return defaultMaxSessionIdleTimeout;
956     }
957 
958 
959     @Override
setDefaultMaxSessionIdleTimeout(long timeout)960     public void setDefaultMaxSessionIdleTimeout(long timeout) {
961         this.defaultMaxSessionIdleTimeout = timeout;
962     }
963 
964 
965     @Override
getDefaultMaxBinaryMessageBufferSize()966     public int getDefaultMaxBinaryMessageBufferSize() {
967         return maxBinaryMessageBufferSize;
968     }
969 
970 
971     @Override
setDefaultMaxBinaryMessageBufferSize(int max)972     public void setDefaultMaxBinaryMessageBufferSize(int max) {
973         maxBinaryMessageBufferSize = max;
974     }
975 
976 
977     @Override
getDefaultMaxTextMessageBufferSize()978     public int getDefaultMaxTextMessageBufferSize() {
979         return maxTextMessageBufferSize;
980     }
981 
982 
983     @Override
setDefaultMaxTextMessageBufferSize(int max)984     public void setDefaultMaxTextMessageBufferSize(int max) {
985         maxTextMessageBufferSize = max;
986     }
987 
988 
989     /**
990      * {@inheritDoc}
991      *
992      * Currently, this implementation does not support any extensions.
993      */
994     @Override
getInstalledExtensions()995     public Set<Extension> getInstalledExtensions() {
996         return Collections.emptySet();
997     }
998 
999 
1000     /**
1001      * {@inheritDoc}
1002      *
1003      * The default value for this implementation is -1.
1004      */
1005     @Override
getDefaultAsyncSendTimeout()1006     public long getDefaultAsyncSendTimeout() {
1007         return defaultAsyncTimeout;
1008     }
1009 
1010 
1011     /**
1012      * {@inheritDoc}
1013      *
1014      * The default value for this implementation is -1.
1015      */
1016     @Override
setAsyncSendTimeout(long timeout)1017     public void setAsyncSendTimeout(long timeout) {
1018         this.defaultAsyncTimeout = timeout;
1019     }
1020 
1021 
1022     /**
1023      * Cleans up the resources still in use by WebSocket sessions created from
1024      * this container. This includes closing sessions and cancelling
1025      * {@link Future}s associated with blocking read/writes.
1026      */
destroy()1027     public void destroy() {
1028         CloseReason cr = new CloseReason(
1029                 CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown"));
1030 
1031         for (WsSession session : sessions.keySet()) {
1032             try {
1033                 session.close(cr);
1034             } catch (IOException ioe) {
1035                 log.debug(sm.getString(
1036                         "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe);
1037             }
1038         }
1039 
1040         // Only unregister with AsyncChannelGroupUtil if this instance
1041         // registered with it
1042         if (asynchronousChannelGroup != null) {
1043             synchronized (asynchronousChannelGroupLock) {
1044                 if (asynchronousChannelGroup != null) {
1045                     AsyncChannelGroupUtil.unregister();
1046                     asynchronousChannelGroup = null;
1047                 }
1048             }
1049         }
1050     }
1051 
1052 
getAsynchronousChannelGroup()1053     private AsynchronousChannelGroup getAsynchronousChannelGroup() {
1054         // Use AsyncChannelGroupUtil to share a common group amongst all
1055         // WebSocket clients
1056         AsynchronousChannelGroup result = asynchronousChannelGroup;
1057         if (result == null) {
1058             synchronized (asynchronousChannelGroupLock) {
1059                 if (asynchronousChannelGroup == null) {
1060                     asynchronousChannelGroup = AsyncChannelGroupUtil.register();
1061                 }
1062                 result = asynchronousChannelGroup;
1063             }
1064         }
1065         return result;
1066     }
1067 
1068 
1069     // ----------------------------------------------- BackgroundProcess methods
1070 
1071     @Override
backgroundProcess()1072     public void backgroundProcess() {
1073         // This method gets called once a second.
1074         backgroundProcessCount ++;
1075         if (backgroundProcessCount >= processPeriod) {
1076             backgroundProcessCount = 0;
1077 
1078             for (WsSession wsSession : sessions.keySet()) {
1079                 wsSession.checkExpiration();
1080             }
1081         }
1082 
1083     }
1084 
1085 
1086     @Override
setProcessPeriod(int period)1087     public void setProcessPeriod(int period) {
1088         this.processPeriod = period;
1089     }
1090 
1091 
1092     /**
1093      * {@inheritDoc}
1094      *
1095      * The default value is 10 which means session expirations are processed
1096      * every 10 seconds.
1097      */
1098     @Override
getProcessPeriod()1099     public int getProcessPeriod() {
1100         return processPeriod;
1101     }
1102 
1103 
1104     private static class HttpResponse {
1105         private final int status;
1106         private final HandshakeResponse handshakeResponse;
1107 
HttpResponse(int status, HandshakeResponse handshakeResponse)1108         public HttpResponse(int status, HandshakeResponse handshakeResponse) {
1109             this.status = status;
1110             this.handshakeResponse = handshakeResponse;
1111         }
1112 
1113 
getStatus()1114         public int getStatus() {
1115             return status;
1116         }
1117 
1118 
getHandshakeResponse()1119         public HandshakeResponse getHandshakeResponse() {
1120             return handshakeResponse;
1121         }
1122     }
1123 }
1124