]> git.openstreetmap.org Git - osqa.git/blob - forum_modules/oauthauth/lib/oauth.py
deleting the test file
[osqa.git] / forum_modules / oauthauth / lib / oauth.py
1 """\r
2 The MIT License\r
3 \r
4 Copyright (c) 2007 Leah Culver\r
5 \r
6 Permission is hereby granted, free of charge, to any person obtaining a copy\r
7 of this software and associated documentation files (the "Software"), to deal\r
8 in the Software without restriction, including without limitation the rights\r
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r
10 copies of the Software, and to permit persons to whom the Software is\r
11 furnished to do so, subject to the following conditions:\r
12 \r
13 The above copyright notice and this permission notice shall be included in\r
14 all copies or substantial portions of the Software.\r
15 \r
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\r
22 THE SOFTWARE.\r
23 """\r
24 \r
25 import cgi\r
26 import urllib\r
27 import time\r
28 import random\r
29 import urlparse\r
30 import hmac\r
31 import binascii\r
32 \r
33 \r
34 VERSION = '1.0' # Hi Blaine!\r
35 HTTP_METHOD = 'GET'\r
36 SIGNATURE_METHOD = 'PLAINTEXT'\r
37 \r
38 \r
39 class OAuthError(RuntimeError):\r
40     """Generic exception class."""\r
41     def __init__(self, message='OAuth error occured.'):\r
42         self.message = message\r
43 \r
44 def build_authenticate_header(realm=''):\r
45     """Optional WWW-Authenticate header (401 error)"""\r
46     return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}\r
47 \r
48 def escape(s):\r
49     """Escape a URL including any /."""\r
50     return urllib.quote(s, safe='~')\r
51 \r
52 def _utf8_str(s):\r
53     """Convert unicode to utf-8."""\r
54     if isinstance(s, unicode):\r
55         return s.encode("utf-8")\r
56     else:\r
57         return str(s)\r
58 \r
59 def generate_timestamp():\r
60     """Get seconds since epoch (UTC)."""\r
61     return int(time.time())\r
62 \r
63 def generate_nonce(length=8):\r
64     """Generate pseudorandom number."""\r
65     return ''.join([str(random.randint(0, 9)) for i in range(length)])\r
66 \r
67 \r
68 class OAuthConsumer(object):\r
69     """Consumer of OAuth authentication.\r
70 \r
71     OAuthConsumer is a data type that represents the identity of the Consumer\r
72     via its shared secret with the Service Provider.\r
73 \r
74     """\r
75     key = None\r
76     secret = None\r
77 \r
78     def __init__(self, key, secret):\r
79         self.key = key\r
80         self.secret = secret\r
81 \r
82 \r
83 class OAuthToken(object):\r
84     """OAuthToken is a data type that represents an End User via either an access\r
85     or request token.\r
86 \r
87     key -- the token\r
88     secret -- the token secret\r
89 \r
90     """\r
91     key = None\r
92     secret = None\r
93 \r
94     def __init__(self, key, secret):\r
95         self.key = key\r
96         self.secret = secret\r
97 \r
98     def to_string(self):\r
99         return urllib.urlencode({'oauth_token': self.key,\r
100             'oauth_token_secret': self.secret})\r
101 \r
102     def from_string(s):\r
103         """ Returns a token from something like:\r
104         oauth_token_secret=xxx&oauth_token=xxx\r
105         """\r
106         params = cgi.parse_qs(s, keep_blank_values=False)\r
107         key = params['oauth_token'][0]\r
108         secret = params['oauth_token_secret'][0]\r
109         return OAuthToken(key, secret)\r
110     from_string = staticmethod(from_string)\r
111 \r
112     def __str__(self):\r
113         return self.to_string()\r
114 \r
115 \r
116 class OAuthRequest(object):\r
117     """OAuthRequest represents the request and can be serialized.\r
118 \r
119     OAuth parameters:\r
120         - oauth_consumer_key\r
121         - oauth_token\r
122         - oauth_signature_method\r
123         - oauth_signature\r
124         - oauth_timestamp\r
125         - oauth_nonce\r
126         - oauth_version\r
127         ... any additional parameters, as defined by the Service Provider.\r
128     """\r
129     parameters = None # OAuth parameters.\r
130     http_method = HTTP_METHOD\r
131     http_url = None\r
132     version = VERSION\r
133 \r
134     def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None):\r
135         self.http_method = http_method\r
136         self.http_url = http_url\r
137         self.parameters = parameters or {}\r
138 \r
139     def set_parameter(self, parameter, value):\r
140         self.parameters[parameter] = value\r
141 \r
142     def get_parameter(self, parameter):\r
143         try:\r
144             return self.parameters[parameter]\r
145         except:\r
146             raise OAuthError('Parameter not found: %s' % parameter)\r
147 \r
148     def _get_timestamp_nonce(self):\r
149         return self.get_parameter('oauth_timestamp'), self.get_parameter(\r
150             'oauth_nonce')\r
151 \r
152     def get_nonoauth_parameters(self):\r
153         """Get any non-OAuth parameters."""\r
154         parameters = {}\r
155         for k, v in self.parameters.iteritems():\r
156             # Ignore oauth parameters.\r
157             if k.find('oauth_') < 0:\r
158                 parameters[k] = v\r
159         return parameters\r
160 \r
161     def to_header(self, realm=''):\r
162         """Serialize as a header for an HTTPAuth request."""\r
163         auth_header = 'OAuth realm="%s"' % realm\r
164         # Add the oauth parameters.\r
165         if self.parameters:\r
166             for k, v in self.parameters.iteritems():\r
167                 if k[:6] == 'oauth_':\r
168                     auth_header += ', %s="%s"' % (k, escape(str(v)))\r
169         return {'Authorization': auth_header}\r
170 \r
171     def to_postdata(self):\r
172         """Serialize as post data for a POST request."""\r
173         return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \\r
174             for k, v in self.parameters.iteritems()])\r
175 \r
176     def to_url(self):\r
177         """Serialize as a URL for a GET request."""\r
178         return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())\r
179 \r
180     def get_normalized_parameters(self):\r
181         """Return a string that contains the parameters that must be signed."""\r
182         params = self.parameters\r
183         try:\r
184             # Exclude the signature if it exists.\r
185             del params['oauth_signature']\r
186         except:\r
187             pass\r
188         # Escape key values before sorting.\r
189         key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \\r
190             for k,v in params.items()]\r
191         # Sort lexicographically, first after key, then after value.\r
192         key_values.sort()\r
193         # Combine key value pairs into a string.\r
194         return '&'.join(['%s=%s' % (k, v) for k, v in key_values])\r
195 \r
196     def get_normalized_http_method(self):\r
197         """Uppercases the http method."""\r
198         return self.http_method.upper()\r
199 \r
200     def get_normalized_http_url(self):\r
201         """Parses the URL and rebuilds it to be scheme://host/path."""\r
202         parts = urlparse.urlparse(self.http_url)\r
203         scheme, netloc, path = parts[:3]\r
204         # Exclude default port numbers.\r
205         if scheme == 'http' and netloc[-3:] == ':80':\r
206             netloc = netloc[:-3]\r
207         elif scheme == 'https' and netloc[-4:] == ':443':\r
208             netloc = netloc[:-4]\r
209         return '%s://%s%s' % (scheme, netloc, path)\r
210 \r
211     def sign_request(self, signature_method, consumer, token):\r
212         """Set the signature parameter to the result of build_signature."""\r
213         # Set the signature method.\r
214         self.set_parameter('oauth_signature_method',\r
215             signature_method.get_name())\r
216         # Set the signature.\r
217         self.set_parameter('oauth_signature',\r
218             self.build_signature(signature_method, consumer, token))\r
219 \r
220     def build_signature(self, signature_method, consumer, token):\r
221         """Calls the build signature method within the signature method."""\r
222         return signature_method.build_signature(self, consumer, token)\r
223 \r
224     def from_request(http_method, http_url, headers=None, parameters=None,\r
225             query_string=None):\r
226         """Combines multiple parameter sources."""\r
227         if parameters is None:\r
228             parameters = {}\r
229 \r
230         # Headers\r
231         if headers and 'Authorization' in headers:\r
232             auth_header = headers['Authorization']\r
233             # Check that the authorization header is OAuth.\r
234             if auth_header.index('OAuth') > -1:\r
235                 auth_header = auth_header.lstrip('OAuth ')\r
236                 try:\r
237                     # Get the parameters from the header.\r
238                     header_params = OAuthRequest._split_header(auth_header)\r
239                     parameters.update(header_params)\r
240                 except:\r
241                     raise OAuthError('Unable to parse OAuth parameters from '\r
242                         'Authorization header.')\r
243 \r
244         # GET or POST query string.\r
245         if query_string:\r
246             query_params = OAuthRequest._split_url_string(query_string)\r
247             parameters.update(query_params)\r
248 \r
249         # URL parameters.\r
250         param_str = urlparse.urlparse(http_url)[4] # query\r
251         url_params = OAuthRequest._split_url_string(param_str)\r
252         parameters.update(url_params)\r
253 \r
254         if parameters:\r
255             return OAuthRequest(http_method, http_url, parameters)\r
256 \r
257         return None\r
258     from_request = staticmethod(from_request)\r
259 \r
260     def from_consumer_and_token(oauth_consumer, token=None,\r
261             http_method=HTTP_METHOD, http_url=None, parameters=None):\r
262         if not parameters:\r
263             parameters = {}\r
264 \r
265         defaults = {\r
266             'oauth_consumer_key': oauth_consumer.key,\r
267             'oauth_timestamp': generate_timestamp(),\r
268             'oauth_nonce': generate_nonce(),\r
269             'oauth_version': OAuthRequest.version,\r
270         }\r
271 \r
272         defaults.update(parameters)\r
273         parameters = defaults\r
274 \r
275         if token:\r
276             parameters['oauth_token'] = token.key\r
277 \r
278         return OAuthRequest(http_method, http_url, parameters)\r
279     from_consumer_and_token = staticmethod(from_consumer_and_token)\r
280 \r
281     def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,\r
282             http_url=None, parameters=None):\r
283         if not parameters:\r
284             parameters = {}\r
285 \r
286         parameters['oauth_token'] = token.key\r
287 \r
288         if callback:\r
289             parameters['oauth_callback'] = callback\r
290 \r
291         return OAuthRequest(http_method, http_url, parameters)\r
292     from_token_and_callback = staticmethod(from_token_and_callback)\r
293 \r
294     def _split_header(header):\r
295         """Turn Authorization: header into parameters."""\r
296         params = {}\r
297         parts = header.split(',')\r
298         for param in parts:\r
299             # Ignore realm parameter.\r
300             if param.find('realm') > -1:\r
301                 continue\r
302             # Remove whitespace.\r
303             param = param.strip()\r
304             # Split key-value.\r
305             param_parts = param.split('=', 1)\r
306             # Remove quotes and unescape the value.\r
307             params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))\r
308         return params\r
309     _split_header = staticmethod(_split_header)\r
310 \r
311     def _split_url_string(param_str):\r
312         """Turn URL string into parameters."""\r
313         parameters = cgi.parse_qs(param_str, keep_blank_values=False)\r
314         for k, v in parameters.iteritems():\r
315             parameters[k] = urllib.unquote(v[0])\r
316         return parameters\r
317     _split_url_string = staticmethod(_split_url_string)\r
318 \r
319 class OAuthServer(object):\r
320     """A worker to check the validity of a request against a data store."""\r
321     timestamp_threshold = 300 # In seconds, five minutes.\r
322     version = VERSION\r
323     signature_methods = None\r
324     data_store = None\r
325 \r
326     def __init__(self, data_store=None, signature_methods=None):\r
327         self.data_store = data_store\r
328         self.signature_methods = signature_methods or {}\r
329 \r
330     def set_data_store(self, data_store):\r
331         self.data_store = data_store\r
332 \r
333     def get_data_store(self):\r
334         return self.data_store\r
335 \r
336     def add_signature_method(self, signature_method):\r
337         self.signature_methods[signature_method.get_name()] = signature_method\r
338         return self.signature_methods\r
339 \r
340     def fetch_request_token(self, oauth_request):\r
341         """Processes a request_token request and returns the\r
342         request token on success.\r
343         """\r
344         try:\r
345             # Get the request token for authorization.\r
346             token = self._get_token(oauth_request, 'request')\r
347         except OAuthError:\r
348             # No token required for the initial token request.\r
349             version = self._get_version(oauth_request)\r
350             consumer = self._get_consumer(oauth_request)\r
351             self._check_signature(oauth_request, consumer, None)\r
352             # Fetch a new token.\r
353             token = self.data_store.fetch_request_token(consumer)\r
354         return token\r
355 \r
356     def fetch_access_token(self, oauth_request):\r
357         """Processes an access_token request and returns the\r
358         access token on success.\r
359         """\r
360         version = self._get_version(oauth_request)\r
361         consumer = self._get_consumer(oauth_request)\r
362         # Get the request token.\r
363         token = self._get_token(oauth_request, 'request')\r
364         self._check_signature(oauth_request, consumer, token)\r
365         new_token = self.data_store.fetch_access_token(consumer, token)\r
366         return new_token\r
367 \r
368     def verify_request(self, oauth_request):\r
369         """Verifies an api call and checks all the parameters."""\r
370         # -> consumer and token\r
371         version = self._get_version(oauth_request)\r
372         consumer = self._get_consumer(oauth_request)\r
373         # Get the access token.\r
374         token = self._get_token(oauth_request, 'access')\r
375         self._check_signature(oauth_request, consumer, token)\r
376         parameters = oauth_request.get_nonoauth_parameters()\r
377         return consumer, token, parameters\r
378 \r
379     def authorize_token(self, token, user):\r
380         """Authorize a request token."""\r
381         return self.data_store.authorize_request_token(token, user)\r
382 \r
383     def get_callback(self, oauth_request):\r
384         """Get the callback URL."""\r
385         return oauth_request.get_parameter('oauth_callback')\r
386 \r
387     def build_authenticate_header(self, realm=''):\r
388         """Optional support for the authenticate header."""\r
389         return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}\r
390 \r
391     def _get_version(self, oauth_request):\r
392         """Verify the correct version request for this server."""\r
393         try:\r
394             version = oauth_request.get_parameter('oauth_version')\r
395         except:\r
396             version = VERSION\r
397         if version and version != self.version:\r
398             raise OAuthError('OAuth version %s not supported.' % str(version))\r
399         return version\r
400 \r
401     def _get_signature_method(self, oauth_request):\r
402         """Figure out the signature with some defaults."""\r
403         try:\r
404             signature_method = oauth_request.get_parameter(\r
405                 'oauth_signature_method')\r
406         except:\r
407             signature_method = SIGNATURE_METHOD\r
408         try:\r
409             # Get the signature method object.\r
410             signature_method = self.signature_methods[signature_method]\r
411         except:\r
412             signature_method_names = ', '.join(self.signature_methods.keys())\r
413             raise OAuthError('Signature method %s not supported try one of the '\r
414                 'following: %s' % (signature_method, signature_method_names))\r
415 \r
416         return signature_method\r
417 \r
418     def _get_consumer(self, oauth_request):\r
419         consumer_key = oauth_request.get_parameter('oauth_consumer_key')\r
420         consumer = self.data_store.lookup_consumer(consumer_key)\r
421         if not consumer:\r
422             raise OAuthError('Invalid consumer.')\r
423         return consumer\r
424 \r
425     def _get_token(self, oauth_request, token_type='access'):\r
426         """Try to find the token for the provided request token key."""\r
427         token_field = oauth_request.get_parameter('oauth_token')\r
428         token = self.data_store.lookup_token(token_type, token_field)\r
429         if not token:\r
430             raise OAuthError('Invalid %s token: %s' % (token_type, token_field))\r
431         return token\r
432 \r
433     def _check_signature(self, oauth_request, consumer, token):\r
434         timestamp, nonce = oauth_request._get_timestamp_nonce()\r
435         self._check_timestamp(timestamp)\r
436         self._check_nonce(consumer, token, nonce)\r
437         signature_method = self._get_signature_method(oauth_request)\r
438         try:\r
439             signature = oauth_request.get_parameter('oauth_signature')\r
440         except:\r
441             raise OAuthError('Missing signature.')\r
442         # Validate the signature.\r
443         valid_sig = signature_method.check_signature(oauth_request, consumer,\r
444             token, signature)\r
445         if not valid_sig:\r
446             key, base = signature_method.build_signature_base_string(\r
447                 oauth_request, consumer, token)\r
448             raise OAuthError('Invalid signature. Expected signature base '\r
449                 'string: %s' % base)\r
450         built = signature_method.build_signature(oauth_request, consumer, token)\r
451 \r
452     def _check_timestamp(self, timestamp):\r
453         """Verify that timestamp is recentish."""\r
454         timestamp = int(timestamp)\r
455         now = int(time.time())\r
456         lapsed = now - timestamp\r
457         if lapsed > self.timestamp_threshold:\r
458             raise OAuthError('Expired timestamp: given %d and now %s has a '\r
459                 'greater difference than threshold %d' %\r
460                 (timestamp, now, self.timestamp_threshold))\r
461 \r
462     def _check_nonce(self, consumer, token, nonce):\r
463         """Verify that the nonce is uniqueish."""\r
464         nonce = self.data_store.lookup_nonce(consumer, token, nonce)\r
465         if nonce:\r
466             raise OAuthError('Nonce already used: %s' % str(nonce))\r
467 \r
468 \r
469 class OAuthClient(object):\r
470     """OAuthClient is a worker to attempt to execute a request."""\r
471     consumer = None\r
472     token = None\r
473 \r
474     def __init__(self, oauth_consumer, oauth_token):\r
475         self.consumer = oauth_consumer\r
476         self.token = oauth_token\r
477 \r
478     def get_consumer(self):\r
479         return self.consumer\r
480 \r
481     def get_token(self):\r
482         return self.token\r
483 \r
484     def fetch_request_token(self, oauth_request):\r
485         """-> OAuthToken."""\r
486         raise NotImplementedError\r
487 \r
488     def fetch_access_token(self, oauth_request):\r
489         """-> OAuthToken."""\r
490         raise NotImplementedError\r
491 \r
492     def access_resource(self, oauth_request):\r
493         """-> Some protected resource."""\r
494         raise NotImplementedError\r
495 \r
496 \r
497 class OAuthDataStore(object):\r
498     """A database abstraction used to lookup consumers and tokens."""\r
499 \r
500     def lookup_consumer(self, key):\r
501         """-> OAuthConsumer."""\r
502         raise NotImplementedError\r
503 \r
504     def lookup_token(self, oauth_consumer, token_type, token_token):\r
505         """-> OAuthToken."""\r
506         raise NotImplementedError\r
507 \r
508     def lookup_nonce(self, oauth_consumer, oauth_token, nonce):\r
509         """-> OAuthToken."""\r
510         raise NotImplementedError\r
511 \r
512     def fetch_request_token(self, oauth_consumer):\r
513         """-> OAuthToken."""\r
514         raise NotImplementedError\r
515 \r
516     def fetch_access_token(self, oauth_consumer, oauth_token):\r
517         """-> OAuthToken."""\r
518         raise NotImplementedError\r
519 \r
520     def authorize_request_token(self, oauth_token, user):\r
521         """-> OAuthToken."""\r
522         raise NotImplementedError\r
523 \r
524 \r
525 class OAuthSignatureMethod(object):\r
526     """A strategy class that implements a signature method."""\r
527     def get_name(self):\r
528         """-> str."""\r
529         raise NotImplementedError\r
530 \r
531     def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):\r
532         """-> str key, str raw."""\r
533         raise NotImplementedError\r
534 \r
535     def build_signature(self, oauth_request, oauth_consumer, oauth_token):\r
536         """-> str."""\r
537         raise NotImplementedError\r
538 \r
539     def check_signature(self, oauth_request, consumer, token, signature):\r
540         built = self.build_signature(oauth_request, consumer, token)\r
541         return built == signature\r
542 \r
543 \r
544 class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):\r
545 \r
546     def get_name(self):\r
547         return 'HMAC-SHA1'\r
548 \r
549     def build_signature_base_string(self, oauth_request, consumer, token):\r
550         sig = (\r
551             escape(oauth_request.get_normalized_http_method()),\r
552             escape(oauth_request.get_normalized_http_url()),\r
553             escape(oauth_request.get_normalized_parameters()),\r
554         )\r
555 \r
556         key = '%s&' % escape(consumer.secret)\r
557         if token:\r
558             key += escape(token.secret)\r
559         raw = '&'.join(sig)\r
560         return key, raw\r
561 \r
562     def build_signature(self, oauth_request, consumer, token):\r
563         """Builds the base signature string."""\r
564         key, raw = self.build_signature_base_string(oauth_request, consumer,\r
565             token)\r
566 \r
567         # HMAC object.\r
568         try:\r
569             import hashlib # 2.5\r
570             hashed = hmac.new(key, raw, hashlib.sha1)\r
571         except:\r
572             import sha # Deprecated\r
573             hashed = hmac.new(key, raw, sha)\r
574 \r
575         # Calculate the digest base 64.\r
576         return binascii.b2a_base64(hashed.digest())[:-1]\r
577 \r
578 \r
579 class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):\r
580 \r
581     def get_name(self):\r
582         return 'PLAINTEXT'\r
583 \r
584     def build_signature_base_string(self, oauth_request, consumer, token):\r
585         """Concatenates the consumer key and secret."""\r
586         sig = '%s&' % escape(consumer.secret)\r
587         if token:\r
588             sig = sig + escape(token.secret)\r
589         return sig, sig\r
590 \r
591     def build_signature(self, oauth_request, consumer, token):\r
592         key, raw = self.build_signature_base_string(oauth_request, consumer,\r
593             token)\r
594         return key