Rework how TCP resequencing goes. Breaks everything!

This commit is contained in:
Neale Pickett 2008-07-23 10:48:27 -06:00
parent 119c4e391c
commit a561d1f199
2 changed files with 65 additions and 112 deletions

View File

@ -26,11 +26,12 @@ class GapString:
return '<GapString of length %d>' % self.length return '<GapString of length %d>' % self.length
def append(self, i): def append(self, i):
self.contents.append(i) try:
if isinstance(i, int):
self.length += i
else:
self.length += len(i) self.length += len(i)
self.contents.append(i)
except TypeError:
self.length += int(i)
self.contents.append(int(i))
def __str__(self): def __str__(self):
ret = [] ret = []

168
ip.py
View File

@ -149,81 +149,6 @@ class Frame:
self.dst_addr) self.dst_addr)
class Chunk:
"""Chunk of frames, possibly with gaps.
"""
def __init__(self, seq=None):
self.collection = {}
self.length = 0
self.seq = seq
self.first = None
def add(self, frame):
print 'self.seq=%d, adding %d' % (self.seq, frame.seq)
if not self.first:
self.first = frame
if self.seq is None:
self.seq = frame.seq
assert frame.seq >= self.seq, (frame.seq, self.seq)
self.collection[frame.seq] = frame
end = frame.seq - self.seq + len(frame.payload)
self.length = max(self.length, long(end))
def setlast(self, seq):
l = seq - self.seq
print 'len', l, self.length
self.length = max(l, self.length)
def __len__(self):
return int(self.length)
def __repr__(self):
if self.first:
return '<Chunk %s:%d -> %s:%d length %d (0x%x)>' % (self.first.src_addr,
self.first.sport,
self.first.dst_addr,
self.first.dport,
len(self),
len(self))
else:
return '<Chunk (no frames)>'
def gapstr(self, drop='?'):
"""Return contents as a GapString"""
ret = gapstr.GapString(drop=drop)
while len(ret) < self.length:
f = self.collection.get(self.seq + len(ret))
if f:
ret.append(f.payload)
else:
# This is where to fix big inefficiency for dropped packets.
l = 1
while ((len(ret) + l < self.length) and
(not (self.seq + len(ret) + l) in self.collection)):
l += 1
ret.append(l)
return ret
def __str__(self):
return str(self.gapstr())
def extend(self, other):
self.seq = min(self.seq or other.seq, other.seq)
self.length = self.length + other.length
if not self.first:
self.first = other.first
self.collection.update(other.collection)
def __add__(self, next):
new = self.__class__(self.seq)
new.extend(self)
new.extend(next)
return new
FIN = 1 FIN = 1
SYN = 2 SYN = 2
RST = 4 RST = 4
@ -315,17 +240,35 @@ class TCP_Resequence:
# Does this ACK after the last output sequence number? # Does this ACK after the last output sequence number?
seq = self.lastack[idx] seq = self.lastack[idx]
if pkt.ack > seq: if pkt.ack > seq:
ret = Chunk(seq)
pending = self.pending[xdi] pending = self.pending[xdi]
for key in pending.keys():
# Get a sorted list of sequence numbers
keys = pending.keys()
keys.sort()
# Build up return value
gs = gapstr.GapString()
if keys:
f = pending[keys[0]]
ret = (xdi, f, gs)
else:
ret = (xdi, None, gs)
# Fill in gs with our frames
for key in keys:
if key >= pkt.ack: if key >= pkt.ack:
continue break
if key >= seq: if key < seq:
ret.add(pending[key]) warnings.warn('Dropping %r from mid-stream session' % pending[key])
else: elif key > seq:
warnings.warn('Dropping out-of-order packet %r from mid-stream session' % pending[key]) gs.append(key - seq)
seq = key
frame = pending[key]
gs.append(frame.payload)
seq += len(frame.payload)
del pending[key] del pending[key]
ret.setlast(pkt.ack) if seq != pkt.ack:
gs.append(pkt.ack - seq)
self.lastack[idx] = pkt.ack self.lastack[idx] = pkt.ack
# If it has a payload, stick it into pending # If it has a payload, stick it into pending
@ -350,7 +293,7 @@ class TCP_Resequence:
warnings.warn('Spurious frame after shutdown: %r %d' % (pkt, pkt.flags)) warnings.warn('Spurious frame after shutdown: %r %d' % (pkt, pkt.flags))
class Resequence: class Dispatch:
def __init__(self, *filenames): def __init__(self, *filenames):
self.pcs = {} self.pcs = {}
@ -395,9 +338,9 @@ class Resequence:
if not tcp_sess: if not tcp_sess:
tcp_sess = TCP_Resequence() tcp_sess = TCP_Resequence()
self.sessions[frame.hash] = tcp_sess self.sessions[frame.hash] = tcp_sess
chunk = tcp_sess.handle(frame) ret = tcp_sess.handle(frame)
if chunk: if ret:
yield chunk yield frame.hash, ret
self.last = None self.last = None
self._read(pc, filename, fd) self._read(pc, filename, fd)
@ -537,7 +480,8 @@ class Session:
Packet = Packet Packet = Packet
def __init__(self, frame): def __init__(self, frame):
self.frame = frame self.firstframe = frame
self.lastframe = [None, None]
self.basename = 'transfers/%s' % (frame.src_addr,) self.basename = 'transfers/%s' % (frame.src_addr,)
self.pending = {} self.pending = {}
self.count = 0 self.count = 0
@ -548,29 +492,37 @@ class Session:
pass pass
def handle(self, chunk, lastpos): def handle(self, is_srv, frame, gs, lastpos):
"""Handle a data burst. """Handle a data burst.
Pass in a chunk. @param is_srv Is this from the server?
@param frame A frame associated with this packet, or None if it's all drops
@param gs A gapstring of the data
@param lastpos Last position in the source file, for debugging
""" """
if frame:
self.lastframe[is_srv] = frame
frame = self.lastframe[is_srv]
self.lastpos = lastpos self.lastpos = lastpos
try: try:
saddr = chunk.first.saddr saddr = frame.saddr
try: try:
(first, data) = self.pending.pop(saddr) (f, data) = self.pending.pop(saddr)
except KeyError: except KeyError:
first = chunk.first f = frame
data = gapstr.GapString() data = gapstr.GapString()
data.extend(chunk.gapstr()) data.extend(gs)
try: try:
print (data, data.length)
print (data, bool(data))
while data: while data:
p = self.Packet(self, first) p = self.Packet(self, f)
data = p.handle(data) data = p.handle(data)
self.process(p) self.process(p)
except NeedMoreData: except NeedMoreData:
self.pending[saddr] = (first, data) self.pending[saddr] = (f, data)
self.count += 1 self.count += 1
except: except:
print 'Lastpos: %s:::%d' % lastpos print 'Lastpos: %s:::%d' % lastpos
@ -599,7 +551,7 @@ class Session:
os.makedirs(self.basename) os.makedirs(self.basename)
except OSError: except OSError:
pass pass
frame = self.frame frame = self.firstframe
fn = '%s:%d-%s:%d---%s' % (frame.src_addr, frame.sport, fn = '%s:%d-%s:%d---%s' % (frame.src_addr, frame.sport,
frame.dst_addr, frame.dport, frame.dst_addr, frame.dport,
urllib.quote(fn, '\:')) urllib.quote(fn, '\:'))
@ -616,9 +568,9 @@ class Session:
class HtmlSession(Session): class HtmlSession(Session):
def __init__(self, frame): def __init__(self, frame):
Session.__init__(self, frame) Session.__init__(self, frame)
self.fn = self.make_filename('session.html') self.sessfn = self.make_filename('session.html')
self.fd = file(self.fn, 'w') self.sessfd = file(self.sessfn, 'w')
self.fd.write('''<?xml version="1.0" encoding="UTF-8"?> self.sessfd.write('''<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE html <!DOCTYPE html
PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN" PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd"> "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">
@ -626,18 +578,18 @@ class HtmlSession(Session):
<head> <head>
<title>%s</title> <title>%s</title>
<style type="text/css"> <style type="text/css">
.server { background-color .server { background-color: white; color: black; }
.client { background-color: blue; color: white; } .client { background-color: #884; color: white; }
</style> </style>
</head> </head>
<body> <body>
''' % self.__class__.__name__) ''' % self.__class__.__name__)
self.fd.write('<h1>%s</h1>\n' % self.__class__.__name__) self.sessfd.write('<h1>%s</h1>\n' % self.__class__.__name__)
self.fd.write('<pre>') self.sessfd.write('<pre>')
self.srv = None self.srv = None
def __del__(self): def __del__(self):
self.fd.write('</pre></body></html>') self.sessfd.write('</pre></body></html>')
def log(self, frame, payload, escape=True): def log(self, frame, payload, escape=True):
if escape: if escape:
@ -650,6 +602,6 @@ class HtmlSession(Session):
cls = 'server' cls = 'server'
else: else:
cls = 'client' cls = 'client'
self.fd.write('<span class="%s" title="%s(%s)">' % (cls, time.ctime(frame.time), frame.time)) self.sessfd.write('<span class="%s" title="%s(%s)">' % (cls, time.ctime(frame.time), frame.time))
self.fd.write(p.replace('\r\n', '\n')) self.sessfd.write(p.replace('\r\n', '\n'))
self.fd.write('</span>') self.sessfd.write('</span>')