xref: /unit/src/java/nginx/unit/websocket/PerMessageDeflate.java (revision 1157:7ae152bda303)
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 package nginx.unit.websocket;
18 
19 import java.io.IOException;
20 import java.nio.ByteBuffer;
21 import java.util.ArrayList;
22 import java.util.List;
23 import java.util.zip.DataFormatException;
24 import java.util.zip.Deflater;
25 import java.util.zip.Inflater;
26 
27 import javax.websocket.Extension;
28 import javax.websocket.Extension.Parameter;
29 import javax.websocket.SendHandler;
30 
31 import org.apache.tomcat.util.res.StringManager;
32 
33 public class PerMessageDeflate implements Transformation {
34 
35     private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);
36 
37     private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
38     private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
39     private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
40     private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
41 
42     private static final int RSV_BITMASK = 0b100;
43     private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1};
44 
45     public static final String NAME = "permessage-deflate";
46 
47     private final boolean serverContextTakeover;
48     private final int serverMaxWindowBits;
49     private final boolean clientContextTakeover;
50     private final int clientMaxWindowBits;
51     private final boolean isServer;
52     private final Inflater inflater = new Inflater(true);
53     private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
54     private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
55     private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];
56 
57     private volatile Transformation next;
58     private volatile boolean skipDecompression = false;
59     private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
60     private volatile boolean firstCompressedFrameWritten = false;
61     // Flag to track if a message is completely empty
62     private volatile boolean emptyMessage = true;
63 
negotiate(List<List<Parameter>> preferences, boolean isServer)64     static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
65         // Accept the first preference that the endpoint is able to support
66         for (List<Parameter> preference : preferences) {
67             boolean ok = true;
68             boolean serverContextTakeover = true;
69             int serverMaxWindowBits = -1;
70             boolean clientContextTakeover = true;
71             int clientMaxWindowBits = -1;
72 
73             for (Parameter param : preference) {
74                 if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
75                     if (serverContextTakeover) {
76                         serverContextTakeover = false;
77                     } else {
78                         // Duplicate definition
79                         throw new IllegalArgumentException(sm.getString(
80                                 "perMessageDeflate.duplicateParameter",
81                                 SERVER_NO_CONTEXT_TAKEOVER ));
82                     }
83                 } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
84                     if (clientContextTakeover) {
85                         clientContextTakeover = false;
86                     } else {
87                         // Duplicate definition
88                         throw new IllegalArgumentException(sm.getString(
89                                 "perMessageDeflate.duplicateParameter",
90                                 CLIENT_NO_CONTEXT_TAKEOVER ));
91                     }
92                 } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
93                     if (serverMaxWindowBits == -1) {
94                         serverMaxWindowBits = Integer.parseInt(param.getValue());
95                         if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
96                             throw new IllegalArgumentException(sm.getString(
97                                     "perMessageDeflate.invalidWindowSize",
98                                     SERVER_MAX_WINDOW_BITS,
99                                     Integer.valueOf(serverMaxWindowBits)));
100                         }
101                         // Java SE API (as of Java 8) does not expose the API to
102                         // control the Window size. It is effectively hard-coded
103                         // to 15
104                         if (isServer && serverMaxWindowBits != 15) {
105                             ok = false;
106                             break;
107                             // Note server window size is not an issue for the
108                             // client since the client will assume 15 and if the
109                             // server uses a smaller window everything will
110                             // still work
111                         }
112                     } else {
113                         // Duplicate definition
114                         throw new IllegalArgumentException(sm.getString(
115                                 "perMessageDeflate.duplicateParameter",
116                                 SERVER_MAX_WINDOW_BITS ));
117                     }
118                 } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
119                     if (clientMaxWindowBits == -1) {
120                         if (param.getValue() == null) {
121                             // Hint to server that the client supports this
122                             // option. Java SE API (as of Java 8) does not
123                             // expose the API to control the Window size. It is
124                             // effectively hard-coded to 15
125                             clientMaxWindowBits = 15;
126                         } else {
127                             clientMaxWindowBits = Integer.parseInt(param.getValue());
128                             if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
129                                 throw new IllegalArgumentException(sm.getString(
130                                         "perMessageDeflate.invalidWindowSize",
131                                         CLIENT_MAX_WINDOW_BITS,
132                                         Integer.valueOf(clientMaxWindowBits)));
133                             }
134                         }
135                         // Java SE API (as of Java 8) does not expose the API to
136                         // control the Window size. It is effectively hard-coded
137                         // to 15
138                         if (!isServer && clientMaxWindowBits != 15) {
139                             ok = false;
140                             break;
141                             // Note client window size is not an issue for the
142                             // server since the server will assume 15 and if the
143                             // client uses a smaller window everything will
144                             // still work
145                         }
146                     } else {
147                         // Duplicate definition
148                         throw new IllegalArgumentException(sm.getString(
149                                 "perMessageDeflate.duplicateParameter",
150                                 CLIENT_MAX_WINDOW_BITS ));
151                     }
152                 } else {
153                     // Unknown parameter
154                     throw new IllegalArgumentException(sm.getString(
155                             "perMessageDeflate.unknownParameter", param.getName()));
156                 }
157             }
158             if (ok) {
159                 return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits,
160                         clientContextTakeover, clientMaxWindowBits, isServer);
161             }
162         }
163         // Failed to negotiate agreeable terms
164         return null;
165     }
166 
167 
PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer)168     private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits,
169             boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) {
170         this.serverContextTakeover = serverContextTakeover;
171         this.serverMaxWindowBits = serverMaxWindowBits;
172         this.clientContextTakeover = clientContextTakeover;
173         this.clientMaxWindowBits = clientMaxWindowBits;
174         this.isServer = isServer;
175     }
176 
177 
178     @Override
getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)179     public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
180             throws IOException {
181         // Control frames are never compressed and may appear in the middle of
182         // a WebSocket method. Pass them straight through.
183         if (Util.isControl(opCode)) {
184             return next.getMoreData(opCode, fin, rsv, dest);
185         }
186 
187         if (!Util.isContinuation(opCode)) {
188             // First frame in new message
189             skipDecompression = (rsv & RSV_BITMASK) == 0;
190         }
191 
192         // Pass uncompressed frames straight through.
193         if (skipDecompression) {
194             return next.getMoreData(opCode, fin, rsv, dest);
195         }
196 
197         int written;
198         boolean usedEomBytes = false;
199 
200         while (dest.remaining() > 0) {
201             // Space available in destination. Try and fill it.
202             try {
203                 written = inflater.inflate(
204                         dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
205             } catch (DataFormatException e) {
206                 throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
207             }
208             dest.position(dest.position() + written);
209 
210             if (inflater.needsInput() && !usedEomBytes ) {
211                 if (dest.hasRemaining()) {
212                     readBuffer.clear();
213                     TransformationResult nextResult =
214                             next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
215                     inflater.setInput(
216                             readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
217                     if (TransformationResult.UNDERFLOW.equals(nextResult)) {
218                         return nextResult;
219                     } else if (TransformationResult.END_OF_FRAME.equals(nextResult) &&
220                             readBuffer.position() == 0) {
221                         if (fin) {
222                             inflater.setInput(EOM_BYTES);
223                             usedEomBytes = true;
224                         } else {
225                             return TransformationResult.END_OF_FRAME;
226                         }
227                     }
228                 }
229             } else if (written == 0) {
230                 if (fin && (isServer && !clientContextTakeover ||
231                         !isServer && !serverContextTakeover)) {
232                     inflater.reset();
233                 }
234                 return TransformationResult.END_OF_FRAME;
235             }
236         }
237 
238         return TransformationResult.OVERFLOW;
239     }
240 
241 
242     @Override
validateRsv(int rsv, byte opCode)243     public boolean validateRsv(int rsv, byte opCode) {
244         if (Util.isControl(opCode)) {
245             if ((rsv & RSV_BITMASK) != 0) {
246                 return false;
247             } else {
248                 if (next == null) {
249                     return true;
250                 } else {
251                     return next.validateRsv(rsv, opCode);
252                 }
253             }
254         } else {
255             int rsvNext = rsv;
256             if ((rsv & RSV_BITMASK) != 0) {
257                 rsvNext = rsv ^ RSV_BITMASK;
258             }
259             if (next == null) {
260                 return true;
261             } else {
262                 return next.validateRsv(rsvNext, opCode);
263             }
264         }
265     }
266 
267 
268     @Override
getExtensionResponse()269     public Extension getExtensionResponse() {
270         Extension result = new WsExtension(NAME);
271 
272         List<Extension.Parameter> params = result.getParameters();
273 
274         if (!serverContextTakeover) {
275             params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
276         }
277         if (serverMaxWindowBits != -1) {
278             params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS,
279                     Integer.toString(serverMaxWindowBits)));
280         }
281         if (!clientContextTakeover) {
282             params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
283         }
284         if (clientMaxWindowBits != -1) {
285             params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS,
286                     Integer.toString(clientMaxWindowBits)));
287         }
288 
289         return result;
290     }
291 
292 
293     @Override
setNext(Transformation t)294     public void setNext(Transformation t) {
295         if (next == null) {
296             this.next = t;
297         } else {
298             next.setNext(t);
299         }
300     }
301 
302 
303     @Override
validateRsvBits(int i)304     public boolean validateRsvBits(int i) {
305         if ((i & RSV_BITMASK) != 0) {
306             return false;
307         }
308         if (next == null) {
309             return true;
310         } else {
311             return next.validateRsvBits(i | RSV_BITMASK);
312         }
313     }
314 
315 
316     @Override
sendMessagePart(List<MessagePart> uncompressedParts)317     public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) {
318         List<MessagePart> allCompressedParts = new ArrayList<>();
319 
320         for (MessagePart uncompressedPart : uncompressedParts) {
321             byte opCode = uncompressedPart.getOpCode();
322             boolean emptyPart = uncompressedPart.getPayload().limit() == 0;
323             emptyMessage = emptyMessage && emptyPart;
324             if (Util.isControl(opCode)) {
325                 // Control messages can appear in the middle of other messages
326                 // and must not be compressed. Pass it straight through
327                 allCompressedParts.add(uncompressedPart);
328             } else if (emptyMessage && uncompressedPart.isFin()) {
329                 // Zero length messages can't be compressed so pass the
330                 // final (empty) part straight through.
331                 allCompressedParts.add(uncompressedPart);
332             } else {
333                 List<MessagePart> compressedParts = new ArrayList<>();
334                 ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
335                 SendHandler uncompressedIntermediateHandler =
336                         uncompressedPart.getIntermediateHandler();
337 
338                 deflater.setInput(uncompressedPayload.array(),
339                         uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
340                         uncompressedPayload.remaining());
341 
342                 int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
343                 boolean deflateRequired = true;
344 
345                 while (deflateRequired) {
346                     ByteBuffer compressedPayload = writeBuffer;
347 
348                     int written = deflater.deflate(compressedPayload.array(),
349                             compressedPayload.arrayOffset() + compressedPayload.position(),
350                             compressedPayload.remaining(), flush);
351                     compressedPayload.position(compressedPayload.position() + written);
352 
353                     if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
354                         // This message part has been fully processed by the
355                         // deflater. Fire the send handler for this message part
356                         // and move on to the next message part.
357                         break;
358                     }
359 
360                     // If this point is reached, a new compressed message part
361                     // will be created...
362                     MessagePart compressedPart;
363 
364                     // .. and a new writeBuffer will be required.
365                     writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
366 
367                     // Flip the compressed payload ready for writing
368                     compressedPayload.flip();
369 
370                     boolean fin = uncompressedPart.isFin();
371                     boolean full = compressedPayload.limit() == compressedPayload.capacity();
372                     boolean needsInput = deflater.needsInput();
373                     long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();
374 
375                     if (fin && !full && needsInput) {
376                         // End of compressed message. Drop EOM bytes and output.
377                         compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
378                         compressedPart = new MessagePart(true, getRsv(uncompressedPart),
379                                 opCode, compressedPayload, uncompressedIntermediateHandler,
380                                 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
381                         deflateRequired = false;
382                         startNewMessage();
383                     } else if (full && !needsInput) {
384                         // Write buffer full and input message not fully read.
385                         // Output and start new compressed part.
386                         compressedPart = new MessagePart(false, getRsv(uncompressedPart),
387                                 opCode, compressedPayload, uncompressedIntermediateHandler,
388                                 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
389                     } else if (!fin && full && needsInput) {
390                         // Write buffer full and input message not fully read.
391                         // Output and get more data.
392                         compressedPart = new MessagePart(false, getRsv(uncompressedPart),
393                                 opCode, compressedPayload, uncompressedIntermediateHandler,
394                                 uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
395                         deflateRequired = false;
396                     } else if (fin && full && needsInput) {
397                         // Write buffer full. Input fully read. Deflater may be
398                         // in one of four states:
399                         // - output complete (just happened to align with end of
400                         //   buffer
401                         // - in middle of EOM bytes
402                         // - about to write EOM bytes
403                         // - more data to write
404                         int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
405                         if (eomBufferWritten < EOM_BUFFER.length) {
406                             // EOM has just been completed
407                             compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
408                             compressedPart = new MessagePart(true,
409                                     getRsv(uncompressedPart), opCode, compressedPayload,
410                                     uncompressedIntermediateHandler, uncompressedIntermediateHandler,
411                                     blockingWriteTimeoutExpiry);
412                             deflateRequired = false;
413                             startNewMessage();
414                         } else {
415                             // More data to write
416                             // Copy bytes to new write buffer
417                             writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
418                             compressedPart = new MessagePart(false,
419                                     getRsv(uncompressedPart), opCode, compressedPayload,
420                                     uncompressedIntermediateHandler, uncompressedIntermediateHandler,
421                                     blockingWriteTimeoutExpiry);
422                         }
423                     } else {
424                         throw new IllegalStateException("Should never happen");
425                     }
426 
427                     // Add the newly created compressed part to the set of parts
428                     // to pass on to the next transformation.
429                     compressedParts.add(compressedPart);
430                 }
431 
432                 SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
433                 int size = compressedParts.size();
434                 if (size > 0) {
435                     compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
436                 }
437 
438                 allCompressedParts.addAll(compressedParts);
439             }
440         }
441 
442         if (next == null) {
443             return allCompressedParts;
444         } else {
445             return next.sendMessagePart(allCompressedParts);
446         }
447     }
448 
449 
startNewMessage()450     private void startNewMessage() {
451         firstCompressedFrameWritten = false;
452         emptyMessage = true;
453         if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
454             deflater.reset();
455         }
456     }
457 
458 
getRsv(MessagePart uncompressedMessagePart)459     private int getRsv(MessagePart uncompressedMessagePart) {
460         int result = uncompressedMessagePart.getRsv();
461         if (!firstCompressedFrameWritten) {
462             result += RSV_BITMASK;
463             firstCompressedFrameWritten = true;
464         }
465         return result;
466     }
467 
468 
469     @Override
close()470     public void close() {
471         // There will always be a next transformation
472         next.close();
473         inflater.end();
474         deflater.end();
475     }
476 }
477