1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 package nginx.unit.websocket; 18 19 import java.io.IOException; 20 import java.nio.ByteBuffer; 21 import java.util.ArrayList; 22 import java.util.List; 23 import java.util.zip.DataFormatException; 24 import java.util.zip.Deflater; 25 import java.util.zip.Inflater; 26 27 import javax.websocket.Extension; 28 import javax.websocket.Extension.Parameter; 29 import javax.websocket.SendHandler; 30 31 import org.apache.tomcat.util.res.StringManager; 32 33 public class PerMessageDeflate implements Transformation { 34 35 private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class); 36 37 private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover"; 38 private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover"; 39 private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits"; 40 private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits"; 41 42 private static final int RSV_BITMASK = 0b100; 43 private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1}; 44 45 public static final String NAME = "permessage-deflate"; 46 47 private final boolean serverContextTakeover; 48 private final int serverMaxWindowBits; 49 private final boolean clientContextTakeover; 50 private final int clientMaxWindowBits; 51 private final boolean isServer; 52 private final Inflater inflater = new Inflater(true); 53 private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 54 private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); 55 private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1]; 56 57 private volatile Transformation next; 58 private volatile boolean skipDecompression = false; 59 private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 60 private volatile boolean firstCompressedFrameWritten = false; 61 // Flag to track if a message is completely empty 62 private volatile boolean emptyMessage = true; 63 negotiate(List<List<Parameter>> preferences, boolean isServer)64 static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) { 65 // Accept the first preference that the endpoint is able to support 66 for (List<Parameter> preference : preferences) { 67 boolean ok = true; 68 boolean serverContextTakeover = true; 69 int serverMaxWindowBits = -1; 70 boolean clientContextTakeover = true; 71 int clientMaxWindowBits = -1; 72 73 for (Parameter param : preference) { 74 if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) { 75 if (serverContextTakeover) { 76 serverContextTakeover = false; 77 } else { 78 // Duplicate definition 79 throw new IllegalArgumentException(sm.getString( 80 "perMessageDeflate.duplicateParameter", 81 SERVER_NO_CONTEXT_TAKEOVER )); 82 } 83 } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) { 84 if (clientContextTakeover) { 85 clientContextTakeover = false; 86 } else { 87 // Duplicate definition 88 throw new IllegalArgumentException(sm.getString( 89 "perMessageDeflate.duplicateParameter", 90 CLIENT_NO_CONTEXT_TAKEOVER )); 91 } 92 } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) { 93 if (serverMaxWindowBits == -1) { 94 serverMaxWindowBits = Integer.parseInt(param.getValue()); 95 if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) { 96 throw new IllegalArgumentException(sm.getString( 97 "perMessageDeflate.invalidWindowSize", 98 SERVER_MAX_WINDOW_BITS, 99 Integer.valueOf(serverMaxWindowBits))); 100 } 101 // Java SE API (as of Java 8) does not expose the API to 102 // control the Window size. It is effectively hard-coded 103 // to 15 104 if (isServer && serverMaxWindowBits != 15) { 105 ok = false; 106 break; 107 // Note server window size is not an issue for the 108 // client since the client will assume 15 and if the 109 // server uses a smaller window everything will 110 // still work 111 } 112 } else { 113 // Duplicate definition 114 throw new IllegalArgumentException(sm.getString( 115 "perMessageDeflate.duplicateParameter", 116 SERVER_MAX_WINDOW_BITS )); 117 } 118 } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) { 119 if (clientMaxWindowBits == -1) { 120 if (param.getValue() == null) { 121 // Hint to server that the client supports this 122 // option. Java SE API (as of Java 8) does not 123 // expose the API to control the Window size. It is 124 // effectively hard-coded to 15 125 clientMaxWindowBits = 15; 126 } else { 127 clientMaxWindowBits = Integer.parseInt(param.getValue()); 128 if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) { 129 throw new IllegalArgumentException(sm.getString( 130 "perMessageDeflate.invalidWindowSize", 131 CLIENT_MAX_WINDOW_BITS, 132 Integer.valueOf(clientMaxWindowBits))); 133 } 134 } 135 // Java SE API (as of Java 8) does not expose the API to 136 // control the Window size. It is effectively hard-coded 137 // to 15 138 if (!isServer && clientMaxWindowBits != 15) { 139 ok = false; 140 break; 141 // Note client window size is not an issue for the 142 // server since the server will assume 15 and if the 143 // client uses a smaller window everything will 144 // still work 145 } 146 } else { 147 // Duplicate definition 148 throw new IllegalArgumentException(sm.getString( 149 "perMessageDeflate.duplicateParameter", 150 CLIENT_MAX_WINDOW_BITS )); 151 } 152 } else { 153 // Unknown parameter 154 throw new IllegalArgumentException(sm.getString( 155 "perMessageDeflate.unknownParameter", param.getName())); 156 } 157 } 158 if (ok) { 159 return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits, 160 clientContextTakeover, clientMaxWindowBits, isServer); 161 } 162 } 163 // Failed to negotiate agreeable terms 164 return null; 165 } 166 167 PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer)168 private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, 169 boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) { 170 this.serverContextTakeover = serverContextTakeover; 171 this.serverMaxWindowBits = serverMaxWindowBits; 172 this.clientContextTakeover = clientContextTakeover; 173 this.clientMaxWindowBits = clientMaxWindowBits; 174 this.isServer = isServer; 175 } 176 177 178 @Override getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)179 public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) 180 throws IOException { 181 // Control frames are never compressed and may appear in the middle of 182 // a WebSocket method. Pass them straight through. 183 if (Util.isControl(opCode)) { 184 return next.getMoreData(opCode, fin, rsv, dest); 185 } 186 187 if (!Util.isContinuation(opCode)) { 188 // First frame in new message 189 skipDecompression = (rsv & RSV_BITMASK) == 0; 190 } 191 192 // Pass uncompressed frames straight through. 193 if (skipDecompression) { 194 return next.getMoreData(opCode, fin, rsv, dest); 195 } 196 197 int written; 198 boolean usedEomBytes = false; 199 200 while (dest.remaining() > 0) { 201 // Space available in destination. Try and fill it. 202 try { 203 written = inflater.inflate( 204 dest.array(), dest.arrayOffset() + dest.position(), dest.remaining()); 205 } catch (DataFormatException e) { 206 throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e); 207 } 208 dest.position(dest.position() + written); 209 210 if (inflater.needsInput() && !usedEomBytes ) { 211 if (dest.hasRemaining()) { 212 readBuffer.clear(); 213 TransformationResult nextResult = 214 next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer); 215 inflater.setInput( 216 readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position()); 217 if (TransformationResult.UNDERFLOW.equals(nextResult)) { 218 return nextResult; 219 } else if (TransformationResult.END_OF_FRAME.equals(nextResult) && 220 readBuffer.position() == 0) { 221 if (fin) { 222 inflater.setInput(EOM_BYTES); 223 usedEomBytes = true; 224 } else { 225 return TransformationResult.END_OF_FRAME; 226 } 227 } 228 } 229 } else if (written == 0) { 230 if (fin && (isServer && !clientContextTakeover || 231 !isServer && !serverContextTakeover)) { 232 inflater.reset(); 233 } 234 return TransformationResult.END_OF_FRAME; 235 } 236 } 237 238 return TransformationResult.OVERFLOW; 239 } 240 241 242 @Override validateRsv(int rsv, byte opCode)243 public boolean validateRsv(int rsv, byte opCode) { 244 if (Util.isControl(opCode)) { 245 if ((rsv & RSV_BITMASK) != 0) { 246 return false; 247 } else { 248 if (next == null) { 249 return true; 250 } else { 251 return next.validateRsv(rsv, opCode); 252 } 253 } 254 } else { 255 int rsvNext = rsv; 256 if ((rsv & RSV_BITMASK) != 0) { 257 rsvNext = rsv ^ RSV_BITMASK; 258 } 259 if (next == null) { 260 return true; 261 } else { 262 return next.validateRsv(rsvNext, opCode); 263 } 264 } 265 } 266 267 268 @Override getExtensionResponse()269 public Extension getExtensionResponse() { 270 Extension result = new WsExtension(NAME); 271 272 List<Extension.Parameter> params = result.getParameters(); 273 274 if (!serverContextTakeover) { 275 params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null)); 276 } 277 if (serverMaxWindowBits != -1) { 278 params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS, 279 Integer.toString(serverMaxWindowBits))); 280 } 281 if (!clientContextTakeover) { 282 params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null)); 283 } 284 if (clientMaxWindowBits != -1) { 285 params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS, 286 Integer.toString(clientMaxWindowBits))); 287 } 288 289 return result; 290 } 291 292 293 @Override setNext(Transformation t)294 public void setNext(Transformation t) { 295 if (next == null) { 296 this.next = t; 297 } else { 298 next.setNext(t); 299 } 300 } 301 302 303 @Override validateRsvBits(int i)304 public boolean validateRsvBits(int i) { 305 if ((i & RSV_BITMASK) != 0) { 306 return false; 307 } 308 if (next == null) { 309 return true; 310 } else { 311 return next.validateRsvBits(i | RSV_BITMASK); 312 } 313 } 314 315 316 @Override sendMessagePart(List<MessagePart> uncompressedParts)317 public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) { 318 List<MessagePart> allCompressedParts = new ArrayList<>(); 319 320 for (MessagePart uncompressedPart : uncompressedParts) { 321 byte opCode = uncompressedPart.getOpCode(); 322 boolean emptyPart = uncompressedPart.getPayload().limit() == 0; 323 emptyMessage = emptyMessage && emptyPart; 324 if (Util.isControl(opCode)) { 325 // Control messages can appear in the middle of other messages 326 // and must not be compressed. Pass it straight through 327 allCompressedParts.add(uncompressedPart); 328 } else if (emptyMessage && uncompressedPart.isFin()) { 329 // Zero length messages can't be compressed so pass the 330 // final (empty) part straight through. 331 allCompressedParts.add(uncompressedPart); 332 } else { 333 List<MessagePart> compressedParts = new ArrayList<>(); 334 ByteBuffer uncompressedPayload = uncompressedPart.getPayload(); 335 SendHandler uncompressedIntermediateHandler = 336 uncompressedPart.getIntermediateHandler(); 337 338 deflater.setInput(uncompressedPayload.array(), 339 uncompressedPayload.arrayOffset() + uncompressedPayload.position(), 340 uncompressedPayload.remaining()); 341 342 int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH); 343 boolean deflateRequired = true; 344 345 while (deflateRequired) { 346 ByteBuffer compressedPayload = writeBuffer; 347 348 int written = deflater.deflate(compressedPayload.array(), 349 compressedPayload.arrayOffset() + compressedPayload.position(), 350 compressedPayload.remaining(), flush); 351 compressedPayload.position(compressedPayload.position() + written); 352 353 if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) { 354 // This message part has been fully processed by the 355 // deflater. Fire the send handler for this message part 356 // and move on to the next message part. 357 break; 358 } 359 360 // If this point is reached, a new compressed message part 361 // will be created... 362 MessagePart compressedPart; 363 364 // .. and a new writeBuffer will be required. 365 writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); 366 367 // Flip the compressed payload ready for writing 368 compressedPayload.flip(); 369 370 boolean fin = uncompressedPart.isFin(); 371 boolean full = compressedPayload.limit() == compressedPayload.capacity(); 372 boolean needsInput = deflater.needsInput(); 373 long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry(); 374 375 if (fin && !full && needsInput) { 376 // End of compressed message. Drop EOM bytes and output. 377 compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length); 378 compressedPart = new MessagePart(true, getRsv(uncompressedPart), 379 opCode, compressedPayload, uncompressedIntermediateHandler, 380 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); 381 deflateRequired = false; 382 startNewMessage(); 383 } else if (full && !needsInput) { 384 // Write buffer full and input message not fully read. 385 // Output and start new compressed part. 386 compressedPart = new MessagePart(false, getRsv(uncompressedPart), 387 opCode, compressedPayload, uncompressedIntermediateHandler, 388 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); 389 } else if (!fin && full && needsInput) { 390 // Write buffer full and input message not fully read. 391 // Output and get more data. 392 compressedPart = new MessagePart(false, getRsv(uncompressedPart), 393 opCode, compressedPayload, uncompressedIntermediateHandler, 394 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); 395 deflateRequired = false; 396 } else if (fin && full && needsInput) { 397 // Write buffer full. Input fully read. Deflater may be 398 // in one of four states: 399 // - output complete (just happened to align with end of 400 // buffer 401 // - in middle of EOM bytes 402 // - about to write EOM bytes 403 // - more data to write 404 int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH); 405 if (eomBufferWritten < EOM_BUFFER.length) { 406 // EOM has just been completed 407 compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten); 408 compressedPart = new MessagePart(true, 409 getRsv(uncompressedPart), opCode, compressedPayload, 410 uncompressedIntermediateHandler, uncompressedIntermediateHandler, 411 blockingWriteTimeoutExpiry); 412 deflateRequired = false; 413 startNewMessage(); 414 } else { 415 // More data to write 416 // Copy bytes to new write buffer 417 writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten); 418 compressedPart = new MessagePart(false, 419 getRsv(uncompressedPart), opCode, compressedPayload, 420 uncompressedIntermediateHandler, uncompressedIntermediateHandler, 421 blockingWriteTimeoutExpiry); 422 } 423 } else { 424 throw new IllegalStateException("Should never happen"); 425 } 426 427 // Add the newly created compressed part to the set of parts 428 // to pass on to the next transformation. 429 compressedParts.add(compressedPart); 430 } 431 432 SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler(); 433 int size = compressedParts.size(); 434 if (size > 0) { 435 compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler); 436 } 437 438 allCompressedParts.addAll(compressedParts); 439 } 440 } 441 442 if (next == null) { 443 return allCompressedParts; 444 } else { 445 return next.sendMessagePart(allCompressedParts); 446 } 447 } 448 449 startNewMessage()450 private void startNewMessage() { 451 firstCompressedFrameWritten = false; 452 emptyMessage = true; 453 if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) { 454 deflater.reset(); 455 } 456 } 457 458 getRsv(MessagePart uncompressedMessagePart)459 private int getRsv(MessagePart uncompressedMessagePart) { 460 int result = uncompressedMessagePart.getRsv(); 461 if (!firstCompressedFrameWritten) { 462 result += RSV_BITMASK; 463 firstCompressedFrameWritten = true; 464 } 465 return result; 466 } 467 468 469 @Override close()470 public void close() { 471 // There will always be a next transformation 472 next.close(); 473 inflater.end(); 474 deflater.end(); 475 } 476 } 477