Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions DASMatrix/processing/backends/numba_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,36 @@

import numba
import numpy as np
from scipy import signal

from ...core.computation_graph import FusionNode


@numba.njit(fastmath=True)
def sos_filter_sample(val, sos, zi):
# sos: (n_sections, 6)
# zi: (n_sections, 2)
# Direct Form II Transposed
for s in range(sos.shape[0]):
b0, b1, b2, a0, a1, a2 = sos[s]

# In Scipy, a0 is usually 1.0. If not, we should normalize.
# But usually sos output from scipy is normalized.

x = val
# y[n] = b0*x[n] + z1[n-1]
y = b0 * x + zi[s, 0]

# z1[n] = b1*x[n] - a1*y[n] + z2[n-1]
zi[s, 0] = b1 * x - a1 * y + zi[s, 1]

# z2[n] = b2*x[n] - a2*y[n]
zi[s, 1] = b2 * x - a2 * y

val = y
return val


class NumbaBackend:
"""Numba 高性能计算后端。"""

Expand Down Expand Up @@ -53,7 +79,7 @@ def execute(self, node: FusionNode, data: np.ndarray) -> np.ndarray:
def _prepare_aux_params(
self, node: FusionNode, data: np.ndarray
) -> List[np.ndarray]:
"""为需要归约的算子准备参数 (如 mean, trend)。"""
"""为需要归约的算子准备参数 (如 mean, trend, filter coeffs)。"""
params = []
n_samples, n_channels = data.shape

Expand Down Expand Up @@ -93,6 +119,22 @@ def _prepare_aux_params(
means = np.mean(data, axis=0)
params.append(means.astype(data.dtype))

elif op.operation == "bandpass":
# Calculate SOS coefficients
low = op.kwargs.get("low")
high = op.kwargs.get("high")
order = op.kwargs.get("order", 4)
fs = op.kwargs.get("fs", 1000.0) # Default to 1000 if not provided

nyq = 0.5 * fs
sos = signal.butter(
order, [low / nyq, high / nyq], btype="band", output="sos"
)
# SOS shape: (n_sections, 6)
# It is global for all channels,
# but Numba kernel needs it as an argument
params.append(sos.astype(data.dtype))

return params

def _get_kernel_signature(self, node: FusionNode) -> str:
Expand All @@ -108,6 +150,8 @@ def _compile_kernel(self, node: FusionNode, has_aux: bool):
# 辅助参数名列表 (在 kernel 签名中使用)
aux_arg_names = []

pre_loop_code = [] # Code to execute before inner loop (inside j loop)

for op in node.fused_nodes:
if op.operation == "detrend":
# 需要 slope 和 intercept
Expand All @@ -133,24 +177,36 @@ def _compile_kernel(self, node: FusionNode, has_aux: bool):
ops_code.append(f"val = val * {factor}")

elif op.operation == "bandpass":
# Placeholder: Pass-through
# TODO: Implement IIR/FIR filter state or sosfilt
pass
sos_name = f"aux_{aux_idx}"
aux_arg_names.extend([sos_name])
aux_idx += 1

# Setup state variable for this channel
zi_name = f"zi_{aux_idx}"
# Get n_sections from sos shape at runtime or assuming fixed?
# We can use sos_name.shape[0]
pre_loop_code.append(
f"{zi_name} = np.zeros(({sos_name}.shape[0], 2), dtype=inp.dtype)"
)

# TODO: Add filter support (requires stateful loop or simple FIR/IIR)
# Apply filter
ops_code.append(f"val = sos_filter_sample(val, {sos_name}, {zi_name})")

kernel_body = "\n ".join(ops_code)
pre_loop_body = "\n ".join(pre_loop_code)

# 构建函数签名
base_args = ["inp", "out"]
all_args = base_args + aux_arg_names
args_str = ", ".join(all_args)

# Swapped loops: Parallel over cols (channels), Serial over rows (time)
code = f"""
def fused_kernel({args_str}):
rows, cols = inp.shape
for i in prange(rows):
for j in range(cols):
for j in prange(cols):
{pre_loop_body}
for i in range(rows):
val = inp[i, j]
{kernel_body}
out[i, j] = val
Expand All @@ -160,6 +216,8 @@ def fused_kernel({args_str}):
"numba": numba,
"prange": numba.prange,
"abs": abs,
"np": np,
"sos_filter_sample": sos_filter_sample
}

exec(code, global_scope)
Expand Down
2 changes: 1 addition & 1 deletion DASMatrix/processing/planner/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"envelope",
"normalize",
# 滤波器操作 (需要 SciPy)
"bandpass",
"lowpass",
"highpass",
"notch",
Expand All @@ -33,6 +32,7 @@
"demean",
"abs",
"scale",
"bandpass",
}


Expand Down
114 changes: 114 additions & 0 deletions tests/unit/test_numba_bandpass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Test Numba implementation of bandpass filter."""

import numpy as np
from scipy import signal

from DASMatrix.core.computation_graph import (
FusionNode,
OperationNode,
SourceNode,
)
from DASMatrix.processing.backends.numba_backend import NumbaBackend


class TestNumbaBandpass:
"""Test Bandpass implementation in Numba Backend."""

def test_bandpass_correctness(self):
"""Verify that Numba bandpass matches Scipy implementation (causal sosfilt)."""
# Setup data
rows = 2000 # Time
cols = 5 # Channels
fs = 1000.0
# Create random signal
np.random.seed(42)
data = np.random.randn(rows, cols).astype(np.float64)

# Create Graph Node
source = SourceNode(data, name="source")
# bandpass: low=10, high=100
low = 10.0
high = 100.0
order = 4

op_bandpass = OperationNode(
operation="bandpass",
inputs=[source],
name="bandpass_op",
kwargs={"low": low, "high": high, "order": order, "fs": fs}
)

# Fusion Node
fusion_node = FusionNode([op_bandpass], name="fused_bandpass")

# Execute with NumbaBackend
backend = NumbaBackend()
result_numba = backend.execute(fusion_node, data)

# Execute with Scipy (Reference)
nyq = 0.5 * fs
sos = signal.butter(
order, [low / nyq, high / nyq], btype="band", output="sos"
)
# We implemented sosfilt (Direct Form II Transposed), causal.
# Scipy default sosfilt is axis=-1, but here data is (Time, Channel).
# We need axis=0.
result_scipy = signal.sosfilt(sos, data, axis=0)

# Comparison
# Since we use float64, precision should be high.
# sosfilt might have slight numerical differences
# due to implementation details or order,
# but Numba implementation follows standard DF-II Transposed logic.

np.testing.assert_allclose(
result_numba,
result_scipy,
rtol=1e-5,
atol=1e-5,
err_msg="Numba bandpass result does not match Scipy sosfilt"
)

def test_bandpass_chain(self):
"""Verify bandpass works in a chain (e.g. detrend -> bandpass -> abs)."""
rows = 1000
cols = 3
fs = 500.0
data = np.random.randn(rows, cols) * 10.0 + 100.0 # Offset

# Add trend
t = np.arange(rows)
trend = t.reshape(-1, 1) * 0.1
data += trend

# 1. Detrend
# 2. Bandpass
# 3. Abs

# Reference Scipy Calculation
# Detrend
step1 = signal.detrend(data, axis=0)

# Bandpass
sos = signal.butter(4, [5.0/250.0, 50.0/250.0], btype="band", output="sos")
step2 = signal.sosfilt(sos, step1, axis=0)

# Abs
expected = np.abs(step2)

# Numba Calculation
source = SourceNode(data)
op1 = OperationNode("detrend", [source])
op2 = OperationNode(
"bandpass",
[op1],
kwargs={"low": 5.0, "high": 50.0, "order": 4, "fs": fs},
)
op3 = OperationNode("abs", [op2])

fusion_node = FusionNode([op1, op2, op3])

backend = NumbaBackend()
result_numba = backend.execute(fusion_node, data)

np.testing.assert_allclose(result_numba, expected, rtol=1e-5, atol=1e-5)
Loading