lib/packet: teach packet library to truncate padding octet

The patch teaches packet library to truncate padding octets.
Change packet_base.parser() to return (header, next_type, rest_of_packet)
The protocol class that knows its payload length should rest_of_packet
where padding octets at the last of packet is truncated.

As bonus,
- fix ipv6 parser as ipv6 header doesn't have options.
  It seems copy-and-paste from ipv4
- improve ipv4, tcp a bit

Cc: YAMAMOTO Takashi <yamamoto@valinux.co.jp>
Cc: Shaun Crampton <Shaun.Crampton@metaswitch.com>
Signed-off-by: Isaku Yamahata <yamahata@valinux.co.jp>
Signed-off-by: FUJITA Tomonori <fujita.tomonori@lab.ntt.co.jp>
This commit is contained in:
Isaku Yamahata 2013-07-01 14:53:33 +09:00 committed by FUJITA Tomonori
parent 3679d7facc
commit 3837a84eec
25 changed files with 79 additions and 92 deletions

View File

@ -70,14 +70,13 @@ class arp(packet_base.PacketBase):
self.src_ip = src_ip
self.dst_mac = dst_mac
self.dst_ip = dst_ip
self.length = arp._MIN_LEN
@classmethod
def parser(cls, buf):
(hwtype, proto, hlen, plen, opcode, src_mac, src_ip,
dst_mac, dst_ip) = struct.unpack_from(cls._PACK_STR, buf)
return cls(hwtype, proto, hlen, plen, opcode, src_mac, src_ip,
dst_mac, dst_ip), None
dst_mac, dst_ip), None, buf[arp._MIN_LEN:]
def serialize(self, payload, prev):
return struct.pack(arp._PACK_STR, self.hwtype, self.proto,

View File

@ -172,10 +172,13 @@ class dhcp(packet_base.PacketBase):
(hops, xid, secs, flags, ciaddr, yiaddr, siaddr, giaddr, chaddr,
dummy, sname, boot_file
) = struct.unpack_from(unpack_str, buf)
length = min_len
if len(buf) > min_len:
parse_opt = options.parser(buf[min_len:])
return cls(op, chaddr, parse_opt, htype, hlen, hops, xid, secs, flags,
ciaddr, yiaddr, siaddr, giaddr, sname, boot_file)
length += parse_opt.options_len
return (cls(op, chaddr, parse_opt, htype, hlen, hops, xid, secs, flags,
ciaddr, yiaddr, siaddr, giaddr, sname, boot_file),
None, buf[length:])
def serialize(self, payload, prev):
seri_opt = self.options.serialize()

View File

@ -43,12 +43,12 @@ class ethernet(packet_base.PacketBase):
self.dst = dst
self.src = src
self.ethertype = ethertype
self.length = ethernet._MIN_LEN
@classmethod
def parser(cls, buf):
dst, src, ethertype = struct.unpack_from(cls._PACK_STR, buf)
return cls(dst, src, ethertype), ethernet.get_packet_type(ethertype)
return (cls(dst, src, ethertype), ethernet.get_packet_type(ethertype),
buf[ethernet._MIN_LEN:])
def serialize(self, payload, prev):
return struct.pack(ethernet._PACK_STR, self.dst, self.src,

View File

@ -94,7 +94,7 @@ class icmp(packet_base.PacketBase):
else:
msg.data = buf[offset:]
return msg, None
return msg, None, None
def serialize(self, payload, prev):
hdr = bytearray(struct.pack(icmp._PACK_STR, self.type,

View File

@ -106,7 +106,7 @@ class icmpv6(packet_base.PacketBase):
else:
msg.data = buf[offset:]
return msg, None
return msg, None, None
def serialize(self, payload, prev):
hdr = bytearray(struct.pack(icmpv6._PACK_STR, self.type_,

View File

@ -84,9 +84,11 @@ class ipv4(packet_base.PacketBase):
self.csum = csum
self.src = src
self.dst = dst
self.length = header_length * 4
self.option = option
def __len__(self):
return self.header_length * 4
@classmethod
def parser(cls, buf):
(version, tos, total_length, identification, flags, ttl, proto, csum,
@ -95,16 +97,19 @@ class ipv4(packet_base.PacketBase):
version = version >> 4
offset = flags & ((1 << 13) - 1)
flags = flags >> 13
length = header_length * 4
if length > ipv4._MIN_LEN:
option = buf[ipv4._MIN_LEN:length]
else:
option = None
msg = cls(version, header_length, tos, total_length, identification,
flags, offset, ttl, proto, csum, src, dst)
flags, offset, ttl, proto, csum, src, dst, option)
if msg.length > ipv4._MIN_LEN:
msg.option = buf[ipv4._MIN_LEN:msg.length]
return msg, ipv4.get_packet_type(proto)
return msg, ipv4.get_packet_type(proto), buf[length:total_length]
def serialize(self, payload, prev):
hdr = bytearray(self.header_length * 4)
length = len(self)
hdr = bytearray(length)
version = self.version << 4 | self.header_length
flags = self.flags << 13 | self.offset
if self.total_length == 0:
@ -114,7 +119,7 @@ class ipv4(packet_base.PacketBase):
self.ttl, self.proto, 0, self.src, self.dst)
if self.option:
assert (self.length - ipv4._MIN_LEN) >= len(self.option)
assert (length - ipv4._MIN_LEN) >= len(self.option)
hdr[ipv4._MIN_LEN:ipv4._MIN_LEN + len(self.option)] = self.option
self.csum = packet_utils.checksum(hdr)

View File

@ -63,24 +63,18 @@ class ipv6(packet_base.PacketBase):
self.hop_limit = hop_limit
self.src = src
self.dst = dst
self.length = 40
@classmethod
def parser(cls, buf):
(v_tc_flow, plen, nxt, hlim, src, dst) = struct.unpack_from(
(v_tc_flow, payload_length, nxt, hlim, src, dst) = struct.unpack_from(
cls._PACK_STR, buf)
version = v_tc_flow >> 28
traffic_class = (v_tc_flow >> 20) & 0xff
flow_label = v_tc_flow & 0xfffff
payload_length = plen
hop_limit = hlim
msg = cls(version, traffic_class, flow_label, payload_length,
nxt, hop_limit, src, dst)
if msg.length > ipv6._MIN_LEN:
msg.option = buf[ipv6._MIN_LEN:msg.length]
return msg, ipv6.get_packet_type(nxt)
return msg, ipv6.get_packet_type(nxt), buf[cls._MIN_LEN:payload_length]
def serialize(self, payload, prev):
hdr = bytearray(40)

View File

@ -111,11 +111,6 @@ class lldp(packet_base.PacketBase):
def __init__(self, tlvs):
super(lldp, self).__init__()
self.tlvs = tlvs
length = 0
for tlv in tlvs:
length += LLDP_TLV_SIZE + tlv.len
self.length = length
# at least it must have chassis id, port id, ttl and end
def _tlvs_len_valid(self):
@ -137,9 +132,9 @@ class lldp(packet_base.PacketBase):
tlv = cls._tlv_parsers[tlv_type](buf)
tlvs.append(tlv)
offset = LLDP_TLV_SIZE + tlv.len
buf = buf[offset:]
if tlv.tlv_type == LLDP_TLV_END:
break
buf = buf[offset:]
assert len(buf) > 0
lldp_pkt = cls(tlvs)
@ -147,7 +142,7 @@ class lldp(packet_base.PacketBase):
assert lldp_pkt._tlvs_len_valid()
assert lldp_pkt._tlvs_valid()
return lldp_pkt, None
return lldp_pkt, None, buf
def serialize(self, payload, prev):
data = bytearray()

View File

@ -50,7 +50,6 @@ class mpls(packet_base.PacketBase):
self.exp = exp
self.bsb = bsb
self.ttl = ttl
self.length = mpls._MIN_LEN
@classmethod
def parser(cls, buf):
@ -61,9 +60,9 @@ class mpls(packet_base.PacketBase):
label = label >> 12
msg = cls(label, exp, bsb, ttl)
if bsb:
return msg, ipv4.ipv4
return msg, ipv4.ipv4, buf[msg._MIN_LEN:]
else:
return msg, mpls
return msg, mpls, buf[msg._MIN_LEN:]
def serialize(self, payload, prev):
val = self.label << 12 | self.exp << 9 | self.bsb << 8 | self.ttl

View File

@ -42,21 +42,20 @@ class Packet(object):
else:
self.protocols = protocols
self.protocol_idx = 0
self.parsed_bytes = 0
if self.data:
self._parser(parse_cls)
def _parser(self, cls):
rest_data = self.data
while cls:
try:
proto, cls = cls.parser(self.data[self.parsed_bytes:])
if proto:
self.parsed_bytes += proto.length
self.protocols.append(proto)
proto, cls, rest_data = cls.parser(rest_data)
except struct.error:
cls = None
if len(self.data) > self.parsed_bytes:
self.protocols.append(self.data[self.parsed_bytes:])
break
if proto:
self.protocols.append(proto)
if rest_data:
self.protocols.append(rest_data)
def serialize(self):
"""Encode a packet and store the resulted bytearray in self.data.

View File

@ -39,7 +39,6 @@ class PacketBase(object):
def __init__(self):
super(PacketBase, self).__init__()
self.length = 0
@property
def protocol_name(self):

View File

@ -59,36 +59,42 @@ class tcp(packet_base.PacketBase):
self.window_size = window_size
self.csum = csum
self.urgent = urgent
self.length = self.offset * 4
self.option = option
def __len__(self):
return self.offset * 4
@classmethod
def parser(cls, buf):
(src_port, dst_port, seq, ack, offset, bits, window_size,
csum, urgent) = struct.unpack_from(cls._PACK_STR, buf)
offset = offset >> 4
bits = bits & 0x3f
length = offset * 4
if length > tcp._MIN_LEN:
option = buf[tcp._MIN_LEN:length]
else:
option = None
msg = cls(src_port, dst_port, seq, ack, offset, bits,
window_size, csum, urgent)
window_size, csum, urgent, option)
if msg.length > tcp._MIN_LEN:
msg.option = buf[tcp._MIN_LEN:msg.length]
return msg, None
return msg, None, buf[length:]
def serialize(self, payload, prev):
h = bytearray(self.length)
length = len(self)
h = bytearray(length)
offset = self.offset << 4
struct.pack_into(tcp._PACK_STR, h, 0, self.src_port, self.dst_port,
self.seq, self.ack, offset, self.bits,
self.window_size, self.csum, self.urgent)
if self.option:
assert (self.length - tcp._MIN_LEN) >= len(self.option)
assert (length - tcp._MIN_LEN) >= len(self.option)
h[tcp._MIN_LEN:tcp._MIN_LEN + len(self.option)] = self.option
if self.csum == 0:
length = self.length + len(payload)
self.csum = packet_utils.checksum_ip(prev, length, h + payload)
total_length = length + len(payload)
self.csum = packet_utils.checksum_ip(prev, total_length,
h + payload)
struct.pack_into('!H', h, 16, self.csum)
return h

View File

@ -47,14 +47,13 @@ class udp(packet_base.PacketBase):
self.dst_port = dst_port
self.total_length = total_length
self.csum = csum
self.length = udp._MIN_LEN
@classmethod
def parser(cls, buf):
(src_port, dst_port, total_length, csum) = struct.unpack_from(
cls._PACK_STR, buf)
msg = cls(src_port, dst_port, total_length, csum)
return msg, None
return msg, None, buf[msg._MIN_LEN:total_length]
def serialize(self, payload, prev):
if self.total_length == 0:

View File

@ -49,7 +49,6 @@ class vlan(packet_base.PacketBase):
self.cfi = cfi
self.vid = vid
self.ethertype = ethertype
self.length = vlan._MIN_LEN
@classmethod
def parser(cls, buf):
@ -57,7 +56,8 @@ class vlan(packet_base.PacketBase):
pcp = tci >> 13
cfi = (tci >> 12) & 1
vid = tci & ((1 << 12) - 1)
return cls(pcp, cfi, vid, ethertype), vlan.get_packet_type(ethertype)
return (cls(pcp, cfi, vid, ethertype),
vlan.get_packet_type(ethertype), buf[vlan._MIN_LEN:])
def serialize(self, payload, prev):
tci = self.pcp << 13 | self.cfi << 12 | self.vid

View File

@ -270,7 +270,6 @@ class vrrp(packet_base.PacketBase):
self.auth_data = auth_data
self._is_ipv6 = is_ipv6(self.ip_addresses[0])
self.length = len(self)
self.identification = 0 # used for ipv4 identification
def checksum_ok(self, ipvx, vrrp_buf):
@ -337,7 +336,7 @@ class vrrp(packet_base.PacketBase):
if self.is_ipv6:
traffic_class = 0xc0 # set tos to internetwork control
flow_label = 0
payload_length = ipv6.ipv6._MIN_LEN + self.length # XXX _MIN_LEN
payload_length = ipv6.ipv6._MIN_LEN + len(self) # XXX _MIN_LEN
e = ethernet.ethernet(VRRP_IPV6_DST_MAC_ADDRESS,
vrrp_ipv6_src_mac_address(self.vrid),
ether.ETH_TYPE_IPV6)
@ -456,8 +455,9 @@ class vrrpv2(vrrp):
offset += struct.calcsize(ip_addresses_pack_str)
auth_data = struct.unpack_from(cls._AUTH_DATA_PACK_STR, buf, offset)
return cls(version, type_, vrid, priority, count_ip, adver_int,
checksum, ip_addresses, auth_type, auth_data), None
msg = cls(version, type_, vrid, priority, count_ip, adver_int,
checksum, ip_addresses, auth_type, auth_data)
return msg, None, buf[len(msg):]
@staticmethod
def serialize_static(vrrp_, prev):
@ -534,7 +534,7 @@ class vrrpv3(vrrp):
# http://www.ietf.org/mail-archive/web/vrrp/current/msg01473.html
# if not self.is_ipv6:
# return packet_utils.checksum(vrrp_buf) == 0
return packet_utils.checksum_ip(ipvx, self.length, vrrp_buf) == 0
return packet_utils.checksum_ip(ipvx, len(self), vrrp_buf) == 0
@staticmethod
def create(type_, vrid, priority, max_adver_int, ip_addresses):
@ -573,8 +573,9 @@ class vrrpv3(vrrp):
address_len, count_ip))
ip_addresses = struct.unpack_from(pack_str, buf, offset)
return cls(version, type_, vrid, priority,
count_ip, max_adver_int, checksum, ip_addresses), None
msg = cls(version, type_, vrid, priority,
count_ip, max_adver_int, checksum, ip_addresses)
return msg, None, buf[len(msg):]
@staticmethod
def serialize_static(vrrp_, prev):

View File

@ -48,7 +48,6 @@ class Test_arp(unittest.TestCase):
dst_ip = int(netaddr.IPAddress('24.166.173.159'))
fmt = arp._PACK_STR
length = struct.calcsize(arp._PACK_STR)
buf = pack(fmt, hwtype, proto, hlen, plen, opcode, src_mac, src_ip,
dst_mac, dst_ip)
@ -76,7 +75,6 @@ class Test_arp(unittest.TestCase):
eq_(self.src_ip, self.a.src_ip)
eq_(self.dst_mac, self.a.dst_mac)
eq_(self.dst_ip, self.a.dst_ip)
eq_(self.length, self.a.length)
def test_parser(self):
_res = self.a.parser(self.buf)
@ -94,7 +92,6 @@ class Test_arp(unittest.TestCase):
eq_(res.src_ip, self.src_ip)
eq_(res.dst_mac, self.dst_mac)
eq_(res.dst_ip, self.dst_ip)
eq_(res.length, self.length)
def test_serialize(self):
data = bytearray()
@ -173,7 +170,6 @@ class Test_arp(unittest.TestCase):
eq_(a.src_ip, self.src_ip)
eq_(a.dst_mac, self.dst_mac)
eq_(a.dst_ip, self.dst_ip)
eq_(a.length, self.length)
@raises(Exception)
def test_malformed_arp(self):

View File

@ -39,7 +39,6 @@ class Test_ethernet(unittest.TestCase):
dst = mac.haddr_to_bin('AA:AA:AA:AA:AA:AA')
src = mac.haddr_to_bin('BB:BB:BB:BB:BB:BB')
ethertype = ether.ETH_TYPE_ARP
length = struct.calcsize(ethernet._PACK_STR)
buf = pack(ethernet._PACK_STR, dst, src, ethertype)
@ -60,16 +59,14 @@ class Test_ethernet(unittest.TestCase):
eq_(self.dst, self.e.dst)
eq_(self.src, self.e.src)
eq_(self.ethertype, self.e.ethertype)
eq_(self.length, self.e.length)
def test_parser(self):
res, ptype = self.e.parser(self.buf)
res, ptype, _ = self.e.parser(self.buf)
LOG.debug((res, ptype))
eq_(res.dst, self.dst)
eq_(res.src, self.src)
eq_(res.ethertype, self.ethertype)
eq_(res.length, self.length)
eq_(ptype, arp)
def test_serialize(self):

View File

@ -61,7 +61,7 @@ class Test_icmpv6_header(unittest.TestCase):
eq_(0, self.icmp.csum)
def test_parser(self):
msg, n = self.icmp.parser(self.buf)
msg, n, _ = self.icmp.parser(self.buf)
eq_(msg.type_, self.type_)
eq_(msg.code, self.code)
@ -110,7 +110,7 @@ class Test_icmpv6_echo_request(unittest.TestCase):
def _test_parser(self, data=None):
buf = self.buf + str(data or '')
msg, n = icmpv6.icmpv6.parser(buf)
msg, n, _ = icmpv6.icmpv6.parser(buf)
eq_(msg.type_, self.type_)
eq_(msg.code, self.code)
@ -195,7 +195,7 @@ class Test_icmpv6_neighbor_solict(unittest.TestCase):
def _test_parser(self, data=None):
buf = self.buf + str(data or '')
msg, n = icmpv6.icmpv6.parser(buf)
msg, n, _ = icmpv6.icmpv6.parser(buf)
eq_(msg.type_, self.type_)
eq_(msg.code, self.code)
@ -300,7 +300,7 @@ class Test_icmpv6_router_solict(unittest.TestCase):
def _test_parser(self, data=None):
buf = self.buf + str(data or '')
msg, n = icmpv6.icmpv6.parser(buf)
msg, n, _ = icmpv6.icmpv6.parser(buf)
eq_(msg.type_, self.type_)
eq_(msg.code, self.code)

View File

@ -85,11 +85,11 @@ class Test_ipv4(unittest.TestCase):
eq_(self.csum, self.ip.csum)
eq_(self.src, self.ip.src)
eq_(self.dst, self.ip.dst)
eq_(self.length, self.ip.length)
eq_(self.length, len(self.ip))
eq_(self.option, self.ip.option)
def test_parser(self):
res, ptype = self.ip.parser(self.buf)
res, ptype, _ = self.ip.parser(self.buf)
eq_(res.version, self.version)
eq_(res.header_length, self.header_length)

View File

@ -49,8 +49,8 @@ class TestLLDPMandatoryTLV(unittest.TestCase):
def test_parse_without_ethernet(self):
buf = self.data[ethernet.ethernet._MIN_LEN:]
(lldp_pkt, cls) = lldp.lldp.parser(buf)
eq_(lldp_pkt.length, len(buf))
(lldp_pkt, cls, rest_buf) = lldp.lldp.parser(buf)
eq_(len(rest_buf), 0)
tlvs = lldp_pkt.tlvs
eq_(tlvs[0].tlv_type, lldp.LLDP_TLV_CHASSIS_ID)
@ -170,7 +170,6 @@ class TestLLDPOptionalTLV(unittest.TestCase):
eq_(type(pkt.next()), ethernet.ethernet)
lldp_pkt = pkt.next()
eq_(type(lldp_pkt), lldp.lldp)
eq_(lldp_pkt.length, len(buf) - ethernet.ethernet._MIN_LEN)
tlvs = lldp_pkt.tlvs

View File

@ -358,7 +358,7 @@ class TestPacket(unittest.TestCase):
eq_(0b101010, p_tcp.bits)
eq_(2048, p_tcp.window_size)
eq_(0x6f, p_tcp.urgent)
eq_(len(t_buf), p_tcp.length)
eq_(len(t_buf), len(p_tcp))
t = bytearray(t_buf)
struct.pack_into('!H', t, 16, p_tcp.csum)
ph = struct.pack('!IIBBH', self.src_ip, self.dst_ip, 0,

View File

@ -74,7 +74,7 @@ class Test_tcp(unittest.TestCase):
eq_(self.option, self.t.option)
def test_parser(self):
r1, r2 = self.t.parser(self.buf)
r1, r2, _ = self.t.parser(self.buf)
eq_(self.src_port, r1.src_port)
eq_(self.dst_port, r1.dst_port)

View File

@ -57,7 +57,7 @@ class Test_udp(unittest.TestCase):
eq_(self.csum, self.u.csum)
def test_parser(self):
r1, r2 = self.u.parser(self.buf)
r1, r2, _ = self.u.parser(self.buf)
eq_(self.src_port, r1.src_port)
eq_(self.dst_port, r1.dst_port)

View File

@ -42,7 +42,6 @@ class Test_vlan(unittest.TestCase):
vid = 32
tci = pcp << 15 | cfi << 12 | vid
ethertype = ether.ETH_TYPE_IP
length = struct.calcsize(vlan._PACK_STR)
buf = pack(vlan._PACK_STR, tci, ethertype)
@ -64,16 +63,14 @@ class Test_vlan(unittest.TestCase):
eq_(self.cfi, self.v.cfi)
eq_(self.vid, self.v.vid)
eq_(self.ethertype, self.v.ethertype)
eq_(self.length, self.v.length)
def test_parser(self):
res, ptype = self.v.parser(self.buf)
res, ptype, _ = self.v.parser(self.buf)
eq_(res.pcp, self.pcp)
eq_(res.cfi, self.cfi)
eq_(res.vid, self.vid)
eq_(res.ethertype, self.ethertype)
eq_(res.length, self.length)
eq_(ptype, ipv4)
def test_serialize(self):
@ -136,7 +133,6 @@ class Test_vlan(unittest.TestCase):
eq_(v.cfi, self.cfi)
eq_(v.vid, self.vid)
eq_(v.ethertype, self.ethertype)
eq_(v.length, self.length)
@raises(Exception)
def test_malformed_vlan(self):

View File

@ -73,7 +73,7 @@ class Test_vrrpv2(unittest.TestCase):
eq_(self.auth_data, self.vrrpv2.auth_data)
def test_parser(self):
vrrpv2, _cls = self.vrrpv2.parser(self.buf)
vrrpv2, _cls, _ = self.vrrpv2.parser(self.buf)
eq_(self.version, vrrpv2.version)
eq_(self.type_, vrrpv2.type)
@ -216,7 +216,7 @@ class Test_vrrpv3_ipv4(unittest.TestCase):
eq_(self.ip_address, self.vrrpv3.ip_addresses[0])
def test_parser(self):
vrrpv3, _cls = self.vrrpv3.parser(self.buf)
vrrpv3, _cls, _ = self.vrrpv3.parser(self.buf)
eq_(self.version, vrrpv3.version)
eq_(self.type_, vrrpv3.type)
@ -357,7 +357,7 @@ class Test_vrrpv3_ipv6(unittest.TestCase):
eq_(self.ip_address, self.vrrpv3.ip_addresses[0])
def test_parser(self):
vrrpv3, _cls = self.vrrpv3.parser(self.buf)
vrrpv3, _cls, _ = self.vrrpv3.parser(self.buf)
eq_(self.version, vrrpv3.version)
eq_(self.type_, vrrpv3.type)