feat: make python2/3 compatible

Closes #3
This commit is contained in:
jrconlin 2016-06-06 14:01:05 -07:00
parent 4464fb972d
commit 1d218ce802
5 changed files with 35 additions and 21 deletions

View File

@ -1,2 +1,3 @@
[report] [report]
omit = *noseplugin* omit = *noseplugin*
show_missing = True

View File

@ -1,3 +1,6 @@
## 0.4.0 (2016-06-05)
feat: make python 2.7 / 3.5 polyglot
## 0.3.4 (2016-05-17) ## 0.3.4 (2016-05-17)
bug: make header keys case insenstive bug: make header keys case insenstive

View File

@ -86,6 +86,11 @@ class WebPusher:
the client. the client.
""" """
# Python 2 v. 3 hack
try:
self.basestr = basestring
except NameError:
self.basestr = str
if 'endpoint' not in subscription_info: if 'endpoint' not in subscription_info:
raise WebPushException("subscription_info missing endpoint URL") raise WebPushException("subscription_info missing endpoint URL")
if 'keys' not in subscription_info: if 'keys' not in subscription_info:
@ -95,22 +100,22 @@ class WebPusher:
for k in ['p256dh', 'auth']: for k in ['p256dh', 'auth']:
if keys.get(k) is None: if keys.get(k) is None:
raise WebPushException("Missing keys value: %s", k) raise WebPushException("Missing keys value: %s", k)
receiver_raw = base64.urlsafe_b64decode( if isinstance(keys[k], self.basestr):
self._repad(keys['p256dh'].encode('utf8'))) keys[k] = bytes(keys[k].encode('utf8'))
receiver_raw = base64.urlsafe_b64decode(self._repad(keys['p256dh']))
if len(receiver_raw) != 65 and receiver_raw[0] != "\x04": if len(receiver_raw) != 65 and receiver_raw[0] != "\x04":
raise WebPushException("Invalid p256dh key specified") raise WebPushException("Invalid p256dh key specified")
self.receiver_key = receiver_raw self.receiver_key = receiver_raw
self.auth_key = base64.urlsafe_b64decode( self.auth_key = base64.urlsafe_b64decode(self._repad(keys['auth']))
self._repad(keys['auth'].encode('utf8')))
def _repad(self, str): def _repad(self, data):
"""Add base64 padding to the end of a string, if required""" """Add base64 padding to the end of a string, if required"""
return str + "===="[:len(str) % 4] return data + b"===="[:len(data) % 4]
def encode(self, data): def encode(self, data):
"""Encrypt the data. """Encrypt the data.
:param data: A serialized block of data (String, JSON, bit array, :param data: A serialized block of byte data (String, JSON, bit array,
etc.) Make sure that whatever you send, your client knows how etc.) Make sure that whatever you send, your client knows how
to understand it. to understand it.
@ -124,6 +129,9 @@ class WebPusher:
# ID tag. # ID tag.
server_key_id = base64.urlsafe_b64encode(server_key.get_pubkey()[1:]) server_key_id = base64.urlsafe_b64encode(server_key.get_pubkey()[1:])
if isinstance(data, self.basestr):
data = bytes(data.encode('utf8'))
# http_ece requires that these both be set BEFORE encrypt or # http_ece requires that these both be set BEFORE encrypt or
# decrypt is called if you specify the key as "dh". # decrypt is called if you specify the key as "dh".
http_ece.keys[server_key_id] = server_key http_ece.keys[server_key_id] = server_key
@ -138,8 +146,8 @@ class WebPusher:
return CaseInsensitiveDict({ return CaseInsensitiveDict({
'crypto_key': base64.urlsafe_b64encode( 'crypto_key': base64.urlsafe_b64encode(
server_key.get_pubkey()).strip('='), server_key.get_pubkey()).strip(b'='),
'salt': base64.urlsafe_b64encode(salt).strip("="), 'salt': base64.urlsafe_b64encode(salt).strip(b'='),
'body': encrypted, 'body': encrypted,
}) })
@ -160,11 +168,12 @@ class WebPusher:
crypto_key = headers.get("crypto-key", "") crypto_key = headers.get("crypto-key", "")
if crypto_key: if crypto_key:
crypto_key += ',' crypto_key += ','
crypto_key += "keyid=p256dh;dh=" + encoded["crypto_key"] crypto_key += "keyid=p256dh;dh=" + encoded["crypto_key"].decode('utf8')
headers.update({ headers.update({
'crypto-key': crypto_key, 'crypto-key': crypto_key,
'content-encoding': 'aesgcm', 'content-encoding': 'aesgcm',
'encryption': "keyid=p256dh;salt=" + encoded['salt'], 'encryption': "keyid=p256dh;salt=" +
encoded['salt'].decode('utf8'),
}) })
if 'ttl' not in headers or ttl: if 'ttl' not in headers or ttl:
headers['ttl'] = ttl headers['ttl'] = ttl

View File

@ -15,9 +15,9 @@ class WebpushTestCase(unittest.TestCase):
return { return {
"endpoint": "https://example.com/", "endpoint": "https://example.com/",
"keys": { "keys": {
'auth': base64.urlsafe_b64encode(os.urandom(16)).strip('='), 'auth': base64.urlsafe_b64encode(os.urandom(16)).strip(b'='),
'p256dh': base64.urlsafe_b64encode( 'p256dh': base64.urlsafe_b64encode(
recv_key.get_pubkey()).strip('='), recv_key.get_pubkey()).strip(b'='),
} }
} }
@ -31,6 +31,11 @@ class WebpushTestCase(unittest.TestCase):
u"auth": u"k8JV6sjdbhAi1n3_LDBLvA" u"auth": u"k8JV6sjdbhAi1n3_LDBLvA"
} }
} }
rk_decode = (b'\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b'
b'\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1'
b'|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3'
b'\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00'
b'\xfb\xfc\xcc=\x1f')
self.assertRaises( self.assertRaises(
WebPushException, WebPushException,
WebPusher, WebPusher,
@ -55,12 +60,8 @@ class WebpushTestCase(unittest.TestCase):
push = WebPusher(subscription_info) push = WebPusher(subscription_info)
eq_(push.subscription_info, subscription_info) eq_(push.subscription_info, subscription_info)
eq_(push.receiver_key, ('\x04\xea\xe7"\xc9W\xadJ0\xd9P3(%\x00\x13\x8b' eq_(push.receiver_key, rk_decode)
'\x08l\xad4u\xa1\x19\n\xcc\x0eq\xff&\xdd1' eq_(push.auth_key, b'\x93\xc2U\xea\xc8\xddn\x10"\xd6}\xff,0K\xbc')
'|W\xcd\x81\xf8\xeaN\x83\x92[\x99\x82\xe0\xe3'
'\x89\x11r\xf4\x02\xd4M\xa00\x9b!\xb1F\x00'
'\xfb\xfc\xcc=\x1f'))
eq_(push.auth_key, '\x93\xc2U\xea\xc8\xddn\x10"\xd6}\xff,0K\xbc')
def test_encode(self): def test_encode(self):
recv_key = pyelliptic.ECC(curve="prime256v1") recv_key = pyelliptic.ECC(curve="prime256v1")
@ -88,7 +89,7 @@ class WebpushTestCase(unittest.TestCase):
authSecret=raw_auth authSecret=raw_auth
) )
eq_(decoded, data) eq_(decoded.decode('utf8'), data)
@patch("requests.post") @patch("requests.post")
def test_send(self, mock_post): def test_send(self, mock_post):

View File

@ -3,7 +3,7 @@ import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
__version__ = "0.3.4" __version__ = "0.4.0"
def read_from(file): def read_from(file):