Service Class Cleanup - Part 3/3

Here, we rework the MiniDNS service to use the DNSService class, and move
code that was temporarily placed in the dnsutils module into the DNSService
class. Once this is merged, the newly moved code will be updated to handle
TSIG verification for the mDNS service.

Long term, we should probably follow the RPCService vs RPCServer (from
oslo.messaging) model, allowing for getting all of the packet parsing
code into a single separate class.

Change-Id: I2ef936e570a8c19c7b0145a32e0aed1ab0718fa7
This commit is contained in:
Kiall Mac Innes 2015-02-28 20:30:15 +00:00 committed by Kiall Mac Innes
parent deb65468e9
commit b256a6fc9e
6 changed files with 157 additions and 203 deletions

View File

@ -43,7 +43,8 @@ class Service(service.DNSService, service.Service):
def _dns_application(self):
# Create an instance of the RequestHandler class
application = handler.RequestHandler()
application = dnsutils.DNSMiddleware(application)
application = dnsutils.ContextMiddleware(application)
application = dnsutils.SerializationMiddleware(application)
return application

View File

@ -14,23 +14,57 @@
# License for the specific language governing permissions and limitations
# under the License.
import socket
import struct
import dns
import dns.zone
from dns import rdatatype
from oslo_log import log as logging
from designate import context
from designate import exceptions
from designate import objects
from designate.i18n import _LE
from designate.i18n import _LI
from designate.i18n import _LW
LOG = logging.getLogger(__name__)
class SerializationMiddleware(object):
"""DNS Middleware to serialize/deserialize DNS Packets"""
def __init__(self, application):
self.application = application
def __call__(self, request):
try:
message = dns.message.from_wire(request['payload'])
# Create + Attach the initial "environ" dict. This is similar to
# the environ dict used in typical WSGI middleware.
message.environ = {'addr': request['addr']}
except dns.exception.DNSException:
LOG.error(_LE("Failed to deserialize packet from %(host)s:"
"%(port)d") % {'host': request['addr'][0],
'port': request['addr'][1]})
# We failed to deserialize the request, generate a failure
# response using a made up request.
response = dns.message.make_response(
dns.message.make_query('unknown', dns.rdatatype.A))
response.set_rcode(dns.rcode.FORMERR)
else:
# Hand the Deserialized packet on
response = self.application(message)
# Serialize and return the response if present
if response is not None:
return response.to_wire()
class DNSMiddleware(object):
"""Base DNS Middleware class with some utility methods"""
def __init__(self, application):
self.application = application
@ -57,6 +91,20 @@ class DNSMiddleware(object):
return self.process_response(response)
class ContextMiddleware(DNSMiddleware):
"""Temporary ContextMiddleware which attaches an admin context to every
request
This will be replaced with a piece of middleware which generates, from
a TSIG signed request, an appropriate Request Context.
"""
def process_request(self, request):
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
request.environ['context'] = ctxt
return None
def from_dnspython_zone(dnspython_zone):
# dnspython never builds a zone with more than one SOA, even if we give
# it a zonefile that contains more than one
@ -113,35 +161,6 @@ def dnspythonrecord_to_recordset(rname, rdataset):
return rrset
def _deserialize_request(payload, addr):
"""
Deserialize a DNS Request Packet
:param payload: Raw DNS query payload
:param addr: Tuple of the client's (IP, Port)
"""
try:
request = dns.message.from_wire(payload)
except dns.exception.DNSException:
LOG.error(_LE("Failed to deserialize packet from %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]})
return None
else:
# Create + Attach the initial "environ" dict. This is similar to
# the environ dict used in typical WSGI middleware.
request.environ = {'addr': addr}
return request
def _serialize_response(response):
"""
Serialize a DNS Response Packet
:param response: DNS Response Message
"""
return response.to_wire()
def bind_tcp(host, port, tcp_backlog):
# Bind to the TCP port
LOG.info(_LI('Opening TCP Listening Socket on %(host)s:%(port)d') %
@ -165,93 +184,6 @@ def bind_udp(host, port):
return sock_udp
def handle_tcp(sock_tcp, tg, handle, application, timeout=None):
LOG.info(_LI("_handle_tcp thread started"))
while True:
client, addr = sock_tcp.accept()
if timeout:
client.settimeout(timeout)
LOG.debug("Handling TCP Request from: %(host)s:%(port)d" %
{'host': addr[0], 'port': addr[1]})
# Prepare a variable for the payload to be buffered
payload = ""
try:
# Receive the first 2 bytes containing the payload length
expected_length_raw = client.recv(2)
(expected_length, ) = struct.unpack('!H', expected_length_raw)
# Keep receiving data until we've got all the data we expect
while len(payload) < expected_length:
data = client.recv(65535)
if not data:
break
payload += data
except socket.timeout:
client.close()
LOG.warn(_LW("TCP Timeout from: %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]})
# Dispatch a thread to handle the query
tg.add_thread(handle, addr, payload, application, client=client)
def handle_udp(sock_udp, tg, handle, application):
LOG.info(_LI("_handle_udp thread started"))
while True:
# TODO(kiall): Determine the appropriate default value for
# UDP recvfrom.
payload, addr = sock_udp.recvfrom(8192)
LOG.debug("Handling UDP Request from: %(host)s:%(port)d" %
{'host': addr[0], 'port': addr[1]})
tg.add_thread(handle, addr, payload, application, sock_udp=sock_udp)
def handle(addr, payload, application, sock_udp=None, client=None):
"""
Handle a DNS Query
:param addr: Tuple of the client's (IP, Port)
:param payload: Raw DNS query payload
:param client: Client socket (for TCP only)
"""
try:
request = _deserialize_request(payload, addr)
if request is None:
# We failed to deserialize the request, generate a failure
# response using a made up request.
response = dns.message.make_response(
dns.message.make_query('unknown', dns.rdatatype.A))
response.set_rcode(dns.rcode.FORMERR)
else:
response = application(request)
# send back a response only if present
if response:
response = _serialize_response(response)
if client:
# Handle TCP Responses
msg_length = len(response)
tcp_response = struct.pack("!H", msg_length) + response
client.send(tcp_response)
client.close()
elif sock_udp:
# Handle UDP Responses
sock_udp.sendto(response, addr)
else:
LOG.warn(_LW("Both sock_udp and client are None"))
except Exception:
LOG.exception(_LE("Unhandled exception while processing request "
"from %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]})
def do_axfr(zone_name, masters):
"""
Performs an AXFR for a given zone name

View File

@ -1,31 +0,0 @@
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Author: Kiall Mac Innes <kiall@hp.com>
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from designate import context
from designate import dnsutils
class ContextMiddleware(dnsutils.DNSMiddleware):
"""Temporary ContextMiddleware which attaches an admin context to every
request
This will be replaced with a piece of middleware which generates, from
a TSIG signed request, an appropriate Request Context.
"""
def process_request(self, request):
ctxt = context.DesignateContext.get_admin_context(all_tenants=True)
request.environ['context'] = ctxt
return None

View File

@ -17,36 +17,19 @@ from oslo.config import cfg
from oslo_log import log as logging
from designate import utils
from designate import dnsutils
from designate import service
from designate import dnsutils
from designate.mdns import handler
from designate.mdns import middleware
from designate.mdns import notify
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class Service(service.RPCService, service.Service):
class Service(service.DNSService, service.RPCService, service.Service):
def __init__(self, threads=None):
super(Service, self).__init__(threads=threads)
# Create an instance of the RequestHandler class
self.application = handler.RequestHandler()
# Wrap the application in any middleware required
# TODO(kiall): In the future, we want to allow users to pick+choose
# the middleware to be applied, similar to how we do this
# in the API.
self.application = middleware.ContextMiddleware(self.application)
self._sock_tcp = dnsutils.bind_tcp(
CONF['service:mdns'].host, CONF['service:mdns'].port,
CONF['service:mdns'].tcp_backlog)
self._sock_udp = dnsutils.bind_udp(
CONF['service:mdns'].host, CONF['service:mdns'].port)
@property
def service_name(self):
return 'mdns'
@ -56,17 +39,12 @@ class Service(service.RPCService, service.Service):
def _rpc_endpoints(self):
return [notify.NotifyEndpoint()]
def start(self):
super(Service, self).start()
@property
@utils.cache_result
def _dns_application(self):
# Create an instance of the RequestHandler class
application = handler.RequestHandler()
application = dnsutils.ContextMiddleware(application)
application = dnsutils.SerializationMiddleware(application)
self.tg.add_thread(
dnsutils.handle_tcp, self._sock_tcp, self.tg, dnsutils.handle,
self.application, timeout=CONF['service:mdns'].tcp_recv_timeout)
self.tg.add_thread(
dnsutils.handle_udp, self._sock_udp, self.tg, dnsutils.handle,
self.application)
def stop(self):
# When the service is stopped, the threads for _handle_tcp and
# _handle_udp are stopped too.
super(Service, self).stop()
return application

View File

@ -19,6 +19,7 @@
# under the License.
import abc
import socket
import struct
import errno
import time
@ -32,6 +33,9 @@ from oslo_log import loggers
from designate.openstack.common import service
from designate.openstack.common import sslutils
from designate.i18n import _
from designate.i18n import _LE
from designate.i18n import _LI
from designate.i18n import _LW
from designate import rpc
from designate import policy
from designate import version
@ -161,6 +165,14 @@ class WSGIService(object):
def _wsgi_application(self):
pass
def start(self):
super(WSGIService, self).start()
socket = self._wsgi_get_socket()
application = self._wsgi_application
self.tg.add_thread(self._wsgi_handle, application, socket)
def _wsgi_get_socket(self):
# TODO(dims): eventlet's green dns/socket module does not actually
# support IPv6 in getaddrinfo(). We need to get around this in the
@ -205,14 +217,6 @@ class WSGIService(object):
return sock
def start(self):
super(WSGIService, self).start()
socket = self._wsgi_get_socket()
application = self._wsgi_application
self.tg.add_thread(self._wsgi_handle, application, socket)
def _wsgi_handle(self, application, socket):
logger = logging.getLogger('eventlet.wsgi')
eventlet.wsgi.server(socket,
@ -245,13 +249,8 @@ class DNSService(object):
def start(self):
super(DNSService, self).start()
self.tg.add_thread(
dnsutils.handle_tcp, self._dns_sock_tcp, self.tg, dnsutils.handle,
self._dns_application, self._service_config.tcp_recv_timeout)
self.tg.add_thread(
dnsutils.handle_udp, self._dns_sock_udp, self.tg, dnsutils.handle,
self._dns_application)
self.tg.add_thread(self._dns_handle_tcp)
self.tg.add_thread(self._dns_handle_udp)
def wait(self):
super(DNSService, self).wait()
@ -261,6 +260,85 @@ class DNSService(object):
# _handle_udp are stopped too.
super(DNSService, self).stop()
def _dns_handle_tcp(self):
LOG.info(_LI("_handle_tcp thread started"))
while True:
client, addr = self._dns_sock_tcp.accept()
if self._service_config.tcp_recv_timeout:
client.settimeout(self._service_config.tcp_recv_timeout)
LOG.debug("Handling TCP Request from: %(host)s:%(port)d" %
{'host': addr[0], 'port': addr[1]})
# Prepare a variable for the payload to be buffered
payload = ""
try:
# Receive the first 2 bytes containing the payload length
expected_length_raw = client.recv(2)
(expected_length, ) = struct.unpack('!H', expected_length_raw)
# Keep receiving data until we've got all the data we expect
while len(payload) < expected_length:
data = client.recv(65535)
if not data:
break
payload += data
except socket.timeout:
client.close()
LOG.warn(_LW("TCP Timeout from: %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]})
# Dispatch a thread to handle the query
self.tg.add_thread(self._dns_handle, addr, payload, client=client)
def _dns_handle_udp(self):
LOG.info(_LI("_handle_udp thread started"))
while True:
# TODO(kiall): Determine the appropriate default value for
# UDP recvfrom.
payload, addr = self._dns_sock_udp.recvfrom(8192)
LOG.debug("Handling UDP Request from: %(host)s:%(port)d" %
{'host': addr[0], 'port': addr[1]})
self.tg.add_thread(self._dns_handle, addr, payload)
def _dns_handle(self, addr, payload, client=None):
"""
Handle a DNS Query
:param addr: Tuple of the client's (IP, Port)
:param payload: Raw DNS query payload
:param client: Client socket (for TCP only)
"""
try:
# Call into the DNS Application itself with the payload and addr
response = self._dns_application({
'payload': payload,
'addr': addr
})
# Send back a response only if present
if response is not None:
if client:
# Handle TCP Responses
msg_length = len(response)
tcp_response = struct.pack("!H", msg_length) + response
client.send(tcp_response)
client.close()
else:
# Handle UDP Responses
self._dns_sock_udp.sendto(response, addr)
except Exception:
LOG.exception(_LE("Unhandled exception while processing request "
"from %(host)s:%(port)d") %
{'host': addr[0], 'port': addr[1]})
_launcher = None

View File

@ -20,7 +20,6 @@ import dns
import dns.message
import mock
from designate import dnsutils
from designate.tests.test_mdns import MdnsTestCase
@ -40,8 +39,7 @@ class MdnsServiceTest(MdnsTestCase):
@mock.patch.object(dns.message, 'make_query')
def test_handle_empty_payload(self, query_mock):
dnsutils.handle(self.addr, None, self.service.application,
sock_udp=self.service._sock_udp)
self.service._dns_handle(self.addr, None)
query_mock.assert_called_once_with('unknown', dns.rdatatype.A)
@mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock)
@ -62,8 +60,6 @@ class MdnsServiceTest(MdnsTestCase):
expected_response = ("271289050001000000000000076578616d706c6503636f6d"
"0000010001")
dnsutils.handle(
self.addr, binascii.a2b_hex(payload), self.service.application,
sock_udp=self.service._sock_udp)
self.service._dns_handle(self.addr, binascii.a2b_hex(payload))
sendto_mock.assert_called_once_with(
binascii.a2b_hex(expected_response), self.addr)