Skip to content

Commit bbe62d7

Browse files
Ramdam17m2march
andcommitted
feat(accorr): add numba/torch optimizations with unified optimization API
Integrate accorr optimizations (numba JIT, PyTorch GPU) from PR #246 into the modular sync architecture. Unify the API around a single `optimization` parameter (None, 'auto', 'numba', 'torch') with graceful fallback and warnings when backends are unavailable. - BaseMetric: add _resolve_optimization() with fallback cascade - ACCorr: numpy/numba/torch backends with precompute optimization - All metrics: remove dead dispatch code for numpy-only metrics - compute_sync: pass optimization directly to get_metric() - Tests: reference-based validation for all backends, mocked fallbacks - Add optional dependency groups (optim_torch, optim_numba) Co-Authored-By: Martín A. Miguel <m2march@users.noreply.github.com>
1 parent 665e481 commit bbe62d7

20 files changed

+1349
-321
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The **Hy**perscanning **Py**thon **P**ipeline
1515
## Contributors
1616

1717
Original authors: Florence BRUN, Anaël AYROLLES, Phoebe CHEN, Amir DJALOVSKI, Yann BEAUXIS, Suzanne DIKKER, Guillaume DUMAS
18-
New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice FORTIN, Ghazaleh RANJBARAN, Quentin MOREAU, Caitriona DOUGLAS, Franck PORTEOUS, Jonas MAGO, Juan C. AVENDANO, Julie BONNAIRE
18+
New contributors: Ryssa MOFFAT, Marine Gautier MARTINS, Rémy RAMADOUR, Patrice FORTIN, Ghazaleh RANJBARAN, Quentin MOREAU, Caitriona DOUGLAS, Franck PORTEOUS, Jonas MAGO, Juan C. AVENDANO, Julie BONNAIRE, Martín A. MIGUEL
1919

2020
## Installation
2121

hypyp/analyses.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import statsmodels.stats.multitest
2121
import copy
2222
from collections import namedtuple
23-
from typing import Union, List, Tuple
23+
from typing import Union, List, Tuple, Optional
2424
import matplotlib.pyplot as plt
2525
from tqdm import tqdm
2626

@@ -437,18 +437,18 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int,
437437
return result
438438

439439

440-
def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True) -> np.ndarray:
440+
def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True,
441+
optimization: Optional[str] = None) -> np.ndarray:
441442
"""
442443
Computes frequency-domain connectivity measures from analytic signals.
443-
444+
444445
This function calculates various connectivity metrics between all possible
445446
channel pairs based on the input complex-valued signals.
446-
447+
447448
Parameters
448449
----------
449450
complex_signal : np.ndarray
450451
Complex analytic signals with shape (2, n_epochs, n_channels, n_freq_bins, n_times)
451-
452452
mode : str
453453
Connectivity measure to compute. Options:
454454
- 'envelope_corr' or 'envcorr': envelope correlation - correlation between signal envelopes
@@ -460,40 +460,24 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
460460
- 'imaginary_coh' or 'imcoh': imaginary coherence - imaginary part of coherence (volume conduction resistant)
461461
- 'pli': phase lag index - asymmetry of phase difference distribution
462462
- 'wpli': weighted phase lag index - weighted version of PLI with improved properties
463-
464463
epochs_average : bool, optional
465464
If True, connectivity values are averaged across epochs (default)
466465
If False, epoch-by-epoch connectivity is preserved
467-
466+
optimization : str, optional
467+
Optimization strategy. May require extra dependencies.
468+
Currently only available for 'accorr'. Options:
469+
- None: standard numpy implementation (default)
470+
- 'auto': best available (torch > numba > numpy)
471+
- 'numba': numba JIT compilation (falls back to numpy if unavailable)
472+
- 'torch': PyTorch with auto-detected GPU (falls back gracefully)
473+
468474
Returns
469475
-------
470476
con : np.ndarray
471477
Connectivity matrix with shape:
472478
- If epochs_average=True: (n_freq, 2*n_channels, 2*n_channels)
473479
- If epochs_average=False: (n_freq, n_epochs, 2*n_channels, 2*n_channels)
474-
475-
Notes
476-
-----
477-
Mathematical formulations for each connectivity measure:
478-
479-
- PLV: |⟨e^(i(φₓ-φᵧ))⟩|
480-
Measures consistency of phase differences across time
481-
482-
- Envelope correlation: corr(env(x), env(y))
483-
Pearson correlation between signal envelopes
484-
485-
- Coherence: |⟨XY*⟩|²/(⟨|X|²⟩⟨|Y|²⟩)
486-
Normalized cross-spectrum
487-
488-
- Imaginary coherence: |Im(⟨XY*⟩)|/√(⟨|X|²⟩⟨|Y|²⟩)
489-
Takes only imaginary part which is less affected by volume conduction
490-
491-
- PLI: |⟨sign(Im(XY*))⟩|
492-
Quantifies asymmetry in phase difference distribution
493-
494-
- wPLI: |⟨|Im(XY*)|sign(Im(XY*))⟩|/⟨|Im(XY*)|⟩
495-
Weighted version that downweights phase differences near 0 or π
496-
480+
497481
Raises
498482
------
499483
ValueError
@@ -506,7 +490,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
506490
# calculate all epochs at once, the only downside is that the disk may not have enough space
507491
complex_signal = complex_signal.transpose((1, 3, 0, 2, 4)).reshape(n_epoch, n_freq, 2 * n_ch, n_samp)
508492
transpose_axes = (0, 1, 3, 2)
509-
493+
510494
# Normalize mode names (handle aliases)
511495
mode_lower = mode.lower()
512496
mode_map = {
@@ -515,10 +499,10 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T
515499
'imaginary_coh': 'imcoh',
516500
}
517501
mode_normalized = mode_map.get(mode_lower, mode_lower)
518-
502+
519503
# Get the metric from the sync module
520504
try:
521-
metric = get_metric(mode_normalized)
505+
metric = get_metric(mode_normalized, optimization=optimization)
522506
con = metric.compute(complex_signal, n_samp, transpose_axes)
523507
except ValueError:
524508
raise ValueError(f'Metric type "{mode}" not supported.')
@@ -1190,7 +1174,7 @@ def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True,
11901174
11911175
.. deprecated:: 0.5.0
11921176
This function is deprecated and will be removed in version 1.0.0.
1193-
Use :class:`hypyp.sync.ACorr` instead.
1177+
Use :class:`hypyp.sync.ACCorr` instead.
11941178
11951179
This function calculates the adjusted circular correlation coefficient between
11961180
all channel pairs. It uses a vectorized computation for the numerator and an
@@ -1241,7 +1225,7 @@ def _accorr_hybrid(complex_signal: np.ndarray, epochs_average: bool = True,
12411225
"""
12421226
warnings.warn(
12431227
"_accorr_hybrid is deprecated and will be removed in version 1.0.0. "
1244-
"Use hypyp.sync.ACorr instead.",
1228+
"Use hypyp.sync.ACCorr instead.",
12451229
DeprecationWarning,
12461230
stacklevel=2
12471231
)

hypyp/sync/__init__.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
to measure neural synchronization between participants.
99
"""
1010

11-
from .base import BaseMetric, detect_backend, multiply_conjugate, multiply_conjugate_time, multiply_product
11+
from typing import Optional
12+
13+
from .base import BaseMetric, multiply_conjugate, multiply_conjugate_time, multiply_product
1214
from .plv import PLV
1315
from .ccorr import CCorr
14-
from .accorr import ACorr
16+
from .accorr import ACCorr
1517
from .coh import Coh
1618
from .imaginary_coh import ImCoh
1719
from .pli import PLI
@@ -23,7 +25,7 @@
2325
METRICS = {
2426
'plv': PLV,
2527
'ccorr': CCorr,
26-
'accorr': ACorr,
28+
'accorr': ACCorr,
2729
'coh': Coh,
2830
'imcoh': ImCoh,
2931
'pli': PLI,
@@ -35,14 +37,13 @@
3537
__all__ = [
3638
# Base classes and utilities
3739
'BaseMetric',
38-
'detect_backend',
3940
'multiply_conjugate',
4041
'multiply_conjugate_time',
4142
'multiply_product',
4243
# Metric classes
4344
'PLV',
4445
'CCorr',
45-
'ACorr',
46+
'ACCorr',
4647
'Coh',
4748
'ImCoh',
4849
'PLI',
@@ -55,39 +56,38 @@
5556
]
5657

5758

58-
def get_metric(mode: str, backend: str = 'numpy') -> BaseMetric:
59+
def get_metric(mode: str, optimization: Optional[str] = None) -> BaseMetric:
5960
"""
6061
Get a connectivity metric instance by name.
61-
62+
6263
Parameters
6364
----------
6465
mode : str
6566
Name of the connectivity metric. One of: 'plv', 'ccorr', 'accorr',
6667
'coh', 'imcoh', 'pli', 'wpli', 'envcorr', 'powcorr'.
67-
68-
backend : str, optional
69-
Computation backend to use. One of: 'numpy', 'numba', 'torch'.
70-
Default is 'numpy'.
71-
68+
optimization : str, optional
69+
Optimization strategy. Options: None, 'auto', 'numba', 'torch'.
70+
See BaseMetric for fallback behavior.
71+
7272
Returns
7373
-------
7474
metric : BaseMetric
7575
An instance of the requested metric class.
76-
76+
7777
Raises
7878
------
7979
ValueError
8080
If the mode is not recognized.
81-
81+
8282
Examples
8383
--------
8484
>>> from hypyp.sync import get_metric
85-
>>> plv = get_metric('plv', backend='numpy')
86-
>>> result = plv.compute(complex_signal, n_samp, transpose_axes)
85+
>>> accorr = get_metric('accorr', optimization='torch')
86+
>>> result = accorr.compute(complex_signal, n_samp, transpose_axes)
8787
"""
8888
mode_lower = mode.lower()
8989
if mode_lower not in METRICS:
9090
available = ', '.join(METRICS.keys())
9191
raise ValueError(f"Unknown metric mode '{mode}'. Available: {available}")
92-
93-
return METRICS[mode_lower](backend=backend)
92+
93+
return METRICS[mode_lower](optimization=optimization)

0 commit comments

Comments
 (0)