Merge "Cleaned up mdns handler and added better test coverage"

This commit is contained in:
Zuul 2019-08-06 12:42:39 +00:00 committed by Gerrit Code Review
commit 92c60b14a5
2 changed files with 334 additions and 132 deletions

View File

@ -15,19 +15,18 @@
# under the License.
import dns
import dns.flags
import dns.message
import dns.opcode
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.message
import six
from oslo_config import cfg
from oslo_log import log as logging
import six
from designate import exceptions
from designate.mdns import xfr
from designate.central import rpcapi as central_api
from designate.mdns import xfr
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
@ -41,9 +40,7 @@ TSIG_RRSIZE = 10 + 64 + 160 + 1
class RequestHandler(xfr.XFRMixin):
def __init__(self, storage, tg):
# Get a storage connection
self.storage = storage
self.tg = tg
@ -63,7 +60,7 @@ class RequestHandler(xfr.XFRMixin):
# TSIG places the pseudo records into the additional section.
if (len(request.question) != 1 or
request.question[0].rdclass != dns.rdataclass.IN):
LOG.debug("Refusing due to numbers of questions or rdclass")
LOG.debug('Refusing due to numbers of questions or rdclass')
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
@ -88,7 +85,7 @@ class RequestHandler(xfr.XFRMixin):
else:
# Unhandled OpCode's include STATUS, IQUERY, UPDATE
LOG.debug("Refusing unhandled opcode")
LOG.debug('Refusing unhandled opcode')
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
@ -135,8 +132,14 @@ class RequestHandler(xfr.XFRMixin):
# We'll reply but don't do anything with the NOTIFY.
master_addr = zone.get_master_by_ip(notify_addr)
if not master_addr:
LOG.warning("NOTIFY for %(name)s from non-master server %(addr)s, "
"refusing.", {"name": zone.name, "addr": notify_addr})
LOG.warning(
'NOTIFY for %(name)s from non-master server %(addr)s, '
'refusing.',
{
'name': zone.name,
'addr': notify_addr
}
)
response.set_rcode(dns.rcode.REFUSED)
yield response
return
@ -147,12 +150,24 @@ class RequestHandler(xfr.XFRMixin):
soa_answer = resolver.query(zone.name, 'SOA')
soa_serial = soa_answer[0].serial
if soa_serial == zone.serial:
LOG.info("Serial %(serial)s is the same for master and us for "
"%(zone_id)s", {"serial": soa_serial, "zone_id": zone.id})
LOG.info(
'Serial %(serial)s is the same for master and us for '
'%(zone_id)s',
{
'serial': soa_serial,
'zone_id': zone.id
}
)
else:
LOG.info("Scheduling AXFR for %(zone_id)s from %(master_addr)s",
{"zone_id": zone.id, "master_addr": master_addr})
LOG.info(
'Scheduling AXFR for %(zone_id)s from %(master_addr)s',
{
'zone_id': zone.id,
'master_addr': master_addr.to_data()
}
)
self.tg.add_thread(self.zone_sync, context, zone,
[master_addr])
@ -161,18 +176,6 @@ class RequestHandler(xfr.XFRMixin):
yield response
return
def _handle_query_error(self, request, rcode):
"""
Construct an error response with the rcode passed in.
:param request: The decoded request from the wire.
:param rcode: The response code to send back.
:return: A dns response message with the response code set to rcode
"""
response = dns.message.make_response(request)
response.set_rcode(rcode)
return response
def _zone_criterion_from_request(self, request, criterion=None):
"""Builds a bare criterion dict based on the request attributes"""
criterion = criterion or {}
@ -181,45 +184,20 @@ class RequestHandler(xfr.XFRMixin):
if tsigkey is None and CONF['service:mdns'].query_enforce_tsig:
raise exceptions.Forbidden('Request is not TSIG signed')
elif tsigkey is None:
# Default to using the default_pool_id when no TSIG key is
# available
criterion['pool_id'] = CONF['service:central'].default_pool_id
else:
if tsigkey.scope == 'POOL':
criterion['pool_id'] = tsigkey.resource_id
elif tsigkey.scope == 'ZONE':
criterion['id'] = tsigkey.resource_id
else:
raise NotImplementedError("Support for %s scoped TSIG Keys is "
"not implemented")
raise NotImplementedError('Support for %s scoped TSIG Keys is '
'not implemented')
return criterion
def _convert_to_rrset(self, zone, recordset):
# Fetch the zone or the config ttl if the recordset ttl is null
ttl = recordset.ttl or zone.ttl
# construct rdata from all the records
# TODO(Ron): this should be handled in the Storage query where we
# find the recordsets.
rdata = [str(record.data) for record in recordset.records
if record.action != 'DELETE']
# Now put the records into dnspython's RRsets
# answer section has 1 RR set. If the RR set has multiple
# records, DNSpython puts each record in a separate answer
# section.
# RRSet has name, ttl, class, type and rdata
# The rdata has one or more records
if rdata:
return dns.rrset.from_text_list(
recordset.name, ttl, dns.rdataclass.IN, recordset.type, rdata)
def _handle_axfr(self, request):
context = request.environ['context']
q_rrset = request.question[0]
@ -234,17 +212,15 @@ class RequestHandler(xfr.XFRMixin):
criterion = self._zone_criterion_from_request(
request, {'name': name})
zone = self.storage.find_zone(context, criterion)
except exceptions.ZoneNotFound:
LOG.warning("ZoneNotFound while handling axfr request. "
"Question was %(qr)s", {'qr': q_rrset})
LOG.warning('ZoneNotFound while handling axfr request. '
'Question was %(qr)s', {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
except exceptions.Forbidden:
LOG.warning("Forbidden while handling axfr request. "
"Question was %(qr)s", {'qr': q_rrset})
LOG.warning('Forbidden while handling axfr request. '
'Question was %(qr)s', {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
@ -261,86 +237,54 @@ class RequestHandler(xfr.XFRMixin):
records.insert(0, soa_records[0])
records.append(soa_records[0])
# Build up a dummy response, we're stealing it's logic for building
# the Flags.
response = dns.message.make_response(request)
response.flags |= dns.flags.AA
response.set_rcode(dns.rcode.NOERROR)
max_message_size = CONF['service:mdns'].max_message_size
if max_message_size > 65535:
LOG.warning('MDNS max message size must not be greater than 65535')
max_message_size = 65535
if request.had_tsig:
# Make some room for the TSIG RR to be appended at the end of the
# rendered message.
max_message_size = max_message_size - TSIG_RRSIZE
# Render the results, yielding a packet after each TooBig exception.
i, renderer = 0, None
while i < len(records):
record = records[i]
renderer = None
while records:
record = records.pop(0)
# No renderer? Build one
if renderer is None:
renderer = dns.renderer.Renderer(
response.id, response.flags, max_message_size)
for q in request.question:
renderer.add_question(q.name, q.rdtype, q.rdclass)
rrname = str(record[3])
ttl = int(record[2]) if record[2] is not None else zone.ttl
rrtype = str(record[1])
rdata = [str(record[4])]
# Build a DNSPython RRSet from the RR
rrset = dns.rrset.from_text_list(
str(record[3]), # name
int(record[2]) if record[2] is not None else zone.ttl, # ttl
dns.rdataclass.IN, # class
str(record[1]), # rrtype
[str(record[4])], # rdata
rrname, ttl, dns.rdataclass.IN, rrtype, rdata,
)
try:
renderer.add_rrset(dns.renderer.ANSWER, rrset)
i += 1
except dns.exception.TooBig:
if renderer.counts[dns.renderer.ANSWER] == 0:
# We've received a TooBig from the first attempted RRSet in
# this packet. Log a warning and abort the AXFR.
LOG.warning('Aborted AXFR of %(zone)s, a single RR '
'(%(rrset_type)s %(rrset_name)s) '
'exceeded the max message size.',
{'zone': zone.name,
'rrset_type': record[1],
'rrset_name': record[3]})
while True:
try:
if not renderer:
renderer = self._create_axfr_renderer(request)
renderer.add_rrset(dns.renderer.ANSWER, rrset)
break
except dns.exception.TooBig:
if renderer.counts[dns.renderer.ANSWER] == 0:
# We've received a TooBig from the first attempted
# RRSet in this packet. Log a warning and abort the
# AXFR.
LOG.warning(
'Aborted AXFR of %(zone)s, a single RR '
'(%(rrset_type)s %(rrset_name)s) '
'exceeded the max message size.',
{
'zone': zone.name,
'rrset_type': rrtype,
'rrset_name': rrname,
}
)
yield self._handle_query_error(request, dns.rcode.SERVFAIL)
return
yield self._handle_query_error(
request, dns.rcode.SERVFAIL
)
return
else:
yield self._finalize_packet(renderer, request)
renderer = None
if renderer is not None:
if renderer:
yield self._finalize_packet(renderer, request)
return
def _finalize_packet(self, renderer, request):
renderer.write_header()
if request.had_tsig:
# Make the space we reserved for TSIG available for use
renderer.max_size += TSIG_RRSIZE
renderer.add_tsig(
request.keyname,
request.keyring[request.keyname],
request.fudge,
request.original_id,
request.tsig_error,
request.other_data,
request.mac,
request.keyalgorithm)
return renderer
def _handle_record_query(self, request):
"""Handle a DNS QUERY request for a record"""
context = request.environ['context']
@ -377,13 +321,13 @@ class RequestHandler(xfr.XFRMixin):
#
# To simply things currently this returns a REFUSED in all cases.
# If zone transfers needs different errors, we could revisit this.
LOG.info("NotFound, refusing. Question was %(qr)s",
LOG.info('NotFound, refusing. Question was %(qr)s',
{'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
except exceptions.Forbidden:
LOG.info("Forbidden, refusing. Question was %(qr)s",
LOG.info('Forbidden, refusing. Question was %(qr)s',
{'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
@ -394,14 +338,14 @@ class RequestHandler(xfr.XFRMixin):
zone = self.storage.find_zone(context, criterion)
except exceptions.ZoneNotFound:
LOG.warning("ZoneNotFound while handling query request. "
"Question was %(qr)s", {'qr': q_rrset})
LOG.warning('ZoneNotFound while handling query request. '
'Question was %(qr)s', {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
except exceptions.Forbidden:
LOG.warning("Forbidden while handling query request. "
"Question was %(qr)s", {'qr': q_rrset})
LOG.warning('Forbidden while handling query request. '
'Question was %(qr)s', {'qr': q_rrset})
yield self._handle_query_error(request, dns.rcode.REFUSED)
return
@ -411,3 +355,83 @@ class RequestHandler(xfr.XFRMixin):
# For all the data stored in designate mdns is Authoritative
response.flags |= dns.flags.AA
yield response
def _create_axfr_renderer(self, request):
# Build up a dummy response, we're stealing it's logic for building
# the Flags.
response = dns.message.make_response(request)
response.flags |= dns.flags.AA
response.set_rcode(dns.rcode.NOERROR)
max_message_size = self._get_max_message_size(request.had_tsig)
renderer = dns.renderer.Renderer(
response.id, response.flags, max_message_size)
for q in request.question:
renderer.add_question(q.name, q.rdtype, q.rdclass)
return renderer
@staticmethod
def _convert_to_rrset(zone, recordset):
# Fetch the zone or the config ttl if the recordset ttl is null
ttl = recordset.ttl or zone.ttl
# construct rdata from all the records
# TODO(Ron): this should be handled in the Storage query where we
# find the recordsets.
rdata = [str(record.data) for record in recordset.records
if record.action != 'DELETE']
# Now put the records into dnspython's RRsets
# answer section has 1 RR set. If the RR set has multiple
# records, DNSpython puts each record in a separate answer
# section.
# RRSet has name, ttl, class, type and rdata
# The rdata has one or more records
if not rdata:
return None
return dns.rrset.from_text_list(
recordset.name, ttl, dns.rdataclass.IN, recordset.type, rdata)
@staticmethod
def _finalize_packet(renderer, request):
renderer.write_header()
if request.had_tsig:
# Make the space we reserved for TSIG available for use
renderer.max_size += TSIG_RRSIZE
renderer.add_tsig(
request.keyname,
request.keyring[request.keyname],
request.fudge,
request.original_id,
request.tsig_error,
request.other_data,
request.mac,
request.keyalgorithm
)
return renderer
@staticmethod
def _get_max_message_size(had_tsig):
max_message_size = CONF['service:mdns'].max_message_size
if max_message_size > 65535:
LOG.warning('MDNS max message size must not be greater than 65535')
max_message_size = 65535
if had_tsig:
# Make some room for the TSIG RR to be appended at the end of the
# rendered message.
max_message_size = max_message_size - TSIG_RRSIZE
return max_message_size
@staticmethod
def _handle_query_error(request, rcode):
"""
Construct an error response with the rcode passed in.
:param request: The decoded request from the wire.
:param rcode: The response code to send back.
:return: A dns response message with the response code set to rcode
"""
response = dns.message.make_response(request)
response.set_rcode(rcode)
return response

View File

@ -15,11 +15,190 @@
# under the License.
import dns
import mock
from oslo_config import cfg
from oslo_config import fixture as cfg_fixture
import oslotest.base
from designate import exceptions
from designate import objects
from designate.mdns import handler
from designate.tests import fixtures
CONF = cfg.CONF
class MdnsHandleTest(oslotest.base.BaseTestCase):
def setUp(self):
super(MdnsHandleTest, self).setUp()
self.stdlog = fixtures.StandardLogging()
self.useFixture(self.stdlog)
self.useFixture(cfg_fixture.Config(CONF))
self.context = mock.Mock()
self.storage = mock.Mock()
self.tg = mock.Mock()
self.handler = handler.RequestHandler(self.storage, self.tg)
@mock.patch.object(dns.resolver.Resolver, 'query')
def test_notify(self, mock_query):
self.storage.find_zone.return_value = objects.Zone(
id='e2bed4dc-9d01-11e4-89d3-123b93f75cba',
serial=2,
masters=objects.ZoneMasterList.from_list([
{'host': '1.0.0.0', 'port': 53},
])
)
mock_query.return_value = [
mock.Mock(serial=1)
]
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.SOA
)
request.environ = dict(addr=['1.0.0.0'], context=self.context)
response = self.handler._handle_notify(request)
self.assertEqual(dns.rcode.NOERROR, tuple(response)[0].rcode())
self.assertIn(
'Scheduling AXFR for e2bed4dc-9d01-11e4-89d3-123b93f75cba '
'from 1.0.0.0:53',
self.stdlog.logger.output
)
@mock.patch.object(dns.resolver.Resolver, 'query')
def test_notify_same_serial(self, mock_query):
self.storage.find_zone.return_value = objects.Zone(
id='e2bed4dc-9d01-11e4-89d3-123b93f75cba',
serial=1,
masters=objects.ZoneMasterList.from_list([
{'host': '1.0.0.0', 'port': 53},
])
)
mock_query.return_value = [
mock.Mock(serial=1)
]
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.SOA
)
request.environ = dict(addr=['1.0.0.0'], context=self.context)
response = self.handler._handle_notify(request)
self.assertEqual(dns.rcode.NOERROR, tuple(response)[0].rcode())
self.assertIn(
'Serial 1 is the same for master and us for '
'e2bed4dc-9d01-11e4-89d3-123b93f75cba',
self.stdlog.logger.output
)
def test_notify_no_questions(self):
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.SOA
)
request.environ = dict(context=self.context)
request.question = []
response = self.handler._handle_notify(request)
self.assertEqual(dns.rcode.FORMERR, tuple(response)[0].rcode())
def test_notify_zone_not_found(self):
self.storage.find_zone.side_effect = exceptions.ZoneNotFound
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.SOA
)
request.environ = dict(context=self.context)
response = self.handler._handle_notify(request)
self.assertEqual(dns.rcode.NOTAUTH, tuple(response)[0].rcode())
def test_notify_no_master_addr(self):
self.storage.find_zone.return_value = objects.Zone(
masters=objects.ZoneMasterList.from_list([
{'host': '1.0.0.0', 'port': 53},
])
)
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.SOA
)
request.environ = dict(addr=['127.0.0.1', 53], context=self.context)
response = self.handler._handle_notify(request)
self.assertEqual(dns.rcode.REFUSED, tuple(response)[0].rcode())
self.assertIn(
'NOTIFY for None from non-master server 127.0.0.1, refusing.',
self.stdlog.logger.output
)
def test_axfr_zone_not_found(self):
self.storage.find_zone.side_effect = exceptions.ZoneNotFound
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.AXFR
)
request.environ = dict(context=self.context)
response = tuple(self.handler._handle_axfr(request))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
self.assertIn(
'ZoneNotFound while handling axfr request. '
'Question was www.example.org. IN AXFR',
self.stdlog.logger.output
)
def test_axfr_forbidden(self):
self.storage.find_zone.side_effect = exceptions.Forbidden
request = dns.message.make_query(
'www.example.org.', dns.rdatatype.AXFR
)
request.environ = dict(context=self.context)
response = tuple(self.handler._handle_axfr(request))
self.assertEqual(dns.rcode.REFUSED, response[0].rcode())
self.assertIn(
'Forbidden while handling axfr request. '
'Question was www.example.org. IN AXFR',
self.stdlog.logger.output
)
def test_get_max_message_size(self):
CONF.set_override('max_message_size', 32768, 'service:mdns')
self.assertEqual(
32768, self.handler._get_max_message_size(had_tsig=False)
)
def test_get_max_message_size_larger_than_allowed(self):
CONF.set_override('max_message_size', 65535 * 2, 'service:mdns')
self.assertEqual(
65535, self.handler._get_max_message_size(had_tsig=False)
)
self.assertIn(
'MDNS max message size must not be greater than 65535',
self.stdlog.logger.output
)
def test_get_max_message_with_tsig(self):
CONF.set_override('max_message_size', 65535, 'service:mdns')
self.assertEqual(
65300, self.handler._get_max_message_size(had_tsig=True)
)
class TestRequestHandlerCall(oslotest.base.BaseTestCase):
@ -138,7 +317,6 @@ class TestRequestHandlerCall(oslotest.base.BaseTestCase):
class HandleRecordQueryTest(oslotest.base.BaseTestCase):
def setUp(self):
super(HandleRecordQueryTest, self).setUp()
self.context = mock.Mock()