summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Demo/rpc/rpc.py281
1 files changed, 198 insertions, 83 deletions
diff --git a/Demo/rpc/rpc.py b/Demo/rpc/rpc.py
index d1c2c5e..00397dd 100644
--- a/Demo/rpc/rpc.py
+++ b/Demo/rpc/rpc.py
@@ -1,4 +1,4 @@
-# Implement (a subset of) Sun RPC, version 2 -- RFC1057.
+# Sun RPC version 2 -- RFC1057.
# XXX There should be separate exceptions for the various reasons why
# XXX an RPC can fail, rather than using RuntimeError for everything
@@ -177,8 +177,8 @@ class Client:
self.port = port
self.makesocket() # Assigns to self.sock
self.bindsocket()
- self.sock.connect((host, port))
- self.lastxid = 0
+ self.connsocket()
+ self.lastxid = 0 # XXX should be more random?
self.addpackers()
self.cred = None
self.verf = None
@@ -191,6 +191,10 @@ class Client:
# This MUST be overridden
raise RuntimeError, 'makesocket not defined'
+ def connsocket(self):
+ # Override this if you don't want/need a connection
+ self.sock.connect((self.host, self.port))
+
def bindsocket(self):
# Override this to bind to a different port (e.g. reserved)
self.sock.bind(('', 0))
@@ -200,6 +204,21 @@ class Client:
self.packer = Packer().init()
self.unpacker = Unpacker().init('')
+ def make_call(self, proc, args, pack_func, unpack_func):
+ # Don't normally override this (but see Broadcast)
+ if pack_func is None and args is not None:
+ raise TypeError, 'non-null args with null pack_func'
+ self.start_call(proc)
+ if pack_func:
+ pack_func(args)
+ self.do_call()
+ if unpack_func:
+ result = unpack_func()
+ else:
+ result = None
+ self.unpacker.done()
+ return result
+
def start_call(self, proc):
# Don't override this
self.lastxid = xid = self.lastxid + 1
@@ -209,14 +228,10 @@ class Client:
p.reset()
p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
- def do_call(self, *rest):
+ def do_call(self):
# This MUST be overridden
raise RuntimeError, 'do_call not defined'
- def end_call(self):
- # Don't override this
- self.unpacker.done()
-
def mkcred(self):
# Override this to use more powerful credentials
if self.cred == None:
@@ -230,9 +245,7 @@ class Client:
return self.verf
def Null(self): # Procedure 0 is always like this
- self.start_call(0)
- self.do_call(0)
- self.end_call()
+ return self.make_call(0, None, None, None)
# Record-Marking standard support
@@ -293,23 +306,14 @@ def bindresvport(sock, host):
raise RuntimeError, 'can\'t assign reserved port'
-# Raw TCP-based client
+# Client using TCP to a specific port
class RawTCPClient(Client):
def makesocket(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- def start_call(self, proc):
- self.lastxid = xid = self.lastxid + 1
- cred = self.mkcred()
- verf = self.mkverf()
- p = self.packer
- p.reset()
- p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
- def do_call(self, *rest):
- # rest is used for UDP buffer size; ignored for TCP
+ def do_call(self):
call = self.packer.get_buf()
sendrecord(self.sock, call)
reply = recvrecord(self.sock)
@@ -321,41 +325,25 @@ class RawTCPClient(Client):
raise RuntimeError, 'wrong xid in reply ' + `xid` + \
' instead of ' + `self.lastxid`
- def end_call(self):
- self.unpacker.done()
-
-# Raw UDP-based client
+# Client using UDP to a specific port
class RawUDPClient(Client):
def makesocket(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- def start_call(self, proc):
- self.lastxid = xid = self.lastxid + 1
- cred = self.mkcred()
- verf = self.mkverf()
- p = self.packer
- p.reset()
- p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
- def do_call(self, *rest):
+ def do_call(self):
+ call = self.packer.get_buf()
+ self.sock.send(call)
try:
from select import select
except ImportError:
print 'WARNING: select not found, RPC may hang'
select = None
- if len(rest) == 0:
- bufsize = 8192
- elif len(rest) > 1:
- raise TypeError, 'too many args'
- else:
- bufsize = rest[0] + 512
- call = self.packer.get_buf()
+ BUFSIZE = 8192 # Max UDP buffer size
timeout = 1
count = 5
- self.sock.send(call)
while 1:
r, w, x = [self.sock], [], []
if select:
@@ -367,7 +355,7 @@ class RawUDPClient(Client):
## print 'RESEND', timeout, count
self.sock.send(call)
continue
- reply = self.sock.recv(bufsize)
+ reply = self.sock.recv(BUFSIZE)
u = self.unpacker
u.reset(reply)
xid, verf = u.unpack_replyheader()
@@ -376,13 +364,70 @@ class RawUDPClient(Client):
continue
break
- def end_call(self):
- self.unpacker.done()
+# Client using UDP broadcast to a specific port
-# Port mapper interface
+class RawBroadcastUDPClient(RawUDPClient):
+
+ def init(self, bcastaddr, prog, vers, port):
+ self = RawUDPClient.init(self, bcastaddr, prog, vers, port)
+ self.reply_handler = None
+ self.timeout = 30
+ return self
+
+ def connsocket(self):
+ # Don't connect -- use sendto
+ self.sock.allowbroadcast(1)
+
+ def set_reply_handler(self, reply_handler):
+ self.reply_handler = reply_handler
-# XXX CALLIT is not implemented
+ def set_timeout(self, timeout):
+ self.timeout = timeout # Use None for infinite timeout
+
+ def make_call(self, proc, args, pack_func, unpack_func):
+ if pack_func is None and args is not None:
+ raise TypeError, 'non-null args with null pack_func'
+ self.start_call(proc)
+ if pack_func:
+ pack_func(args)
+ call = self.packer.get_buf()
+ self.sock.sendto(call, (self.host, self.port))
+ try:
+ from select import select
+ except ImportError:
+ print 'WARNING: select not found, broadcast will hang'
+ select = None
+ BUFSIZE = 8192 # Max UDP buffer size (for reply)
+ replies = []
+ if unpack_func is None:
+ def dummy(): pass
+ unpack_func = dummy
+ while 1:
+ r, w, x = [self.sock], [], []
+ if select:
+ if self.timeout is None:
+ r, w, x = select(r, w, x)
+ else:
+ r, w, x = select(r, w, x, self.timeout)
+ if self.sock not in r:
+ break
+ reply, fromaddr = self.sock.recvfrom(BUFSIZE)
+ u = self.unpacker
+ u.reset(reply)
+ xid, verf = u.unpack_replyheader()
+ if xid <> self.lastxid:
+## print 'BAD xid'
+ continue
+ reply = unpack_func()
+ self.unpacker.done()
+ replies.append((reply, fromaddr))
+ if self.reply_handler:
+ self.reply_handler(reply, fromaddr)
+ return replies
+
+
+# Port mapper interface
# Program number, version and (fixed!) port number
PMAP_PROG = 100000
@@ -421,6 +466,13 @@ class PortMapperPacker(Packer):
def pack_pmaplist(self, list):
self.pack_list(list, self.pack_mapping)
+ def pack_call_args(self, ca):
+ prog, vers, proc, args = ca
+ self.pack_uint(prog)
+ self.pack_uint(vers)
+ self.pack_uint(proc)
+ self.pack_opaque(args)
+
class PortMapperUnpacker(Unpacker):
@@ -434,6 +486,11 @@ class PortMapperUnpacker(Unpacker):
def unpack_pmaplist(self):
return self.unpack_list(self.unpack_mapping)
+ def unpack_call_result(self):
+ port = self.unpack_uint()
+ res = self.unpack_opaque()
+ return port, res
+
class PartialPortMapperClient:
@@ -442,35 +499,29 @@ class PartialPortMapperClient:
self.unpacker = PortMapperUnpacker().init('')
def Set(self, mapping):
- self.start_call(PMAPPROC_SET)
- self.packer.pack_mapping(mapping)
- self.do_call()
- res = self.unpacker.unpack_uint()
- self.end_call()
- return res
+ return self.make_call(PMAPPROC_SET, mapping, \
+ self.packer.pack_mapping, \
+ self.unpacker.unpack_uint)
def Unset(self, mapping):
- self.start_call(PMAPPROC_UNSET)
- self.packer.pack_mapping(mapping)
- self.do_call()
- res = self.unpacker.unpack_uint()
- self.end_call()
- return res
+ return self.make_call(PMAPPROC_UNSET, mapping, \
+ self.packer.pack_mapping, \
+ self.unpacker.unpack_uint)
def Getport(self, mapping):
- self.start_call(PMAPPROC_GETPORT)
- self.packer.pack_mapping(mapping)
- self.do_call(4)
- port = self.unpacker.unpack_uint()
- self.end_call()
- return port
+ return self.make_call(PMAPPROC_GETPORT, mapping, \
+ self.packer.pack_mapping, \
+ self.unpacker.unpack_uint)
def Dump(self):
- self.start_call(PMAPPROC_DUMP)
- self.do_call(8192-512)
- list = self.unpacker.unpack_pmaplist()
- self.end_call()
- return list
+ return self.make_call(PMAPPROC_DUMP, None, \
+ None, \
+ self.unpacker.unpack_pmaplist)
+
+ def Callit(self, ca):
+ return self.make_call(PMAPPROC_CALLIT, ca, \
+ self.packer.pack_call_args, \
+ self.unpacker.unpack_call_result)
class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient):
@@ -487,6 +538,16 @@ class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient):
host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
+class BroadcastUDPPortMapperClient(PartialPortMapperClient, \
+ RawBroadcastUDPClient):
+
+ def init(self, bcastaddr):
+ return RawBroadcastUDPClient.init(self, \
+ bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT)
+
+
+# Generic clients that find their server through the Port mapper
+
class TCPClient(RawTCPClient):
def init(self, host, prog, vers):
@@ -509,6 +570,51 @@ class UDPClient(RawUDPClient):
return RawUDPClient.init(self, host, prog, vers, port)
+class BroadcastUDPClient(Client):
+
+ def init(self, bcastaddr, prog, vers):
+ self.pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+ self.pmap.set_reply_handler(self.my_reply_handler)
+ self.prog = prog
+ self.vers = vers
+ self.user_reply_handler = None
+ self.addpackers()
+ return self
+
+ def close(self):
+ self.pmap.close()
+
+ def set_reply_handler(self, reply_handler):
+ self.user_reply_handler = reply_handler
+
+ def set_timeout(self, timeout):
+ self.pmap.set_timeout(timeout)
+
+ def my_reply_handler(self, reply, fromaddr):
+ port, res = reply
+ self.unpacker.reset(res)
+ result = self.unpack_func()
+ self.unpacker.done()
+ self.replies.append((result, fromaddr))
+ if self.user_reply_handler is not None:
+ self.user_reply_handler(result, fromaddr)
+
+ def make_call(self, proc, args, pack_func, unpack_func):
+ self.packer.reset()
+ if pack_func:
+ pack_func(args)
+ if unpack_func is None:
+ def dummy(): pass
+ self.unpack_func = dummy
+ else:
+ self.unpack_func = unpack_func
+ self.replies = []
+ packed_args = self.packer.get_buf()
+ dummy_replies = self.pmap.Callit( \
+ (self.prog, self.vers, proc, packed_args))
+ return self.replies
+
+
# Server classes
# These are not symmetric to the Client classes
@@ -657,14 +763,9 @@ class UDPServer(Server):
# Simple test program -- dump local portmapper status
def test():
- import T
- T.TSTART()
pmap = UDPPortMapperClient().init('')
- T.TSTOP()
pmap.Null()
- T.TSTOP()
list = pmap.Dump()
- T.TSTOP()
list.sort()
for prog, vers, prot, port in list:
print prog, vers,
@@ -674,7 +775,24 @@ def test():
print port
-# Server and client test program.
+# Test program for broadcast operation -- dump everybody's portmapper status
+
+def testbcast():
+ import sys
+ if sys.argv[1:]:
+ bcastaddr = sys.argv[1]
+ else:
+ bcastaddr = '<broadcast>'
+ def rh(reply, fromaddr):
+ host, port = fromaddr
+ print host + '\t' + `reply`
+ pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+ pmap.set_reply_handler(rh)
+ pmap.set_timeout(5)
+ replies = pmap.Getport((100002, 1, IPPROTO_UDP, 0))
+
+
+# Test program for server, with corresponding client
# On machine A: python -c 'import rpc; rpc.testsvr()'
# On machine B: python -c 'import rpc; rpc.testclt()' A
# (A may be == B)
@@ -709,12 +827,9 @@ def testclt():
# Client for above server
class C(UDPClient):
def call_1(self, arg):
- self.start_call(1)
- self.packer.pack_string(arg)
- self.do_call()
- reply = self.unpacker.unpack_string()
- self.end_call()
- return reply
+ return self.make_call(1, arg, \
+ self.packer.pack_string, \
+ self.unpacker.unpack_string)
c = C().init(host, 0x20000000, 1)
print 'making call...'
reply = c.call_1('hello, world, ')