xref: /unit/src/java/nginx/unit/websocket/server/UpgradeUtil.java (revision 1157:7ae152bda303)
1*1157Smax.romanov@nginx.com /*
2*1157Smax.romanov@nginx.com  * Licensed to the Apache Software Foundation (ASF) under one or more
3*1157Smax.romanov@nginx.com  * contributor license agreements.  See the NOTICE file distributed with
4*1157Smax.romanov@nginx.com  * this work for additional information regarding copyright ownership.
5*1157Smax.romanov@nginx.com  * The ASF licenses this file to You under the Apache License, Version 2.0
6*1157Smax.romanov@nginx.com  * (the "License"); you may not use this file except in compliance with
7*1157Smax.romanov@nginx.com  * the License.  You may obtain a copy of the License at
8*1157Smax.romanov@nginx.com  *
9*1157Smax.romanov@nginx.com  *     http://www.apache.org/licenses/LICENSE-2.0
10*1157Smax.romanov@nginx.com  *
11*1157Smax.romanov@nginx.com  * Unless required by applicable law or agreed to in writing, software
12*1157Smax.romanov@nginx.com  * distributed under the License is distributed on an "AS IS" BASIS,
13*1157Smax.romanov@nginx.com  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14*1157Smax.romanov@nginx.com  * See the License for the specific language governing permissions and
15*1157Smax.romanov@nginx.com  * limitations under the License.
16*1157Smax.romanov@nginx.com  */
17*1157Smax.romanov@nginx.com package nginx.unit.websocket.server;
18*1157Smax.romanov@nginx.com 
19*1157Smax.romanov@nginx.com import java.io.IOException;
20*1157Smax.romanov@nginx.com import java.nio.charset.StandardCharsets;
21*1157Smax.romanov@nginx.com import java.util.ArrayList;
22*1157Smax.romanov@nginx.com import java.util.Collections;
23*1157Smax.romanov@nginx.com import java.util.Enumeration;
24*1157Smax.romanov@nginx.com import java.util.LinkedHashMap;
25*1157Smax.romanov@nginx.com import java.util.List;
26*1157Smax.romanov@nginx.com import java.util.Map;
27*1157Smax.romanov@nginx.com import java.util.Map.Entry;
28*1157Smax.romanov@nginx.com 
29*1157Smax.romanov@nginx.com import javax.servlet.ServletException;
30*1157Smax.romanov@nginx.com import javax.servlet.ServletRequest;
31*1157Smax.romanov@nginx.com import javax.servlet.ServletResponse;
32*1157Smax.romanov@nginx.com import javax.servlet.http.HttpServletRequest;
33*1157Smax.romanov@nginx.com import javax.servlet.http.HttpServletResponse;
34*1157Smax.romanov@nginx.com import javax.websocket.Endpoint;
35*1157Smax.romanov@nginx.com import javax.websocket.Extension;
36*1157Smax.romanov@nginx.com import javax.websocket.HandshakeResponse;
37*1157Smax.romanov@nginx.com import javax.websocket.server.ServerEndpointConfig;
38*1157Smax.romanov@nginx.com 
39*1157Smax.romanov@nginx.com import nginx.unit.Request;
40*1157Smax.romanov@nginx.com 
41*1157Smax.romanov@nginx.com import org.apache.tomcat.util.codec.binary.Base64;
42*1157Smax.romanov@nginx.com import org.apache.tomcat.util.res.StringManager;
43*1157Smax.romanov@nginx.com import org.apache.tomcat.util.security.ConcurrentMessageDigest;
44*1157Smax.romanov@nginx.com import nginx.unit.websocket.Constants;
45*1157Smax.romanov@nginx.com import nginx.unit.websocket.Transformation;
46*1157Smax.romanov@nginx.com import nginx.unit.websocket.TransformationFactory;
47*1157Smax.romanov@nginx.com import nginx.unit.websocket.Util;
48*1157Smax.romanov@nginx.com import nginx.unit.websocket.WsHandshakeResponse;
49*1157Smax.romanov@nginx.com import nginx.unit.websocket.pojo.PojoEndpointServer;
50*1157Smax.romanov@nginx.com 
51*1157Smax.romanov@nginx.com public class UpgradeUtil {
52*1157Smax.romanov@nginx.com 
53*1157Smax.romanov@nginx.com     private static final StringManager sm =
54*1157Smax.romanov@nginx.com             StringManager.getManager(UpgradeUtil.class.getPackage().getName());
55*1157Smax.romanov@nginx.com     private static final byte[] WS_ACCEPT =
56*1157Smax.romanov@nginx.com             "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
57*1157Smax.romanov@nginx.com                     StandardCharsets.ISO_8859_1);
58*1157Smax.romanov@nginx.com 
UpgradeUtil()59*1157Smax.romanov@nginx.com     private UpgradeUtil() {
60*1157Smax.romanov@nginx.com         // Utility class. Hide default constructor.
61*1157Smax.romanov@nginx.com     }
62*1157Smax.romanov@nginx.com 
63*1157Smax.romanov@nginx.com     /**
64*1157Smax.romanov@nginx.com      * Checks to see if this is an HTTP request that includes a valid upgrade
65*1157Smax.romanov@nginx.com      * request to web socket.
66*1157Smax.romanov@nginx.com      * <p>
67*1157Smax.romanov@nginx.com      * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
68*1157Smax.romanov@nginx.com      *       WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
69*1157Smax.romanov@nginx.com      *       6455 section 4.1 requires that a WebSocket Upgrade uses GET.
70*1157Smax.romanov@nginx.com      * @param request  The request to check if it is an HTTP upgrade request for
71*1157Smax.romanov@nginx.com      *                 a WebSocket connection
72*1157Smax.romanov@nginx.com      * @param response The response associated with the request
73*1157Smax.romanov@nginx.com      * @return <code>true</code> if the request includes a HTTP Upgrade request
74*1157Smax.romanov@nginx.com      *         for the WebSocket protocol, otherwise <code>false</code>
75*1157Smax.romanov@nginx.com      */
isWebSocketUpgradeRequest(ServletRequest request, ServletResponse response)76*1157Smax.romanov@nginx.com     public static boolean isWebSocketUpgradeRequest(ServletRequest request,
77*1157Smax.romanov@nginx.com             ServletResponse response) {
78*1157Smax.romanov@nginx.com 
79*1157Smax.romanov@nginx.com         Request r = (Request) request.getAttribute(Request.BARE);
80*1157Smax.romanov@nginx.com 
81*1157Smax.romanov@nginx.com         return ((request instanceof HttpServletRequest) &&
82*1157Smax.romanov@nginx.com                 (response instanceof HttpServletResponse) &&
83*1157Smax.romanov@nginx.com                 (r != null) &&
84*1157Smax.romanov@nginx.com                 (r.isUpgrade()));
85*1157Smax.romanov@nginx.com     }
86*1157Smax.romanov@nginx.com 
87*1157Smax.romanov@nginx.com 
doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map<String,String> pathParams)88*1157Smax.romanov@nginx.com     public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
89*1157Smax.romanov@nginx.com             HttpServletResponse resp, ServerEndpointConfig sec,
90*1157Smax.romanov@nginx.com             Map<String,String> pathParams)
91*1157Smax.romanov@nginx.com             throws ServletException, IOException {
92*1157Smax.romanov@nginx.com 
93*1157Smax.romanov@nginx.com 
94*1157Smax.romanov@nginx.com         // Origin check
95*1157Smax.romanov@nginx.com         String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
96*1157Smax.romanov@nginx.com 
97*1157Smax.romanov@nginx.com         if (!sec.getConfigurator().checkOrigin(origin)) {
98*1157Smax.romanov@nginx.com             resp.sendError(HttpServletResponse.SC_FORBIDDEN);
99*1157Smax.romanov@nginx.com             return;
100*1157Smax.romanov@nginx.com         }
101*1157Smax.romanov@nginx.com         // Sub-protocols
102*1157Smax.romanov@nginx.com         List<String> subProtocols = getTokensFromHeader(req,
103*1157Smax.romanov@nginx.com                 Constants.WS_PROTOCOL_HEADER_NAME);
104*1157Smax.romanov@nginx.com         String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
105*1157Smax.romanov@nginx.com                 sec.getSubprotocols(), subProtocols);
106*1157Smax.romanov@nginx.com 
107*1157Smax.romanov@nginx.com         // Extensions
108*1157Smax.romanov@nginx.com         // Should normally only be one header but handle the case of multiple
109*1157Smax.romanov@nginx.com         // headers
110*1157Smax.romanov@nginx.com         List<Extension> extensionsRequested = new ArrayList<>();
111*1157Smax.romanov@nginx.com         Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
112*1157Smax.romanov@nginx.com         while (extHeaders.hasMoreElements()) {
113*1157Smax.romanov@nginx.com             Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
114*1157Smax.romanov@nginx.com         }
115*1157Smax.romanov@nginx.com 
116*1157Smax.romanov@nginx.com         // Negotiation phase 1. By default this simply filters out the
117*1157Smax.romanov@nginx.com         // extensions that the server does not support but applications could
118*1157Smax.romanov@nginx.com         // use a custom configurator to do more than this.
119*1157Smax.romanov@nginx.com         List<Extension> installedExtensions = null;
120*1157Smax.romanov@nginx.com         if (sec.getExtensions().size() == 0) {
121*1157Smax.romanov@nginx.com             installedExtensions = Constants.INSTALLED_EXTENSIONS;
122*1157Smax.romanov@nginx.com         } else {
123*1157Smax.romanov@nginx.com             installedExtensions = new ArrayList<>();
124*1157Smax.romanov@nginx.com             installedExtensions.addAll(sec.getExtensions());
125*1157Smax.romanov@nginx.com             installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
126*1157Smax.romanov@nginx.com         }
127*1157Smax.romanov@nginx.com         List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
128*1157Smax.romanov@nginx.com                 installedExtensions, extensionsRequested);
129*1157Smax.romanov@nginx.com 
130*1157Smax.romanov@nginx.com         // Negotiation phase 2. Create the Transformations that will be applied
131*1157Smax.romanov@nginx.com         // to this connection. Note than an extension may be dropped at this
132*1157Smax.romanov@nginx.com         // point if the client has requested a configuration that the server is
133*1157Smax.romanov@nginx.com         // unable to support.
134*1157Smax.romanov@nginx.com         List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
135*1157Smax.romanov@nginx.com 
136*1157Smax.romanov@nginx.com         List<Extension> negotiatedExtensionsPhase2;
137*1157Smax.romanov@nginx.com         if (transformations.isEmpty()) {
138*1157Smax.romanov@nginx.com             negotiatedExtensionsPhase2 = Collections.emptyList();
139*1157Smax.romanov@nginx.com         } else {
140*1157Smax.romanov@nginx.com             negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
141*1157Smax.romanov@nginx.com             for (Transformation t : transformations) {
142*1157Smax.romanov@nginx.com                 negotiatedExtensionsPhase2.add(t.getExtensionResponse());
143*1157Smax.romanov@nginx.com             }
144*1157Smax.romanov@nginx.com         }
145*1157Smax.romanov@nginx.com 
146*1157Smax.romanov@nginx.com         WsHttpUpgradeHandler wsHandler =
147*1157Smax.romanov@nginx.com                 req.upgrade(WsHttpUpgradeHandler.class);
148*1157Smax.romanov@nginx.com 
149*1157Smax.romanov@nginx.com         WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
150*1157Smax.romanov@nginx.com         WsHandshakeResponse wsResponse = new WsHandshakeResponse();
151*1157Smax.romanov@nginx.com         WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
152*1157Smax.romanov@nginx.com                 new WsPerSessionServerEndpointConfig(sec);
153*1157Smax.romanov@nginx.com         sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
154*1157Smax.romanov@nginx.com                 wsRequest, wsResponse);
155*1157Smax.romanov@nginx.com         //wsRequest.finished();
156*1157Smax.romanov@nginx.com 
157*1157Smax.romanov@nginx.com         // Add any additional headers
158*1157Smax.romanov@nginx.com         for (Entry<String,List<String>> entry :
159*1157Smax.romanov@nginx.com                 wsResponse.getHeaders().entrySet()) {
160*1157Smax.romanov@nginx.com             for (String headerValue: entry.getValue()) {
161*1157Smax.romanov@nginx.com                 resp.addHeader(entry.getKey(), headerValue);
162*1157Smax.romanov@nginx.com             }
163*1157Smax.romanov@nginx.com         }
164*1157Smax.romanov@nginx.com 
165*1157Smax.romanov@nginx.com         Endpoint ep;
166*1157Smax.romanov@nginx.com         try {
167*1157Smax.romanov@nginx.com             Class<?> clazz = sec.getEndpointClass();
168*1157Smax.romanov@nginx.com             if (Endpoint.class.isAssignableFrom(clazz)) {
169*1157Smax.romanov@nginx.com                 ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
170*1157Smax.romanov@nginx.com                         clazz);
171*1157Smax.romanov@nginx.com             } else {
172*1157Smax.romanov@nginx.com                 ep = new PojoEndpointServer();
173*1157Smax.romanov@nginx.com                 // Need to make path params available to POJO
174*1157Smax.romanov@nginx.com                 perSessionServerEndpointConfig.getUserProperties().put(
175*1157Smax.romanov@nginx.com                         nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
176*1157Smax.romanov@nginx.com             }
177*1157Smax.romanov@nginx.com         } catch (InstantiationException e) {
178*1157Smax.romanov@nginx.com             throw new ServletException(e);
179*1157Smax.romanov@nginx.com         }
180*1157Smax.romanov@nginx.com 
181*1157Smax.romanov@nginx.com         wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
182*1157Smax.romanov@nginx.com                 negotiatedExtensionsPhase2, subProtocol, null, pathParams,
183*1157Smax.romanov@nginx.com                 req.isSecure());
184*1157Smax.romanov@nginx.com 
185*1157Smax.romanov@nginx.com         wsHandler.init(null);
186*1157Smax.romanov@nginx.com     }
187*1157Smax.romanov@nginx.com 
188*1157Smax.romanov@nginx.com 
createTransformations( List<Extension> negotiatedExtensions)189*1157Smax.romanov@nginx.com     private static List<Transformation> createTransformations(
190*1157Smax.romanov@nginx.com             List<Extension> negotiatedExtensions) {
191*1157Smax.romanov@nginx.com 
192*1157Smax.romanov@nginx.com         TransformationFactory factory = TransformationFactory.getInstance();
193*1157Smax.romanov@nginx.com 
194*1157Smax.romanov@nginx.com         LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
195*1157Smax.romanov@nginx.com                 new LinkedHashMap<>();
196*1157Smax.romanov@nginx.com 
197*1157Smax.romanov@nginx.com         // Result will likely be smaller than this
198*1157Smax.romanov@nginx.com         List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
199*1157Smax.romanov@nginx.com 
200*1157Smax.romanov@nginx.com         for (Extension extension : negotiatedExtensions) {
201*1157Smax.romanov@nginx.com             List<List<Extension.Parameter>> preferences =
202*1157Smax.romanov@nginx.com                     extensionPreferences.get(extension.getName());
203*1157Smax.romanov@nginx.com 
204*1157Smax.romanov@nginx.com             if (preferences == null) {
205*1157Smax.romanov@nginx.com                 preferences = new ArrayList<>();
206*1157Smax.romanov@nginx.com                 extensionPreferences.put(extension.getName(), preferences);
207*1157Smax.romanov@nginx.com             }
208*1157Smax.romanov@nginx.com 
209*1157Smax.romanov@nginx.com             preferences.add(extension.getParameters());
210*1157Smax.romanov@nginx.com         }
211*1157Smax.romanov@nginx.com 
212*1157Smax.romanov@nginx.com         for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
213*1157Smax.romanov@nginx.com             extensionPreferences.entrySet()) {
214*1157Smax.romanov@nginx.com             Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
215*1157Smax.romanov@nginx.com             if (transformation != null) {
216*1157Smax.romanov@nginx.com                 result.add(transformation);
217*1157Smax.romanov@nginx.com             }
218*1157Smax.romanov@nginx.com         }
219*1157Smax.romanov@nginx.com         return result;
220*1157Smax.romanov@nginx.com     }
221*1157Smax.romanov@nginx.com 
222*1157Smax.romanov@nginx.com 
append(StringBuilder sb, Extension extension)223*1157Smax.romanov@nginx.com     private static void append(StringBuilder sb, Extension extension) {
224*1157Smax.romanov@nginx.com         if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
225*1157Smax.romanov@nginx.com             return;
226*1157Smax.romanov@nginx.com         }
227*1157Smax.romanov@nginx.com 
228*1157Smax.romanov@nginx.com         sb.append(extension.getName());
229*1157Smax.romanov@nginx.com 
230*1157Smax.romanov@nginx.com         for (Extension.Parameter p : extension.getParameters()) {
231*1157Smax.romanov@nginx.com             sb.append(';');
232*1157Smax.romanov@nginx.com             sb.append(p.getName());
233*1157Smax.romanov@nginx.com             if (p.getValue() != null) {
234*1157Smax.romanov@nginx.com                 sb.append('=');
235*1157Smax.romanov@nginx.com                 sb.append(p.getValue());
236*1157Smax.romanov@nginx.com             }
237*1157Smax.romanov@nginx.com         }
238*1157Smax.romanov@nginx.com     }
239*1157Smax.romanov@nginx.com 
240*1157Smax.romanov@nginx.com 
241*1157Smax.romanov@nginx.com     /*
242*1157Smax.romanov@nginx.com      * This only works for tokens. Quoted strings need more sophisticated
243*1157Smax.romanov@nginx.com      * parsing.
244*1157Smax.romanov@nginx.com      */
headerContainsToken(HttpServletRequest req, String headerName, String target)245*1157Smax.romanov@nginx.com     private static boolean headerContainsToken(HttpServletRequest req,
246*1157Smax.romanov@nginx.com             String headerName, String target) {
247*1157Smax.romanov@nginx.com         Enumeration<String> headers = req.getHeaders(headerName);
248*1157Smax.romanov@nginx.com         while (headers.hasMoreElements()) {
249*1157Smax.romanov@nginx.com             String header = headers.nextElement();
250*1157Smax.romanov@nginx.com             String[] tokens = header.split(",");
251*1157Smax.romanov@nginx.com             for (String token : tokens) {
252*1157Smax.romanov@nginx.com                 if (target.equalsIgnoreCase(token.trim())) {
253*1157Smax.romanov@nginx.com                     return true;
254*1157Smax.romanov@nginx.com                 }
255*1157Smax.romanov@nginx.com             }
256*1157Smax.romanov@nginx.com         }
257*1157Smax.romanov@nginx.com         return false;
258*1157Smax.romanov@nginx.com     }
259*1157Smax.romanov@nginx.com 
260*1157Smax.romanov@nginx.com 
261*1157Smax.romanov@nginx.com     /*
262*1157Smax.romanov@nginx.com      * This only works for tokens. Quoted strings need more sophisticated
263*1157Smax.romanov@nginx.com      * parsing.
264*1157Smax.romanov@nginx.com      */
getTokensFromHeader(HttpServletRequest req, String headerName)265*1157Smax.romanov@nginx.com     private static List<String> getTokensFromHeader(HttpServletRequest req,
266*1157Smax.romanov@nginx.com             String headerName) {
267*1157Smax.romanov@nginx.com         List<String> result = new ArrayList<>();
268*1157Smax.romanov@nginx.com         Enumeration<String> headers = req.getHeaders(headerName);
269*1157Smax.romanov@nginx.com         while (headers.hasMoreElements()) {
270*1157Smax.romanov@nginx.com             String header = headers.nextElement();
271*1157Smax.romanov@nginx.com             String[] tokens = header.split(",");
272*1157Smax.romanov@nginx.com             for (String token : tokens) {
273*1157Smax.romanov@nginx.com                 result.add(token.trim());
274*1157Smax.romanov@nginx.com             }
275*1157Smax.romanov@nginx.com         }
276*1157Smax.romanov@nginx.com         return result;
277*1157Smax.romanov@nginx.com     }
278*1157Smax.romanov@nginx.com 
279*1157Smax.romanov@nginx.com 
getWebSocketAccept(String key)280*1157Smax.romanov@nginx.com     private static String getWebSocketAccept(String key) {
281*1157Smax.romanov@nginx.com         byte[] digest = ConcurrentMessageDigest.digestSHA1(
282*1157Smax.romanov@nginx.com                 key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
283*1157Smax.romanov@nginx.com         return Base64.encodeBase64String(digest);
284*1157Smax.romanov@nginx.com     }
285*1157Smax.romanov@nginx.com }
286