Add support for non-blocking DNS calls
This takes advantage of some of the purepydns code from the gogreen package open-sourced by Slide, Inc: http://github.com/slideinc/gogreen
This commit is contained in:
@@ -15,6 +15,13 @@ except AttributeError:
|
||||
__all__ = __socket.__all__
|
||||
__patched__ = __socket.__patched__ + ['gethostbyname', 'getaddrinfo']
|
||||
|
||||
|
||||
greendns = None
|
||||
try:
|
||||
from eventlet.support import greendns
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__original_gethostbyname__ = __socket.gethostbyname
|
||||
# the thread primitives on Darwin have some bugs that make
|
||||
# it undesirable to use tpool for hostname lookups
|
||||
@@ -33,6 +40,8 @@ def _gethostbyname_tpool(name):
|
||||
|
||||
if getattr(get_hub(), 'uses_twisted_reactor', None):
|
||||
gethostbyname = _gethostbyname_twisted
|
||||
elif greendns:
|
||||
gethostbyname = greendns.gethostbyname
|
||||
elif _can_use_tpool:
|
||||
gethostbyname = _gethostbyname_tpool
|
||||
else:
|
||||
@@ -45,8 +54,15 @@ def _getaddrinfo_tpool(*args, **kw):
|
||||
return tpool.execute(
|
||||
__original_getaddrinfo__, *args, **kw)
|
||||
|
||||
if _can_use_tpool:
|
||||
if greendns:
|
||||
getaddrinfo = greendns.getaddrinfo
|
||||
elif _can_use_tpool:
|
||||
getaddrinfo = _getaddrinfo_tpool
|
||||
else:
|
||||
getaddrinfo = __original_getaddrinfo__
|
||||
|
||||
if greendns:
|
||||
gethostbyname_ex = greendns.gethostbyname_ex
|
||||
getnameinfo = greendns.getnameinfo
|
||||
|
||||
|
||||
|
451
eventlet/support/greendns.py
Normal file
451
eventlet/support/greendns.py
Normal file
@@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python
|
||||
'''
|
||||
greendns - non-blocking DNS support for Eventlet
|
||||
'''
|
||||
|
||||
# Portions of this code taken from the gogreen project:
|
||||
# http://github.com/slideinc/gogreen
|
||||
#
|
||||
# Copyright (c) 2005-2010 Slide, Inc.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are
|
||||
# met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above
|
||||
# copyright notice, this list of conditions and the following
|
||||
# disclaimer in the documentation and/or other materials provided
|
||||
# with the distribution.
|
||||
# * Neither the name of the author nor the names of other
|
||||
# contributors may be used to endorse or promote products derived
|
||||
# from this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import random
|
||||
from eventlet import patcher
|
||||
from eventlet.green import _socket_nodns
|
||||
from eventlet.green import time
|
||||
from eventlet.green import select
|
||||
|
||||
start = time.time()
|
||||
|
||||
__imports = []
|
||||
for package in ('dns', 'dns.query', 'dns.exception', 'dns.inet',
|
||||
'dns.message', 'dns.name', 'dns.rdata', 'dns.rdataset',
|
||||
'dns.rdatatype', 'dns.resolver', 'dns.reversename'):
|
||||
__imports.append('%(pkg)s = patcher.import_patched(\'%(pkg)s\', socket=_socket_nodns, time=time, select=select)' % dict(pkg=package))
|
||||
exec '\n'.join(__imports)
|
||||
|
||||
end = time.time()
|
||||
|
||||
with open('times_%s.txt' % random.random(), 'w') as fd:
|
||||
fd.write('%s\n' % (end - start))
|
||||
|
||||
socket = _socket_nodns
|
||||
|
||||
DNS_QUERY_TIMEOUT = 10.0
|
||||
|
||||
#
|
||||
# Resolver instance used to perfrom DNS lookups.
|
||||
#
|
||||
class FakeAnswer(list):
|
||||
expiration = 0
|
||||
class FakeRecord(object):
|
||||
pass
|
||||
|
||||
class ResolverProxy(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._resolver = None
|
||||
self._filename = kwargs.get('filename', '/etc/resolv.conf')
|
||||
self._hosts = {}
|
||||
if kwargs.pop('dev', False):
|
||||
self._load_etc_hosts()
|
||||
|
||||
def _load_etc_hosts(self):
|
||||
fd = open('/etc/hosts', 'r')
|
||||
contents = fd.read()
|
||||
fd.close()
|
||||
contents = [line for line in contents.split('\n') if line and not line[0] == '#']
|
||||
for line in contents:
|
||||
line = line.replace('\t', ' ')
|
||||
parts = line.split(' ')
|
||||
parts = [p for p in parts if p]
|
||||
if not len(parts):
|
||||
continue
|
||||
ip = parts[0]
|
||||
for part in parts[1:]:
|
||||
self._hosts[part] = ip
|
||||
|
||||
def clear(self):
|
||||
self._resolver = None
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
if self._resolver is None:
|
||||
self._resolver = dns.resolver.Resolver(filename = self._filename)
|
||||
self._resolver.cache = dns.resolver.Cache()
|
||||
|
||||
query = args[0]
|
||||
if self._hosts and self._hosts.get(query):
|
||||
answer = FakeAnswer()
|
||||
record = FakeRecord()
|
||||
setattr(record, 'address', self._hosts[query])
|
||||
answer.append(record)
|
||||
return answer
|
||||
return self._resolver.query(*args, **kwargs)
|
||||
#
|
||||
# cache
|
||||
#
|
||||
resolver = ResolverProxy()
|
||||
|
||||
def resolve(name):
|
||||
error = None
|
||||
rrset = None
|
||||
|
||||
if rrset is None or time.time() > rrset.expiration:
|
||||
try:
|
||||
rrset = resolver.query(name)
|
||||
except dns.exception.Timeout, e:
|
||||
error = (socket.EAI_AGAIN, 'Lookup timed out')
|
||||
except dns.exception.DNSException, e:
|
||||
error = (socket.EAI_NODATA, 'No address associated with hostname')
|
||||
else:
|
||||
pass
|
||||
#responses.insert(name, rrset)
|
||||
|
||||
if error:
|
||||
if rrset is None:
|
||||
raise socket.gaierror(error)
|
||||
else:
|
||||
sys.stderr.write('DNS error: %r %r\n' % (name, error))
|
||||
return rrset
|
||||
#
|
||||
# methods
|
||||
#
|
||||
def getaliases(host):
|
||||
"""Checks for aliases of the given hostname (cname records)
|
||||
returns a list of alias targets
|
||||
will return an empty list if no aliases
|
||||
"""
|
||||
cnames = []
|
||||
error = None
|
||||
|
||||
try:
|
||||
answers = dns.resolver.query(host, 'cname')
|
||||
except dns.exception.Timeout, e:
|
||||
error = (socket.EAI_AGAIN, 'Lookup timed out')
|
||||
except dns.exception.DNSException, e:
|
||||
error = (socket.EAI_NODATA, 'No address associated with hostname')
|
||||
else:
|
||||
for record in answers:
|
||||
cnames.append(str(answers[0].target))
|
||||
|
||||
if error:
|
||||
sys.stderr.write('DNS error: %r %r\n' % (host, error))
|
||||
|
||||
return cnames
|
||||
|
||||
def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0):
|
||||
"""Replacement for Python's socket.getaddrinfo.
|
||||
|
||||
Currently only supports IPv4. At present, flags are not
|
||||
implemented.
|
||||
"""
|
||||
socktype = socktype or socket.SOCK_STREAM
|
||||
|
||||
if is_ipv4_addr(host):
|
||||
return [(socket.AF_INET, socktype, proto, '', (host, port))]
|
||||
|
||||
rrset = resolve(host)
|
||||
value = []
|
||||
|
||||
for rr in rrset:
|
||||
value.append((socket.AF_INET, socktype, proto, '', (rr.address, port)))
|
||||
return value
|
||||
|
||||
def gethostbyname(hostname):
|
||||
"""Replacement for Python's socket.gethostbyname.
|
||||
|
||||
Currently only supports IPv4.
|
||||
"""
|
||||
if is_ipv4_addr(hostname):
|
||||
return hostname
|
||||
|
||||
rrset = resolve(hostname)
|
||||
return rrset[0].address
|
||||
|
||||
def gethostbyname_ex(hostname):
|
||||
"""Replacement for Python's socket.gethostbyname_ex.
|
||||
|
||||
Currently only supports IPv4.
|
||||
"""
|
||||
if is_ipv4_addr(hostname):
|
||||
return (hostname, [], [hostname])
|
||||
|
||||
rrset = resolve(hostname)
|
||||
addrs = []
|
||||
|
||||
for rr in rrset:
|
||||
addrs.append(rr.address)
|
||||
return (hostname, [], addrs)
|
||||
|
||||
def getnameinfo(sockaddr, flags):
|
||||
"""Replacement for Python's socket.getnameinfo.
|
||||
|
||||
Currently only supports IPv4.
|
||||
"""
|
||||
host, port = sockaddr
|
||||
|
||||
if (flags & socket.NI_NAMEREQD) and (flags & socket.NI_NUMERICHOST):
|
||||
# Conflicting flags. Punt.
|
||||
raise socket.gaierror(
|
||||
(socket.EAI_NONAME, 'Name or service not known'))
|
||||
|
||||
if is_ipv4_addr(host):
|
||||
try:
|
||||
rrset = resolver.query(
|
||||
dns.reversename.from_address(host), dns.rdatatype.PTR)
|
||||
if len(rrset) > 1:
|
||||
raise socket.error('sockaddr resolved to multiple addresses')
|
||||
host = rrset[0].target.to_text(omit_final_dot=True)
|
||||
except dns.exception.Timeout, e:
|
||||
if flags & socket.NI_NAMEREQD:
|
||||
raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out'))
|
||||
except dns.exception.DNSException, e:
|
||||
if flags & socket.NI_NAMEREQD:
|
||||
raise socket.gaierror(
|
||||
(socket.EAI_NONAME, 'Name or service not known'))
|
||||
else:
|
||||
try:
|
||||
rrset = resolver.query(host)
|
||||
if len(rrset) > 1:
|
||||
raise socket.error('sockaddr resolved to multiple addresses')
|
||||
if flags & socket.NI_NUMERICHOST:
|
||||
host = rrset[0].address
|
||||
except dns.exception.Timeout, e:
|
||||
raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out'))
|
||||
except dns.exception.DNSException, e:
|
||||
raise socket.gaierror(
|
||||
(socket.EAI_NODATA, 'No address associated with hostname'))
|
||||
|
||||
if not (flags & socket.NI_NUMERICSERV):
|
||||
proto = (flags & socket.NI_DGRAM) and 'udp' or 'tcp'
|
||||
port = socket.getservbyport(port, proto)
|
||||
|
||||
return (host, port)
|
||||
|
||||
def is_ipv4_addr(host):
|
||||
"""is_ipv4_addr returns true if host is a valid IPv4 address in
|
||||
dotted quad notation.
|
||||
"""
|
||||
try:
|
||||
d1, d2, d3, d4 = map(int, host.split('.'))
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
if 0 <= d1 <= 255 and 0 <= d2 <= 255 and 0 <= d3 <= 255 and 0 <= d4 <= 255:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _net_read(sock, count, expiration):
|
||||
"""coro friendly replacement for dns.query._net_write
|
||||
Read the specified number of bytes from sock. Keep trying until we
|
||||
either get the desired amount, or we hit EOF.
|
||||
A Timeout exception will be raised if the operation is not completed
|
||||
by the expiration time.
|
||||
"""
|
||||
s = ''
|
||||
while count > 0:
|
||||
try:
|
||||
n = sock.recv(count)
|
||||
except socket.timeout:
|
||||
## Q: Do we also need to catch coro.CoroutineSocketWake and pass?
|
||||
if expiration - time.time() <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
if n == '':
|
||||
raise EOFError
|
||||
count = count - len(n)
|
||||
s = s + n
|
||||
return s
|
||||
|
||||
def _net_write(sock, data, expiration):
|
||||
"""coro friendly replacement for dns.query._net_write
|
||||
Write the specified data to the socket.
|
||||
A Timeout exception will be raised if the operation is not completed
|
||||
by the expiration time.
|
||||
"""
|
||||
current = 0
|
||||
l = len(data)
|
||||
while current < l:
|
||||
try:
|
||||
current += sock.send(data[current:])
|
||||
except socket.timeout:
|
||||
## Q: Do we also need to catch coro.CoroutineSocketWake and pass?
|
||||
if expiration - time.time() <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
|
||||
def udp(
|
||||
q, where, timeout=DNS_QUERY_TIMEOUT, port=53, af=None, source=None,
|
||||
source_port=0, ignore_unexpected=False):
|
||||
"""coro friendly replacement for dns.query.udp
|
||||
Return the response obtained after sending a query via UDP.
|
||||
|
||||
@param q: the query
|
||||
@type q: dns.message.Message
|
||||
@param where: where to send the message
|
||||
@type where: string containing an IPv4 or IPv6 address
|
||||
@param timeout: The number of seconds to wait before the query times out.
|
||||
If None, the default, wait forever.
|
||||
@type timeout: float
|
||||
@param port: The port to which to send the message. The default is 53.
|
||||
@type port: int
|
||||
@param af: the address family to use. The default is None, which
|
||||
causes the address family to use to be inferred from the form of of where.
|
||||
If the inference attempt fails, AF_INET is used.
|
||||
@type af: int
|
||||
@rtype: dns.message.Message object
|
||||
@param source: source address. The default is the IPv4 wildcard address.
|
||||
@type source: string
|
||||
@param source_port: The port from which to send the message.
|
||||
The default is 0.
|
||||
@type source_port: int
|
||||
@param ignore_unexpected: If True, ignore responses from unexpected
|
||||
sources. The default is False.
|
||||
@type ignore_unexpected: bool"""
|
||||
|
||||
wire = q.to_wire()
|
||||
if af is None:
|
||||
try:
|
||||
af = dns.inet.af_for_address(where)
|
||||
except:
|
||||
af = dns.inet.AF_INET
|
||||
if af == dns.inet.AF_INET:
|
||||
destination = (where, port)
|
||||
if source is not None:
|
||||
source = (source, source_port)
|
||||
elif af == dns.inet.AF_INET6:
|
||||
destination = (where, port, 0, 0)
|
||||
if source is not None:
|
||||
source = (source, source_port, 0, 0)
|
||||
|
||||
s = socket.socket(af, socket.SOCK_DGRAM)
|
||||
s.settimeout(timeout)
|
||||
try:
|
||||
expiration = dns.query._compute_expiration(timeout)
|
||||
if source is not None:
|
||||
s.bind(source)
|
||||
try:
|
||||
s.sendto(wire, destination)
|
||||
except socket.timeout:
|
||||
## Q: Do we also need to catch coro.CoroutineSocketWake and pass?
|
||||
if expiration - time.time() <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
while 1:
|
||||
try:
|
||||
(wire, from_address) = s.recvfrom(65535)
|
||||
except socket.timeout:
|
||||
## Q: Do we also need to catch coro.CoroutineSocketWake and pass?
|
||||
if expiration - time.time() <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
if from_address == destination:
|
||||
break
|
||||
if not ignore_unexpected:
|
||||
raise dns.query.UnexpectedSource(
|
||||
'got a response from %s instead of %s'
|
||||
% (from_address, destination))
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
|
||||
if not q.is_response(r):
|
||||
raise dns.query.BadResponse()
|
||||
return r
|
||||
|
||||
def tcp(q, where, timeout=DNS_QUERY_TIMEOUT, port=53,
|
||||
af=None, source=None, source_port=0):
|
||||
"""coro friendly replacement for dns.query.tcp
|
||||
Return the response obtained after sending a query via TCP.
|
||||
|
||||
@param q: the query
|
||||
@type q: dns.message.Message object
|
||||
@param where: where to send the message
|
||||
@type where: string containing an IPv4 or IPv6 address
|
||||
@param timeout: The number of seconds to wait before the query times out.
|
||||
If None, the default, wait forever.
|
||||
@type timeout: float
|
||||
@param port: The port to which to send the message. The default is 53.
|
||||
@type port: int
|
||||
@param af: the address family to use. The default is None, which
|
||||
causes the address family to use to be inferred from the form of of where.
|
||||
If the inference attempt fails, AF_INET is used.
|
||||
@type af: int
|
||||
@rtype: dns.message.Message object
|
||||
@param source: source address. The default is the IPv4 wildcard address.
|
||||
@type source: string
|
||||
@param source_port: The port from which to send the message.
|
||||
The default is 0.
|
||||
@type source_port: int"""
|
||||
|
||||
wire = q.to_wire()
|
||||
if af is None:
|
||||
try:
|
||||
af = dns.inet.af_for_address(where)
|
||||
except:
|
||||
af = dns.inet.AF_INET
|
||||
if af == dns.inet.AF_INET:
|
||||
destination = (where, port)
|
||||
if source is not None:
|
||||
source = (source, source_port)
|
||||
elif af == dns.inet.AF_INET6:
|
||||
destination = (where, port, 0, 0)
|
||||
if source is not None:
|
||||
source = (source, source_port, 0, 0)
|
||||
s = socket.socket(af, socket.SOCK_STREAM)
|
||||
s.settimeout(timeout)
|
||||
try:
|
||||
expiration = dns.query._compute_expiration(timeout)
|
||||
if source is not None:
|
||||
s.bind(source)
|
||||
try:
|
||||
s.connect(destination)
|
||||
except socket.timeout:
|
||||
## Q: Do we also need to catch coro.CoroutineSocketWake and pass?
|
||||
if expiration - time.time() <= 0.0:
|
||||
raise dns.exception.Timeout
|
||||
|
||||
l = len(wire)
|
||||
# copying the wire into tcpmsg is inefficient, but lets us
|
||||
# avoid writev() or doing a short write that would get pushed
|
||||
# onto the net
|
||||
tcpmsg = struct.pack("!H", l) + wire
|
||||
_net_write(s, tcpmsg, expiration)
|
||||
ldata = _net_read(s, 2, expiration)
|
||||
(l,) = struct.unpack("!H", ldata)
|
||||
wire = _net_read(s, l, expiration)
|
||||
finally:
|
||||
s.close()
|
||||
r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac)
|
||||
if not q.is_response(r):
|
||||
raise dns.query.BadResponse()
|
||||
return r
|
||||
|
||||
def reset():
|
||||
resolver.clear()
|
||||
|
||||
# Install our coro-friendly replacements for the tcp and udp query methods.
|
||||
dns.query.tcp = tcp
|
||||
dns.query.udp = udp
|
||||
|
Reference in New Issue
Block a user