356 lines
9.5 KiB
Python
356 lines
9.5 KiB
Python
import enum
|
|
import ipaddress
|
|
import socket
|
|
from typing import Optional, Union
|
|
from dataclasses import dataclass, field
|
|
|
|
from .errors import ReplyError
|
|
from .._helpers import is_ip_address
|
|
|
|
|
|
RSV = NULL = AUTH_GRANTED = 0x00
|
|
SOCKS_VER = 0x05
|
|
|
|
|
|
class AuthMethod(enum.IntEnum):
|
|
ANONYMOUS = 0x00
|
|
GSSAPI = 0x01
|
|
USERNAME_PASSWORD = 0x02
|
|
NO_ACCEPTABLE = 0xFF
|
|
|
|
|
|
class AddressType(enum.IntEnum):
|
|
IPV4 = 0x01
|
|
DOMAIN = 0x03
|
|
IPV6 = 0x04
|
|
|
|
@classmethod
|
|
def from_ip_ver(cls, ver: int):
|
|
if ver == 4:
|
|
return cls.IPV4
|
|
if ver == 6:
|
|
return cls.IPV6
|
|
|
|
raise ValueError('Invalid IP version')
|
|
|
|
|
|
class Command(enum.IntEnum):
|
|
CONNECT = 0x01
|
|
BIND = 0x02
|
|
UDP_ASSOCIATE = 0x03
|
|
|
|
|
|
class ReplyCode(enum.IntEnum):
|
|
SUCCEEDED = 0x00
|
|
GENERAL_FAILURE = 0x01
|
|
CONNECTION_NOT_ALLOWED = 0x02
|
|
NETWORK_UNREACHABLE = 0x03
|
|
HOST_UNREACHABLE = 0x04
|
|
CONNECTION_REFUSED = 0x05
|
|
TTL_EXPIRED = 0x06
|
|
COMMAND_NOT_SUPPORTED = 0x07
|
|
ADDRESS_TYPE_NOT_SUPPORTED = 0x08
|
|
|
|
|
|
ReplyMessages = {
|
|
ReplyCode.SUCCEEDED: 'Request granted',
|
|
ReplyCode.GENERAL_FAILURE: 'General SOCKS server failure',
|
|
ReplyCode.CONNECTION_NOT_ALLOWED: 'Connection not allowed by ruleset',
|
|
ReplyCode.NETWORK_UNREACHABLE: 'Network unreachable',
|
|
ReplyCode.HOST_UNREACHABLE: 'Host unreachable',
|
|
ReplyCode.CONNECTION_REFUSED: 'Connection refused by destination host',
|
|
ReplyCode.TTL_EXPIRED: 'TTL expired',
|
|
ReplyCode.COMMAND_NOT_SUPPORTED: 'Command not supported or protocol error',
|
|
ReplyCode.ADDRESS_TYPE_NOT_SUPPORTED: 'Address type not supported',
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class AuthMethodsRequest:
|
|
username: Optional[str]
|
|
password: Optional[str]
|
|
methods: bytearray = field(init=False)
|
|
|
|
def __post_init__(self):
|
|
methods = bytearray([AuthMethod.ANONYMOUS])
|
|
|
|
if self.username and self.password:
|
|
methods.append(AuthMethod.USERNAME_PASSWORD)
|
|
|
|
self.methods = methods
|
|
|
|
def dumps(self) -> bytes:
|
|
return bytes([SOCKS_VER, len(self.methods)]) + self.methods
|
|
|
|
|
|
@dataclass
|
|
class AuthMethodReply:
|
|
SIZE = 2
|
|
|
|
ver: int
|
|
method: AuthMethod
|
|
|
|
def validate(self, request: AuthMethodsRequest):
|
|
if self.method not in request.methods: # pragma: no cover
|
|
raise ReplyError(f'Unexpected SOCKS authentication method: {self.method}')
|
|
|
|
@classmethod
|
|
def loads(cls, data: bytes) -> 'AuthMethodReply':
|
|
if len(data) != cls.SIZE:
|
|
raise ReplyError('Malformed authentication method reply')
|
|
|
|
ver = data[0]
|
|
if ver != SOCKS_VER: # pragma: no cover
|
|
raise ReplyError(f'Unexpected SOCKS version number: {ver}')
|
|
|
|
try:
|
|
method = AuthMethod(data[1])
|
|
except ValueError:
|
|
raise ReplyError(f'Invalid authentication method: {data[1]:#02X}')
|
|
|
|
if method == AuthMethod.NO_ACCEPTABLE: # pragma: no cover
|
|
raise ReplyError('No acceptable authentication methods were offered')
|
|
|
|
return cls(ver=ver, method=method)
|
|
|
|
|
|
@dataclass
|
|
class AuthRequest:
|
|
VER = 0x01
|
|
|
|
username: str
|
|
password: str
|
|
|
|
def dumps(self) -> bytes:
|
|
data = bytearray()
|
|
data.append(self.VER)
|
|
data.append(len(self.username))
|
|
data += self.username.encode('ascii')
|
|
data.append(len(self.password))
|
|
data += self.password.encode('ascii')
|
|
return bytes(data)
|
|
|
|
|
|
@dataclass
|
|
class AuthReply:
|
|
SIZE = 2
|
|
|
|
ver: int
|
|
status: int
|
|
|
|
@classmethod
|
|
def loads(cls, data: bytes) -> 'AuthReply':
|
|
if len(data) != cls.SIZE:
|
|
raise ReplyError('Malformed auth reply')
|
|
|
|
ver = data[0]
|
|
if ver != AuthRequest.VER: # pragma: no cover
|
|
raise ReplyError('Invalid authentication response')
|
|
|
|
status = data[1]
|
|
if status != AUTH_GRANTED: # pragma: no cover
|
|
raise ReplyError('Username and password authentication failure')
|
|
|
|
return cls(ver=ver, status=status)
|
|
|
|
|
|
@dataclass
|
|
class ConnectRequest:
|
|
host: str # hostname or IPv4 or IPv6 address
|
|
port: int
|
|
|
|
def dumps(self) -> bytes:
|
|
data = bytearray([SOCKS_VER, Command.CONNECT, RSV])
|
|
data += self._build_addr_request()
|
|
return bytes(data)
|
|
|
|
def _build_addr_request(self) -> bytes:
|
|
port = self.port.to_bytes(2, 'big')
|
|
|
|
if is_ip_address(self.host):
|
|
ip = ipaddress.ip_address(self.host)
|
|
address_type = AddressType.from_ip_ver(ip.version)
|
|
return bytes([address_type]) + ip.packed + port
|
|
else:
|
|
address_type = AddressType.DOMAIN
|
|
host = self.host.encode('idna')
|
|
return bytes([address_type, len(host)]) + host + port
|
|
|
|
|
|
@dataclass
|
|
class ConnectReply:
|
|
ver: int
|
|
reply: ReplyCode
|
|
rsv: int
|
|
bound_host: str
|
|
bound_port: int
|
|
|
|
def validate(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def loads(cls, data: bytes) -> 'ConnectReply':
|
|
if not data:
|
|
raise ReplyError('Empty connect reply')
|
|
|
|
ver = data[0]
|
|
if ver != SOCKS_VER: # pragma: no cover
|
|
raise ReplyError(f'Unexpected SOCKS version number: {ver:#02X}')
|
|
|
|
try:
|
|
reply = ReplyCode(data[1])
|
|
except IndexError:
|
|
raise ReplyError('Malformed connect reply')
|
|
except ValueError:
|
|
raise ReplyError(f'Invalid reply code: {data[1]:#02X}')
|
|
|
|
if reply != ReplyCode.SUCCEEDED: # pragma: no cover
|
|
msg = ReplyMessages.get(reply, 'Unknown error') # type: ignore
|
|
raise ReplyError(msg, error_code=reply)
|
|
|
|
try:
|
|
rsv = data[2]
|
|
except IndexError:
|
|
raise ReplyError('Malformed connect reply')
|
|
|
|
if rsv != RSV: # pragma: no cover
|
|
raise ReplyError(f'The reserved byte must be {RSV:#02X}')
|
|
|
|
try:
|
|
addr_type = data[3]
|
|
bnd_host_data = data[4:-2]
|
|
bnd_port_data = data[-2:]
|
|
except IndexError:
|
|
raise ReplyError('Malformed connect reply')
|
|
|
|
if addr_type == AddressType.IPV4:
|
|
bnd_host = socket.inet_ntop(socket.AF_INET, bnd_host_data)
|
|
elif addr_type == AddressType.IPV6:
|
|
bnd_host = socket.inet_ntop(socket.AF_INET6, bnd_host_data)
|
|
elif addr_type == AddressType.DOMAIN: # pragma: no cover
|
|
# host_len = bnd_host_data[0]
|
|
bnd_host = bnd_host_data[1:].decode()
|
|
else: # pragma: no cover
|
|
raise ReplyError(f'Invalid address type: {addr_type:#02X}')
|
|
|
|
bnd_port = int.from_bytes(bnd_port_data, 'big')
|
|
|
|
return cls(
|
|
ver=ver,
|
|
reply=reply,
|
|
rsv=rsv,
|
|
bound_host=bnd_host,
|
|
bound_port=bnd_port,
|
|
)
|
|
|
|
|
|
class StateServerWaitingForAuthMethods:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class StateClientSentAuthMethods:
|
|
data: AuthMethodsRequest
|
|
|
|
|
|
@dataclass
|
|
class StateServerWaitingForAuth:
|
|
data: AuthMethodReply
|
|
|
|
|
|
@dataclass
|
|
class StateClientAuthenticated:
|
|
data: Optional[AuthReply] = None
|
|
|
|
|
|
@dataclass
|
|
class StateClientSentAuthRequest:
|
|
data: AuthRequest
|
|
|
|
|
|
@dataclass
|
|
class StateClientSentConnectRequest:
|
|
data: ConnectRequest
|
|
|
|
|
|
@dataclass
|
|
class StateServerConnected:
|
|
data: ConnectReply
|
|
|
|
|
|
Request = Union[
|
|
AuthMethodsRequest,
|
|
AuthRequest,
|
|
ConnectRequest,
|
|
]
|
|
|
|
Reply = Union[
|
|
AuthMethodReply,
|
|
AuthReply,
|
|
ConnectReply,
|
|
]
|
|
|
|
ConnectionState = Union[
|
|
StateServerWaitingForAuthMethods,
|
|
StateClientSentAuthMethods,
|
|
StateServerWaitingForAuth,
|
|
StateClientSentAuthRequest,
|
|
StateClientAuthenticated,
|
|
StateClientSentConnectRequest,
|
|
StateServerConnected,
|
|
]
|
|
|
|
|
|
class Connection:
|
|
_state: ConnectionState
|
|
|
|
def __init__(self):
|
|
self._state = StateServerWaitingForAuthMethods()
|
|
|
|
def send(self, request: Request) -> bytes:
|
|
if type(request) is AuthMethodsRequest:
|
|
if type(self._state) is not StateServerWaitingForAuthMethods:
|
|
raise RuntimeError('Server is not currently waiting for auth methods')
|
|
self._state = StateClientSentAuthMethods(request)
|
|
return request.dumps()
|
|
|
|
if type(request) is AuthRequest:
|
|
if type(self._state) is not StateServerWaitingForAuth:
|
|
raise RuntimeError('Server is not currently waiting for authentication')
|
|
self._state = StateClientSentAuthRequest(request)
|
|
return request.dumps()
|
|
|
|
if type(request) is ConnectRequest:
|
|
if type(self._state) is not StateClientAuthenticated:
|
|
raise RuntimeError('Client is not authenticated')
|
|
self._state = StateClientSentConnectRequest(request)
|
|
return request.dumps()
|
|
|
|
raise RuntimeError(f'Invalid request type: {type(request)}')
|
|
|
|
def receive(self, data: bytes) -> Reply:
|
|
if type(self._state) is StateClientSentAuthMethods:
|
|
reply = AuthMethodReply.loads(data)
|
|
reply.validate(self._state.data)
|
|
if reply.method == AuthMethod.USERNAME_PASSWORD:
|
|
self._state = StateServerWaitingForAuth(data=reply)
|
|
else:
|
|
self._state = StateClientAuthenticated()
|
|
return reply
|
|
|
|
if type(self._state) is StateClientSentAuthRequest:
|
|
reply = AuthReply.loads(data)
|
|
self._state = StateClientAuthenticated(data=reply)
|
|
return reply
|
|
|
|
if type(self._state) is StateClientSentConnectRequest:
|
|
reply = ConnectReply.loads(data)
|
|
self._state = StateServerConnected(data=reply)
|
|
return reply
|
|
|
|
raise RuntimeError(f'Invalid connection state: {self._state}')
|
|
|
|
@property
|
|
def state(self):
|
|
return self._state
|