#!/usr/bin/env python # # Standard imports # import string import struct, socket, select import thread # # Default variables # TFTP_PORT = 69 TFTP_PORT_MVP = 16869 # For MediaMVP, if bootp runs on 16867 # # TFTP Errors # class TFTPError(Exception): pass # # A class for a TFTP Connection # class TFTPConnection: RRQ = 1 WRQ = 2 DATA = 3 ACK = 4 ERR = 5 HDRSIZE = 4 # number of bytes for OPCODE and BLOCK in header def __init__(self, port=0, blocksize=512, timeout=2.0, retry=5 ): self.port = port self.blocksize = blocksize self.timeout = timeout self.retry = retry self.client_addr = None self.sock = None self.active = 0 self.blockNumber = 0 self.lastpkt = "" self.mode = "" self.filename = "" self.file = None self.bind('', port) # end __init__ def bind(self, host="", port=0): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock = sock if host or port: sock.bind((host, port)) # end start def send(self, pkt=""): self.sock.sendto(pkt, self.client_addr) self.lastpkt = pkt # end send def recv(self): sock = self.sock F = sock.fileno() client_addr = self.client_addr timeout = self.timeout retry = self.retry while retry: r,w,e = select.select( [F], [], [F], timeout) if not r: # We timed out -- retransmit retry = retry - 1 self.retransmit() else: # Read data packet pktsize = self.blocksize + self.HDRSIZE data, addr = sock.recvfrom(pktsize) if addr == client_addr: break else: raise TFTPError(4, "Transfer timed out") # end while return self.parse(data) # end recv def parse(self, data, unpack=struct.unpack): buf = buffer(data) pkt = {} opcode = pkt["opcode"] = unpack("!h", buf[:2])[0] if ( opcode == self.RRQ ) or ( opcode == self.WRQ ): filename, mode, junk = string.split(data[2:], "\000", 2) pkt["filename"] = filename pkt["mode"] = mode while junk: key, value, junk = junk.split("\000", 2) if key == 'blksize': self.blocksize = int(value) elif key == 'timeout': self.timeout = float(value) pkt[key] = value elif opcode == self.ACK: block = pkt["block"] = unpack("!h", buf[2:4])[0] elif opcode == self.DATA: block = pkt["block"] = unpack("!h", buf[2:4])[0] data = pkt["data"] = buf[4:] elif opcode == self.ERR: errnum = pkt["errnum"] = unpack("!h", buf[2:4])[0] errtxt = pkt["errtxt"] = buf[4:-1] else: raise TFTPError(4, "Unknown packet type") return pkt # end recv def retransmit(self): self.sock.sendto(self.lastpkt, self.client_addr) return # end retransmit def connect(self, addr, data): self.client_addr = addr RRQ = self.RRQ WRQ = self.WRQ DATA = self.DATA ACK = self.ACK ERR = self.ERR print "Client:", addr try: pkt = self.parse(data) opcode = pkt["opcode"] if opcode not in (RRQ, WRQ): raise TFTPError(4, "Bad request") # Start lock-step transfer self.active = 1 if opcode == RRQ: self.handleRRQ(pkt) else: self.handleWRQ(pkt) # Loop until done while self.active: pkt = self.recv() opcode = pkt["opcode"] if opcode == DATA: self.recvData(pkt) elif opcode == ACK: self.recvAck(pkt) elif opcode == ERR: self.recvErr(pkt) else: raise TFTPError(5, "Invalid opcode") except TFTPError, detail: self.sendError( detail[0], detail[1] ) print "Done." return # end connection def recvAck(self, pkt): if pkt["block"] == self.blockNumber: # We received the correct ACK self.handleACK(pkt) return # end recvAck def recvData(self, pkt): if pkt["block"] == self.blockNumber: # We received the correct DATA packet self.active = ( self.blocksize == len(pkt["data"]) ) self.handleDATA(pkt) return # end recvError def recvErr(self, pkt): self.handleERR(pkt) self.retransmit() # end recvErr def sendData(self, data, pack=struct.pack): blocksize = self.blocksize block = self.blockNumber = self.blockNumber + 1 lendata = len(data) format = "!hh%ds" % lendata pkt = pack(format, self.DATA, block, data) self.send(pkt) self.active = (len(data) == blocksize) # end sendData def sendAck(self, pack=struct.pack): block = self.blockNumber self.blockNumber = self.blockNumber + 1 format = "!hh" pkt = pack(format, self.ACK, block) self.send(pkt) # end sendAck def sendError(self, errnum, errtext, pack=struct.pack): errtext = errtext + "\000" format = "!hh%ds" % len(errtext) outdata = pack(format, self.ERR, errnum, errtext) self.sock.sendto(outdata, self.client_addr) return # end sendError # # # Override these handle* methods as needed # # def handleRRQ(self, pkt): filename = pkt["filename"] mode = pkt["mode"] try: self.file = self.readRequest(filename, mode) except: self.sendError(1, "Cannot open file") print "Cannot open file for reading %s\n%s: %s" % ((filename,) + sys.exc_info()[:2]) return self.sendData( self.file.read(self.blocksize) ) # end readFile def handleWRQ(self, pkt): filename = pkt["filename"] mode = pkt["mode"] try: self.file = self.writeRequest(filename, mode) except: self.sendError(1, "Cannot open file") print "Cannot open file for writing %s\n%s: %s" % ((filename,) + sys.exc_info()[:2]) return self.sendAck() # end writeFile def handleACK(self, pkt): if self.active: self.sendData( self.file.read(self.blocksize) ) return # end handle ACK def handleDATA(self, pkt): self.sendAck() data = pkt["data"] self.file.write( data ) # end handleDATA def handleERR(self, pkt): print pkt["errtxt"] return # end handleERR # # You should definitely override these # def readRequest(self, filename, mode): from StringIO import StringIO return StringIO("") # end readRequest def writeRequest(self, filename, mode): from StringIO import StringIO return StringIO() # end writeRequest # end class TFTPConnection # # Simple TFTP Server # Each connection is handled in its own thread. # class TFTPServer: """TFTP Server Implements a threaded TFTP Server. Each request is handled in its own thread """ def __init__(self, conn=TFTPConnection, srcports=[]): self.conn = conn self.srcports = srcports self.sock = [] # end __init__ def bind(self, host, port): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock.append(sock) sock.bind((host, port)) # end start def forever(self): while 1: r,w,e = select.select(self.sock, [], self.sock) for sock in r: data, addr = sock.recvfrom(516) self.handle(addr, data) # end serve_forever def handle(self, addr, data): if self.srcports: nextport = self.srcports.pop(0) self.srcports.append( nextport ) T = self.conn( nextport ) else: T = self.conn() thread.start_new_thread( T.connect, (addr, data) ) return # end handle # end class TFTPServer class FileTFTP( TFTPConnection ): def readRequest(self, filename, mode): print "Sending file:", filename return open(filename, 'rb') def writeRequest(self, filename, mode): print "Receiving file:", filename return open(filename, 'wb') if __name__ == "__main__": import sys from StringIO import StringIO # # Subclass to create our own TFTP Connection object # class TestTFTP( TFTPConnection ): def readRequest(self, filename, mode): randomstring = "Here is a sample string" return StringIO( randomstring ) def writeRequest(self, filename, mode): return StringIO() # end class if sys.argv[1:]: port = string.atoi(sys.argv[1]) try: serv = TFTPServer(conn=FileTFTP) serv.bind("", TFTP_PORT) serv.bind("", TFTP_PORT_MVP) serv.forever() except KeyboardInterrupt, SystemExit: pass