1*1157Smax.romanov@nginx.com /* 2*1157Smax.romanov@nginx.com * Licensed to the Apache Software Foundation (ASF) under one or more 3*1157Smax.romanov@nginx.com * contributor license agreements. See the NOTICE file distributed with 4*1157Smax.romanov@nginx.com * this work for additional information regarding copyright ownership. 5*1157Smax.romanov@nginx.com * The ASF licenses this file to You under the Apache License, Version 2.0 6*1157Smax.romanov@nginx.com * (the "License"); you may not use this file except in compliance with 7*1157Smax.romanov@nginx.com * the License. You may obtain a copy of the License at 8*1157Smax.romanov@nginx.com * 9*1157Smax.romanov@nginx.com * http://www.apache.org/licenses/LICENSE-2.0 10*1157Smax.romanov@nginx.com * 11*1157Smax.romanov@nginx.com * Unless required by applicable law or agreed to in writing, software 12*1157Smax.romanov@nginx.com * distributed under the License is distributed on an "AS IS" BASIS, 13*1157Smax.romanov@nginx.com * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14*1157Smax.romanov@nginx.com * See the License for the specific language governing permissions and 15*1157Smax.romanov@nginx.com * limitations under the License. 16*1157Smax.romanov@nginx.com */ 17*1157Smax.romanov@nginx.com package nginx.unit.websocket.server; 18*1157Smax.romanov@nginx.com 19*1157Smax.romanov@nginx.com import java.io.IOException; 20*1157Smax.romanov@nginx.com import java.nio.charset.StandardCharsets; 21*1157Smax.romanov@nginx.com import java.util.ArrayList; 22*1157Smax.romanov@nginx.com import java.util.Collections; 23*1157Smax.romanov@nginx.com import java.util.Enumeration; 24*1157Smax.romanov@nginx.com import java.util.LinkedHashMap; 25*1157Smax.romanov@nginx.com import java.util.List; 26*1157Smax.romanov@nginx.com import java.util.Map; 27*1157Smax.romanov@nginx.com import java.util.Map.Entry; 28*1157Smax.romanov@nginx.com 29*1157Smax.romanov@nginx.com import javax.servlet.ServletException; 30*1157Smax.romanov@nginx.com import javax.servlet.ServletRequest; 31*1157Smax.romanov@nginx.com import javax.servlet.ServletResponse; 32*1157Smax.romanov@nginx.com import javax.servlet.http.HttpServletRequest; 33*1157Smax.romanov@nginx.com import javax.servlet.http.HttpServletResponse; 34*1157Smax.romanov@nginx.com import javax.websocket.Endpoint; 35*1157Smax.romanov@nginx.com import javax.websocket.Extension; 36*1157Smax.romanov@nginx.com import javax.websocket.HandshakeResponse; 37*1157Smax.romanov@nginx.com import javax.websocket.server.ServerEndpointConfig; 38*1157Smax.romanov@nginx.com 39*1157Smax.romanov@nginx.com import nginx.unit.Request; 40*1157Smax.romanov@nginx.com 41*1157Smax.romanov@nginx.com import org.apache.tomcat.util.codec.binary.Base64; 42*1157Smax.romanov@nginx.com import org.apache.tomcat.util.res.StringManager; 43*1157Smax.romanov@nginx.com import org.apache.tomcat.util.security.ConcurrentMessageDigest; 44*1157Smax.romanov@nginx.com import nginx.unit.websocket.Constants; 45*1157Smax.romanov@nginx.com import nginx.unit.websocket.Transformation; 46*1157Smax.romanov@nginx.com import nginx.unit.websocket.TransformationFactory; 47*1157Smax.romanov@nginx.com import nginx.unit.websocket.Util; 48*1157Smax.romanov@nginx.com import nginx.unit.websocket.WsHandshakeResponse; 49*1157Smax.romanov@nginx.com import nginx.unit.websocket.pojo.PojoEndpointServer; 50*1157Smax.romanov@nginx.com 51*1157Smax.romanov@nginx.com public class UpgradeUtil { 52*1157Smax.romanov@nginx.com 53*1157Smax.romanov@nginx.com private static final StringManager sm = 54*1157Smax.romanov@nginx.com StringManager.getManager(UpgradeUtil.class.getPackage().getName()); 55*1157Smax.romanov@nginx.com private static final byte[] WS_ACCEPT = 56*1157Smax.romanov@nginx.com "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( 57*1157Smax.romanov@nginx.com StandardCharsets.ISO_8859_1); 58*1157Smax.romanov@nginx.com UpgradeUtil()59*1157Smax.romanov@nginx.com private UpgradeUtil() { 60*1157Smax.romanov@nginx.com // Utility class. Hide default constructor. 61*1157Smax.romanov@nginx.com } 62*1157Smax.romanov@nginx.com 63*1157Smax.romanov@nginx.com /** 64*1157Smax.romanov@nginx.com * Checks to see if this is an HTTP request that includes a valid upgrade 65*1157Smax.romanov@nginx.com * request to web socket. 66*1157Smax.romanov@nginx.com * <p> 67*1157Smax.romanov@nginx.com * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java 68*1157Smax.romanov@nginx.com * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC 69*1157Smax.romanov@nginx.com * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. 70*1157Smax.romanov@nginx.com * @param request The request to check if it is an HTTP upgrade request for 71*1157Smax.romanov@nginx.com * a WebSocket connection 72*1157Smax.romanov@nginx.com * @param response The response associated with the request 73*1157Smax.romanov@nginx.com * @return <code>true</code> if the request includes a HTTP Upgrade request 74*1157Smax.romanov@nginx.com * for the WebSocket protocol, otherwise <code>false</code> 75*1157Smax.romanov@nginx.com */ isWebSocketUpgradeRequest(ServletRequest request, ServletResponse response)76*1157Smax.romanov@nginx.com public static boolean isWebSocketUpgradeRequest(ServletRequest request, 77*1157Smax.romanov@nginx.com ServletResponse response) { 78*1157Smax.romanov@nginx.com 79*1157Smax.romanov@nginx.com Request r = (Request) request.getAttribute(Request.BARE); 80*1157Smax.romanov@nginx.com 81*1157Smax.romanov@nginx.com return ((request instanceof HttpServletRequest) && 82*1157Smax.romanov@nginx.com (response instanceof HttpServletResponse) && 83*1157Smax.romanov@nginx.com (r != null) && 84*1157Smax.romanov@nginx.com (r.isUpgrade())); 85*1157Smax.romanov@nginx.com } 86*1157Smax.romanov@nginx.com 87*1157Smax.romanov@nginx.com doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map<String,String> pathParams)88*1157Smax.romanov@nginx.com public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, 89*1157Smax.romanov@nginx.com HttpServletResponse resp, ServerEndpointConfig sec, 90*1157Smax.romanov@nginx.com Map<String,String> pathParams) 91*1157Smax.romanov@nginx.com throws ServletException, IOException { 92*1157Smax.romanov@nginx.com 93*1157Smax.romanov@nginx.com 94*1157Smax.romanov@nginx.com // Origin check 95*1157Smax.romanov@nginx.com String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); 96*1157Smax.romanov@nginx.com 97*1157Smax.romanov@nginx.com if (!sec.getConfigurator().checkOrigin(origin)) { 98*1157Smax.romanov@nginx.com resp.sendError(HttpServletResponse.SC_FORBIDDEN); 99*1157Smax.romanov@nginx.com return; 100*1157Smax.romanov@nginx.com } 101*1157Smax.romanov@nginx.com // Sub-protocols 102*1157Smax.romanov@nginx.com List<String> subProtocols = getTokensFromHeader(req, 103*1157Smax.romanov@nginx.com Constants.WS_PROTOCOL_HEADER_NAME); 104*1157Smax.romanov@nginx.com String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( 105*1157Smax.romanov@nginx.com sec.getSubprotocols(), subProtocols); 106*1157Smax.romanov@nginx.com 107*1157Smax.romanov@nginx.com // Extensions 108*1157Smax.romanov@nginx.com // Should normally only be one header but handle the case of multiple 109*1157Smax.romanov@nginx.com // headers 110*1157Smax.romanov@nginx.com List<Extension> extensionsRequested = new ArrayList<>(); 111*1157Smax.romanov@nginx.com Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); 112*1157Smax.romanov@nginx.com while (extHeaders.hasMoreElements()) { 113*1157Smax.romanov@nginx.com Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); 114*1157Smax.romanov@nginx.com } 115*1157Smax.romanov@nginx.com 116*1157Smax.romanov@nginx.com // Negotiation phase 1. By default this simply filters out the 117*1157Smax.romanov@nginx.com // extensions that the server does not support but applications could 118*1157Smax.romanov@nginx.com // use a custom configurator to do more than this. 119*1157Smax.romanov@nginx.com List<Extension> installedExtensions = null; 120*1157Smax.romanov@nginx.com if (sec.getExtensions().size() == 0) { 121*1157Smax.romanov@nginx.com installedExtensions = Constants.INSTALLED_EXTENSIONS; 122*1157Smax.romanov@nginx.com } else { 123*1157Smax.romanov@nginx.com installedExtensions = new ArrayList<>(); 124*1157Smax.romanov@nginx.com installedExtensions.addAll(sec.getExtensions()); 125*1157Smax.romanov@nginx.com installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); 126*1157Smax.romanov@nginx.com } 127*1157Smax.romanov@nginx.com List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( 128*1157Smax.romanov@nginx.com installedExtensions, extensionsRequested); 129*1157Smax.romanov@nginx.com 130*1157Smax.romanov@nginx.com // Negotiation phase 2. Create the Transformations that will be applied 131*1157Smax.romanov@nginx.com // to this connection. Note than an extension may be dropped at this 132*1157Smax.romanov@nginx.com // point if the client has requested a configuration that the server is 133*1157Smax.romanov@nginx.com // unable to support. 134*1157Smax.romanov@nginx.com List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1); 135*1157Smax.romanov@nginx.com 136*1157Smax.romanov@nginx.com List<Extension> negotiatedExtensionsPhase2; 137*1157Smax.romanov@nginx.com if (transformations.isEmpty()) { 138*1157Smax.romanov@nginx.com negotiatedExtensionsPhase2 = Collections.emptyList(); 139*1157Smax.romanov@nginx.com } else { 140*1157Smax.romanov@nginx.com negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); 141*1157Smax.romanov@nginx.com for (Transformation t : transformations) { 142*1157Smax.romanov@nginx.com negotiatedExtensionsPhase2.add(t.getExtensionResponse()); 143*1157Smax.romanov@nginx.com } 144*1157Smax.romanov@nginx.com } 145*1157Smax.romanov@nginx.com 146*1157Smax.romanov@nginx.com WsHttpUpgradeHandler wsHandler = 147*1157Smax.romanov@nginx.com req.upgrade(WsHttpUpgradeHandler.class); 148*1157Smax.romanov@nginx.com 149*1157Smax.romanov@nginx.com WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); 150*1157Smax.romanov@nginx.com WsHandshakeResponse wsResponse = new WsHandshakeResponse(); 151*1157Smax.romanov@nginx.com WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = 152*1157Smax.romanov@nginx.com new WsPerSessionServerEndpointConfig(sec); 153*1157Smax.romanov@nginx.com sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, 154*1157Smax.romanov@nginx.com wsRequest, wsResponse); 155*1157Smax.romanov@nginx.com //wsRequest.finished(); 156*1157Smax.romanov@nginx.com 157*1157Smax.romanov@nginx.com // Add any additional headers 158*1157Smax.romanov@nginx.com for (Entry<String,List<String>> entry : 159*1157Smax.romanov@nginx.com wsResponse.getHeaders().entrySet()) { 160*1157Smax.romanov@nginx.com for (String headerValue: entry.getValue()) { 161*1157Smax.romanov@nginx.com resp.addHeader(entry.getKey(), headerValue); 162*1157Smax.romanov@nginx.com } 163*1157Smax.romanov@nginx.com } 164*1157Smax.romanov@nginx.com 165*1157Smax.romanov@nginx.com Endpoint ep; 166*1157Smax.romanov@nginx.com try { 167*1157Smax.romanov@nginx.com Class<?> clazz = sec.getEndpointClass(); 168*1157Smax.romanov@nginx.com if (Endpoint.class.isAssignableFrom(clazz)) { 169*1157Smax.romanov@nginx.com ep = (Endpoint) sec.getConfigurator().getEndpointInstance( 170*1157Smax.romanov@nginx.com clazz); 171*1157Smax.romanov@nginx.com } else { 172*1157Smax.romanov@nginx.com ep = new PojoEndpointServer(); 173*1157Smax.romanov@nginx.com // Need to make path params available to POJO 174*1157Smax.romanov@nginx.com perSessionServerEndpointConfig.getUserProperties().put( 175*1157Smax.romanov@nginx.com nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); 176*1157Smax.romanov@nginx.com } 177*1157Smax.romanov@nginx.com } catch (InstantiationException e) { 178*1157Smax.romanov@nginx.com throw new ServletException(e); 179*1157Smax.romanov@nginx.com } 180*1157Smax.romanov@nginx.com 181*1157Smax.romanov@nginx.com wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, 182*1157Smax.romanov@nginx.com negotiatedExtensionsPhase2, subProtocol, null, pathParams, 183*1157Smax.romanov@nginx.com req.isSecure()); 184*1157Smax.romanov@nginx.com 185*1157Smax.romanov@nginx.com wsHandler.init(null); 186*1157Smax.romanov@nginx.com } 187*1157Smax.romanov@nginx.com 188*1157Smax.romanov@nginx.com createTransformations( List<Extension> negotiatedExtensions)189*1157Smax.romanov@nginx.com private static List<Transformation> createTransformations( 190*1157Smax.romanov@nginx.com List<Extension> negotiatedExtensions) { 191*1157Smax.romanov@nginx.com 192*1157Smax.romanov@nginx.com TransformationFactory factory = TransformationFactory.getInstance(); 193*1157Smax.romanov@nginx.com 194*1157Smax.romanov@nginx.com LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences = 195*1157Smax.romanov@nginx.com new LinkedHashMap<>(); 196*1157Smax.romanov@nginx.com 197*1157Smax.romanov@nginx.com // Result will likely be smaller than this 198*1157Smax.romanov@nginx.com List<Transformation> result = new ArrayList<>(negotiatedExtensions.size()); 199*1157Smax.romanov@nginx.com 200*1157Smax.romanov@nginx.com for (Extension extension : negotiatedExtensions) { 201*1157Smax.romanov@nginx.com List<List<Extension.Parameter>> preferences = 202*1157Smax.romanov@nginx.com extensionPreferences.get(extension.getName()); 203*1157Smax.romanov@nginx.com 204*1157Smax.romanov@nginx.com if (preferences == null) { 205*1157Smax.romanov@nginx.com preferences = new ArrayList<>(); 206*1157Smax.romanov@nginx.com extensionPreferences.put(extension.getName(), preferences); 207*1157Smax.romanov@nginx.com } 208*1157Smax.romanov@nginx.com 209*1157Smax.romanov@nginx.com preferences.add(extension.getParameters()); 210*1157Smax.romanov@nginx.com } 211*1157Smax.romanov@nginx.com 212*1157Smax.romanov@nginx.com for (Map.Entry<String,List<List<Extension.Parameter>>> entry : 213*1157Smax.romanov@nginx.com extensionPreferences.entrySet()) { 214*1157Smax.romanov@nginx.com Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); 215*1157Smax.romanov@nginx.com if (transformation != null) { 216*1157Smax.romanov@nginx.com result.add(transformation); 217*1157Smax.romanov@nginx.com } 218*1157Smax.romanov@nginx.com } 219*1157Smax.romanov@nginx.com return result; 220*1157Smax.romanov@nginx.com } 221*1157Smax.romanov@nginx.com 222*1157Smax.romanov@nginx.com append(StringBuilder sb, Extension extension)223*1157Smax.romanov@nginx.com private static void append(StringBuilder sb, Extension extension) { 224*1157Smax.romanov@nginx.com if (extension == null || extension.getName() == null || extension.getName().length() == 0) { 225*1157Smax.romanov@nginx.com return; 226*1157Smax.romanov@nginx.com } 227*1157Smax.romanov@nginx.com 228*1157Smax.romanov@nginx.com sb.append(extension.getName()); 229*1157Smax.romanov@nginx.com 230*1157Smax.romanov@nginx.com for (Extension.Parameter p : extension.getParameters()) { 231*1157Smax.romanov@nginx.com sb.append(';'); 232*1157Smax.romanov@nginx.com sb.append(p.getName()); 233*1157Smax.romanov@nginx.com if (p.getValue() != null) { 234*1157Smax.romanov@nginx.com sb.append('='); 235*1157Smax.romanov@nginx.com sb.append(p.getValue()); 236*1157Smax.romanov@nginx.com } 237*1157Smax.romanov@nginx.com } 238*1157Smax.romanov@nginx.com } 239*1157Smax.romanov@nginx.com 240*1157Smax.romanov@nginx.com 241*1157Smax.romanov@nginx.com /* 242*1157Smax.romanov@nginx.com * This only works for tokens. Quoted strings need more sophisticated 243*1157Smax.romanov@nginx.com * parsing. 244*1157Smax.romanov@nginx.com */ headerContainsToken(HttpServletRequest req, String headerName, String target)245*1157Smax.romanov@nginx.com private static boolean headerContainsToken(HttpServletRequest req, 246*1157Smax.romanov@nginx.com String headerName, String target) { 247*1157Smax.romanov@nginx.com Enumeration<String> headers = req.getHeaders(headerName); 248*1157Smax.romanov@nginx.com while (headers.hasMoreElements()) { 249*1157Smax.romanov@nginx.com String header = headers.nextElement(); 250*1157Smax.romanov@nginx.com String[] tokens = header.split(","); 251*1157Smax.romanov@nginx.com for (String token : tokens) { 252*1157Smax.romanov@nginx.com if (target.equalsIgnoreCase(token.trim())) { 253*1157Smax.romanov@nginx.com return true; 254*1157Smax.romanov@nginx.com } 255*1157Smax.romanov@nginx.com } 256*1157Smax.romanov@nginx.com } 257*1157Smax.romanov@nginx.com return false; 258*1157Smax.romanov@nginx.com } 259*1157Smax.romanov@nginx.com 260*1157Smax.romanov@nginx.com 261*1157Smax.romanov@nginx.com /* 262*1157Smax.romanov@nginx.com * This only works for tokens. Quoted strings need more sophisticated 263*1157Smax.romanov@nginx.com * parsing. 264*1157Smax.romanov@nginx.com */ getTokensFromHeader(HttpServletRequest req, String headerName)265*1157Smax.romanov@nginx.com private static List<String> getTokensFromHeader(HttpServletRequest req, 266*1157Smax.romanov@nginx.com String headerName) { 267*1157Smax.romanov@nginx.com List<String> result = new ArrayList<>(); 268*1157Smax.romanov@nginx.com Enumeration<String> headers = req.getHeaders(headerName); 269*1157Smax.romanov@nginx.com while (headers.hasMoreElements()) { 270*1157Smax.romanov@nginx.com String header = headers.nextElement(); 271*1157Smax.romanov@nginx.com String[] tokens = header.split(","); 272*1157Smax.romanov@nginx.com for (String token : tokens) { 273*1157Smax.romanov@nginx.com result.add(token.trim()); 274*1157Smax.romanov@nginx.com } 275*1157Smax.romanov@nginx.com } 276*1157Smax.romanov@nginx.com return result; 277*1157Smax.romanov@nginx.com } 278*1157Smax.romanov@nginx.com 279*1157Smax.romanov@nginx.com getWebSocketAccept(String key)280*1157Smax.romanov@nginx.com private static String getWebSocketAccept(String key) { 281*1157Smax.romanov@nginx.com byte[] digest = ConcurrentMessageDigest.digestSHA1( 282*1157Smax.romanov@nginx.com key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); 283*1157Smax.romanov@nginx.com return Base64.encodeBase64String(digest); 284*1157Smax.romanov@nginx.com } 285*1157Smax.romanov@nginx.com } 286