xref: /unit/src/java/nginx/unit/websocket/server/WsServerContainer.java (revision 1157:7ae152bda303)
1 /*
2  *  Licensed to the Apache Software Foundation (ASF) under one or more
3  *  contributor license agreements.  See the NOTICE file distributed with
4  *  this work for additional information regarding copyright ownership.
5  *  The ASF licenses this file to You under the Apache License, Version 2.0
6  *  (the "License"); you may not use this file except in compliance with
7  *  the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  *  Unless required by applicable law or agreed to in writing, software
12  *  distributed under the License is distributed on an "AS IS" BASIS,
13  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  *  See the License for the specific language governing permissions and
15  *  limitations under the License.
16  */
17 package nginx.unit.websocket.server;
18 
19 import java.io.IOException;
20 import java.util.Arrays;
21 import java.util.Collections;
22 import java.util.Comparator;
23 import java.util.EnumSet;
24 import java.util.Map;
25 import java.util.Set;
26 import java.util.SortedSet;
27 import java.util.TreeSet;
28 import java.util.concurrent.ConcurrentHashMap;
29 
30 import javax.servlet.DispatcherType;
31 import javax.servlet.FilterRegistration;
32 import javax.servlet.ServletContext;
33 import javax.servlet.ServletException;
34 import javax.servlet.http.HttpServletRequest;
35 import javax.servlet.http.HttpServletResponse;
36 import javax.websocket.CloseReason;
37 import javax.websocket.CloseReason.CloseCodes;
38 import javax.websocket.DeploymentException;
39 import javax.websocket.Encoder;
40 import javax.websocket.Endpoint;
41 import javax.websocket.server.ServerContainer;
42 import javax.websocket.server.ServerEndpoint;
43 import javax.websocket.server.ServerEndpointConfig;
44 import javax.websocket.server.ServerEndpointConfig.Configurator;
45 
46 import org.apache.tomcat.InstanceManager;
47 import org.apache.tomcat.util.res.StringManager;
48 import nginx.unit.websocket.WsSession;
49 import nginx.unit.websocket.WsWebSocketContainer;
50 import nginx.unit.websocket.pojo.PojoMethodMapping;
51 
52 /**
53  * Provides a per class loader (i.e. per web application) instance of a
54  * ServerContainer. Web application wide defaults may be configured by setting
55  * the following servlet context initialisation parameters to the desired
56  * values.
57  * <ul>
58  * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
59  * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
60  * </ul>
61  */
62 public class WsServerContainer extends WsWebSocketContainer
63         implements ServerContainer {
64 
65     private static final StringManager sm = StringManager.getManager(WsServerContainer.class);
66 
67     private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED =
68             new CloseReason(CloseCodes.VIOLATED_POLICY,
69                     "This connection was established under an authenticated " +
70                     "HTTP session that has ended.");
71 
72     private final ServletContext servletContext;
73     private final Map<String,ServerEndpointConfig> configExactMatchMap =
74             new ConcurrentHashMap<>();
75     private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap =
76             new ConcurrentHashMap<>();
77     private volatile boolean enforceNoAddAfterHandshake =
78             nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE;
79     private volatile boolean addAllowed = true;
80     private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>();
81     private volatile boolean endpointsRegistered = false;
82 
WsServerContainer(ServletContext servletContext)83     WsServerContainer(ServletContext servletContext) {
84 
85         this.servletContext = servletContext;
86         setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName()));
87 
88         // Configure servlet context wide defaults
89         String value = servletContext.getInitParameter(
90                 Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
91         if (value != null) {
92             setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value));
93         }
94 
95         value = servletContext.getInitParameter(
96                 Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
97         if (value != null) {
98             setDefaultMaxTextMessageBufferSize(Integer.parseInt(value));
99         }
100 
101         value = servletContext.getInitParameter(
102                 Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM);
103         if (value != null) {
104             setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value));
105         }
106 
107         FilterRegistration.Dynamic fr = servletContext.addFilter(
108                 "Tomcat WebSocket (JSR356) Filter", new WsFilter());
109         fr.setAsyncSupported(true);
110 
111         EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST,
112                 DispatcherType.FORWARD);
113 
114         fr.addMappingForUrlPatterns(types, true, "/*");
115     }
116 
117 
118     /**
119      * Published the provided endpoint implementation at the specified path with
120      * the specified configuration. {@link #WsServerContainer(ServletContext)}
121      * must be called before calling this method.
122      *
123      * @param sec   The configuration to use when creating endpoint instances
124      * @throws DeploymentException if the endpoint cannot be published as
125      *         requested
126      */
127     @Override
addEndpoint(ServerEndpointConfig sec)128     public void addEndpoint(ServerEndpointConfig sec)
129             throws DeploymentException {
130 
131         if (enforceNoAddAfterHandshake && !addAllowed) {
132             throw new DeploymentException(
133                     sm.getString("serverContainer.addNotAllowed"));
134         }
135 
136         if (servletContext == null) {
137             throw new DeploymentException(
138                     sm.getString("serverContainer.servletContextMissing"));
139         }
140         String path = sec.getPath();
141 
142         // Add method mapping to user properties
143         PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(),
144                 sec.getDecoders(), path);
145         if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null
146                 || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) {
147             sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY,
148                     methodMapping);
149         }
150 
151         UriTemplate uriTemplate = new UriTemplate(path);
152         if (uriTemplate.hasParameters()) {
153             Integer key = Integer.valueOf(uriTemplate.getSegmentCount());
154             SortedSet<TemplatePathMatch> templateMatches =
155                     configTemplateMatchMap.get(key);
156             if (templateMatches == null) {
157                 // Ensure that if concurrent threads execute this block they
158                 // both end up using the same TreeSet instance
159                 templateMatches = new TreeSet<>(
160                         TemplatePathMatchComparator.getInstance());
161                 configTemplateMatchMap.putIfAbsent(key, templateMatches);
162                 templateMatches = configTemplateMatchMap.get(key);
163             }
164             if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) {
165                 // Duplicate uriTemplate;
166                 throw new DeploymentException(
167                         sm.getString("serverContainer.duplicatePaths", path,
168                                      sec.getEndpointClass(),
169                                      sec.getEndpointClass()));
170             }
171         } else {
172             // Exact match
173             ServerEndpointConfig old = configExactMatchMap.put(path, sec);
174             if (old != null) {
175                 // Duplicate path mappings
176                 throw new DeploymentException(
177                         sm.getString("serverContainer.duplicatePaths", path,
178                                      old.getEndpointClass(),
179                                      sec.getEndpointClass()));
180             }
181         }
182 
183         endpointsRegistered = true;
184     }
185 
186 
187     /**
188      * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)}
189      * for publishing plain old java objects (POJOs) that have been annotated as
190      * WebSocket endpoints.
191      *
192      * @param pojo   The annotated POJO
193      */
194     @Override
addEndpoint(Class<?> pojo)195     public void addEndpoint(Class<?> pojo) throws DeploymentException {
196 
197         ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class);
198         if (annotation == null) {
199             throw new DeploymentException(
200                     sm.getString("serverContainer.missingAnnotation",
201                             pojo.getName()));
202         }
203         String path = annotation.value();
204 
205         // Validate encoders
206         validateEncoders(annotation.encoders());
207 
208         // ServerEndpointConfig
209         ServerEndpointConfig sec;
210         Class<? extends Configurator> configuratorClazz =
211                 annotation.configurator();
212         Configurator configurator = null;
213         if (!configuratorClazz.equals(Configurator.class)) {
214             try {
215                 configurator = annotation.configurator().getConstructor().newInstance();
216             } catch (ReflectiveOperationException e) {
217                 throw new DeploymentException(sm.getString(
218                         "serverContainer.configuratorFail",
219                         annotation.configurator().getName(),
220                         pojo.getClass().getName()), e);
221             }
222         }
223         if (configurator == null) {
224             configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator();
225         }
226         sec = ServerEndpointConfig.Builder.create(pojo, path).
227                 decoders(Arrays.asList(annotation.decoders())).
228                 encoders(Arrays.asList(annotation.encoders())).
229                 subprotocols(Arrays.asList(annotation.subprotocols())).
230                 configurator(configurator).
231                 build();
232 
233         addEndpoint(sec);
234     }
235 
236 
areEndpointsRegistered()237     boolean areEndpointsRegistered() {
238         return endpointsRegistered;
239     }
240 
241 
242     /**
243      * Until the WebSocket specification provides such a mechanism, this Tomcat
244      * proprietary method is provided to enable applications to programmatically
245      * determine whether or not to upgrade an individual request to WebSocket.
246      * <p>
247      * Note: This method is not used by Tomcat but is used directly by
248      *       third-party code and must not be removed.
249      *
250      * @param request The request object to be upgraded
251      * @param response The response object to be populated with the result of
252      *                 the upgrade
253      * @param sec The server endpoint to use to process the upgrade request
254      * @param pathParams The path parameters associated with the upgrade request
255      *
256      * @throws ServletException If a configuration error prevents the upgrade
257      *         from taking place
258      * @throws IOException If an I/O error occurs during the upgrade process
259      */
doUpgrade(HttpServletRequest request, HttpServletResponse response, ServerEndpointConfig sec, Map<String,String> pathParams)260     public void doUpgrade(HttpServletRequest request,
261             HttpServletResponse response, ServerEndpointConfig sec,
262             Map<String,String> pathParams)
263             throws ServletException, IOException {
264         UpgradeUtil.doUpgrade(this, request, response, sec, pathParams);
265     }
266 
267 
findMapping(String path)268     public WsMappingResult findMapping(String path) {
269 
270         // Prevent registering additional endpoints once the first attempt has
271         // been made to use one
272         if (addAllowed) {
273             addAllowed = false;
274         }
275 
276         // Check an exact match. Simple case as there are no templates.
277         ServerEndpointConfig sec = configExactMatchMap.get(path);
278         if (sec != null) {
279             return new WsMappingResult(sec, Collections.<String, String>emptyMap());
280         }
281 
282         // No exact match. Need to look for template matches.
283         UriTemplate pathUriTemplate = null;
284         try {
285             pathUriTemplate = new UriTemplate(path);
286         } catch (DeploymentException e) {
287             // Path is not valid so can't be matched to a WebSocketEndpoint
288             return null;
289         }
290 
291         // Number of segments has to match
292         Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount());
293         SortedSet<TemplatePathMatch> templateMatches =
294                 configTemplateMatchMap.get(key);
295 
296         if (templateMatches == null) {
297             // No templates with an equal number of segments so there will be
298             // no matches
299             return null;
300         }
301 
302         // List is in alphabetical order of normalised templates.
303         // Correct match is the first one that matches.
304         Map<String,String> pathParams = null;
305         for (TemplatePathMatch templateMatch : templateMatches) {
306             pathParams = templateMatch.getUriTemplate().match(pathUriTemplate);
307             if (pathParams != null) {
308                 sec = templateMatch.getConfig();
309                 break;
310             }
311         }
312 
313         if (sec == null) {
314             // No match
315             return null;
316         }
317 
318         return new WsMappingResult(sec, pathParams);
319     }
320 
321 
322 
isEnforceNoAddAfterHandshake()323     public boolean isEnforceNoAddAfterHandshake() {
324         return enforceNoAddAfterHandshake;
325     }
326 
327 
setEnforceNoAddAfterHandshake( boolean enforceNoAddAfterHandshake)328     public void setEnforceNoAddAfterHandshake(
329             boolean enforceNoAddAfterHandshake) {
330         this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake;
331     }
332 
333 
334     /**
335      * {@inheritDoc}
336      *
337      * Overridden to make it visible to other classes in this package.
338      */
339     @Override
registerSession(Endpoint endpoint, WsSession wsSession)340     protected void registerSession(Endpoint endpoint, WsSession wsSession) {
341         super.registerSession(endpoint, wsSession);
342         if (wsSession.isOpen() &&
343                 wsSession.getUserPrincipal() != null &&
344                 wsSession.getHttpSessionId() != null) {
345             registerAuthenticatedSession(wsSession,
346                     wsSession.getHttpSessionId());
347         }
348     }
349 
350 
351     /**
352      * {@inheritDoc}
353      *
354      * Overridden to make it visible to other classes in this package.
355      */
356     @Override
unregisterSession(Endpoint endpoint, WsSession wsSession)357     protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
358         if (wsSession.getUserPrincipal() != null &&
359                 wsSession.getHttpSessionId() != null) {
360             unregisterAuthenticatedSession(wsSession,
361                     wsSession.getHttpSessionId());
362         }
363         super.unregisterSession(endpoint, wsSession);
364     }
365 
366 
registerAuthenticatedSession(WsSession wsSession, String httpSessionId)367     private void registerAuthenticatedSession(WsSession wsSession,
368             String httpSessionId) {
369         Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
370         if (wsSessions == null) {
371             wsSessions = Collections.newSetFromMap(
372                      new ConcurrentHashMap<WsSession,Boolean>());
373              authenticatedSessions.putIfAbsent(httpSessionId, wsSessions);
374              wsSessions = authenticatedSessions.get(httpSessionId);
375         }
376         wsSessions.add(wsSession);
377     }
378 
379 
unregisterAuthenticatedSession(WsSession wsSession, String httpSessionId)380     private void unregisterAuthenticatedSession(WsSession wsSession,
381             String httpSessionId) {
382         Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
383         // wsSessions will be null if the HTTP session has ended
384         if (wsSessions != null) {
385             wsSessions.remove(wsSession);
386         }
387     }
388 
389 
closeAuthenticatedSession(String httpSessionId)390     public void closeAuthenticatedSession(String httpSessionId) {
391         Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId);
392 
393         if (wsSessions != null && !wsSessions.isEmpty()) {
394             for (WsSession wsSession : wsSessions) {
395                 try {
396                     wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED);
397                 } catch (IOException e) {
398                     // Any IOExceptions during close will have been caught and the
399                     // onError method called.
400                 }
401             }
402         }
403     }
404 
405 
validateEncoders(Class<? extends Encoder>[] encoders)406     private static void validateEncoders(Class<? extends Encoder>[] encoders)
407             throws DeploymentException {
408 
409         for (Class<? extends Encoder> encoder : encoders) {
410             // Need to instantiate decoder to ensure it is valid and that
411             // deployment can be failed if it is not
412             @SuppressWarnings("unused")
413             Encoder instance;
414             try {
415                 encoder.getConstructor().newInstance();
416             } catch(ReflectiveOperationException e) {
417                 throw new DeploymentException(sm.getString(
418                         "serverContainer.encoderFail", encoder.getName()), e);
419             }
420         }
421     }
422 
423 
424     private static class TemplatePathMatch {
425         private final ServerEndpointConfig config;
426         private final UriTemplate uriTemplate;
427 
TemplatePathMatch(ServerEndpointConfig config, UriTemplate uriTemplate)428         public TemplatePathMatch(ServerEndpointConfig config,
429                 UriTemplate uriTemplate) {
430             this.config = config;
431             this.uriTemplate = uriTemplate;
432         }
433 
434 
getConfig()435         public ServerEndpointConfig getConfig() {
436             return config;
437         }
438 
439 
getUriTemplate()440         public UriTemplate getUriTemplate() {
441             return uriTemplate;
442         }
443     }
444 
445 
446     /**
447      * This Comparator implementation is thread-safe so only create a single
448      * instance.
449      */
450     private static class TemplatePathMatchComparator
451             implements Comparator<TemplatePathMatch> {
452 
453         private static final TemplatePathMatchComparator INSTANCE =
454                 new TemplatePathMatchComparator();
455 
getInstance()456         public static TemplatePathMatchComparator getInstance() {
457             return INSTANCE;
458         }
459 
TemplatePathMatchComparator()460         private TemplatePathMatchComparator() {
461             // Hide default constructor
462         }
463 
464         @Override
compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2)465         public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) {
466             return tpm1.getUriTemplate().getNormalizedPath().compareTo(
467                     tpm2.getUriTemplate().getNormalizedPath());
468         }
469     }
470 }
471