ensure all bytes are read in socks handshake. fixes #4188

This commit is contained in:
Daniel Gultsch 2021-10-02 14:24:36 +02:00
parent 586fff5485
commit da14f83a42
2 changed files with 106 additions and 70 deletions

View file

@ -1,5 +1,7 @@
package eu.siacs.conversations.utils; package eu.siacs.conversations.utils;
import com.google.common.io.ByteStreams;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -12,76 +14,108 @@ import eu.siacs.conversations.Config;
public class SocksSocketFactory { public class SocksSocketFactory {
private static final byte[] LOCALHOST = new byte[]{127,0,0,1}; private static final byte[] LOCALHOST = new byte[]{127, 0, 0, 1};
public static void createSocksConnection(final Socket socket, final String destination, final int port) throws IOException { public static void createSocksConnection(final Socket socket, final String destination, final int port) throws IOException {
final InputStream proxyIs = socket.getInputStream(); //TODO use different Socks Addr Type if destination is IP or IPv6
final OutputStream proxyOs = socket.getOutputStream(); final InputStream proxyIs = socket.getInputStream();
proxyOs.write(new byte[]{0x05, 0x01, 0x00}); final OutputStream proxyOs = socket.getOutputStream();
proxyOs.flush(); proxyOs.write(new byte[]{0x05, 0x01, 0x00});
final byte[] handshake = new byte[2]; proxyOs.flush();
proxyIs.read(handshake); final byte[] handshake = new byte[2];
if (handshake[0] != 0x05 || handshake[1] != 0x00) { ByteStreams.readFully(proxyIs, handshake);
throw new SocksConnectionException("Socks 5 handshake failed"); if (handshake[0] != 0x05 || handshake[1] != 0x00) {
} throw new SocksConnectionException("Socks 5 handshake failed");
final byte[] dest = destination.getBytes(); }
final ByteBuffer request = ByteBuffer.allocate(7 + dest.length); final byte[] dest = destination.getBytes();
request.put(new byte[]{0x05, 0x01, 0x00, 0x03}); final ByteBuffer request = ByteBuffer.allocate(7 + dest.length);
request.put((byte) dest.length); request.put(new byte[]{0x05, 0x01, 0x00, 0x03});
request.put(dest); request.put((byte) dest.length);
request.putShort((short) port); request.put(dest);
proxyOs.write(request.array()); request.putShort((short) port);
proxyOs.flush(); proxyOs.write(request.array());
final byte[] response = new byte[7 + dest.length]; proxyOs.flush();
proxyIs.read(response); final byte[] response = new byte[4];
if (response[1] != 0x00) { ByteStreams.readFully(proxyIs, response);
if (response[1] == 0x04) { final byte ver = response[0];
throw new HostNotFoundException("Host unreachable"); if (ver != 0x05) {
} throw new IOException(String.format("Unknown Socks version %02X ", ver));
if (response[1] == 0x05) { }
throw new HostNotFoundException("Connection refused"); final byte status = response[1];
} final byte bndAddrType = response[3];
throw new SocksConnectionException("Unable to connect to destination "+(int) (response[1])); final byte[] bndDestination = readDestination(bndAddrType, proxyIs);
} final byte[] bndPort = new byte[2];
} if (bndAddrType == 0x03) {
final String receivedDestination = new String(bndDestination);
if (!receivedDestination.equalsIgnoreCase(destination)) {
throw new IOException(String.format("Destination mismatch. Received %s Expected %s", receivedDestination, destination));
}
}
ByteStreams.readFully(proxyIs, bndPort);
if (status != 0x00) {
if (status == 0x04) {
throw new HostNotFoundException("Host unreachable");
}
if (status == 0x05) {
throw new HostNotFoundException("Connection refused");
}
throw new IOException(String.format("Unknown status code %02X ", status));
}
}
public static boolean contains(byte needle, byte[] haystack) { private static byte[] readDestination(final byte type, final InputStream inputStream) throws IOException {
for(byte hay : haystack) { final byte[] bndDestination;
if (hay == needle) { if (type == 0x01) {
return true; bndDestination = new byte[4];
} } else if (type == 0x03) {
} final int length = inputStream.read();
return false; bndDestination = new byte[length];
} } else if (type == 0x04) {
bndDestination = new byte[16];
} else {
throw new IOException(String.format("Unknown Socks address type %02X ", type));
}
ByteStreams.readFully(inputStream, bndDestination);
return bndDestination;
}
private static Socket createSocket(InetSocketAddress address, String destination, int port) throws IOException { public static boolean contains(byte needle, byte[] haystack) {
Socket socket = new Socket(); for (byte hay : haystack) {
try { if (hay == needle) {
socket.connect(address, Config.CONNECT_TIMEOUT * 1000); return true;
} catch (IOException e) { }
throw new SocksProxyNotFoundException(); }
} return false;
createSocksConnection(socket, destination, port); }
return socket;
}
public static Socket createSocketOverTor(String destination, int port) throws IOException { private static Socket createSocket(InetSocketAddress address, String destination, int port) throws IOException {
return createSocket(new InetSocketAddress(InetAddress.getByAddress(LOCALHOST), 9050), destination, port); Socket socket = new Socket();
} try {
socket.connect(address, Config.CONNECT_TIMEOUT * 1000);
} catch (IOException e) {
throw new SocksProxyNotFoundException();
}
createSocksConnection(socket, destination, port);
return socket;
}
private static class SocksConnectionException extends IOException { public static Socket createSocketOverTor(String destination, int port) throws IOException {
SocksConnectionException(String message) { return createSocket(new InetSocketAddress(InetAddress.getByAddress(LOCALHOST), 9050), destination, port);
super(message); }
}
}
public static class SocksProxyNotFoundException extends IOException { private static class SocksConnectionException extends IOException {
SocksConnectionException(String message) {
super(message);
}
}
} public static class SocksProxyNotFoundException extends IOException {
public static class HostNotFoundException extends SocksConnectionException { }
HostNotFoundException(String message) {
super(message); public static class HostNotFoundException extends SocksConnectionException {
} HostNotFoundException(String message) {
} super(message);
}
}
} }

View file

@ -3,6 +3,8 @@ package eu.siacs.conversations.xmpp.jingle;
import android.os.PowerManager; import android.os.PowerManager;
import android.util.Log; import android.util.Log;
import com.google.common.io.ByteStreams;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -114,26 +116,26 @@ public class JingleSocks5Transport extends JingleTransport {
final byte[] authBegin = new byte[2]; final byte[] authBegin = new byte[2];
final InputStream inputStream = socket.getInputStream(); final InputStream inputStream = socket.getInputStream();
final OutputStream outputStream = socket.getOutputStream(); final OutputStream outputStream = socket.getOutputStream();
inputStream.read(authBegin); ByteStreams.readFully(inputStream, authBegin);
if (authBegin[0] != 0x5) { if (authBegin[0] != 0x5) {
socket.close(); socket.close();
} }
final short methodCount = authBegin[1]; final short methodCount = authBegin[1];
final byte[] methods = new byte[methodCount]; final byte[] methods = new byte[methodCount];
inputStream.read(methods); ByteStreams.readFully(inputStream, methods);
if (SocksSocketFactory.contains((byte) 0x00, methods)) { if (SocksSocketFactory.contains((byte) 0x00, methods)) {
outputStream.write(new byte[]{0x05, 0x00}); outputStream.write(new byte[]{0x05, 0x00});
} else { } else {
outputStream.write(new byte[]{0x05, (byte) 0xff}); outputStream.write(new byte[]{0x05, (byte) 0xff});
} }
byte[] connectCommand = new byte[4]; final byte[] connectCommand = new byte[4];
inputStream.read(connectCommand); ByteStreams.readFully(inputStream, connectCommand);
if (connectCommand[0] == 0x05 && connectCommand[1] == 0x01 && connectCommand[3] == 0x03) { if (connectCommand[0] == 0x05 && connectCommand[1] == 0x01 && connectCommand[3] == 0x03) {
int destinationCount = inputStream.read(); int destinationCount = inputStream.read();
final byte[] destination = new byte[destinationCount]; final byte[] destination = new byte[destinationCount];
inputStream.read(destination); ByteStreams.readFully(inputStream, destination);
final byte[] port = new byte[2]; final byte[] port = new byte[2];
inputStream.read(port); ByteStreams.readFully(inputStream, port);
final String receivedDestination = new String(destination); final String receivedDestination = new String(destination);
final ByteBuffer response = ByteBuffer.allocate(7 + destination.length); final ByteBuffer response = ByteBuffer.allocate(7 + destination.length);
final byte[] responseHeader; final byte[] responseHeader;