Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
git-svn-id: https://svn.apache.org/repos/asf/activemq/trunk@1399438 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
Timothy A. Bish committed Oct 17, 2012
1 parent b5e46ef commit 65af81e
Showing 1 changed file with 88 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.wireformat.WireFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NIOSSLTransport extends NIOTransport {
public class NIOSSLTransport extends NIOTransport {

private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);

protected boolean needClientAuth;
protected boolean wantClientAuth;
Expand Down Expand Up @@ -79,15 +83,36 @@ protected void initializeStreams() throws IOException {
sslContext = SSLContext.getDefault();
}

String remoteHost = null;
int remotePort = -1;

try {
URI remoteAddress = new URI(this.getRemoteAddress());
remoteHost = remoteAddress.getHost();
remotePort = remoteAddress.getPort();
} catch (Exception e) {
}

// initialize engine, the initial sslSession we get will need to be
// updated once the ssl handshake process is completed.
sslEngine = sslContext.createSSLEngine();
if (remoteHost != null && remotePort != -1) {
sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
} else {
sslEngine = sslContext.createSSLEngine();
}

sslEngine.setUseClientMode(false);
if (enabledCipherSuites != null) {
sslEngine.setEnabledCipherSuites(enabledCipherSuites);
}
sslEngine.setNeedClientAuth(needClientAuth);
sslEngine.setWantClientAuth(wantClientAuth);

if (wantClientAuth) {
sslEngine.setWantClientAuth(wantClientAuth);
}

if (needClientAuth) {
sslEngine.setNeedClientAuth(needClientAuth);
}

sslSession = sslEngine.getSession();

Expand All @@ -107,31 +132,31 @@ protected void initializeStreams() throws IOException {
}
}

protected void finishHandshake() throws Exception {
if (handshakeInProgress) {
handshakeInProgress = false;
nextFrameSize = -1;

// Once handshake completes we need to ask for the now real sslSession
// otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
// cipher suite.
sslSession = sslEngine.getSession();

// listen for events telling us when the socket is readable.
selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
public void onSelect(SelectorSelection selection) {
serviceRead();
}

public void onError(SelectorSelection selection, Throwable error) {
if (error instanceof IOException) {
onException((IOException) error);
} else {
onException(IOExceptionSupport.create(error));
}
}
});
}
protected void finishHandshake() throws Exception {
if (handshakeInProgress) {
handshakeInProgress = false;
nextFrameSize = -1;

// Once handshake completes we need to ask for the now real sslSession
// otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
// cipher suite.
sslSession = sslEngine.getSession();

// listen for events telling us when the socket is readable.
selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
public void onSelect(SelectorSelection selection) {
serviceRead();
}

public void onError(SelectorSelection selection, Throwable error) {
if (error instanceof IOException) {
onException((IOException) error);
} else {
onException(IOExceptionSupport.create(error));
}
}
});
}
}

protected void serviceRead() {
Expand All @@ -143,7 +168,7 @@ protected void serviceRead() {
ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
plain.position(plain.limit());

while(true) {
while (true) {
if (!plain.hasRemaining()) {

if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
Expand All @@ -153,12 +178,11 @@ protected void serviceRead() {
}
int readCount = secureRead(plain);


if (readCount == 0)
break;

// channel is closed, cleanup
if (readCount== -1) {
if (readCount == -1) {
onException(new EOFException());
selection.close();
break;
Expand All @@ -181,7 +205,8 @@ protected void processCommand(ByteBuffer plain) throws Exception {
if (wireFormat instanceof OpenWireFormat) {
long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
if (nextFrameSize > maxFrameSize) {
throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
" MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
}
}
currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
Expand Down Expand Up @@ -213,8 +238,7 @@ protected int secureRead(ByteBuffer plain) throws Exception {

if (bytesRead == -1) {
sslEngine.closeInbound();
if (inputBuffer.position() == 0 ||
status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
return -1;
}
}
Expand All @@ -226,18 +250,17 @@ protected int secureRead(ByteBuffer plain) throws Exception {
SSLEngineResult res;
do {
res = sslEngine.unwrap(inputBuffer, plain);
} while (res.getStatus() == SSLEngineResult.Status.OK &&
res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
res.bytesProduced() == 0);
} while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
&& res.bytesProduced() == 0);

if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
finishHandshake();
finishHandshake();
}

status = res.getStatus();
handshakeStatus = res.getHandshakeStatus();

//TODO deal with BUFFER_OVERFLOW
// TODO deal with BUFFER_OVERFLOW

if (status == SSLEngineResult.Status.CLOSED) {
sslEngine.closeInbound();
Expand All @@ -254,22 +277,22 @@ protected void doHandshake() throws Exception {
handshakeInProgress = true;
while (true) {
switch (sslEngine.getHandshakeStatus()) {
case NEED_UNWRAP:
secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
break;
case NEED_TASK:
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
taskRunnerFactory.execute(task);
}
break;
case NEED_WRAP:
((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
break;
case FINISHED:
case NOT_HANDSHAKING:
finishHandshake();
return;
case NEED_UNWRAP:
secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
break;
case NEED_TASK:
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
taskRunnerFactory.execute(task);
}
break;
case NEED_WRAP:
((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
break;
case FINISHED:
case NOT_HANDSHAKING:
finishHandshake();
return;
}
}
}
Expand All @@ -295,14 +318,15 @@ protected void doStop(ServiceStopper stopper) throws Exception {
}

/**
* Overriding in order to add the client's certificates to ConnectionInfo Commmands.
* Overriding in order to add the client's certificates to ConnectionInfo Commands.
*
* @param command The Command coming in.
* @param command
* The Command coming in.
*/
@Override
public void doConsume(Object command) {
if (command instanceof ConnectionInfo) {
ConnectionInfo connectionInfo = (ConnectionInfo)command;
ConnectionInfo connectionInfo = (ConnectionInfo) command;
connectionInfo.setTransportContext(getPeerCertificates());
}
super.doConsume(command);
Expand All @@ -315,10 +339,13 @@ public X509Certificate[] getPeerCertificates() {

X509Certificate[] clientCertChain = null;
try {
if (sslSession != null) {
clientCertChain = (X509Certificate[])sslSession.getPeerCertificates();
if (sslEngine.getSession() != null) {
clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
}
} catch (SSLPeerUnverifiedException e) {
if (LOG.isTraceEnabled()) {
LOG.trace("Failed to get peer certificates.", e);
}
}

return clientCertChain;
Expand Down

0 comments on commit 65af81e

Please sign in to comment.