diff options
Diffstat (limited to 'Modules/_ssl.c')
| -rw-r--r-- | Modules/_ssl.c | 477 | 
1 files changed, 386 insertions, 91 deletions
diff --git a/Modules/_ssl.c b/Modules/_ssl.c index d0f7115..f5b9199 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -1,4 +1,4 @@ -/* SSL socket module  +/* SSL socket module     SSL support based on patches by Brian E Gallew and Laszlo Kovacs. @@ -8,25 +8,44 @@  */  #include "Python.h" +  enum py_ssl_error {  	/* these mirror ssl.h */ -	PY_SSL_ERROR_NONE,                  -	PY_SSL_ERROR_SSL,                    -	PY_SSL_ERROR_WANT_READ,              -	PY_SSL_ERROR_WANT_WRITE,             -	PY_SSL_ERROR_WANT_X509_LOOKUP,       +	PY_SSL_ERROR_NONE, +	PY_SSL_ERROR_SSL, +	PY_SSL_ERROR_WANT_READ, +	PY_SSL_ERROR_WANT_WRITE, +	PY_SSL_ERROR_WANT_X509_LOOKUP,  	PY_SSL_ERROR_SYSCALL,     /* look at error stack/return value/errno */ -	PY_SSL_ERROR_ZERO_RETURN,            +	PY_SSL_ERROR_ZERO_RETURN,  	PY_SSL_ERROR_WANT_CONNECT, -	/* start of non ssl.h errorcodes */  +	/* start of non ssl.h errorcodes */  	PY_SSL_ERROR_EOF,         /* special case of SSL_ERROR_SYSCALL */  	PY_SSL_ERROR_INVALID_ERROR_CODE  }; +enum py_ssl_server_or_client { +	PY_SSL_CLIENT, +	PY_SSL_SERVER +}; + +enum py_ssl_cert_requirements { +	PY_SSL_CERT_NONE, +	PY_SSL_CERT_OPTIONAL, +	PY_SSL_CERT_REQUIRED +}; + +enum py_ssl_version { +	PY_SSL_VERSION_SSL2, +	PY_SSL_VERSION_SSL3, +	PY_SSL_VERSION_SSL23, +	PY_SSL_VERSION_TLS1, +}; +  /* Include symbols from _socket module */  #include "socketmodule.h" -#if defined(HAVE_POLL_H)  +#if defined(HAVE_POLL_H)  #include <poll.h>  #elif defined(HAVE_SYS_POLL_H)  #include <sys/poll.h> @@ -58,10 +77,10 @@ static PyObject *PySSLErrorObject;  typedef struct {  	PyObject_HEAD  	PySocketSockObject *Socket;	/* Socket on which we're layered */ -	SSL_CTX* 	ctx; -	SSL*     	ssl; -	X509*    	server_cert; -	char    	server[X509_NAME_MAXLEN]; +	SSL_CTX*	ctx; +	SSL*		ssl; +	X509*		peer_cert; +	char		server[X509_NAME_MAXLEN];  	char		issuer[X509_NAME_MAXLEN];  } PySSLObject; @@ -69,8 +88,10 @@ typedef struct {  static PyTypeObject PySSL_Type;  static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args);  static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args); -static int check_socket_and_wait_for_timeout(PySocketSockObject *s,  +static int check_socket_and_wait_for_timeout(PySocketSockObject *s,  					     int writing); +static PyObject *PySSL_peercert(PySSLObject *self); +  #define PySSLObject_Check(v)	(Py_Type(v) == &PySSL_Type) @@ -83,21 +104,27 @@ typedef enum {  	SOCKET_OPERATION_OK  } timeout_state; +/* Wrap error strings with filename and line # */ +#define STRINGIFY1(x) #x +#define STRINGIFY2(x) STRINGIFY1(x) +#define ERRSTR1(x,y,z) (x ":" y ": " z) +#define ERRSTR(x) ERRSTR1("_ssl.c", STRINGIFY2(__LINE__), x) +  /* XXX It might be helpful to augment the error message generated     below with the name of the SSL function that generated the error.     I expect it's obvious most of the time.  */  static PyObject * -PySSL_SetError(PySSLObject *obj, int ret) +PySSL_SetError(PySSLObject *obj, int ret, char *filename, int lineno)  { -	PyObject *v, *n, *s; +	PyObject *v;  	char *errstr;  	int err;  	enum py_ssl_error p;  	assert(ret <= 0); -     +  	err = SSL_get_error(obj->ssl, ret);  	switch (err) { @@ -141,12 +168,12 @@ PySSL_SetError(PySSLObject *obj, int ret)  			errstr = ERR_error_string(e, NULL);  		}  		break; -	}    +	}  	case SSL_ERROR_SSL:  	{  		unsigned long e = ERR_get_error();  		p = PY_SSL_ERROR_SSL; -		if (e != 0)  +		if (e != 0)  			/* XXX Protected by global interpreter lock */  			errstr = ERR_error_string(e, NULL);  		else { /* possible? */ @@ -158,29 +185,23 @@ PySSL_SetError(PySSLObject *obj, int ret)  		p = PY_SSL_ERROR_INVALID_ERROR_CODE;  		errstr = "Invalid error code";  	} -	n = PyInt_FromLong((long) p); -	if (n == NULL) -		return NULL; -	v = PyTuple_New(2); -	if (v == NULL) { -		Py_DECREF(n); -		return NULL; -	} -	s = PyString_FromString(errstr); -	if (s == NULL) { +	char buf[2048]; +	PyOS_snprintf(buf, sizeof(buf), "_ssl.c:%d: %s", lineno, errstr); +	v = Py_BuildValue("(is)", p, buf); +	if (v != NULL) { +		PyErr_SetObject(PySSLErrorObject, v);  		Py_DECREF(v); -		Py_DECREF(n);  	} -	PyTuple_SET_ITEM(v, 0, n); -	PyTuple_SET_ITEM(v, 1, s); -	PyErr_SetObject(PySSLErrorObject, v); -	Py_DECREF(v);  	return NULL;  }  static PySSLObject * -newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file) +newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file, +	       enum py_ssl_server_or_client socket_type, +	       enum py_ssl_cert_requirements certreq, +	       enum py_ssl_version proto_version, +	       char *cacerts_file)  {  	PySSLObject *self;  	char *errstr = NULL; @@ -193,31 +214,60 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  		return NULL;  	memset(self->server, '\0', sizeof(char) * X509_NAME_MAXLEN);  	memset(self->issuer, '\0', sizeof(char) * X509_NAME_MAXLEN); -	self->server_cert = NULL; +	self->peer_cert = NULL;  	self->ssl = NULL;  	self->ctx = NULL;  	self->Socket = NULL;  	if ((key_file && !cert_file) || (!key_file && cert_file)) { -		errstr = "Both the key & certificate files must be specified"; +		errstr = ERRSTR("Both the key & certificate files must be specified"); +		goto fail; +	} + +	if ((socket_type == PY_SSL_SERVER) && +	    ((key_file == NULL) || (cert_file == NULL))) { +		errstr = ERRSTR("Both the key & certificate files must be specified for server-side operation");  		goto fail;  	}  	Py_BEGIN_ALLOW_THREADS -	self->ctx = SSL_CTX_new(SSLv23_method()); /* Set up context */ +	if (proto_version == PY_SSL_VERSION_TLS1) +		self->ctx = SSL_CTX_new(TLSv1_method()); /* Set up context */ +	else if (proto_version == PY_SSL_VERSION_SSL3) +		self->ctx = SSL_CTX_new(SSLv3_method()); /* Set up context */ +	else if (proto_version == PY_SSL_VERSION_SSL2) +		self->ctx = SSL_CTX_new(SSLv2_method()); /* Set up context */ +	else +		self->ctx = SSL_CTX_new(SSLv23_method()); /* Set up context */  	Py_END_ALLOW_THREADS +  	if (self->ctx == NULL) { -		errstr = "SSL_CTX_new error"; +		errstr = ERRSTR("Invalid SSL protocol variant specified.");  		goto fail;  	} +	if (certreq != PY_SSL_CERT_NONE) { +		if (cacerts_file == NULL) { +			errstr = ERRSTR("No root certificates specified for verification of other-side certificates."); +			goto fail; +		} else { +			Py_BEGIN_ALLOW_THREADS +			ret = SSL_CTX_load_verify_locations(self->ctx, +							    cacerts_file, NULL); +			Py_END_ALLOW_THREADS +			if (ret < 1) { +				errstr = ERRSTR("SSL_CTX_load_verify_locations"); +				goto fail; +			} +		} +	}  	if (key_file) {  		Py_BEGIN_ALLOW_THREADS  		ret = SSL_CTX_use_PrivateKey_file(self->ctx, key_file, -						SSL_FILETYPE_PEM); +						  SSL_FILETYPE_PEM);  		Py_END_ALLOW_THREADS  		if (ret < 1) { -			errstr = "SSL_CTX_use_PrivateKey_file error"; +			errstr = ERRSTR("SSL_CTX_use_PrivateKey_file error");  			goto fail;  		} @@ -225,16 +275,23 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  		ret = SSL_CTX_use_certificate_chain_file(self->ctx,  						       cert_file);  		Py_END_ALLOW_THREADS -		SSL_CTX_set_options(self->ctx, SSL_OP_ALL); /* ssl compatibility */  		if (ret < 1) { -			errstr = "SSL_CTX_use_certificate_chain_file error"; +			errstr = ERRSTR("SSL_CTX_use_certificate_chain_file error") ;  			goto fail;  		} +		SSL_CTX_set_options(self->ctx, SSL_OP_ALL); /* ssl compatibility */  	} +	int verification_mode = SSL_VERIFY_NONE; +	if (certreq == PY_SSL_CERT_OPTIONAL) +		verification_mode = SSL_VERIFY_PEER; +	else if (certreq == PY_SSL_CERT_REQUIRED) +		verification_mode = (SSL_VERIFY_PEER | +				     SSL_VERIFY_FAIL_IF_NO_PEER_CERT); +	SSL_CTX_set_verify(self->ctx, verification_mode, +			   NULL); /* set verify lvl */ +  	Py_BEGIN_ALLOW_THREADS -	SSL_CTX_set_verify(self->ctx, -			   SSL_VERIFY_NONE, NULL); /* set verify lvl */  	self->ssl = SSL_new(self->ctx); /* New ssl struct */  	Py_END_ALLOW_THREADS  	SSL_set_fd(self->ssl, Sock->sock_fd);	/* Set the socket for SSL */ @@ -249,7 +306,10 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  	}  	Py_BEGIN_ALLOW_THREADS -	SSL_set_connect_state(self->ssl); +	if (socket_type == PY_SSL_CLIENT) +		SSL_set_connect_state(self->ssl); +	else +		SSL_set_accept_state(self->ssl);  	Py_END_ALLOW_THREADS  	/* Actually negotiate SSL connection */ @@ -257,11 +317,14 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  	sockstate = 0;  	do {  		Py_BEGIN_ALLOW_THREADS -		ret = SSL_connect(self->ssl); +		if (socket_type == PY_SSL_CLIENT) +			ret = SSL_connect(self->ssl); +		else +			ret = SSL_accept(self->ssl);  		err = SSL_get_error(self->ssl, ret);  		Py_END_ALLOW_THREADS  		if(PyErr_CheckSignals()) { -                        goto fail; +			goto fail;  		}  		if (err == SSL_ERROR_WANT_READ) {  			sockstate = check_socket_and_wait_for_timeout(Sock, 0); @@ -270,30 +333,33 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  		} else {  			sockstate = SOCKET_OPERATION_OK;  		} -	        if (sockstate == SOCKET_HAS_TIMED_OUT) { -			PyErr_SetString(PySSLErrorObject, "The connect operation timed out"); +		if (sockstate == SOCKET_HAS_TIMED_OUT) { +			PyErr_SetString(PySSLErrorObject, +				ERRSTR("The connect operation timed out"));  			goto fail;  		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { -			PyErr_SetString(PySSLErrorObject, "Underlying socket has been closed."); +			PyErr_SetString(PySSLErrorObject, +				ERRSTR("Underlying socket has been closed."));  			goto fail;  		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { -			PyErr_SetString(PySSLErrorObject, "Underlying socket too large for select()."); +			PyErr_SetString(PySSLErrorObject, +			  ERRSTR("Underlying socket too large for select()."));  			goto fail;  		} else if (sockstate == SOCKET_IS_NONBLOCKING) {  			break;  		}  	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); -	if (ret <= 0) { -		PySSL_SetError(self, ret); +	if (ret < 1) { +		PySSL_SetError(self, ret, __FILE__, __LINE__);  		goto fail;  	}  	self->ssl->debug = 1;  	Py_BEGIN_ALLOW_THREADS -	if ((self->server_cert = SSL_get_peer_certificate(self->ssl))) { -		X509_NAME_oneline(X509_get_subject_name(self->server_cert), +	if ((self->peer_cert = SSL_get_peer_certificate(self->ssl))) { +		X509_NAME_oneline(X509_get_subject_name(self->peer_cert),  				  self->server, X509_NAME_MAXLEN); -		X509_NAME_oneline(X509_get_issuer_name(self->server_cert), +		X509_NAME_oneline(X509_get_issuer_name(self->peer_cert),  				  self->issuer, X509_NAME_MAXLEN);  	}  	Py_END_ALLOW_THREADS @@ -310,25 +376,39 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file)  static PyObject *  PySocket_ssl(PyObject *self, PyObject *args)  { -	PySSLObject *rv;  	PySocketSockObject *Sock; +	int server_side = 0; +	int verification_mode = PY_SSL_CERT_NONE; +	int protocol = PY_SSL_VERSION_SSL23;  	char *key_file = NULL;  	char *cert_file = NULL; +	char *cacerts_file = NULL; -	if (!PyArg_ParseTuple(args, "O!|zz:ssl", +	if (!PyArg_ParseTuple(args, "O!i|zziiz:sslwrap",  			      PySocketModule.Sock_Type,  			      &Sock, -			      &key_file, &cert_file)) +			      &server_side, +			      &key_file, &cert_file, +			      &verification_mode, &protocol, +			      &cacerts_file))  		return NULL; -	rv = newPySSLObject(Sock, key_file, cert_file); -	if (rv == NULL) -		return NULL; -	return (PyObject *)rv; +	/* +	fprintf(stderr, +		"server_side is %d, keyfile %p, certfile %p, verify_mode %d, " +		"protocol %d, certs %p\n", +		server_side, key_file, cert_file, verification_mode, +		protocol, cacerts_file); +	 */ + +	return (PyObject *) newPySSLObject(Sock, key_file, cert_file, +					   server_side, verification_mode, +					   protocol, cacerts_file);  }  PyDoc_STRVAR(ssl_doc, -"ssl(socket, [keyfile, certfile]) -> sslobject"); +"sslwrap(socket, server_side, [keyfile, certfile, certs_mode, protocol,\n" +"                              cacertsfile]) -> sslobject");  /* SSL object methods */ @@ -344,15 +424,153 @@ PySSL_issuer(PySSLObject *self)  	return PyString_FromString(self->issuer);  } +static PyObject * +_create_dict_for_X509_NAME (X509_NAME *xname) +{ +	PyObject *pd = PyDict_New(); +	int index_counter; + +	for (index_counter = 0; +	     index_counter < X509_NAME_entry_count(xname); +	     index_counter++) +	{ +		char namebuf[X509_NAME_MAXLEN]; +		int buflen; + +		X509_NAME_ENTRY *entry = X509_NAME_get_entry(xname, +							     index_counter); + +		ASN1_OBJECT *name = X509_NAME_ENTRY_get_object(entry); +		buflen = OBJ_obj2txt(namebuf, sizeof(namebuf), name, 0); +		if (buflen < 0) +			goto fail0; +		PyObject *name_obj = PyString_FromStringAndSize(namebuf, +								buflen); +		if (name_obj == NULL) +			goto fail0; + +		ASN1_STRING *value = X509_NAME_ENTRY_get_data(entry); +		unsigned char *valuebuf = NULL; +		buflen = ASN1_STRING_to_UTF8(&valuebuf, value); +		if (buflen < 0) { +			Py_DECREF(name_obj); +			goto fail0; +		} +		PyObject *value_obj = PyUnicode_DecodeUTF8((char *) valuebuf, +							   buflen, "strict"); +		OPENSSL_free(valuebuf); +		if (value_obj == NULL) { +			Py_DECREF(name_obj); +			goto fail0; +		} +		if (PyDict_SetItem(pd, name_obj, value_obj) < 0) { +			Py_DECREF(name_obj); +			Py_DECREF(value_obj); +			goto fail0; +		} +		Py_DECREF(name_obj); +		Py_DECREF(value_obj); +	} +	return pd; + +  fail0: +	Py_XDECREF(pd); +	return NULL; +} + +static PyObject * +PySSL_peercert(PySSLObject *self) +{ +	PyObject *retval = NULL; +	BIO *biobuf = NULL; + +	if (!self->peer_cert) +		Py_RETURN_NONE; + +	retval = PyDict_New(); +	if (retval == NULL) +		return NULL; + +	int verification = SSL_CTX_get_verify_mode(self->ctx); +	if ((verification & SSL_VERIFY_PEER) == 0) +		return retval; + +	PyObject *peer = _create_dict_for_X509_NAME( +		X509_get_subject_name(self->peer_cert)); +	if (peer == NULL) +		goto fail0; +	if (PyDict_SetItemString(retval, (const char *) "subject", peer) < 0) { +		Py_DECREF(peer); +		goto fail0; +	} +	Py_DECREF(peer); + +	PyObject *issuer = _create_dict_for_X509_NAME( +		X509_get_issuer_name(self->peer_cert)); +	if (issuer == NULL) +		goto fail0; +	if (PyDict_SetItemString(retval, (const char *) "issuer", issuer) < 0) { +		Py_DECREF(issuer); +		goto fail0; +	} +	Py_DECREF(issuer); + +	PyObject *version = PyInt_FromLong(X509_get_version(self->peer_cert)); +	if (PyDict_SetItemString(retval, "version", version) < 0) { +		Py_DECREF(version); +		goto fail0; +	} +	Py_DECREF(version); + +	char buf[2048]; +	int len; + +	/* get a memory buffer */ +	biobuf = BIO_new(BIO_s_mem()); + +	ASN1_TIME *notBefore = X509_get_notBefore(self->peer_cert); +	ASN1_TIME_print(biobuf, notBefore); +	len = BIO_gets(biobuf, buf, sizeof(buf)-1); +	PyObject *pnotBefore = PyString_FromStringAndSize(buf, len); +	if (pnotBefore == NULL) +		goto fail1; +	if (PyDict_SetItemString(retval, "notBefore", pnotBefore) < 0) { +		Py_DECREF(pnotBefore); +		goto fail1; +	} +	Py_DECREF(pnotBefore); + +	BIO_reset(biobuf); +	ASN1_TIME *notAfter = X509_get_notAfter(self->peer_cert); +	ASN1_TIME_print(biobuf, notAfter); +	len = BIO_gets(biobuf, buf, sizeof(buf)-1); +	BIO_free(biobuf); +	PyObject *pnotAfter = PyString_FromStringAndSize(buf, len); +	if (pnotAfter == NULL) +		goto fail0; +	if (PyDict_SetItemString(retval, "notAfter", pnotAfter) < 0) { +		Py_DECREF(pnotAfter); +		goto fail0; +	} +	Py_DECREF(pnotAfter); +	return retval; + +  fail1: +	if (biobuf != NULL) +		BIO_free(biobuf); +  fail0: +	Py_XDECREF(retval); +	return NULL; +}  static void PySSL_dealloc(PySSLObject *self)  { -	if (self->server_cert)	/* Possible not to have one? */ -		X509_free (self->server_cert); +	if (self->peer_cert)	/* Possible not to have one? */ +		X509_free (self->peer_cert);  	if (self->ssl) -	    SSL_free(self->ssl); +		SSL_free(self->ssl);  	if (self->ctx) -	    SSL_CTX_free(self->ctx); +		SSL_CTX_free(self->ctx);  	Py_XDECREF(self->Socket);  	PyObject_Del(self);  } @@ -463,7 +681,7 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)  		} else {  			sockstate = SOCKET_OPERATION_OK;  		} -	        if (sockstate == SOCKET_HAS_TIMED_OUT) { +		if (sockstate == SOCKET_HAS_TIMED_OUT) {  			PyErr_SetString(PySSLErrorObject, "The write operation timed out");  			return NULL;  		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { @@ -476,7 +694,7 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)  	if (len > 0)  		return PyInt_FromLong(len);  	else -		return PySSL_SetError(self, len); +		return PySSL_SetError(self, len, __FILE__, __LINE__);  }  PyDoc_STRVAR(PySSL_SSLwrite_doc, @@ -498,7 +716,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)  	if (!(buf = PyString_FromStringAndSize((char *) 0, len)))  		return NULL; -	 +  	/* first check if there are bytes ready to be read */  	Py_BEGIN_ALLOW_THREADS  	count = SSL_pending(self->ssl); @@ -507,12 +725,28 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)  	if (!count) {  		sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);  		if (sockstate == SOCKET_HAS_TIMED_OUT) { -			PyErr_SetString(PySSLErrorObject, "The read operation timed out"); +			PyErr_SetString(PySSLErrorObject, +					"The read operation timed out");  			Py_DECREF(buf);  			return NULL;  		} else if (sockstate == SOCKET_TOO_LARGE_FOR_SELECT) { -			PyErr_SetString(PySSLErrorObject, "Underlying socket too large for select()."); +			PyErr_SetString(PySSLErrorObject, +				"Underlying socket too large for select()."); +			Py_DECREF(buf);  			return NULL; +		} else if (sockstate == SOCKET_HAS_BEEN_CLOSED) { +			if (SSL_get_shutdown(self->ssl) != +			    SSL_RECEIVED_SHUTDOWN) +			{ +				Py_DECREF(buf); +				PyErr_SetString(PySSLErrorObject, +				"Socket closed without SSL shutdown handshake"); +				return NULL; +			} else { +				/* should contain a zero-length string */ +				_PyString_Resize(&buf, 0); +				return buf; +			}  		}  	}  	do { @@ -526,23 +760,32 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)  			return NULL;  		}  		if (err == SSL_ERROR_WANT_READ) { -			sockstate = check_socket_and_wait_for_timeout(self->Socket, 0); +			sockstate =  +			  check_socket_and_wait_for_timeout(self->Socket, 0);  		} else if (err == SSL_ERROR_WANT_WRITE) { -			sockstate = check_socket_and_wait_for_timeout(self->Socket, 1); +			sockstate = +			  check_socket_and_wait_for_timeout(self->Socket, 1); +		} else if ((err == SSL_ERROR_ZERO_RETURN) && +			   (SSL_get_shutdown(self->ssl) == +			    SSL_RECEIVED_SHUTDOWN))  +		{ +			_PyString_Resize(&buf, 0); +			return buf;  		} else {  			sockstate = SOCKET_OPERATION_OK;  		} -	        if (sockstate == SOCKET_HAS_TIMED_OUT) { -			PyErr_SetString(PySSLErrorObject, "The read operation timed out"); +		if (sockstate == SOCKET_HAS_TIMED_OUT) { +			PyErr_SetString(PySSLErrorObject, +					"The read operation timed out");  			Py_DECREF(buf);  			return NULL;  		} else if (sockstate == SOCKET_IS_NONBLOCKING) {  			break;  		}  	} while (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE); - 	if (count <= 0) { +	if (count <= 0) {  		Py_DECREF(buf); -		return PySSL_SetError(self, count); +		return PySSL_SetError(self, count, __FILE__, __LINE__);  	}  	if (count != len)  		_PyString_Resize(&buf, count); @@ -554,13 +797,48 @@ PyDoc_STRVAR(PySSL_SSLread_doc,  \n\  Read up to len bytes from the SSL socket."); +static PyObject *PySSL_SSLshutdown(PySSLObject *self, PyObject *args) +{ +	int err; + +	/* Guard against closed socket */ +	if (self->Socket->sock_fd < 0) { +		PyErr_SetString(PySSLErrorObject, +				"Underlying socket has been closed."); +		return NULL; +	} + +	Py_BEGIN_ALLOW_THREADS +	err = SSL_shutdown(self->ssl); +	if (err == 0) { +		/* we need to call it again to finish the shutdown */ +		err = SSL_shutdown(self->ssl); +	} +	Py_END_ALLOW_THREADS + +	if (err < 0) +		return PySSL_SetError(self, err, __FILE__, __LINE__); +	else { +		Py_INCREF(self->Socket); +		return (PyObject *) (self->Socket); +	} +} + +PyDoc_STRVAR(PySSL_SSLshutdown_doc, +"shutdown(s) -> socket\n\ +\n\ +Does the SSL shutdown handshake with the remote end, and returns\n\ +the underlying socket object."); +  static PyMethodDef PySSLMethods[] = {  	{"write", (PyCFunction)PySSL_SSLwrite, METH_VARARGS, -	          PySSL_SSLwrite_doc}, +		  PySSL_SSLwrite_doc},  	{"read", (PyCFunction)PySSL_SSLread, METH_VARARGS, -	          PySSL_SSLread_doc}, +		  PySSL_SSLread_doc},  	{"server", (PyCFunction)PySSL_server, METH_NOARGS},  	{"issuer", (PyCFunction)PySSL_issuer, METH_NOARGS}, +	{"peer_certificate", (PyCFunction)PySSL_peercert, METH_NOARGS}, +	{"shutdown", (PyCFunction)PySSL_SSLshutdown, METH_NOARGS, PySSL_SSLshutdown_doc},  	{NULL, NULL}  }; @@ -654,17 +932,17 @@ if it does provide enough data to seed PRNG.");  /* List of functions exported by this module. */  static PyMethodDef PySSL_methods[] = { -	{"ssl",			PySocket_ssl, -	 METH_VARARGS, ssl_doc}, +	{"sslwrap",             PySocket_ssl, +         METH_VARARGS, ssl_doc},  #ifdef HAVE_OPENSSL_RAND -	{"RAND_add",            PySSL_RAND_add, METH_VARARGS,  +	{"RAND_add",            PySSL_RAND_add, METH_VARARGS,  	 PySSL_RAND_add_doc},  	{"RAND_egd",            PySSL_RAND_egd, METH_O,  	 PySSL_RAND_egd_doc},  	{"RAND_status",         (PyCFunction)PySSL_RAND_status, METH_NOARGS,  	 PySSL_RAND_status_doc},  #endif -	{NULL,			NULL}		 /* Sentinel */ +	{NULL,                  NULL}            /* Sentinel */  }; @@ -686,7 +964,7 @@ init_ssl(void)  	/* Load _socket module and its C API */  	if (PySocketModule_ImportModuleAndAPI()) - 	    	return; +		return;  	/* Init OpenSSL */  	SSL_load_error_strings(); @@ -694,11 +972,12 @@ init_ssl(void)  	/* Add symbols to module dict */  	PySSLErrorObject = PyErr_NewException("socket.sslerror", -                                               PySocketModule.error, -                                               NULL); +					      PySocketModule.error, +					      NULL);  	if (PySSLErrorObject == NULL)  		return; -	PyDict_SetItemString(d, "sslerror", PySSLErrorObject); +	if (PyDict_SetItemString(d, "sslerror", PySSLErrorObject) != 0) +		return;  	if (PyDict_SetItemString(d, "SSLType",  				 (PyObject *)&PySSL_Type) != 0)  		return; @@ -721,5 +1000,21 @@ init_ssl(void)  				PY_SSL_ERROR_EOF);  	PyModule_AddIntConstant(m, "SSL_ERROR_INVALID_ERROR_CODE",  				PY_SSL_ERROR_INVALID_ERROR_CODE); - +	/* cert requirements */ +	PyModule_AddIntConstant(m, "CERT_NONE", +				PY_SSL_CERT_NONE); +	PyModule_AddIntConstant(m, "CERT_OPTIONAL", +				PY_SSL_CERT_OPTIONAL); +	PyModule_AddIntConstant(m, "CERT_REQUIRED", +				PY_SSL_CERT_REQUIRED); + +	/* protocol versions */ +	PyModule_AddIntConstant(m, "PROTOCOL_SSLv2", +				PY_SSL_VERSION_SSL2); +	PyModule_AddIntConstant(m, "PROTOCOL_SSLv3", +				PY_SSL_VERSION_SSL3); +	PyModule_AddIntConstant(m, "PROTOCOL_SSLv23", +				PY_SSL_VERSION_SSL23); +	PyModule_AddIntConstant(m, "PROTOCOL_TLSv1", +				PY_SSL_VERSION_TLS1);  }  | 
