259 lines
8.3 KiB
Python
259 lines
8.3 KiB
Python
import ssl
|
|
|
|
import sniffio
|
|
from httpcore import (
|
|
AsyncConnectionPool,
|
|
Origin,
|
|
AsyncConnectionInterface,
|
|
Request,
|
|
Response,
|
|
default_ssl_context,
|
|
AsyncHTTP11Connection,
|
|
ConnectionNotAvailable,
|
|
)
|
|
from httpcore import AsyncNetworkStream
|
|
from httpcore._synchronization import AsyncLock
|
|
from python_socks import ProxyType, parse_proxy_url
|
|
|
|
|
|
class AsyncProxy(AsyncConnectionPool):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
proxy_type: ProxyType,
|
|
proxy_host: str,
|
|
proxy_port: int,
|
|
username=None,
|
|
password=None,
|
|
rdns=None,
|
|
proxy_ssl: ssl.SSLContext = None,
|
|
loop=None,
|
|
**kwargs,
|
|
):
|
|
self._proxy_type = proxy_type
|
|
self._proxy_host = proxy_host
|
|
self._proxy_port = proxy_port
|
|
self._username = username
|
|
self._password = password
|
|
self._rdns = rdns
|
|
self._proxy_ssl = proxy_ssl
|
|
self._loop = loop
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
def create_connection(self, origin: Origin) -> AsyncConnectionInterface:
|
|
return AsyncProxyConnection(
|
|
proxy_type=self._proxy_type,
|
|
proxy_host=self._proxy_host,
|
|
proxy_port=self._proxy_port,
|
|
username=self._username,
|
|
password=self._password,
|
|
rdns=self._rdns,
|
|
proxy_ssl=self._proxy_ssl,
|
|
loop=self._loop,
|
|
remote_origin=origin,
|
|
ssl_context=self._ssl_context,
|
|
keepalive_expiry=self._keepalive_expiry,
|
|
http1=self._http1,
|
|
http2=self._http2,
|
|
)
|
|
|
|
@classmethod
|
|
def from_url(cls, url, **kwargs):
|
|
proxy_type, host, port, username, password = parse_proxy_url(url)
|
|
return cls(
|
|
proxy_type=proxy_type,
|
|
proxy_host=host,
|
|
proxy_port=port,
|
|
username=username,
|
|
password=password,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class AsyncProxyConnection(AsyncConnectionInterface):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
proxy_type: ProxyType,
|
|
proxy_host: str,
|
|
proxy_port: int,
|
|
username=None,
|
|
password=None,
|
|
rdns=None,
|
|
proxy_ssl: ssl.SSLContext = None,
|
|
loop=None,
|
|
remote_origin: Origin,
|
|
ssl_context: ssl.SSLContext,
|
|
keepalive_expiry: float = None,
|
|
http1: bool = True,
|
|
http2: bool = False,
|
|
) -> None:
|
|
|
|
if ssl_context is None: # pragma: no cover
|
|
ssl_context = default_ssl_context()
|
|
|
|
self._proxy_type = proxy_type
|
|
self._proxy_host = proxy_host
|
|
self._proxy_port = proxy_port
|
|
self._username = username
|
|
self._password = password
|
|
self._rdns = rdns
|
|
self._proxy_ssl = proxy_ssl
|
|
self._loop = loop
|
|
|
|
self._remote_origin = remote_origin
|
|
self._ssl_context = ssl_context
|
|
self._keepalive_expiry = keepalive_expiry
|
|
self._http1 = http1
|
|
self._http2 = http2
|
|
|
|
self._connect_lock = AsyncLock()
|
|
self._connection = None
|
|
self._connect_failed: bool = False
|
|
|
|
async def handle_async_request(self, request: Request) -> Response:
|
|
timeouts = request.extensions.get('timeout', {})
|
|
timeout = timeouts.get('connect', None)
|
|
|
|
try:
|
|
async with self._connect_lock:
|
|
if self._connection is None:
|
|
stream = await self._connect_via_proxy(
|
|
origin=self._remote_origin,
|
|
connect_timeout=timeout,
|
|
)
|
|
|
|
ssl_object = stream.get_extra_info("ssl_object")
|
|
http2_negotiated = (
|
|
ssl_object is not None and ssl_object.selected_alpn_protocol() == "h2"
|
|
)
|
|
if http2_negotiated or (self._http2 and not self._http1):
|
|
from httpcore import AsyncHTTP2Connection
|
|
|
|
self._connection = AsyncHTTP2Connection(
|
|
origin=self._remote_origin,
|
|
stream=stream,
|
|
keepalive_expiry=self._keepalive_expiry,
|
|
)
|
|
else:
|
|
self._connection = AsyncHTTP11Connection(
|
|
origin=self._remote_origin,
|
|
stream=stream,
|
|
keepalive_expiry=self._keepalive_expiry,
|
|
)
|
|
elif not self._connection.is_available(): # pragma: no cover
|
|
raise ConnectionNotAvailable()
|
|
except BaseException as exc:
|
|
self._connect_failed = True
|
|
raise exc
|
|
|
|
return await self._connection.handle_async_request(request)
|
|
|
|
async def _connect_via_proxy(self, origin, connect_timeout) -> AsyncNetworkStream:
|
|
scheme, hostname, port = origin.scheme, origin.host, origin.port
|
|
|
|
ssl_context = self._ssl_context if scheme == b'https' else None
|
|
host = hostname.decode('ascii') # ?
|
|
|
|
return await self._open_stream(
|
|
host=host,
|
|
port=port,
|
|
connect_timeout=connect_timeout,
|
|
ssl_context=ssl_context,
|
|
)
|
|
|
|
async def _open_stream(self, host, port, connect_timeout, ssl_context):
|
|
backend = sniffio.current_async_library()
|
|
|
|
if backend == 'asyncio':
|
|
return await self._open_aio_stream(host, port, connect_timeout, ssl_context)
|
|
|
|
if backend == 'trio':
|
|
return await self._open_trio_stream(host, port, connect_timeout, ssl_context)
|
|
|
|
# Curio support has been dropped in httpcore 0.14.0
|
|
# if backend == 'curio':
|
|
# return await self._open_curio_stream(host, port, connect_timeout, ssl_context)
|
|
|
|
raise RuntimeError(f'Unsupported concurrency backend {backend!r}') # pragma: no cover
|
|
|
|
async def _open_aio_stream(self, host, port, connect_timeout, ssl_context):
|
|
from httpcore._backends.anyio import AnyIOStream
|
|
from python_socks.async_.anyio import Proxy
|
|
|
|
proxy = Proxy.create(
|
|
proxy_type=self._proxy_type,
|
|
host=self._proxy_host,
|
|
port=self._proxy_port,
|
|
username=self._username,
|
|
password=self._password,
|
|
rdns=self._rdns,
|
|
proxy_ssl=self._proxy_ssl,
|
|
)
|
|
|
|
proxy_stream = await proxy.connect(
|
|
host,
|
|
port,
|
|
dest_ssl=ssl_context,
|
|
timeout=connect_timeout,
|
|
)
|
|
|
|
return AnyIOStream(proxy_stream.anyio_stream)
|
|
|
|
async def _open_trio_stream(self, host, port, connect_timeout, ssl_context):
|
|
from httpcore._backends.trio import TrioStream
|
|
from python_socks.async_.trio.v2 import Proxy
|
|
|
|
proxy = Proxy.create(
|
|
proxy_type=self._proxy_type,
|
|
host=self._proxy_host,
|
|
port=self._proxy_port,
|
|
username=self._username,
|
|
password=self._password,
|
|
rdns=self._rdns,
|
|
proxy_ssl=self._proxy_ssl,
|
|
)
|
|
|
|
proxy_stream = await proxy.connect(
|
|
host,
|
|
port,
|
|
dest_ssl=ssl_context,
|
|
timeout=connect_timeout,
|
|
)
|
|
|
|
return TrioStream(proxy_stream.trio_stream)
|
|
|
|
async def aclose(self) -> None:
|
|
if self._connection is not None:
|
|
await self._connection.aclose()
|
|
|
|
def can_handle_request(self, origin: Origin) -> bool:
|
|
return origin == self._remote_origin
|
|
|
|
def is_available(self) -> bool:
|
|
if self._connection is None: # pragma: no cover
|
|
# return self._http2 and (self._remote_origin.scheme == b"https" or not self._http1)
|
|
return False
|
|
return self._connection.is_available()
|
|
|
|
def has_expired(self) -> bool:
|
|
if self._connection is None:
|
|
return self._connect_failed
|
|
return self._connection.has_expired()
|
|
|
|
def is_idle(self) -> bool:
|
|
if self._connection is None:
|
|
return self._connect_failed
|
|
return self._connection.is_idle()
|
|
|
|
def is_closed(self) -> bool:
|
|
if self._connection is None:
|
|
return self._connect_failed
|
|
return self._connection.is_closed()
|
|
|
|
def info(self) -> str: # pragma: no cover
|
|
if self._connection is None:
|
|
return "CONNECTION FAILED" if self._connect_failed else "CONNECTING"
|
|
return self._connection.info()
|