Skip to content

Commit c0fce4e

Browse files
authored
Merge pull request #261 from synsense/develop
v2.0.1
2 parents dbc3421 + ace99fd commit c0fce4e

File tree

7 files changed

+170
-30
lines changed

7 files changed

+170
-30
lines changed

ChangeLog

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
CHANGES
22
=======
33

4+
* Spike count plot in 'DynapcnnVisualizer' is optional
5+
* `DynapcnnVisualizer` allows custom JIT filters to make readout predictions
6+
47
v2.0.0
58
------
69

docs/speck/notebooks/nmnist_quick_start.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@
357357
"cell_type": "markdown",
358358
"metadata": {},
359359
"source": [
360-
"### Covert CNN To SNN"
360+
"### Convert CNN To SNN"
361361
]
362362
},
363363
{

docs/speck/visualizer.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ hardware_compatible_model.to(
131131
In order to visualize the class outputs as images, we need to get the images. The images should be passed in the same order as the output layer of the network. Important! <br>
132132
- If you want to visualize power measurements during streaming inference, set `add_power_monitor_plot`=`True`.
133133
- If you want to visualize readout images as class predictions during streaming you need to pass `add_readout_plot`=`True`.
134+
- If you don't want to visualize spike counts of output classes as line graphs over time during streaming you need to pass `add_spike_count_plot`=`False`.
134135
- In order to show a prediction for each `N` milliseconds, set the parameter `spike_collection_interval`=`N`.
135136
- In order to show the images, the paths of these images should be passed to `readout_images` parameter.
136137
- In order to show a prediction only if there are more than a `threshold` number of events from that output, set the `readout_prediction_threshold`=`threshold`.
@@ -172,4 +173,4 @@ The example script that runs the visualizer can be found under `/examples/visual
172173

173174

174175
#### MacOS users
175-
Due to the difference in the behaviour of python's multiprocessing library on MacOS, you should run the `examples/visualizer/gesture_viz.py` script with `-i` flag. `python -i /examples/visualizer/gesture_viz.py` .
176+
Due to the difference in the behaviour of python's multiprocessing library on MacOS, you should run the `examples/visualizer/gesture_viz.py` script with `-i` flag. `python -i /examples/visualizer/gesture_viz.py` .

sinabs/backend/dynapcnn/dynapcnn_network.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,9 +436,10 @@ def forward(self, x):
436436
# Send input
437437
self.samna_input_buffer.write(x)
438438
received_evts = []
439-
# Record at least until the last event has been replayed
440-
min_duration = max(event.timestamp for event in x) * 1e-6
441-
time.sleep(min_duration)
439+
440+
# Wait a minimum time to guarantee the events were played
441+
time.sleep(1)
442+
442443
# Keep recording if more events are being registered
443444
while True:
444445
prev_length = len(received_evts)

sinabs/backend/dynapcnn/dynapcnn_visualizer.py

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import socket
22
import warnings
3-
from typing import Dict, List, Optional, Tuple
3+
from typing import Callable, Dict, List, Optional, Tuple, Union
44

55
import samna
66

@@ -15,6 +15,7 @@ class DynapcnnVisualizer:
1515
# (tlx, tly, brx, bry)
1616
DEFAULT_LAYOUT_DS = [(0, 0, 0.5, 1), (0.5, 0, 1, 1), None, None]
1717
DEFAULT_LAYOUT_DSP = [(0, 0, 0.5, 0.66), (0.5, 0, 1, 0.66), None, (0, 0.66, 1, 1)]
18+
DEFAULT_LAYOUT_DRP = [(0, 0, 0.5, 0.66), None, (0.5, 0, 1, 0.66), (0, 0.66, 1, 1)]
1819
DEFAULT_LAYOUT_DSR = [(0, 0, 0.33, 1), (0.33, 0, 0.66, 1), (0.66, 0, 1, 1), None]
1920

2021
DEFAULT_LAYOUT_DSRP = [
@@ -27,6 +28,7 @@ class DynapcnnVisualizer:
2728
LAYOUTS_DICT = {
2829
"ds": DEFAULT_LAYOUT_DS,
2930
"dsp": DEFAULT_LAYOUT_DSP,
31+
"drp": DEFAULT_LAYOUT_DRP,
3032
"dsr": DEFAULT_LAYOUT_DSR,
3133
"dsrp": DEFAULT_LAYOUT_DSRP,
3234
}
@@ -36,6 +38,7 @@ def __init__(
3638
window_scale: Tuple[int, int] = (4, 8),
3739
dvs_shape: Tuple[int, int] = (128, 128), # height, width
3840
add_readout_plot: bool = False,
41+
add_spike_count_plot: bool = True,
3942
add_power_monitor_plot: bool = False,
4043
spike_collection_interval: int = 500,
4144
readout_prediction_threshold: int = 10,
@@ -46,6 +49,7 @@ def __init__(
4649
feature_names: Optional[List[str]] = None,
4750
readout_images: Optional[List[str]] = None,
4851
feature_count: Optional[int] = None,
52+
readout_node: Union[str, Callable] = "JitMajorityReadout",
4953
extra_arguments: Optional[Dict[str, Dict[str, any]]] = None,
5054
):
5155
"""Quick wrapper around Samna objects to get a basic dynapcnn visualizer.
@@ -58,6 +62,10 @@ def __init__(
5862
Defaults to (128, 128) -- Speck sensor resolution.
5963
add_readout_plot: bool (defaults to False)
6064
If set true adds a readout plot to the GUI
65+
It displays an icon for the currently predicted class.
66+
add_spike_count_plot: bool (defaults to True)
67+
If set true adds a spike count plot to the GUI.
68+
A line chart indicating the number of spikes over time.
6169
add_power_monitor_plot: bool (defaults to False)
6270
If set true adds a power monitor plot to the GUI.
6371
spike_collection_interval: int (defaults to 500) (in milliseconds)
@@ -91,14 +99,17 @@ def __init__(
9199
If the `feature_names` and `readout_images` was passed, this is not needed. Otherwise this parameter
92100
should be passed, so that the GUI knows how many lines should be drawn on the `Spike Count Plot` and
93101
`Readout Layer Plot`.
102+
readout_node: str or Callable
103+
Can either be a string "JitMajorityReadout" or a callable that returns a samna JIT filter
104+
to decide on the readout prediction. Function parameters can be defined freely.
94105
extra_arguments: Optional[Dict[str, Dict[str, any]]] (defaults to None)
95106
Extra arguments that can be passed to individual plots. Available keys are:
96107
- `spike_count`: Arguments that can be passed to `spike_count` plot.
97108
- `readout`: Arguments that can be passed to `readout` plot.
98109
- `power_measurement`: Arguments that can be passed `power_measurement` plot.
99110
"""
100111
# Checks if the configuration passed is valid
101-
if add_readout_plot and not readout_images:
112+
if add_readout_plot and readout_images is None:
102113
raise ValueError(
103114
"If a readout plot is to be displayed image paths should be passed as a list."
104115
+ "The order of the images, should match the model output."
@@ -112,7 +123,9 @@ def __init__(
112123
self.dvs_shape = dvs_shape
113124

114125
# Modify the GUI type based on the parameters
115-
self.gui_type = "ds"
126+
self.gui_type = "d"
127+
if add_spike_count_plot:
128+
self.gui_type += "s"
116129
if add_readout_plot:
117130
self.gui_type += "r"
118131
if add_power_monitor_plot:
@@ -126,6 +139,7 @@ def __init__(
126139
self.readout_default_return_value = readout_default_return_value
127140
self.readout_default_threshold_low = readout_default_threshold_low
128141
self.readout_default_threshold_high = readout_default_threshold_high
142+
self.readout_node = readout_node
129143

130144
# Power monitor components
131145
if power_monitor_number_of_items != 3 and power_monitor_number_of_items != 5:
@@ -338,13 +352,14 @@ def create_plots(self):
338352
plots = []
339353

340354
plots.append(self.add_dvs_plot(shape=self.dvs_shape, layout=layout[0]))
341-
if self.extra_arguments and "spike_count" in self.extra_argument.keys():
342-
spike_count_plot_args = self.extra_arguments["spike_count"]
343-
else:
344-
spike_count_plot_args = {}
345-
plots.append(
346-
self.add_spike_count_plot(layout=layout[1], **spike_count_plot_args)
347-
)
355+
if "s" in self.gui_type:
356+
if self.extra_arguments and "spike_count" in self.extra_argument.keys():
357+
spike_count_plot_args = self.extra_arguments["spike_count"]
358+
else:
359+
spike_count_plot_args = {}
360+
plots.append(
361+
self.add_spike_count_plot(layout=layout[1], **spike_count_plot_args)
362+
)
348363
if "r" in self.gui_type:
349364
try:
350365
if self.extra_arguments and "readout" in self.extra_arguments.keys():
@@ -508,19 +523,32 @@ def connect(
508523

509524
## Readout node
510525
if "r" in self.gui_type:
511-
(_, majority_readout_node, _) = self.streamer_graph.sequential(
512-
[
513-
spike_collection_node,
514-
samna.graph.JitMajorityReadout(samna.ui.Event),
515-
streamer_node,
516-
]
517-
)
518-
majority_readout_node.set_feature_count(self.feature_count)
519-
majority_readout_node.set_default_feature(self.readout_default_return_value)
520-
majority_readout_node.set_threshold_low(self.readout_default_threshold_low)
521-
majority_readout_node.set_threshold_high(
522-
self.readout_default_threshold_high
523-
)
526+
if self.readout_node == "JitMajorityReadout":
527+
(_, majority_readout_node, _) = self.streamer_graph.sequential(
528+
[
529+
spike_collection_node,
530+
samna.graph.JitMajorityReadout(samna.ui.Event),
531+
streamer_node,
532+
]
533+
)
534+
majority_readout_node.set_feature_count(self.feature_count)
535+
majority_readout_node.set_default_feature(
536+
self.readout_default_return_value
537+
)
538+
majority_readout_node.set_threshold_low(
539+
self.readout_default_threshold_low
540+
)
541+
majority_readout_node.set_threshold_high(
542+
self.readout_default_threshold_high
543+
)
544+
else:
545+
(_, majority_readout_node, _) = self.streamer_graph.sequential(
546+
[
547+
spike_collection_node,
548+
self.readout_node,
549+
streamer_node,
550+
]
551+
)
524552

525553
## Readout layer visualization
526554
if "o" in self.gui_type:
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Optional
2+
3+
import samna
4+
5+
6+
def majority_readout_filter(
7+
feature_count: int,
8+
default_feature: Optional[int] = None,
9+
detection_threshold: int = 0,
10+
threshold_low: int = 0,
11+
threshold_high: Optional[int] = None,
12+
):
13+
"""
14+
The default reaodut filter of samna's visualizer counts the total
15+
number of events received per timestep to decide whether a detection
16+
should be made or not.
17+
18+
The filter defined here allows for an additional `detection_threshold`
19+
parameter which is compared to the number of spikes of the most
20+
active class.
21+
In other words, for a class to be detected, there needs to be
22+
a minimum number of spikes for this class.
23+
"""
24+
25+
jit_src = f"""
26+
using InputT = speck2f::event::Spike;
27+
using OutputT = ui::Event;
28+
using ReadoutT = ui::Readout;
29+
30+
template<typename Spike>
31+
class CustomMajorityReadout : public iris::FilterInterface<std::shared_ptr<const std::vector<Spike>>, std::shared_ptr<const std::vector<OutputT>>> {{
32+
private:
33+
int featureCount = {feature_count};
34+
uint32_t defaultFeature = {default_feature if default_feature is not None else feature_count};
35+
int detectionThreshold = {detection_threshold};
36+
int thresholdLow = {threshold_low};
37+
int thresholdHigh = {threshold_high if threshold_high is not None else "std::numeric_limits<int>::max()"};
38+
39+
public:
40+
void apply() override
41+
{{
42+
while (const auto maybeSpikesPtr = this->receiveInput()) {{
43+
if (0 == featureCount) {{
44+
return;
45+
}}
46+
47+
auto outputCollection = std::make_shared<std::vector<OutputT>>();
48+
if ((*maybeSpikesPtr)->size() >= thresholdLow && (*maybeSpikesPtr)->size() <= thresholdHigh) {{
49+
std::unordered_map<uint32_t, int> sum; // feature -> count
50+
int maxCount = 0;
51+
uint32_t maxCountFeature = 0;
52+
int maxCountNum = 0;
53+
54+
for (const auto& spike : (**maybeSpikesPtr)) {{
55+
sum[spike.feature]++;
56+
}}
57+
58+
for (const auto& [feature, count] : sum) {{
59+
if (feature >= featureCount) {{
60+
continue;
61+
}}
62+
63+
if (count > maxCount) {{
64+
maxCount = count;
65+
maxCountFeature = feature;
66+
maxCountNum = 1;
67+
}}
68+
else if (count == maxCount) {{
69+
maxCountNum++;
70+
}}
71+
}}
72+
73+
if (maxCount > detectionThreshold && 1 == maxCountNum) {{
74+
outputCollection->emplace_back(ReadoutT{{maxCountFeature}});
75+
}}
76+
else {{
77+
outputCollection->emplace_back(ReadoutT{{defaultFeature}});
78+
}}
79+
}}
80+
else {{
81+
outputCollection->emplace_back(ReadoutT{{defaultFeature}});
82+
}}
83+
this->forwardResult(std::move(outputCollection));
84+
}}
85+
}}
86+
}};
87+
"""
88+
return samna.graph.JitFilter("CustomMajorityReadout", jit_src)

tests/test_dynapcnn/test_visualizer.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
from itertools import product
2+
from typing import Callable, Union
3+
14
import pytest
25
import samna
6+
from custom_jit_filters import majority_readout_filter as custom_filter
37
from hw_utils import find_open_devices, is_any_samna_device_connected
48

59
from sinabs.backend.dynapcnn.dynapcnn_visualizer import DynapcnnVisualizer
@@ -13,17 +17,32 @@ def X_available() -> bool:
1317
return p.returncode == 0
1418

1519

20+
vis_init_args = product(
21+
(True, False),
22+
(True, False),
23+
("JitMajorityReadout", custom_filter),
24+
)
25+
26+
1627
@pytest.mark.skipif(
1728
True,
1829
reason="A window needs to pop. Needs UI. Makes sense to check this test manually",
1930
)
20-
def test_visualizer_initialization():
31+
@pytest.mark.parametrize("spike_count_plot,readout_plot,readout_node", vis_init_args)
32+
def test_visualizer_initialization(
33+
spike_count_plot: bool, readout_plot: bool, readout_node: Union[str, Callable]
34+
):
2135
dvs_shape = (128, 128)
2236
spike_collection_interval = 500
2337
visualizer_id = 3
2438

2539
visualizer = DynapcnnVisualizer(
26-
dvs_shape=dvs_shape, spike_collection_interval=spike_collection_interval
40+
dvs_shape=dvs_shape,
41+
spike_collection_interval=spike_collection_interval,
42+
add_spike_count_plot=spike_count_plot,
43+
add_readout_plot=readout_plot,
44+
readout_node=readout_node,
45+
readout_images=[],
2746
)
2847
visualizer.create_visualizer_process(
2948
f"tcp://0.0.0.0:{visualizer.samna_visualizer_port}"

0 commit comments

Comments
 (0)