diff --git a/designate/service.py b/designate/service.py index 773af1577..1515d4f59 100644 --- a/designate/service.py +++ b/designate/service.py @@ -420,6 +420,12 @@ class DNSService(object): client.close() def _dns_handle_udp(self, sock_udp): + """Handle a DNS Query over UDP in a dedicated thread + + :param sock_udp: UDP socket + :type sock_udp: socket + :raises: None + """ LOG.info(_LI("_handle_udp thread started")) while True: @@ -432,8 +438,8 @@ class DNSService(object): {'host': addr[0], 'port': addr[1]}) # Dispatch a thread to handle the query - self.tg.add_thread(self._dns_handle, addr, payload, - sock_udp=sock_udp) + self.tg.add_thread(self._dns_handle_udp_query, sock_udp, addr, + payload) except socket.error as e: errname = errno.errorcode[e.args[0]] @@ -446,13 +452,17 @@ class DNSService(object): "from: %(host)s:%(port)d") % {'host': addr[0], 'port': addr[1]}) - def _dns_handle(self, addr, payload, client=None, sock_udp=None): + def _dns_handle_udp_query(self, sock, addr, payload): """ - Handle a DNS Query + Handle a DNS Query over UDP + :param sock: UDP socket + :type sock: socket :param addr: Tuple of the client's (IP, Port) + :type addr: tuple :param payload: Raw DNS query payload - :param client: Client socket (for TCP only) + :type payload: string + :raises: None """ try: # Call into the DNS Application itself with the payload and addr @@ -461,24 +471,13 @@ class DNSService(object): # Send back a response only if present if response is not None: - if client: - # Handle TCP Responses - msg_length = len(response) - tcp_response = struct.pack("!H", msg_length) + response - client.sendall(tcp_response) - else: - # Handle UDP Responses - sock_udp.sendto(response, addr) + sock.sendto(response, addr) except Exception: LOG.exception(_LE("Unhandled exception while processing request " "from %(host)s:%(port)d") % {'host': addr[0], 'port': addr[1]}) - # Close the TCP connection if we have one. - if client: - client.close() - _launcher = None diff --git a/designate/tests/test_mdns/test_service.py b/designate/tests/test_mdns/test_service.py index b66768cb5..8f625cc9b 100644 --- a/designate/tests/test_mdns/test_service.py +++ b/designate/tests/test_mdns/test_service.py @@ -15,7 +15,9 @@ # under the License. import binascii +import errno import socket +import struct import dns import dns.message @@ -32,6 +34,27 @@ def hex_wire(response): class MdnsServiceTest(MdnsTestCase): + + # DNS packet with IQUERY opcode + query_payload = binascii.a2b_hex( + "271209000001000000000000076578616d706c6503636f6d0000010001" + ) + expected_response = binascii.a2b_hex( + b"271289050001000000000000076578616d706c6503636f6d0000010001" + ) + # expected response is an error code REFUSED. The other fields are + # id 10002 + # opcode IQUERY + # rcode REFUSED + # flags QR RD + # ;QUESTION + # example.com. IN A + # ;ANSWER + # ;AUTHORITY + # ;ADDITIONAL + + # Use self._print_dns_msg() to display the messages + def setUp(self): super(MdnsServiceTest, self).setUp() @@ -41,147 +64,115 @@ class MdnsServiceTest(MdnsTestCase): self.service = self.start_service('mdns') self.addr = ['0.0.0.0', 5556] + @staticmethod + def _print_dns_msg(desc, wire): + """Print DNS message for debugging""" + q = dns.message.from_wire(wire).to_text() + print("%s:\n%s\n" % (desc, q)) + def test_stop(self): # NOTE: Start is already done by the fixture in start_service() self.service.stop() @mock.patch.object(dns.message, 'make_query') def test_handle_empty_payload(self, query_mock): - self.service._dns_handle(self.addr, ' '.encode('utf-8')) + mock_socket = mock.Mock() + self.service._dns_handle_udp_query(mock_socket, self.addr, + ' '.encode('utf-8')) query_mock.assert_called_once_with('unknown', dns.rdatatype.A) - @mock.patch.object(socket.socket, 'sendto', new_callable=mock.MagicMock) - def test_handle_udp_payload(self, sendto_mock): - # DNS packet with IQUERY opcode - payload = "271209000001000000000000076578616d706c6503636f6d0000010001" + def test_handle_udp_payload(self): + mock_socket = mock.Mock() + self.service._dns_handle_udp_query(mock_socket, self.addr, + self.query_payload) + mock_socket.sendto.assert_called_once_with(self.expected_response, + self.addr) - # expected response is an error code REFUSED. The other fields are - # id 10002 - # opcode IQUERY - # rcode REFUSED - # flags QR RD - # ;QUESTION - # example.com. IN A - # ;ANSWER - # ;AUTHORITY - # ;ADDITIONAL - expected_response = (b"271289050001000000000000076578616d706c6503636f6" - b"d0000010001") + def test__dns_handle_tcp_conn_fail_unpack(self): + # will call recv() only once + mock_socket = mock.Mock() + mock_socket.recv.side_effect = ['X', 'boo'] # X will fail unpack - sock_udp = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.service._dns_handle(self.addr, binascii.a2b_hex(payload), - sock_udp=sock_udp) - sendto_mock.assert_called_once_with( - binascii.a2b_hex(expected_response), self.addr) + self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket) + self.assertEqual(1, mock_socket.recv.call_count) + self.assertEqual(1, mock_socket.close.call_count) - def _send_request_to_mdns(self, req): - """Send request to localhost""" - self.assertTrue(len(self.service._dns_socks_udp)) - port = self.service._dns_socks_udp[0].getsockname()[1] - response = dns.query.udp(req, '127.0.0.1', port=port, timeout=1) - LOG.info("\n-- RESPONSE --\n%s\n--------------\n" % response.to_text()) - return response + def test__dns_handle_tcp_conn_one_query(self): + payload = self.query_payload + mock_socket = mock.Mock() + pay_len = struct.pack("!H", len(payload)) + mock_socket.recv.side_effect = [pay_len, payload, socket.timeout] - def _query_mdns(self, qname, rdtype, rdclass=dns.rdataclass.IN): - """Send query to localhost""" - req = dns.message.make_query(qname, rdtype, rdclass=rdclass) - req.id = 123 - return self._send_request_to_mdns(req) + self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket) - def test_query(self): - zone = self.create_zone() + self.assertEqual(3, mock_socket.recv.call_count) + self.assertEqual(1, mock_socket.sendall.call_count) + self.assertEqual(1, mock_socket.close.call_count) + wire = mock_socket.sendall.call_args[0][0] + expected_length_raw = wire[:2] + (expected_length, ) = struct.unpack('!H', expected_length_raw) + self.assertEqual(len(wire), expected_length + 2) + self.assertEqual(self.expected_response, wire[2:]) - # Reply query for NS - response = self._query_mdns(zone.name, dns.rdatatype.NS) - self.assertEqual(dns.rcode.NOERROR, response.rcode()) - self.assertEqual(1, len(response.answer)) - ans = response.answer[0] - self.assertEqual(dns.rdatatype.NS, ans.rdtype) - self.assertEqual(zone.name, ans.name.to_text()) - self.assertEqual(zone.ttl, ans.ttl) + def test__dns_handle_tcp_conn_multiple_queries(self): + payload = self.query_payload + mock_socket = mock.Mock() + pay_len = struct.pack("!H", len(payload)) + # Process 5 queries, than receive a misaligned query and close the + # connection there + mock_socket.recv.side_effect = [ + pay_len, payload, + pay_len, payload, + pay_len, payload, + pay_len, payload, + pay_len, payload, + 'X', payload, + pay_len, payload, + pay_len, payload, + ] + self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket) - # Reply query for SOA - response = self._query_mdns(zone.name, dns.rdatatype.SOA) - self.assertEqual(dns.rcode.NOERROR, response.rcode()) - self.assertEqual(1, len(response.answer)) - ans = response.answer[0] - self.assertEqual(dns.rdatatype.SOA, ans.rdtype) - self.assertEqual(zone.name, ans.name.to_text()) - self.assertEqual(zone.ttl, ans.ttl) + self.assertEqual(11, mock_socket.recv.call_count) + self.assertEqual(5, mock_socket.sendall.call_count) + self.assertEqual(1, mock_socket.close.call_count) - # Refuse query for incorrect rdclass - response = self._query_mdns(zone.name, dns.rdatatype.SOA, - rdclass=dns.rdataclass.RESERVED0) - self.assertEqual(dns.rcode.REFUSED, response.rcode()) - expected = b'007b81050001000000000000076578616d706c6503636f6d0000060000' # noqa - self.assertEqual(expected, hex_wire(response)) + def test__dns_handle_tcp_conn_multiple_queries_socket_error(self): + payload = self.query_payload + mock_socket = mock.Mock() + pay_len = struct.pack("!H", len(payload)) + # Process 5 queries, than receive a socket error and close the + # connection there + mock_socket.recv.side_effect = [ + pay_len, payload, + pay_len, payload, + pay_len, payload, + pay_len, payload, + pay_len, payload, + socket.error(errno.EAGAIN), + pay_len, payload, + pay_len, payload, + ] + self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket) - # Refuse query for ANY - response = self._query_mdns("www.%s" % zone.name, dns.rdatatype.ANY) - self.assertEqual(dns.rcode.REFUSED, response.rcode()) - expected = b'007b8105000100000000000003777777076578616d706c6503636f6d0000ff0001' # noqa - self.assertEqual(expected, hex_wire(response)) + self.assertEqual(11, mock_socket.recv.call_count) + self.assertEqual(5, mock_socket.sendall.call_count) + self.assertEqual(1, mock_socket.close.call_count) - # Reply query for A against inexistent record - response = self._query_mdns("nope.%s" % zone.name, dns.rdatatype.A) - self.assertEqual(dns.rcode.REFUSED, response.rcode()) - expected = b'007b81050001000000000000046e6f7065076578616d706c6503636f6d0000010001' # noqa - self.assertEqual(expected, hex_wire(response)) + def test__dns_handle_tcp_conn_multiple_queries_ignore_bad_query(self): + payload = self.query_payload + mock_socket = mock.Mock() + pay_len = struct.pack("!H", len(payload)) + # Ignore a broken query and keep going as long as the query len + # header was correct + mock_socket.recv.side_effect = [ + pay_len, payload, + pay_len, payload[:-5] + b'hello', + pay_len, payload, + pay_len, payload, + pay_len, payload, + ] + self.service._dns_handle_tcp_conn(('1.2.3.4', 42), mock_socket) - # Reply query for A - recordset = self.create_recordset(zone) - self.create_record(zone, recordset) - response = self._query_mdns(recordset.name, dns.rdatatype.A) - self.assertEqual(dns.rcode.NOERROR, response.rcode()) - self.assertEqual(1, len(response.answer)) - ans = response.answer[0] - self.assertEqual(dns.rdatatype.A, ans.rdtype) - self.assertEqual(recordset.name, ans.name.to_text()) - self.assertEqual(zone.ttl, ans.ttl) - self.assertEqual('3600 IN A 192.0.2.1', str(ans.to_rdataset())) - expected = b'007b85000001000100000000046d61696c076578616d706c6503636f6d0000010001c00c0001000100000e100004c0000201' # noqa - self.assertEqual(expected, hex_wire(response)) - - def test_query_axfr(self): - zone = self.create_zone() - - # Query for AXFR - response = self._query_mdns(zone.name, dns.rdatatype.AXFR) - self.assertEqual(dns.rcode.NOERROR, response.rcode()) - self.assertEqual(2, len(response.answer)) - ans = response.answer[0] # SOA - self.assertEqual(dns.rdatatype.SOA, ans.rdtype) - self.assertEqual(zone.name, ans.name.to_text()) - self.assertEqual(zone.ttl, ans.ttl) - ans = response.answer[1] # NS - self.assertEqual(dns.rdatatype.NS, ans.rdtype) - self.assertEqual(zone.name, ans.name.to_text()) - self.assertEqual(zone.ttl, ans.ttl) - - def test_notify_notauth_primary_zone(self): - zone = self.create_zone() - - # Send NOTIFY to mdns: NOTAUTH for primary zone - notify = dns.message.make_query(zone.name, dns.rdatatype.SOA) - notify.id = 123 - notify.flags = 0 - notify.set_opcode(dns.opcode.NOTIFY) - notify.flags |= dns.flags.AA - response = self._send_request_to_mdns(notify) - self.assertEqual(dns.rcode.NOTAUTH, response.rcode()) - expected = b'007ba0090001000000000000076578616d706c6503636f6d0000060001' # noqa - self.assertEqual(expected, hex_wire(response)) - - def test_notify_non_master(self): - zone = self.create_zone(type='SECONDARY', email='test@example.com') - - # Send NOTIFY to mdns: refuse from non-master - notify = dns.message.make_query(zone.name, dns.rdatatype.SOA) - notify.id = 123 - notify.flags = 0 - notify.set_opcode(dns.opcode.NOTIFY) - notify.flags |= dns.flags.AA - response = self._send_request_to_mdns(notify) - self.assertEqual(dns.rcode.REFUSED, response.rcode()) - expected = b'007ba0050001000000000000076578616d706c6503636f6d0000060001' # noqa - self.assertEqual(expected, hex_wire(response)) + self.assertEqual(11, mock_socket.recv.call_count) + self.assertEqual(4, mock_socket.sendall.call_count) + self.assertEqual(1, mock_socket.close.call_count)