1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package nginx.unit.websocket.server;
18 
19 import java.net.URI;
20 import java.net.URISyntaxException;
21 import java.security.Principal;
22 import java.util.Arrays;
23 import java.util.Collections;
24 import java.util.Enumeration;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Map.Entry;
29 
30 import javax.servlet.http.HttpServletRequest;
31 import javax.websocket.server.HandshakeRequest;
32 
33 import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
34 import org.apache.tomcat.util.res.StringManager;
35 
36 /**
37  * Represents the request that this session was opened under.
38  */
39 public class WsHandshakeRequest implements HandshakeRequest {
40 
41     private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class);
42 
43     private final URI requestUri;
44     private final Map<String,List<String>> parameterMap;
45     private final String queryString;
46     private final Principal userPrincipal;
47     private final Map<String,List<String>> headers;
48     private final Object httpSession;
49 
50     private volatile HttpServletRequest request;
51 
52 
WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams)53     public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) {
54 
55         this.request = request;
56 
57         queryString = request.getQueryString();
58         userPrincipal = request.getUserPrincipal();
59         httpSession = request.getSession(false);
60         requestUri = buildRequestUri(request);
61 
62         // ParameterMap
63         Map<String,String[]> originalParameters = request.getParameterMap();
64         Map<String,List<String>> newParameters =
65                 new HashMap<>(originalParameters.size());
66         for (Entry<String,String[]> entry : originalParameters.entrySet()) {
67             newParameters.put(entry.getKey(),
68                     Collections.unmodifiableList(
69                             Arrays.asList(entry.getValue())));
70         }
71         for (Entry<String,String> entry : pathParams.entrySet()) {
72             newParameters.put(entry.getKey(),
73                     Collections.unmodifiableList(
74                             Collections.singletonList(entry.getValue())));
75         }
76         parameterMap = Collections.unmodifiableMap(newParameters);
77 
78         // Headers
79         Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>();
80 
81         Enumeration<String> headerNames = request.getHeaderNames();
82         while (headerNames.hasMoreElements()) {
83             String headerName = headerNames.nextElement();
84 
85             newHeaders.put(headerName, Collections.unmodifiableList(
86                     Collections.list(request.getHeaders(headerName))));
87         }
88 
89         headers = Collections.unmodifiableMap(newHeaders);
90     }
91 
92     @Override
getRequestURI()93     public URI getRequestURI() {
94         return requestUri;
95     }
96 
97     @Override
getParameterMap()98     public Map<String,List<String>> getParameterMap() {
99         return parameterMap;
100     }
101 
102     @Override
getQueryString()103     public String getQueryString() {
104         return queryString;
105     }
106 
107     @Override
getUserPrincipal()108     public Principal getUserPrincipal() {
109         return userPrincipal;
110     }
111 
112     @Override
getHeaders()113     public Map<String,List<String>> getHeaders() {
114         return headers;
115     }
116 
117     @Override
isUserInRole(String role)118     public boolean isUserInRole(String role) {
119         if (request == null) {
120             throw new IllegalStateException();
121         }
122 
123         return request.isUserInRole(role);
124     }
125 
126     @Override
getHttpSession()127     public Object getHttpSession() {
128         return httpSession;
129     }
130 
131     /**
132      * Called when the HandshakeRequest is no longer required. Since an instance
133      * of this class retains a reference to the current HttpServletRequest that
134      * reference needs to be cleared as the HttpServletRequest may be reused.
135      *
136      * There is no reason for instances of this class to be accessed once the
137      * handshake has been completed.
138      */
finished()139     void finished() {
140         request = null;
141     }
142 
143 
144     /*
145      * See RequestUtil.getRequestURL()
146      */
buildRequestUri(HttpServletRequest req)147     private static URI buildRequestUri(HttpServletRequest req) {
148 
149         StringBuffer uri = new StringBuffer();
150         String scheme = req.getScheme();
151         int port = req.getServerPort();
152         if (port < 0) {
153             // Work around java.net.URL bug
154             port = 80;
155         }
156 
157         if ("http".equals(scheme)) {
158             uri.append("ws");
159         } else if ("https".equals(scheme)) {
160             uri.append("wss");
161         } else {
162             // Should never happen
163             throw new IllegalArgumentException(
164                     sm.getString("wsHandshakeRequest.unknownScheme", scheme));
165         }
166 
167         uri.append("://");
168         uri.append(req.getServerName());
169 
170         if ((scheme.equals("http") && (port != 80))
171             || (scheme.equals("https") && (port != 443))) {
172             uri.append(':');
173             uri.append(port);
174         }
175 
176         uri.append(req.getRequestURI());
177 
178         if (req.getQueryString() != null) {
179             uri.append("?");
180             uri.append(req.getQueryString());
181         }
182 
183         try {
184             return new URI(uri.toString());
185         } catch (URISyntaxException e) {
186             // Should never happen
187             throw new IllegalArgumentException(
188                     sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e);
189         }
190     }
191 
getAttribute(String name)192     public Object getAttribute(String name)
193     {
194         return request != null ? request.getAttribute(name) : null;
195     }
196 }
197