1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket.server; 18 19 import java.io.IOException; 20 import java.util.Arrays; 21 import java.util.Collections; 22 import java.util.Comparator; 23 import java.util.EnumSet; 24 import java.util.Map; 25 import java.util.Set; 26 import java.util.SortedSet; 27 import java.util.TreeSet; 28 import java.util.concurrent.ConcurrentHashMap; 29 30 import javax.servlet.DispatcherType; 31 import javax.servlet.FilterRegistration; 32 import javax.servlet.ServletContext; 33 import javax.servlet.ServletException; 34 import javax.servlet.http.HttpServletRequest; 35 import javax.servlet.http.HttpServletResponse; 36 import javax.websocket.CloseReason; 37 import javax.websocket.CloseReason.CloseCodes; 38 import javax.websocket.DeploymentException; 39 import javax.websocket.Encoder; 40 import javax.websocket.Endpoint; 41 import javax.websocket.server.ServerContainer; 42 import javax.websocket.server.ServerEndpoint; 43 import javax.websocket.server.ServerEndpointConfig; 44 import javax.websocket.server.ServerEndpointConfig.Configurator; 45 46 import org.apache.tomcat.InstanceManager; 47 import org.apache.tomcat.util.res.StringManager; 48 import nginx.unit.websocket.WsSession; 49 import nginx.unit.websocket.WsWebSocketContainer; 50 import nginx.unit.websocket.pojo.PojoMethodMapping; 51 52 /** 53 * Provides a per class loader (i.e. per web application) instance of a 54 * ServerContainer. Web application wide defaults may be configured by setting 55 * the following servlet context initialisation parameters to the desired 56 * values. 57 * <ul> 58 * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> 59 * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> 60 * </ul> 61 */ 62 public class WsServerContainer extends WsWebSocketContainer 63 implements ServerContainer { 64 65 private static final StringManager sm = StringManager.getManager(WsServerContainer.class); 66 67 private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = 68 new CloseReason(CloseCodes.VIOLATED_POLICY, 69 "This connection was established under an authenticated " + 70 "HTTP session that has ended."); 71 72 private final ServletContext servletContext; 73 private final Map<String,ServerEndpointConfig> configExactMatchMap = 74 new ConcurrentHashMap<>(); 75 private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap = 76 new ConcurrentHashMap<>(); 77 private volatile boolean enforceNoAddAfterHandshake = 78 nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE; 79 private volatile boolean addAllowed = true; 80 private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>(); 81 private volatile boolean endpointsRegistered = false; 82 WsServerContainer(ServletContext servletContext)83 WsServerContainer(ServletContext servletContext) { 84 85 this.servletContext = servletContext; 86 setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName())); 87 88 // Configure servlet context wide defaults 89 String value = servletContext.getInitParameter( 90 Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); 91 if (value != null) { 92 setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value)); 93 } 94 95 value = servletContext.getInitParameter( 96 Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); 97 if (value != null) { 98 setDefaultMaxTextMessageBufferSize(Integer.parseInt(value)); 99 } 100 101 value = servletContext.getInitParameter( 102 Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM); 103 if (value != null) { 104 setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value)); 105 } 106 107 FilterRegistration.Dynamic fr = servletContext.addFilter( 108 "Tomcat WebSocket (JSR356) Filter", new WsFilter()); 109 fr.setAsyncSupported(true); 110 111 EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST, 112 DispatcherType.FORWARD); 113 114 fr.addMappingForUrlPatterns(types, true, "/*"); 115 } 116 117 118 /** 119 * Published the provided endpoint implementation at the specified path with 120 * the specified configuration. {@link #WsServerContainer(ServletContext)} 121 * must be called before calling this method. 122 * 123 * @param sec The configuration to use when creating endpoint instances 124 * @throws DeploymentException if the endpoint cannot be published as 125 * requested 126 */ 127 @Override addEndpoint(ServerEndpointConfig sec)128 public void addEndpoint(ServerEndpointConfig sec) 129 throws DeploymentException { 130 131 if (enforceNoAddAfterHandshake && !addAllowed) { 132 throw new DeploymentException( 133 sm.getString("serverContainer.addNotAllowed")); 134 } 135 136 if (servletContext == null) { 137 throw new DeploymentException( 138 sm.getString("serverContainer.servletContextMissing")); 139 } 140 String path = sec.getPath(); 141 142 // Add method mapping to user properties 143 PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), 144 sec.getDecoders(), path); 145 if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null 146 || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) { 147 sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY, 148 methodMapping); 149 } 150 151 UriTemplate uriTemplate = new UriTemplate(path); 152 if (uriTemplate.hasParameters()) { 153 Integer key = Integer.valueOf(uriTemplate.getSegmentCount()); 154 SortedSet<TemplatePathMatch> templateMatches = 155 configTemplateMatchMap.get(key); 156 if (templateMatches == null) { 157 // Ensure that if concurrent threads execute this block they 158 // both end up using the same TreeSet instance 159 templateMatches = new TreeSet<>( 160 TemplatePathMatchComparator.getInstance()); 161 configTemplateMatchMap.putIfAbsent(key, templateMatches); 162 templateMatches = configTemplateMatchMap.get(key); 163 } 164 if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) { 165 // Duplicate uriTemplate; 166 throw new DeploymentException( 167 sm.getString("serverContainer.duplicatePaths", path, 168 sec.getEndpointClass(), 169 sec.getEndpointClass())); 170 } 171 } else { 172 // Exact match 173 ServerEndpointConfig old = configExactMatchMap.put(path, sec); 174 if (old != null) { 175 // Duplicate path mappings 176 throw new DeploymentException( 177 sm.getString("serverContainer.duplicatePaths", path, 178 old.getEndpointClass(), 179 sec.getEndpointClass())); 180 } 181 } 182 183 endpointsRegistered = true; 184 } 185 186 187 /** 188 * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)} 189 * for publishing plain old java objects (POJOs) that have been annotated as 190 * WebSocket endpoints. 191 * 192 * @param pojo The annotated POJO 193 */ 194 @Override addEndpoint(Class<?> pojo)195 public void addEndpoint(Class<?> pojo) throws DeploymentException { 196 197 ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class); 198 if (annotation == null) { 199 throw new DeploymentException( 200 sm.getString("serverContainer.missingAnnotation", 201 pojo.getName())); 202 } 203 String path = annotation.value(); 204 205 // Validate encoders 206 validateEncoders(annotation.encoders()); 207 208 // ServerEndpointConfig 209 ServerEndpointConfig sec; 210 Class<? extends Configurator> configuratorClazz = 211 annotation.configurator(); 212 Configurator configurator = null; 213 if (!configuratorClazz.equals(Configurator.class)) { 214 try { 215 configurator = annotation.configurator().getConstructor().newInstance(); 216 } catch (ReflectiveOperationException e) { 217 throw new DeploymentException(sm.getString( 218 "serverContainer.configuratorFail", 219 annotation.configurator().getName(), 220 pojo.getClass().getName()), e); 221 } 222 } 223 if (configurator == null) { 224 configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator(); 225 } 226 sec = ServerEndpointConfig.Builder.create(pojo, path). 227 decoders(Arrays.asList(annotation.decoders())). 228 encoders(Arrays.asList(annotation.encoders())). 229 subprotocols(Arrays.asList(annotation.subprotocols())). 230 configurator(configurator). 231 build(); 232 233 addEndpoint(sec); 234 } 235 236 areEndpointsRegistered()237 boolean areEndpointsRegistered() { 238 return endpointsRegistered; 239 } 240 241 242 /** 243 * Until the WebSocket specification provides such a mechanism, this Tomcat 244 * proprietary method is provided to enable applications to programmatically 245 * determine whether or not to upgrade an individual request to WebSocket. 246 * <p> 247 * Note: This method is not used by Tomcat but is used directly by 248 * third-party code and must not be removed. 249 * 250 * @param request The request object to be upgraded 251 * @param response The response object to be populated with the result of 252 * the upgrade 253 * @param sec The server endpoint to use to process the upgrade request 254 * @param pathParams The path parameters associated with the upgrade request 255 * 256 * @throws ServletException If a configuration error prevents the upgrade 257 * from taking place 258 * @throws IOException If an I/O error occurs during the upgrade process 259 */ doUpgrade(HttpServletRequest request, HttpServletResponse response, ServerEndpointConfig sec, Map<String,String> pathParams)260 public void doUpgrade(HttpServletRequest request, 261 HttpServletResponse response, ServerEndpointConfig sec, 262 Map<String,String> pathParams) 263 throws ServletException, IOException { 264 UpgradeUtil.doUpgrade(this, request, response, sec, pathParams); 265 } 266 267 findMapping(String path)268 public WsMappingResult findMapping(String path) { 269 270 // Prevent registering additional endpoints once the first attempt has 271 // been made to use one 272 if (addAllowed) { 273 addAllowed = false; 274 } 275 276 // Check an exact match. Simple case as there are no templates. 277 ServerEndpointConfig sec = configExactMatchMap.get(path); 278 if (sec != null) { 279 return new WsMappingResult(sec, Collections.<String, String>emptyMap()); 280 } 281 282 // No exact match. Need to look for template matches. 283 UriTemplate pathUriTemplate = null; 284 try { 285 pathUriTemplate = new UriTemplate(path); 286 } catch (DeploymentException e) { 287 // Path is not valid so can't be matched to a WebSocketEndpoint 288 return null; 289 } 290 291 // Number of segments has to match 292 Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount()); 293 SortedSet<TemplatePathMatch> templateMatches = 294 configTemplateMatchMap.get(key); 295 296 if (templateMatches == null) { 297 // No templates with an equal number of segments so there will be 298 // no matches 299 return null; 300 } 301 302 // List is in alphabetical order of normalised templates. 303 // Correct match is the first one that matches. 304 Map<String,String> pathParams = null; 305 for (TemplatePathMatch templateMatch : templateMatches) { 306 pathParams = templateMatch.getUriTemplate().match(pathUriTemplate); 307 if (pathParams != null) { 308 sec = templateMatch.getConfig(); 309 break; 310 } 311 } 312 313 if (sec == null) { 314 // No match 315 return null; 316 } 317 318 return new WsMappingResult(sec, pathParams); 319 } 320 321 322 isEnforceNoAddAfterHandshake()323 public boolean isEnforceNoAddAfterHandshake() { 324 return enforceNoAddAfterHandshake; 325 } 326 327 setEnforceNoAddAfterHandshake( boolean enforceNoAddAfterHandshake)328 public void setEnforceNoAddAfterHandshake( 329 boolean enforceNoAddAfterHandshake) { 330 this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake; 331 } 332 333 334 /** 335 * {@inheritDoc} 336 * 337 * Overridden to make it visible to other classes in this package. 338 */ 339 @Override registerSession(Endpoint endpoint, WsSession wsSession)340 protected void registerSession(Endpoint endpoint, WsSession wsSession) { 341 super.registerSession(endpoint, wsSession); 342 if (wsSession.isOpen() && 343 wsSession.getUserPrincipal() != null && 344 wsSession.getHttpSessionId() != null) { 345 registerAuthenticatedSession(wsSession, 346 wsSession.getHttpSessionId()); 347 } 348 } 349 350 351 /** 352 * {@inheritDoc} 353 * 354 * Overridden to make it visible to other classes in this package. 355 */ 356 @Override unregisterSession(Endpoint endpoint, WsSession wsSession)357 protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { 358 if (wsSession.getUserPrincipal() != null && 359 wsSession.getHttpSessionId() != null) { 360 unregisterAuthenticatedSession(wsSession, 361 wsSession.getHttpSessionId()); 362 } 363 super.unregisterSession(endpoint, wsSession); 364 } 365 366 registerAuthenticatedSession(WsSession wsSession, String httpSessionId)367 private void registerAuthenticatedSession(WsSession wsSession, 368 String httpSessionId) { 369 Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); 370 if (wsSessions == null) { 371 wsSessions = Collections.newSetFromMap( 372 new ConcurrentHashMap<WsSession,Boolean>()); 373 authenticatedSessions.putIfAbsent(httpSessionId, wsSessions); 374 wsSessions = authenticatedSessions.get(httpSessionId); 375 } 376 wsSessions.add(wsSession); 377 } 378 379 unregisterAuthenticatedSession(WsSession wsSession, String httpSessionId)380 private void unregisterAuthenticatedSession(WsSession wsSession, 381 String httpSessionId) { 382 Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); 383 // wsSessions will be null if the HTTP session has ended 384 if (wsSessions != null) { 385 wsSessions.remove(wsSession); 386 } 387 } 388 389 closeAuthenticatedSession(String httpSessionId)390 public void closeAuthenticatedSession(String httpSessionId) { 391 Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId); 392 393 if (wsSessions != null && !wsSessions.isEmpty()) { 394 for (WsSession wsSession : wsSessions) { 395 try { 396 wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED); 397 } catch (IOException e) { 398 // Any IOExceptions during close will have been caught and the 399 // onError method called. 400 } 401 } 402 } 403 } 404 405 validateEncoders(Class<? extends Encoder>[] encoders)406 private static void validateEncoders(Class<? extends Encoder>[] encoders) 407 throws DeploymentException { 408 409 for (Class<? extends Encoder> encoder : encoders) { 410 // Need to instantiate decoder to ensure it is valid and that 411 // deployment can be failed if it is not 412 @SuppressWarnings("unused") 413 Encoder instance; 414 try { 415 encoder.getConstructor().newInstance(); 416 } catch(ReflectiveOperationException e) { 417 throw new DeploymentException(sm.getString( 418 "serverContainer.encoderFail", encoder.getName()), e); 419 } 420 } 421 } 422 423 424 private static class TemplatePathMatch { 425 private final ServerEndpointConfig config; 426 private final UriTemplate uriTemplate; 427 TemplatePathMatch(ServerEndpointConfig config, UriTemplate uriTemplate)428 public TemplatePathMatch(ServerEndpointConfig config, 429 UriTemplate uriTemplate) { 430 this.config = config; 431 this.uriTemplate = uriTemplate; 432 } 433 434 getConfig()435 public ServerEndpointConfig getConfig() { 436 return config; 437 } 438 439 getUriTemplate()440 public UriTemplate getUriTemplate() { 441 return uriTemplate; 442 } 443 } 444 445 446 /** 447 * This Comparator implementation is thread-safe so only create a single 448 * instance. 449 */ 450 private static class TemplatePathMatchComparator 451 implements Comparator<TemplatePathMatch> { 452 453 private static final TemplatePathMatchComparator INSTANCE = 454 new TemplatePathMatchComparator(); 455 getInstance()456 public static TemplatePathMatchComparator getInstance() { 457 return INSTANCE; 458 } 459 TemplatePathMatchComparator()460 private TemplatePathMatchComparator() { 461 // Hide default constructor 462 } 463 464 @Override compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2)465 public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) { 466 return tpm1.getUriTemplate().getNormalizedPath().compareTo( 467 tpm2.getUriTemplate().getNormalizedPath()); 468 } 469 } 470 } 471