diff options
-rw-r--r-- | Demo/rpc/rpc.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/Demo/rpc/rpc.py b/Demo/rpc/rpc.py index 71a2c1f..4c790f0 100644 --- a/Demo/rpc/rpc.py +++ b/Demo/rpc/rpc.py @@ -75,6 +75,11 @@ class Packer(xdr.Packer): # Caller must add procedure-specific part of reply +# Exceptions +BadRPCFormat = 'rpc.BadRPCFormat' +BadRPCVersion = 'rpc.BadRPCVersion' +GarbageArgs = 'rpc.GarbageArgs' + class Unpacker(xdr.Unpacker): def unpack_auth(self): @@ -82,6 +87,22 @@ class Unpacker(xdr.Unpacker): stuff = self.unpack_opaque() return (flavor, stuff) + def unpack_callheader(self): + xid = self.unpack_uint(xid) + temp = self.unpack_enum() + if temp <> CALL: + raise BadRPCFormat, 'no CALL but ' + `temp` + temp = self.unpack_uint() + if temp <> RPCVERSION: + raise BadRPCVerspion, 'bad RPC version ' + `temp` + prog = self.unpack_uint() + vers = self.unpack_uint() + proc = self.unpack_uint() + cred = self.unpack_auth() + verf = self.unpack_auth() + return xid, prog, vers, proc, cred, verf + # Caller must add procedure-specific part of call + def unpack_replyheader(self): xid = self.unpack_uint() mtype = self.unpack_enum() @@ -105,11 +126,17 @@ class Unpacker(xdr.Unpacker): 'Neither MSG_DENIED nor MSG_ACCEPTED: ' + `stat` verf = self.unpack_auth() stat = self.unpack_enum() + if stat == PROG_UNAVAIL: + raise RuntimeError, 'call failed: PROG_UNAVAIL' if stat == PROG_MISMATCH: low = self.unpack_uint() high = self.unpack_uint() raise RuntimeError, \ 'call failed: PROG_MISMATCH: ' + `low, high` + if stat == PROC_UNAVAIL: + raise RuntimeError, 'call failed: PROC_UNAVAIL' + if stat == GARBAGE_ARGS: + raise RuntimeError, 'call failed: GARBAGE_ARGS' if stat <> SUCCESS: raise RuntimeError, 'call failed: ' + `stat` return xid, verf @@ -193,6 +220,8 @@ def sendrecord(sock, record): def recvfrag(sock): header = sock.recv(4) + if len(header) < 4: + raise EOFError x = long(ord(header[0]))<<24 | ord(header[1])<<16 | \ ord(header[2])<<8 | ord(header[3]) last = ((x & 0x80000000) != 0) @@ -359,6 +388,22 @@ class PartialPortMapperClient: self.packer = PortMapperPacker().init() self.unpacker = PortMapperUnpacker().init('') + def Set(self, mapping): + self.start_call(PMAPPROC_SET) + self.packer.pack_mapping(mapping) + self.do_call() + res = self.unpacker.unpack_uint() + self.end_call() + return res + + def Unset(self, mapping): + self.start_call(PMAPPROC_UNSET) + self.packer.pack_mapping(mapping) + self.do_call() + res = self.unpacker.unpack_uint() + self.end_call() + return res + def Getport(self, mapping): self.start_call(PMAPPROC_GETPORT) self.packer.pack_mapping(mapping) @@ -394,6 +439,8 @@ class TCPClient(RawTCPClient): def init(self, host, prog, vers): pmap = TCPPortMapperClient().init(host) port = pmap.Getport((prog, vers, IPPROTO_TCP, 0)) + if port == 0: + raise RuntimeError, 'program not registered' pmap.close() return RawTCPClient.init(self, host, prog, vers, port) @@ -407,6 +454,156 @@ class UDPClient(RawUDPClient): return RawUDPClient.init(self, host, prog, vers, port) +# Server classes + +class Server: + + def init(self, host, prog, vers, port, type): + self.host = host # Should normally be '' for default interface + self.prog = prog + self.vers = vers + self.port = port # Should normally be 0 for random port + self.type = type # SOCK_STREAM or SOCK_DGRAM + self.sock = socket.socket(socket.AF_INET, type) + self.sock.bind((host, port)) + self.host, self.port = self.sock.getsockname() + self.addpackers() + return self + + def register(self): + if self.type == socket.SOCK_STREAM: + type = IPPROTO_TCP + elif self.type == socket.SOCK_DGRAM: + type = IPPROTO_UDP + else: + raise ValueError, 'unknown protocol type' + mapping = self.prog, self.vers, type, self.port + p = TCPPortMapperClient().init(self.host) + if not p.Set(mapping): + raise RuntimeError, 'register failed' + + def unregister(self): + if self.type == socket.SOCK_STREAM: + type = IPPROTO_TCP + elif self.type == socket.SOCK_DGRAM: + type = IPPROTO_UDP + else: + raise ValueError, 'unknown protocol type' + mapping = self.prog, self.vers, type, self.port + p = TCPPortMapperClient().init(self.host) + if not p.Unset(mapping): + raise RuntimeError, 'unregister failed' + + def handle(self, call): + # Don't use unpack_header but parse the header piecewise + # XXX I have no idea if I am using the right error responses! + self.unpacker.reset(call) + self.packer.reset() + xid = self.unpacker.unpack_uint() + self.packer.pack_uint(xid) + temp = self.unpacker.unpack_enum() + if temp <> CALL: + return None # Not worthy of a reply + self.packer.pack_uint(REPLY) + temp = self.unpacker.unpack_uint() + if temp <> RPCVERSION: + self.packer.pack_uint(MSG_DENIED) + self.packer.pack_uint(RPC_MISMATCH) + self.packer.pack_uint(RPCVERSION) + self.packer.pack_uint(RPCVERSION) + return self.packer.get_buf() + self.packer.pack_uint(MSG_ACCEPTED) + self.packer.pack_auth((AUTH_NULL, make_auth_null())) + prog = self.unpacker.unpack_uint() + if prog <> self.prog: + self.packer.pack_uint(PROG_UNAVAIL) + return self.packer.get_buf() + vers = self.unpacker.unpack_uint() + if vers <> self.vers: + self.packer.pack_uint(PROG_MISMATCH) + self.packer.pack_uint(self.vers) + self.packer.pack_uint(self.vers) + return self.packer.get_buf() + proc = self.unpacker.unpack_uint() + methname = 'handle_' + `proc` + try: + meth = getattr(self, methname) + except AttributeError: + self.packer.pack_uint(PROC_UNAVAIL) + return self.packer.get_buf() + cred = self.unpacker.unpack_auth() + verf = self.unpacker.unpack_auth() + try: + meth() # Unpack args, call turn_around(), pack reply + except (EOFError, GarbageArgs): + # Too few or too many arguments + self.packer.reset() + self.packer.pack_uint(xid) + self.packer.pack_uint(REPLY) + self.packer.pack_uint(MSG_ACCEPTED) + self.packer.pack_auth((AUTH_NULL, make_auth_null())) + self.packer.pack_uint(GARBAGE_ARGS) + return self.packer.get_buf() + + def turn_around(self): + try: + self.unpacker.done() + except RuntimeError: + raise GarbageArgs + self.packer.pack_uint(SUCCESS) + + def handle_0(self): # Handle NULL message + self.turn_around() + + # Functions that may be overridden by specific derived classes + + def addpackers(self): + self.packer = Packer().init() + self.unpacker = Unpacker().init('') + + +class TCPServer(Server): + + def init(self, host, prog, vers, port): + return Server.init(self, host, prog, vers, port, \ + socket.SOCK_STREAM) + + def loop(self): + self.sock.listen(0) + while 1: + self.session(self.sock.accept()) + + def session(self, connection): + sock, (host, port) = connection + while 1: + try: + call = recvrecord(sock) + except EOFError: + break + reply = self.handle(call) + if reply <> None: + sendrecord(sock, reply) + + +class UDPServer(Server): + + def init(self, host, prog, vers, port): + return Server.init(self, host, prog, vers, port, \ + socket.SOCK_DGRAM) + + def loop(self): + while 1: + session() + + def session(self): + call, host_port = self.sock.recvfrom(8192) + reply = self.handle(call) + if reply <> None: + self.sock.sendto(reply, host_port) + + +# Simple test program -- dump local portmapper status + def test(): import T T.TSTART() @@ -423,3 +620,50 @@ def test(): elif prot == IPPROTO_UDP: print 'udp', else: print prot, print port + + +# Server and client test program. +# On machine A: python -c 'import rpc; rpc.testsvr()' +# On machine B: python -c 'import rpc; rpc.testclt()' A +# (A may be == B) + +def testsvr(): + # Simple test class -- proc 1 doubles its string argument as reply + class S(TCPServer): + def handle_1(self): + arg = self.unpacker.unpack_string() + self.turn_around() + print 'RPC function 1 called, arg', `arg` + self.packer.pack_string(arg + arg) + # + s = S().init('', 0x20000000, 1, 0) + try: + s.unregister() + except RuntimeError, msg: + print 'RuntimeError:', msg, '(ignored)' + s.register() + print 'Service started...' + try: + s.loop() + finally: + s.unregister() + print 'Service interrupted.' + + +def testclt(): + import sys + if sys.argv[1:]: host = sys.argv[1] + else: host = '' + # Client for above server + class C(TCPClient): + def call_1(self, arg): + self.start_call(1) + self.packer.pack_string(arg) + self.do_call() + reply = self.unpacker.unpack_string() + self.end_call() + return reply + c = C().init(host, 0x20000000, 1) + print 'making call...' + reply = c.call_1('hello, world, ') + print 'call returned', `reply` |