|
1 | 1 | from __future__ import absolute_import |
2 | 2 | from abc import ABC, abstractmethod |
3 | 3 |
|
| 4 | +from dash import Dash |
| 5 | + |
4 | 6 |
|
5 | 7 | class Auth(ABC): |
6 | | - def __init__(self, app, authorization_hook=None, _overwrite_index=True): |
| 8 | + def __init__(self, app: Dash, **obsolete): |
| 9 | + """Auth base class for authentication in Dash. |
| 10 | +
|
| 11 | + :param app: Dash app |
| 12 | + """ |
| 13 | + |
| 14 | + # Deprecated arguments |
| 15 | + if obsolete: |
| 16 | + raise TypeError( |
| 17 | + f"Auth got unexpected keyword arguments: {list(obsolete)}" |
| 18 | + ) |
| 19 | + |
7 | 20 | self.app = app |
8 | | - self._index_view_name = app.config['routes_pathname_prefix'] |
9 | | - if _overwrite_index: |
10 | | - self._overwrite_index() |
11 | | - self._protect_views() |
12 | | - self._index_view_name = app.config['routes_pathname_prefix'] |
13 | | - self._auth_hooks = [authorization_hook] if authorization_hook else [] |
14 | | - |
15 | | - def _overwrite_index(self): |
16 | | - original_index = self.app.server.view_functions[self._index_view_name] |
17 | | - |
18 | | - self.app.server.view_functions[self._index_view_name] = \ |
19 | | - self.index_auth_wrapper(original_index) |
20 | | - |
21 | | - def _protect_views(self): |
22 | | - # TODO - allow users to white list in case they add their own views |
23 | | - for view_name, view_method in self.app.server.view_functions.items(): |
24 | | - if view_name != self._index_view_name: |
25 | | - self.app.server.view_functions[view_name] = \ |
26 | | - self.auth_wrapper(view_method) |
| 21 | + self._protect() |
| 22 | + |
| 23 | + def _protect(self): |
| 24 | + """Add a before_request authentication check on all routes. |
| 25 | +
|
| 26 | + The authentication check will pass if the request |
| 27 | + is authorised by `Auth.is_authorised` |
| 28 | + """ |
| 29 | + |
| 30 | + server = self.app.server |
| 31 | + |
| 32 | + @server.before_request |
| 33 | + def before_request_auth(): |
| 34 | + |
| 35 | + # Check whether the request is authorised |
| 36 | + if self.is_authorized(): |
| 37 | + return None |
| 38 | + |
| 39 | + # Otherwise, ask the user to log in |
| 40 | + return self.login_request() |
27 | 41 |
|
28 | 42 | def is_authorized_hook(self, func): |
29 | 43 | self._auth_hooks.append(func) |
|
0 commit comments