Skip to content

Commit b765238

Browse files
committed
[fix] return wait_writable on non-blocking reads
1 parent 5f8c201 commit b765238

2 files changed

Lines changed: 89 additions & 18 deletions

File tree

src/main/java/org/jruby/ext/openssl/SSLSocket.java

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static CallSite callSite(final CallSite[] sites, final CallSiteIndex ind
141141
return sites[ index.ordinal() ];
142142
}
143143

144-
private SSLContext sslContext;
144+
SSLContext sslContext;
145145
private SSLEngine engine;
146146
private RubyIO io;
147147

@@ -209,7 +209,7 @@ private IRubyObject fallback_set_io_nonblock_checked(ThreadContext context, Ruby
209209

210210
private static final String SESSION_SOCKET_ID = "socket_id";
211211

212-
private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) {
212+
SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) {
213213
SSLEngine engine = this.engine;
214214
if ( engine != null ) return engine;
215215

@@ -574,7 +574,7 @@ private IRubyObject doHandshake(final boolean blocking, final boolean exception)
574574
doTasks();
575575
break;
576576
case NEED_UNWRAP:
577-
if (readAndUnwrap(blocking) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
577+
if (readAndUnwrap(blocking, exception) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) {
578578
throw new SSLHandshakeException("Socket closed");
579579
}
580580
// during initialHandshake, calling readAndUnwrap that results UNDERFLOW does not mean writable.
@@ -721,10 +721,6 @@ private int read(final ByteBuffer dst, final boolean blocking, final boolean exc
721721
return limit;
722722
}
723723

724-
private int readAndUnwrap(final boolean blocking) throws IOException {
725-
return readAndUnwrap(blocking, true);
726-
}
727-
728724
private int readAndUnwrap(final boolean blocking, final boolean exception) throws IOException {
729725
final int bytesRead = socketChannelImpl().read(netReadData);
730726
if ( bytesRead == -1 ) {
@@ -813,10 +809,10 @@ private void doShutdown() throws IOException {
813809
}
814810

815811
/**
816-
* @return the (@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol}
812+
* @return the {@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol}
817813
*/
818-
private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject len, final IRubyObject buff,
819-
final boolean blocking, final boolean exception) {
814+
private IRubyObject sysreadImpl(final ThreadContext context,
815+
final IRubyObject len, final IRubyObject buff, final boolean blocking, final boolean exception) {
820816
final Ruby runtime = context.runtime;
821817

822818
final int length = RubyNumeric.fix2int(len);
@@ -836,11 +832,12 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l
836832
}
837833

838834
try {
839-
// flush any pending encrypted write data before reading; after write_nonblock,
840-
// encrypted bytes may remain in the buffer that haven't been sent, if we read wout flushing,
841-
// server may not have received the complete request (e.g. net/http POST body) and will not respond
835+
// Flush pending write data before reading (after write_nonblock encrypted bytes may still be buffered)
842836
if ( engine != null && netWriteData.hasRemaining() ) {
843-
flushData(blocking);
837+
if ( flushData(blocking) && ! blocking ) {
838+
if ( exception ) throw newSSLErrorWaitWritable(runtime, "write would block");
839+
return runtime.newSymbol("wait_writable");
840+
}
844841
}
845842

846843
// So we need to make sure to only block when there is no data left to process
@@ -851,7 +848,7 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l
851848

852849
final ByteBuffer dst = ByteBuffer.allocate(length);
853850
int read = -1;
854-
// ensure >0 bytes read; sysread is blocking read.
851+
// ensure > 0 bytes read; sysread is blocking read
855852
while ( read <= 0 ) {
856853
if ( engine == null ) {
857854
read = socketChannelImpl().read(dst);
@@ -1238,7 +1235,7 @@ public IRubyObject ssl_version(ThreadContext context) {
12381235
return context.runtime.newString( engine.getSession().getProtocol() );
12391236
}
12401237

1241-
private transient SocketChannelImpl socketChannel;
1238+
transient SocketChannelImpl socketChannel;
12421239

12431240
private SocketChannelImpl socketChannelImpl() {
12441241
if ( socketChannel != null ) return socketChannel;
@@ -1253,7 +1250,7 @@ private SocketChannelImpl socketChannelImpl() {
12531250
throw new IllegalStateException("unknow channel impl: " + channel + " of type " + channel.getClass().getName());
12541251
}
12551252

1256-
private interface SocketChannelImpl {
1253+
interface SocketChannelImpl {
12571254

12581255
boolean isOpen() ;
12591256

src/test/java/org/jruby/ext/openssl/SSLSocketTest.java

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
package org.jruby.ext.openssl;
22

3+
import java.io.IOException;
34
import java.nio.ByteBuffer;
5+
import java.nio.channels.SelectionKey;
6+
import java.nio.channels.Selector;
7+
import javax.net.ssl.SSLEngine;
48

9+
import org.jruby.Ruby;
510
import org.jruby.RubyArray;
611
import org.jruby.RubyFixnum;
12+
import org.jruby.RubyHash;
713
import org.jruby.RubyInteger;
814
import org.jruby.RubyString;
915
import org.jruby.exceptions.RaiseException;
@@ -45,7 +51,7 @@ public void tearDown() {
4551
* discarded during partial non-blocking writes, so the server would
4652
* receive fewer bytes than sent.
4753
*/
48-
@org.junit.jupiter.api.Test
54+
@Test
4955
public void syswriteNonblockDataIntegrity() throws Exception {
5056
final RubyArray pair = (RubyArray) runtime.evalScriptlet(start_ssl_server_rb());
5157
SSLSocket client = (SSLSocket) pair.entry(0).toJava(SSLSocket.class);
@@ -173,4 +179,72 @@ private void closeQuietly(final RubyArray sslPair) {
173179
}
174180
}
175181
}
182+
183+
// ----------
184+
185+
/**
186+
* MRI's ossl_ssl_read_internal returns :wait_writable (or raises SSLErrorWaitWritable / "write would block")
187+
* when SSL_read hits SSL_ERROR_WANT_WRITE. Pending netWriteData is JRuby's equivalent state.
188+
*/
189+
@Test
190+
public void sysreadNonblockReturnsWaitWritableWhenPendingEncryptedBytesRemain() {
191+
final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1));
192+
final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false);
193+
engine.setUseClientMode(true);
194+
195+
socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 });
196+
197+
final RubyHash opts = RubyHash.newKwargs(runtime, "exception", runtime.getFalse()); // exception: false
198+
final IRubyObject result = socket.sysread_nonblock(currentContext(), runtime.newFixnum(1), opts);
199+
200+
assertEquals("wait_writable", result.asJavaString());
201+
assertEquals(1, socket.netWriteData.remaining());
202+
}
203+
204+
@Test
205+
public void sysreadNonblockRaisesWaitWritableWhenPendingEncryptedBytesRemain() {
206+
final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1));
207+
final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false);
208+
engine.setUseClientMode(true);
209+
210+
socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 });
211+
212+
try {
213+
socket.sysread_nonblock(currentContext(), runtime.newFixnum(1));
214+
fail("expected SSLErrorWaitWritable");
215+
}
216+
catch (RaiseException ex) {
217+
assertEquals("OpenSSL::SSL::SSLErrorWaitWritable", ex.getException().getMetaClass().getName());
218+
assertTrue(ex.getMessage().contains("write would block"));
219+
assertEquals(1, socket.netWriteData.remaining());
220+
}
221+
}
222+
223+
private static SSLSocket newSSLSocket(final Ruby runtime, final SSLSocket.SocketChannelImpl socketChannel) {
224+
final SSLContext sslContext = new SSLContext(runtime);
225+
sslContext.doSetup(runtime.getCurrentContext());
226+
final SSLSocket sslSocket = new SSLSocket(runtime, runtime.getObject());
227+
sslSocket.sslContext = sslContext;
228+
sslSocket.socketChannel = socketChannel;
229+
return sslSocket;
230+
}
231+
232+
private static SSLSocket.SocketChannelImpl partialWriteChannel(final int bytesPerWrite) {
233+
return new SSLSocket.SocketChannelImpl() {
234+
public boolean isOpen() { return true; }
235+
public int read(final ByteBuffer dst) { return 0; }
236+
public int write(final ByteBuffer src) {
237+
final int written = Math.min(bytesPerWrite, src.remaining());
238+
src.position(src.position() + written);
239+
return written;
240+
}
241+
public int getRemotePort() { return 443; }
242+
public boolean isSelectable() { return false; }
243+
public boolean isBlocking() { return false; }
244+
public void configureBlocking(final boolean block) { }
245+
public SelectionKey register(final Selector selector, final int ops) throws IOException {
246+
throw new UnsupportedOperationException();
247+
}
248+
};
249+
}
176250
}

0 commit comments

Comments
 (0)