Improve DoH query speed

This commit is contained in:
Xylitol
2023-12-30 16:35:30 +08:00
parent f710f8e460
commit f5c4e8988a

View File

@ -6,18 +6,17 @@ import json
import logging
import socket
import struct
import time
import urllib
import urllib.request
from dataclasses import dataclass
from typing import Dict, List, Optional
from scraper.exceptions import RequestSendError
from scraper.functions import Args, Func
_logger = logging.getLogger(__name__)
_timeout = 5
_registered_hosts = set()
_executor = concurrent.futures.ThreadPoolExecutor()
_doh_cache: Dict[str, str] = {}
_doh_resolvers = [
# https://developers.cloudflare.com/1.1.1.1/encryption/dns-over-https
@ -52,27 +51,18 @@ def _patched_getaddrinfo(host, *args, **kwargs):
return _orig_getaddrinfo(ip, *args, **kwargs)
# resolve the host using DoH
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for resolver in _doh_resolvers:
futures.append(executor.submit(_doh_query, resolver, host))
futures = []
for resolver in _doh_resolvers:
futures.append(_executor.submit(_doh_query, resolver, host))
done, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in concurrent.futures.as_completed(futures):
ip = future.result()
if ip is not None:
_logger.info("Resolved [%s] to [%s]", host, ip)
_doh_cache[host] = ip
host = ip
break
for future in done:
ip = future.result()
if ip is not None:
_logger.info("Resolved [%s] to [%s]", host, ip)
_doh_cache[host] = ip
host = ip
break
for future in not_done:
future.cancel()
_logger.info("Calling original getaddrinfo with [%s]", host)
return _orig_getaddrinfo(host, *args, **kwargs)
@ -121,7 +111,7 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
with urllib.request.urlopen(request, timeout=_timeout) as response:
_logger.info("Resolver(%s) response: %s", resolver, response.status)
if response.status != 200:
raise RequestSendError
return None
resp_body = response.read()
# parse DNS response message (RFC 1035)
@ -133,8 +123,7 @@ def _doh_query(resolver: str, host: str) -> Optional[str]:
return socket.inet_ntoa(resp_body[first_rdata_start:first_rdata_end])
except Exception as e:
_logger.error("Resolver(%s) request error: %s", resolver, e)
time.sleep(_timeout)
raise RequestSendError from e
return None
def _doh_query_json(resolver: str, host: str) -> Optional[str]:
@ -147,15 +136,14 @@ def _doh_query_json(resolver: str, host: str) -> Optional[str]:
with urllib.request.urlopen(request, timeout=_timeout) as response:
_logger.info("Resolver(%s) response: %s", resolver, response.status)
if response.status != 200:
raise RequestSendError
return None
response_body = response.read().decode("utf-8")
_logger.debug("<== body: %s", response_body)
answer = json.loads(response_body)["Answer"]
return answer[0]["data"]
except Exception as e:
_logger.error("Resolver(%s) request error: %s", resolver, e)
time.sleep(_timeout)
raise RequestSendError from e
return None
@dataclass(init=False)