From 164f37ed3a7eff3c49c4a3ef0163f5a47f8f8198 Mon Sep 17 00:00:00 2001 From: Ali Afshar Date: Mon, 7 Jan 2013 14:05:45 -0800 Subject: [PATCH] Add push notification subscriptions. Reviewed in https://codereview.appspot.com/6488087/. --- apiclient/http.py | 15 +++ apiclient/push.py | 274 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_http.py | 22 ++++ tests/test_push.py | 272 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 583 insertions(+) create mode 100644 apiclient/push.py create mode 100644 tests/test_push.py diff --git a/apiclient/http.py b/apiclient/http.py index cd279b1..28dba64 100644 --- a/apiclient/http.py +++ b/apiclient/http.py @@ -617,6 +617,7 @@ class HttpRequest(object): self.http = http self.postproc = postproc self.resumable = resumable + self.response_callbacks = [] self._in_error_state = False # Pull the multipart boundary out of the content-type header. @@ -673,10 +674,24 @@ class HttpRequest(object): resp, content = http.request(str(self.uri), method=str(self.method), body=self.body, headers=self.headers) + for callback in self.response_callbacks: + callback(resp) if resp.status >= 300: raise HttpError(resp, content, uri=self.uri) return self.postproc(resp, content) + @util.positional(2) + def add_response_callback(self, cb): + """add_response_headers_callback + + Args: + cb: Callback to be called on receiving the response headers, of signature: + + def cb(resp): + # Where resp is an instance of httplib2.Response + """ + self.response_callbacks.append(cb) + @util.positional(1) def next_chunk(self, http=None): """Execute the next step of a resumable upload. diff --git a/apiclient/push.py b/apiclient/push.py new file mode 100644 index 0000000..c520faf --- /dev/null +++ b/apiclient/push.py @@ -0,0 +1,274 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Push notifications support. + +This code is based on experimental APIs and is subject to change. +""" + +__author__ = 'afshar@google.com (Ali Afshar)' + +import binascii +import collections +import os +import urllib + +SUBSCRIBE = 'X-GOOG-SUBSCRIBE' +SUBSCRIPTION_ID = 'X-GOOG-SUBSCRIPTION-ID' +TOPIC_ID = 'X-GOOG-TOPIC-ID' +TOPIC_URI = 'X-GOOG-TOPIC-URI' +CLIENT_TOKEN = 'X-GOOG-CLIENT-TOKEN' +EVENT_TYPE = 'X-GOOG-EVENT-TYPE' +UNSUBSCRIBE = 'X-GOOG-UNSUBSCRIBE' + + +class InvalidSubscriptionRequestError(ValueError): + """The request cannot be subscribed.""" + + +def new_token(): + """Gets a random token for use as a client_token in push notifications. + + Returns: + str, a new random token. + """ + return binascii.hexlify(os.urandom(32)) + + +class Channel(object): + """Base class for channel types.""" + + def __init__(self, channel_type, channel_args): + """Create a new Channel. + + You probably won't need to create this channel manually, since there are + subclassed Channel for each specific type with a more customized set of + arguments to pass. However, you may wish to just create it manually here. + + Args: + channel_type: str, the type of channel. + channel_args: dict, arguments to pass to the channel. + """ + self.channel_type = channel_type + self.channel_args = channel_args + + def as_header_value(self): + """Create the appropriate header for this channel. + + Returns: + str encoded channel description suitable for use as a header. + """ + return '%s?%s' % (self.channel_type, urllib.urlencode(self.channel_args)) + + def write_header(self, headers): + """Write the appropriate subscribe header to a headers dict. + + Args: + headers: dict, headers to add subscribe header to. + """ + headers[SUBSCRIBE] = self.as_header_value() + + +class WebhookChannel(Channel): + """Channel for registering web hook notifications.""" + + def __init__(self, url, app_engine=False): + """Create a new WebhookChannel + + Args: + url: str, URL to post notifications to. + app_engine: bool, default=False, whether the destination for the + notifications is an App Engine application. + """ + super(WebhookChannel, self).__init__( + channel_type='web_hook', + channel_args={ + 'url': url, + 'app_engine': app_engine and 'true' or 'false', + } + ) + + +class Headers(collections.defaultdict): + """Headers for managing subscriptions.""" + + + ALL_HEADERS = set([SUBSCRIBE, SUBSCRIPTION_ID, TOPIC_ID, TOPIC_URI, + CLIENT_TOKEN, EVENT_TYPE, UNSUBSCRIBE]) + + def __init__(self): + """Create a new subscription configuration instance.""" + collections.defaultdict.__init__(self, str) + + def __setitem__(self, key, value): + """Set a header value, ensuring the key is an allowed value. + + Args: + key: str, the header key. + value: str, the header value. + Raises: + ValueError if key is not one of the accepted headers. + """ + normal_key = self._normalize_key(key) + if normal_key not in self.ALL_HEADERS: + raise ValueError('Header name must be one of %s.' % self.ALL_HEADERS) + else: + return collections.defaultdict.__setitem__(self, normal_key, value) + + def __getitem__(self, key): + """Get a header value, normalizing the key case. + + Args: + key: str, the header key. + Returns: + String header value. + Raises: + KeyError if the key is not one of the accepted headers. + """ + normal_key = self._normalize_key(key) + if normal_key not in self.ALL_HEADERS: + raise ValueError('Header name must be one of %s.' % self.ALL_HEADERS) + else: + return collections.defaultdict.__getitem__(self, normal_key) + + def _normalize_key(self, key): + """Normalize a header name for use as a key.""" + return key.upper() + + def items(self): + """Generator for each header.""" + for header in self.ALL_HEADERS: + value = self[header] + if value: + yield header, value + + def write(self, headers): + """Applies the subscription headers. + + Args: + headers: dict of headers to insert values into. + """ + for header, value in self.items(): + headers[header.lower()] = value + + def read(self, headers): + """Read from headers. + + Args: + headers: dict of headers to read from. + """ + for header in self.ALL_HEADERS: + if header.lower() in headers: + self[header] = headers[header.lower()] + + +class Subscription(object): + """Information about a subscription.""" + + def __init__(self): + """Create a new Subscription.""" + self.headers = Headers() + + @classmethod + def for_request(cls, request, channel, client_token=None): + """Creates a subscription and attaches it to a request. + + Args: + request: An http.HttpRequest to modify for making a subscription. + channel: A apiclient.push.Channel describing the subscription to + create. + client_token: (optional) client token to verify the notification. + + Returns: + New subscription object. + """ + subscription = cls.for_channel(channel=channel, client_token=client_token) + subscription.headers.write(request.headers) + if request.method != 'GET': + raise InvalidSubscriptionRequestError( + 'Can only subscribe to requests which are GET.') + request.method = 'POST' + + def _on_response(response, subscription=subscription): + """Called with the response headers. Reads the subscription headers.""" + subscription.headers.read(response) + + request.add_response_callback(_on_response) + return subscription + + @classmethod + def for_channel(cls, channel, client_token=None): + """Alternate constructor to create a subscription from a channel. + + Args: + channel: A apiclient.push.Channel describing the subscription to + create. + client_token: (optional) client token to verify the notification. + + Returns: + New subscription object. + """ + subscription = cls() + channel.write_header(subscription.headers) + if client_token is None: + client_token = new_token() + subscription.headers[SUBSCRIPTION_ID] = new_token() + subscription.headers[CLIENT_TOKEN] = client_token + return subscription + + def verify(self, headers): + """Verifies that a webhook notification has the correct client_token. + + Args: + headers: dict of request headers for a push notification. + + Returns: + Boolean value indicating whether the notification is verified. + """ + new_subscription = Subscription() + new_subscription.headers.read(headers) + return new_subscription.client_token == self.client_token + + @property + def subscribe(self): + """Subscribe header value.""" + return self.headers[SUBSCRIBE] + + @property + def subscription_id(self): + """Subscription ID header value.""" + return self.headers[SUBSCRIPTION_ID] + + @property + def topic_id(self): + """Topic ID header value.""" + return self.headers[TOPIC_ID] + + @property + def topic_uri(self): + """Topic URI header value.""" + return self.headers[TOPIC_URI] + + @property + def client_token(self): + """Client Token header value.""" + return self.headers[CLIENT_TOKEN] + + @property + def event_type(self): + """Event Type header value.""" + return self.headers[EVENT_TYPE] + + @property + def unsubscribe(self): + """Unsuscribe header value.""" + return self.headers[UNSUBSCRIBE] diff --git a/tests/test_http.py b/tests/test_http.py index 81248eb..7fc9c51 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -827,5 +827,27 @@ class TestStreamSlice(unittest.TestCase): s = _StreamSlice(self.stream, 2, 1) self.assertEqual('2', s.read(-1)) + +class TestResponseCallback(unittest.TestCase): + """Test adding callbacks to responses.""" + + def test_ensure_response_callback(self): + m = JsonModel() + request = HttpRequest( + None, + m.response, + 'https://www.googleapis.com/someapi/v1/collection/?foo=bar', + method='POST', + body='{}', + headers={'content-type': 'application/json'}) + h = HttpMockSequence([ ({'status': 200}, '{}')]) + responses = [] + def _on_response(resp, responses=responses): + responses.append(resp) + request.add_response_callback(_on_response) + request.execute(http=h) + self.assertEqual(1, len(responses)) + + if __name__ == '__main__': unittest.main() diff --git a/tests/test_push.py b/tests/test_push.py new file mode 100644 index 0000000..5a42835 --- /dev/null +++ b/tests/test_push.py @@ -0,0 +1,272 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Push notifications tests.""" + +__author__ = 'afshar@google.com (Ali Afshar)' + +import unittest + +from apiclient import push +from apiclient import model +from apiclient import http + + +class ClientTokenGeneratorTest(unittest.TestCase): + + def test_next(self): + t = push.new_token() + self.assertTrue(t) + + +class ChannelTest(unittest.TestCase): + + def test_creation_noargs(self): + c = push.Channel(channel_type='my_channel_type', channel_args={}) + self.assertEqual('my_channel_type', c.channel_type) + self.assertEqual({}, c.channel_args) + + def test_creation_args(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b'}) + self.assertEqual('my_channel_type', c.channel_type) + self.assertEqual({'a':'b'}, c.channel_args) + + def test_as_header_value_noargs(self): + c = push.Channel(channel_type='my_channel_type', channel_args={}) + self.assertEqual('my_channel_type?', c.as_header_value()) + + def test_as_header_value_args(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b'}) + self.assertEqual('my_channel_type?a=b', c.as_header_value()) + + def test_as_header_value_args_space(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b c'}) + self.assertEqual('my_channel_type?a=b+c', c.as_header_value()) + + def test_as_header_value_args_escape(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b%c'}) + self.assertEqual('my_channel_type?a=b%25c', c.as_header_value()) + + def test_write_header_noargs(self): + c = push.Channel(channel_type='my_channel_type', channel_args={}) + headers = {} + c.write_header(headers) + self.assertEqual('my_channel_type?', headers['X-GOOG-SUBSCRIBE']) + + def test_write_header_args(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b'}) + headers = {} + c.write_header(headers) + self.assertEqual('my_channel_type?a=b', headers['X-GOOG-SUBSCRIBE']) + + def test_write_header_args_space(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b c'}) + headers = {} + c.write_header(headers) + self.assertEqual('my_channel_type?a=b+c', headers['X-GOOG-SUBSCRIBE']) + + def test_write_header_args_escape(self): + c = push.Channel(channel_type='my_channel_type', + channel_args={'a': 'b%c'}) + headers = {} + c.write_header(headers) + self.assertEqual('my_channel_type?a=b%25c', headers['X-GOOG-SUBSCRIBE']) + + +class WebhookChannelTest(unittest.TestCase): + + def test_creation_no_appengine(self): + c = push.WebhookChannel('http://example.org') + self.assertEqual('web_hook?url=http%3A%2F%2Fexample.org&app_engine=false', + c.as_header_value()) + + def test_creation_appengine(self): + c = push.WebhookChannel('http://example.org', app_engine=True) + self.assertEqual('web_hook?url=http%3A%2F%2Fexample.org&app_engine=true', + c.as_header_value()) + + +class HeadersTest(unittest.TestCase): + + def test_creation(self): + h = push.Headers() + self.assertEqual('', h[push.SUBSCRIBE]) + + def test_items(self): + h = push.Headers() + h[push.SUBSCRIBE] = 'my_channel_type' + self.assertEqual([(push.SUBSCRIBE, 'my_channel_type')], list(h.items())) + + def test_items_non_whitelisted(self): + h = push.Headers() + def set_bad_header(h=h): + h['X-Banana'] = 'my_channel_type' + self.assertRaises(ValueError, set_bad_header) + + def test_read(self): + h = push.Headers() + h.read({'x-goog-subscribe': 'my_channel_type'}) + self.assertEqual([(push.SUBSCRIBE, 'my_channel_type')], list(h.items())) + + def test_read_non_whitelisted(self): + h = push.Headers() + h.read({'X-Banana': 'my_channel_type'}) + self.assertEqual([], list(h.items())) + + def test_write(self): + h = push.Headers() + h[push.SUBSCRIBE] = 'my_channel_type' + headers = {} + h.write(headers) + self.assertEqual({'x-goog-subscribe': 'my_channel_type'}, headers) + + +class SubscriptionTest(unittest.TestCase): + + def test_create(self): + s = push.Subscription() + self.assertEqual('', s.client_token) + + def test_create_for_channnel(self): + c = push.WebhookChannel('http://example.org') + s = push.Subscription.for_channel(c) + self.assertTrue(s.client_token) + self.assertEqual('web_hook?url=http%3A%2F%2Fexample.org&app_engine=false', + s.subscribe) + + def test_create_for_channel_client_token(self): + c = push.WebhookChannel('http://example.org') + s = push.Subscription.for_channel(c, client_token='my_token') + self.assertEqual('my_token', s.client_token) + self.assertEqual('web_hook?url=http%3A%2F%2Fexample.org&app_engine=false', + s.subscribe) + + def test_subscribe(self): + s = push.Subscription() + s.headers[push.SUBSCRIBE] = 'my_header' + self.assertEqual('my_header', s.subscribe) + + def test_subscription_id(self): + s = push.Subscription() + s.headers[push.SUBSCRIPTION_ID] = 'my_header' + self.assertEqual('my_header', s.subscription_id) + + def test_subscription_id_set(self): + c = push.WebhookChannel('http://example.org') + s = push.Subscription.for_channel(c) + self.assertTrue(s.subscription_id) + + def test_topic_id(self): + s = push.Subscription() + s.headers[push.TOPIC_ID] = 'my_header' + self.assertEqual('my_header', s.topic_id) + + def test_topic_uri(self): + s = push.Subscription() + s.headers[push.TOPIC_URI] = 'my_header' + self.assertEqual('my_header', s.topic_uri) + + def test_client_token(self): + s = push.Subscription() + s.headers[push.CLIENT_TOKEN] = 'my_header' + self.assertEqual('my_header', s.client_token) + + def test_event_type(self): + s = push.Subscription() + s.headers[push.EVENT_TYPE] = 'my_header' + self.assertEqual('my_header', s.event_type) + + def test_unsubscribe(self): + s = push.Subscription() + s.headers[push.UNSUBSCRIBE] = 'my_header' + self.assertEqual('my_header', s.unsubscribe) + + def test_do_subscribe(self): + m = model.JsonModel() + request = http.HttpRequest( + None, + m.response, + 'https://www.googleapis.com/someapi/v1/collection/?foo=bar', + method='GET', + body='{}', + headers={'content-type': 'application/json'}) + h = http.HttpMockSequence([ + ({'status': 200, + 'X-Goog-Subscription-ID': 'my_subscription'}, + '{}')]) + c = push.Channel('my_channel', {}) + s = push.Subscription.for_request(request, c) + request.execute(http=h) + self.assertEqual('my_subscription', s.subscription_id) + + def test_subscribe_with_token(self): + m = model.JsonModel() + request = http.HttpRequest( + None, + m.response, + 'https://www.googleapis.com/someapi/v1/collection/?foo=bar', + method='GET', + body='{}', + headers={'content-type': 'application/json'}) + h = http.HttpMockSequence([ + ({'status': 200, + 'X-Goog-Subscription-ID': 'my_subscription'}, + '{}')]) + c = push.Channel('my_channel', {}) + s = push.Subscription.for_request(request, c, client_token='my_token') + request.execute(http=h) + self.assertEqual('my_subscription', s.subscription_id) + self.assertEqual('my_token', s.client_token) + + def test_verify_good_token(self): + s = push.Subscription() + s.headers['X-Goog-Client-Token'] = '123' + notification_headers = {'x-goog-client-token': '123'} + self.assertTrue(s.verify(notification_headers)) + + def test_verify_bad_token(self): + s = push.Subscription() + s.headers['X-Goog-Client-Token'] = '321' + notification_headers = {'x-goog-client-token': '123'} + self.assertFalse(s.verify(notification_headers)) + + def test_request_is_post(self): + m = model.JsonModel() + request = http.HttpRequest( + None, + m.response, + 'https://www.googleapis.com/someapi/v1/collection/?foo=bar', + method='GET', + body='{}', + headers={'content-type': 'application/json'}) + c = push.Channel('my_channel', {}) + push.Subscription.for_request(request, c) + self.assertEqual('POST', request.method) + + def test_non_get_error(self): + m = model.JsonModel() + request = http.HttpRequest( + None, + m.response, + 'https://www.googleapis.com/someapi/v1/collection/?foo=bar', + method='POST', + body='{}', + headers={'content-type': 'application/json'}) + c = push.Channel('my_channel', {}) + self.assertRaises(push.InvalidSubscriptionRequestError, + push.Subscription.for_request, request, c)