fix: 修复代理问题

This commit is contained in:
丹尼尔
2026-03-15 17:16:05 +08:00
parent 8b62c445fc
commit 15c9e1772a
100 changed files with 6157 additions and 69 deletions

View File

@@ -0,0 +1,4 @@
from ._proxy import AnyioProxy as Proxy
from ._chain import ProxyChain
__all__ = ('Proxy', 'ProxyChain')

View File

@@ -0,0 +1,42 @@
from typing import Iterable
import warnings
from ._proxy import AnyioProxy
class ProxyChain:
def __init__(self, proxies: Iterable[AnyioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
_stream = None
proxies = list(self._proxies)
length = len(proxies) - 1
for i in range(length):
_stream = await proxies[i].connect(
dest_host=proxies[i + 1].proxy_host,
dest_port=proxies[i + 1].proxy_port,
timeout=timeout,
_stream=_stream,
)
_stream = await proxies[length].connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
_stream=_stream,
)
return _stream

View File

@@ -0,0 +1,16 @@
from typing import Optional
import anyio
import anyio.abc
async def connect_tcp(
host: str,
port: int,
local_host: Optional[str] = None,
) -> anyio.abc.SocketStream:
return await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_host,
)

View File

@@ -0,0 +1,137 @@
import ssl
from typing import Any, Optional
import warnings
import anyio
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._resolver import Resolver
from ._stream import AnyioSocketStream
from ._connect import connect_tcp
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class AnyioProxy:
_stream: Optional[AnyioSocketStream]
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AnyioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_stream = kwargs.get('_stream')
if _stream is not None:
warnings.warn(
"The '_stream' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_host = kwargs.get('local_host')
try:
with anyio.fail_after(timeout):
if _stream is None:
try:
_stream = AnyioSocketStream(
await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_host=local_host,
)
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = _stream
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
return stream
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException:
await stream.close()
raise
except TimeoutError as e:
raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,22 @@
import anyio
import socket
from ... import _abc as abc
class Resolver(abc.AsyncResolver):
async def resolve(self, host, port=0, family=socket.AF_UNSPEC):
infos = await anyio.getaddrinfo(
host=host,
port=port,
family=family,
type=socket.SOCK_STREAM,
)
if not infos: # pragma: no cover
raise OSError('Can`t resolve address {}:{} [{}]'.format(host, port, family))
infos = sorted(infos, key=lambda info: info[0])
family, _, _, _, address = infos[0]
return family, address[0]

View File

@@ -0,0 +1,59 @@
import ssl
from typing import Union
import anyio
import anyio.abc
from anyio.streams.tls import TLSStream
from ..._errors import ProxyError
from ... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
class AnyioSocketStream(abc.AsyncSocketStream):
_stream: AnyioStreamType
def __init__(self, stream: AnyioStreamType) -> None:
self._stream = stream
async def write_all(self, data: bytes):
await self._stream.send(item=data)
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: no cover
return b""
async def read_exact(self, n: int):
data = bytearray()
while len(data) < n:
packet = await self.read(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
) -> 'AnyioSocketStream':
ssl_stream = await TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=hostname,
standard_compatible=False,
server_side=False,
)
return AnyioSocketStream(ssl_stream)
async def close(self):
await self._stream.aclose()
@property
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
return self._stream

View File

@@ -0,0 +1,7 @@
from ._proxy import AnyioProxy as Proxy
from ._chain import ProxyChain
__all__ = (
'Proxy',
'ProxyChain',
)

View File

@@ -0,0 +1,32 @@
from typing import Sequence
import warnings
from ._proxy import AnyioProxy
class ProxyChain:
def __init__(self, proxies: Sequence[AnyioProxy]):
warnings.warn(
'This implementation of ProxyChain is deprecated and will be removed in the future',
DeprecationWarning,
stacklevel=2,
)
self._proxies = proxies
async def connect(
self,
dest_host,
dest_port,
dest_ssl=None,
timeout=None,
):
forward = None
for proxy in self._proxies:
proxy._forward = forward
forward = proxy
return await forward.connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
timeout=timeout,
)

View File

@@ -0,0 +1,17 @@
from typing import Optional
import anyio
import anyio.abc
from ._stream import AnyioSocketStream
async def connect_tcp(
host: str,
port: int,
local_host: Optional[str] = None,
) -> AnyioSocketStream:
s = await anyio.connect_tcp(
remote_host=host,
remote_port=port,
local_host=local_host,
)
return AnyioSocketStream(s)

View File

@@ -0,0 +1,135 @@
import ssl
from typing import Any, Optional
import anyio
from ._connect import connect_tcp
from ._stream import AnyioSocketStream
from .._resolver import Resolver
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class AnyioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['AnyioProxy'] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AnyioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
local_host = kwargs.get('local_host')
try:
with anyio.fail_after(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
local_host=local_host,
)
except TimeoutError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
local_host: Optional[str] = None,
) -> AnyioSocketStream:
if self._forward is None:
try:
stream = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_host=local_host,
)
except OSError as e:
raise ProxyConnectionError(
e.errno,
"Couldn't connect to proxy"
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
) from e
else:
stream = await self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
)
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException:
with anyio.CancelScope(shield=True):
await stream.close()
raise
return stream
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AnyioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)

View File

@@ -0,0 +1,59 @@
import ssl
from typing import Union
import anyio
import anyio.abc
from anyio.streams.tls import TLSStream
from ...._errors import ProxyError
from .... import _abc as abc
DEFAULT_RECEIVE_SIZE = 65536
AnyioStreamType = Union[anyio.abc.SocketStream, TLSStream]
class AnyioSocketStream(abc.AsyncSocketStream):
_stream: AnyioStreamType
def __init__(self, stream: AnyioStreamType) -> None:
self._stream = stream
async def write_all(self, data: bytes):
await self._stream.send(item=data)
async def read(self, max_bytes: int = DEFAULT_RECEIVE_SIZE):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: no cover
return b""
async def read_exact(self, n: int):
data = bytearray()
while len(data) < n:
packet = await self.read(n - len(data))
if not packet: # pragma: no cover
raise ProxyError('Connection closed unexpectedly')
data += packet
return data
async def start_tls(
self,
hostname: str,
ssl_context: ssl.SSLContext,
) -> 'AnyioSocketStream':
ssl_stream = await TLSStream.wrap(
self._stream,
ssl_context=ssl_context,
hostname=hostname,
standard_compatible=False,
server_side=False,
)
return AnyioSocketStream(ssl_stream)
async def close(self):
await self._stream.aclose()
@property
def anyio_stream(self) -> AnyioStreamType: # pragma: no cover
return self._stream