Skip to content

Commit 77d7d17

Browse files
committed
Enforce max message size during permessage-deflate inflation on the server path to prevent decompression-bomb DoS
1 parent 3d9285a commit 77d7d17

5 files changed

Lines changed: 88 additions & 4 deletions

File tree

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/PerMessageDeflateExtension.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.zip.Deflater;
3434
import java.util.zip.Inflater;
3535

36+
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
37+
3638
public final class PerMessageDeflateExtension implements WebSocketExtension {
3739

3840
private static final byte[] TAIL = new byte[]{0x00, 0x00, (byte) 0xFF, (byte) 0xFF};
@@ -76,6 +78,14 @@ public boolean usesRsv1() {
7678

7779
@Override
7880
public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final ByteBuffer payload) throws WebSocketException {
81+
return decode(type, fin, payload, 0L);
82+
}
83+
84+
@Override
85+
public ByteBuffer decode(final WebSocketFrameType type,
86+
final boolean fin,
87+
final ByteBuffer payload,
88+
final long maxOutputSize) throws WebSocketException {
7989
if (!isDataFrame(type) && type != WebSocketFrameType.CONTINUATION) {
8090
throw new WebSocketException("Unsupported frame type for permessage-deflate: " + type);
8191
}
@@ -94,14 +104,23 @@ public ByteBuffer decode(final WebSocketFrameType type, final boolean fin, final
94104
inflater.setInput(withTail);
95105
final ByteArrayOutputStream out = new ByteArrayOutputStream(Math.max(128, input.length));
96106
final byte[] buffer = new byte[Math.min(16384, Math.max(1024, input.length * 2))];
107+
long produced = 0L;
97108
try {
98109
while (!inflater.needsInput()) {
99110
final int count = inflater.inflate(buffer);
100111
if (count == 0 && inflater.needsInput()) {
101112
break;
102113
}
114+
// Enforce the decoded size cap during inflation, not after, so a small
115+
// compressed payload cannot expand into a huge buffer before we react.
116+
if (maxOutputSize > 0L && produced + count > maxOutputSize) {
117+
throw new WebSocketProtocolException(1009, "Message too big");
118+
}
103119
out.write(buffer, 0, count);
120+
produced += count;
104121
}
122+
} catch (final WebSocketProtocolException wspe) {
123+
throw wspe;
105124
} catch (final Exception ex) {
106125
throw new WebSocketException("Unable to inflate payload", ex);
107126
}

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketExtension.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@ default ByteBuffer decode(
5151
return payload;
5252
}
5353

54+
/**
55+
* Decode a frame payload, aborting as soon as the produced output exceeds
56+
* {@code maxOutputSize}. A non-positive limit means no limit. Implementations
57+
* that may expand input (e.g. permessage-deflate) MUST honour the limit during
58+
* the expansion step, not only after it, to prevent decompression-bomb attacks.
59+
*
60+
* @since 5.7
61+
*/
62+
default ByteBuffer decode(
63+
final WebSocketFrameType type,
64+
final boolean fin,
65+
final ByteBuffer payload,
66+
final long maxOutputSize) throws WebSocketException {
67+
return decode(type, fin, payload);
68+
}
69+
5470
default ByteBuffer encode(
5571
final WebSocketFrameType type,
5672
final boolean fin,

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/WebSocketFrameReader.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,22 +112,23 @@ WebSocketFrame readFrame() throws IOException {
112112
payload[i] = (byte) (payload[i] ^ maskKey[i % 4]);
113113
}
114114
ByteBuffer data = ByteBuffer.wrap(payload);
115+
final long maxOutputSize = config.getMaxMessageSize();
115116
if (rsv1 && rsv1Extension != null) {
116-
data = rsv1Extension.decode(type, fin, data);
117+
data = rsv1Extension.decode(type, fin, data, maxOutputSize);
117118
continuationCompressed = !fin && (type == WebSocketFrameType.TEXT || type == WebSocketFrameType.BINARY);
118119
} else if (type == WebSocketFrameType.CONTINUATION && continuationCompressed && rsv1Extension != null) {
119-
data = rsv1Extension.decode(type, fin, data);
120+
data = rsv1Extension.decode(type, fin, data, maxOutputSize);
120121
if (fin) {
121122
continuationCompressed = false;
122123
}
123124
} else if (type == WebSocketFrameType.CONTINUATION && fin) {
124125
continuationCompressed = false;
125126
}
126127
if (rsv2 && rsv2Extension != null) {
127-
data = rsv2Extension.decode(type, fin, data);
128+
data = rsv2Extension.decode(type, fin, data, maxOutputSize);
128129
}
129130
if (rsv3 && rsv3Extension != null) {
130-
data = rsv3Extension.decode(type, fin, data);
131+
data = rsv3Extension.decode(type, fin, data, maxOutputSize);
131132
}
132133
return new WebSocketFrame(fin, false, false, false, type, data);
133134
}

httpclient5-websocket/src/main/java/org/apache/hc/core5/websocket/server/WebSocketH2ServerExchangeHandler.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.apache.hc.core5.websocket.WebSocketHandler;
6262
import org.apache.hc.core5.websocket.WebSocketHandshake;
6363
import org.apache.hc.core5.websocket.WebSocketSession;
64+
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
6465

6566
final class WebSocketH2ServerExchangeHandler implements AsyncServerExchangeHandler {
6667

@@ -160,6 +161,13 @@ public void handleRequest(
160161
try {
161162
handler.onOpen(session);
162163
new WebSocketServerProcessor(session, handler, config.getMaxMessageSize()).process();
164+
} catch (final WebSocketProtocolException ex) {
165+
handler.onError(session, ex);
166+
try {
167+
session.close(ex.closeCode, ex.getMessage());
168+
} catch (final IOException ignore) {
169+
// ignore
170+
}
163171
} catch (final Exception ex) {
164172
handler.onError(session, ex);
165173
try {

httpclient5-websocket/src/test/java/org/apache/hc/core5/websocket/PerMessageDeflateExtensionTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@
2626
*/
2727
package org.apache.hc.core5.websocket;
2828

29+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
2930
import static org.junit.jupiter.api.Assertions.assertEquals;
31+
import static org.junit.jupiter.api.Assertions.assertThrows;
3032

3133
import java.io.ByteArrayOutputStream;
3234
import java.nio.ByteBuffer;
3335
import java.nio.charset.StandardCharsets;
36+
import java.util.Arrays;
3437
import java.util.zip.Deflater;
3538

39+
import org.apache.hc.core5.websocket.exceptions.WebSocketProtocolException;
3640
import org.junit.jupiter.api.Test;
3741

3842
class PerMessageDeflateExtensionTest {
@@ -56,6 +60,42 @@ void decodesFragmentedMessage() throws Exception {
5660
assertEquals("fragmented message", WebSocketSession.decodeText(ByteBuffer.wrap(joined.toByteArray())));
5761
}
5862

63+
@Test
64+
void decodeWithinLimitSucceeds() throws Exception {
65+
final byte[] plain = "hello world hello world hello world".getBytes(StandardCharsets.UTF_8);
66+
final byte[] compressed = deflateWithSyncFlush(plain);
67+
68+
final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
69+
final ByteBuffer out = ext.decode(WebSocketFrameType.TEXT, true, ByteBuffer.wrap(compressed), plain.length + 16L);
70+
71+
assertArrayEquals(plain, toBytes(out));
72+
}
73+
74+
@Test
75+
void decodeInflationBombIsRejectedDuringInflate() {
76+
final byte[] plain = new byte[64 * 1024];
77+
Arrays.fill(plain, (byte) 'A');
78+
final byte[] compressed = deflateWithSyncFlush(plain);
79+
80+
final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
81+
final WebSocketProtocolException ex = assertThrows(WebSocketProtocolException.class,
82+
() -> ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 1024L));
83+
assertEquals(1009, ex.closeCode);
84+
assertEquals("Message too big", ex.getMessage());
85+
}
86+
87+
@Test
88+
void decodeZeroLimitMeansUnlimited() throws Exception {
89+
final byte[] plain = new byte[8 * 1024];
90+
Arrays.fill(plain, (byte) 'B');
91+
final byte[] compressed = deflateWithSyncFlush(plain);
92+
93+
final PerMessageDeflateExtension ext = new PerMessageDeflateExtension();
94+
final ByteBuffer out = ext.decode(WebSocketFrameType.BINARY, true, ByteBuffer.wrap(compressed), 0L);
95+
96+
assertArrayEquals(plain, toBytes(out));
97+
}
98+
5999
private static byte[] deflateWithSyncFlush(final byte[] input) {
60100
final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
61101
deflater.setInput(input);

0 commit comments

Comments
 (0)