Source code for tierproxy.async_client

from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any

import httpx

from tierproxy._internal.cache import ResponseCache
from tierproxy._internal.cookies import CookieJar
from tierproxy._internal.http import Transport
from tierproxy._internal.streaming import AsyncStreamCM, is_stream
from tierproxy.errors import AuthenticationError
from tierproxy.resources.health import AsyncHealthResource
from tierproxy.resources.me import AsyncMeResource
from tierproxy.resources.ratelimits import AsyncRateLimitsResource
from tierproxy.resources.usage import AsyncUsageResource
from tierproxy.resources.usage_recent import AsyncUsageRecentResource
from tierproxy.retry import RetryPolicy

if TYPE_CHECKING:
    from tierproxy._internal.cost import AsyncCostAttributor
    from tierproxy.proxy.selector import AsyncSmartSelector


[docs] class AsyncTierProxy: def __init__( self, api_key: str | None = None, *, base_url: str = "https://gw.tierproxy.com:8444", timeout: float = 30.0, http_timeout: float = 30.0, max_retries: int = 3, retry_policy: RetryPolicy | None = None, http_client: httpx.AsyncClient | None = None, user_agent_suffix: str | None = None, routing: str | None = None, monthly_budget_usd: float | None = None, cache_ttl: float = 0.0, cache_max_size: int = 256, cache_max_response_size: int = 262144, auto_failover: bool = False, auto_failover_max_attempts: int = 3, ) -> None: key = api_key or os.environ.get("TIERPROXY_API_KEY") if not key: raise AuthenticationError("No api_key passed and TIERPROXY_API_KEY env var not set") retry = retry_policy or RetryPolicy(max_retries=max_retries) self._transport = Transport( api_key=key, base_url=base_url, timeout=timeout, retry=retry, ua_suffix=user_agent_suffix, http_client=http_client, is_async=True, ) self._me: AsyncMeResource | None = None self._usage: AsyncUsageResource | None = None self._health: AsyncHealthResource | None = None self._usage_recent: AsyncUsageRecentResource | None = None self._cost_attributor: AsyncCostAttributor | None = None self._rate_limits: AsyncRateLimitsResource | None = None self._pending_429_reports: set[str] = set() self._http_timeout = http_timeout self._routing = routing self._monthly_budget_usd = monthly_budget_usd self._cache_ttl = cache_ttl self._response_cache: ResponseCache | None = ( ResponseCache( max_size=cache_max_size, default_ttl=cache_ttl, max_response_size=cache_max_response_size, ) if cache_ttl > 0 else None ) self._selector: AsyncSmartSelector | None = None self._spent_usd_mtd: float = 0.0 self._budget_cache_ts: float = 0.0 self._auto_failover = auto_failover self._auto_failover_max_attempts = max(1, auto_failover_max_attempts) self._cookie_jar = CookieJar() if routing or auto_failover: from tierproxy.proxy.selector import AsyncSmartSelector self._selector = AsyncSmartSelector(self, strategy=routing or "balanced") # type: ignore[arg-type] @property def me(self) -> AsyncMeResource: if self._me is None: self._me = AsyncMeResource(self) return self._me @property def usage(self) -> AsyncUsageResource: if self._usage is None: self._usage = AsyncUsageResource(self) return self._usage @property def health(self) -> AsyncHealthResource: if self._health is None: self._health = AsyncHealthResource(self) return self._health @property def usage_recent(self) -> AsyncUsageRecentResource: if self._usage_recent is None: self._usage_recent = AsyncUsageRecentResource(self) return self._usage_recent @property def rate_limits(self) -> AsyncRateLimitsResource: if self._rate_limits is None: self._rate_limits = AsyncRateLimitsResource(self) return self._rate_limits @property def cookies(self) -> CookieJar: return self._cookie_jar async def cost_for(self, resp: httpx.Response) -> float | None: if self._cost_attributor is None: from tierproxy._internal.cost import AsyncCostAttributor self._cost_attributor = AsyncCostAttributor(self) return await self._cost_attributor.cost_for(resp) async def upstream_for(self, resp: httpx.Response) -> str | None: if self._cost_attributor is None: from tierproxy._internal.cost import AsyncCostAttributor self._cost_attributor = AsyncCostAttributor(self) return await self._cost_attributor.upstream_for(resp) async def close(self) -> None: client = self._transport._client if isinstance(client, httpx.AsyncClient): await client.aclose() async def __aenter__(self) -> AsyncTierProxy: return self async def __aexit__(self, *_: Any) -> None: await self.close() # -------- Level 1-2: proxy-through helpers --------
[docs] async def request(self, method: str, url: str, **kwargs: Any) -> Any: """Async request. With ``stream=True``, returns an async context manager yielding an ``httpx.Response`` for streaming. """ if is_stream(kwargs): return self._stream_request(method, url, **kwargs) cache = self._response_cache params = kwargs.get("params") if isinstance(kwargs.get("params"), dict) else None if cache is not None and method.upper() == "GET": cached = cache.get(method, url, params) if cached is not None: return cached resp = await self._dispatch(method, url, **kwargs) if cache is not None and method.upper() == "GET" and 200 <= resp.status_code < 300: cache.set(method, url, params, resp, ttl=self._cache_ttl) return resp
async def _dispatch(self, method: str, url: str, **kwargs: Any) -> httpx.Response: if not self._auto_failover or self._selector is None: return await self._do_request(method, url, **kwargs) tried: set[str] = set() attempts = 0 last_resp: httpx.Response | None = None last_exc: Exception | None = None max_attempts = self._auto_failover_max_attempts while attempts < max_attempts: if attempts == 0: chosen = await self._selector.pick() else: nxt = await self._selector.pick_next(tried) if nxt is None: break chosen = nxt tried.add(chosen.upstream_id) attempts += 1 try: resp = await self._do_request( method, url, _upstream_override=chosen.upstream_id, **kwargs ) except (httpx.ConnectError, httpx.NetworkError, httpx.TimeoutException) as e: last_exc = e continue last_resp = resp if resp.status_code == 429 or 500 <= resp.status_code < 600: continue return resp if last_resp is not None: return last_resp assert last_exc is not None raise last_exc async def _do_request( self, method: str, url: str, _upstream_override: str | None = None, **kwargs: Any, ) -> httpx.Response: from tierproxy._targeting import build_proxy, split_kwargs targeting, passthrough = split_kwargs(kwargs) host = httpx.URL(self._transport.base_url).host proxy = build_proxy(self._transport.api_key, host, 443, targeting) if _upstream_override is not None: proxy.upstream_hint = _upstream_override elif self._selector is not None and not proxy.upstream_hint: chosen = await self._selector.pick() proxy.upstream_hint = chosen.upstream_id headers = passthrough.pop("headers", None) or {} headers = {**headers, **proxy.headers()} if self._pending_429_reports: headers["X-TierProxy-Report-429"] = ",".join(sorted(self._pending_429_reports)) self._pending_429_reports.clear() if (tls_fp := passthrough.pop("tls_fingerprint", None)) is not None: headers["X-TierProxy-TLS-Profile"] = str(tls_fp) if self._monthly_budget_usd is not None: await self._guard_budget_async() session_id = targeting.get("session_id") cookies = self._cookie_jar.get(session_id) if session_id else None async with httpx.AsyncClient( proxy=proxy.http_url(), timeout=self._http_timeout, cookies=cookies ) as h: resp = await h.request(method, url, headers=headers, **passthrough) if session_id: self._cookie_jar.update_from(session_id, resp) if resp.status_code == 429: target_host = httpx.URL(url).host if target_host: self._pending_429_reports.add(target_host) return resp def _stream_request(self, method: str, url: str, **kwargs: Any) -> AsyncStreamCM: from tierproxy._targeting import build_proxy, split_kwargs targeting, passthrough = split_kwargs(kwargs) host = httpx.URL(self._transport.base_url).host proxy = build_proxy(self._transport.api_key, host, 443, targeting) headers = passthrough.pop("headers", None) or {} headers = {**headers, **proxy.headers()} if (tls_fp := passthrough.pop("tls_fingerprint", None)) is not None: headers["X-TierProxy-TLS-Profile"] = str(tls_fp) return AsyncStreamCM( method, url, proxy=proxy.http_url(), timeout=self._http_timeout, headers=headers, **passthrough, ) async def get(self, url: str, **kwargs: Any) -> Any: return await self.request("GET", url, **kwargs) async def post(self, url: str, **kwargs: Any) -> Any: return await self.request("POST", url, **kwargs) def session(self, **targeting_kwargs: Any) -> httpx.AsyncClient: from tierproxy._targeting import build_proxy host = httpx.URL(self._transport.base_url).host proxy = build_proxy(self._transport.api_key, host, 443, targeting_kwargs) return httpx.AsyncClient( proxy=proxy.http_url(), headers=proxy.headers(), timeout=self._http_timeout, ) def target(self, **kwargs: Any) -> AsyncTargetedRequest: return AsyncTargetedRequest(self, kwargs) # -------- Level 3: cost guard -------- async def _guard_budget_async(self) -> None: import time from tierproxy._guard import check_budget if time.time() - self._budget_cache_ts > 60: usage = await self.usage.get() self._spent_usd_mtd = usage.total_cost_usd self._budget_cache_ts = time.time() avg_cost = 4.0 estimated = avg_cost * (1024 * 1024) / (1024**3) assert self._monthly_budget_usd is not None check_budget(estimated, self._spent_usd_mtd, self._monthly_budget_usd)
class AsyncTargetedRequest: def __init__(self, client: AsyncTierProxy, targeting: dict[str, Any]) -> None: self._client = client self._targeting = targeting async def get(self, url: str, **kw: Any) -> Any: return await self._client.get(url, **{**self._targeting, **kw}) async def post(self, url: str, **kw: Any) -> Any: return await self._client.post(url, **{**self._targeting, **kw}) async def request(self, method: str, url: str, **kw: Any) -> Any: return await self._client.request(method, url, **{**self._targeting, **kw})