Skip to content

Commit

Permalink
Typing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
AA-Turner committed Mar 2, 2021
1 parent c433201 commit d609b72
Show file tree
Hide file tree
Showing 11 changed files with 276 additions and 169 deletions.
26 changes: 17 additions & 9 deletions src/auth.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from collections.abc import Callable
import functools

from flask import session

import dash
from dash.dependencies import Input, Output, State
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State

from src.config import password, user
from src.config import password
from src.config import user

from typing import Any, Optional, TypeVar, Union

users = {user: password}

# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
F = TypeVar('F', bound=Callable[..., Any])


def authenticate_user(credentials):
def authenticate_user(credentials: tuple[Optional[str], Optional[str]]) -> bool:
"""
generic authentication function
returns True if user is correct and False otherwise
Expand All @@ -26,14 +33,15 @@ def authenticate_user(credentials):
return authed


def validate_login_session(f):
def validate_login_session(f: F) -> Union[F, html.Div]:
"""
takes a layout function that returns layout objects
checks if the user is logged in or not through the session.
If not, returns an error with link to the login page
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Union[F, html.Div]:
# return f(*args, **kwargs)
if session.get("authed"):
return f(*args, **kwargs)
return html.Div(
Expand All @@ -48,7 +56,7 @@ def wrapper(*args, **kwargs):


# login layout content
def login_layout():
def login_layout() -> html.Div:
return html.Div(
html.Div([
html.H4("Login"),
Expand All @@ -62,7 +70,7 @@ def login_layout():
)


def setup_callbacks(app: dash.Dash):
def setup_callbacks(app: dash.Dash) -> None:
# authenticate
@app.callback(
[Output("login-url", "pathname"),
Expand All @@ -72,7 +80,7 @@ def setup_callbacks(app: dash.Dash):
State("login-password", "value"),
State("old-url", "data"), ],
prevent_initial_call=True)
def login_auth(_, username, pw, pathname):
def login_auth(_: int, username: Optional[str], pw: Optional[str], pathname: str) -> Union[tuple[dash.no_update, html.Div], tuple[str, str]]:
"""
check credentials
if correct, authenticate the session
Expand All @@ -82,7 +90,7 @@ def login_auth(_, username, pw, pathname):
pathname = "/home"

credentials = (username, pw)
if credentials.count(None) == len(credentials):
if credentials == (None, None):
return dash.no_update, dash.no_update

if authenticate_user(credentials):
Expand Down
35 changes: 19 additions & 16 deletions src/cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import datetime
import json

import dash
import flask_caching
import redis

import dash

import src.config as config

from typing import Any, Optional


class CacheInterface:
cache_path = config.DATA_ROOT / "cache" / "cache.json"
Expand All @@ -27,55 +30,55 @@ def __init__(self, app: dash.Dash):
self.environment = "prod" if config.is_production else "dev"
self.app = app

def _serialize_path(self, *path) -> str:
def _serialize_path(self, *path: str) -> str:
return "/".join([self.environment, *path])

@staticmethod
def _set_by_path(root: dict, items, value):
def _set_by_path(root: dict[str, Any], items: list[str], value: Any) -> None:
"""Set a value in a nested object in root by item sequence."""
for key in items[:-1]:
root.setdefault(key, {})
root = root[key]
root[items[-1]] = value

@staticmethod
def _keys_to_strings(keys: list, key_prefix: str = "") -> list:
return [key.decode("UTF8")[len(key_prefix):] for key in keys] # TODO key.removeprefix(root) when py3.9 is released (Oct 20)
def _keys_to_strings(keys: list[bytes], key_prefix: str = "") -> list[str]:
return [key.decode("UTF8").removeprefix(key_prefix) for key in keys]

def get_keys_from_partial(self, *path) -> list:
def get_keys_from_partial(self, *path: str) -> list[bytes]:
"""Return a list of keys from the global cache"""
return self.r.keys(f"{self.key_prefix}{self._serialize_path(*path)}*")

def get_string_keys_from_partial(self, *path) -> list:
def get_string_keys_from_partial(self, *path: str) -> list[str]:
keys = self.get_keys_from_partial(*path)
return self._keys_to_strings(keys, key_prefix=self.key_prefix)

def get_dict_from_partial(self, *path):
def get_dict_from_partial(self, *path: str) -> dict[str, Any]:
keys = self.get_keys_from_partial(*path)
root = self._serialize_path(*path) + "/"
stringified_keys = self._keys_to_strings(keys, self.key_prefix)
results = {key[len(root):]: self.cache.get(key) for key in stringified_keys} # TODO key.removeprefix(root) when py3.9 is released (Oct 20)
results = {key.removeprefix(root): self.cache.get(key) for key in stringified_keys}
tree = self._build_tree(results)
return tree

def get_from_cache(self, *path):
def get_from_cache(self, *path: str) -> Any:
"""Access a value from the global cache"""
return self.cache.get(self._serialize_path(*path))

def set_to_cache(self, *path, value=None):
def set_to_cache(self, *path: str, value: Optional[Any] = None) -> Any:
"""Access a value from the global cache"""
return self.cache.set(self._serialize_path(*path), value)

def build_tree(self):
def build_tree(self) -> dict[str, Any]:
keys = self.r.keys()
vals = self.cache.get_many(keys)
pairs = dict(zip(keys, vals))
tree = self._build_tree(pairs)
return tree

def _build_tree(self, pairs: dict) -> dict:
def _build_tree(self, pairs: dict[str, str]) -> dict[str, Any]:
"""Builds nested directory tree from key-val dict with path-like keys"""
tree = {}
tree: dict[str, Any] = {}
for key, val in pairs.items():
self._set_by_path(tree, key.split("/"), val)
return tree
Expand All @@ -87,7 +90,7 @@ def get_key_timestamp(self, key: str) -> datetime.datetime:
return datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc) # If key doesn't exist
return datetime.datetime.fromtimestamp(int(last_accessed), tz=datetime.timezone.utc)

def save_to_disk(self):
def save_to_disk(self) -> None:
# Ensure state is not changed whilst getting data
lock_key = "saving"
with self.r.lock(lock_key, timeout=5, blocking_timeout=2.5):
Expand Down Expand Up @@ -118,7 +121,7 @@ def save_to_disk(self):
json_pairs = json.dumps(pairs)
self.cache_path.write_text(json_pairs, encoding="UTF8")

def load_from_disk(self):
def load_from_disk(self) -> None:
try:
loaded_json = json.loads(self.cache_path.read_text(encoding="UTF8"))
except (FileNotFoundError, json.JSONDecodeError):
Expand Down
6 changes: 3 additions & 3 deletions src/components/homepage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import dash_core_components as dcc
import dash_html_components as html

from src.components.navbar import Navbar
from src.components import navbar

from typing import TYPE_CHECKING

Expand All @@ -23,8 +23,8 @@
], className="page-container static-container vertical-center home")


def Homepage(app: Dash):
def Homepage(app: Dash) -> html.Div:
return html.Div([
Navbar(app),
navbar.Navbar(app),
body
], className="page")
11 changes: 6 additions & 5 deletions src/components/navbar.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from dash.dependencies import ClientsideFunction, Input, Output, State
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import ClientsideFunction, Input, Output, State

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dash import Dash
from dash.development.base_component import Component


def Navbar(app: Dash, download: bool = False):
def Navbar(app: Dash, download: bool = False) -> html.Nav:
return html.Nav([
html.Div([
dcc.Link([
Expand All @@ -22,7 +23,7 @@ def Navbar(app: Dash, download: bool = False):
], className="navigation bg-nav")


def generate_nav_links(query: str = "", download: bool = False):
def generate_nav_links(query: str = "", download: bool = False) -> list[Component]:
query = query if query is not None else ""
links = []
if download:
Expand All @@ -34,12 +35,12 @@ def generate_nav_links(query: str = "", download: bool = False):
return links


def setup_callbacks(app: Dash):
def setup_callbacks(app: Dash) -> None:
@app.callback(Output("nav-list", "children"),
[Input('url', 'href')],
[State("report-query", "data")],
prevent_initial_call=True)
def update_nav_links(_, query_string):
def update_nav_links(_: str, query_string: str) -> list[Component]:
return generate_nav_links(query_string)

# app.clientside_callback(
Expand Down
32 changes: 20 additions & 12 deletions src/components/new_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import time

import dash
from dash.dependencies import ClientsideFunction, Input, Output, State
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import ClientsideFunction, Input, Output, State

import src.create_dashbord_helper as create_dashbord_helper
import src.components.navbar as navbar
from src import create_dashbord_helper
from src.components import navbar

from typing import Optional, TYPE_CHECKING, Union

if TYPE_CHECKING:
from dash.dash import _NoUpdate
from dash.development.base_component import Component

from src import cache as cache_int


def setup_callbacks(app: dash.Dash, cache):
def setup_callbacks(app: dash.Dash, cache: cache_int.CacheInterface) -> None:
create_upload_callback(app, "compliance-report", cache)
create_upload_callback(app, "training-report", cache)
create_button_callback(app, "button", "compliance-report", "training-report", cache)
Expand All @@ -35,7 +43,7 @@ def setup_callbacks(app: dash.Dash, cache):
)


def create_upload_callback(app: dash.Dash, upload_id: str, cache):
def create_upload_callback(app: dash.Dash, upload_id: str, cache: cache_int.CacheInterface) -> None:
@app.callback(Output(f"upload-{upload_id}", "children"),
[Input(f"upload-{upload_id}", "filename")],
[State(f"upload-{upload_id}", "contents")])
Expand All @@ -50,7 +58,7 @@ def update_output(filename: str, contents: str) -> html.Div:
return _upload_text(children=html.H5(filename))


def create_button_callback(app: dash.Dash, button_id: str, compliance_upload_id: str, training_upload_id: str, cache):
def create_button_callback(app: dash.Dash, button_id: str, compliance_upload_id: str, training_upload_id: str, cache: cache_int.CacheInterface) -> None:
@app.callback([Output(button_id, "children"),
Output("url", "pathname"),
Output("url", "search"),
Expand All @@ -61,7 +69,7 @@ def create_button_callback(app: dash.Dash, button_id: str, compliance_upload_id:
State("input-title", "value"),
State("input-disclosures", "value"), ],
prevent_initial_call=True)
def update_button(clicked: int, c_contents: str, t_contents: str, title: str, valid_disclosures: float) -> tuple:
def update_button(clicked: int, c_contents: str, t_contents: str, title: str, valid_disclosures: float) -> tuple[Union[str, _NoUpdate], ...]:
# TODO accept only CSV/XLS(X)?

# Short circuit if empty
Expand All @@ -70,10 +78,10 @@ def update_button(clicked: int, c_contents: str, t_contents: str, title: str, va

start_time = time.time()
ctx = dash.callback_context
num_outputs = len(ctx.outputs_list)
num_outputs = len(ctx.outputs_list) # TODO check this is 4

def update_button_text(new_text: str) -> tuple:
return (new_text,) + (dash.no_update, ) * (num_outputs - 1)
def update_button_text(new_text: str) -> tuple[str, _NoUpdate, _NoUpdate, _NoUpdate]:
return new_text, dash.no_update, dash.no_update, dash.no_update

# Check all values are present
inputs = [
Expand All @@ -95,7 +103,7 @@ def update_button_text(new_text: str) -> tuple:

app.server.logger.info(f"Input validation took: {time.time() - start_time}")

reports_parser = create_dashbord_helper.ReportsParser(app=app, session_id=True, cache=cache)
reports_parser = create_dashbord_helper.ReportsParser(cache=cache, app=app, session_id=True)
out = reports_parser.create_query_string(title, valid_disclosures)
value = out["value"]
app.server.logger.info(f"Report processing took: {time.time() - start_time}")
Expand All @@ -107,7 +115,7 @@ def update_button_text(new_text: str) -> tuple:
return update_button_text(out["value"])


def _upload_text(file_desc: str = None, children: str = None) -> html.Div:
def _upload_text(file_desc: Optional[str] = None, children: Optional[Union[str, Component]] = None) -> html.Div:
if children is None:
children = [
html.Span(f"Upload a {file_desc}"),
Expand Down
8 changes: 4 additions & 4 deletions src/components/nopage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
import dash_core_components as dcc
import dash_html_components as html

from src.components.navbar import Navbar
from src.components import navbar

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dash import Dash


def noPage(app: Dash):
def noPage(app: Dash) -> html.Div:
return html.Div([
# CC Header
Navbar(app),
navbar.Navbar(app),
page_not_found()
], className="page")


def page_not_found():
def page_not_found() -> html.Div:
return html.Div([
html.H3("The page you requested does't exist."),
dcc.Link([
Expand Down
Loading

0 comments on commit d609b72

Please sign in to comment.