diff --git a/DASMatrix/processing/backends/numba_backend.py b/DASMatrix/processing/backends/numba_backend.py index ecdf50d..c97b3aa 100644 --- a/DASMatrix/processing/backends/numba_backend.py +++ b/DASMatrix/processing/backends/numba_backend.py @@ -4,10 +4,36 @@ import numba import numpy as np +from scipy import signal from ...core.computation_graph import FusionNode +@numba.njit(fastmath=True) +def sos_filter_sample(val, sos, zi): + # sos: (n_sections, 6) + # zi: (n_sections, 2) + # Direct Form II Transposed + for s in range(sos.shape[0]): + b0, b1, b2, a0, a1, a2 = sos[s] + + # In Scipy, a0 is usually 1.0. If not, we should normalize. + # But usually sos output from scipy is normalized. + + x = val + # y[n] = b0*x[n] + z1[n-1] + y = b0 * x + zi[s, 0] + + # z1[n] = b1*x[n] - a1*y[n] + z2[n-1] + zi[s, 0] = b1 * x - a1 * y + zi[s, 1] + + # z2[n] = b2*x[n] - a2*y[n] + zi[s, 1] = b2 * x - a2 * y + + val = y + return val + + class NumbaBackend: """Numba 高性能计算后端。""" @@ -53,7 +79,7 @@ def execute(self, node: FusionNode, data: np.ndarray) -> np.ndarray: def _prepare_aux_params( self, node: FusionNode, data: np.ndarray ) -> List[np.ndarray]: - """为需要归约的算子准备参数 (如 mean, trend)。""" + """为需要归约的算子准备参数 (如 mean, trend, filter coeffs)。""" params = [] n_samples, n_channels = data.shape @@ -93,6 +119,22 @@ def _prepare_aux_params( means = np.mean(data, axis=0) params.append(means.astype(data.dtype)) + elif op.operation == "bandpass": + # Calculate SOS coefficients + low = op.kwargs.get("low") + high = op.kwargs.get("high") + order = op.kwargs.get("order", 4) + fs = op.kwargs.get("fs", 1000.0) # Default to 1000 if not provided + + nyq = 0.5 * fs + sos = signal.butter( + order, [low / nyq, high / nyq], btype="band", output="sos" + ) + # SOS shape: (n_sections, 6) + # It is global for all channels, + # but Numba kernel needs it as an argument + params.append(sos.astype(data.dtype)) + return params def _get_kernel_signature(self, node: FusionNode) -> str: @@ -108,6 +150,8 @@ def _compile_kernel(self, node: FusionNode, has_aux: bool): # 辅助参数名列表 (在 kernel 签名中使用) aux_arg_names = [] + pre_loop_code = [] # Code to execute before inner loop (inside j loop) + for op in node.fused_nodes: if op.operation == "detrend": # 需要 slope 和 intercept @@ -133,24 +177,36 @@ def _compile_kernel(self, node: FusionNode, has_aux: bool): ops_code.append(f"val = val * {factor}") elif op.operation == "bandpass": - # Placeholder: Pass-through - # TODO: Implement IIR/FIR filter state or sosfilt - pass + sos_name = f"aux_{aux_idx}" + aux_arg_names.extend([sos_name]) + aux_idx += 1 + + # Setup state variable for this channel + zi_name = f"zi_{aux_idx}" + # Get n_sections from sos shape at runtime or assuming fixed? + # We can use sos_name.shape[0] + pre_loop_code.append( + f"{zi_name} = np.zeros(({sos_name}.shape[0], 2), dtype=inp.dtype)" + ) - # TODO: Add filter support (requires stateful loop or simple FIR/IIR) + # Apply filter + ops_code.append(f"val = sos_filter_sample(val, {sos_name}, {zi_name})") kernel_body = "\n ".join(ops_code) + pre_loop_body = "\n ".join(pre_loop_code) # 构建函数签名 base_args = ["inp", "out"] all_args = base_args + aux_arg_names args_str = ", ".join(all_args) + # Swapped loops: Parallel over cols (channels), Serial over rows (time) code = f""" def fused_kernel({args_str}): rows, cols = inp.shape - for i in prange(rows): - for j in range(cols): + for j in prange(cols): + {pre_loop_body} + for i in range(rows): val = inp[i, j] {kernel_body} out[i, j] = val @@ -160,6 +216,8 @@ def fused_kernel({args_str}): "numba": numba, "prange": numba.prange, "abs": abs, + "np": np, + "sos_filter_sample": sos_filter_sample } exec(code, global_scope) diff --git a/DASMatrix/processing/planner/optimizer.py b/DASMatrix/processing/planner/optimizer.py index 1aa5871..36da49f 100644 --- a/DASMatrix/processing/planner/optimizer.py +++ b/DASMatrix/processing/planner/optimizer.py @@ -20,7 +20,6 @@ "envelope", "normalize", # 滤波器操作 (需要 SciPy) - "bandpass", "lowpass", "highpass", "notch", @@ -33,6 +32,7 @@ "demean", "abs", "scale", + "bandpass", } diff --git a/tests/unit/test_numba_bandpass.py b/tests/unit/test_numba_bandpass.py new file mode 100644 index 0000000..62d52fd --- /dev/null +++ b/tests/unit/test_numba_bandpass.py @@ -0,0 +1,114 @@ +"""Test Numba implementation of bandpass filter.""" + +import numpy as np +from scipy import signal + +from DASMatrix.core.computation_graph import ( + FusionNode, + OperationNode, + SourceNode, +) +from DASMatrix.processing.backends.numba_backend import NumbaBackend + + +class TestNumbaBandpass: + """Test Bandpass implementation in Numba Backend.""" + + def test_bandpass_correctness(self): + """Verify that Numba bandpass matches Scipy implementation (causal sosfilt).""" + # Setup data + rows = 2000 # Time + cols = 5 # Channels + fs = 1000.0 + # Create random signal + np.random.seed(42) + data = np.random.randn(rows, cols).astype(np.float64) + + # Create Graph Node + source = SourceNode(data, name="source") + # bandpass: low=10, high=100 + low = 10.0 + high = 100.0 + order = 4 + + op_bandpass = OperationNode( + operation="bandpass", + inputs=[source], + name="bandpass_op", + kwargs={"low": low, "high": high, "order": order, "fs": fs} + ) + + # Fusion Node + fusion_node = FusionNode([op_bandpass], name="fused_bandpass") + + # Execute with NumbaBackend + backend = NumbaBackend() + result_numba = backend.execute(fusion_node, data) + + # Execute with Scipy (Reference) + nyq = 0.5 * fs + sos = signal.butter( + order, [low / nyq, high / nyq], btype="band", output="sos" + ) + # We implemented sosfilt (Direct Form II Transposed), causal. + # Scipy default sosfilt is axis=-1, but here data is (Time, Channel). + # We need axis=0. + result_scipy = signal.sosfilt(sos, data, axis=0) + + # Comparison + # Since we use float64, precision should be high. + # sosfilt might have slight numerical differences + # due to implementation details or order, + # but Numba implementation follows standard DF-II Transposed logic. + + np.testing.assert_allclose( + result_numba, + result_scipy, + rtol=1e-5, + atol=1e-5, + err_msg="Numba bandpass result does not match Scipy sosfilt" + ) + + def test_bandpass_chain(self): + """Verify bandpass works in a chain (e.g. detrend -> bandpass -> abs).""" + rows = 1000 + cols = 3 + fs = 500.0 + data = np.random.randn(rows, cols) * 10.0 + 100.0 # Offset + + # Add trend + t = np.arange(rows) + trend = t.reshape(-1, 1) * 0.1 + data += trend + + # 1. Detrend + # 2. Bandpass + # 3. Abs + + # Reference Scipy Calculation + # Detrend + step1 = signal.detrend(data, axis=0) + + # Bandpass + sos = signal.butter(4, [5.0/250.0, 50.0/250.0], btype="band", output="sos") + step2 = signal.sosfilt(sos, step1, axis=0) + + # Abs + expected = np.abs(step2) + + # Numba Calculation + source = SourceNode(data) + op1 = OperationNode("detrend", [source]) + op2 = OperationNode( + "bandpass", + [op1], + kwargs={"low": 5.0, "high": 50.0, "order": 4, "fs": fs}, + ) + op3 = OperationNode("abs", [op2]) + + fusion_node = FusionNode([op1, op2, op3]) + + backend = NumbaBackend() + result_numba = backend.execute(fusion_node, data) + + np.testing.assert_allclose(result_numba, expected, rtol=1e-5, atol=1e-5)