From 98df7c70f1db8d767c7775b423c4d678bd61ca22 Mon Sep 17 00:00:00 2001 From: Shahin Ramezanzadeh Date: Tue, 19 May 2026 10:03:52 +0200 Subject: [PATCH] New Organised one --- PROJECT_STRUCTURE.md | 31 + app.py => archive/old_scripts/app.py | 0 .../old_scripts/total_app.py | 0 organize_project.sh | 384 +++ {attach => prompts}/Komplett.txt | 0 {attach => prompts}/just_edss_schema.gbnf | 0 {attach => prompts}/just_edss_text.txt | 0 certainty.py => scripts/analyze_certainty.py | 1924 +++++++------- audit.py => scripts/audit_outputs.py | 0 .../certainty_show.py | 0 figure1.py => scripts/figure1.py | 0 show_plots.py => scripts/show_plots.py | 0 show_plots.py.orig | 2320 ----------------- 13 files changed, 1377 insertions(+), 3282 deletions(-) create mode 100644 PROJECT_STRUCTURE.md rename app.py => archive/old_scripts/app.py (100%) rename total_app.py => archive/old_scripts/total_app.py (100%) create mode 100755 organize_project.sh rename {attach => prompts}/Komplett.txt (100%) rename {attach => prompts}/just_edss_schema.gbnf (100%) rename {attach => prompts}/just_edss_text.txt (100%) rename certainty.py => scripts/analyze_certainty.py (75%) rename audit.py => scripts/audit_outputs.py (100%) rename certainty_show.py => scripts/certainty_show.py (100%) rename figure1.py => scripts/figure1.py (100%) rename show_plots.py => scripts/show_plots.py (100%) delete mode 100644 show_plots.py.orig diff --git a/PROJECT_STRUCTURE.md b/PROJECT_STRUCTURE.md new file mode 100644 index 0000000..bb0f527 --- /dev/null +++ b/PROJECT_STRUCTURE.md @@ -0,0 +1,31 @@ +# Project Structure + +This project was reorganized into: + +- `data/` + - `raw/`: original raw data, if retained locally + - `processed/`: cleaned or derived input data + - `ground_truth/`: manually annotated reference data + - `external/`: externally provided data + +- `prompts/` + - EDSS instructions and prompt/schema assets + +- `scripts/` + - runnable analysis and plotting scripts + +- `results/` + - `benchmark_runs/`: full model benchmark runs + - `final_results/`: final selected model outputs + - `figures/`: generated figures + - `tables/`: generated tables + - `logs/`: terminal logs + +- `manuscript/` + - final figures and tables for paper/thesis writing + +- `archive/` + - old scripts, old results, temporary files, and unclear legacy files + +Important: +The reorganization was performed after creating a full timestamped backup. diff --git a/app.py b/archive/old_scripts/app.py similarity index 100% rename from app.py rename to archive/old_scripts/app.py diff --git a/total_app.py b/archive/old_scripts/total_app.py similarity index 100% rename from total_app.py rename to archive/old_scripts/total_app.py diff --git a/organize_project.sh b/organize_project.sh new file mode 100755 index 0000000..af2c35a --- /dev/null +++ b/organize_project.sh @@ -0,0 +1,384 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# ============================================================ +# Organize Barcelona EDSS project safely +# - Creates a timestamped backup first +# - Creates a cleaner folder structure +# - Moves files conservatively +# - Does NOT delete anything +# ============================================================ + +PROJECT_ROOT="$(pwd)" +TIMESTAMP="$(date +%Y%m%d_%H%M%S)" +BACKUP_PARENT="${PROJECT_ROOT}/../Barcelona_backups" +BACKUP_DIR="${BACKUP_PARENT}/Barcelona_backup_${TIMESTAMP}" + +echo "Project root: ${PROJECT_ROOT}" +echo "Backup dir: ${BACKUP_DIR}" +echo + +# ------------------------------------------------------------ +# Safety checks +# ------------------------------------------------------------ + +if [ ! -f "${PROJECT_ROOT}/README.md" ]; then + echo "WARNING: README.md not found. Are you sure you are in the project root?" + echo "Current directory: ${PROJECT_ROOT}" + read -r -p "Continue anyway? [y/N] " answer + case "$answer" in + y|Y|yes|YES) ;; + *) echo "Aborted."; exit 1 ;; + esac +fi + +if [ -d "${PROJECT_ROOT}/.git" ]; then + if ! git diff --quiet || ! git diff --cached --quiet; then + echo "ERROR: Git working tree is not clean." + echo "Please commit or stash changes before organizing." + exit 1 + fi +fi + +echo "This script will:" +echo "1. Create a full backup." +echo "2. Create organized folders." +echo "3. Move files into data/, prompts/, scripts/, results/, archive/." +echo "4. Keep your original files in the backup." +echo +read -r -p "Proceed? [y/N] " answer +case "$answer" in + y|Y|yes|YES) ;; + *) echo "Aborted."; exit 1 ;; +esac + +# ------------------------------------------------------------ +# Backup +# ------------------------------------------------------------ + +mkdir -p "${BACKUP_PARENT}" + +echo +echo "Creating backup..." +rsync -a \ + --exclude "enarcelona/" \ + --exclude "env/" \ + --exclude ".venv/" \ + --exclude "__pycache__/" \ + "${PROJECT_ROOT}/" "${BACKUP_DIR}/" + +echo "Backup created at:" +echo "${BACKUP_DIR}" + +# ------------------------------------------------------------ +# Create target structure +# ------------------------------------------------------------ + +echo +echo "Creating new directory structure..." + +mkdir -p \ + data/raw \ + data/processed \ + data/ground_truth \ + data/external \ + prompts \ + scripts \ + results/benchmark_runs \ + results/final_results/model_outputs \ + results/figures \ + results/tables \ + results/logs \ + manuscript/figures \ + manuscript/tables \ + archive/old_scripts \ + archive/old_results \ + archive/tmp \ + archive/old_data \ + archive/old_project_files + +# ------------------------------------------------------------ +# Helper move functions +# ------------------------------------------------------------ + +move_if_exists() { + src="$1" + dest="$2" + + if [ -e "$src" ]; then + mkdir -p "$(dirname "$dest")" + + if [ -e "$dest" ]; then + echo "SKIP: destination exists: $dest" + else + echo "MOVE: $src -> $dest" + mv "$src" "$dest" + fi + fi +} + +move_glob_if_exists() { + pattern="$1" + dest_dir="$2" + + mkdir -p "$dest_dir" + + shopt -s nullglob + files=( $pattern ) + shopt -u nullglob + + for f in "${files[@]}"; do + base="$(basename "$f")" + dest="${dest_dir}/${base}" + + if [ -e "$dest" ]; then + echo "SKIP: destination exists: $dest" + else + echo "MOVE: $f -> $dest" + mv "$f" "$dest" + fi + done +} + +# ------------------------------------------------------------ +# Move prompts / attached instruction files +# ------------------------------------------------------------ + +echo +echo "Moving prompt and instruction files..." + +move_if_exists "attach/Komplett.txt" "prompts/Komplett.txt" +move_if_exists "attach/just_edss_schema.gbnf" "prompts/just_edss_schema.gbnf" +move_if_exists "attach/just_edss_text.txt" "prompts/just_edss_text.txt" + +# Move leftover attach folder if empty or archive it +if [ -d "attach" ]; then + if [ -z "$(ls -A attach)" ]; then + rmdir attach + else + move_if_exists "attach" "archive/old_project_files/attach" + fi +fi + +# ------------------------------------------------------------ +# Move important data files +# ------------------------------------------------------------ + +echo +echo "Moving data files..." + +move_if_exists "Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.csv" \ + "data/processed/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.csv" + +move_if_exists "Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" \ + "data/processed/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" + +move_if_exists "Data/Join_edssandsub.tsv" \ + "data/ground_truth/Join_edssandsub.tsv" + +move_if_exists "Data/GT_Numbers.csv" \ + "data/ground_truth/GT_Numbers.csv" + +move_if_exists "Data/Annika1.csv" \ + "data/ground_truth/Annika1.csv" + +move_if_exists "Data/comparison.tsv" \ + "data/ground_truth/comparison.tsv" + +move_if_exists "Data/edss_distribution_summary.csv" \ + "data/processed/edss_distribution_summary.csv" + +move_if_exists "Data/empirical_confidence_table.csv" \ + "data/processed/empirical_confidence_table.csv" + +move_if_exists "Data/functional_system_colors.json" \ + "data/processed/functional_system_colors.json" + +move_if_exists "Data/Test.csv" \ + "archive/tmp/Test.csv" + +move_if_exists "Data/Hernan" \ + "data/external/Hernan" + +move_if_exists "Data/iteration" \ + "archive/old_data/iteration" + +# Old generated JSON/results from Data folder +move_glob_if_exists "Data/*results*.json" "archive/old_results" +move_glob_if_exists "Data/join_*.tsv" "archive/old_results" + +# Move remaining Data folder if anything left +if [ -d "Data" ]; then + if [ -z "$(ls -A Data)" ]; then + rmdir Data + else + move_if_exists "Data" "archive/old_data/Data_remaining" + fi +fi + +# ------------------------------------------------------------ +# Move benchmark results +# ------------------------------------------------------------ + +echo +echo "Moving benchmark results..." + +if [ -d "results_edss_benchmark" ]; then + move_glob_if_exists "results_edss_benchmark/run_*" "results/benchmark_runs" + + move_if_exists "results_edss_benchmark/endresults" \ + "results/final_results/model_outputs" + + move_if_exists "results_edss_benchmark/confusion_matrices" \ + "results/figures/confusion_matrices" + + if [ -z "$(ls -A results_edss_benchmark 2>/dev/null || true)" ]; then + rmdir results_edss_benchmark + else + move_if_exists "results_edss_benchmark" \ + "archive/old_results/results_edss_benchmark_remaining" + fi +fi + +# ------------------------------------------------------------ +# Move old/general results +# ------------------------------------------------------------ + +echo +echo "Moving existing results files..." + +if [ -d "results" ]; then + # Figures + move_glob_if_exists "results/*.png" "results/figures" + move_glob_if_exists "results/*.PNG" "results/figures" + move_glob_if_exists "results/*.jpg" "results/figures" + move_glob_if_exists "results/*.jpeg" "results/figures" + move_glob_if_exists "results/*.svg" "results/figures" + + # Tables + move_glob_if_exists "results/*.csv" "results/tables" + move_glob_if_exists "results/*.tsv" "results/tables" + move_glob_if_exists "results/*.xlsx" "results/tables" + + # Subfolders that look like old results + move_if_exists "results/Jan_visual" "archive/old_results/Jan_visual" + move_if_exists "results/Lab_meeting" "archive/old_results/Lab_meeting" + move_if_exists "results/just_edss" "archive/old_results/just_edss" +fi + +# Root-level result tables +move_if_exists "edss_distribution_summary.csv" \ + "results/tables/edss_distribution_summary.csv" + +# Logs +move_if_exists "edss_benchmark_terminal.log" \ + "results/logs/edss_benchmark_terminal.log" + +# ------------------------------------------------------------ +# Move scripts +# ------------------------------------------------------------ + +echo +echo "Moving scripts..." + +move_if_exists "audit.py" "scripts/audit_outputs.py" +move_if_exists "certainty.py" "scripts/analyze_certainty.py" +move_if_exists "certainty_show.py" "scripts/certainty_show.py" +move_if_exists "figure1.py" "scripts/figure1.py" +move_if_exists "show_plots.py" "scripts/show_plots.py" + +move_if_exists "show_plots.py.orig" "archive/old_scripts/show_plots.py.orig" + +# Apps / old entry points +move_if_exists "app.py" "archive/old_scripts/app.py" +move_if_exists "total_app.py" "archive/old_scripts/total_app.py" + +# Existing project visuals folder +move_if_exists "project/visuals" "results/figures/project_visuals" + +if [ -d "project" ]; then + if [ -z "$(ls -A project)" ]; then + rmdir project + else + move_if_exists "project" "archive/old_project_files/project" + fi +fi + +# ------------------------------------------------------------ +# Environment folder +# ------------------------------------------------------------ + +echo +echo "Handling virtual environment..." + +if [ -d "enarcelona" ]; then + echo "Leaving virtual environment in place: enarcelona/" + echo "It should remain ignored by .gitignore." +fi + +# ------------------------------------------------------------ +# Create README notes +# ------------------------------------------------------------ + +echo +echo "Writing organization notes..." + +cat > "PROJECT_STRUCTURE.md" <<'EOF' +# Project Structure + +This project was reorganized into: + +- `data/` + - `raw/`: original raw data, if retained locally + - `processed/`: cleaned or derived input data + - `ground_truth/`: manually annotated reference data + - `external/`: externally provided data + +- `prompts/` + - EDSS instructions and prompt/schema assets + +- `scripts/` + - runnable analysis and plotting scripts + +- `results/` + - `benchmark_runs/`: full model benchmark runs + - `final_results/`: final selected model outputs + - `figures/`: generated figures + - `tables/`: generated tables + - `logs/`: terminal logs + +- `manuscript/` + - final figures and tables for paper/thesis writing + +- `archive/` + - old scripts, old results, temporary files, and unclear legacy files + +Important: +The reorganization was performed after creating a full timestamped backup. +EOF + +# ------------------------------------------------------------ +# Final checks +# ------------------------------------------------------------ + +echo +echo "Organization complete." +echo +echo "Backup is here:" +echo "${BACKUP_DIR}" +echo +echo "New top-level structure:" +find . -maxdepth 2 -type d | sort +echo + +if [ -d ".git" ]; then + echo "Git status:" + git status --short +fi + +echo +echo "Next recommended commands:" +echo " git status" +echo " git add ." +echo " git commit -m \"Reorganize project structure\"" diff --git a/attach/Komplett.txt b/prompts/Komplett.txt similarity index 100% rename from attach/Komplett.txt rename to prompts/Komplett.txt diff --git a/attach/just_edss_schema.gbnf b/prompts/just_edss_schema.gbnf similarity index 100% rename from attach/just_edss_schema.gbnf rename to prompts/just_edss_schema.gbnf diff --git a/attach/just_edss_text.txt b/prompts/just_edss_text.txt similarity index 100% rename from attach/just_edss_text.txt rename to prompts/just_edss_text.txt diff --git a/certainty.py b/scripts/analyze_certainty.py similarity index 75% rename from certainty.py rename to scripts/analyze_certainty.py index fac6998..a9f1038 100644 --- a/certainty.py +++ b/scripts/analyze_certainty.py @@ -1258,973 +1258,973 @@ # %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark - -import time -import json -import os -import re -import threading -from datetime import datetime -from pathlib import Path - -import pandas as pd -from openai import OpenAI -from dotenv import load_dotenv - -try: - import psutil -except ImportError: - psutil = None - print("⚠️ psutil is not installed. Resource metrics will be limited.") - print("Install with: pip install psutil") - - -# ========================= -# CONFIGURATION -# ========================= - -load_dotenv() - -OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") -OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") - -MODEL_CONFIGS = [ +# +#import time +#import json +#import os +#import re +#import threading +#from datetime import datetime +#from pathlib import Path +# +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +#try: +# import psutil +#except ImportError: +# psutil = None +# print("⚠️ psutil is not installed. Resource metrics will be limited.") +# print("Install with: pip install psutil") +# +# +## ========================= +## CONFIGURATION +## ========================= +# +#load_dotenv() +# +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +# +#MODEL_CONFIGS = [ +## { +## "model_name": "qwen3.6-35b-a3b", +## "use_response_format": False, +## "temperature": 0.0, +## "max_tokens": 4096, +## +## # If your backend is vLLM / Qwen chat-template compatible, +## # this may reduce long hidden reasoning and JSON truncation. +## # If your server errors because of extra_body, set this to None. +## "extra_body": { +## "chat_template_kwargs": { +## "enable_thinking": False +## } +## }, +## }, # { -# "model_name": "qwen3.6-35b-a3b", +# "model_name": "gemma-4-31B-it", # "use_response_format": False, # "temperature": 0.0, # "max_tokens": 4096, -# -# # If your backend is vLLM / Qwen chat-template compatible, -# # this may reduce long hidden reasoning and JSON truncation. -# # If your server errors because of extra_body, set this to None. -# "extra_body": { -# "chat_template_kwargs": { -# "enable_thinking": False -# } -# }, +# "extra_body": None, # }, - { - "model_name": "gemma-4-31B-it", - "use_response_format": False, - "temperature": 0.0, - "max_tokens": 4096, - "extra_body": None, - }, - # { - # "model_name": "GPT-OSS-120B", - # "use_response_format": True, - # "temperature": 0.0, - # "max_tokens": 4096, - # "extra_body": None, - # }, -] - -INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" -EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" - -RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark" - -NUM_ITERATIONS = 10 -STOP_ON_FIRST_ERROR = False - -# For testing, set to e.g. 2. -# For full run, set to None. -MAX_ROWS = 2 -# MAX_ROWS = 2 - -MAX_TOKENS = 4096 -TEMPERATURE = 0.0 - -RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 - -SAVE_EVERY_N_ROWS = 1 - -# Retries for invalid JSON / truncated JSON -MAX_JSON_RETRIES = 2 -RETRY_SLEEP_SEC = 2 - - -# ========================= -# CLIENT -# ========================= - -client = OpenAI( - api_key=OPENAI_API_KEY, - base_url=OPENAI_BASE_URL -) - - -# ========================= -# HELPERS -# ========================= - -def safe_dir_name(name: str) -> str: - name = str(name).strip() - name = re.sub(r"[^\w\-.]+", "_", name) - return name[:150] - - -def now_timestamp() -> str: - return datetime.now().strftime("%Y%m%d_%H%M%S") - - -def get_process(): - if psutil is None: - return None - return psutil.Process(os.getpid()) - - -def get_memory_rss_mb(process=None): - if psutil is None: - return None - if process is None: - process = get_process() - return process.memory_info().rss / (1024 * 1024) - - -def get_cpu_times_sec(process=None): - if psutil is None: - return None - if process is None: - process = get_process() - cpu_times = process.cpu_times() - return cpu_times.user + cpu_times.system - - -class ResourceSampler: - def __init__(self, interval_sec=0.05): - self.interval_sec = interval_sec - self.process = get_process() - self.running = False - self.thread = None - self.samples_mb = [] - - def start(self): - if psutil is None: - return - - self.running = True - self.samples_mb = [] - self.thread = threading.Thread(target=self._sample_loop, daemon=True) - self.thread.start() - - def stop(self): - if psutil is None: - return - - self.running = False - if self.thread is not None: - self.thread.join(timeout=1.0) - - def _sample_loop(self): - while self.running: - try: - rss_mb = get_memory_rss_mb(self.process) - self.samples_mb.append(rss_mb) - except Exception: - pass - time.sleep(self.interval_sec) - - @property - def peak_rss_mb(self): - if not self.samples_mb: - return None - return max(self.samples_mb) - - -# ========================= -# JSON EXTRACTION -# ========================= - -def extract_json_from_text(text): - if text is None: - raise ValueError("Model returned empty content: message.content is None") - - text = str(text).strip() - - if not text: - raise ValueError("Model returned empty content") - - text = ( - text.replace("```json", "") - .replace("```JSON", "") - .replace("```Json", "") - .replace("```", "") - .strip() - ) - - # Direct parse - try: - parsed = json.loads(text) - if isinstance(parsed, dict): - return parsed - except json.JSONDecodeError: - pass - - # Balanced JSON candidates - candidates = [] - stack = [] - start_idx = None - in_string = False - escape = False - - for i, ch in enumerate(text): - if escape: - escape = False - continue - - if ch == "\\": - escape = True - continue - - if ch == '"': - in_string = not in_string - continue - - if in_string: - continue - - if ch == "{": - if not stack: - start_idx = i - stack.append(ch) - - elif ch == "}": - if stack: - stack.pop() - if not stack and start_idx is not None: - candidates.append(text[start_idx:i + 1]) - start_idx = None - - valid_objects = [] - - for candidate in candidates: - candidate = candidate.strip() - lowered = candidate.lower() - - invalid_markers = [ - "true/false", - "null or", - "oder zahl", - "0.0-6.0", - "0.0-10.0", - "zahl zwischen", - "...", - ] - - if any(marker in lowered for marker in invalid_markers): - continue - - try: - parsed = json.loads(candidate) - if isinstance(parsed, dict): - valid_objects.append(parsed) - except json.JSONDecodeError: - continue - - for obj in reversed(valid_objects): - if ( - "klassifizierbar" in obj - and "certainty_percent" in obj - and "subcategories" in obj - ): - return obj - - if valid_objects: - return valid_objects[-1] - - # Check if it looks like truncated JSON - stripped = text.strip() - if stripped.startswith("{") and not stripped.endswith("}"): - raise ValueError( - "Model output looks like truncated JSON. " - f"Raw output starts with: {text[:1000]}" - ) - - raise ValueError( - "No valid JSON object found in model output. " - f"Raw output starts with: {text[:1000]}" - ) - - -def extract_message_content(message): - raw_content = getattr(message, "content", None) - - if raw_content is not None: - return raw_content - - msg_dict = None - - try: - msg_dict = message.model_dump() - except Exception: - try: - msg_dict = dict(message) - except Exception: - msg_dict = None - - if not isinstance(msg_dict, dict): - return None - - for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: - value = msg_dict.get(key) - if value: - return value - - possible_content = msg_dict.get("content") - if isinstance(possible_content, list): - parts = [] - for block in possible_content: - if isinstance(block, dict): - if "text" in block: - parts.append(str(block["text"])) - elif "content" in block: - parts.append(str(block["content"])) - if parts: - return "\n".join(parts).strip() - - return None - - -# ========================= -# READ INSTRUCTIONS -# ========================= - -with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: - EDSS_INSTRUCTIONS = f.read().strip() - - -# ========================= -# PROMPT -# ========================= - -def build_prompt(patient_text): - return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. - -Extrahiere: -1. Gesamt-EDSS-Score von 0.0 bis 10.0 -2. Alle 8 EDSS-Unterkategorien -3. Sicherheit als Ganzzahl von 0 bis 100 - -Antworte ausschließlich mit EINEM validen JSON-Objekt. -Kein Markdown. -Keine Code-Fences. -Kein Text vor oder nach JSON. -Keine Platzhalter. -Kopiere kein Schema. - -Das JSON muss exakt diese Schlüssel enthalten: -- reason -- klassifizierbar -- EDSS -- certainty_percent -- subcategories - -Die subcategories müssen exakt diese 8 Schlüssel enthalten: -- VISUAL_OPTIC_FUNCTIONS -- BRAINSTEM_FUNCTIONS -- PYRAMIDAL_FUNCTIONS -- CEREBELLAR_FUNCTIONS -- SENSORY_FUNCTIONS -- BOWEL_AND_BLADDER_FUNCTIONS -- CEREBRAL_FUNCTIONS -- AMBULATION - -Werte: -- klassifizierbar: true oder false -- EDSS: Zahl von 0.0 bis 10.0 oder null -- certainty_percent: Ganzzahl von 0 bis 100 -- Unterkategorien: Zahl oder null -- VISUAL_OPTIC_FUNCTIONS maximal 6.0 -- BRAINSTEM_FUNCTIONS maximal 6.0 -- PYRAMIDAL_FUNCTIONS maximal 6.0 -- CEREBELLAR_FUNCTIONS maximal 6.0 -- SENSORY_FUNCTIONS maximal 6.0 -- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 -- CEREBRAL_FUNCTIONS maximal 6.0 -- AMBULATION maximal 10.0 -- reason: maximal 250 Zeichen, Deutsch - -Wenn klassifizierbar false ist, setze EDSS auf null. - -Valide Beispielausgabe: -{{ - "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", - "klassifizierbar": true, - "EDSS": 2.0, - "certainty_percent": 90, - "subcategories": {{ - "VISUAL_OPTIC_FUNCTIONS": null, - "BRAINSTEM_FUNCTIONS": null, - "PYRAMIDAL_FUNCTIONS": 1.0, - "CEREBELLAR_FUNCTIONS": 1.0, - "SENSORY_FUNCTIONS": 1.0, - "BOWEL_AND_BLADDER_FUNCTIONS": null, - "CEREBRAL_FUNCTIONS": null, - "AMBULATION": 0.0 - }} -}} - -EDSS-Bewertungsrichtlinien: -{EDSS_INSTRUCTIONS} - -Patientenbericht: -{patient_text} - -Gib ausschließlich das finale JSON-Objekt zurück. -''' - - -# ========================= -# VALIDATION / NORMALIZATION -# ========================= - -def normalize_model_output(parsed): - if not isinstance(parsed, dict): - raise ValueError("Parsed model output is not a JSON object") - - if "certainty_percent" not in parsed: - parsed["certainty_percent"] = 0 - elif not isinstance(parsed["certainty_percent"], (int, float)): - try: - parsed["certainty_percent"] = int(parsed["certainty_percent"]) - except Exception: - parsed["certainty_percent"] = 0 - - parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"]))) - - if "klassifizierbar" not in parsed: - parsed["klassifizierbar"] = False - - if not isinstance(parsed["klassifizierbar"], bool): - if str(parsed["klassifizierbar"]).lower() in ["true", "1", "yes", "ja"]: - parsed["klassifizierbar"] = True - else: - parsed["klassifizierbar"] = False - - if not parsed.get("klassifizierbar", False): - parsed["EDSS"] = None - else: - if "EDSS" not in parsed or parsed["EDSS"] is None: - parsed["EDSS"] = 7.0 - else: - try: - parsed["EDSS"] = float(parsed["EDSS"]) - parsed["EDSS"] = max(0.0, min(10.0, parsed["EDSS"])) - except Exception: - parsed["EDSS"] = 7.0 - - required_subcategories = { - "VISUAL_OPTIC_FUNCTIONS": 6.0, - "BRAINSTEM_FUNCTIONS": 6.0, - "PYRAMIDAL_FUNCTIONS": 6.0, - "CEREBELLAR_FUNCTIONS": 6.0, - "SENSORY_FUNCTIONS": 6.0, - "BOWEL_AND_BLADDER_FUNCTIONS": 6.0, - "CEREBRAL_FUNCTIONS": 6.0, - "AMBULATION": 10.0, - } - - if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict): - parsed["subcategories"] = {} - - for key, max_value in required_subcategories.items(): - value = parsed["subcategories"].get(key, None) - - if value is None: - parsed["subcategories"][key] = None - else: - try: - value = float(value) - value = max(0.0, min(max_value, value)) - parsed["subcategories"][key] = value - except Exception: - parsed["subcategories"][key] = None - - parsed["subcategories"] = { - key: parsed["subcategories"].get(key, None) - for key in required_subcategories.keys() - } - - if "reason" not in parsed or parsed["reason"] is None: - parsed["reason"] = "" - - parsed["reason"] = str(parsed["reason"])[:250] - - return parsed - - -# ========================= -# API CALL -# ========================= - -def make_chat_completion(model_config, prompt): - model_name = model_config["model_name"] - - kwargs = dict( - messages=[ - { - "role": "system", - "content": ( - "Du bist ein JSON-Generator. " - "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " - "Keine Erklärung. Kein Markdown. Keine Code-Fences. " - "Keine Platzhalter. Kein Schema kopieren. " - "Das JSON muss vollständig geschlossen sein." - ) - }, - { - "role": "user", - "content": prompt - } - ], - model=model_name, - max_tokens=model_config.get("max_tokens", MAX_TOKENS), - temperature=model_config.get("temperature", TEMPERATURE), - ) - - if model_config.get("use_response_format", False): - kwargs["response_format"] = {"type": "json_object"} - - extra_body = model_config.get("extra_body") - if extra_body is not None: - kwargs["extra_body"] = extra_body - - return client.chat.completions.create(**kwargs) - - -# ========================= -# INFERENCE FUNCTION WITH RETRIES -# ========================= - -def run_inference(patient_text, model_config): - model_name = model_config["model_name"] - prompt = build_prompt(patient_text) - - process = get_process() - sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) - - wall_start = time.perf_counter() - cpu_start = get_cpu_times_sec(process) - rss_start_mb = get_memory_rss_mb(process) - - sampler.start() - - raw_content = None - raw_response_debug = None - last_error = None - prompt_tokens = None - completion_tokens = None - total_tokens = None - - try: - for attempt in range(1, MAX_JSON_RETRIES + 2): - try: - response = make_chat_completion( - model_config=model_config, - prompt=prompt - ) - - message = response.choices[0].message - raw_content = extract_message_content(message) - - try: - raw_response_debug = response.model_dump() - except Exception: - raw_response_debug = str(response) - - usage = getattr(response, "usage", None) - if usage is not None: - prompt_tokens = getattr(usage, "prompt_tokens", None) - completion_tokens = getattr(usage, "completion_tokens", None) - total_tokens = getattr(usage, "total_tokens", None) - - parsed = extract_json_from_text(raw_content) - parsed = normalize_model_output(parsed) - - success = True - error = None - result = parsed - break - - except Exception as e: - last_error = str(e) - - if attempt <= MAX_JSON_RETRIES: - print( - f"\n⚠️ JSON failed on attempt {attempt}. " - f"Retrying row. Error: {last_error[:300]}" - ) - time.sleep(RETRY_SLEEP_SEC) - continue - - raise - - except Exception as e: - print(f"❌ Inference error: {e}") - - success = False - error = str(e) - result = None - - finally: - sampler.stop() - - wall_end = time.perf_counter() - cpu_end = get_cpu_times_sec(process) - rss_end_mb = get_memory_rss_mb(process) - - wall_time_sec = wall_end - wall_start - - if cpu_start is not None and cpu_end is not None: - process_cpu_time_sec = cpu_end - cpu_start - else: - process_cpu_time_sec = None - - if rss_start_mb is not None and rss_end_mb is not None: - rss_delta_mb = rss_end_mb - rss_start_mb - else: - rss_delta_mb = None - - return { - "success": success, - "error": error, - "result": result, - - "model": model_name, - - "inference_time_sec": wall_time_sec, - - "process_cpu_time_sec": process_cpu_time_sec, - "rss_before_mb": rss_start_mb, - "rss_after_mb": rss_end_mb, - "rss_delta_mb": rss_delta_mb, - "peak_rss_mb": sampler.peak_rss_mb, - - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - - "raw_content": raw_content if not success else None, - "raw_response_debug": raw_response_debug if not success else None, - "last_error": last_error, - } - - -# ========================= -# BUILD PATIENT TEXT -# ========================= - -def build_patient_text(row): - return ( - str(row.get("T_Zusammenfassung", "")) + "\n" + - str(row.get("Diagnosen", "")) + "\n" + - str(row.get("T_KlinBef", "")) + "\n" + - str(row.get("T_Befunde", "")) - ) - - -# ========================= -# FLATTEN RESULTS FOR CSV -# ========================= - -def flatten_result(record): - flat = { - "model": record.get("model"), - "iteration": record.get("iteration"), - "row_index": record.get("row_index"), - "row_number_in_run": record.get("row_number_in_run"), - "unique_id": record.get("unique_id"), - "MedDatum": record.get("MedDatum"), - - "success": record.get("success"), - "error": record.get("error"), - "last_error": record.get("last_error"), - - "inference_time_sec": record.get("inference_time_sec"), - "process_cpu_time_sec": record.get("process_cpu_time_sec"), - "rss_before_mb": record.get("rss_before_mb"), - "rss_after_mb": record.get("rss_after_mb"), - "rss_delta_mb": record.get("rss_delta_mb"), - "peak_rss_mb": record.get("peak_rss_mb"), - - "prompt_tokens": record.get("prompt_tokens"), - "completion_tokens": record.get("completion_tokens"), - "total_tokens": record.get("total_tokens"), - - "raw_content": record.get("raw_content"), - } - - result = record.get("result") - - if isinstance(result, dict): - flat["reason"] = result.get("reason") - flat["klassifizierbar"] = result.get("klassifizierbar") - flat["EDSS"] = result.get("EDSS") - flat["certainty_percent"] = result.get("certainty_percent") - - subcats = result.get("subcategories", {}) - if isinstance(subcats, dict): - for key, value in subcats.items(): - flat[f"subcat_{key}"] = value - - return flat - - -def summarize_records(records): - df = pd.DataFrame([flatten_result(r) for r in records]) - - if df.empty: - return pd.DataFrame() - - success_series = df["success"].fillna(False).astype(bool) - - summary = { - "model": df["model"].iloc[0] if "model" in df.columns else None, - "n_records": len(df), - "n_success": int(success_series.sum()), - "n_failed": int((~success_series).sum()), - "success_rate": float(success_series.mean()), - } - - numeric_cols = [ - "inference_time_sec", - "process_cpu_time_sec", - "rss_delta_mb", - "peak_rss_mb", - "prompt_tokens", - "completion_tokens", - "total_tokens", - "certainty_percent", - "EDSS", - ] - - for col in numeric_cols: - if col in df.columns: - values = pd.to_numeric(df[col], errors="coerce") - summary[f"{col}_mean"] = values.mean() - summary[f"{col}_median"] = values.median() - summary[f"{col}_std"] = values.std() - summary[f"{col}_min"] = values.min() - summary[f"{col}_max"] = values.max() - - return pd.DataFrame([summary]) - - -# ========================= -# INCREMENTAL SAVE HELPERS -# ========================= - -def append_jsonl(path, record): - with open(path, "a", encoding="utf-8") as f: - f.write(json.dumps(record, ensure_ascii=False) + "\n") - f.flush() - os.fsync(f.fileno()) - - -def append_csv(path, record): - flat = flatten_result(record) - df_one = pd.DataFrame([flat]) - file_exists = Path(path).exists() - df_one.to_csv(path, mode="a", header=not file_exists, index=False) - - -# ========================= -# MAIN LOOP -# ========================= - -if __name__ == "__main__": - - run_timestamp = now_timestamp() - - results_root = Path(RESULTS_ROOT) - results_root.mkdir(parents=True, exist_ok=True) - - run_root = results_root / f"run_{run_timestamp}" - run_root.mkdir(parents=True, exist_ok=True) - - print(f"Results root: {run_root}") - - df = pd.read_csv(INPUT_CSV, sep=";") - - if MAX_ROWS is not None: - df = df.head(MAX_ROWS) - - total_rows = len(df) - - model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] - - print(f"Loaded {total_rows} patient records.") - print(f"Models: {model_names_for_print}") - print(f"Iterations per model: {NUM_ITERATIONS}") - - all_model_summaries = [] - - for model_config in MODEL_CONFIGS: - model_name = model_config["model_name"] - safe_model = safe_dir_name(model_name) - - model_dir = run_root / safe_model - model_dir.mkdir(parents=True, exist_ok=True) - - print(f"\n{'#' * 80}") - print(f"MODEL: {model_name}") - print(f"use_response_format: {model_config.get('use_response_format', False)}") - print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") - print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") - print(f"Saving to: {model_dir}") - print(f"{'#' * 80}") - - model_records = [] - model_start = time.perf_counter() - - for iteration in range(1, NUM_ITERATIONS + 1): - print(f"\n{'=' * 60}") - print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") - print(f"{'=' * 60}") - - iteration_results = [] - iteration_start = time.perf_counter() - - incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" - incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" - - print(f"Incremental JSONL: {incremental_jsonl_path}") - print(f"Incremental CSV: {incremental_csv_path}") - - for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): - print( - f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", - end="", - flush=True - ) - - try: - patient_text = build_patient_text(row) - - record = run_inference( - patient_text=patient_text, - model_config=model_config - ) - - record["iteration"] = iteration - record["row_index"] = int(idx) - record["row_number_in_run"] = int(loop_i) - record["unique_id"] = row.get("unique_id", f"row_{idx}") - record["MedDatum"] = row.get("MedDatum", None) - - iteration_results.append(record) - model_records.append(record) - - if loop_i % SAVE_EVERY_N_ROWS == 0: - append_jsonl(incremental_jsonl_path, record) - append_csv(incremental_csv_path, record) - - if record["success"]: - res = record["result"] - edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" - print( - f" ✅ EDSS={edss}, " - f"cert={res.get('certainty_percent', '?')}%, " - f"time={record['inference_time_sec']:.2f}s" - ) - else: - print(f" ❌ {record.get('error', 'Unknown error')}") - - except Exception as e: - print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") - - fallback_record = { - "success": False, - "error": str(e), - "last_error": str(e), - "result": None, - - "model": model_name, - "iteration": iteration, - "row_index": int(idx), - "row_number_in_run": int(loop_i), - "unique_id": row.get("unique_id", f"row_{idx}"), - "MedDatum": row.get("MedDatum", None), - - "inference_time_sec": None, - "process_cpu_time_sec": None, - "rss_before_mb": None, - "rss_after_mb": None, - "rss_delta_mb": None, - "peak_rss_mb": None, - - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - - "raw_content": None, - "raw_response_debug": None, - } - - iteration_results.append(fallback_record) - model_records.append(fallback_record) - - append_jsonl(incremental_jsonl_path, fallback_record) - append_csv(incremental_csv_path, fallback_record) - - if STOP_ON_FIRST_ERROR: - break - - iteration_elapsed = time.perf_counter() - iteration_start - - iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" - with open(iter_json_path, "w", encoding="utf-8") as f: - json.dump(iteration_results, f, indent=2, ensure_ascii=False) - - iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" - iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) - iter_flat_df.to_csv(iter_csv_path, index=False) - - print(f"\n✅ Iteration {iteration} complete.") - print(f"Incremental JSONL saved to: {incremental_jsonl_path}") - print(f"Incremental CSV saved to: {incremental_csv_path}") - print(f"Final JSON saved to: {iter_json_path}") - print(f"Final CSV saved to: {iter_csv_path}") - print( - f"⏱️ Iteration time: {iteration_elapsed:.1f}s " - f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" - ) - - model_elapsed = time.perf_counter() - model_start - - model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" - with open(model_json_path, "w", encoding="utf-8") as f: - json.dump(model_records, f, indent=2, ensure_ascii=False) - - model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" - model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) - model_flat_df.to_csv(model_csv_path, index=False) - - model_summary_df = summarize_records(model_records) - model_summary_df["model_total_wall_time_sec"] = model_elapsed - model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 - - model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" - model_summary_df.to_csv(model_summary_path, index=False) - - all_model_summaries.append(model_summary_df) - - print(f"\n🎉 Model completed: {model_name}") - print(f"All JSON: {model_json_path}") - print(f"All CSV: {model_csv_path}") - print(f"Summary: {model_summary_path}") - print(f"Total model time: {model_elapsed / 60:.2f} min") - - if all_model_summaries: - combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) - combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" - combined_summary_df.to_csv(combined_summary_path, index=False) - - print(f"\n📊 Combined summary saved to: {combined_summary_path}") - - print(f"\n🎉 All models and all iterations completed!") +# # { +# # "model_name": "GPT-OSS-120B", +# # "use_response_format": True, +# # "temperature": 0.0, +# # "max_tokens": 4096, +# # "extra_body": None, +# # }, +#] +# +#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +# +#RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark" +# +#NUM_ITERATIONS = 10 +#STOP_ON_FIRST_ERROR = False +# +## For testing, set to e.g. 2. +## For full run, set to None. +#MAX_ROWS = 2 +## MAX_ROWS = 2 +# +#MAX_TOKENS = 4096 +#TEMPERATURE = 0.0 +# +#RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 +# +#SAVE_EVERY_N_ROWS = 1 +# +## Retries for invalid JSON / truncated JSON +#MAX_JSON_RETRIES = 2 +#RETRY_SLEEP_SEC = 2 +# +# +## ========================= +## CLIENT +## ========================= +# +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +# +## ========================= +## HELPERS +## ========================= +# +#def safe_dir_name(name: str) -> str: +# name = str(name).strip() +# name = re.sub(r"[^\w\-.]+", "_", name) +# return name[:150] +# +# +#def now_timestamp() -> str: +# return datetime.now().strftime("%Y%m%d_%H%M%S") +# +# +#def get_process(): +# if psutil is None: +# return None +# return psutil.Process(os.getpid()) +# +# +#def get_memory_rss_mb(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# return process.memory_info().rss / (1024 * 1024) +# +# +#def get_cpu_times_sec(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# cpu_times = process.cpu_times() +# return cpu_times.user + cpu_times.system +# +# +#class ResourceSampler: +# def __init__(self, interval_sec=0.05): +# self.interval_sec = interval_sec +# self.process = get_process() +# self.running = False +# self.thread = None +# self.samples_mb = [] +# +# def start(self): +# if psutil is None: +# return +# +# self.running = True +# self.samples_mb = [] +# self.thread = threading.Thread(target=self._sample_loop, daemon=True) +# self.thread.start() +# +# def stop(self): +# if psutil is None: +# return +# +# self.running = False +# if self.thread is not None: +# self.thread.join(timeout=1.0) +# +# def _sample_loop(self): +# while self.running: +# try: +# rss_mb = get_memory_rss_mb(self.process) +# self.samples_mb.append(rss_mb) +# except Exception: +# pass +# time.sleep(self.interval_sec) +# +# @property +# def peak_rss_mb(self): +# if not self.samples_mb: +# return None +# return max(self.samples_mb) +# +# +## ========================= +## JSON EXTRACTION +## ========================= +# +#def extract_json_from_text(text): +# if text is None: +# raise ValueError("Model returned empty content: message.content is None") +# +# text = str(text).strip() +# +# if not text: +# raise ValueError("Model returned empty content") +# +# text = ( +# text.replace("```json", "") +# .replace("```JSON", "") +# .replace("```Json", "") +# .replace("```", "") +# .strip() +# ) +# +# # Direct parse +# try: +# parsed = json.loads(text) +# if isinstance(parsed, dict): +# return parsed +# except json.JSONDecodeError: +# pass +# +# # Balanced JSON candidates +# candidates = [] +# stack = [] +# start_idx = None +# in_string = False +# escape = False +# +# for i, ch in enumerate(text): +# if escape: +# escape = False +# continue +# +# if ch == "\\": +# escape = True +# continue +# +# if ch == '"': +# in_string = not in_string +# continue +# +# if in_string: +# continue +# +# if ch == "{": +# if not stack: +# start_idx = i +# stack.append(ch) +# +# elif ch == "}": +# if stack: +# stack.pop() +# if not stack and start_idx is not None: +# candidates.append(text[start_idx:i + 1]) +# start_idx = None +# +# valid_objects = [] +# +# for candidate in candidates: +# candidate = candidate.strip() +# lowered = candidate.lower() +# +# invalid_markers = [ +# "true/false", +# "null or", +# "oder zahl", +# "0.0-6.0", +# "0.0-10.0", +# "zahl zwischen", +# "...", +# ] +# +# if any(marker in lowered for marker in invalid_markers): +# continue +# +# try: +# parsed = json.loads(candidate) +# if isinstance(parsed, dict): +# valid_objects.append(parsed) +# except json.JSONDecodeError: +# continue +# +# for obj in reversed(valid_objects): +# if ( +# "klassifizierbar" in obj +# and "certainty_percent" in obj +# and "subcategories" in obj +# ): +# return obj +# +# if valid_objects: +# return valid_objects[-1] +# +# # Check if it looks like truncated JSON +# stripped = text.strip() +# if stripped.startswith("{") and not stripped.endswith("}"): +# raise ValueError( +# "Model output looks like truncated JSON. " +# f"Raw output starts with: {text[:1000]}" +# ) +# +# raise ValueError( +# "No valid JSON object found in model output. " +# f"Raw output starts with: {text[:1000]}" +# ) +# +# +#def extract_message_content(message): +# raw_content = getattr(message, "content", None) +# +# if raw_content is not None: +# return raw_content +# +# msg_dict = None +# +# try: +# msg_dict = message.model_dump() +# except Exception: +# try: +# msg_dict = dict(message) +# except Exception: +# msg_dict = None +# +# if not isinstance(msg_dict, dict): +# return None +# +# for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: +# value = msg_dict.get(key) +# if value: +# return value +# +# possible_content = msg_dict.get("content") +# if isinstance(possible_content, list): +# parts = [] +# for block in possible_content: +# if isinstance(block, dict): +# if "text" in block: +# parts.append(str(block["text"])) +# elif "content" in block: +# parts.append(str(block["content"])) +# if parts: +# return "\n".join(parts).strip() +# +# return None +# +# +## ========================= +## READ INSTRUCTIONS +## ========================= +# +#with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: +# EDSS_INSTRUCTIONS = f.read().strip() +# +# +## ========================= +## PROMPT +## ========================= +# +#def build_prompt(patient_text): +# return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. +# +#Extrahiere: +#1. Gesamt-EDSS-Score von 0.0 bis 10.0 +#2. Alle 8 EDSS-Unterkategorien +#3. Sicherheit als Ganzzahl von 0 bis 100 +# +#Antworte ausschließlich mit EINEM validen JSON-Objekt. +#Kein Markdown. +#Keine Code-Fences. +#Kein Text vor oder nach JSON. +#Keine Platzhalter. +#Kopiere kein Schema. +# +#Das JSON muss exakt diese Schlüssel enthalten: +#- reason +#- klassifizierbar +#- EDSS +#- certainty_percent +#- subcategories +# +#Die subcategories müssen exakt diese 8 Schlüssel enthalten: +#- VISUAL_OPTIC_FUNCTIONS +#- BRAINSTEM_FUNCTIONS +#- PYRAMIDAL_FUNCTIONS +#- CEREBELLAR_FUNCTIONS +#- SENSORY_FUNCTIONS +#- BOWEL_AND_BLADDER_FUNCTIONS +#- CEREBRAL_FUNCTIONS +#- AMBULATION +# +#Werte: +#- klassifizierbar: true oder false +#- EDSS: Zahl von 0.0 bis 10.0 oder null +#- certainty_percent: Ganzzahl von 0 bis 100 +#- Unterkategorien: Zahl oder null +#- VISUAL_OPTIC_FUNCTIONS maximal 6.0 +#- BRAINSTEM_FUNCTIONS maximal 6.0 +#- PYRAMIDAL_FUNCTIONS maximal 6.0 +#- CEREBELLAR_FUNCTIONS maximal 6.0 +#- SENSORY_FUNCTIONS maximal 6.0 +#- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 +#- CEREBRAL_FUNCTIONS maximal 6.0 +#- AMBULATION maximal 10.0 +#- reason: maximal 250 Zeichen, Deutsch +# +#Wenn klassifizierbar false ist, setze EDSS auf null. +# +#Valide Beispielausgabe: +#{{ +# "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", +# "klassifizierbar": true, +# "EDSS": 2.0, +# "certainty_percent": 90, +# "subcategories": {{ +# "VISUAL_OPTIC_FUNCTIONS": null, +# "BRAINSTEM_FUNCTIONS": null, +# "PYRAMIDAL_FUNCTIONS": 1.0, +# "CEREBELLAR_FUNCTIONS": 1.0, +# "SENSORY_FUNCTIONS": 1.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": null, +# "CEREBRAL_FUNCTIONS": null, +# "AMBULATION": 0.0 +# }} +#}} +# +#EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +# +#Gib ausschließlich das finale JSON-Objekt zurück. +#''' +# +# +## ========================= +## VALIDATION / NORMALIZATION +## ========================= +# +#def normalize_model_output(parsed): +# if not isinstance(parsed, dict): +# raise ValueError("Parsed model output is not a JSON object") +# +# if "certainty_percent" not in parsed: +# parsed["certainty_percent"] = 0 +# elif not isinstance(parsed["certainty_percent"], (int, float)): +# try: +# parsed["certainty_percent"] = int(parsed["certainty_percent"]) +# except Exception: +# parsed["certainty_percent"] = 0 +# +# parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"]))) +# +# if "klassifizierbar" not in parsed: +# parsed["klassifizierbar"] = False +# +# if not isinstance(parsed["klassifizierbar"], bool): +# if str(parsed["klassifizierbar"]).lower() in ["true", "1", "yes", "ja"]: +# parsed["klassifizierbar"] = True +# else: +# parsed["klassifizierbar"] = False +# +# if not parsed.get("klassifizierbar", False): +# parsed["EDSS"] = None +# else: +# if "EDSS" not in parsed or parsed["EDSS"] is None: +# parsed["EDSS"] = 7.0 +# else: +# try: +# parsed["EDSS"] = float(parsed["EDSS"]) +# parsed["EDSS"] = max(0.0, min(10.0, parsed["EDSS"])) +# except Exception: +# parsed["EDSS"] = 7.0 +# +# required_subcategories = { +# "VISUAL_OPTIC_FUNCTIONS": 6.0, +# "BRAINSTEM_FUNCTIONS": 6.0, +# "PYRAMIDAL_FUNCTIONS": 6.0, +# "CEREBELLAR_FUNCTIONS": 6.0, +# "SENSORY_FUNCTIONS": 6.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": 6.0, +# "CEREBRAL_FUNCTIONS": 6.0, +# "AMBULATION": 10.0, +# } +# +# if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict): +# parsed["subcategories"] = {} +# +# for key, max_value in required_subcategories.items(): +# value = parsed["subcategories"].get(key, None) +# +# if value is None: +# parsed["subcategories"][key] = None +# else: +# try: +# value = float(value) +# value = max(0.0, min(max_value, value)) +# parsed["subcategories"][key] = value +# except Exception: +# parsed["subcategories"][key] = None +# +# parsed["subcategories"] = { +# key: parsed["subcategories"].get(key, None) +# for key in required_subcategories.keys() +# } +# +# if "reason" not in parsed or parsed["reason"] is None: +# parsed["reason"] = "" +# +# parsed["reason"] = str(parsed["reason"])[:250] +# +# return parsed +# +# +## ========================= +## API CALL +## ========================= +# +#def make_chat_completion(model_config, prompt): +# model_name = model_config["model_name"] +# +# kwargs = dict( +# messages=[ +# { +# "role": "system", +# "content": ( +# "Du bist ein JSON-Generator. " +# "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " +# "Keine Erklärung. Kein Markdown. Keine Code-Fences. " +# "Keine Platzhalter. Kein Schema kopieren. " +# "Das JSON muss vollständig geschlossen sein." +# ) +# }, +# { +# "role": "user", +# "content": prompt +# } +# ], +# model=model_name, +# max_tokens=model_config.get("max_tokens", MAX_TOKENS), +# temperature=model_config.get("temperature", TEMPERATURE), +# ) +# +# if model_config.get("use_response_format", False): +# kwargs["response_format"] = {"type": "json_object"} +# +# extra_body = model_config.get("extra_body") +# if extra_body is not None: +# kwargs["extra_body"] = extra_body +# +# return client.chat.completions.create(**kwargs) +# +# +## ========================= +## INFERENCE FUNCTION WITH RETRIES +## ========================= +# +#def run_inference(patient_text, model_config): +# model_name = model_config["model_name"] +# prompt = build_prompt(patient_text) +# +# process = get_process() +# sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) +# +# wall_start = time.perf_counter() +# cpu_start = get_cpu_times_sec(process) +# rss_start_mb = get_memory_rss_mb(process) +# +# sampler.start() +# +# raw_content = None +# raw_response_debug = None +# last_error = None +# prompt_tokens = None +# completion_tokens = None +# total_tokens = None +# +# try: +# for attempt in range(1, MAX_JSON_RETRIES + 2): +# try: +# response = make_chat_completion( +# model_config=model_config, +# prompt=prompt +# ) +# +# message = response.choices[0].message +# raw_content = extract_message_content(message) +# +# try: +# raw_response_debug = response.model_dump() +# except Exception: +# raw_response_debug = str(response) +# +# usage = getattr(response, "usage", None) +# if usage is not None: +# prompt_tokens = getattr(usage, "prompt_tokens", None) +# completion_tokens = getattr(usage, "completion_tokens", None) +# total_tokens = getattr(usage, "total_tokens", None) +# +# parsed = extract_json_from_text(raw_content) +# parsed = normalize_model_output(parsed) +# +# success = True +# error = None +# result = parsed +# break +# +# except Exception as e: +# last_error = str(e) +# +# if attempt <= MAX_JSON_RETRIES: +# print( +# f"\n⚠️ JSON failed on attempt {attempt}. " +# f"Retrying row. Error: {last_error[:300]}" +# ) +# time.sleep(RETRY_SLEEP_SEC) +# continue +# +# raise +# +# except Exception as e: +# print(f"❌ Inference error: {e}") +# +# success = False +# error = str(e) +# result = None +# +# finally: +# sampler.stop() +# +# wall_end = time.perf_counter() +# cpu_end = get_cpu_times_sec(process) +# rss_end_mb = get_memory_rss_mb(process) +# +# wall_time_sec = wall_end - wall_start +# +# if cpu_start is not None and cpu_end is not None: +# process_cpu_time_sec = cpu_end - cpu_start +# else: +# process_cpu_time_sec = None +# +# if rss_start_mb is not None and rss_end_mb is not None: +# rss_delta_mb = rss_end_mb - rss_start_mb +# else: +# rss_delta_mb = None +# +# return { +# "success": success, +# "error": error, +# "result": result, +# +# "model": model_name, +# +# "inference_time_sec": wall_time_sec, +# +# "process_cpu_time_sec": process_cpu_time_sec, +# "rss_before_mb": rss_start_mb, +# "rss_after_mb": rss_end_mb, +# "rss_delta_mb": rss_delta_mb, +# "peak_rss_mb": sampler.peak_rss_mb, +# +# "prompt_tokens": prompt_tokens, +# "completion_tokens": completion_tokens, +# "total_tokens": total_tokens, +# +# "raw_content": raw_content if not success else None, +# "raw_response_debug": raw_response_debug if not success else None, +# "last_error": last_error, +# } +# +# +## ========================= +## BUILD PATIENT TEXT +## ========================= +# +#def build_patient_text(row): +# return ( +# str(row.get("T_Zusammenfassung", "")) + "\n" + +# str(row.get("Diagnosen", "")) + "\n" + +# str(row.get("T_KlinBef", "")) + "\n" + +# str(row.get("T_Befunde", "")) +# ) +# +# +## ========================= +## FLATTEN RESULTS FOR CSV +## ========================= +# +#def flatten_result(record): +# flat = { +# "model": record.get("model"), +# "iteration": record.get("iteration"), +# "row_index": record.get("row_index"), +# "row_number_in_run": record.get("row_number_in_run"), +# "unique_id": record.get("unique_id"), +# "MedDatum": record.get("MedDatum"), +# +# "success": record.get("success"), +# "error": record.get("error"), +# "last_error": record.get("last_error"), +# +# "inference_time_sec": record.get("inference_time_sec"), +# "process_cpu_time_sec": record.get("process_cpu_time_sec"), +# "rss_before_mb": record.get("rss_before_mb"), +# "rss_after_mb": record.get("rss_after_mb"), +# "rss_delta_mb": record.get("rss_delta_mb"), +# "peak_rss_mb": record.get("peak_rss_mb"), +# +# "prompt_tokens": record.get("prompt_tokens"), +# "completion_tokens": record.get("completion_tokens"), +# "total_tokens": record.get("total_tokens"), +# +# "raw_content": record.get("raw_content"), +# } +# +# result = record.get("result") +# +# if isinstance(result, dict): +# flat["reason"] = result.get("reason") +# flat["klassifizierbar"] = result.get("klassifizierbar") +# flat["EDSS"] = result.get("EDSS") +# flat["certainty_percent"] = result.get("certainty_percent") +# +# subcats = result.get("subcategories", {}) +# if isinstance(subcats, dict): +# for key, value in subcats.items(): +# flat[f"subcat_{key}"] = value +# +# return flat +# +# +#def summarize_records(records): +# df = pd.DataFrame([flatten_result(r) for r in records]) +# +# if df.empty: +# return pd.DataFrame() +# +# success_series = df["success"].fillna(False).astype(bool) +# +# summary = { +# "model": df["model"].iloc[0] if "model" in df.columns else None, +# "n_records": len(df), +# "n_success": int(success_series.sum()), +# "n_failed": int((~success_series).sum()), +# "success_rate": float(success_series.mean()), +# } +# +# numeric_cols = [ +# "inference_time_sec", +# "process_cpu_time_sec", +# "rss_delta_mb", +# "peak_rss_mb", +# "prompt_tokens", +# "completion_tokens", +# "total_tokens", +# "certainty_percent", +# "EDSS", +# ] +# +# for col in numeric_cols: +# if col in df.columns: +# values = pd.to_numeric(df[col], errors="coerce") +# summary[f"{col}_mean"] = values.mean() +# summary[f"{col}_median"] = values.median() +# summary[f"{col}_std"] = values.std() +# summary[f"{col}_min"] = values.min() +# summary[f"{col}_max"] = values.max() +# +# return pd.DataFrame([summary]) +# +# +## ========================= +## INCREMENTAL SAVE HELPERS +## ========================= +# +#def append_jsonl(path, record): +# with open(path, "a", encoding="utf-8") as f: +# f.write(json.dumps(record, ensure_ascii=False) + "\n") +# f.flush() +# os.fsync(f.fileno()) +# +# +#def append_csv(path, record): +# flat = flatten_result(record) +# df_one = pd.DataFrame([flat]) +# file_exists = Path(path).exists() +# df_one.to_csv(path, mode="a", header=not file_exists, index=False) +# +# +## ========================= +## MAIN LOOP +## ========================= +# +#if __name__ == "__main__": +# +# run_timestamp = now_timestamp() +# +# results_root = Path(RESULTS_ROOT) +# results_root.mkdir(parents=True, exist_ok=True) +# +# run_root = results_root / f"run_{run_timestamp}" +# run_root.mkdir(parents=True, exist_ok=True) +# +# print(f"Results root: {run_root}") +# +# df = pd.read_csv(INPUT_CSV, sep=";") +# +# if MAX_ROWS is not None: +# df = df.head(MAX_ROWS) +# +# total_rows = len(df) +# +# model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] +# +# print(f"Loaded {total_rows} patient records.") +# print(f"Models: {model_names_for_print}") +# print(f"Iterations per model: {NUM_ITERATIONS}") +# +# all_model_summaries = [] +# +# for model_config in MODEL_CONFIGS: +# model_name = model_config["model_name"] +# safe_model = safe_dir_name(model_name) +# +# model_dir = run_root / safe_model +# model_dir.mkdir(parents=True, exist_ok=True) +# +# print(f"\n{'#' * 80}") +# print(f"MODEL: {model_name}") +# print(f"use_response_format: {model_config.get('use_response_format', False)}") +# print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") +# print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") +# print(f"Saving to: {model_dir}") +# print(f"{'#' * 80}") +# +# model_records = [] +# model_start = time.perf_counter() +# +# for iteration in range(1, NUM_ITERATIONS + 1): +# print(f"\n{'=' * 60}") +# print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") +# print(f"{'=' * 60}") +# +# iteration_results = [] +# iteration_start = time.perf_counter() +# +# incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" +# incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" +# +# print(f"Incremental JSONL: {incremental_jsonl_path}") +# print(f"Incremental CSV: {incremental_csv_path}") +# +# for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): +# print( +# f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", +# end="", +# flush=True +# ) +# +# try: +# patient_text = build_patient_text(row) +# +# record = run_inference( +# patient_text=patient_text, +# model_config=model_config +# ) +# +# record["iteration"] = iteration +# record["row_index"] = int(idx) +# record["row_number_in_run"] = int(loop_i) +# record["unique_id"] = row.get("unique_id", f"row_{idx}") +# record["MedDatum"] = row.get("MedDatum", None) +# +# iteration_results.append(record) +# model_records.append(record) +# +# if loop_i % SAVE_EVERY_N_ROWS == 0: +# append_jsonl(incremental_jsonl_path, record) +# append_csv(incremental_csv_path, record) +# +# if record["success"]: +# res = record["result"] +# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" +# print( +# f" ✅ EDSS={edss}, " +# f"cert={res.get('certainty_percent', '?')}%, " +# f"time={record['inference_time_sec']:.2f}s" +# ) +# else: +# print(f" ❌ {record.get('error', 'Unknown error')}") +# +# except Exception as e: +# print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") +# +# fallback_record = { +# "success": False, +# "error": str(e), +# "last_error": str(e), +# "result": None, +# +# "model": model_name, +# "iteration": iteration, +# "row_index": int(idx), +# "row_number_in_run": int(loop_i), +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None), +# +# "inference_time_sec": None, +# "process_cpu_time_sec": None, +# "rss_before_mb": None, +# "rss_after_mb": None, +# "rss_delta_mb": None, +# "peak_rss_mb": None, +# +# "prompt_tokens": None, +# "completion_tokens": None, +# "total_tokens": None, +# +# "raw_content": None, +# "raw_response_debug": None, +# } +# +# iteration_results.append(fallback_record) +# model_records.append(fallback_record) +# +# append_jsonl(incremental_jsonl_path, fallback_record) +# append_csv(incremental_csv_path, fallback_record) +# +# if STOP_ON_FIRST_ERROR: +# break +# +# iteration_elapsed = time.perf_counter() - iteration_start +# +# iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" +# with open(iter_json_path, "w", encoding="utf-8") as f: +# json.dump(iteration_results, f, indent=2, ensure_ascii=False) +# +# iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" +# iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) +# iter_flat_df.to_csv(iter_csv_path, index=False) +# +# print(f"\n✅ Iteration {iteration} complete.") +# print(f"Incremental JSONL saved to: {incremental_jsonl_path}") +# print(f"Incremental CSV saved to: {incremental_csv_path}") +# print(f"Final JSON saved to: {iter_json_path}") +# print(f"Final CSV saved to: {iter_csv_path}") +# print( +# f"⏱️ Iteration time: {iteration_elapsed:.1f}s " +# f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" +# ) +# +# model_elapsed = time.perf_counter() - model_start +# +# model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" +# with open(model_json_path, "w", encoding="utf-8") as f: +# json.dump(model_records, f, indent=2, ensure_ascii=False) +# +# model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" +# model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) +# model_flat_df.to_csv(model_csv_path, index=False) +# +# model_summary_df = summarize_records(model_records) +# model_summary_df["model_total_wall_time_sec"] = model_elapsed +# model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 +# +# model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" +# model_summary_df.to_csv(model_summary_path, index=False) +# +# all_model_summaries.append(model_summary_df) +# +# print(f"\n🎉 Model completed: {model_name}") +# print(f"All JSON: {model_json_path}") +# print(f"All CSV: {model_csv_path}") +# print(f"Summary: {model_summary_path}") +# print(f"Total model time: {model_elapsed / 60:.2f} min") +# +# if all_model_summaries: +# combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) +# combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" +# combined_summary_df.to_csv(combined_summary_path, index=False) +# +# print(f"\n📊 Combined summary saved to: {combined_summary_path}") +# +# print(f"\n🎉 All models and all iterations completed!") ## # %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark2 diff --git a/audit.py b/scripts/audit_outputs.py similarity index 100% rename from audit.py rename to scripts/audit_outputs.py diff --git a/certainty_show.py b/scripts/certainty_show.py similarity index 100% rename from certainty_show.py rename to scripts/certainty_show.py diff --git a/figure1.py b/scripts/figure1.py similarity index 100% rename from figure1.py rename to scripts/figure1.py diff --git a/show_plots.py b/scripts/show_plots.py similarity index 100% rename from show_plots.py rename to scripts/show_plots.py diff --git a/show_plots.py.orig b/show_plots.py.orig deleted file mode 100644 index 4332b98..0000000 --- a/show_plots.py.orig +++ /dev/null @@ -1,2320 +0,0 @@ -# %% Scatter -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np - -# Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results -df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') -df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') - -# Convert to float (handle invalid entries gracefully) -df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') -df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') - -# Drop rows where either column is NaN -df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) - -# Create scatter plot -plt.figure(figsize=(8, 6)) -plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue') - -# Add labels and title -plt.xlabel('GT_EDSS') -plt.ylabel('LLM_Results') -plt.title('Comparison of GT_EDSS vs LLM_Results') - -# Optional: Add a diagonal line for reference (perfect prediction) -plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction') -plt.legend() - -# Show plot -plt.grid(True) -plt.tight_layout() -plt.show() - -## - - -# %% Bland0-altman - -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -import statsmodels.api as sm - -# Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results -df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') -df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') - -# Convert to float (handle invalid entries gracefully) -df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') -df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') - -# Drop rows where either column is NaN -df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) - -# Create Bland-Altman plot -f, ax = plt.subplots(1, figsize=(8, 5)) -sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax) - -# Add labels and title -ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results') -ax.set_xlabel('Mean of GT_EDSS and LLM_Results') -ax.set_ylabel('Difference between GT_EDSS and LLM_Results') - -# Display Bland-Altman plot -plt.tight_layout() -plt.show() - -# Print some statistics -mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results']) -std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results']) -print(f"Mean difference: {mean_diff:.3f}") -print(f"Standard deviation of differences: {std_diff:.3f}") -print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]") - -## - - - -# %% Confusion matrix -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from sklearn.metrics import confusion_matrix, classification_report -import seaborn as sns - -# Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS -df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') -df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') - -# Convert to float (handle invalid entries gracefully) -df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') -df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') - -# Drop rows where either column is NaN -df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) - -# For confusion matrix, we need to categorize the values -# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) -def categorize_edss(value): - if pd.isna(value): - return np.nan - elif value <= 1.0: - return '0-1' - elif value <= 2.0: - return '1-2' - elif value <= 3.0: - return '2-3' - elif value <= 4.0: - return '3-4' - elif value <= 5.0: - return '4-5' - elif value <= 6.0: - return '5-6' - elif value <= 7.0: - return '6-7' - elif value <= 8.0: - return '7-8' - elif value <= 9.0: - return '8-9' - elif value <= 10.0: - return '9-10' - else: - return '10+' - -# Create categorical versions -df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) -df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) - -# Remove any NaN categories -df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) - -# Create confusion matrix -cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], - labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) - -# Plot confusion matrix -plt.figure(figsize=(10, 8)) -sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', - xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], - yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) -#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') -plt.xlabel('LLM Generated EDSS') -plt.ylabel('Ground Truth EDSS') -plt.tight_layout() -plt.show() - -# Print classification report -print("Classification Report:") -print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) - -# Print raw counts -print("\nConfusion Matrix (Raw Counts):") -print(cm) - -## - - -# %% Confusion matrix -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from sklearn.metrics import confusion_matrix, classification_report -import seaborn as sns - -# Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS -df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') -df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') - -# Convert to float (handle invalid entries gracefully) -df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') -df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') - -# Drop rows where either column is NaN -df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) - -# For confusion matrix, we need to categorize the values -# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) -def categorize_edss(value): - if pd.isna(value): - return np.nan - elif value <= 1.0: - return '0-1' - elif value <= 2.0: - return '1-2' - elif value <= 3.0: - return '2-3' - elif value <= 4.0: - return '3-4' - elif value <= 5.0: - return '4-5' - elif value <= 6.0: - return '5-6' - elif value <= 7.0: - return '6-7' - elif value <= 8.0: - return '7-8' - elif value <= 9.0: - return '8-9' - elif value <= 10.0: - return '9-10' - else: - return '10+' - -# Create categorical versions -df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) -df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) - -# Remove any NaN categories -df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) - -# Create confusion matrix -cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], - labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) - -# Plot confusion matrix -plt.figure(figsize=(10, 8)) -ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', - xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], - yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) - -# Add legend text above the color bar -# Get the colorbar object -cbar = ax.collections[0].colorbar -# Add text above the colorbar -cbar.set_label('Number of Cases', rotation=270, labelpad=20) - -plt.xlabel('LLM Generated EDSS') -plt.ylabel('Ground Truth EDSS') -#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') -plt.tight_layout() -plt.show() - -# Print classification report -print("Classification Report:") -print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) - -# Print raw counts -print("\nConfusion Matrix (Raw Counts):") -print(cm) - - -## - - - - -# %% Classification -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.metrics import confusion_matrix -import numpy as np - -# Load your data from TSV file -file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' - -df = pd.read_csv(file_path, sep='\t') - -# Check data structure -print("Data shape:", df.shape) -print("First few rows:") -print(df.head()) -print("\nColumn names:") -for col in df.columns: - print(f" {col}") - -# Function to safely convert to boolean -def safe_bool_convert(series): - '''Safely convert series to boolean, handling various input formats''' - # Convert to string first, then to boolean - series_str = series.astype(str).str.strip().str.lower() - - # Handle different true/false representations - bool_map = { - 'true': True, '1': True, 'yes': True, 'y': True, - 'false': False, '0': False, 'no': False, 'n': False - } - - converted = series_str.map(bool_map) - - # Handle remaining NaN values - converted = converted.fillna(False) # or True, depending on your preference - - return converted - -# Convert columns safely -if 'result.klassifizierbar' in df.columns: - print("\nresult.klassifizierbar column info:") - print(df['result.klassifizierbar'].head(10)) - print("Unique values:", df['result.klassifizierbar'].unique()) - - df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar']) - print("After conversion:") - print(df['result.klassifizierbar'].value_counts()) - -if 'GT.klassifizierbar' in df.columns: - print("\nGT.klassifizierbar column info:") - print(df['GT.klassifizierbar'].head(10)) - print("Unique values:", df['GT.klassifizierbar'].unique()) - - df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar']) - print("After conversion:") - print(df['GT.klassifizierbar'].value_counts()) - -# Create bar chart showing only True values for klassifizierbar -if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: - # Get counts for True values only - llm_true_count = df['result.klassifizierbar'].sum() - gt_true_count = df['GT.klassifizierbar'].sum() - - # Plot using matplotlib directly - fig, ax = plt.subplots(figsize=(8, 6)) - - x = np.arange(2) - width = 0.35 - - bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8) - bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8) - - # Add value labels on bars - ax.annotate(f'{llm_true_count}', - xy=(x[0], llm_true_count), - xytext=(0, 3), - textcoords="offset points", - ha='center', va='bottom') - - ax.annotate(f'{gt_true_count}', - xy=(x[1], gt_true_count), - xytext=(0, 3), - textcoords="offset points", - ha='center', va='bottom') - - ax.set_xlabel('Classification Status (klassifizierbar)') - ax.set_ylabel('Count') - ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"') - ax.set_xticks(x) - ax.set_xticklabels(['LLM', 'GT']) - ax.legend() - - plt.tight_layout() - plt.show() - -# Create confusion matrix if both columns exist -if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: - try: - # Ensure both columns are boolean - llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool) - gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool) - - cm = confusion_matrix(gt_bool, llm_bool) - - # Plot confusion matrix - fig, ax = plt.subplots(figsize=(8, 6)) - sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', - xticklabels=['False ', 'True '], - yticklabels=['False', 'True '], - ax=ax) - ax.set_xlabel('LLM Predictions ') - ax.set_ylabel('GT Labels ') - ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"') - - plt.tight_layout() - plt.show() - - print("Confusion Matrix:") - print(cm) - - except Exception as e: - print(f"Error creating confusion matrix: {e}") - -# Show final data info -print("\nFinal DataFrame info:") -print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info()) - -## - - - - -# %% Boxplot -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import numpy as np - -# Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS -df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') -df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') - -# Convert to float (handle invalid entries gracefully) -df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') -df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') - -# Drop rows where either column is NaN -df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) - -# 1. DEFINE CATEGORY ORDER -# This ensures the X-axis is numerically logical (0-1 comes before 1-2) -category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+'] - -# Convert the column to a Categorical type with the specific order -df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss), - categories=category_order, - ordered=True) - -plt.figure(figsize=(14, 8)) - -# 2. ADD HUE FOR LEGEND -# Assigning x to 'hue' allows Seaborn to generate a legend automatically -box_plot = sns.boxplot( - data=df_clean, - x='GT.EDSS_cat', - y='result.EDSS', - hue='GT.EDSS_cat', # Added hue - palette='viridis', - linewidth=1.5, - legend=True # Ensure legend is enabled -) - -# 3. CUSTOMIZE PLOT -plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20) -plt.xlabel('Ground Truth EDSS Category', fontsize=14) -plt.ylabel('LLM Predicted EDSS', fontsize=14) - -# Move legend to the side or top -plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left') - -plt.xticks(rotation=45, ha='right', fontsize=10) -plt.grid(True, axis='y', alpha=0.3) -plt.tight_layout() - -plt.show() -## - - -# %% Postproccessing Column names - -import pandas as pd - -# Read the TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Create a mapping dictionary for German to English column names -column_mapping = { - 'EDSS':'GT.EDSS', - 'klassifizierbar': 'GT.klassifizierbar', - 'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS', - 'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS', - 'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS', - 'Sensibiliät': 'GT.SENSORY_FUNCTIONS', - 'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS', - 'Ambulation': 'GT.AMBULATION', - 'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS', - 'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS' -} - -# Rename columns -df = df.rename(columns=column_mapping) - -# Save the modified dataframe back to TSV file -df.to_csv(file_path, sep='\t', index=False) - -print("Columns have been successfully renamed!") -print("Renamed columns:") -for old_name, new_name in column_mapping.items(): - if old_name in df.columns: - print(f" {old_name} -> {new_name}") - - -## - - - - -# %% Styled table -import pandas as pd -import numpy as np -import seaborn as sns -import matplotlib.pyplot as plt -import dataframe_image as dfi -# Load data -df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t') - -# 1. Identify all GT and result columns -gt_columns = [col for col in df.columns if col.startswith('GT.')] -result_columns = [col for col in df.columns if col.startswith('result.')] - -print("GT Columns found:", gt_columns) -print("Result Columns found:", result_columns) - -# 2. Create proper mapping between GT and result columns -# Handle various naming conventions (spaces, underscores, etc.) -column_mapping = {} - -for gt_col in gt_columns: - base_name = gt_col.replace('GT.', '') - - # Clean the base name for matching - remove spaces, underscores, etc. - # Try different matching approaches - candidates = [ - f'result.{base_name}', # Exact match - f'result.{base_name.replace(" ", "_")}', # With underscores - f'result.{base_name.replace("_", " ")}', # With spaces - f'result.{base_name.replace(" ", "")}', # No spaces - f'result.{base_name.replace("_", "")}' # No underscores - ] - - # Also try case-insensitive matching - candidates.append(f'result.{base_name.lower()}') - candidates.append(f'result.{base_name.upper()}') - - # Try to find matching result column - matched = False - for candidate in candidates: - if candidate in result_columns: - column_mapping[gt_col] = candidate - matched = True - break - - # If no exact match found, try partial matching - if not matched: - # Try to match by removing special characters and comparing - base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' ']) - for result_col in result_columns: - result_base = result_col.replace('result.', '') - result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' ']) - if base_clean.lower() == result_clean.lower(): - column_mapping[gt_col] = result_col - matched = True - break - -print("Column mapping:", column_mapping) - -# 3. Faster, vectorized computation using the corrected mapping -data_list = [] - -for gt_col, result_col in column_mapping.items(): - print(f"Processing {gt_col} vs {result_col}") - - # Convert to numeric, forcing errors to NaN - s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float) - s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float) - - # Calculate matches (abs difference <= 0.5) - diff = np.abs(s1 - s2) - matches = (diff <= 0.5).sum() - - # Determine the denominator (total valid comparisons) - valid_count = diff.notna().sum() - - if valid_count > 0: - percentage = (matches / valid_count) * 100 - else: - percentage = 0 - - # Extract clean base name for display - base_name = gt_col.replace('GT.', '') - - data_list.append({ - 'GT': base_name, - 'Match %': round(percentage, 1) - }) - - - - -# 4. Prepare Data -match_df = pd.DataFrame(data_list) -# Clean up labels: Replace underscores with spaces and capitalize -match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title() -match_df = match_df.sort_values('Match %', ascending=False) - -# 5. Create a "Beautiful" Table using Seaborn Heatmap -def create_luxury_table(df, output_file="edss_agreement.png"): - # Set the aesthetic style - sns.set_theme(style="white", font="sans-serif") - - # Prepare data for heatmap - plot_data = df.set_index('GT')[['Match %']] - - # Initialize the figure - # Height is dynamic based on number of rows - fig, ax = plt.subplots(figsize=(8, len(df) * 0.6)) - - # Create a custom diverging color map (Deep Red -> Mustard -> Emerald) - # This looks more professional than standard 'RdYlGn' - cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True) - - # Draw the heatmap - sns.heatmap( - plot_data, - annot=True, - fmt=".1f", - cmap=cmap, - center=85, # Centers the color transition - vmin=50, vmax=100, # Range of the gradient - linewidths=2, - linecolor='white', - cbar=False, # Remove color bar for a "table" look - annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"} - ) - - # Styling the Axes (Turning the heatmap into a table) - ax.set_xlabel("") - ax.set_ylabel("") - ax.xaxis.tick_top() # Move "Match %" label to top - ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50') - ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0) - - # Add a thin border around the plot - for _, spine in ax.spines.items(): - spine.set_visible(True) - spine.set_color('#ecf0f1') - - plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50') - - # Add a subtle footer - plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points", - wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic') - - # Save with high resolution - plt.tight_layout() - plt.savefig(output_file, dpi=300, bbox_inches='tight') - print(f"Beautiful table saved as {output_file}") - -# Execute -create_luxury_table(match_df) - - -# Run the function -save_styled_table(match_df) -# 6. Save as SVG - -plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight') -print("Successfully saved agreement_table.svg") - -# Show plot if running in a GUI environment -plt.show() -## - - - -# %% Time Plot -import numpy as np -import matplotlib.pyplot as plt -import pandas as pd -from scipy import stats - -# Load the TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Extract the inference_time_sec column -inference_times = df['inference_time_sec'].dropna() # Remove NaN values - -# Calculate statistics -mean_time = inference_times.mean() -std_time = inference_times.std() -median_time = np.median(inference_times) - -# Create the histogram -fig, ax = plt.subplots(figsize=(10, 6)) - -# Create histogram with bins of 1 second width -min_time = int(inference_times.min()) -max_time = int(inference_times.max()) + 1 -bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width - -# Create histogram with counts (not probability density) -n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5) - -# Generate Gaussian curve for fit -x = np.linspace(inference_times.min(), inference_times.max(), 100) -# Scale Gaussian to match histogram counts -gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0]) - -# Plot Gaussian fit -ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)') - -# Add vertical lines for mean and median -ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s') -ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s') - -# Add standard deviation as vertical lines -ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s') -ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s') - -ax.set_xlabel('Inference Time (seconds)') -ax.set_ylabel('Frequency') -ax.set_title('Inference Time Distribution with Gaussian Fit') -ax.legend() -ax.grid(True, alpha=0.3) - -plt.tight_layout() -plt.show() - -## - -# %% Dashboard -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.dates as mdates -import numpy as np -from matplotlib.gridspec import GridSpec - -def to_numeric_comma(s: pd.Series) -> pd.Series: - # accepts 1.5 and 1,5 - return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") - -# Load the data -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Rename columns to remove 'result.' prefix and replace spaces -column_mapping = {} -for col in df.columns: - if col.startswith('result.'): - new_name = col.replace('result.', '').replace(' ', '_') - column_mapping[col] = new_name -df = df.rename(columns=column_mapping) - -# Parse MedDatum safely -df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') - -# Patient -patient_id = 'd13e4aa3' -patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() -if patient_data.empty: - raise ValueError(f"No data found for patient: {patient_id}") - -# Functional systems + EDSS -edss_col, edss_title = ('GT.EDSS', 'EDSS') - -functional_systems = [ - ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), - ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), - ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), - ('GT.SENSORY_FUNCTIONS', 'Sensory'), - ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), - ('GT.AMBULATION', 'Ambulation'), - ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), - ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), -] - -# y-axis max rules -ymax_by_col = { - 'GT.PYRAMIDAL_FUNCTIONS': 6, - 'GT.SENSORY_FUNCTIONS': 6, - 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, - 'GT.VISUAL_OPTIC_FUNCTIONS': 6, - 'GT.CEREBELLAR_FUNCTIONS': 5, - 'GT.CEREBRAL_FUNCTIONS': 5, - 'GT.BRAINSTEM_FUNCTIONS': 5, - 'GT.EDSS': 10, -} -default_ymax = 6 - -# ---------- Build shared "event dates" ticks ---------- -cols_for_dates = [edss_col] + [c for c, _ in functional_systems] -event_dates = [] - -for c in cols_for_dates: - if c in patient_data.columns: - y = to_numeric_comma(patient_data[c]) # <-- changed - x = patient_data['MedDatum'] - tmp = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]) - event_dates.extend(tmp["x"].tolist()) - -event_dates = sorted(pd.Series(event_dates).drop_duplicates().tolist()) - -max_ticks = 8 -if len(event_dates) > max_ticks: - idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) - event_dates = [event_dates[i] for i in idx] - -# ---------- A4 figure ---------- -fig = plt.figure(figsize=(11.69, 8.27)) -gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) - -def style_time_axis(ax, show_labels=True): - ax.set_xticks(event_dates) - ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) - ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) - if not show_labels: - ax.tick_params(labelbottom=False) - -# ---------- EDSS main plot ---------- -ax_main = fig.add_subplot(gs[0, :]) - -if edss_col in patient_data.columns: - y = to_numeric_comma(patient_data[edss_col]) # <-- changed - x = patient_data['MedDatum'] - plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") - - ax_main.set_title(edss_title, fontsize=14, fontweight='bold') - ax_main.set_ylabel("Score") - ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) - ax_main.grid(True, alpha=0.3) - - if not plot_df.empty: - ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') - else: - ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') -else: - ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') - ax_main.set_ylim(0, ymax_by_col.get(edss_col, 10)) - ax_main.grid(True, alpha=0.3) - -style_time_axis(ax_main) - -# ---------- Small aligned plots ---------- -small_axes = [] -for k, (col, title) in enumerate(functional_systems): - r = 1 + (k // 4) - c = (k % 4) - ax = fig.add_subplot(gs[r, c], sharex=ax_main) - small_axes.append(ax) - - ymax = ymax_by_col.get(col, default_ymax) - ax.set_title(title, fontsize=10) - ax.set_ylabel("Score") - ax.set_ylim(0, ymax) - ax.grid(True, alpha=0.3) - - if col in patient_data.columns: - y = to_numeric_comma(patient_data[col]) # <-- changed - x = patient_data['MedDatum'] - plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") - - if not plot_df.empty: - ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') - else: - ax.set_title(f"{title} (no data)", fontsize=10) - else: - ax.set_title(f"{title} (missing)", fontsize=10) - - style_time_axis(ax) - -# Hide x tick labels on first row of small plots -for ax in small_axes[:4]: - ax.tick_params(labelbottom=False) - -plt.tight_layout() -fig.subplots_adjust(hspace=0.7) -plt.show() -## - -<<<<<<< Updated upstream -======= - - - -# %% Dashboard Angepasst -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.dates as mdates -import numpy as np -from matplotlib.gridspec import GridSpec - -def to_numeric_comma(s: pd.Series) -> pd.Series: - # accepts 1.5 and 1,5 - return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") - -# Load the data -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Rename columns to remove 'result.' prefix and replace spaces -column_mapping = {} -for col in df.columns: - if col.startswith('result.'): - new_name = col.replace('result.', '').replace(' ', '_') - column_mapping[col] = new_name -df = df.rename(columns=column_mapping) - -# Parse MedDatum safely -df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') - -# Patient -patient_id = '3d942c60' - -patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() -if patient_data.empty: - raise ValueError(f"No data found for patient: {patient_id}") - -# Functional systems + EDSS -edss_col, edss_title = ('GT.EDSS', 'EDSS') - -functional_systems = [ - ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), - ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), - ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), - ('GT.SENSORY_FUNCTIONS', 'Sensory'), - ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), - ('GT.AMBULATION', 'Ambulation'), - ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), - ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), -] - -# y-axis max rules -ymax_by_col = { - 'GT.PYRAMIDAL_FUNCTIONS': 6, - 'GT.SENSORY_FUNCTIONS': 6, - 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, - 'GT.VISUAL_OPTIC_FUNCTIONS': 6, - 'GT.CEREBELLAR_FUNCTIONS': 5, - 'GT.CEREBRAL_FUNCTIONS': 5, - 'GT.BRAINSTEM_FUNCTIONS': 5, - 'GT.EDSS': 10, -} -default_ymax = 6 - -# ---------- Build shared visit dates ticks ---------- -# Use ALL patient visit dates, not only dates with valid numeric values -event_dates = sorted(patient_data['MedDatum'].dropna().drop_duplicates().tolist()) - -max_ticks = 8 -if len(event_dates) > max_ticks: - idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) - event_dates = [event_dates[i] for i in idx] - -# Base timeline for plotting: one row per patient visit date -timeline = ( - patient_data[['MedDatum']] - .dropna() - .drop_duplicates() - .sort_values('MedDatum') - .rename(columns={'MedDatum': 'x'}) -) - -# ---------- A4 figure ---------- -fig = plt.figure(figsize=(11.69, 8.27)) -gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) - -def style_time_axis(ax, show_labels=True): - ax.set_xticks(event_dates) - ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) - ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) - if not show_labels: - ax.tick_params(labelbottom=False) - -def get_plot_df(patient_data, col): - """ - Keep all visit dates. - Missing values stay NaN so matplotlib draws gaps instead of zeros. - """ - tmp = patient_data[['MedDatum', col]].copy() - tmp = tmp.rename(columns={'MedDatum': 'x', col: 'raw_y'}) - tmp['y'] = to_numeric_comma(tmp['raw_y']) - - # aggregate if multiple rows exist on same date - tmp = tmp.groupby('x', as_index=False)['y'].max() - - # merge onto full timeline so all dates remain visible - plot_df = timeline.merge(tmp, on='x', how='left').sort_values('x') - return plot_df - -# ---------- EDSS main plot ---------- -ax_main = fig.add_subplot(gs[0, :]) -ax_main.set_title(edss_title, fontsize=14, fontweight='bold') -ax_main.set_ylabel("Score") -ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) -ax_main.grid(True, alpha=0.3) - -if edss_col in patient_data.columns: - plot_df = get_plot_df(patient_data, edss_col) - - if plot_df['y'].notna().any(): - # NaNs create visible gaps in the line - ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') - else: - ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') -else: - ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') - -style_time_axis(ax_main) - -# ---------- Small aligned plots ---------- -small_axes = [] -for k, (col, title) in enumerate(functional_systems): - r = 1 + (k // 4) - c = (k % 4) - ax = fig.add_subplot(gs[r, c], sharex=ax_main) - small_axes.append(ax) - - ymax = ymax_by_col.get(col, default_ymax) - ax.set_title(title, fontsize=10) - ax.set_ylabel("Score") - ax.set_ylim(0, ymax) - ax.grid(True, alpha=0.3) - - if col in patient_data.columns: - plot_df = get_plot_df(patient_data, col) - - if plot_df['y'].notna().any(): - # NaNs remain in y -> line breaks where data is missing - ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') - else: - ax.set_title(f"{title} (no numeric data)", fontsize=10) - else: - ax.set_title(f"{title} (missing)", fontsize=10) - - style_time_axis(ax) - -# Hide x tick labels on first row of small plots -for ax in small_axes[:4]: - ax.tick_params(labelbottom=False) - -plt.tight_layout() -fig.subplots_adjust(hspace=0.7) -plt.show() - - - -## - ->>>>>>> Stashed changes -# %% Table -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -from datetime import datetime -import numpy as np - -# Load the data -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' -df = pd.read_csv(file_path, sep='\t') - -# Convert MedDatum to datetime -df['MedDatum'] = pd.to_datetime(df['MedDatum']) - -# Check what columns actually exist in the dataset -print("Available columns:") -print(df.columns.tolist()) -print("\nFirst few rows:") -print(df.head()) - -# Check data types -print("\nData types:") -print(df.dtypes) - -# Hardcode specific patient names -patient_names = ['6ccda8c6'] - -# Define the functional systems (columns to plot) -functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] - -# Create subplots -num_plots = len(functional_systems) -num_cols = 2 -num_rows = (num_plots + num_cols - 1) // num_cols - -fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) -if num_plots == 1: - axes = [axes] -elif num_rows == 1: - axes = axes -else: - axes = axes.flatten() - -# Plot for the hardcoded patient -for i, system in enumerate(functional_systems): - # Filter data for this specific patient - patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum') - - # Check if patient data exists - if patient_data.empty: - print(f"No data found for patient: {patient_names[0]}") - axes[i].set_title(f'Functional System: {system} (No data)') - axes[i].set_ylabel('Score') - continue - - # Check if the system column exists - if system in patient_data.columns: - # Plot only valid data (non-null values) - valid_data = patient_data.dropna(subset=[system]) - - if not valid_data.empty: - # Ensure MedDatum is properly formatted for plotting - axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system) - axes[i].set_ylabel('Score') - axes[i].set_title(f'Functional System: {system}') - axes[i].grid(True, alpha=0.3) - axes[i].legend() - else: - axes[i].set_title(f'Functional System: {system} (No valid data)') - axes[i].set_ylabel('Score') - else: - # Try to find similar column names - found_column = None - for col in df.columns: - if system.lower() in col.lower(): - found_column = col - break - - if found_column: - valid_data = patient_data.dropna(subset=[found_column]) - if not valid_data.empty: - axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column) - axes[i].set_ylabel('Score') - axes[i].set_title(f'Functional System: {system} (found as: {found_column})') - axes[i].grid(True, alpha=0.3) - axes[i].legend() - else: - axes[i].set_title(f'Functional System: {system} (No valid data)') - axes[i].set_ylabel('Score') - else: - axes[i].set_title(f'Functional System: {system} (Column not found)') - axes[i].set_ylabel('Score') - -# Hide empty subplots -for i in range(len(functional_systems), len(axes)): - axes[i].set_visible(False) - -# Set x-axis label for the last row only -for i in range(len(functional_systems)): - if i >= len(axes) - num_cols: # Last row - axes[i].set_xlabel('Date') - -# Format x-axis dates -for ax in axes: - if ax.get_lines(): # Only format if there are lines to plot - ax.tick_params(axis='x', rotation=45) - ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d')) - -# Automatically adjust layout -plt.tight_layout() -plt.show() - - - - -## - - - - -# %% Histogram Fig1 -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.font_manager as fm -import json -import os - -def create_visit_frequency_plot( - file_path, - output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', - output_filename='visit_frequency_distribution.svg', - fontsize=10, - color_scheme_path='colors.json' -): - """ - Creates a publication-ready bar chart of patient visit frequency. - - Args: - file_path (str): Path to the input TSV file. - output_dir (str): Directory to save the output SVG file. - output_filename (str): Name of the output SVG file. - fontsize (int): Font size for all text elements (labels, title). - color_scheme_path (str): Path to the JSON file containing the color palette. - """ - # --- 1. Load Data and Color Scheme --- - try: - df = pd.read_csv(file_path, sep='\t') - print("Data loaded successfully.") - # Sort data for easier visual comparison - df = df.sort_values(by='Visits Count') - except FileNotFoundError: - print(f"Error: The file was not found at {file_path}") - return - - try: - with open(color_scheme_path, 'r') as f: - colors = json.load(f) - # Select a blue from the sequential palette for the bars - bar_color = colors['sequential']['blues'][-2] # A saturated blue - except FileNotFoundError: - print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") - bar_color = '#2171b5' # A common matplotlib blue - - # --- 2. Set up the Plot with Scientific Style --- - plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height - - # Set the font to Arial - arial_font = fm.FontProperties(family='Arial', size=fontsize) - plt.rcParams['font.family'] = 'Arial' - plt.rcParams['font.size'] = fontsize - - # --- 3. Create the Bar Chart --- - ax = plt.gca() - bars = plt.bar( - x=df['Visits Count'], - height=df['Unique Patients'], - color=bar_color, - edgecolor='black', - linewidth=0.5, # Minimum line thickness - width=0.7 - ) - - # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- - # Get the unique visit counts to use as tick labels - visit_counts = df['Visits Count'].unique() - # Set the x-ticks to be at the center of each bar - ax.set_xticks(visit_counts) - # Set the x-tick labels to be the visit counts, using the specified font - ax.set_xticklabels(visit_counts, fontproperties=arial_font) - # --- END OF NEW SECTION --- - - # --- 4. Customize Axes and Layout (Nature style) --- - # Display only left and bottom axes - ax.spines['top'].set_visible(False) - ax.spines['right'].set_visible(False) - - # Turn off axis ticks (the marks, not the labels) - plt.tick_params(axis='both', which='both', length=0) - - # Remove grid lines - plt.grid(False) - - # Set background to white (no shading) - ax.set_facecolor('white') - plt.gcf().set_facecolor('white') - - # --- 5. Add Labels and Title --- - plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) - plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) - plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) - - # --- 6. Add y-axis values on top of each bar --- - # This adds the count of unique patients directly above each bar. - ax.bar_label(bars, fmt='%d', padding=3) - - # --- 7. Export the Figure --- - # Ensure the output directory exists - os.makedirs(output_dir, exist_ok=True) - - full_output_path = os.path.join(output_dir, output_filename) - plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') - print(f"\nFigure saved as '{full_output_path}'") - - # --- 8. (Optional) Display the Plot --- - # plt.show() - -# --- Main execution --- -if __name__ == '__main__': - # Define the file path - input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' - - # Call the function to create and save the plot - create_visit_frequency_plot( - file_path=input_file, - fontsize=10 # Using a 10 pt font size as per guidelines - ) - -## - - - -# %% Scatter Plot functional system - -import pandas as pd -import matplotlib.pyplot as plt -import json -import os - -# --- Configuration --- -# Set the font to Arial for all text in the plot, as per the guidelines -plt.rcParams['font.family'] = 'Arial' - -# Define the path to your data file -data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' - -# Define the path to save the color mapping JSON file -color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' - -# Define the path to save the final figure -figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' - -# --- 1. Load the Dataset --- -try: - # Load the TSV file - df = pd.read_csv(data_path, sep='\t') - print(f"Successfully loaded data from {data_path}") - print(f"Data shape: {df.shape}") -except FileNotFoundError: - print(f"Error: The file at {data_path} was not found.") - # Exit or handle the error appropriately - raise - -# --- 2. Define Functional Systems and Create Color Mapping --- -# List of tuples containing (ground_truth_column, result_column) -functional_systems_to_plot = [ - ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), - ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), - ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), - ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), - ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), - ('GT.AMBULATION', 'result.AMBULATION'), - ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), - ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') -] - -# Extract system names for color mapping and legend -system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] - -# Define a professional color palette (dark blue theme) -# This is a qualitative palette with distinct, accessible colors -colors = [ - '#003366', # Dark Blue - '#336699', # Medium Blue - '#6699CC', # Light Blue - '#99CCFF', # Very Light Blue - '#FF9966', # Coral - '#FF6666', # Light Red - '#CC6699', # Magenta - '#9966CC' # Purple -] - -# Create a dictionary mapping system names to colors -color_map = dict(zip(system_names, colors)) - -# Ensure the directory for the JSON file exists -os.makedirs(os.path.dirname(color_json_path), exist_ok=True) - -# Save the color map to a JSON file -with open(color_json_path, 'w') as f: - json.dump(color_map, f, indent=4) - -print(f"Color mapping saved to {color_json_path}") - -# --- 3. Calculate Agreement Percentages and Format Legend Labels --- -agreement_percentages = {} -legend_labels = {} - -for gt_col, res_col in functional_systems_to_plot: - system_name = gt_col.split('.')[1] - - # Convert columns to numeric, setting errors to NaN - gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') - res_numeric = pd.to_numeric(df[res_col], errors='coerce') - - # Ensure we are comparing the same rows - common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) - gt_data = gt_numeric.loc[common_index] - res_data = res_numeric.loc[common_index] - - # Calculate agreement percentage - if len(gt_data) > 0: - agreement = (gt_data == res_data).mean() * 100 - else: - agreement = 0 # Handle case with no valid data - - agreement_percentages[system_name] = agreement - - # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") - formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) - legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" - -# --- 4. Reshape Data for Plotting --- -plot_data = [] -for gt_col, res_col in functional_systems_to_plot: - system_name = gt_col.split('.')[1] - - # Convert columns to numeric, setting errors to NaN - gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') - res_numeric = pd.to_numeric(df[res_col], errors='coerce') - - # Create a temporary DataFrame with the numeric data - temp_df = pd.DataFrame({ - 'system': system_name, - 'ground_truth': gt_numeric, - 'inference': res_numeric - }) - - # Drop rows where either value is NaN, as they cannot be plotted - temp_df = temp_df.dropna() - - plot_data.append(temp_df) - -# Concatenate all the temporary DataFrames into one -plot_df = pd.concat(plot_data, ignore_index=True) - -if plot_df.empty: - print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") -else: - print(f"Prepared plot data with {len(plot_df)} data points.") - -# --- 5. Create the Scatter Plot --- -plt.figure(figsize=(10, 8)) - -# Plot each functional system with its assigned color and formatted legend label -for system, group in plot_df.groupby('system'): - plt.scatter( - group['ground_truth'], - group['inference'], - label=legend_labels[system], - color=color_map[system], - alpha=0.7, - s=30 - ) - -# Add a diagonal line representing perfect agreement (y = x) -# This line helps visualize how close the predictions are to the ground truth -if not plot_df.empty: - plt.plot( - [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], - [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], - color='black', - linestyle='--', - linewidth=0.8, - alpha=0.7 - ) - -# --- 6. Apply Styling and Labels --- -plt.xlabel('Ground Truth', fontsize=12) -plt.ylabel('LLM Inference', fontsize=12) -plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) - -# Apply scientific visualization styling rules -ax = plt.gca() -ax.spines['top'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.tick_params(axis='both', which='both', length=0) # Remove ticks -ax.grid(False) # Remove grid lines -plt.legend(title='Functional System', frameon=False, fontsize=10) - -# --- 7. Save and Display the Figure --- -# Ensure the directory for the figure exists -os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) - -plt.savefig(figure_save_path, format='svg', bbox_inches='tight') -print(f"Figure successfully saved to {figure_save_path}") - -# Display the plot -plt.show() -## - - - - -# %% Confusion Matrix functional systems - -import pandas as pd -import matplotlib.pyplot as plt -import json -import os -import numpy as np -import matplotlib.colors as mcolors - -# --- Configuration --- -plt.rcParams['font.family'] = 'Arial' -data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' -figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' - -# --- 1. Load the Dataset --- -df = pd.read_csv(data_path, sep='\t') - -# --- 2. Define Functional Systems and Colors --- -functional_systems_to_plot = [ - ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), - ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), - ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), - ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), - ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), - ('GT.AMBULATION', 'result.AMBULATION'), - ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), - ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') -] - -system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] -colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] -color_map = dict(zip(system_names, colors)) - -# --- 3. Categorization Function --- -categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] -category_to_index = {cat: i for i, cat in enumerate(categories)} -n_categories = len(categories) - -def categorize_edss(value): - if pd.isna(value): return np.nan - # Ensure value is float to avoid TypeError - val = float(value) - idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 - return categories[min(idx, len(categories)-1)] - -# --- 4. Prepare Mixed Color Matrix with Saturation --- -cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) - -for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): - # Fix: Ensure numeric conversion to avoid string comparison errors - temp_df = df[[gt_col, res_col]].copy() - temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') - temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') - valid_df = temp_df.dropna() - - for _, row in valid_df.iterrows(): - gt_cat = categorize_edss(row[gt_col]) - res_cat = categorize_edss(row[res_col]) - if gt_cat in category_to_index and res_cat in category_to_index: - cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 - -# Create an RGBA image matrix (10x10x4) -rgba_matrix = np.zeros((n_categories, n_categories, 4)) - -total_counts = np.sum(cell_system_counts, axis=2) -max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 - -for i in range(n_categories): - for j in range(n_categories): - count_sum = total_counts[i, j] - if count_sum > 0: - mixed_rgb = np.zeros(3) - for s_idx, s_name in enumerate(system_names): - weight = cell_system_counts[i, j, s_idx] / count_sum - system_rgb = mcolors.to_rgb(color_map[s_name]) - mixed_rgb += np.array(system_rgb) * weight - - # Set RGB channels - rgba_matrix[i, j, :3] = mixed_rgb - - # Set Alpha channel (Saturation Effect) - # Using a square root scale to make lower counts more visible but still "lighter" - alpha = np.sqrt(count_sum / max_count) - # Ensure alpha is at least 0.1 so it's not invisible - rgba_matrix[i, j, 3] = max(0.1, alpha) - else: - # Empty cells are white - rgba_matrix[i, j] = [1, 1, 1, 0] - -# --- 5. Plotting --- -fig, ax = plt.subplots(figsize=(12, 10)) - -# Show the matrix -# Note: we use origin='lower' if you want 0-1 at the bottom, -# but confusion matrices usually have 0-1 at the top (origin='upper') -im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') - -# Add count labels -for i in range(n_categories): - for j in range(n_categories): - if total_counts[i, j] > 0: - # Background brightness for text contrast - bg_color = rgba_matrix[i, j, :3] - lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] - # If alpha is low, background is effectively white, so use black text - text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" - ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", - color=text_col, fontsize=10, fontweight='bold') - -# --- 6. Styling --- -ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) -ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) -ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) - -ax.set_xticks(np.arange(n_categories)) -ax.set_xticklabels(categories) -ax.set_yticks(np.arange(n_categories)) -ax.set_yticklabels(categories) - -# Remove the frame/spines for a cleaner look -for spine in ax.spines.values(): - spine.set_visible(False) - -# Custom Legend -handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] -labels = [name.replace('_', ' ').capitalize() for name in system_names] -ax.legend(handles, labels, title='Functional Systems', loc='upper left', - bbox_to_anchor=(1.05, 1), frameon=False) - -plt.tight_layout() -os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) -plt.savefig(figure_save_path, format='svg', bbox_inches='tight') -plt.show() - -## - - - - - - - - - - -# %% Difference Plot Functional system - -import pandas as pd -import matplotlib.pyplot as plt -import json -import os -import numpy as np - -# --- Configuration --- -# Set the font to Arial for all text in the plot, as per the guidelines -plt.rcParams['font.family'] = 'Arial' - -# Define the path to your data file -data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' - -# Define the path to save the color mapping JSON file -color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' - -# Define the path to save the final figure -figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' - -# --- 1. Load the Dataset --- -try: - # Load the TSV file - df = pd.read_csv(data_path, sep='\t') - print(f"Successfully loaded data from {data_path}") - print(f"Data shape: {df.shape}") -except FileNotFoundError: - print(f"Error: The file at {data_path} was not found.") - # Exit or handle the error appropriately - raise - -# --- 2. Define Functional Systems and Create Color Mapping --- -# List of tuples containing (ground_truth_column, result_column) -functional_systems_to_plot = [ - ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), - ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), - ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), - ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), - ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), - ('GT.AMBULATION', 'result.AMBULATION'), - ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), - ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') -] - -# Extract system names for color mapping and legend -system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] - -# Define a professional color palette (dark blue theme) -# This is a qualitative palette with distinct, accessible colors -colors = [ - '#003366', # Dark Blue - '#336699', # Medium Blue - '#6699CC', # Light Blue - '#99CCFF', # Very Light Blue - '#FF9966', # Coral - '#FF6666', # Light Red - '#CC6699', # Magenta - '#9966CC' # Purple -] - -# Create a dictionary mapping system names to colors -color_map = dict(zip(system_names, colors)) - -# Ensure the directory for the JSON file exists -os.makedirs(os.path.dirname(color_json_path), exist_ok=True) - -# Save the color map to a JSON file -with open(color_json_path, 'w') as f: - json.dump(color_map, f, indent=4) - -print(f"Color mapping saved to {color_json_path}") - -# --- 3. Calculate Agreement Percentages and Format Legend Labels --- -agreement_percentages = {} -legend_labels = {} - -for gt_col, res_col in functional_systems_to_plot: - system_name = gt_col.split('.')[1] - - # Convert columns to numeric, setting errors to NaN - gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') - res_numeric = pd.to_numeric(df[res_col], errors='coerce') - - # Ensure we are comparing the same rows - common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) - gt_data = gt_numeric.loc[common_index] - res_data = res_numeric.loc[common_index] - - # Calculate agreement percentage - if len(gt_data) > 0: - agreement = (gt_data == res_data).mean() * 100 - else: - agreement = 0 # Handle case with no valid data - - agreement_percentages[system_name] = agreement - - # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") - formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) - legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" - - - # --- 4. Robustly Prepare Error Data for Boxplot --- - -def safe_parse(s): - '''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)''' - if pd.isna(s): - return np.nan - if isinstance(s, (int, float)): - return float(s) - # Replace comma with dot, then strip whitespace - s_clean = str(s).replace(',', '.').strip() - try: - return float(s_clean) - except ValueError: - return np.nan - -plot_data = [] -for gt_col, res_col in functional_systems_to_plot: - system_name = gt_col.split('.')[1] - - # Parse both columns with robust comma handling - gt_numeric = df[gt_col].apply(safe_parse) - res_numeric = df[res_col].apply(safe_parse) - - # Compute error (only where both are finite) - error = res_numeric - gt_numeric - - # Create temp DataFrame - temp_df = pd.DataFrame({ - 'system': system_name, - 'error': error - }).dropna() # drop rows where either was unparseable - - plot_data.append(temp_df) - -plot_df = pd.concat(plot_data, ignore_index=True) - -if plot_df.empty: - print("⚠️ Warning: No valid numeric error data to plot after robust parsing.") -else: - print(f"✅ Prepared error data with {len(plot_df)} data points.") - # Diagnostic: show a few samples - print("\n📌 Sample errors by system:") - for sys, grp in plot_df.groupby('system'): - print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}") - -# Ensure categorical ordering -plot_df['system'] = pd.Categorical( - plot_df['system'], - categories=[name.split('.')[1] for name, _ in functional_systems_to_plot], - ordered=True -) - - -# --- 5. Prepare Data for Diverging Stacked Bar Plot --- -print("\n📊 Preparing diverging stacked bar plot data...") - -# Define bins for error direction -def categorize_error(err): - if pd.isna(err): - return 'missing' - elif err < 0: - return 'underestimate' - elif err > 0: - return 'overestimate' - else: - return 'match' - -# Add category column (only on finite errors) -plot_df_clean = plot_df[plot_df['error'].notna()].copy() -plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error) - -# Count by system + category -category_counts = ( - plot_df_clean - .groupby(['system', 'category']) - .size() - .unstack(fill_value=0) - .reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0) -) -# Reorder systems -category_counts = category_counts.reindex(system_names) - -# Prepare for diverging plot: -# - Underestimates: plotted to the *left* (negative x) -# - Overestimates: plotted to the *right* (positive x) -# - Matches: centered (no width needed, or as a bar of width 0.2) -underestimate_counts = category_counts['underestimate'] -match_counts = category_counts['match'] -overestimate_counts = category_counts['overestimate'] - -# For diverging: left = -underestimate, right = overestimate -left_counts = underestimate_counts -right_counts = overestimate_counts - -# Compute max absolute bar height (for symmetric x-axis) -max_bar = max(left_counts.max(), right_counts.max(), 1) -plot_range = (-max_bar, max_bar) - -# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ... -n_systems = len(system_names) -positions = np.arange(n_systems) -left_positions = -positions - 0.5 # left-aligned underestimates -right_positions = positions + 0.5 # right-aligned overestimates - -# --- 6. Create Diverging Stacked Bar Plot --- -plt.figure(figsize=(12, 7)) - -# Colors: diverging palette -colors = { - 'underestimate': '#E74C3C', # Red (left) - 'match': '#2ECC71', # Green (center) - 'overestimate': '#F39C12' # Orange (right) -} - -# Plot underestimates (left side) -bars_left = plt.barh( - left_positions, - left_counts.values, - height=0.8, - left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values) - color=colors['underestimate'], - edgecolor='black', - linewidth=0.5, - alpha=0.9, - label='Underestimate' -) - -# Plot overestimates (right side) -bars_right = plt.barh( - right_positions, - right_counts.values, - height=0.8, - left=0, - color=colors['overestimate'], - edgecolor='black', - linewidth=0.5, - alpha=0.9, - label='Overestimate' -) - -# Plot matches (center — narrow bar) -# Use a very narrow width (0.2) centered at 0 -plt.barh( - positions, - match_counts.values, - height=0.2, - left=0, # starts at 0, extends right - color=colors['match'], - edgecolor='black', - linewidth=0.5, - alpha=0.9, - label='Exact Match' -) - -# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match) -# For perfect symmetry: -for i, count in enumerate(match_counts.values): - if count > 0: - plt.barh( - positions[i], - width=count, - left=-count/2, - height=0.25, - color=colors['match'], - edgecolor='black', - linewidth=0.5, - alpha=0.95 - ) - -# --- 7. Styling & Labels --- -# Zero reference line -plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8) - -# X-axis: symmetric around 0 -plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1) -plt.xticks(rotation=0, fontsize=10) -plt.xlabel('Count', fontsize=12) - -# Y-axis: system names at original positions (centered) -plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10) -plt.ylabel('Functional System', fontsize=12) - -# Title & layout -plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15) - -# Clean axes -ax = plt.gca() -ax.spines['top'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['left'].set_visible(False) # We only need bottom axis -ax.xaxis.set_ticks_position('bottom') -ax.yaxis.set_ticks_position('none') - -# Grid only along x -ax.xaxis.grid(True, linestyle=':', alpha=0.5) - -# Legend -from matplotlib.patches import Patch -legend_elements = [ - Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'), - Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'), - Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate') -] -plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10) - -# Optional: Add counts on bars -for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)): - if left > 0: - plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold') - if right > 0: - plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold') - if match > 0: - plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black') - -plt.tight_layout() - -# --- 8. Save & Show --- -os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) -plt.savefig(figure_save_path, format='svg', bbox_inches='tight') -print(f"✅ Diverging bar plot saved to {figure_save_path}") - -plt.show() - -## - - - -# %% Difference Plot Gemini -import pandas as pd -import matplotlib.pyplot as plt -import os -import numpy as np - -# --- Configuration & Theme --- -plt.rcParams['font.family'] = 'Arial' -figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg' - -# --- 1. Process Error Data with Magnitude Breakdown --- -system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] -plot_list = [] - -for gt_col, res_col in functional_systems_to_plot: - sys_name = gt_col.split('.')[1] - - # Robust parsing - gt = df[gt_col].apply(safe_parse) - res = df[res_col].apply(safe_parse) - error = res - gt - - # Granular Counts - matches = (error == 0).sum() - u_1 = (error == -1).sum() - u_2plus = (error <= -2).sum() - o_1 = (error == 1).sum() - o_2plus = (error >= 2).sum() - - total = error.dropna().count() - divisor = max(total, 1) - - plot_list.append({ - 'System': sys_name.replace('_', ' ').title(), - 'Matches': matches, 'MatchPct': (matches / divisor) * 100, - 'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus, - 'UnderPct': ((u_1 + u_2plus) / divisor) * 100, - 'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus, - 'OverPct': ((o_1 + o_2plus) / divisor) * 100 - }) - -stats_df = pd.DataFrame(plot_list) - -# --- 2. Plotting --- -fig, ax = plt.subplots(figsize=(13, 8)) - -# Define Magnitude Colors -c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1) -c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1) -bar_height = 0.6 -y_pos = np.arange(len(stats_df)) - -# Plot Under-scored (Stacked: -2+ then -1) -ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white') -ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white') - -# Plot Over-scored (Stacked: +1 then +2+) -ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white') -ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white') - -# --- 3. Aesthetics & Table Labels --- -for i, row in stats_df.iterrows(): - label_text = ( - f"$\\mathbf{{{row['System']}}}$\n" - f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" - f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)" - ) - # Position table text to the left - ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4) - -# Formatting -ax.axvline(0, color='black', linewidth=1.2) -ax.set_yticks([]) -ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold') -#ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35) - -# Absolute X-axis labels -ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) - -# Remove spines and add grid -for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False) -ax.xaxis.grid(True, linestyle='--', alpha=0.3) - -# Legend with magnitude info -ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) - -plt.tight_layout() -plt.show() -## - - -# %% Functional System Error Boxplots -import pandas as pd -import matplotlib.pyplot as plt -import os -import numpy as np -from matplotlib.patches import Patch -from matplotlib.lines import Line2D - -# --- Configuration & Theme --- -plt.rcParams['font.family'] = 'Arial' -figure_save_path = 'project/visuals/functional_systems_boxplot.svg' - -# --- 1. Build error data for boxplots --- -boxplot_data = [] -system_labels = [] -sample_sizes = [] - -for gt_col, res_col in functional_systems_to_plot: - sys_name = gt_col.split('.')[1] - - # Robust parsing - gt = df[gt_col].apply(safe_parse) - res = df[res_col].apply(safe_parse) - - # Error = result - ground truth - error = (res - gt).dropna() - - # Ignore all 0 errors - error = error[error != 0] - - # Keep only systems that actually have non-zero data - if len(error) > 0: - clean_name = sys_name.replace('_', ' ').title() - boxplot_data.append(error.values) - system_labels.append(clean_name) - sample_sizes.append(len(error)) - -# Safety check -if not boxplot_data: - raise ValueError("No valid non-zero error data available for any functional system.") - -# Put n into x-axis labels so it doesn't overlap the plot -xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] - -# --- 2. Plotting --- -fig, ax = plt.subplots(figsize=(14, 8)) - -bp = ax.boxplot( - boxplot_data, - vert=True, - patch_artist=True, - labels=xtick_labels, - showmeans=True, - meanline=False -) - -# --- 3. Styling --- -box_face = '#D6EAF8' -box_edge = '#2980B9' -whisker_col = '#7F8C8D' -median_col = '#C0392B' -mean_col = '#1ABC9C' -flier_face = '#95A5A6' -flier_edge = '#7F8C8D' - -for box in bp['boxes']: - box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) - -for whisker in bp['whiskers']: - whisker.set(color=whisker_col, linewidth=1.2) - -for cap in bp['caps']: - cap.set(color=whisker_col, linewidth=1.2) - -for median in bp['medians']: - median.set(color=median_col, linewidth=2) - -for mean in bp['means']: - mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) - -for flier in bp['fliers']: - flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) - -# Reference line at zero error -ax.axhline(0, color='black', linewidth=1.2, linestyle='--') - -# Labels and formatting -ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') -ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') - -# Rotate x labels for readability -plt.xticks(rotation=45, ha='right') - -# Grid and spines -ax.yaxis.grid(True, linestyle='--', alpha=0.3) -for spine in ['top', 'right']: - ax.spines[spine].set_visible(False) - -# --- 4. Legend above the plot, outside the axes --- -legend_handles = [ - Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), - Line2D([0], [0], color=median_col, lw=2, label='Median'), - Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, - markeredgecolor='black', markersize=7, label='Mean'), - Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, - markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), - Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') -] - -ax.legend( - handles=legend_handles, - loc='lower center', - bbox_to_anchor=(0.5, 1.02), - ncol=3, - frameon=False -) - -# Leave room at the top for the legend -plt.tight_layout(rect=[0, 0, 1, 0.90]) - -# Optional save -os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) -plt.savefig(figure_save_path, format='svg', bbox_inches='tight') - -plt.show() -## - -<<<<<<< Updated upstream -======= -# %% Functional System + EDSS Error Boxplots -import pandas as pd -import matplotlib.pyplot as plt -import os -import numpy as np -from matplotlib.patches import Patch -from matplotlib.lines import Line2D - -# --- Configuration & Theme --- -plt.rcParams['font.family'] = 'Arial' -figure_save_path = 'project/visuals/functional_systems_edss_boxplot.svg' - -# ------------------------------------------------------------ -# Expect functional_systems_to_plot like: -# [ -# ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL_OPTIC_FUNCTIONS'), -# ... -# ] -# -# Add EDSS here: -# ------------------------------------------------------------ -all_systems_to_plot = list(functional_systems_to_plot) + [ - ('GT.EDSS', 'result.EDSS') -] - -# --- 1. Build error data for boxplots --- -boxplot_data = [] -system_labels = [] -sample_sizes = [] - -for gt_col, res_col in all_systems_to_plot: - # Skip safely if a column is missing - if gt_col not in df.columns or res_col not in df.columns: - print(f"Skipping missing columns: {gt_col}, {res_col}") - continue - - sys_name = gt_col.split('.')[1] - - # Robust parsing - gt = df[gt_col].apply(safe_parse) - res = df[res_col].apply(safe_parse) - - # Error = result - ground truth - error = (res - gt).dropna() - - # Ignore all 0 errors - error = error[error != 0] - - # Keep only systems that actually have non-zero data - if len(error) > 0: - if sys_name == 'EDSS': - clean_name = 'EDSS' - else: - clean_name = sys_name.replace('_', ' ').title() - - boxplot_data.append(error.values) - system_labels.append(clean_name) - sample_sizes.append(len(error)) - -# Safety check -if not boxplot_data: - raise ValueError("No valid non-zero error data available for any functional system or EDSS.") - -# Put n into x-axis labels so it doesn't overlap the plot -xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] - -# --- 2. Plotting --- -fig, ax = plt.subplots(figsize=(15, 8)) - -bp = ax.boxplot( - boxplot_data, - vert=True, - patch_artist=True, - labels=xtick_labels, - showmeans=True, - meanline=False -) - -# --- 3. Styling --- -box_face = '#D6EAF8' -box_edge = '#2980B9' -whisker_col = '#7F8C8D' -median_col = '#C0392B' -mean_col = '#1ABC9C' -flier_face = '#95A5A6' -flier_edge = '#7F8C8D' - -for box in bp['boxes']: - box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) - -for whisker in bp['whiskers']: - whisker.set(color=whisker_col, linewidth=1.2) - -for cap in bp['caps']: - cap.set(color=whisker_col, linewidth=1.2) - -for median in bp['medians']: - median.set(color=median_col, linewidth=2) - -for mean in bp['means']: - mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) - -for flier in bp['fliers']: - flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) - -# Reference line at zero error -ax.axhline(0, color='black', linewidth=1.2, linestyle='--') - -# Labels and formatting -ax.set_xlabel('Functional System / EDSS', fontsize=11, fontweight='bold') -ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') - -# Rotate x labels for readability -plt.xticks(rotation=45, ha='right') - -# Grid and spines -ax.yaxis.grid(True, linestyle='--', alpha=0.3) -for spine in ['top', 'right']: - ax.spines[spine].set_visible(False) - -# --- 4. Legend above the plot, outside the axes --- -legend_handles = [ - Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), - Line2D([0], [0], color=median_col, lw=2, label='Median'), - Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, - markeredgecolor='black', markersize=7, label='Mean'), - Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, - markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), - Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') -] - -ax.legend( - handles=legend_handles, - loc='lower center', - bbox_to_anchor=(0.5, 1.02), - ncol=3, - frameon=False -) - -# Leave room at the top for the legend -plt.tight_layout(rect=[0, 0, 1, 0.90]) - -# Optional save -os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) -plt.savefig(figure_save_path, format='svg', bbox_inches='tight') - -plt.show() -## ->>>>>>> Stashed changes - -# %% test -# Diagnose: what are the actual differences? -print("\n🔍 Raw differences (first 5 rows per system):") -for gt_col, res_col in functional_systems_to_plot: - gt = df[gt_col].apply(safe_parse) - res = df[res_col].apply(safe_parse) - diff = res - gt - non_zero = (diff != 0).sum() - # Check if it's due to floating point noise - abs_diff = diff.abs() - tiny = (abs_diff > 0) & (abs_diff < 1e-10) - print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}") - -##