Cleanup DNS Middleware

- Moved DNS Middleware to separate file.
- Cleaned up logic.
- Fixed non-lazy loaded logging.

Change-Id: Ic1d7c17fc2d7a563ff22f0b1dd42a58f87665cf1
This commit is contained in:
Erik Olof Gunnar Andersson
2023-07-01 22:38:24 +02:00
parent 8cc87bce03
commit fa95018c91
10 changed files with 378 additions and 233 deletions
+3 -3
View File
@@ -30,7 +30,7 @@ from oslo_config import cfg
from designate.agent import handler
from designate.backend import agent_backend
from designate.conf.agent import DEFAULT_AGENT_PORT
from designate import dnsutils
from designate import dnsmiddleware
from designate import service
from designate import utils
@@ -80,7 +80,7 @@ class Service(service.Service):
# Create an instance of the RequestHandler class
application = handler.RequestHandler()
if cfg.CONF['service:agent'].notify_delay > 0.0:
application = dnsutils.LimitNotifyMiddleware(application)
application = dnsutils.SerializationMiddleware(application)
application = dnsmiddleware.LimitNotifyMiddleware(application)
application = dnsmiddleware.SerializationMiddleware(application)
return application
+214
View File
@@ -0,0 +1,214 @@
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Author: Endre Karlson <endre.karlson@hpe.com>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import time
import dns.exception
import dns.message
import dns.opcode
import dns.rcode
import dns.rdatatype
import dns.renderer
import dns.tsig
from oslo_log import log as logging
import designate.conf
from designate import context
from designate import dnsutils
from designate import exceptions
CONF = designate.conf.CONF
LOG = logging.getLogger(__name__)
class DNSMiddleware(object):
"""Base DNS Middleware class with some utility methods"""
def __init__(self, application):
self.application = application
def process_request(self, request):
"""Called on each request.
If this returns None, the next application down the stack will be
executed. If it returns a response then that response will be returned
and execution will stop here.
"""
return None
def process_response(self, response):
"""Do whatever you'd like to the response."""
return response
def __call__(self, request):
response = self.process_request(request)
if response:
return response
response = self.application(request)
return self.process_response(response)
def _build_error_response(self):
response = dns.message.make_response(
dns.message.make_query('unknown', dns.rdatatype.A))
response.set_rcode(dns.rcode.FORMERR)
return response
class SerializationMiddleware(DNSMiddleware):
"""DNS Middleware to serialize/deserialize DNS Packets"""
def __init__(self, application, tsig_keyring=None):
super(SerializationMiddleware, self).__init__(application)
self.tsig_keyring = tsig_keyring
def __call__(self, request):
# Generate the initial context. This may be updated by other middleware
# as we learn more information about the Request.
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
message = None
try:
message = dns.message.from_wire(request['payload'],
self.tsig_keyring)
if message.had_tsig:
LOG.debug('Request signed with TSIG key: %s', message.keyname)
# Create + Attach the initial "environ" dict. This is similar to
# the environ dict used in typical WSGI middleware.
message.environ = {
'context': ctxt,
'addr': request['addr'],
}
except dns.message.UnknownTSIGKey:
LOG.error(
'Unknown TSIG key from %(host)s:%(port)d',
{
'host': request['addr'][0],
'port': request['addr'][1]
}
)
except dns.tsig.BadSignature:
LOG.error(
'Invalid TSIG signature from %(host)s:%(port)d',
{
'host': request['addr'][0],
'port': request['addr'][1]
}
)
except dns.exception.DNSException:
LOG.error(
'Failed to deserialize packet from %(host)s:%(port)d',
{
'host': request['addr'][0],
'port': request['addr'][1]
}
)
except Exception:
LOG.exception(
'Unknown exception deserializing packet '
'from %(host)s %(port)d',
{
'host': request['addr'][0],
'port': request['addr'][1]
}
)
if message is None:
# NOTE(eandersson): Unsure on the intent of the error handling
# in this code. Cleaning the code path up, but
# leaving functionality as it was.
# error_response = self._build_error_response()
# yield error_response.to_wire()
return
# Hand the Deserialized packet onto the Application
for response in self.application(message):
# Serialize and return the response if present
if isinstance(response, dns.message.Message):
yield response.to_wire(max_size=65535)
elif isinstance(response, dns.renderer.Renderer):
yield response.get_wire()
else:
LOG.error('Unexpected response %r', response)
class TsigInfoMiddleware(DNSMiddleware):
"""Middleware which looks up the information available for a TsigKey"""
def __init__(self, application, storage):
super(TsigInfoMiddleware, self).__init__(application)
self.storage = storage
def process_request(self, request):
if not request.had_tsig:
return None
try:
name = request.keyname.to_text(True)
if isinstance(name, bytes):
name = name.decode('utf-8')
criterion = {'name': name}
tsigkey = self.storage.find_tsigkey(
context.get_current(), criterion
)
request.environ['tsigkey'] = tsigkey
request.environ['context'].tsigkey_id = tsigkey.id
except exceptions.TsigKeyNotFound:
# This should never happen, as we just validated the key.. Except
# for race conditions..
return self._build_error_response()
return None
class LimitNotifyMiddleware(DNSMiddleware):
"""Middleware that rate limits NOTIFYs to the Agent"""
def __init__(self, application):
super(LimitNotifyMiddleware, self).__init__(application)
self.delay = CONF['service:agent'].notify_delay
self.locker = dnsutils.ZoneLock(self.delay)
def process_request(self, request):
opcode = request.opcode()
if opcode != dns.opcode.NOTIFY:
return None
zone_name = request.question[0].name.to_text()
if isinstance(zone_name, bytes):
zone_name = zone_name.decode('utf-8')
if self.locker.acquire(zone_name):
time.sleep(self.delay)
self.locker.release(zone_name)
return None
else:
LOG.debug(
'Threw away NOTIFY for %(zone)s, already '
'working on an update.',
{
'zone': zone_name
}
)
response = dns.message.make_response(request)
# Provide an authoritative answer
response.flags |= dns.flags.AA
return (response,)
+13 -169
View File
@@ -15,11 +15,12 @@
# under the License.
import random
import socket
from threading import Lock
import threading
import time
import dns
import dns.exception
import dns.message
import dns.opcode
import dns.query
import dns.rdatatype
import dns.zone
@@ -36,137 +37,6 @@ CONF = designate.conf.CONF
LOG = logging.getLogger(__name__)
class DNSMiddleware(object):
"""Base DNS Middleware class with some utility methods"""
def __init__(self, application):
self.application = application
def process_request(self, request):
"""Called on each request.
If this returns None, the next application down the stack will be
executed. If it returns a response then that response will be returned
and execution will stop here.
"""
return None
def process_response(self, response):
"""Do whatever you'd like to the response."""
return response
def __call__(self, request):
response = self.process_request(request)
if response:
return response
response = self.application(request)
return self.process_response(response)
def _build_error_response(self):
response = dns.message.make_response(
dns.message.make_query('unknown', dns.rdatatype.A))
response.set_rcode(dns.rcode.FORMERR)
return response
class SerializationMiddleware(DNSMiddleware):
"""DNS Middleware to serialize/deserialize DNS Packets"""
def __init__(self, application, tsig_keyring=None):
self.application = application
self.tsig_keyring = tsig_keyring
def __call__(self, request):
# Generate the initial context. This may be updated by other middleware
# as we learn more information about the Request.
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
try:
message = dns.message.from_wire(request['payload'],
self.tsig_keyring)
if message.had_tsig:
LOG.debug('Request signed with TSIG key: %s', message.keyname)
# Create + Attach the initial "environ" dict. This is similar to
# the environ dict used in typical WSGI middleware.
message.environ = {
'context': ctxt,
'addr': request['addr'],
}
except dns.message.UnknownTSIGKey:
LOG.error("Unknown TSIG key from %(host)s:%(port)d",
{'host': request['addr'][0], 'port': request['addr'][1]})
response = self._build_error_response()
except dns.tsig.BadSignature:
LOG.error("Invalid TSIG signature from %(host)s:%(port)d",
{'host': request['addr'][0], 'port': request['addr'][1]})
response = self._build_error_response()
except dns.exception.DNSException:
LOG.error("Failed to deserialize packet from %(host)s:%(port)d",
{'host': request['addr'][0], 'port': request['addr'][1]})
response = self._build_error_response()
except Exception:
LOG.exception("Unknown exception deserializing packet "
"from %(host)s %(port)d",
{'host': request['addr'][0],
'port': request['addr'][1]})
response = self._build_error_response()
else:
# Hand the Deserialized packet onto the Application
for response in self.application(message):
# Serialize and return the response if present
if isinstance(response, dns.message.Message):
yield response.to_wire(max_size=65535)
elif isinstance(response, dns.renderer.Renderer):
yield response.get_wire()
else:
LOG.error("Unexpected response %r", response)
class TsigInfoMiddleware(DNSMiddleware):
"""Middleware which looks up the information available for a TsigKey"""
def __init__(self, application, storage):
super(TsigInfoMiddleware, self).__init__(application)
self.storage = storage
def process_request(self, request):
if not request.had_tsig:
return None
try:
name = request.keyname.to_text(True)
if isinstance(name, bytes):
name = name.decode('utf-8')
criterion = {'name': name}
tsigkey = self.storage.find_tsigkey(
context.get_current(), criterion)
request.environ['tsigkey'] = tsigkey
request.environ['context'].tsigkey_id = tsigkey.id
except exceptions.TsigKeyNotFound:
# This should never happen, as we just validated the key.. Except
# for race conditions..
return self._build_error_response()
return None
class TsigKeyring(dict):
"""Implements the DNSPython KeyRing API, backed by the Designate DB"""
@@ -184,7 +54,8 @@ class TsigKeyring(dict):
name = name.decode('utf-8')
criterion = {'name': name}
tsigkey = self.storage.find_tsigkey(
context.get_current(), criterion)
context.get_current(), criterion
)
return base64.decode_as_bytes(tsigkey.secret)
@@ -196,7 +67,7 @@ class ZoneLock(object):
"""A Lock across all zones that enforces a rate limit on NOTIFYs"""
def __init__(self, delay):
self.lock = Lock()
self.lock = threading.Lock()
self.data = {}
self.delay = delay
@@ -219,9 +90,13 @@ class ZoneLock(object):
self.data[zone] = now
return True
LOG.debug('Lock for %(zone)s can\'t be released for %(period)s'
'seconds' % {'zone': zone,
'period': str(self.delay - period)})
LOG.debug(
'Lock for %(zone)s can\'t be released for %(period)s seconds',
{
'zone': zone,
'period': str(self.delay - period)
}
)
# Don't grant the lock for the zone
return False
@@ -235,37 +110,6 @@ class ZoneLock(object):
pass
class LimitNotifyMiddleware(DNSMiddleware):
"""Middleware that rate limits NOTIFYs to the Agent"""
def __init__(self, application):
super(LimitNotifyMiddleware, self).__init__(application)
self.delay = CONF['service:agent'].notify_delay
self.locker = ZoneLock(self.delay)
def process_request(self, request):
opcode = request.opcode()
if opcode != dns.opcode.NOTIFY:
return None
zone_name = request.question[0].name.to_text()
if isinstance(zone_name, bytes):
zone_name = zone_name.decode('utf-8')
if self.locker.acquire(zone_name):
time.sleep(self.delay)
self.locker.release(zone_name)
return None
else:
LOG.debug('Threw away NOTIFY for %(zone)s, already '
'working on an update.' % {'zone': zone_name})
response = dns.message.make_response(request)
# Provide an authoritative answer
response.flags |= dns.flags.AA
return (response,)
def from_dnspython_zone(dnspython_zone):
# dnspython never builds a zone with more than one SOA, even if we give
# it a zonefile that contains more than one
+5 -2
View File
@@ -17,6 +17,7 @@ from oslo_config import cfg
from oslo_log import log as logging
from designate.conf.mdns import DEFAULT_MDNS_PORT
from designate import dnsmiddleware
from designate import dnsutils
from designate.mdns import handler
from designate import service
@@ -67,8 +68,10 @@ class Service(service.Service):
# Create an instance of the RequestHandler class and wrap with
# necessary middleware.
application = handler.RequestHandler(self.storage, self.tg)
application = dnsutils.TsigInfoMiddleware(application, self.storage)
application = dnsutils.SerializationMiddleware(
application = dnsmiddleware.TsigInfoMiddleware(
application, self.storage
)
application = dnsmiddleware.SerializationMiddleware(
application, dnsutils.TsigKeyring(self.storage)
)
+67
View File
@@ -0,0 +1,67 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
import dns
import dns.query
import dns.tsigkeyring
from oslo_config import cfg
from designate import dnsmiddleware
from designate import dnsutils
from designate.mdns import handler
from designate import storage
import designate.tests
CONF = cfg.CONF
class TestSerializationMiddleware(designate.tests.TestCase):
def setUp(self):
super(TestSerializationMiddleware, self).setUp()
self.storage = storage.get_storage()
self.tg = mock.Mock()
def test_with_tsigkeyring(self):
self.create_tsigkey(fixture=1)
query = dns.message.make_query(
'example.com.', dns.rdatatype.SOA,
)
query.use_tsig(dns.tsigkeyring.from_text(
{'test-key-two': 'AnotherSecretKey'})
)
payload = query.to_wire()
application = handler.RequestHandler(self.storage, self.tg)
application = dnsmiddleware.SerializationMiddleware(
application, dnsutils.TsigKeyring(self.storage)
)
self.assertTrue(next(application(
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
)))
def test_without_tsigkeyring(self):
query = dns.message.make_query(
'example.com.', dns.rdatatype.SOA,
)
payload = query.to_wire()
application = handler.RequestHandler(self.storage, self.tg)
application = dnsmiddleware.SerializationMiddleware(
application, dnsutils.TsigKeyring(self.storage)
)
self.assertTrue(next(application(
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
)))
+3 -45
View File
@@ -20,9 +20,9 @@ import dns.query
import dns.tsigkeyring
from oslo_config import cfg
from designate import dnsmiddleware
from designate import dnsutils
from designate import exceptions
from designate.mdns import handler
from designate import objects
from designate import storage
import designate.tests
@@ -84,48 +84,6 @@ SAMPLES = {
}
class TestSerializationMiddleware(designate.tests.TestCase):
def setUp(self):
super(TestSerializationMiddleware, self).setUp()
self.storage = storage.get_storage()
self.tg = mock.Mock()
def test_with_tsigkeyring(self):
self.create_tsigkey(fixture=1)
query = dns.message.make_query(
'example.com.', dns.rdatatype.SOA,
)
query.use_tsig(dns.tsigkeyring.from_text(
{'test-key-two': 'AnotherSecretKey'})
)
payload = query.to_wire()
application = handler.RequestHandler(self.storage, self.tg)
application = dnsutils.SerializationMiddleware(
application, dnsutils.TsigKeyring(self.storage)
)
self.assertTrue(next(application(
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
)))
def test_without_tsigkeyring(self):
query = dns.message.make_query(
'example.com.', dns.rdatatype.SOA,
)
payload = query.to_wire()
application = handler.RequestHandler(self.storage, self.tg)
application = dnsutils.SerializationMiddleware(
application, dnsutils.TsigKeyring(self.storage)
)
self.assertTrue(next(application(
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
)))
class TestTsigUtils(designate.tests.TestCase):
def setUp(self):
super(TestTsigUtils, self).setUp()
@@ -242,7 +200,7 @@ class TestUtils(designate.tests.TestCase):
# Initialize the middlware
placeholder_app = None
middleware = dnsutils.LimitNotifyMiddleware(placeholder_app)
middleware = dnsmiddleware.LimitNotifyMiddleware(placeholder_app)
# Prepare a NOTIFY
zone_name = 'example.com.'
@@ -261,7 +219,7 @@ class TestUtils(designate.tests.TestCase):
# Initialize the middlware
placeholder_app = None
middleware = dnsutils.LimitNotifyMiddleware(placeholder_app)
middleware = dnsmiddleware.LimitNotifyMiddleware(placeholder_app)
# Prepare a NOTIFY
zone_name = 'example.com.'
+8 -8
View File
@@ -74,12 +74,12 @@ class MdnsServiceTest(designate.tests.TestCase):
# NOTE: Start is already done by the fixture in start_service()
self.service.stop()
@mock.patch.object(dns.message, 'make_query')
def test_handle_empty_payload(self, query_mock):
@mock.patch.object(dns.message, 'from_wire')
def test_handle_empty_payload(self, mock_from_wire):
mock_socket = mock.Mock()
self.dns_service._dns_handle_udp_query(mock_socket, self.addr,
' '.encode('utf-8'))
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
mock_from_wire.assert_called_once_with(b' ', {})
def test_handle_udp_payload(self):
mock_socket = mock.Mock()
@@ -88,7 +88,7 @@ class MdnsServiceTest(designate.tests.TestCase):
mock_socket.sendto.assert_called_once_with(self.expected_response,
self.addr)
def test__dns_handle_tcp_conn_fail_unpack(self):
def test_dns_handle_tcp_conn_fail_unpack(self):
# will call recv() only once
mock_socket = mock.Mock()
mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack
@@ -97,7 +97,7 @@ class MdnsServiceTest(designate.tests.TestCase):
self.assertEqual(1, mock_socket.recv.call_count)
self.assertEqual(1, mock_socket.close.call_count)
def test__dns_handle_tcp_conn_one_query(self):
def test_dns_handle_tcp_conn_one_query(self):
payload = self.query_payload
mock_socket = mock.Mock()
pay_len = struct.pack("!H", len(payload))
@@ -114,7 +114,7 @@ class MdnsServiceTest(designate.tests.TestCase):
self.assertEqual(len(wire), expected_length + 2)
self.assertEqual(self.expected_response, wire[2:])
def test__dns_handle_tcp_conn_multiple_queries(self):
def test_dns_handle_tcp_conn_multiple_queries(self):
payload = self.query_payload
mock_socket = mock.Mock()
pay_len = struct.pack("!H", len(payload))
@@ -136,7 +136,7 @@ class MdnsServiceTest(designate.tests.TestCase):
self.assertEqual(5, mock_socket.sendall.call_count)
self.assertEqual(1, mock_socket.close.call_count)
def test__dns_handle_tcp_conn_multiple_queries_socket_error(self):
def test_dns_handle_tcp_conn_multiple_queries_socket_error(self):
payload = self.query_payload
mock_socket = mock.Mock()
pay_len = struct.pack("!H", len(payload))
@@ -158,7 +158,7 @@ class MdnsServiceTest(designate.tests.TestCase):
self.assertEqual(5, mock_socket.sendall.call_count)
self.assertEqual(1, mock_socket.close.call_count)
def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
def test_dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
payload = self.query_payload
mock_socket = mock.Mock()
pay_len = struct.pack("!H", len(payload))
+5 -3
View File
@@ -18,7 +18,7 @@ from unittest import mock
from designate.agent import service
from designate.backend import agent_backend
from designate.backend.agent_backend import impl_fake
from designate import dnsutils
from designate import dnsmiddleware
import designate.tests
from designate.tests import fixtures
from designate import utils
@@ -62,7 +62,8 @@ class AgentServiceTest(designate.tests.TestCase):
@mock.patch.object(utils, 'cache_result')
def test_get_dns_application(self, mock_cache_result):
self.assertIsInstance(
self.service.dns_application, dnsutils.SerializationMiddleware
self.service.dns_application,
dnsmiddleware.SerializationMiddleware
)
@mock.patch.object(utils, 'cache_result')
@@ -72,5 +73,6 @@ class AgentServiceTest(designate.tests.TestCase):
self.CONF.set_override('notify_delay', 1.0, 'service:agent')
self.assertIsInstance(
self.service.dns_application, dnsutils.SerializationMiddleware
self.service.dns_application,
dnsmiddleware.SerializationMiddleware
)
+2 -3
View File
@@ -19,10 +19,9 @@ from oslo_config import cfg
from oslo_config import fixture as cfg_fixture
import oslotest.base
import designate.dnsutils
from designate import dnsmiddleware
from designate.mdns import handler
from designate.mdns import service
import designate.rpc
import designate.service
from designate import storage
from designate.tests import fixtures
@@ -81,4 +80,4 @@ class MdnsServiceTest(oslotest.base.BaseTestCase):
app = self.service.dns_application
self.assertIsInstance(app, designate.dnsutils.DNSMiddleware)
self.assertIsInstance(app, dnsmiddleware.DNSMiddleware)
@@ -0,0 +1,58 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
import dns.message
from oslo_config import cfg
from designate import dnsmiddleware
import oslotest.base
CONF = cfg.CONF
class TestDNSMiddleware(oslotest.base.BaseTestCase):
def setUp(self):
super(TestDNSMiddleware, self).setUp()
self.application = mock.Mock(name='application')
self.dns_application = dnsmiddleware.DNSMiddleware(self.application)
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_request')
def test_call(self, mock_process_request):
request = mock.Mock()
self.dns_application(request)
mock_process_request.assert_called_with(request)
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_response')
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_request')
def test_call_with_none(self, mock_process_request, mock_process_response):
mock_process_request.return_value = None
self.dns_application(None)
mock_process_request.assert_called_with(None)
mock_process_response.assert_called_with(self.application())
def test_process_request(self):
self.assertIsNone(self.dns_application.process_request(mock.Mock()))
def test_process_response(self):
response = mock.Mock()
self.assertEqual(
response, self.dns_application.process_response(response)
)
def test_build_error_response(self):
self.assertIsInstance(
self.dns_application._build_error_response(), dns.message.Message
)