Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hypyp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from importlib.metadata import version
from hypyp import analyses, prep, stats, utils, viz
from hypyp import analyses, prep, stats, utils, viz, fnirs, eeg, multimodal, ext, signal

__version__ = version("hypyp")
__all__ = ["analyses", "prep", "stats", "utils", "viz", "fnirs", "ext"]
__all__ = ["analyses", "prep", "stats", "utils", "viz", "fnirs", "eeg", "multimodal", "ext", "signal"]
1 change: 1 addition & 0 deletions hypyp/analyses.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def pow(epochs: mne.Epochs, fmin: float, fmax: float, n_fft: int, n_per_seg: int
tmin=None, tmax=None, method='welch', picks='all', exclude=[],
proj=False, remove_dc=True, n_jobs=1)
spectrum = EpochsSpectrum(epochs, **kwargs)
print(spectrum.freqs)
psds = spectrum.get_data()
freq_list = spectrum.freqs

Expand Down
147 changes: 147 additions & 0 deletions hypyp/connectivity/connectivities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from collections import OrderedDict

import numpy as np
import pandas as pd

from ..dataclasses.freq_band import FreqBands
from .connectivity import Connectivity
from ..plots import plot_coherence_matrix

class Connectivities():
mode: str
inter: list[Connectivity]
intras: list[list[Connectivity]]
is_averaged: bool

@staticmethod
def matrix_is_averaged(matrix: np.ndarray):
"""
Expected shape when averaged: (n_bands, 2*n_ch, 2*n_ch)
Expected shape when not averaged: (n_bands, n_epochs, 2*n_ch, 2*n_ch)
"""
if len(matrix.shape) == 3:
return True
elif len(matrix.shape) == 4:
return False
else:
raise ValueError(f"Received matrix have an invalid shape: {np.shape(matrix)}")


def __init__(
self,
mode: str,
freq_bands: FreqBands,
matrix: np.ndarray,
ch_names: list[str] | tuple[list[str], list[str]],
):
self.mode = mode
self.inter = []
self.intras = [[], []]
self.is_averaged = Connectivities.matrix_is_averaged(matrix)

# Determine the number of channels
n_ch = matrix.shape[-1] // 2

if self.is_averaged:
# add a "epoch" dimension to simplify code below, even of there is a single value when averaged
matrix = np.expand_dims(matrix, axis=1)

if not isinstance(ch_names, tuple):
ch_names = (ch_names, ch_names)

for i, freq_band in enumerate(freq_bands):
range_axis_1 = slice(0, n_ch)
range_axis_2 = slice(n_ch, 2*n_ch)
values = matrix[i, :, range_axis_1, range_axis_2]
C = (values - np.mean(values[:])) / np.std(values[:])

# drop the epoch dimension when averaged
if self.is_averaged:
values = np.squeeze(values, axis=0)

self.inter.append(Connectivity(freq_band, values, C, ch_names))

for subject_idx in [0, 1]:
for i, freq_band in enumerate(freq_bands):
range_axis_1 = slice((subject_idx * n_ch), ((subject_idx + 1) * n_ch))
range_axis_2 = range_axis_1

values = matrix[i, :, range_axis_1, range_axis_2]
for epoch_idx in range(values.shape[0]):
values[epoch_idx] -= np.diag(np.diag(values[epoch_idx]))
C = (values - np.mean(values[:])) / np.std(values[:])

# drop the epoch dimension when averaged
if self.is_averaged:
values = np.squeeze(values, axis=0)

ch_names_pair = (ch_names[subject_idx], ch_names[subject_idx])
self.intras[subject_idx].append(Connectivity(freq_band, values, C, ch_names_pair))

@property
def intra1(self) -> list[Connectivity]:
return self.intras[0]

@property
def intra2(self) -> list[Connectivity]:
return self.intras[1]

def get_based_on_subject_id(self, subject_id: int = None):
if subject_id == 0:
raise ValueError("This method expects subject_id starting at 1, not 0")

if subject_id is None:
return self.inter

if subject_id == 1:
return self.intra1

if subject_id == 2:
return self.intra2

raise ValueError(f"Cannot have connectivity of subject_id '{subject_id}'")

def get_for_freq_band(self, freq_band_name, subject_id: int | None):
for connectivity in self.get_based_on_subject_id(subject_id):
if connectivity.freq_band.name == freq_band_name:
return connectivity

raise ValueError(f"Cannot find connectivity for freq_band {freq_band_name}")

def get_inter_for_freq_band(self, freq_band_name):
return self.get_for_freq_band(freq_band_name, None)

def get_intra_for_freq_band(self, freq_band_name, subject_id: int):
return self.get_for_freq_band(freq_band_name, subject_id)

def plot_connectivity_for_freq_band(self, freq_band_name):
conn = self.get_inter_for_freq_band(freq_band_name)
flat = conn.zscore.flatten()
dfs = []
df_inter = pd.DataFrame({
'coherence': flat,
'channel1': np.repeat(conn.ch_names[0], len(conn.ch_names[1])),
'channel2': np.array(conn.ch_names[1] * len(conn.ch_names[0])),
'is_intra': np.full_like(flat, False),
'is_intra_of': np.full_like(flat, None),
})
dfs.append(df_inter)

for subject_id in [1, 2]:
conn = self.get_inter_for_freq_band(freq_band_name, subject_id)
flat = conn.zscore.flatten()
df_intra = pd.DataFrame({
'coherence': flat,
'channel1': np.repeat(conn.ch_names[0], len(conn.ch_names[0])),
'channel2': np.array(conn.ch_names[0] * len(conn.ch_names[0])),
'is_intra': np.full_like(flat, True),
'is_intra_of': np.full_like(flat, subject_id),
})
dfs.append(df_intra)

df = pd.concat(dfs, ignore_index=True)

return plot_coherence_matrix(df, 'subject1', 'subject2', 'channel1', 'channel2', [])



35 changes: 35 additions & 0 deletions hypyp/connectivity/connectivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dataclasses import dataclass

import numpy as np
from matplotlib.axes import Axes
import matplotlib.pyplot as plt
import seaborn as sns

from ..dataclasses.freq_band import FreqBand

@dataclass
class Connectivity():
freq_band: FreqBand
values: np.ndarray
zscore: np.ndarray
ch_names: tuple[list[str], list[str]]

def plot_zscore(self, ax:Axes = None, title: str = None):
if title is None:
title = f"Z Score {self.freq_band.name}"

if ax is None:
fig, ax = plt.subplots(1, 1)
else:
fig = ax.get_figure()

# If zscore was not averaged, we need to average it for display
if len(self.zscore.shape) == 3:
zscore = np.mean(self.zscore, axis=0)
else:
zscore = self.zscore

sns.heatmap(zscore, xticklabels=self.ch_names[0], yticklabels=self.ch_names[1], cmap='viridis', cbar=True, ax=ax)
ax.set_title(title)
return fig

9 changes: 9 additions & 0 deletions hypyp/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .base_dyad import BaseDyad
from .base_step import BaseStep

__all__ = [
'BaseDyad',
'BaseStep',
]


15 changes: 15 additions & 0 deletions hypyp/core/base_dyad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from abc import ABC, abstractmethod

from ..dataclasses.synchrony import SynchronyTimeSeries

class BaseDyad(ABC):
def __init__(self):
pass

@abstractmethod
def get_synchrony_time_series() -> SynchronyTimeSeries:
pass

@abstractmethod
def plot_synchrony_time_series(self, ax):
pass
23 changes: 4 additions & 19 deletions hypyp/fnirs/preprocessor/base_step.py → hypyp/core/base_step.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,6 @@
from abc import ABC, abstractmethod
from typing import List, Generic, TypeVar

PREPROCESS_STEP_BASE_KEY = 'base'
PREPROCESS_STEP_BASE_DESC = 'Loaded data'

PREPROCESS_STEP_OD_KEY = 'od'
PREPROCESS_STEP_OD_DESC = 'Optical density'

PREPROCESS_STEP_OD_CLEAN_KEY = 'od_clean'
PREPROCESS_STEP_OD_CLEAN_DESC = 'Optical density cleaned'

PREPROCESS_STEP_HAEMO_KEY = 'haemo'
PREPROCESS_STEP_HAEMO_DESC = 'Hemoglobin'

PREPROCESS_STEP_HAEMO_FILTERED_KEY = 'haemo_filtered'
PREPROCESS_STEP_HAEMO_FILTERED_DESC = 'Hemoglobin Band-pass Filtered'

# Generic type for underlying fnirs implementation (mne raw / cedalion recording)
T = TypeVar('T')

Expand All @@ -29,14 +14,14 @@ class BaseStep(ABC, Generic[T]):
desc (str | None, optional): description of the setup. Defaults to "key" value.
"""
obj: T
key: str
name: str
desc: str

def __init__(self, obj:T, key:str, desc:str|None=None):
def __init__(self, obj:T, name:str, desc:str|None=None):
self.obj = obj
self.key = key
self.name = name
if desc is None:
self.desc = key
self.desc = name
else:
self.desc = desc

Expand Down
Loading
Loading