Cleanup DNS Middleware
- Moved DNS Middleware to separate file. - Cleaned up logic. - Fixed non-lazy loaded logging. Change-Id: Ic1d7c17fc2d7a563ff22f0b1dd42a58f87665cf1
This commit is contained in:
@@ -30,7 +30,7 @@ from oslo_config import cfg
|
||||
from designate.agent import handler
|
||||
from designate.backend import agent_backend
|
||||
from designate.conf.agent import DEFAULT_AGENT_PORT
|
||||
from designate import dnsutils
|
||||
from designate import dnsmiddleware
|
||||
from designate import service
|
||||
from designate import utils
|
||||
|
||||
@@ -80,7 +80,7 @@ class Service(service.Service):
|
||||
# Create an instance of the RequestHandler class
|
||||
application = handler.RequestHandler()
|
||||
if cfg.CONF['service:agent'].notify_delay > 0.0:
|
||||
application = dnsutils.LimitNotifyMiddleware(application)
|
||||
application = dnsutils.SerializationMiddleware(application)
|
||||
application = dnsmiddleware.LimitNotifyMiddleware(application)
|
||||
application = dnsmiddleware.SerializationMiddleware(application)
|
||||
|
||||
return application
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
# Copyright 2014 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Author: Endre Karlson <endre.karlson@hpe.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import time
|
||||
|
||||
import dns.exception
|
||||
import dns.message
|
||||
import dns.opcode
|
||||
import dns.rcode
|
||||
import dns.rdatatype
|
||||
import dns.renderer
|
||||
import dns.tsig
|
||||
from oslo_log import log as logging
|
||||
|
||||
import designate.conf
|
||||
from designate import context
|
||||
from designate import dnsutils
|
||||
from designate import exceptions
|
||||
|
||||
CONF = designate.conf.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DNSMiddleware(object):
|
||||
"""Base DNS Middleware class with some utility methods"""
|
||||
def __init__(self, application):
|
||||
self.application = application
|
||||
|
||||
def process_request(self, request):
|
||||
"""Called on each request.
|
||||
|
||||
If this returns None, the next application down the stack will be
|
||||
executed. If it returns a response then that response will be returned
|
||||
and execution will stop here.
|
||||
"""
|
||||
return None
|
||||
|
||||
def process_response(self, response):
|
||||
"""Do whatever you'd like to the response."""
|
||||
return response
|
||||
|
||||
def __call__(self, request):
|
||||
response = self.process_request(request)
|
||||
|
||||
if response:
|
||||
return response
|
||||
|
||||
response = self.application(request)
|
||||
return self.process_response(response)
|
||||
|
||||
def _build_error_response(self):
|
||||
response = dns.message.make_response(
|
||||
dns.message.make_query('unknown', dns.rdatatype.A))
|
||||
response.set_rcode(dns.rcode.FORMERR)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SerializationMiddleware(DNSMiddleware):
|
||||
"""DNS Middleware to serialize/deserialize DNS Packets"""
|
||||
|
||||
def __init__(self, application, tsig_keyring=None):
|
||||
super(SerializationMiddleware, self).__init__(application)
|
||||
self.tsig_keyring = tsig_keyring
|
||||
|
||||
def __call__(self, request):
|
||||
# Generate the initial context. This may be updated by other middleware
|
||||
# as we learn more information about the Request.
|
||||
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
|
||||
|
||||
message = None
|
||||
try:
|
||||
message = dns.message.from_wire(request['payload'],
|
||||
self.tsig_keyring)
|
||||
|
||||
if message.had_tsig:
|
||||
LOG.debug('Request signed with TSIG key: %s', message.keyname)
|
||||
|
||||
# Create + Attach the initial "environ" dict. This is similar to
|
||||
# the environ dict used in typical WSGI middleware.
|
||||
message.environ = {
|
||||
'context': ctxt,
|
||||
'addr': request['addr'],
|
||||
}
|
||||
except dns.message.UnknownTSIGKey:
|
||||
LOG.error(
|
||||
'Unknown TSIG key from %(host)s:%(port)d',
|
||||
{
|
||||
'host': request['addr'][0],
|
||||
'port': request['addr'][1]
|
||||
}
|
||||
)
|
||||
except dns.tsig.BadSignature:
|
||||
LOG.error(
|
||||
'Invalid TSIG signature from %(host)s:%(port)d',
|
||||
{
|
||||
'host': request['addr'][0],
|
||||
'port': request['addr'][1]
|
||||
}
|
||||
)
|
||||
except dns.exception.DNSException:
|
||||
LOG.error(
|
||||
'Failed to deserialize packet from %(host)s:%(port)d',
|
||||
{
|
||||
'host': request['addr'][0],
|
||||
'port': request['addr'][1]
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
LOG.exception(
|
||||
'Unknown exception deserializing packet '
|
||||
'from %(host)s %(port)d',
|
||||
{
|
||||
'host': request['addr'][0],
|
||||
'port': request['addr'][1]
|
||||
}
|
||||
)
|
||||
|
||||
if message is None:
|
||||
# NOTE(eandersson): Unsure on the intent of the error handling
|
||||
# in this code. Cleaning the code path up, but
|
||||
# leaving functionality as it was.
|
||||
# error_response = self._build_error_response()
|
||||
# yield error_response.to_wire()
|
||||
return
|
||||
|
||||
# Hand the Deserialized packet onto the Application
|
||||
for response in self.application(message):
|
||||
# Serialize and return the response if present
|
||||
if isinstance(response, dns.message.Message):
|
||||
yield response.to_wire(max_size=65535)
|
||||
elif isinstance(response, dns.renderer.Renderer):
|
||||
yield response.get_wire()
|
||||
else:
|
||||
LOG.error('Unexpected response %r', response)
|
||||
|
||||
|
||||
class TsigInfoMiddleware(DNSMiddleware):
|
||||
"""Middleware which looks up the information available for a TsigKey"""
|
||||
|
||||
def __init__(self, application, storage):
|
||||
super(TsigInfoMiddleware, self).__init__(application)
|
||||
self.storage = storage
|
||||
|
||||
def process_request(self, request):
|
||||
if not request.had_tsig:
|
||||
return None
|
||||
|
||||
try:
|
||||
name = request.keyname.to_text(True)
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode('utf-8')
|
||||
criterion = {'name': name}
|
||||
tsigkey = self.storage.find_tsigkey(
|
||||
context.get_current(), criterion
|
||||
)
|
||||
|
||||
request.environ['tsigkey'] = tsigkey
|
||||
request.environ['context'].tsigkey_id = tsigkey.id
|
||||
|
||||
except exceptions.TsigKeyNotFound:
|
||||
# This should never happen, as we just validated the key.. Except
|
||||
# for race conditions..
|
||||
return self._build_error_response()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class LimitNotifyMiddleware(DNSMiddleware):
|
||||
"""Middleware that rate limits NOTIFYs to the Agent"""
|
||||
|
||||
def __init__(self, application):
|
||||
super(LimitNotifyMiddleware, self).__init__(application)
|
||||
|
||||
self.delay = CONF['service:agent'].notify_delay
|
||||
self.locker = dnsutils.ZoneLock(self.delay)
|
||||
|
||||
def process_request(self, request):
|
||||
opcode = request.opcode()
|
||||
if opcode != dns.opcode.NOTIFY:
|
||||
return None
|
||||
|
||||
zone_name = request.question[0].name.to_text()
|
||||
if isinstance(zone_name, bytes):
|
||||
zone_name = zone_name.decode('utf-8')
|
||||
|
||||
if self.locker.acquire(zone_name):
|
||||
time.sleep(self.delay)
|
||||
self.locker.release(zone_name)
|
||||
return None
|
||||
else:
|
||||
LOG.debug(
|
||||
'Threw away NOTIFY for %(zone)s, already '
|
||||
'working on an update.',
|
||||
{
|
||||
'zone': zone_name
|
||||
}
|
||||
)
|
||||
response = dns.message.make_response(request)
|
||||
# Provide an authoritative answer
|
||||
response.flags |= dns.flags.AA
|
||||
return (response,)
|
||||
+13
-169
@@ -15,11 +15,12 @@
|
||||
# under the License.
|
||||
import random
|
||||
import socket
|
||||
from threading import Lock
|
||||
import threading
|
||||
import time
|
||||
|
||||
import dns
|
||||
import dns.exception
|
||||
import dns.message
|
||||
import dns.opcode
|
||||
import dns.query
|
||||
import dns.rdatatype
|
||||
import dns.zone
|
||||
@@ -36,137 +37,6 @@ CONF = designate.conf.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DNSMiddleware(object):
|
||||
"""Base DNS Middleware class with some utility methods"""
|
||||
def __init__(self, application):
|
||||
self.application = application
|
||||
|
||||
def process_request(self, request):
|
||||
"""Called on each request.
|
||||
|
||||
If this returns None, the next application down the stack will be
|
||||
executed. If it returns a response then that response will be returned
|
||||
and execution will stop here.
|
||||
"""
|
||||
return None
|
||||
|
||||
def process_response(self, response):
|
||||
"""Do whatever you'd like to the response."""
|
||||
return response
|
||||
|
||||
def __call__(self, request):
|
||||
response = self.process_request(request)
|
||||
|
||||
if response:
|
||||
return response
|
||||
|
||||
response = self.application(request)
|
||||
return self.process_response(response)
|
||||
|
||||
def _build_error_response(self):
|
||||
response = dns.message.make_response(
|
||||
dns.message.make_query('unknown', dns.rdatatype.A))
|
||||
response.set_rcode(dns.rcode.FORMERR)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SerializationMiddleware(DNSMiddleware):
|
||||
"""DNS Middleware to serialize/deserialize DNS Packets"""
|
||||
|
||||
def __init__(self, application, tsig_keyring=None):
|
||||
self.application = application
|
||||
self.tsig_keyring = tsig_keyring
|
||||
|
||||
def __call__(self, request):
|
||||
# Generate the initial context. This may be updated by other middleware
|
||||
# as we learn more information about the Request.
|
||||
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
|
||||
|
||||
try:
|
||||
message = dns.message.from_wire(request['payload'],
|
||||
self.tsig_keyring)
|
||||
|
||||
if message.had_tsig:
|
||||
LOG.debug('Request signed with TSIG key: %s', message.keyname)
|
||||
|
||||
# Create + Attach the initial "environ" dict. This is similar to
|
||||
# the environ dict used in typical WSGI middleware.
|
||||
message.environ = {
|
||||
'context': ctxt,
|
||||
'addr': request['addr'],
|
||||
}
|
||||
|
||||
except dns.message.UnknownTSIGKey:
|
||||
LOG.error("Unknown TSIG key from %(host)s:%(port)d",
|
||||
{'host': request['addr'][0], 'port': request['addr'][1]})
|
||||
|
||||
response = self._build_error_response()
|
||||
|
||||
except dns.tsig.BadSignature:
|
||||
LOG.error("Invalid TSIG signature from %(host)s:%(port)d",
|
||||
{'host': request['addr'][0], 'port': request['addr'][1]})
|
||||
|
||||
response = self._build_error_response()
|
||||
|
||||
except dns.exception.DNSException:
|
||||
LOG.error("Failed to deserialize packet from %(host)s:%(port)d",
|
||||
{'host': request['addr'][0], 'port': request['addr'][1]})
|
||||
|
||||
response = self._build_error_response()
|
||||
|
||||
except Exception:
|
||||
LOG.exception("Unknown exception deserializing packet "
|
||||
"from %(host)s %(port)d",
|
||||
{'host': request['addr'][0],
|
||||
'port': request['addr'][1]})
|
||||
|
||||
response = self._build_error_response()
|
||||
|
||||
else:
|
||||
# Hand the Deserialized packet onto the Application
|
||||
for response in self.application(message):
|
||||
# Serialize and return the response if present
|
||||
if isinstance(response, dns.message.Message):
|
||||
yield response.to_wire(max_size=65535)
|
||||
|
||||
elif isinstance(response, dns.renderer.Renderer):
|
||||
yield response.get_wire()
|
||||
|
||||
else:
|
||||
LOG.error("Unexpected response %r", response)
|
||||
|
||||
|
||||
class TsigInfoMiddleware(DNSMiddleware):
|
||||
"""Middleware which looks up the information available for a TsigKey"""
|
||||
|
||||
def __init__(self, application, storage):
|
||||
super(TsigInfoMiddleware, self).__init__(application)
|
||||
self.storage = storage
|
||||
|
||||
def process_request(self, request):
|
||||
if not request.had_tsig:
|
||||
return None
|
||||
|
||||
try:
|
||||
name = request.keyname.to_text(True)
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode('utf-8')
|
||||
criterion = {'name': name}
|
||||
tsigkey = self.storage.find_tsigkey(
|
||||
context.get_current(), criterion)
|
||||
|
||||
request.environ['tsigkey'] = tsigkey
|
||||
request.environ['context'].tsigkey_id = tsigkey.id
|
||||
|
||||
except exceptions.TsigKeyNotFound:
|
||||
# This should never happen, as we just validated the key.. Except
|
||||
# for race conditions..
|
||||
return self._build_error_response()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class TsigKeyring(dict):
|
||||
"""Implements the DNSPython KeyRing API, backed by the Designate DB"""
|
||||
|
||||
@@ -184,7 +54,8 @@ class TsigKeyring(dict):
|
||||
name = name.decode('utf-8')
|
||||
criterion = {'name': name}
|
||||
tsigkey = self.storage.find_tsigkey(
|
||||
context.get_current(), criterion)
|
||||
context.get_current(), criterion
|
||||
)
|
||||
|
||||
return base64.decode_as_bytes(tsigkey.secret)
|
||||
|
||||
@@ -196,7 +67,7 @@ class ZoneLock(object):
|
||||
"""A Lock across all zones that enforces a rate limit on NOTIFYs"""
|
||||
|
||||
def __init__(self, delay):
|
||||
self.lock = Lock()
|
||||
self.lock = threading.Lock()
|
||||
self.data = {}
|
||||
self.delay = delay
|
||||
|
||||
@@ -219,9 +90,13 @@ class ZoneLock(object):
|
||||
self.data[zone] = now
|
||||
return True
|
||||
|
||||
LOG.debug('Lock for %(zone)s can\'t be released for %(period)s'
|
||||
'seconds' % {'zone': zone,
|
||||
'period': str(self.delay - period)})
|
||||
LOG.debug(
|
||||
'Lock for %(zone)s can\'t be released for %(period)s seconds',
|
||||
{
|
||||
'zone': zone,
|
||||
'period': str(self.delay - period)
|
||||
}
|
||||
)
|
||||
|
||||
# Don't grant the lock for the zone
|
||||
return False
|
||||
@@ -235,37 +110,6 @@ class ZoneLock(object):
|
||||
pass
|
||||
|
||||
|
||||
class LimitNotifyMiddleware(DNSMiddleware):
|
||||
"""Middleware that rate limits NOTIFYs to the Agent"""
|
||||
|
||||
def __init__(self, application):
|
||||
super(LimitNotifyMiddleware, self).__init__(application)
|
||||
|
||||
self.delay = CONF['service:agent'].notify_delay
|
||||
self.locker = ZoneLock(self.delay)
|
||||
|
||||
def process_request(self, request):
|
||||
opcode = request.opcode()
|
||||
if opcode != dns.opcode.NOTIFY:
|
||||
return None
|
||||
|
||||
zone_name = request.question[0].name.to_text()
|
||||
if isinstance(zone_name, bytes):
|
||||
zone_name = zone_name.decode('utf-8')
|
||||
|
||||
if self.locker.acquire(zone_name):
|
||||
time.sleep(self.delay)
|
||||
self.locker.release(zone_name)
|
||||
return None
|
||||
else:
|
||||
LOG.debug('Threw away NOTIFY for %(zone)s, already '
|
||||
'working on an update.' % {'zone': zone_name})
|
||||
response = dns.message.make_response(request)
|
||||
# Provide an authoritative answer
|
||||
response.flags |= dns.flags.AA
|
||||
return (response,)
|
||||
|
||||
|
||||
def from_dnspython_zone(dnspython_zone):
|
||||
# dnspython never builds a zone with more than one SOA, even if we give
|
||||
# it a zonefile that contains more than one
|
||||
|
||||
@@ -17,6 +17,7 @@ from oslo_config import cfg
|
||||
from oslo_log import log as logging
|
||||
|
||||
from designate.conf.mdns import DEFAULT_MDNS_PORT
|
||||
from designate import dnsmiddleware
|
||||
from designate import dnsutils
|
||||
from designate.mdns import handler
|
||||
from designate import service
|
||||
@@ -67,8 +68,10 @@ class Service(service.Service):
|
||||
# Create an instance of the RequestHandler class and wrap with
|
||||
# necessary middleware.
|
||||
application = handler.RequestHandler(self.storage, self.tg)
|
||||
application = dnsutils.TsigInfoMiddleware(application, self.storage)
|
||||
application = dnsutils.SerializationMiddleware(
|
||||
application = dnsmiddleware.TsigInfoMiddleware(
|
||||
application, self.storage
|
||||
)
|
||||
application = dnsmiddleware.SerializationMiddleware(
|
||||
application, dnsutils.TsigKeyring(self.storage)
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import dns
|
||||
import dns.query
|
||||
import dns.tsigkeyring
|
||||
from oslo_config import cfg
|
||||
|
||||
from designate import dnsmiddleware
|
||||
from designate import dnsutils
|
||||
from designate.mdns import handler
|
||||
from designate import storage
|
||||
import designate.tests
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
|
||||
class TestSerializationMiddleware(designate.tests.TestCase):
|
||||
def setUp(self):
|
||||
super(TestSerializationMiddleware, self).setUp()
|
||||
self.storage = storage.get_storage()
|
||||
self.tg = mock.Mock()
|
||||
|
||||
def test_with_tsigkeyring(self):
|
||||
self.create_tsigkey(fixture=1)
|
||||
|
||||
query = dns.message.make_query(
|
||||
'example.com.', dns.rdatatype.SOA,
|
||||
)
|
||||
query.use_tsig(dns.tsigkeyring.from_text(
|
||||
{'test-key-two': 'AnotherSecretKey'})
|
||||
)
|
||||
payload = query.to_wire()
|
||||
|
||||
application = handler.RequestHandler(self.storage, self.tg)
|
||||
application = dnsmiddleware.SerializationMiddleware(
|
||||
application, dnsutils.TsigKeyring(self.storage)
|
||||
)
|
||||
|
||||
self.assertTrue(next(application(
|
||||
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
|
||||
)))
|
||||
|
||||
def test_without_tsigkeyring(self):
|
||||
query = dns.message.make_query(
|
||||
'example.com.', dns.rdatatype.SOA,
|
||||
)
|
||||
payload = query.to_wire()
|
||||
|
||||
application = handler.RequestHandler(self.storage, self.tg)
|
||||
application = dnsmiddleware.SerializationMiddleware(
|
||||
application, dnsutils.TsigKeyring(self.storage)
|
||||
)
|
||||
|
||||
self.assertTrue(next(application(
|
||||
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
|
||||
)))
|
||||
@@ -20,9 +20,9 @@ import dns.query
|
||||
import dns.tsigkeyring
|
||||
from oslo_config import cfg
|
||||
|
||||
from designate import dnsmiddleware
|
||||
from designate import dnsutils
|
||||
from designate import exceptions
|
||||
from designate.mdns import handler
|
||||
from designate import objects
|
||||
from designate import storage
|
||||
import designate.tests
|
||||
@@ -84,48 +84,6 @@ SAMPLES = {
|
||||
}
|
||||
|
||||
|
||||
class TestSerializationMiddleware(designate.tests.TestCase):
|
||||
def setUp(self):
|
||||
super(TestSerializationMiddleware, self).setUp()
|
||||
self.storage = storage.get_storage()
|
||||
self.tg = mock.Mock()
|
||||
|
||||
def test_with_tsigkeyring(self):
|
||||
self.create_tsigkey(fixture=1)
|
||||
|
||||
query = dns.message.make_query(
|
||||
'example.com.', dns.rdatatype.SOA,
|
||||
)
|
||||
query.use_tsig(dns.tsigkeyring.from_text(
|
||||
{'test-key-two': 'AnotherSecretKey'})
|
||||
)
|
||||
payload = query.to_wire()
|
||||
|
||||
application = handler.RequestHandler(self.storage, self.tg)
|
||||
application = dnsutils.SerializationMiddleware(
|
||||
application, dnsutils.TsigKeyring(self.storage)
|
||||
)
|
||||
|
||||
self.assertTrue(next(application(
|
||||
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
|
||||
)))
|
||||
|
||||
def test_without_tsigkeyring(self):
|
||||
query = dns.message.make_query(
|
||||
'example.com.', dns.rdatatype.SOA,
|
||||
)
|
||||
payload = query.to_wire()
|
||||
|
||||
application = handler.RequestHandler(self.storage, self.tg)
|
||||
application = dnsutils.SerializationMiddleware(
|
||||
application, dnsutils.TsigKeyring(self.storage)
|
||||
)
|
||||
|
||||
self.assertTrue(next(application(
|
||||
{'payload': payload, 'addr': ['192.0.2.1', 5353]}
|
||||
)))
|
||||
|
||||
|
||||
class TestTsigUtils(designate.tests.TestCase):
|
||||
def setUp(self):
|
||||
super(TestTsigUtils, self).setUp()
|
||||
@@ -242,7 +200,7 @@ class TestUtils(designate.tests.TestCase):
|
||||
|
||||
# Initialize the middlware
|
||||
placeholder_app = None
|
||||
middleware = dnsutils.LimitNotifyMiddleware(placeholder_app)
|
||||
middleware = dnsmiddleware.LimitNotifyMiddleware(placeholder_app)
|
||||
|
||||
# Prepare a NOTIFY
|
||||
zone_name = 'example.com.'
|
||||
@@ -261,7 +219,7 @@ class TestUtils(designate.tests.TestCase):
|
||||
|
||||
# Initialize the middlware
|
||||
placeholder_app = None
|
||||
middleware = dnsutils.LimitNotifyMiddleware(placeholder_app)
|
||||
middleware = dnsmiddleware.LimitNotifyMiddleware(placeholder_app)
|
||||
|
||||
# Prepare a NOTIFY
|
||||
zone_name = 'example.com.'
|
||||
|
||||
@@ -74,12 +74,12 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
# NOTE: Start is already done by the fixture in start_service()
|
||||
self.service.stop()
|
||||
|
||||
@mock.patch.object(dns.message, 'make_query')
|
||||
def test_handle_empty_payload(self, query_mock):
|
||||
@mock.patch.object(dns.message, 'from_wire')
|
||||
def test_handle_empty_payload(self, mock_from_wire):
|
||||
mock_socket = mock.Mock()
|
||||
self.dns_service._dns_handle_udp_query(mock_socket, self.addr,
|
||||
' '.encode('utf-8'))
|
||||
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
|
||||
mock_from_wire.assert_called_once_with(b' ', {})
|
||||
|
||||
def test_handle_udp_payload(self):
|
||||
mock_socket = mock.Mock()
|
||||
@@ -88,7 +88,7 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
mock_socket.sendto.assert_called_once_with(self.expected_response,
|
||||
self.addr)
|
||||
|
||||
def test__dns_handle_tcp_conn_fail_unpack(self):
|
||||
def test_dns_handle_tcp_conn_fail_unpack(self):
|
||||
# will call recv() only once
|
||||
mock_socket = mock.Mock()
|
||||
mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack
|
||||
@@ -97,7 +97,7 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
self.assertEqual(1, mock_socket.recv.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
def test__dns_handle_tcp_conn_one_query(self):
|
||||
def test_dns_handle_tcp_conn_one_query(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
@@ -114,7 +114,7 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
self.assertEqual(len(wire), expected_length + 2)
|
||||
self.assertEqual(self.expected_response, wire[2:])
|
||||
|
||||
def test__dns_handle_tcp_conn_multiple_queries(self):
|
||||
def test_dns_handle_tcp_conn_multiple_queries(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
@@ -136,7 +136,7 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
def test__dns_handle_tcp_conn_multiple_queries_socket_error(self):
|
||||
def test_dns_handle_tcp_conn_multiple_queries_socket_error(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
@@ -158,7 +158,7 @@ class MdnsServiceTest(designate.tests.TestCase):
|
||||
self.assertEqual(5, mock_socket.sendall.call_count)
|
||||
self.assertEqual(1, mock_socket.close.call_count)
|
||||
|
||||
def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
|
||||
def test_dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self):
|
||||
payload = self.query_payload
|
||||
mock_socket = mock.Mock()
|
||||
pay_len = struct.pack("!H", len(payload))
|
||||
|
||||
@@ -18,7 +18,7 @@ from unittest import mock
|
||||
from designate.agent import service
|
||||
from designate.backend import agent_backend
|
||||
from designate.backend.agent_backend import impl_fake
|
||||
from designate import dnsutils
|
||||
from designate import dnsmiddleware
|
||||
import designate.tests
|
||||
from designate.tests import fixtures
|
||||
from designate import utils
|
||||
@@ -62,7 +62,8 @@ class AgentServiceTest(designate.tests.TestCase):
|
||||
@mock.patch.object(utils, 'cache_result')
|
||||
def test_get_dns_application(self, mock_cache_result):
|
||||
self.assertIsInstance(
|
||||
self.service.dns_application, dnsutils.SerializationMiddleware
|
||||
self.service.dns_application,
|
||||
dnsmiddleware.SerializationMiddleware
|
||||
)
|
||||
|
||||
@mock.patch.object(utils, 'cache_result')
|
||||
@@ -72,5 +73,6 @@ class AgentServiceTest(designate.tests.TestCase):
|
||||
self.CONF.set_override('notify_delay', 1.0, 'service:agent')
|
||||
|
||||
self.assertIsInstance(
|
||||
self.service.dns_application, dnsutils.SerializationMiddleware
|
||||
self.service.dns_application,
|
||||
dnsmiddleware.SerializationMiddleware
|
||||
)
|
||||
|
||||
@@ -19,10 +19,9 @@ from oslo_config import cfg
|
||||
from oslo_config import fixture as cfg_fixture
|
||||
import oslotest.base
|
||||
|
||||
import designate.dnsutils
|
||||
from designate import dnsmiddleware
|
||||
from designate.mdns import handler
|
||||
from designate.mdns import service
|
||||
import designate.rpc
|
||||
import designate.service
|
||||
from designate import storage
|
||||
from designate.tests import fixtures
|
||||
@@ -81,4 +80,4 @@ class MdnsServiceTest(oslotest.base.BaseTestCase):
|
||||
|
||||
app = self.service.dns_application
|
||||
|
||||
self.assertIsInstance(app, designate.dnsutils.DNSMiddleware)
|
||||
self.assertIsInstance(app, dnsmiddleware.DNSMiddleware)
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock
|
||||
|
||||
import dns.message
|
||||
from oslo_config import cfg
|
||||
|
||||
from designate import dnsmiddleware
|
||||
import oslotest.base
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
||||
|
||||
class TestDNSMiddleware(oslotest.base.BaseTestCase):
|
||||
def setUp(self):
|
||||
super(TestDNSMiddleware, self).setUp()
|
||||
self.application = mock.Mock(name='application')
|
||||
self.dns_application = dnsmiddleware.DNSMiddleware(self.application)
|
||||
|
||||
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_request')
|
||||
def test_call(self, mock_process_request):
|
||||
request = mock.Mock()
|
||||
self.dns_application(request)
|
||||
|
||||
mock_process_request.assert_called_with(request)
|
||||
|
||||
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_response')
|
||||
@mock.patch.object(dnsmiddleware.DNSMiddleware, 'process_request')
|
||||
def test_call_with_none(self, mock_process_request, mock_process_response):
|
||||
mock_process_request.return_value = None
|
||||
|
||||
self.dns_application(None)
|
||||
|
||||
mock_process_request.assert_called_with(None)
|
||||
mock_process_response.assert_called_with(self.application())
|
||||
|
||||
def test_process_request(self):
|
||||
self.assertIsNone(self.dns_application.process_request(mock.Mock()))
|
||||
|
||||
def test_process_response(self):
|
||||
response = mock.Mock()
|
||||
self.assertEqual(
|
||||
response, self.dns_application.process_response(response)
|
||||
)
|
||||
|
||||
def test_build_error_response(self):
|
||||
self.assertIsInstance(
|
||||
self.dns_application._build_error_response(), dns.message.Message
|
||||
)
|
||||
Reference in New Issue
Block a user