diff --git a/eventlet/support/greendns.py b/eventlet/support/greendns.py index 8d5c7d6..ea0924c 100644 --- a/eventlet/support/greendns.py +++ b/eventlet/support/greendns.py @@ -306,6 +306,12 @@ class ResolverProxy(object): tcp=False, source=None, raise_on_no_answer=True, _hosts_rdtypes=(dns.rdatatype.A, dns.rdatatype.AAAA)): """Query the resolver, using /etc/hosts if enabled. + + Behavior: + 1. if hosts is enabled and contains answer, return it now + 2. query nameservers for qname + 3. if qname did not contain dots, pretend it was top-level domain, + query "foobar." and append to previous result """ result = [None, None, 0] @@ -328,6 +334,21 @@ class ResolverProxy(object): result[2] += len(a.rrset) return True + def end(): + if result[0] is not None: + if raise_on_no_answer and result[2] == 0: + raise dns.resolver.NoAnswer + return result[0] + if result[1] is not None: + if raise_on_no_answer or not isinstance(result[1], dns.resolver.NoAnswer): + raise result[1] + raise dns.resolver.NXDOMAIN(qnames=(qname,)) + + if (self._hosts and (rdclass == dns.rdataclass.IN) and (rdtype in _hosts_rdtypes)): + if step(self._hosts.query, qname, rdtype, raise_on_no_answer=False): + if (result[0] is not None) or (result[1] is not None): + return end() + # Main query step(self._resolver.query, qname, rdtype, rdclass, tcp, source, raise_on_no_answer=False) @@ -341,20 +362,7 @@ class ResolverProxy(object): step(self._resolver.query, qname.concatenate(dns.name.root), rdtype, rdclass, tcp, source, raise_on_no_answer=False) - # Return answers from /etc/hosts despite nameserver errors - # https://github.com/eventlet/eventlet/pull/354 - if ((result[2] == 0) and self._hosts and - (rdclass == dns.rdataclass.IN) and (rdtype in _hosts_rdtypes)): - step(self._hosts.query, qname, rdtype, raise_on_no_answer=False) - - if result[0] is not None: - if raise_on_no_answer and result[2] == 0: - raise dns.resolver.NoAnswer - return result[0] - if result[1] is not None: - if raise_on_no_answer or not isinstance(result[1], dns.resolver.NoAnswer): - raise result[1] - raise dns.resolver.NXDOMAIN(qnames=(qname,)) + return end() def getaliases(self, hostname): """Return a list of all the aliases of a given hostname""" @@ -376,8 +384,8 @@ class ResolverProxy(object): resolver = ResolverProxy(hosts_resolver=HostsResolver()) -def resolve(name, family=socket.AF_INET, raises=True): - """Resolve a name for a given family using the global resolver proxy +def resolve(name, family=socket.AF_INET, raises=True, _proxy=None): + """Resolve a name for a given family using the global resolver proxy. This method is called by the global getaddrinfo() function. @@ -391,9 +399,12 @@ def resolve(name, family=socket.AF_INET, raises=True): else: raise socket.gaierror(socket.EAI_FAMILY, 'Address family not supported') + + if _proxy is None: + _proxy = resolver try: try: - return resolver.query(name, rdtype, raise_on_no_answer=raises) + return _proxy.query(name, rdtype, raise_on_no_answer=raises) except dns.resolver.NXDOMAIN: if not raises: return HostsAnswer(dns.name.Name(name), diff --git a/tests/greendns_test.py b/tests/greendns_test.py index d16ee43..efdd628 100644 --- a/tests/greendns_test.py +++ b/tests/greendns_test.py @@ -12,26 +12,27 @@ import tests import tests.mock +def _make_host_resolver(): + """Returns a HostResolver instance + + The hosts file will be empty but accessible as a py.path.local + instance using the ``hosts`` attribute. + """ + hosts = tempfile.NamedTemporaryFile() + hr = greendns.HostsResolver(fname=hosts.name) + hr.hosts = hosts + hr._last_stat = 0 + return hr + + class TestHostsResolver(tests.LimitedTestCase): - def _make_host_resolver(self): - """Returns a HostResolver instance - - The hosts file will be empty but accessible as a py.path.local - instance using the ``hosts`` attribute. - """ - hosts = tempfile.NamedTemporaryFile() - hr = greendns.HostsResolver(fname=hosts.name) - hr.hosts = hosts - hr._last_stat = 0 - return hr - def test_default_fname(self): hr = greendns.HostsResolver() assert os.path.exists(hr.fname) def test_readlines_lines(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr.hosts.write(b'line0\n') hr.hosts.flush() assert hr._readlines() == ['line0'] @@ -44,20 +45,20 @@ class TestHostsResolver(tests.LimitedTestCase): assert hr._readlines() == ['line0', 'line1'] def test_readlines_missing_file(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr.hosts.close() hr._last_stat = 0 assert hr._readlines() == [] def test_load_no_contents(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._load() assert not hr._v4 assert not hr._v6 assert not hr._aliases def test_load_v4_v6_cname_aliases(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr.hosts.write(b'1.2.3.4 v4.example.com v4\n' b'dead:beef::1 v6.example.com v6\n') hr.hosts.flush() @@ -69,7 +70,7 @@ class TestHostsResolver(tests.LimitedTestCase): 'v6': 'v6.example.com'} def test_load_v6_link_local(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr.hosts.write(b'fe80:: foo\n' b'fe80:dead:beef::1 bar\n') hr.hosts.flush() @@ -78,14 +79,14 @@ class TestHostsResolver(tests.LimitedTestCase): assert not hr._v6 def test_query_A(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._v4 = {'v4.example.com': '1.2.3.4'} ans = hr.query('v4.example.com') assert ans[0].address == '1.2.3.4' def test_query_ans_types(self): # This assumes test_query_A above succeeds - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._v4 = {'v4.example.com': '1.2.3.4'} hr._last_stat = time.time() ans = hr.query('v4.example.com') @@ -108,18 +109,18 @@ class TestHostsResolver(tests.LimitedTestCase): assert rr.address == '1.2.3.4' def test_query_AAAA(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._v6 = {'v6.example.com': 'dead:beef::1'} ans = hr.query('v6.example.com', dns.rdatatype.AAAA) assert ans[0].address == 'dead:beef::1' def test_query_unknown_raises(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() with tests.assert_raises(greendns.dns.resolver.NoAnswer): hr.query('example.com') def test_query_unknown_no_raise(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() ans = hr.query('example.com', raise_on_no_answer=False) assert isinstance(ans, greendns.dns.resolver.Answer) assert ans.response is None @@ -134,30 +135,30 @@ class TestHostsResolver(tests.LimitedTestCase): assert len(ans.rrset) == 0 def test_query_CNAME(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._aliases = {'host': 'host.example.com'} ans = hr.query('host', dns.rdatatype.CNAME) assert ans[0].target == dns.name.from_text('host.example.com') assert str(ans[0].target) == 'host.example.com.' def test_query_unknown_type(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() with tests.assert_raises(greendns.dns.resolver.NoAnswer): hr.query('example.com', dns.rdatatype.MX) def test_getaliases(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._aliases = {'host': 'host.example.com', 'localhost': 'host.example.com'} res = set(hr.getaliases('host')) assert res == set(['host.example.com', 'localhost']) def test_getaliases_unknown(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() assert hr.getaliases('host.example.com') == [] def test_getaliases_fqdn(self): - hr = self._make_host_resolver() + hr = _make_host_resolver() hr._aliases = {'host': 'host.example.com'} res = set(hr.getaliases('host.example.com')) assert res == set(['host']) @@ -791,3 +792,29 @@ def test_proxy_resolve_unqualified(): pass assert any(call[0][0] == dns.name.from_text('machine') for call in m.call_args_list) assert any(call[0][0] == dns.name.from_text('machine.') for call in m.call_args_list) + + +def test_hosts_priority(): + name = 'example.com' + addr_from_ns = '1.0.2.0' + + hr = _make_host_resolver() + rp = greendns.ResolverProxy(hosts_resolver=hr, filename=None) + base = _make_mock_base_resolver() + base.rr.address = addr_from_ns + rp._resolver = base() + + # Default behavior + rrns = greendns.resolve(name, _proxy=rp).rrset[0] + assert rrns.address == addr_from_ns + + # Hosts result must shadow that from nameservers + hr.hosts.write(b'1.2.3.4 example.com\ndead:beef::1 example.com\n') + hr.hosts.flush() + hr._load() + rrs4 = greendns.resolve(name, family=socket.AF_INET, _proxy=rp).rrset + assert len(rrs4) == 1 + assert rrs4[0].address == '1.2.3.4', rrs4[0].address + rrs6 = greendns.resolve(name, family=socket.AF_INET6, _proxy=rp).rrset + assert len(rrs6) == 1 + assert rrs6[0].address == 'dead:beef::1', rrs6[0].address