Source code for systole.interact.interact

# Author: Nicolas Legrand <nicolas.legrand@cfin.au.dk>

import functools
import json
from os import PathLike
from pathlib import Path
from typing import List, Optional, Tuple, Union

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.dates import date2num
from matplotlib.widgets import SpanSelector

from systole.detection import ecg_peaks, ppg_peaks, rsp_peaks
from systole.plots import plot_raw
from systole.utils import ecg_strings, norm_bad_segments, ppg_strings, resp_strings


[docs]class Viewer: """This class handles the interaction with BIDS structured folders. It calls the` Editor` class internally to generate the interactive plots. Parameters ---------- bids_folder : Path to the input BIDS folder. If the BIDS folder is used as input, the `Viewer` tries to read the preprocessed physiological recordings generated by the command line for reports in `BIDS/derivatives/systole/`. .. note:: If this parameter is provided, `preprocessed_folder` will be ignored (implicitely here, `preprocessed_folder=BIDS/derivatives/systole/`). preprocessed_folder : Path to the folder where preprocessed physiological recording have been saved. This should be used if not working directly inside the BIDS folder and the preprocessed data have been save loaclly. .. note:: If this parameter is provided, `bids_folder` will be ignored. output_folder : Path to the output folder. This is where the JSON files containing peaks correction, bad segments and signal validity logs will be saved. If an empty strimg is provided (default), the results will be saved in `BIDS/derivative/systole/corrected/` when working in the BIDS folder, or in `preprocessed_folder/corrected/` when working whith a local folder. session : The BIDS sub-session where the pysio files are stored. Defaults to `"ses-session1"`. modality : The BIDS sub-modality where the pysio files are stored (e.g. `"func"` or `"beh"`). pattern : The string pattern that the pysio files should contain. This allows to refine the selection of possible physio files, in case the folders contains many `_physio-gz.tsv`. participant_id : The participant ID as registered in the BIDS folder. If `None` (default), the first participant in the list of available recordings is selected. signal_type : The type of signal that are being analyzed. Can be `"PPG"`, `"ECG"` or `"RESP"`. Defaults to `"PPG"`. figsize : The size of the interactive Matplotlib figure for peaks edition. Defaults to `(15, 7)`. See also -------- Editor Raises ------ ValueError If both `bids_folder` and `preprocessed_folder` are provided. """
[docs] def __init__( self, bids_folder: Optional[Union[str, PathLike]] = None, preprocessed_folder: Optional[Union[str, PathLike]] = None, output_folder: Union[str, PathLike] = "", session: Union[str, PathLike] = "ses-session1", modality: Union[str, PathLike] = "beh", pattern: Union[str, PathLike] = "task-", participant_id: Optional[str] = None, signal_type: Union[str, PathLike] = "PPG", figsize: Tuple[int, int] = (15, 7), ) -> None: self.figsize = figsize if bids_folder is not None: self.bids_folder = bids_folder if not Path(bids_folder, "derivatives", "systole").exists(): print(f"The BIDS folder {bids_folder} does not contains derivatives.") else: self.preprocessed_folder = Path(bids_folder) else: self.bids_folder = "" self.preprocessed_folder = preprocessed_folder # type: ignore ################## # Create widgets # ################## self.bids_folder_ = widgets.Textarea( value=str(self.bids_folder), placeholder="Type something", description="BIDS folder:", disabled=False, layout=widgets.Layout(width="250px"), ) self.preprocessed_folder_ = widgets.Textarea( value=str(self.preprocessed_folder), placeholder="Type something", description="Preprocessed folder:", disabled=False, layout=widgets.Layout(width="250px"), ) self.session_ = widgets.Textarea( value=session, placeholder="Type something", description="Session:", disabled=False, layout=widgets.Layout(width="250px"), ) self.modality_ = widgets.Textarea( value=modality, placeholder="Type something", description="Modality:", disabled=False, layout=widgets.Layout(width="250px"), ) self.pattern_ = widgets.Textarea( value=pattern, placeholder="Type something", description="Pattern:", disabled=False, layout=widgets.Layout(width="250px"), ) self.signal_type_ = widgets.Dropdown( options=["PPG", "ECG", "RESP"], value=signal_type, description="Signal:", layout=widgets.Layout(width="250px"), ) self.output_folder_ = widgets.Textarea( value=output_folder, placeholder="Type something", description="Output:", disabled=False, layout=widgets.Layout(width="250px"), ) # Update the participant list from the BIDS parameters try: # Get the list of all participant from the folders self.participants_list = [ f.stem for f in list(Path(self.preprocessed_folder_.value).glob("sub-*/")) ] if self.participants_list: self.participants_list.sort() # Filter participants that have no physio recording filter_participants_list = [ part for part in self.participants_list if any( Path( self.preprocessed_folder_.value, part, self.session_.value, self.modality_.value, ).glob(f"*{self.pattern_.value}*.tsv.gz") ) ] if len(filter_participants_list) == 0: print( "No file is matching the given paterns.\n" f"... Preprocessed folder: {self.preprocessed_folder_.value}\n" f"... Session: {self.session_.value}\n" f"... Modality: {self.modality_.value}\n" f"... Pattern: {self.pattern_.value}" ) self.participants_list = ["sub-"] else: self.participants_list = filter_participants_list except FileNotFoundError: print("Directory not found.") self.participants_list = ["sub-"] if participant_id is None: self.participant_id = self.participants_list[0] self.participants_ = widgets.Dropdown( options=self.participants_list, value=self.participant_id, description="Participant ID", layout=widgets.Layout(width="200px"), ) # Keep updated if dropdown menus are used self.bids_folder_.observe(self.update_list, names="value") self.preprocessed_folder_.observe(self.update_list, names="value") self.session_.observe(self.update_list, names="value") self.modality_.observe(self.update_list, names="value") self.pattern_.observe(self.update_list, names="value") self.signal_type_.observe(self.plot_signal, names="value") self.participants_.observe(self.plot_signal, names="value") # Show the navigator and main plot self.io_box = widgets.VBox( [ widgets.HBox( [ self.bids_folder_, self.preprocessed_folder_, self.session_, self.participants_, self.modality_, ] ), widgets.HBox([self.pattern_, self.signal_type_, self.output_folder_]), ] ) self.output = widgets.Output() # Plot the first pysio file if any self.plot_signal(change=None)
def update_list(self, change): """Updating the list of participants available in the folder when the text boxes are used.""" self.participants_list = [ f.stem for f in list(Path(self.bids_path.value).glob("sub-*/")) ] self.participants_list = [ part for part in self.participants_list if any( Path( self.bids_path.value, self.participants_.value, self.session_.value, self.modality_.value, ).glob(f"*{self.pattern_.value}*.tsv.gz") ) ] self.participants_.option = self.participants_list def plot_signal(self, change): # Load the physio files and store parameters in the Viewer class # then load the signal from the physio file and perform peaks detection self = self.load_file().load_signal() self.output.clear_output() with self.output: # Start the interactive editor for peaks correction self.editor = Editor( signal=self.input_signal, sfreq=1000, corrected_json=self.corrected_json, signal_type=self.signal_type_.value, figsize=self.figsize, corrected_peaks=self.corrected_peaks, bad_segments=self.bad_segments, viewer=self, ) plt.show() def load_signal(self): """Load the signal from the input folder (BIDS or local).""" # In case no file was match the requirements if self.physio_file is None: return self self.physio_df = None self.input_signal = None self.corrected_peaks = None self.bad_segments = None # Load the physiological signal from the BIDS/preprocessed folder self.physio_df = pd.read_csv( self.physio_file, sep="\t", compression="gzip", names=self.input_columns_names, ) self.physio_df.columns = self.physio_df.columns.str.lower() # Path to the corrected JSON files (if the signal has already been checked) if self.bids_folder is None: self.corrected_json = Path( self.preprocessed_folder, "corrected", str(self.participants_.value), str(self.session_.value), self.modality_.value, f"{self.physio_file.stem[:-11]}_corrected.json", ) else: # When reading the raw data directly, the JSON file # must be loaded from the derivatives self.corrected_json = Path( self.preprocessed_folder, "derivatives", "systole", "corrected", str(self.participants_.value), str(self.session_.value), self.modality_.value, f"{self.physio_file.stem[:-11]}_corrected.json", ) if self.signal_type_.value.lower() == "ecg": ecg_col = [col for col in self.physio_df.columns if col in ecg_strings] ecg_col = ecg_col[0] if len(ecg_col) > 0 else None self.input_signal = self.physio_df[ecg_col].to_numpy() print(f"Loading electrocardiogram - sfreq={self.sfreq} Hz.") elif self.signal_type_.value.lower() == "ppg": ppg_col = [col for col in self.physio_df.columns if col in ppg_strings] ppg_col = ppg_col[0] if len(ppg_col) > 0 else None self.input_signal = self.physio_df[ppg_col].to_numpy() print(f"Loading photoplethysmogram - sfreq={self.sfreq} Hz.") elif self.signal_type_.value.lower() == "resp": res_col = [col for col in self.physio_df.columns if col in resp_strings] res_col = res_col[0] if len(res_col) > 0 else None self.input_signal = self.physio_df[res_col].to_numpy() print(f"Loading respiratory signal - sfreq={self.sfreq} Hz.") # Resample the input signal to fit with the peaks vector if self.sfreq is not None: time = np.arange(0, len(self.input_signal) / self.sfreq, 1 / self.sfreq) new_time = np.arange(0, len(self.input_signal) / self.sfreq, 1 / 1000) self.input_signal = np.interp(new_time, time, self.input_signal) # Load peaks, bad segments and reject signal from the JSON logs if self.corrected_json.exists(): # Opening JSON file and extract metadata f = open(self.corrected_json) json_data = json.load(f) self.bad_segments = json_data[self.signal_type_.value.lower()][ "bad_segments" ] # If corrected peaks already exist, load here and replace the revious ones # The peaks vector is resampled to match 1 kHz self.corrected_peaks = np.zeros(len(self.input_signal), dtype=bool) self.corrected_peaks[ np.array(json_data[self.signal_type_.value.lower()]["corrected_peaks"]) ] = True f.close() # If the signal is invalid, set it to None if np.isnan(self.input_signal).all(): print("The signal only contains NaNs / zeros, settings everything to None.") self.input_signal = None self.corrected_peaks = None self.bad_segments = None return self def load_file(self): """Load the files containing the physiological recordings and the metadat JSON files for one participant.""" self.recording_start_time = None self.recording_end_time = None self.sfreq = None self.input_columns_names = None self.json_file = None self.physio_file = None # List the files matching the requirements physio_files = list( Path( self.preprocessed_folder_.value, str(self.participants_.value), str(self.session_.value), self.modality_.value, ).glob(f"*{self.pattern_.value}*.tsv.gz") ) if len(physio_files) == 0: self.physio_file, self.json_file = None, None print("No file matching the requirements.") return elif len(physio_files) > 1: self.physio_file, self.json_file = None, None print( "More than one recording match the provided string pattern." "Use a more explicit/longer string pattern to find your recording." ) return else: self.physio_file = physio_files[0] print(f"Loading physiological recording from {self.physio_file}") # Try to load the accompagning JSON metadata json_files = list( Path( self.preprocessed_folder, str(self.participants_.value), str(self.session_.value), self.modality_.value, ).glob(f"*{self.pattern_.value}*.json") ) if len(json_files) == 0: self.physio_file, self.json_file = None, None print("No JSON metadat found.") return elif len(json_files) > 1: self.physio_file, self.json_file = None, None print( "More than one JSON file match the provided string pattern. " "Use a more explicit/longer string pattern to find your recording." ) return else: self.json_file = json_files[0] if self.json_file is not None: # Opening JSON file and extract metadata f = open(self.json_file) json_data = json.load(f) self.sfreq = json_data["SamplingFrequency"] self.input_columns_names = json_data["Columns"] try: self.recording_start_time = json_data["StartTime"] self.recording_end_time = json_data["EndTime"] except KeyError: pass f.close() return self
[docs]class Editor: """This class handle the visualization and manual edition of peaks vectors associated with physiological signals. Parameters ---------- signal : The physiological signal. sfreq : The sampling frequency of the pysiological signal. signal_type : The type of signal that are being analyzed. Can be `"PPG"`, `"ECG"` or `"RESP"`. Defaults to `"PPG"`. corrected_json : Path to the corrected JSON file. figsize : The size of the interactive Matplotlib figure for peaks edition. Defaults to `(15, 7)`. viewer : The viewer instance from which the editor is called. corrected_peaks : The 1d array of corrected peaks indexes, in case the signal was previously edited. This is mostly relevant for the :py:class`systole.interact.Viewer` when a pre-existing JSON file is found in the derivatives. bad_segments : List of `start_idx` and `end_idx` annotating bad segments, in case the signal was previously edited. This is mostly relevant for the :py:class`systole.interact.Viewer` when a pre-existing JSON file is found in the derivatives. Attributes ---------- bad_segments : List of `start_idx` and `end_idx` listing bad segments. The list is automatically generated by:py:func:`systole.utils.norm_bad_segments` to avoid overlaping segments. uncorrected_peaks : The peaks vector as detected using the default peaks detection algorithm. If the signal was edited previously, this variable is directly imported from the JSON file. json_file : Path to the sidecar JSON file. peaks : The corrected peaks vector after manual insertion/deletion. physio_file : PathLike | None Path to the physiological recording. time : Time vector. edition_, rejection_, command_box_, save_button_ : Widgets controlling the type of modification to perform. See also -------- Viewer Notes ----- This module was largely inspired by the peakdet toolbox (https://github.com/physiopy/peakdet). """
[docs] def __init__( self, signal: np.ndarray, sfreq: int, signal_type: str, corrected_json: Union[str, PathLike] = "corrected.json", figsize: Tuple[int, int] = (15, 7), viewer: Optional[Viewer] = None, corrected_peaks: Optional[np.ndarray] = None, bad_segments: Optional[list] = None, ) -> None: if viewer is not None: self.viewer = viewer self.sfreq = sfreq self.signal = signal self.figsize = figsize self.bad_segments: List[int] = [] if viewer is not None: self.bad_segments = viewer.bad_segments self.corrected_json = corrected_json self.signal_type = signal_type self.peaks = corrected_peaks # Widgets for correction, rejection, valid recording and saving self.edition_ = widgets.ToggleButtons( options=["Correction", "Rejection"], diabled=False ) self.rejection_ = widgets.Checkbox( value=True, descrition="Valid recording", disabled=False, indent=True ) self.save_button_ = widgets.Button( description="Save modifications", disabled=False, button_style="", tooltip="Description", icon="save", layout=widgets.Layout(width="250px"), ) self.save_button_.on_click(self.save) self.commands_box = widgets.HBox( [self.edition_, self.rejection_, self.save_button_] ) # If a signal is available, call the main plotting method if self.signal is not None: # Peaks detection self = self.find_peaks() # Create a time vector from signal length and convert it to Matplotlib ax values self.time = pd.to_datetime( np.arange(0, len(self.signal)), unit="ms", origin="unix" ) self.x_vec = date2num(self.time) # Create the main plot_raw instance self.fig, self.ax = plt.subplots(nrows=2, figsize=self.figsize, sharex=True) if self.bad_segments: # Convert the list into list of tuples that can fit in the plot_raw bad_segments = [ (self.bad_segments[i], self.bad_segments[i + 1]) for i in range(0, len(self.bad_segments), 2) ] else: bad_segments = None plot_raw( signal=self.signal, peaks=self.peaks, modality=self.signal_type.lower(), backend="matplotlib", show_heart_rate=True, show_artefacts=True, bad_segments=bad_segments, sfreq=1000, ax=[self.ax[0], self.ax[1]], ) self.fig.canvas.mpl_connect("key_press_event", self.on_key) # two selectors for rejection (left mouse) and deletion (right mouse) self.delete = functools.partial(self.on_remove) self.span1 = SpanSelector( self.ax[0], self.delete, "horizontal", button=1, props=dict(facecolor="red", alpha=0.2), useblit=True, ) self.add = functools.partial(self.on_add) self.span2 = SpanSelector( self.ax[0], self.add, "horizontal", button=3, props=dict(facecolor="green", alpha=0.2), useblit=True, )
def on_remove(self, xmin, xmax): """Removes specified peaks by either rejection / deletion, or mark bad segments.""" # Get the interval in sample idexes if self.edition_.value == "Correction": tmin, tmax = np.searchsorted(self.x_vec, (xmin, xmax)) self.peaks[tmin:tmax] = False self.plot_signals() elif self.edition_.value == "Rejection": tmin, tmax = np.searchsorted(self.x_vec, (xmin, xmax)) self.bad_segments.append(int(tmin)) self.bad_segments.append(int(tmax)) # Makes it a list of tuple bad_segments = [ (self.bad_segments[i], self.bad_segments[i + 1]) for i in range(0, len(self.bad_segments), 2) ] # Merge overlapping segments if any bad_segments = norm_bad_segments(bad_segments) self.bad_segments = list(np.array(bad_segments).flatten()) print(self.bad_segments) self.plot_signals() def on_add(self, xmin, xmax): """Add a new peak on the maximum signal value from the selected range.""" # Get the interval in sample idexes tmin, tmax = np.searchsorted(self.x_vec, (xmin, xmax)) self.peaks[tmin + np.argmax(self.signal[tmin:tmax])] = True self.plot_signals() def on_key(self, event): """Undoes last span select or quits peak editor""" # accept both control or Mac command key as selector if event.key in ["ctrl+q", "super+d"]: self.quit() elif event.key in ["left"]: xlo, xhi = self.ax[0].get_xlim() step = xhi - xlo self.ax[0].set_xlim(xlo - step, xhi - step) self.fig.canvas.draw() elif event.key in ["right"]: xlo, xhi = self.ax[0].get_xlim() step = xhi - xlo self.ax[0].set_xlim(xlo + step, xhi + step) self.fig.canvas.draw() def plot_signals(self): """Clears axes and plots data / peaks / troughs.""" if self.signal is not None: # Clear axes and redraw, retaining x-/y-axis zooms xlim, ylim = self.ax[0].get_xlim(), self.ax[0].get_ylim() xlim2, ylim2 = self.ax[1].get_xlim(), self.ax[1].get_ylim() self.ax[0].clear() self.ax[1].clear() # Convert bad segments into list of tuple if self.bad_segments: bad_segments = [ (self.bad_segments[i], self.bad_segments[i + 1]) for i in range(0, len(self.bad_segments), 2) ] else: bad_segments = None plot_raw( signal=self.signal, peaks=self.peaks, modality=self.signal_type.lower(), backend="matplotlib", show_heart_rate=True, show_artefacts=True, bad_segments=bad_segments, sfreq=1000, ax=[self.ax[0], self.ax[1]], ) self.ax[0].set(xlim=xlim, ylim=ylim) self.ax[1].set(xlim=xlim2, ylim=ylim2) # Show span selectors # two selectors for rejection (left mouse) and deletion (right mouse) self.delete = functools.partial(self.on_remove) self.span1 = SpanSelector( self.ax[0], self.delete, "horizontal", button=1, props=dict(facecolor="red", alpha=0.2), useblit=True, ) self.add = functools.partial(self.on_add) self.span2 = SpanSelector( self.ax[0], self.add, "horizontal", button=3, props=dict(facecolor="green", alpha=0.2), useblit=True, ) # Customize the plot a bit for ax in self.ax: ax.spines.right.set_visible(False) ax.spines.top.set_visible(False) ax.tick_params( direction="in", width=1.5, which="major", size=8, ) ax.tick_params(direction="in", width=1, which="minor", size=4) ax.grid(which="major", alpha=0.5, linewidth=0.5) self.fig.set_tight_layout() plt.margins(x=0, y=0) plt.minorticks_on() plt.subplots_adjust(left=0.1, bottom=0.1, right=0.1, top=0.1) self.fig.canvas.draw() return self def quit(self): """Quits editor""" plt.close(self.fig) def save(self): """Save the JSON file containing the corrected peaks, bad segments and signal quality. The path is specified by `corrected_json`.""" if not Path(self.corrected_json).parent.exists(): Path(self.corrected_json).parent.mkdir(parents=True) if Path(self.corrected_json).exists(): # Load the existing corrected JSON data f = open(Path(self.corrected_json)) metadata = json.load(f) f.close() else: metadata = {} # Create the JSON metadata if self.bad_segments: bad_segments = [int(x) for x in self.bad_segments] else: bad_segments = None corrected_info = { "valid": self.rejection_.value, "corrected_peaks": np.where(self.peaks)[0].tolist(), "bad_segments": bad_segments, } metadata[self.signal_type.lower()] = corrected_info print(f"Saving modification in {self.corrected_json}") with open(self.corrected_json, "w") as f: json.dump(metadata, f, ensure_ascii=False, indent=4) def find_peaks(self): """Find peaks depending on the signal type.""" if self.peaks is None: if self.signal_type == "ECG": self.signal, self.peaks = ecg_peaks( signal=self.signal, sfreq=self.sfreq ) elif self.signal_type == "PPG": self.signal, self.peaks = ppg_peaks( signal=self.signal, sfreq=self.sfreq ) elif self.signal_type == "RESP": self.signal, (self.peaks, _) = rsp_peaks( signal=self.signal, sfreq=self.sfreq ) else: raise ValueError("Invalid signal_type. Must be 'ECG', 'PPG' or 'RESP'.") # The peaks vector before manual edition self.uncorrected_peaks = self.peaks.copy() return self