This repository implements a reproducible pipeline to synthesize multi-channel EEG artifact windows using two state-of-the-art generative paradigms:
- WGAN-GP with projection discriminator
- DDPM (1D U-Net + classifier-free guidance)
It targets the TUH EEG Artifact Corpus (TUAR) with subject-wise splits, robust preprocessing, and a comprehensive evaluation suite (signal-level, feature-space, and functional tasks like TRTS/TSTR and AugMix-style augmentation studies).
✅ Data Processing Complete: TUAR dataset has been processed with subject-wise stratified splits (149 train / 32 val / 32 test subjects)
✅ Exploration Analysis Done: Comprehensive data exploration completed including:
- Label distribution analysis across 5 artifact classes (Muscle, Eye movement, Electrode, Chewing, Shiver)
- Channel-wise event frequency analysis
- Duration distribution analysis with recommended window lengths
- Dimensionality reduction visualizations (t-SNE, UMAP) of per-file artifact summaries
- Multi-label stratified splitting to ensure balanced representation
✅ Training Infrastructure Ready: Models configured and training scripts prepared for both WGAN-GP and DDPM architectures
✅ Initial Training Runs Completed: TensorBoard logs indicate successful training runs with GPU acceleration
- Set up environment
- Python 3.12+
- Install dependencies:
pip install -r requirements.txt - GPU support: CUDA 12.1+ for PyTorch acceleration, CuPy for MNE GPU operations
- Prepare data
- Point
configs/*yamldata.dataset_rootto your local TUAR path - Processed data already available in
data/processed/including:- Subject-wise stratified splits (
suggested_splits_subjectwise_multilabel.csv) - Class mappings (
class_map.csv) - Pre-computed data statistics
- Subject-wise stratified splits (
- Train models
- Use provided configs to train WGAN-GP or DDPM
- Models automatically detect and use GPU if available
- Monitor training with TensorBoard:
tensorboard --logdir results/tensorboard/
- Evaluate
-
Run the full evaluation suite (signal, feature, functional, utility classifier):
# zsh/bash ./scripts/run_evaluation.sh configs/ddpm_raw.yaml -
Or call individual evaluators:
# Signal-level metrics (band-power errors, channel correlation, PSD distance) python -m src.eval.metrics_signal --config configs/ddpm_raw.yaml \ --ckpt results/checkpoints/ddpm_unet_best.pth --model-kind ddpm --n 256 # Feature-space metrics placeholder (MMD/PRD, t-SNE/UMAP): python -m src.eval.metrics_feature --config configs/ddpm_raw.yaml # Functional metrics placeholder (TRTS/TSTR, AugMix-style): python -m src.eval.metrics_functional --config configs/ddpm_raw.yaml # NEW: Classifier-based utility evaluation (train on real vs real+synthetic; eval on real test) python -m src.eval.utility_classifier --config configs/ddpm_raw.yaml \ --ckpt-ddpm results/checkpoints/ddpm_unet_best.pth \ --ckpt-wgan results/checkpoints/wgan_generator_best.pth \ --n-synth-per-class 500 --epochs 10 --batch-size 64 --lr 1e-3
-
Key outputs are written to
results/as CSV and LaTeX tables ready for inclusion in the paper.
Explore and visualize results in the following notebooks (open with Jupyter Lab or VS Code):
notebooks/exploration.ipynb— Data explorationnotebooks/visualization.ipynb— Visualizationsnotebooks/class_comparison.ipynb— Class comparison analysisnotebooks/model_comparison.ipynb— Model comparison analysis
- Configs:
configs/(YAML files for model and data settings) - Processed Data & Metadata:
data/processed/(splits, class maps, statistics) - Checkpoints & Results:
results/checkpoints/,results/generated/,results/manifest.json,results/split_summary.json - TensorBoard Logs:
results/tensorboard/ - Paper Source & Figures:
paper/ - Scripts:
scripts/(for all main pipeline steps) - Source Code:
src/(Python modules for all core logic)src/eval/utility_classifier.py— trains a small 1D CNN onreal,real+wgan, andreal+ddpm, then evaluates on the held-out real test set; writesresults/utility_classifier.csvandresults/table_utility.tex.
- All scripts are compatible with PowerShell on Windows.
- Update config paths as needed for your local data.
- For troubleshooting, see comments in scripts and notebooks.
configs/YAML configuration files for experiments (e.g.,ddpm_raw.yaml,wgan_raw.yaml)data/Data directory containing raw and processed datasetsraw/Raw data filesprocessed/Processed data including class mappings and split suggestions
notebooks/Jupyter notebooks for data exploration and visualizationexploration.ipynbComprehensive TUAR dataset analysis and visualization
paper/Paper-related files for NeurIPS 2025 submissionCITATIONS.bibBibliography referencesneurips_2025.pdfCompiled PDF of the paperneurips_2025.styLaTeX style file for NeurIPS formattingneurips_2025.texLaTeX source file for the paper
results/Output directory for model checkpoints, figures, and evaluation resultscheckpoints/Saved model weightsfigures/Generated plots and visualizationsmanifest.jsonMetadata about resultstensorboard/Training logs and metrics
scripts/Bash scripts for running preprocessing, training, and evaluationrun_preprocessing.shScript to preprocess raw datarun_training.shScript to train modelsrun_evaluation.shScript to evaluate trained models
src/Python source codedataset.pyDataset loading and preprocessing utilitiespreprocess.pyData preprocessing functionstrain.pyTraining scripts for WGAN and DDPM modelseval/Evaluation modulesmetrics_feature.pyFeature-space evaluation metricsmetrics_functional.pyFunctional evaluation metricsmetrics_signal.pySignal-level evaluation metrics
models/Model implementationsddpm.pyDenoising Diffusion Probabilistic Modelwgan.pyWasserstein GAN with Gradient Penalty
ENVIRONMENT.mdEnvironment setup and dependency versionsLICENSEProject licenseREADME.mdThis filerequirements.txtPython dependencies
- Subject-wise splits: 213 total subjects split into 149 train / 32 val / 32 test
- Multi-label stratification: Ensures balanced representation of all 5 artifact classes
- Window extraction: Configurable window lengths (1s/2s) with overlap options
- Normalization strategies: Per-window min-max for WGAN, per-recording z-score for DDPM
- WGAN-GP: Projection discriminator, gradient penalty, spectral normalization
- DDPM: 1D U-Net with classifier-free guidance, configurable noise schedules
- GPU acceleration: Automatic CUDA detection for both PyTorch and MNE operations
- Signal-level metrics: Welch band-power relative errors (δ/θ/α/β), channel-wise correlation, PSD L2 distance
- Feature-space metrics: Distribution matching (MMD/PRD), embedding comparisons (t-SNE/UMAP)
- Functional metrics: TRTS/TSTR evaluation, AugMix-style augmentation studies
- Utility classifier (NEW): Train a small 1D CNN on (i) real-only, (ii) real+WGAN, (iii) real+DDPM; evaluate accuracy and macro-F1 on the real test set to measure whether synthetic data improve performance on real data.
- We measure downstream utility by training an identical classifier on three training sets and always evaluating on the held-out real test set:
real(baseline), 2)real+wgan, 3)real+ddpm.
- Synthetic windows are generated per class using the best checkpoints with explicit class conditioning (we pass
class_id=cinto the generator), so each artifact class is represented accurately. - To avoid normalization confounds, both real and synthetic training windows are z-scored per window at classifier-training time.
- We report accuracy and macro-F1 (class-balanced) on the real test set.
- Reproducible outputs:
results/utility_classifier.csvresults/table_utility.tex
Run it directly:
python -m src.eval.utility_classifier \
--config configs/ddpm_raw.yaml \
--ckpt-ddpm results/checkpoints/ddpm_unet_best.pth \
--ckpt-wgan results/checkpoints/wgan_generator_best.pth \
--n-synth-per-class 500 --epochs 10 --batch-size 64 --lr 1e-3- Preprocess:
scripts/run_preprocessing.sh configs/wgan_raw.yaml - Train (WGAN example):
scripts/run_training.sh configs/wgan_raw.yaml - Evaluate:
scripts/run_evaluation.sh configs/wgan_raw.yaml
- Signal metrics:
results/signal_metrics.csv, LaTeXpaper/figs/table_bandpower.texandpaper/figs/table_channel_effects.tex(included in the paper via\input{...}) - Feature metrics:
results/feature_metrics.csv, LaTeXpaper/figs/table_metrics.tex - Utility classifier:
results/utility_classifier.csv, LaTeXpaper/figs/table_utility.tex
Camera-ready constraint: all analyses are post hoc from fixed checkpoints (no retraining). Unless noted, quantitative tables use per-class N=3000 synthetic windows, matched by N=3000 real windows from the test split.
- Python Version: Updated to 3.12.11 for improved performance and compatibility
- Data Exploration: Complete TUAR dataset analysis with visualization notebooks
- Training Infrastructure: Configured for both WGAN-GP and DDPM with GPU support
- Results Tracking: TensorBoard integration for monitoring training progress
- Documentation: Updated setup instructions and project status
See ENVIRONMENT.md for pinned versions, paper/CITATIONS.bib for references, and LICENSE for licensing. Replace example configs with your desired windows (1s/2s), filtering scheme (raw/filtered), and normalization strategies per model.
- GPU Support: The pipeline automatically detects and uses CUDA GPUs where possible. PyTorch models are moved to GPU, MNE filtering uses GPU acceleration if CuPy is installed, and DataLoaders use pinned memory for faster transfers.
- Data Handling: Subject-wise splits are enforced via metadata to prevent data leakage
- Normalization:
- WGAN uses per-window min-max normalization to [-1, 1] with min/max values stored for inversion
- DDPM uses per-recording z-score normalization
- Models:
- WGAN-GP includes a projection discriminator for improved stability
- DDPM uses a 1D U-Net architecture with classifier-free guidance
- Evaluation: Comprehensive metrics include signal fidelity, feature distribution matching, and functional performance on downstream tasks
- Reproducibility: All dependencies are pinned in
requirements.txtandENVIRONMENT.md - Future Additions: Privacy audit, Model/Data cards, and additional configurations will be added alongside trained checkpoints
If you use this work, please cite our PrePrint at arXiv:2509.08188.
All main scripts are in the scripts/ folder and can be run from PowerShell on Windows:
-
Preprocessing:
./scripts/run_preprocessing.sh
-
Training:
./scripts/run_training.sh
-
Evaluation:
./scripts/run_evaluation.sh