1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket; 18 19 import java.io.EOFException; 20 import java.io.File; 21 import java.io.FileInputStream; 22 import java.io.IOException; 23 import java.io.InputStream; 24 import java.net.InetSocketAddress; 25 import java.net.Proxy; 26 import java.net.ProxySelector; 27 import java.net.SocketAddress; 28 import java.net.URI; 29 import java.net.URISyntaxException; 30 import java.nio.ByteBuffer; 31 import java.nio.channels.AsynchronousChannelGroup; 32 import java.nio.channels.AsynchronousSocketChannel; 33 import java.nio.charset.StandardCharsets; 34 import java.security.KeyStore; 35 import java.util.ArrayList; 36 import java.util.Arrays; 37 import java.util.Collections; 38 import java.util.HashMap; 39 import java.util.HashSet; 40 import java.util.List; 41 import java.util.Locale; 42 import java.util.Map; 43 import java.util.Map.Entry; 44 import java.util.Random; 45 import java.util.Set; 46 import java.util.concurrent.ConcurrentHashMap; 47 import java.util.concurrent.ExecutionException; 48 import java.util.concurrent.Future; 49 import java.util.concurrent.TimeUnit; 50 import java.util.concurrent.TimeoutException; 51 52 import javax.net.ssl.SSLContext; 53 import javax.net.ssl.SSLEngine; 54 import javax.net.ssl.SSLException; 55 import javax.net.ssl.SSLParameters; 56 import javax.net.ssl.TrustManagerFactory; 57 import javax.websocket.ClientEndpoint; 58 import javax.websocket.ClientEndpointConfig; 59 import javax.websocket.CloseReason; 60 import javax.websocket.CloseReason.CloseCodes; 61 import javax.websocket.DeploymentException; 62 import javax.websocket.Endpoint; 63 import javax.websocket.Extension; 64 import javax.websocket.HandshakeResponse; 65 import javax.websocket.Session; 66 import javax.websocket.WebSocketContainer; 67 68 import org.apache.juli.logging.Log; 69 import org.apache.juli.logging.LogFactory; 70 import org.apache.tomcat.InstanceManager; 71 import org.apache.tomcat.util.buf.StringUtils; 72 import org.apache.tomcat.util.codec.binary.Base64; 73 import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; 74 import org.apache.tomcat.util.res.StringManager; 75 import nginx.unit.websocket.pojo.PojoEndpointClient; 76 77 public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { 78 79 private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class); 80 private static final Random RANDOM = new Random(); 81 private static final byte[] CRLF = new byte[] { 13, 10 }; 82 83 private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1); 84 private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1); 85 private static final byte[] HTTP_VERSION_BYTES = 86 " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1); 87 88 private volatile AsynchronousChannelGroup asynchronousChannelGroup = null; 89 private final Object asynchronousChannelGroupLock = new Object(); 90 91 private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static 92 private final Map<Endpoint, Set<WsSession>> endpointSessionMap = 93 new HashMap<>(); 94 private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>(); 95 private final Object endPointSessionMapLock = new Object(); 96 97 private long defaultAsyncTimeout = -1; 98 private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; 99 private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; 100 private volatile long defaultMaxSessionIdleTimeout = 0; 101 private int backgroundProcessCount = 0; 102 private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD; 103 104 private InstanceManager instanceManager; 105 getInstanceManager()106 InstanceManager getInstanceManager() { 107 return instanceManager; 108 } 109 setInstanceManager(InstanceManager instanceManager)110 protected void setInstanceManager(InstanceManager instanceManager) { 111 this.instanceManager = instanceManager; 112 } 113 114 @Override connectToServer(Object pojo, URI path)115 public Session connectToServer(Object pojo, URI path) 116 throws DeploymentException { 117 118 ClientEndpoint annotation = 119 pojo.getClass().getAnnotation(ClientEndpoint.class); 120 if (annotation == null) { 121 throw new DeploymentException( 122 sm.getString("wsWebSocketContainer.missingAnnotation", 123 pojo.getClass().getName())); 124 } 125 126 Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders())); 127 128 Class<? extends ClientEndpointConfig.Configurator> configuratorClazz = 129 annotation.configurator(); 130 131 ClientEndpointConfig.Configurator configurator = null; 132 if (!ClientEndpointConfig.Configurator.class.equals( 133 configuratorClazz)) { 134 try { 135 configurator = configuratorClazz.getConstructor().newInstance(); 136 } catch (ReflectiveOperationException e) { 137 throw new DeploymentException(sm.getString( 138 "wsWebSocketContainer.defaultConfiguratorFail"), e); 139 } 140 } 141 142 ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create(); 143 // Avoid NPE when using RI API JAR - see BZ 56343 144 if (configurator != null) { 145 builder.configurator(configurator); 146 } 147 ClientEndpointConfig config = builder. 148 decoders(Arrays.asList(annotation.decoders())). 149 encoders(Arrays.asList(annotation.encoders())). 150 preferredSubprotocols(Arrays.asList(annotation.subprotocols())). 151 build(); 152 return connectToServer(ep, config, path); 153 } 154 155 156 @Override connectToServer(Class<?> annotatedEndpointClass, URI path)157 public Session connectToServer(Class<?> annotatedEndpointClass, URI path) 158 throws DeploymentException { 159 160 Object pojo; 161 try { 162 pojo = annotatedEndpointClass.getConstructor().newInstance(); 163 } catch (ReflectiveOperationException e) { 164 throw new DeploymentException(sm.getString( 165 "wsWebSocketContainer.endpointCreateFail", 166 annotatedEndpointClass.getName()), e); 167 } 168 169 return connectToServer(pojo, path); 170 } 171 172 173 @Override connectToServer(Class<? extends Endpoint> clazz, ClientEndpointConfig clientEndpointConfiguration, URI path)174 public Session connectToServer(Class<? extends Endpoint> clazz, 175 ClientEndpointConfig clientEndpointConfiguration, URI path) 176 throws DeploymentException { 177 178 Endpoint endpoint; 179 try { 180 endpoint = clazz.getConstructor().newInstance(); 181 } catch (ReflectiveOperationException e) { 182 throw new DeploymentException(sm.getString( 183 "wsWebSocketContainer.endpointCreateFail", clazz.getName()), 184 e); 185 } 186 187 return connectToServer(endpoint, clientEndpointConfiguration, path); 188 } 189 190 191 @Override connectToServer(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path)192 public Session connectToServer(Endpoint endpoint, 193 ClientEndpointConfig clientEndpointConfiguration, URI path) 194 throws DeploymentException { 195 return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>()); 196 } 197 connectToServerRecursive(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path, Set<URI> redirectSet)198 private Session connectToServerRecursive(Endpoint endpoint, 199 ClientEndpointConfig clientEndpointConfiguration, URI path, 200 Set<URI> redirectSet) 201 throws DeploymentException { 202 203 boolean secure = false; 204 ByteBuffer proxyConnect = null; 205 URI proxyPath; 206 207 // Validate scheme (and build proxyPath) 208 String scheme = path.getScheme(); 209 if ("ws".equalsIgnoreCase(scheme)) { 210 proxyPath = URI.create("http" + path.toString().substring(2)); 211 } else if ("wss".equalsIgnoreCase(scheme)) { 212 proxyPath = URI.create("https" + path.toString().substring(3)); 213 secure = true; 214 } else { 215 throw new DeploymentException(sm.getString( 216 "wsWebSocketContainer.pathWrongScheme", scheme)); 217 } 218 219 // Validate host 220 String host = path.getHost(); 221 if (host == null) { 222 throw new DeploymentException( 223 sm.getString("wsWebSocketContainer.pathNoHost")); 224 } 225 int port = path.getPort(); 226 227 SocketAddress sa = null; 228 229 // Check to see if a proxy is configured. Javadoc indicates return value 230 // will never be null 231 List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath); 232 Proxy selectedProxy = null; 233 for (Proxy proxy : proxies) { 234 if (proxy.type().equals(Proxy.Type.HTTP)) { 235 sa = proxy.address(); 236 if (sa instanceof InetSocketAddress) { 237 InetSocketAddress inet = (InetSocketAddress) sa; 238 if (inet.isUnresolved()) { 239 sa = new InetSocketAddress(inet.getHostName(), inet.getPort()); 240 } 241 } 242 selectedProxy = proxy; 243 break; 244 } 245 } 246 247 // If the port is not explicitly specified, compute it based on the 248 // scheme 249 if (port == -1) { 250 if ("ws".equalsIgnoreCase(scheme)) { 251 port = 80; 252 } else { 253 // Must be wss due to scheme validation above 254 port = 443; 255 } 256 } 257 258 // If sa is null, no proxy is configured so need to create sa 259 if (sa == null) { 260 sa = new InetSocketAddress(host, port); 261 } else { 262 proxyConnect = createProxyRequest(host, port); 263 } 264 265 // Create the initial HTTP request to open the WebSocket connection 266 Map<String, List<String>> reqHeaders = createRequestHeaders(host, port, 267 clientEndpointConfiguration); 268 clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders); 269 if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null 270 && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) { 271 List<String> originValues = new ArrayList<>(1); 272 originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE); 273 reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues); 274 } 275 ByteBuffer request = createRequest(path, reqHeaders); 276 277 AsynchronousSocketChannel socketChannel; 278 try { 279 socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup()); 280 } catch (IOException ioe) { 281 throw new DeploymentException(sm.getString( 282 "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe); 283 } 284 285 Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties(); 286 287 // Get the connection timeout 288 long timeout = Constants.IO_TIMEOUT_MS_DEFAULT; 289 String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY); 290 if (timeoutValue != null) { 291 timeout = Long.valueOf(timeoutValue).intValue(); 292 } 293 294 // Set-up 295 // Same size as the WsFrame input buffer 296 ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize()); 297 String subProtocol; 298 boolean success = false; 299 List<Extension> extensionsAgreed = new ArrayList<>(); 300 Transformation transformation = null; 301 302 // Open the connection 303 Future<Void> fConnect = socketChannel.connect(sa); 304 AsyncChannelWrapper channel = null; 305 306 if (proxyConnect != null) { 307 try { 308 fConnect.get(timeout, TimeUnit.MILLISECONDS); 309 // Proxy CONNECT is clear text 310 channel = new AsyncChannelWrapperNonSecure(socketChannel); 311 writeRequest(channel, proxyConnect, timeout); 312 HttpResponse httpResponse = processResponse(response, channel, timeout); 313 if (httpResponse.getStatus() != 200) { 314 throw new DeploymentException(sm.getString( 315 "wsWebSocketContainer.proxyConnectFail", selectedProxy, 316 Integer.toString(httpResponse.getStatus()))); 317 } 318 } catch (TimeoutException | InterruptedException | ExecutionException | 319 EOFException e) { 320 if (channel != null) { 321 channel.close(); 322 } 323 throw new DeploymentException( 324 sm.getString("wsWebSocketContainer.httpRequestFailed"), e); 325 } 326 } 327 328 if (secure) { 329 // Regardless of whether a non-secure wrapper was created for a 330 // proxy CONNECT, need to use TLS from this point on so wrap the 331 // original AsynchronousSocketChannel 332 SSLEngine sslEngine = createSSLEngine(userProperties, host, port); 333 channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); 334 } else if (channel == null) { 335 // Only need to wrap as this point if it wasn't wrapped to process a 336 // proxy CONNECT 337 channel = new AsyncChannelWrapperNonSecure(socketChannel); 338 } 339 340 try { 341 fConnect.get(timeout, TimeUnit.MILLISECONDS); 342 343 Future<Void> fHandshake = channel.handshake(); 344 fHandshake.get(timeout, TimeUnit.MILLISECONDS); 345 346 writeRequest(channel, request, timeout); 347 348 HttpResponse httpResponse = processResponse(response, channel, timeout); 349 350 // Check maximum permitted redirects 351 int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT; 352 String maxRedirectsValue = 353 (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY); 354 if (maxRedirectsValue != null) { 355 maxRedirects = Integer.parseInt(maxRedirectsValue); 356 } 357 358 if (httpResponse.status != 101) { 359 if(isRedirectStatus(httpResponse.status)){ 360 List<String> locationHeader = 361 httpResponse.getHandshakeResponse().getHeaders().get( 362 Constants.LOCATION_HEADER_NAME); 363 364 if (locationHeader == null || locationHeader.isEmpty() || 365 locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) { 366 throw new DeploymentException(sm.getString( 367 "wsWebSocketContainer.missingLocationHeader", 368 Integer.toString(httpResponse.status))); 369 } 370 371 URI redirectLocation = URI.create(locationHeader.get(0)).normalize(); 372 373 if (!redirectLocation.isAbsolute()) { 374 redirectLocation = path.resolve(redirectLocation); 375 } 376 377 String redirectScheme = redirectLocation.getScheme().toLowerCase(); 378 379 if (redirectScheme.startsWith("http")) { 380 redirectLocation = new URI(redirectScheme.replace("http", "ws"), 381 redirectLocation.getUserInfo(), redirectLocation.getHost(), 382 redirectLocation.getPort(), redirectLocation.getPath(), 383 redirectLocation.getQuery(), redirectLocation.getFragment()); 384 } 385 386 if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) { 387 throw new DeploymentException(sm.getString( 388 "wsWebSocketContainer.redirectThreshold", redirectLocation, 389 Integer.toString(redirectSet.size()), 390 Integer.toString(maxRedirects))); 391 } 392 393 return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet); 394 395 } 396 397 else if (httpResponse.status == 401) { 398 399 if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { 400 throw new DeploymentException(sm.getString( 401 "wsWebSocketContainer.failedAuthentication", 402 Integer.valueOf(httpResponse.status))); 403 } 404 405 List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse() 406 .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME); 407 408 if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() || 409 wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) { 410 throw new DeploymentException(sm.getString( 411 "wsWebSocketContainer.missingWWWAuthenticateHeader", 412 Integer.toString(httpResponse.status))); 413 } 414 415 String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0]; 416 String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1) 417 .split("\\s", 3)[1]; 418 419 Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme); 420 421 if (auth == null) { 422 throw new DeploymentException( 423 sm.getString("wsWebSocketContainer.unsupportedAuthScheme", 424 Integer.valueOf(httpResponse.status), authScheme)); 425 } 426 427 userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization( 428 requestUri, wwwAuthenticateHeaders.get(0), userProperties)); 429 430 return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet); 431 432 } 433 434 else { 435 throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", 436 Integer.toString(httpResponse.status))); 437 } 438 } 439 HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse(); 440 clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse); 441 442 // Sub-protocol 443 List<String> protocolHeaders = handshakeResponse.getHeaders().get( 444 Constants.WS_PROTOCOL_HEADER_NAME); 445 if (protocolHeaders == null || protocolHeaders.size() == 0) { 446 subProtocol = null; 447 } else if (protocolHeaders.size() == 1) { 448 subProtocol = protocolHeaders.get(0); 449 } else { 450 throw new DeploymentException( 451 sm.getString("wsWebSocketContainer.invalidSubProtocol")); 452 } 453 454 // Extensions 455 // Should normally only be one header but handle the case of 456 // multiple headers 457 List<String> extHeaders = handshakeResponse.getHeaders().get( 458 Constants.WS_EXTENSIONS_HEADER_NAME); 459 if (extHeaders != null) { 460 for (String extHeader : extHeaders) { 461 Util.parseExtensionHeader(extensionsAgreed, extHeader); 462 } 463 } 464 465 // Build the transformations 466 TransformationFactory factory = TransformationFactory.getInstance(); 467 for (Extension extension : extensionsAgreed) { 468 List<List<Extension.Parameter>> wrapper = new ArrayList<>(1); 469 wrapper.add(extension.getParameters()); 470 Transformation t = factory.create(extension.getName(), wrapper, false); 471 if (t == null) { 472 throw new DeploymentException(sm.getString( 473 "wsWebSocketContainer.invalidExtensionParameters")); 474 } 475 if (transformation == null) { 476 transformation = t; 477 } else { 478 transformation.setNext(t); 479 } 480 } 481 482 success = true; 483 } catch (ExecutionException | InterruptedException | SSLException | 484 EOFException | TimeoutException | URISyntaxException | AuthenticationException e) { 485 throw new DeploymentException( 486 sm.getString("wsWebSocketContainer.httpRequestFailed"), e); 487 } finally { 488 if (!success) { 489 channel.close(); 490 } 491 } 492 493 // Switch to WebSocket 494 WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel); 495 496 WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient, 497 this, null, null, null, null, null, extensionsAgreed, 498 subProtocol, Collections.<String,String>emptyMap(), secure, 499 clientEndpointConfiguration, null); 500 501 WsFrameClient wsFrameClient = new WsFrameClient(response, channel, 502 wsSession, transformation); 503 // WsFrame adds the necessary final transformations. Copy the 504 // completed transformation chain to the remote end point. 505 wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation()); 506 507 endpoint.onOpen(wsSession, clientEndpointConfiguration); 508 registerSession(endpoint, wsSession); 509 510 /* It is possible that the server sent one or more messages as soon as 511 * the WebSocket connection was established. Depending on the exact 512 * timing of when those messages were sent they could be sat in the 513 * input buffer waiting to be read and will not trigger a "data 514 * available to read" event. Therefore, it is necessary to process the 515 * input buffer here. Note that this happens on the current thread which 516 * means that this thread will be used for any onMessage notifications. 517 * This is a special case. Subsequent "data available to read" events 518 * will be handled by threads from the AsyncChannelGroup's executor. 519 */ 520 wsFrameClient.startInputProcessing(); 521 522 return wsSession; 523 } 524 525 writeRequest(AsyncChannelWrapper channel, ByteBuffer request, long timeout)526 private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, 527 long timeout) throws TimeoutException, InterruptedException, ExecutionException { 528 int toWrite = request.limit(); 529 530 Future<Integer> fWrite = channel.write(request); 531 Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); 532 toWrite -= thisWrite.intValue(); 533 534 while (toWrite > 0) { 535 fWrite = channel.write(request); 536 thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); 537 toWrite -= thisWrite.intValue(); 538 } 539 } 540 541 isRedirectStatus(int httpResponseCode)542 private static boolean isRedirectStatus(int httpResponseCode) { 543 544 boolean isRedirect = false; 545 546 switch (httpResponseCode) { 547 case Constants.MULTIPLE_CHOICES: 548 case Constants.MOVED_PERMANENTLY: 549 case Constants.FOUND: 550 case Constants.SEE_OTHER: 551 case Constants.USE_PROXY: 552 case Constants.TEMPORARY_REDIRECT: 553 isRedirect = true; 554 break; 555 default: 556 break; 557 } 558 559 return isRedirect; 560 } 561 562 createProxyRequest(String host, int port)563 private static ByteBuffer createProxyRequest(String host, int port) { 564 StringBuilder request = new StringBuilder(); 565 request.append("CONNECT "); 566 request.append(host); 567 request.append(':'); 568 request.append(port); 569 570 request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: "); 571 request.append(host); 572 request.append(':'); 573 request.append(port); 574 575 request.append("\r\n\r\n"); 576 577 byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1); 578 return ByteBuffer.wrap(bytes); 579 } 580 registerSession(Endpoint endpoint, WsSession wsSession)581 protected void registerSession(Endpoint endpoint, WsSession wsSession) { 582 583 if (!wsSession.isOpen()) { 584 // The session was closed during onOpen. No need to register it. 585 return; 586 } 587 synchronized (endPointSessionMapLock) { 588 if (endpointSessionMap.size() == 0) { 589 BackgroundProcessManager.getInstance().register(this); 590 } 591 Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); 592 if (wsSessions == null) { 593 wsSessions = new HashSet<>(); 594 endpointSessionMap.put(endpoint, wsSessions); 595 } 596 wsSessions.add(wsSession); 597 } 598 sessions.put(wsSession, wsSession); 599 } 600 601 unregisterSession(Endpoint endpoint, WsSession wsSession)602 protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { 603 604 synchronized (endPointSessionMapLock) { 605 Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); 606 if (wsSessions != null) { 607 wsSessions.remove(wsSession); 608 if (wsSessions.size() == 0) { 609 endpointSessionMap.remove(endpoint); 610 } 611 } 612 if (endpointSessionMap.size() == 0) { 613 BackgroundProcessManager.getInstance().unregister(this); 614 } 615 } 616 sessions.remove(wsSession); 617 } 618 619 getOpenSessions(Endpoint endpoint)620 Set<Session> getOpenSessions(Endpoint endpoint) { 621 HashSet<Session> result = new HashSet<>(); 622 synchronized (endPointSessionMapLock) { 623 Set<WsSession> sessions = endpointSessionMap.get(endpoint); 624 if (sessions != null) { 625 result.addAll(sessions); 626 } 627 } 628 return result; 629 } 630 createRequestHeaders(String host, int port, ClientEndpointConfig clientEndpointConfiguration)631 private static Map<String, List<String>> createRequestHeaders(String host, int port, 632 ClientEndpointConfig clientEndpointConfiguration) { 633 634 Map<String, List<String>> headers = new HashMap<>(); 635 List<Extension> extensions = clientEndpointConfiguration.getExtensions(); 636 List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols(); 637 Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties(); 638 639 if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { 640 List<String> authValues = new ArrayList<>(1); 641 authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME)); 642 headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues); 643 } 644 645 // Host header 646 List<String> hostValues = new ArrayList<>(1); 647 if (port == -1) { 648 hostValues.add(host); 649 } else { 650 hostValues.add(host + ':' + port); 651 } 652 653 headers.put(Constants.HOST_HEADER_NAME, hostValues); 654 655 // Upgrade header 656 List<String> upgradeValues = new ArrayList<>(1); 657 upgradeValues.add(Constants.UPGRADE_HEADER_VALUE); 658 headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues); 659 660 // Connection header 661 List<String> connectionValues = new ArrayList<>(1); 662 connectionValues.add(Constants.CONNECTION_HEADER_VALUE); 663 headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues); 664 665 // WebSocket version header 666 List<String> wsVersionValues = new ArrayList<>(1); 667 wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE); 668 headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues); 669 670 // WebSocket key 671 List<String> wsKeyValues = new ArrayList<>(1); 672 wsKeyValues.add(generateWsKeyValue()); 673 headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues); 674 675 // WebSocket sub-protocols 676 if (subProtocols != null && subProtocols.size() > 0) { 677 headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols); 678 } 679 680 // WebSocket extensions 681 if (extensions != null && extensions.size() > 0) { 682 headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, 683 generateExtensionHeaders(extensions)); 684 } 685 686 return headers; 687 } 688 689 generateExtensionHeaders(List<Extension> extensions)690 private static List<String> generateExtensionHeaders(List<Extension> extensions) { 691 List<String> result = new ArrayList<>(extensions.size()); 692 for (Extension extension : extensions) { 693 StringBuilder header = new StringBuilder(); 694 header.append(extension.getName()); 695 for (Extension.Parameter param : extension.getParameters()) { 696 header.append(';'); 697 header.append(param.getName()); 698 String value = param.getValue(); 699 if (value != null && value.length() > 0) { 700 header.append('='); 701 header.append(value); 702 } 703 } 704 result.add(header.toString()); 705 } 706 return result; 707 } 708 709 generateWsKeyValue()710 private static String generateWsKeyValue() { 711 byte[] keyBytes = new byte[16]; 712 RANDOM.nextBytes(keyBytes); 713 return Base64.encodeBase64String(keyBytes); 714 } 715 716 createRequest(URI uri, Map<String,List<String>> reqHeaders)717 private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) { 718 ByteBuffer result = ByteBuffer.allocate(4 * 1024); 719 720 // Request line 721 result.put(GET_BYTES); 722 if (null == uri.getPath() || "".equals(uri.getPath())) { 723 result.put(ROOT_URI_BYTES); 724 } else { 725 result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1)); 726 } 727 String query = uri.getRawQuery(); 728 if (query != null) { 729 result.put((byte) '?'); 730 result.put(query.getBytes(StandardCharsets.ISO_8859_1)); 731 } 732 result.put(HTTP_VERSION_BYTES); 733 734 // Headers 735 for (Entry<String, List<String>> entry : reqHeaders.entrySet()) { 736 result = addHeader(result, entry.getKey(), entry.getValue()); 737 } 738 739 // Terminating CRLF 740 result.put(CRLF); 741 742 result.flip(); 743 744 return result; 745 } 746 747 addHeader(ByteBuffer result, String key, List<String> values)748 private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) { 749 if (values.isEmpty()) { 750 return result; 751 } 752 753 result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1)); 754 result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1)); 755 result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1)); 756 result = putWithExpand(result, CRLF); 757 758 return result; 759 } 760 761 putWithExpand(ByteBuffer input, byte[] bytes)762 private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) { 763 if (bytes.length > input.remaining()) { 764 int newSize; 765 if (bytes.length > input.capacity()) { 766 newSize = 2 * bytes.length; 767 } else { 768 newSize = input.capacity() * 2; 769 } 770 ByteBuffer expanded = ByteBuffer.allocate(newSize); 771 input.flip(); 772 expanded.put(input); 773 input = expanded; 774 } 775 return input.put(bytes); 776 } 777 778 779 /** 780 * Process response, blocking until HTTP response has been fully received. 781 * @throws ExecutionException 782 * @throws InterruptedException 783 * @throws DeploymentException 784 * @throws TimeoutException 785 */ processResponse(ByteBuffer response, AsyncChannelWrapper channel, long timeout)786 private HttpResponse processResponse(ByteBuffer response, 787 AsyncChannelWrapper channel, long timeout) throws InterruptedException, 788 ExecutionException, DeploymentException, EOFException, 789 TimeoutException { 790 791 Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); 792 793 int status = 0; 794 boolean readStatus = false; 795 boolean readHeaders = false; 796 String line = null; 797 while (!readHeaders) { 798 // On entering loop buffer will be empty and at the start of a new 799 // loop the buffer will have been fully read. 800 response.clear(); 801 // Blocking read 802 Future<Integer> read = channel.read(response); 803 Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS); 804 if (bytesRead.intValue() == -1) { 805 throw new EOFException(); 806 } 807 response.flip(); 808 while (response.hasRemaining() && !readHeaders) { 809 if (line == null) { 810 line = readLine(response); 811 } else { 812 line += readLine(response); 813 } 814 if ("\r\n".equals(line)) { 815 readHeaders = true; 816 } else if (line.endsWith("\r\n")) { 817 if (readStatus) { 818 parseHeaders(line, headers); 819 } else { 820 status = parseStatus(line); 821 readStatus = true; 822 } 823 line = null; 824 } 825 } 826 } 827 828 return new HttpResponse(status, new WsHandshakeResponse(headers)); 829 } 830 831 parseStatus(String line)832 private int parseStatus(String line) throws DeploymentException { 833 // This client only understands HTTP 1. 834 // RFC2616 is case specific 835 String[] parts = line.trim().split(" "); 836 // CONNECT for proxy may return a 1.0 response 837 if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) { 838 throw new DeploymentException(sm.getString( 839 "wsWebSocketContainer.invalidStatus", line)); 840 } 841 try { 842 return Integer.parseInt(parts[1]); 843 } catch (NumberFormatException nfe) { 844 throw new DeploymentException(sm.getString( 845 "wsWebSocketContainer.invalidStatus", line)); 846 } 847 } 848 849 parseHeaders(String line, Map<String,List<String>> headers)850 private void parseHeaders(String line, Map<String,List<String>> headers) { 851 // Treat headers as single values by default. 852 853 int index = line.indexOf(':'); 854 if (index == -1) { 855 log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line)); 856 return; 857 } 858 // Header names are case insensitive so always use lower case 859 String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH); 860 // Multi-value headers are stored as a single header and the client is 861 // expected to handle splitting into individual values 862 String headerValue = line.substring(index + 1).trim(); 863 864 List<String> values = headers.get(headerName); 865 if (values == null) { 866 values = new ArrayList<>(1); 867 headers.put(headerName, values); 868 } 869 values.add(headerValue); 870 } 871 readLine(ByteBuffer response)872 private String readLine(ByteBuffer response) { 873 // All ISO-8859-1 874 StringBuilder sb = new StringBuilder(); 875 876 char c = 0; 877 while (response.hasRemaining()) { 878 c = (char) response.get(); 879 sb.append(c); 880 if (c == 10) { 881 break; 882 } 883 } 884 885 return sb.toString(); 886 } 887 888 createSSLEngine(Map<String,Object> userProperties, String host, int port)889 private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port) 890 throws DeploymentException { 891 892 try { 893 // See if a custom SSLContext has been provided 894 SSLContext sslContext = 895 (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY); 896 897 if (sslContext == null) { 898 // Create the SSL Context 899 sslContext = SSLContext.getInstance("TLS"); 900 901 // Trust store 902 String sslTrustStoreValue = 903 (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY); 904 if (sslTrustStoreValue != null) { 905 String sslTrustStorePwdValue = (String) userProperties.get( 906 Constants.SSL_TRUSTSTORE_PWD_PROPERTY); 907 if (sslTrustStorePwdValue == null) { 908 sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT; 909 } 910 911 File keyStoreFile = new File(sslTrustStoreValue); 912 KeyStore ks = KeyStore.getInstance("JKS"); 913 try (InputStream is = new FileInputStream(keyStoreFile)) { 914 ks.load(is, sslTrustStorePwdValue.toCharArray()); 915 } 916 917 TrustManagerFactory tmf = TrustManagerFactory.getInstance( 918 TrustManagerFactory.getDefaultAlgorithm()); 919 tmf.init(ks); 920 921 sslContext.init(null, tmf.getTrustManagers(), null); 922 } else { 923 sslContext.init(null, null, null); 924 } 925 } 926 927 SSLEngine engine = sslContext.createSSLEngine(host, port); 928 929 String sslProtocolsValue = 930 (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY); 931 if (sslProtocolsValue != null) { 932 engine.setEnabledProtocols(sslProtocolsValue.split(",")); 933 } 934 935 engine.setUseClientMode(true); 936 937 // Enable host verification 938 // Start with current settings (returns a copy) 939 SSLParameters sslParams = engine.getSSLParameters(); 940 // Use HTTPS since WebSocket starts over HTTP(S) 941 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 942 // Write the parameters back 943 engine.setSSLParameters(sslParams); 944 945 return engine; 946 } catch (Exception e) { 947 throw new DeploymentException(sm.getString( 948 "wsWebSocketContainer.sslEngineFail"), e); 949 } 950 } 951 952 953 @Override getDefaultMaxSessionIdleTimeout()954 public long getDefaultMaxSessionIdleTimeout() { 955 return defaultMaxSessionIdleTimeout; 956 } 957 958 959 @Override setDefaultMaxSessionIdleTimeout(long timeout)960 public void setDefaultMaxSessionIdleTimeout(long timeout) { 961 this.defaultMaxSessionIdleTimeout = timeout; 962 } 963 964 965 @Override getDefaultMaxBinaryMessageBufferSize()966 public int getDefaultMaxBinaryMessageBufferSize() { 967 return maxBinaryMessageBufferSize; 968 } 969 970 971 @Override setDefaultMaxBinaryMessageBufferSize(int max)972 public void setDefaultMaxBinaryMessageBufferSize(int max) { 973 maxBinaryMessageBufferSize = max; 974 } 975 976 977 @Override getDefaultMaxTextMessageBufferSize()978 public int getDefaultMaxTextMessageBufferSize() { 979 return maxTextMessageBufferSize; 980 } 981 982 983 @Override setDefaultMaxTextMessageBufferSize(int max)984 public void setDefaultMaxTextMessageBufferSize(int max) { 985 maxTextMessageBufferSize = max; 986 } 987 988 989 /** 990 * {@inheritDoc} 991 * 992 * Currently, this implementation does not support any extensions. 993 */ 994 @Override getInstalledExtensions()995 public Set<Extension> getInstalledExtensions() { 996 return Collections.emptySet(); 997 } 998 999 1000 /** 1001 * {@inheritDoc} 1002 * 1003 * The default value for this implementation is -1. 1004 */ 1005 @Override getDefaultAsyncSendTimeout()1006 public long getDefaultAsyncSendTimeout() { 1007 return defaultAsyncTimeout; 1008 } 1009 1010 1011 /** 1012 * {@inheritDoc} 1013 * 1014 * The default value for this implementation is -1. 1015 */ 1016 @Override setAsyncSendTimeout(long timeout)1017 public void setAsyncSendTimeout(long timeout) { 1018 this.defaultAsyncTimeout = timeout; 1019 } 1020 1021 1022 /** 1023 * Cleans up the resources still in use by WebSocket sessions created from 1024 * this container. This includes closing sessions and cancelling 1025 * {@link Future}s associated with blocking read/writes. 1026 */ destroy()1027 public void destroy() { 1028 CloseReason cr = new CloseReason( 1029 CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown")); 1030 1031 for (WsSession session : sessions.keySet()) { 1032 try { 1033 session.close(cr); 1034 } catch (IOException ioe) { 1035 log.debug(sm.getString( 1036 "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe); 1037 } 1038 } 1039 1040 // Only unregister with AsyncChannelGroupUtil if this instance 1041 // registered with it 1042 if (asynchronousChannelGroup != null) { 1043 synchronized (asynchronousChannelGroupLock) { 1044 if (asynchronousChannelGroup != null) { 1045 AsyncChannelGroupUtil.unregister(); 1046 asynchronousChannelGroup = null; 1047 } 1048 } 1049 } 1050 } 1051 1052 getAsynchronousChannelGroup()1053 private AsynchronousChannelGroup getAsynchronousChannelGroup() { 1054 // Use AsyncChannelGroupUtil to share a common group amongst all 1055 // WebSocket clients 1056 AsynchronousChannelGroup result = asynchronousChannelGroup; 1057 if (result == null) { 1058 synchronized (asynchronousChannelGroupLock) { 1059 if (asynchronousChannelGroup == null) { 1060 asynchronousChannelGroup = AsyncChannelGroupUtil.register(); 1061 } 1062 result = asynchronousChannelGroup; 1063 } 1064 } 1065 return result; 1066 } 1067 1068 1069 // ----------------------------------------------- BackgroundProcess methods 1070 1071 @Override backgroundProcess()1072 public void backgroundProcess() { 1073 // This method gets called once a second. 1074 backgroundProcessCount ++; 1075 if (backgroundProcessCount >= processPeriod) { 1076 backgroundProcessCount = 0; 1077 1078 for (WsSession wsSession : sessions.keySet()) { 1079 wsSession.checkExpiration(); 1080 } 1081 } 1082 1083 } 1084 1085 1086 @Override setProcessPeriod(int period)1087 public void setProcessPeriod(int period) { 1088 this.processPeriod = period; 1089 } 1090 1091 1092 /** 1093 * {@inheritDoc} 1094 * 1095 * The default value is 10 which means session expirations are processed 1096 * every 10 seconds. 1097 */ 1098 @Override getProcessPeriod()1099 public int getProcessPeriod() { 1100 return processPeriod; 1101 } 1102 1103 1104 private static class HttpResponse { 1105 private final int status; 1106 private final HandshakeResponse handshakeResponse; 1107 HttpResponse(int status, HandshakeResponse handshakeResponse)1108 public HttpResponse(int status, HandshakeResponse handshakeResponse) { 1109 this.status = status; 1110 this.handshakeResponse = handshakeResponse; 1111 } 1112 1113 getStatus()1114 public int getStatus() { 1115 return status; 1116 } 1117 1118 getHandshakeResponse()1119 public HandshakeResponse getHandshakeResponse() { 1120 return handshakeResponse; 1121 } 1122 } 1123 } 1124