Skip to content

Commit 655811f

Browse files
authored
Merge pull request #153 from simvue-io/feature/148-add-checkextra-decorator
Use `check_extra` decorator as user alert
2 parents 61d7813 + 92bc2c5 commit 655811f

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

simvue/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import requests
66

77
from .serialization import Deserializer
8-
from .utilities import get_auth, get_server_version
8+
from .utilities import get_auth, get_server_version, check_extra
99
from .converters import to_dataframe, metrics_to_dataframe
1010

1111
CONCURRENT_DOWNLOADS = 10
@@ -433,7 +433,8 @@ def get_metrics_multiple(self, runs, names, xaxis, max_points=0, aggregate=False
433433
return data
434434

435435
raise Exception(response.text)
436-
436+
437+
@check_extra("plot")
437438
def plot_metrics(self, runs, names, xaxis, max_points=0):
438439
"""
439440
Plot time series metrics from multiple runs and/or metrics

simvue/serialization.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from io import BytesIO
22
import pickle
3-
import plotly
3+
4+
5+
from .utilities import check_extra
46

57
class Serializer:
68
def serialize(self, data, allow_pickle=False):
@@ -51,16 +53,27 @@ def get_serializer(data, allow_pickle):
5153
return _serialize_pickle
5254
return None
5355

56+
@check_extra("plot")
5457
def _serialize_plotly_figure(data):
58+
try:
59+
import plotly
60+
except ImportError:
61+
return
5562
mimetype = 'application/vnd.plotly.v1+json'
5663
data = plotly.io.to_json(data, 'json')
5764
return data, mimetype
5865

66+
@check_extra("plot")
5967
def _serialize_matplotlib_figure(data):
68+
try:
69+
import plotly
70+
except ImportError:
71+
return None
6072
mimetype = 'application/vnd.plotly.v1+json'
6173
data = plotly.io.to_json(plotly.tools.mpl_to_plotly(data), 'json')
6274
return data, mimetype
6375

76+
@check_extra("dataset")
6477
def _serialize_numpy_array(data):
6578
try:
6679
import numpy as np
@@ -75,6 +88,7 @@ def _serialize_numpy_array(data):
7588
data = mfile.read()
7689
return data, mimetype
7790

91+
@check_extra("dataset")
7892
def _serialize_dataframe(data):
7993
mimetype = 'application/vnd.simvue.df.v1'
8094
mfile = BytesIO()
@@ -83,6 +97,7 @@ def _serialize_dataframe(data):
8397
data = mfile.read()
8498
return data, mimetype
8599

100+
@check_extra("torch")
86101
def _serialize_torch_tensor(data):
87102
try:
88103
import torch
@@ -127,14 +142,25 @@ def get_deserializer(mimetype, allow_pickle):
127142
return _deserialize_pickle
128143
return None
129144

145+
@check_extra("plot")
130146
def _deserialize_plotly_figure(data):
147+
try:
148+
import plotly
149+
except ImportError:
150+
return None
131151
data = plotly.io.from_json(data)
132152
return data
133153

154+
@check_extra("plot")
134155
def _deserialize_matplotlib_figure(data):
156+
try:
157+
import plotly
158+
except ImportError:
159+
return None
135160
data = plotly.io.from_json(data)
136161
return data
137162

163+
@check_extra("dataset")
138164
def _deserialize_numpy_array(data):
139165
try:
140166
import numpy as np
@@ -147,6 +173,7 @@ def _deserialize_numpy_array(data):
147173
data = np.load(mfile, allow_pickle=False)
148174
return data
149175

176+
@check_extra("dataset")
150177
def _deserialize_dataframe(data):
151178
try:
152179
import pandas as pd
@@ -159,6 +186,7 @@ def _deserialize_dataframe(data):
159186
data = pd.read_csv(mfile, index_col=0)
160187
return data
161188

189+
@check_extra("torch")
162190
def _deserialize_torch_tensor(data):
163191
try:
164192
import torch

simvue/utilities.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,36 @@
33
import logging
44
import os
55
import requests
6+
import typing
67

78
logger = logging.getLogger(__name__)
89

10+
def check_extra(extra_name: str) -> typing.Callable:
11+
def decorator(class_func: typing.Callable) -> typing.Callable:
12+
def wrapper(self, *args, **kwargs) -> typing.Any:
13+
if extra_name == "plot":
14+
try:
15+
import matplotlib
16+
import plotly
17+
except ImportError:
18+
raise RuntimeError(f"Plotting features require the '{extra_name}' extension to Simvue")
19+
elif extra_name == "torch":
20+
try:
21+
import torch
22+
except ImportError:
23+
raise RuntimeError(f"PyTorch features require the '{extra_name}' extension to Simvue")
24+
elif extra_name == "pandas":
25+
try:
26+
import pandas
27+
import numpy
28+
except ImportError:
29+
raise RuntimeError(f"Dataset features require the '{extra_name}' extension to Simvue")
30+
else:
31+
raise RuntimeError(f"Unrecognised extra '{extra_name}'")
32+
return class_func(self, *args, **kwargs)
33+
return wrapper
34+
return decorator
35+
936
def get_auth():
1037
"""
1138
Get the URL and access token

0 commit comments

Comments
 (0)