Skip to content

Commit ad03919

Browse files
committed
Move stat functions from core to stats
1 parent 46a18b1 commit ad03919

6 files changed

Lines changed: 73 additions & 6 deletions

File tree

docs/source/api/api_statistics.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,17 @@ Statistics
99
:no-members:
1010
:no-inherited-members:
1111

12-
Random effect (rfx)
12+
Non-parametric statistics
13+
+++++++++++++++++++++++++
14+
15+
.. autosummary::
16+
:toctree: generated/
17+
18+
permute_mi_vector
19+
permute_mi_trials
20+
bootstrap_partitions
21+
22+
Random-effect (rfx)
1323
+++++++++++++++++++
1424

1525
.. autosummary::

frites/core/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@
1313
from .gcmi_nd import (mi_nd_gg, mi_model_nd_gd, cmi_nd_ggg, gcmi_nd_cc, # noqa
1414
gcmi_model_nd_cd, gccmi_nd_ccnd, gccmi_model_nd_cdnd,
1515
gccmi_nd_ccc, cmi_nd_ggd)
16-
from .mi_stats import (permute_mi_vector, permute_mi_trials) # noqa

frites/stats/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
Most of those stastical functions are using
1111
`MNE Python <https://mne.tools/stable/index.html>`_
1212
"""
13-
from .stats_param import (ttest_1samp, rfx_ttest) # noqa
1413
from .stats_mcp import (testwise_correction_mcp, cluster_correction_mcp, # noqa
1514
cluster_threshold)
15+
from .stats_nonparam import (permute_mi_vector, permute_mi_trials, # noqa
16+
bootstrap_partitions)
17+
from .stats_param import (ttest_1samp, rfx_ttest, dist_to_ci) # noqa
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Utility functions for stat evaluation."""
22
import numpy as np
33

4+
from frites.utils import nonsorted_unique
5+
from frites.dataset.ds_utils import multi_to_uni_conditions
6+
47

58
def permute_mi_vector(y, suj, mi_type='cc', inference='rfx', n_perm=1000,
69
random_state=None):
@@ -93,3 +96,56 @@ def permute_mi_trials(suj, inference='rfx', n_perm=1000, random_state=None):
9396
assert len(y_p) == n_perm
9497

9598
return y_p
99+
100+
101+
def bootstrap_partitions(n_epochs, *groups, n_partitions=200,
102+
random_state=None):
103+
"""Generate partitions for bootstrap.
104+
105+
Parameters
106+
----------
107+
n_epochs : int
108+
Number of epochs
109+
groups : array_like
110+
Groups within which permutations are performed. Should be arrays of
111+
shape (n_epochs,) and of type int
112+
n_partitions : int | 200
113+
Number of partitions to get
114+
random_state : int | None
115+
Fix the random state of the machine (use it for reproducibility). If
116+
None, a random state is randomly assigned.
117+
118+
Returns
119+
-------
120+
partitions : list
121+
List of arrays describing the partitions within groups or not
122+
"""
123+
from sklearn.utils import resample
124+
125+
# define the random state
126+
rnd = np.random.randint(1000) if not isinstance(
127+
random_state, int) else random_state
128+
129+
# manage groups
130+
if not len(groups):
131+
groups = np.zeros((n_epochs), dtype=int)
132+
else:
133+
if len(groups) == 1:
134+
groups = groups[0]
135+
else:
136+
groups = multi_to_uni_conditions(
137+
[np.stack(groups, axis=1)], var_name='boot', verbose=False)[0]
138+
u_groups = nonsorted_unique(groups)
139+
140+
# generate the partitions
141+
partitions = []
142+
for n_p in range(n_partitions):
143+
_part = np.arange(n_epochs)
144+
for n_g in u_groups:
145+
is_group = groups == n_g
146+
n_group = is_group.sum()
147+
_part[is_group] = resample(
148+
_part[is_group], n_samples=n_group, random_state=rnd + n_p)
149+
partitions.append(_part)
150+
151+
return partitions
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Test high-level mutual information functions."""
22
import numpy as np
33

4-
from frites.core import permute_mi_vector, permute_mi_trials
4+
from frites.stats import permute_mi_vector, permute_mi_trials
55

66
rnd = np.random.RandomState(0)
77

@@ -14,7 +14,7 @@
1414
suj = np.round(np.linspace(0, n_suj, n_epochs)).astype(int)
1515

1616

17-
class TestMiStats(object): # noqa
17+
class TestNonParam(object): # noqa
1818

1919
def test_permute_mi_vector(self):
2020
"""Test function permute_mi_vector."""

frites/workflow/wf_conn_comod.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mne.utils import ProgressBar
66

77
from frites.io import (set_log_level, logger)
8-
from frites.core import permute_mi_trials
8+
from frites.stats import permute_mi_trials
99
from frites.utils import parallel_func, kernel_smoothing
1010
from frites.workflow.wf_stats import WfStats
1111
from frites.workflow.wf_base import WfBase

0 commit comments

Comments
 (0)