From c9cf9ae9a0f43e4035b10a31fb48abf9bd6cc6bd Mon Sep 17 00:00:00 2001 From: Shahin Ramezanzadeh Date: Fri, 29 May 2026 00:42:40 +0200 Subject: [PATCH] optimized results and new benchmark --- scripts/analyze_certainty.py | 2880 ++++++---- scripts/show_plots.py | 9933 +++++++++++++++++++++++++++++++++- 2 files changed, 11657 insertions(+), 1156 deletions(-) diff --git a/scripts/analyze_certainty.py b/scripts/analyze_certainty.py index 1b41d81..0dfc514 100644 --- a/scripts/analyze_certainty.py +++ b/scripts/analyze_certainty.py @@ -2228,26 +2228,1394 @@ ## # %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark2 +# +#import time +#import json +#import os +#import re +#import threading +#from datetime import datetime +#from pathlib import Path +# +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +#try: +# import psutil +#except ImportError: +# psutil = None +# print("⚠️ psutil is not installed. Resource metrics will be limited.") +# print("Install with: pip install psutil") +# +# +## ========================= +## CONFIGURATION +## ========================= +# +#load_dotenv() +# +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +# +#MODEL_CONFIGS = [ +# { +# "model_name": "gpt-oss-120b", +# "use_response_format": True, +# "temperature": 0.0, +# "max_tokens": 4096, +# "extra_body": None, +# }, +# { +# "model_name": "qwen3.6-27b", +# "use_response_format": False, +# "temperature": 0.0, +# "max_tokens": 4096, +# "extra_body": { +# "chat_template_kwargs": { +# "enable_thinking": False +# } +# }, +# }, +# { +# "model_name": "gemma-4-31B-it", +# "use_response_format": False, +# "temperature": 0.0, +# "max_tokens": 4096, +# "extra_body": None, +# }, +# ] +# +#INPUT_CSV ="/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/prompts/Komplett.txt" +# +#RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs" +# +#NUM_ITERATIONS = 10 +#STOP_ON_FIRST_ERROR = False +# +## For testing, set to e.g. 2. +## For full run, set to None. +#MAX_ROWS = None +## MAX_ROWS = 2 +# +#MAX_TOKENS = 4096 +#TEMPERATURE = 0.0 +# +#RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 +# +#SAVE_EVERY_N_ROWS = 1 +# +## Retries for invalid JSON / truncated JSON +#MAX_JSON_RETRIES = 2 +#RETRY_SLEEP_SEC = 2 +# +# +## ========================= +## VALID CLINICAL RANGES +## ========================= +# +#EDSS_MIN = 0.0 +#EDSS_MAX = 10.0 +# +#FUNCTIONAL_SYSTEM_RANGES = { +# "VISUAL_OPTIC_FUNCTIONS": (0.0, 6.0), +# "BRAINSTEM_FUNCTIONS": (0.0, 6.0), +# "PYRAMIDAL_FUNCTIONS": (0.0, 6.0), +# "CEREBELLAR_FUNCTIONS": (0.0, 6.0), +# "SENSORY_FUNCTIONS": (0.0, 6.0), +# "BOWEL_AND_BLADDER_FUNCTIONS": (0.0, 6.0), +# "CEREBRAL_FUNCTIONS": (0.0, 6.0), +# "AMBULATION": (0.0, 10.0), +#} +# +#REQUIRED_TOP_LEVEL_FIELDS = [ +# "reason", +# "klassifizierbar", +# "EDSS", +# "certainty_percent", +# "subcategories", +#] +# +# +## ========================= +## CLIENT +## ========================= +# +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +# +## ========================= +## HELPERS +## ========================= +# +#def safe_dir_name(name: str) -> str: +# name = str(name).strip() +# name = re.sub(r"[^\w\-.]+", "_", name) +# return name[:150] +# +# +#def now_timestamp() -> str: +# return datetime.now().strftime("%Y%m%d_%H%M%S") +# +# +#def get_process(): +# if psutil is None: +# return None +# return psutil.Process(os.getpid()) +# +# +#def get_memory_rss_mb(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# return process.memory_info().rss / (1024 * 1024) +# +# +#def get_cpu_times_sec(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# cpu_times = process.cpu_times() +# return cpu_times.user + cpu_times.system +# +# +#class ResourceSampler: +# def __init__(self, interval_sec=0.05): +# self.interval_sec = interval_sec +# self.process = get_process() +# self.running = False +# self.thread = None +# self.samples_mb = [] +# +# def start(self): +# if psutil is None: +# return +# +# self.running = True +# self.samples_mb = [] +# self.thread = threading.Thread(target=self._sample_loop, daemon=True) +# self.thread.start() +# +# def stop(self): +# if psutil is None: +# return +# +# self.running = False +# if self.thread is not None: +# self.thread.join(timeout=1.0) +# +# def _sample_loop(self): +# while self.running: +# try: +# rss_mb = get_memory_rss_mb(self.process) +# self.samples_mb.append(rss_mb) +# except Exception: +# pass +# time.sleep(self.interval_sec) +# +# @property +# def peak_rss_mb(self): +# if not self.samples_mb: +# return None +# return max(self.samples_mb) +# +# +## ========================= +## JSON EXTRACTION +## ========================= +# +#def extract_json_from_text(text): +# if text is None: +# raise ValueError("Model returned empty content: message.content is None") +# +# text = str(text).strip() +# +# if not text: +# raise ValueError("Model returned empty content") +# +# text = ( +# text.replace("```json", "") +# .replace("```JSON", "") +# .replace("```Json", "") +# .replace("```", "") +# .strip() +# ) +# +# # Direct parse +# try: +# parsed = json.loads(text) +# if isinstance(parsed, dict): +# return parsed +# except json.JSONDecodeError: +# pass +# +# # Balanced JSON candidates +# candidates = [] +# stack = [] +# start_idx = None +# in_string = False +# escape = False +# +# for i, ch in enumerate(text): +# if escape: +# escape = False +# continue +# +# if ch == "\\": +# escape = True +# continue +# +# if ch == '"': +# in_string = not in_string +# continue +# +# if in_string: +# continue +# +# if ch == "{": +# if not stack: +# start_idx = i +# stack.append(ch) +# +# elif ch == "}": +# if stack: +# stack.pop() +# if not stack and start_idx is not None: +# candidates.append(text[start_idx:i + 1]) +# start_idx = None +# +# valid_objects = [] +# +# for candidate in candidates: +# candidate = candidate.strip() +# lowered = candidate.lower() +# +# invalid_markers = [ +# "true/false", +# "null or", +# "oder zahl", +# "0.0-6.0", +# "0.0-10.0", +# "zahl zwischen", +# "...", +# ] +# +# if any(marker in lowered for marker in invalid_markers): +# continue +# +# try: +# parsed = json.loads(candidate) +# if isinstance(parsed, dict): +# valid_objects.append(parsed) +# except json.JSONDecodeError: +# continue +# +# for obj in reversed(valid_objects): +# if ( +# "klassifizierbar" in obj +# and "certainty_percent" in obj +# and "subcategories" in obj +# ): +# return obj +# +# if valid_objects: +# return valid_objects[-1] +# +# stripped = text.strip() +# if stripped.startswith("{") and not stripped.endswith("}"): +# raise ValueError( +# "Model output looks like truncated JSON. " +# f"Raw output starts with: {text[:1000]}" +# ) +# +# raise ValueError( +# "No valid JSON object found in model output. " +# f"Raw output starts with: {text[:1000]}" +# ) +# +# +#def extract_message_content(message): +# raw_content = getattr(message, "content", None) +# +# if raw_content is not None: +# return raw_content +# +# msg_dict = None +# +# try: +# msg_dict = message.model_dump() +# except Exception: +# try: +# msg_dict = dict(message) +# except Exception: +# msg_dict = None +# +# if not isinstance(msg_dict, dict): +# return None +# +# for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: +# value = msg_dict.get(key) +# if value: +# return value +# +# possible_content = msg_dict.get("content") +# if isinstance(possible_content, list): +# parts = [] +# for block in possible_content: +# if isinstance(block, dict): +# if "text" in block: +# parts.append(str(block["text"])) +# elif "content" in block: +# parts.append(str(block["content"])) +# if parts: +# return "\n".join(parts).strip() +# +# return None +# +# +## ========================= +## READ INSTRUCTIONS +## ========================= +# +#with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: +# EDSS_INSTRUCTIONS = f.read().strip() +# +# +## ========================= +## PROMPT +## ========================= +# +#def build_prompt(patient_text): +# return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. +# +#Extrahiere: +#1. Gesamt-EDSS-Score von 0.0 bis 10.0 +#2. Alle 8 EDSS-Unterkategorien +#3. Sicherheit als Ganzzahl von 0 bis 100 +# +#Antworte ausschließlich mit EINEM validen JSON-Objekt. +#Kein Markdown. +#Keine Code-Fences. +#Kein Text vor oder nach JSON. +#Keine Platzhalter. +#Kopiere kein Schema. +# +#Das JSON muss exakt diese Schlüssel enthalten: +#- reason +#- klassifizierbar +#- EDSS +#- certainty_percent +#- subcategories +# +#Die subcategories müssen exakt diese 8 Schlüssel enthalten: +#- VISUAL_OPTIC_FUNCTIONS +#- BRAINSTEM_FUNCTIONS +#- PYRAMIDAL_FUNCTIONS +#- CEREBELLAR_FUNCTIONS +#- SENSORY_FUNCTIONS +#- BOWEL_AND_BLADDER_FUNCTIONS +#- CEREBRAL_FUNCTIONS +#- AMBULATION +# +#Werte: +#- klassifizierbar: true oder false +#- EDSS: Zahl von 0.0 bis 10.0 oder null +#- certainty_percent: Ganzzahl von 0 bis 100 +#- Unterkategorien: Zahl oder null +#- VISUAL_OPTIC_FUNCTIONS maximal 6.0 +#- BRAINSTEM_FUNCTIONS maximal 6.0 +#- PYRAMIDAL_FUNCTIONS maximal 6.0 +#- CEREBELLAR_FUNCTIONS maximal 6.0 +#- SENSORY_FUNCTIONS maximal 6.0 +#- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 +#- CEREBRAL_FUNCTIONS maximal 6.0 +#- AMBULATION maximal 10.0 +#- reason: maximal 250 Zeichen, Deutsch +# +#Wenn klassifizierbar false ist, setze EDSS auf null. +# +#Valide Beispielausgabe: +#{{ +# "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", +# "klassifizierbar": true, +# "EDSS": 2.0, +# "certainty_percent": 90, +# "subcategories": {{ +# "VISUAL_OPTIC_FUNCTIONS": null, +# "BRAINSTEM_FUNCTIONS": null, +# "PYRAMIDAL_FUNCTIONS": 1.0, +# "CEREBELLAR_FUNCTIONS": 1.0, +# "SENSORY_FUNCTIONS": 1.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": null, +# "CEREBRAL_FUNCTIONS": null, +# "AMBULATION": 0.0 +# }} +#}} +# +#EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +# +#Gib ausschließlich das finale JSON-Objekt zurück. +#''' +# +# +## ========================= +## VALIDATION, NOT NORMALIZATION +## ========================= +# +#def parse_float_preserve_raw(value): +# """ +# Try to parse a value as float without clipping or correcting it. +# +# Returns: +# raw_value: original value exactly as present in parsed JSON +# numeric_value: float or None +# is_numeric: bool +# """ +# raw_value = value +# +# if value is None: +# return raw_value, None, False +# +# if isinstance(value, bool): +# return raw_value, None, False +# +# try: +# numeric_value = float(str(value).replace(",", ".")) +# return raw_value, numeric_value, True +# except Exception: +# return raw_value, None, False +# +# +#def is_in_range(value, min_value, max_value): +# """ +# Range check without clipping. +# """ +# if value is None: +# return False +# return min_value <= value <= max_value +# +# +#def validate_model_output(parsed): +# """ +# Validate parsed model output without repairing/clipping clinical values. +# +# Important: +# - Does NOT clip EDSS. +# - Does NOT clip functional system values. +# - Does NOT insert default EDSS. +# - Does NOT insert default certainty_percent. +# - Missing fields are kept as None. +# - Adds explicit validity flags for scientific transparency. +# """ +# +# validation = { +# "json_parse_success": isinstance(parsed, dict), +# "required_fields_present": False, +# "required_schema_success": False, +# "clinical_range_valid": False, +# "certainty_present": False, +# +# "missing_required_fields": [], +# "missing_subcategory_fields": [], +# +# "EDSS_is_numeric": False, +# "EDSS_in_valid_range": False, +# } +# +# if not isinstance(parsed, dict): +# return { +# "raw_output": parsed, +# "validated_output": {}, +# "validation": validation, +# } +# +# missing_required = [ +# field for field in REQUIRED_TOP_LEVEL_FIELDS +# if field not in parsed +# ] +# +# validation["missing_required_fields"] = missing_required +# validation["required_fields_present"] = len(missing_required) == 0 +# +# validated = {} +# +# validated["reason"] = parsed.get("reason", None) +# validated["klassifizierbar"] = parsed.get("klassifizierbar", None) +# +# raw_certainty = parsed.get("certainty_percent", None) +# validated["raw_certainty_percent"] = raw_certainty +# validation["certainty_present"] = "certainty_percent" in parsed and raw_certainty is not None +# +# _, certainty_numeric, certainty_is_numeric = parse_float_preserve_raw(raw_certainty) +# validated["certainty_percent"] = certainty_numeric if certainty_is_numeric else None +# validated["certainty_percent_is_numeric"] = certainty_is_numeric +# validated["certainty_percent_in_valid_range"] = ( +# is_in_range(certainty_numeric, 0.0, 100.0) +# if certainty_is_numeric else False +# ) +# +# raw_edss = parsed.get("EDSS", None) +# raw_edss, edss_numeric, edss_is_numeric = parse_float_preserve_raw(raw_edss) +# +# validated["raw_EDSS"] = raw_edss +# validated["EDSS_numeric"] = edss_numeric +# validated["EDSS"] = edss_numeric # Backward-compatible; parsed only, not clipped +# validated["EDSS_is_numeric"] = edss_is_numeric +# validated["EDSS_in_valid_range"] = ( +# is_in_range(edss_numeric, EDSS_MIN, EDSS_MAX) +# if edss_is_numeric else False +# ) +# +# validation["EDSS_is_numeric"] = validated["EDSS_is_numeric"] +# validation["EDSS_in_valid_range"] = validated["EDSS_in_valid_range"] +# +# raw_subcategories = parsed.get("subcategories", None) +# +# if isinstance(raw_subcategories, dict): +# subcategories = raw_subcategories +# else: +# subcategories = {} +# +# validated["subcategories"] = {} +# validated["raw_subcategories"] = {} +# validated["subcategory_validation"] = {} +# +# missing_subcats = [] +# +# for subcat, (min_value, max_value) in FUNCTIONAL_SYSTEM_RANGES.items(): +# if subcat not in subcategories: +# missing_subcats.append(subcat) +# +# raw_value = subcategories.get(subcat, None) +# raw_value, numeric_value, is_numeric_value = parse_float_preserve_raw(raw_value) +# in_valid_range = ( +# is_in_range(numeric_value, min_value, max_value) +# if is_numeric_value else False +# ) +# +# validated["raw_subcategories"][subcat] = raw_value +# validated["subcategories"][subcat] = numeric_value +# +# validated["subcategory_validation"][subcat] = { +# "is_numeric": is_numeric_value, +# "in_valid_range": in_valid_range, +# "min_allowed": min_value, +# "max_allowed": max_value, +# } +# +# validation["missing_subcategory_fields"] = missing_subcats +# +# subcategory_schema_present = len(missing_subcats) == 0 +# +# all_subcats_numeric = all( +# validated["subcategory_validation"][subcat]["is_numeric"] +# for subcat in FUNCTIONAL_SYSTEM_RANGES +# ) +# +# all_subcats_in_range = all( +# validated["subcategory_validation"][subcat]["in_valid_range"] +# for subcat in FUNCTIONAL_SYSTEM_RANGES +# ) +# +# validated["all_functional_systems_numeric"] = all_subcats_numeric +# validated["all_functional_systems_in_valid_range"] = all_subcats_in_range +# +# validation["clinical_range_valid"] = ( +# validated["EDSS_in_valid_range"] +# and all_subcats_in_range +# ) +# +# validation["required_schema_success"] = ( +# validation["required_fields_present"] +# and subcategory_schema_present +# ) +# +# return { +# "raw_output": parsed, +# "validated_output": validated, +# "validation": validation, +# } +# +# +## ========================= +## API CALL +## ========================= +# +#def make_chat_completion(model_config, prompt): +# model_name = model_config["model_name"] +# +# kwargs = dict( +# messages=[ +# { +# "role": "system", +# "content": ( +# "Du bist ein JSON-Generator. " +# "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " +# "Keine Erklärung. Kein Markdown. Keine Code-Fences. " +# "Keine Platzhalter. Kein Schema kopieren. " +# "Das JSON muss vollständig geschlossen sein." +# ) +# }, +# { +# "role": "user", +# "content": prompt +# } +# ], +# model=model_name, +# max_tokens=model_config.get("max_tokens", MAX_TOKENS), +# temperature=model_config.get("temperature", TEMPERATURE), +# ) +# +# if model_config.get("use_response_format", False): +# kwargs["response_format"] = {"type": "json_object"} +# +# extra_body = model_config.get("extra_body") +# if extra_body is not None: +# kwargs["extra_body"] = extra_body +# +# return client.chat.completions.create(**kwargs) +# +# +## ========================= +## INFERENCE FUNCTION WITH RETRIES +## ========================= +# +#def run_inference(patient_text, model_config): +# model_name = model_config["model_name"] +# prompt = build_prompt(patient_text) +# +# process = get_process() +# sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) +# +# wall_start = time.perf_counter() +# cpu_start = get_cpu_times_sec(process) +# rss_start_mb = get_memory_rss_mb(process) +# +# sampler.start() +# +# raw_content = None +# raw_response_debug = None +# raw_parsed_output = None +# validation = None +# last_error = None +# +# prompt_tokens = None +# completion_tokens = None +# total_tokens = None +# +# try: +# for attempt in range(1, MAX_JSON_RETRIES + 2): +# try: +# response = make_chat_completion( +# model_config=model_config, +# prompt=prompt +# ) +# +# message = response.choices[0].message +# raw_content = extract_message_content(message) +# +# try: +# raw_response_debug = response.model_dump() +# except Exception: +# raw_response_debug = str(response) +# +# usage = getattr(response, "usage", None) +# if usage is not None: +# prompt_tokens = getattr(usage, "prompt_tokens", None) +# completion_tokens = getattr(usage, "completion_tokens", None) +# total_tokens = getattr(usage, "total_tokens", None) +# +# parsed = extract_json_from_text(raw_content) +# validation_package = validate_model_output(parsed) +# +# success = True +# error = None +# +# result = validation_package["validated_output"] +# validation = validation_package["validation"] +# raw_parsed_output = validation_package["raw_output"] +# +# break +# +# except Exception as e: +# last_error = str(e) +# +# if attempt <= MAX_JSON_RETRIES: +# print( +# f"\n⚠️ JSON failed on attempt {attempt}. " +# f"Retrying row. Error: {last_error[:300]}" +# ) +# time.sleep(RETRY_SLEEP_SEC) +# continue +# +# raise +# +# except Exception as e: +# print(f"❌ Inference error: {e}") +# +# success = False +# error = str(e) +# result = None +# raw_parsed_output = None +# +# validation = { +# "json_parse_success": False, +# "required_fields_present": False, +# "required_schema_success": False, +# "clinical_range_valid": False, +# "certainty_present": False, +# "missing_required_fields": [], +# "missing_subcategory_fields": [], +# "EDSS_is_numeric": False, +# "EDSS_in_valid_range": False, +# } +# +# finally: +# sampler.stop() +# +# wall_end = time.perf_counter() +# cpu_end = get_cpu_times_sec(process) +# rss_end_mb = get_memory_rss_mb(process) +# +# wall_time_sec = wall_end - wall_start +# +# if cpu_start is not None and cpu_end is not None: +# process_cpu_time_sec = cpu_end - cpu_start +# else: +# process_cpu_time_sec = None +# +# if rss_start_mb is not None and rss_end_mb is not None: +# rss_delta_mb = rss_end_mb - rss_start_mb +# else: +# rss_delta_mb = None +# +# return { +# "success": success, +# "error": error, +# "result": result, +# +# "validation": validation, +# "raw_parsed_output": raw_parsed_output, +# +# "model": model_name, +# +# "inference_time_sec": wall_time_sec, +# +# "process_cpu_time_sec": process_cpu_time_sec, +# "rss_before_mb": rss_start_mb, +# "rss_after_mb": rss_end_mb, +# "rss_delta_mb": rss_delta_mb, +# "peak_rss_mb": sampler.peak_rss_mb, +# +# "prompt_tokens": prompt_tokens, +# "completion_tokens": completion_tokens, +# "total_tokens": total_tokens, +# +# # Keeping raw content improves auditability but can make files large. +# # To save space, change this to: raw_content if not success else None +# "raw_content": raw_content, +# "raw_response_debug": raw_response_debug if not success else None, +# "last_error": last_error, +# } +# +# +## ========================= +## BUILD PATIENT TEXT +## ========================= +# +#def build_patient_text(row): +# return ( +# str(row.get("T_Zusammenfassung", "")) + "\n" + +# str(row.get("Diagnosen", "")) + "\n" + +# str(row.get("T_KlinBef", "")) + "\n" + +# str(row.get("T_Befunde", "")) +# ) +# +# +## ========================= +## FLATTEN RESULTS FOR CSV +## ========================= +# +#def flatten_result(record): +# """ +# Flatten one benchmark record for CSV export. +# +# This preserves: +# - raw model values +# - parsed numeric values without clipping +# - validity flags +# - backward-compatible columns where possible +# """ +# +# validation = record.get("validation") or {} +# result = record.get("result") or {} +# +# flat = { +# "model": record.get("model"), +# "iteration": record.get("iteration"), +# "row_index": record.get("row_index"), +# "row_number_in_run": record.get("row_number_in_run"), +# "unique_id": record.get("unique_id"), +# "MedDatum": record.get("MedDatum"), +# +# "success": record.get("success"), +# "error": record.get("error"), +# "last_error": record.get("last_error"), +# +# "json_parse_success": validation.get("json_parse_success"), +# "required_fields_present": validation.get("required_fields_present"), +# "required_schema_success": validation.get("required_schema_success"), +# "clinical_range_valid": validation.get("clinical_range_valid"), +# "certainty_present": validation.get("certainty_present"), +# +# "missing_required_fields": json.dumps( +# validation.get("missing_required_fields", []), +# ensure_ascii=False +# ), +# "missing_subcategory_fields": json.dumps( +# validation.get("missing_subcategory_fields", []), +# ensure_ascii=False +# ), +# +# "inference_time_sec": record.get("inference_time_sec"), +# "process_cpu_time_sec": record.get("process_cpu_time_sec"), +# "rss_before_mb": record.get("rss_before_mb"), +# "rss_after_mb": record.get("rss_after_mb"), +# "rss_delta_mb": record.get("rss_delta_mb"), +# "peak_rss_mb": record.get("peak_rss_mb"), +# +# "prompt_tokens": record.get("prompt_tokens"), +# "completion_tokens": record.get("completion_tokens"), +# "total_tokens": record.get("total_tokens"), +# +# "raw_content": record.get("raw_content"), +# "raw_parsed_output": json.dumps(record.get("raw_parsed_output"), ensure_ascii=False), +# +# # Backward-compatible fields +# "reason": result.get("reason"), +# "klassifizierbar": result.get("klassifizierbar"), +# +# "raw_certainty_percent": result.get("raw_certainty_percent"), +# "certainty_percent": result.get("certainty_percent"), +# "certainty_percent_is_numeric": result.get("certainty_percent_is_numeric"), +# "certainty_percent_in_valid_range": result.get("certainty_percent_in_valid_range"), +# +# # EDSS raw/numeric/validity fields +# "raw_EDSS": result.get("raw_EDSS"), +# "EDSS_numeric": result.get("EDSS_numeric"), +# "EDSS": result.get("EDSS"), # backward-compatible; same as EDSS_numeric, not clipped +# "EDSS_is_numeric": result.get("EDSS_is_numeric"), +# "EDSS_in_valid_range": result.get("EDSS_in_valid_range"), +# +# "all_functional_systems_numeric": result.get("all_functional_systems_numeric"), +# "all_functional_systems_in_valid_range": result.get("all_functional_systems_in_valid_range"), +# } +# +# raw_subcategories = result.get("raw_subcategories", {}) +# numeric_subcategories = result.get("subcategories", {}) +# subcat_validation = result.get("subcategory_validation", {}) +# +# for subcat in FUNCTIONAL_SYSTEM_RANGES: +# raw_value = None +# numeric_value = None +# is_numeric = False +# in_valid_range = False +# +# if isinstance(raw_subcategories, dict): +# raw_value = raw_subcategories.get(subcat) +# +# if isinstance(numeric_subcategories, dict): +# numeric_value = numeric_subcategories.get(subcat) +# +# if isinstance(subcat_validation, dict): +# flags = subcat_validation.get(subcat, {}) +# if isinstance(flags, dict): +# is_numeric = flags.get("is_numeric", False) +# in_valid_range = flags.get("in_valid_range", False) +# +# # New transparent columns +# flat[f"raw_subcat_{subcat}"] = raw_value +# flat[f"numeric_subcat_{subcat}"] = numeric_value +# flat[f"subcat_{subcat}_is_numeric"] = is_numeric +# flat[f"subcat_{subcat}_in_valid_range"] = in_valid_range +# +# # Backward-compatible old column name. +# # This is numeric but NOT clipped. +# flat[f"subcat_{subcat}"] = numeric_value +# +# return flat +# +# +## ========================= +## SUMMARY STATISTICS +## ========================= +# +#def summarize_records(records): +# """ +# Create transparent summary statistics per model. +# +# Separates: +# - JSON/schema validity +# - numeric parse validity +# - clinical range validity +# - out-of-range outputs +# """ +# +# df = pd.DataFrame([flatten_result(r) for r in records]) +# +# if df.empty: +# return pd.DataFrame() +# +# def bool_mean(col): +# if col not in df.columns: +# return None +# return df[col].fillna(False).astype(bool).mean() +# +# def bool_sum(col): +# if col not in df.columns: +# return None +# return int(df[col].fillna(False).astype(bool).sum()) +# +# n_records = len(df) +# +# summary = { +# "model": df["model"].iloc[0] if "model" in df.columns else None, +# "n_total_responses": n_records, +# +# "n_success": bool_sum("success"), +# "success_rate": bool_mean("success"), +# +# "n_json_parse_success": bool_sum("json_parse_success"), +# "json_parse_success_rate": bool_mean("json_parse_success"), +# +# "n_required_fields_present": bool_sum("required_fields_present"), +# "required_fields_present_rate": bool_mean("required_fields_present"), +# +# "n_required_schema_success": bool_sum("required_schema_success"), +# "required_schema_success_rate": bool_mean("required_schema_success"), +# +# "n_clinical_range_valid": bool_sum("clinical_range_valid"), +# "clinical_range_valid_rate": bool_mean("clinical_range_valid"), +# +# "n_certainty_present": bool_sum("certainty_present"), +# "certainty_present_rate": bool_mean("certainty_present"), +# +# "n_EDSS_numeric": bool_sum("EDSS_is_numeric"), +# "EDSS_numeric_rate": bool_mean("EDSS_is_numeric"), +# +# "n_EDSS_in_valid_range": bool_sum("EDSS_in_valid_range"), +# "EDSS_valid_range_rate": bool_mean("EDSS_in_valid_range"), +# } +# +# # EDSS out-of-range among numeric EDSS outputs +# if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: +# edss_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) +# edss_valid = df["EDSS_in_valid_range"].fillna(False).astype(bool) +# edss_out_of_range = edss_numeric & (~edss_valid) +# +# summary["n_EDSS_out_of_range"] = int(edss_out_of_range.sum()) +# summary["EDSS_out_of_range_rate_total"] = float(edss_out_of_range.mean()) +# summary["EDSS_out_of_range_rate_among_numeric"] = ( +# float(edss_out_of_range.sum() / edss_numeric.sum()) +# if edss_numeric.sum() > 0 else None +# ) +# +# # Functional system rates +# fs_out_of_range_any = pd.Series(False, index=df.index) +# fs_valid_all = pd.Series(True, index=df.index) +# +# for subcat in FUNCTIONAL_SYSTEM_RANGES: +# numeric_col = f"subcat_{subcat}_is_numeric" +# valid_col = f"subcat_{subcat}_in_valid_range" +# +# if numeric_col in df.columns: +# numeric_series = df[numeric_col].fillna(False).astype(bool) +# else: +# numeric_series = pd.Series(False, index=df.index) +# +# if valid_col in df.columns: +# valid_series = df[valid_col].fillna(False).astype(bool) +# else: +# valid_series = pd.Series(False, index=df.index) +# +# out_of_range_series = numeric_series & (~valid_series) +# +# summary[f"n_{subcat}_numeric"] = int(numeric_series.sum()) +# summary[f"{subcat}_numeric_rate"] = float(numeric_series.mean()) +# +# summary[f"n_{subcat}_in_valid_range"] = int(valid_series.sum()) +# summary[f"{subcat}_valid_range_rate"] = float(valid_series.mean()) +# +# summary[f"n_{subcat}_out_of_range"] = int(out_of_range_series.sum()) +# summary[f"{subcat}_out_of_range_rate_total"] = float(out_of_range_series.mean()) +# summary[f"{subcat}_out_of_range_rate_among_numeric"] = ( +# float(out_of_range_series.sum() / numeric_series.sum()) +# if numeric_series.sum() > 0 else None +# ) +# +# fs_out_of_range_any = fs_out_of_range_any | out_of_range_series +# fs_valid_all = fs_valid_all & valid_series +# +# summary["n_any_functional_system_out_of_range"] = int(fs_out_of_range_any.sum()) +# summary["any_functional_system_out_of_range_rate_total"] = float(fs_out_of_range_any.mean()) +# +# summary["n_all_functional_systems_in_valid_range"] = int(fs_valid_all.sum()) +# summary["all_functional_systems_valid_range_rate"] = float(fs_valid_all.mean()) +# +# numeric_cols = [ +# "inference_time_sec", +# "process_cpu_time_sec", +# "rss_delta_mb", +# "peak_rss_mb", +# "prompt_tokens", +# "completion_tokens", +# "total_tokens", +# "certainty_percent", +# "EDSS_numeric", +# ] +# +# for col in numeric_cols: +# if col in df.columns: +# values = pd.to_numeric(df[col], errors="coerce") +# summary[f"{col}_mean"] = values.mean() +# summary[f"{col}_median"] = values.median() +# summary[f"{col}_std"] = values.std() +# summary[f"{col}_min"] = values.min() +# summary[f"{col}_max"] = values.max() +# +# if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: +# primary_valid_only = ( +# df["EDSS_is_numeric"].fillna(False).astype(bool) +# & df["EDSS_in_valid_range"].fillna(False).astype(bool) +# ) +# +# sensitivity_all_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) +# +# summary["n_primary_valid_only_EDSS"] = int(primary_valid_only.sum()) +# summary["primary_valid_only_EDSS_rate"] = float(primary_valid_only.mean()) +# +# summary["n_sensitivity_all_numeric_EDSS"] = int(sensitivity_all_numeric.sum()) +# summary["sensitivity_all_numeric_EDSS_rate"] = float(sensitivity_all_numeric.mean()) +# +# return pd.DataFrame([summary]) +# +# +## ========================= +## ANALYSIS DATASET HELPERS +## ========================= +# +#def create_analysis_datasets(records): +# """ +# Create two transparent EDSS analysis datasets: +# +# 1. primary_valid_only: +# Only numeric EDSS predictions within the valid clinical range. +# +# 2. sensitivity_all_numeric: +# All numeric EDSS predictions, including out-of-range values. +# No clipping is applied. +# """ +# +# df = pd.DataFrame([flatten_result(r) for r in records]) +# +# if df.empty: +# return df.copy(), df.copy() +# +# primary_valid_only = df[ +# df["EDSS_is_numeric"].fillna(False).astype(bool) +# & df["EDSS_in_valid_range"].fillna(False).astype(bool) +# ].copy() +# +# sensitivity_all_numeric = df[ +# df["EDSS_is_numeric"].fillna(False).astype(bool) +# ].copy() +# +# return primary_valid_only, sensitivity_all_numeric +# +# +## ========================= +## INCREMENTAL SAVE HELPERS +## ========================= +# +#def append_jsonl(path, record): +# with open(path, "a", encoding="utf-8") as f: +# f.write(json.dumps(record, ensure_ascii=False) + "\n") +# f.flush() +# os.fsync(f.fileno()) +# +# +#def append_csv(path, record): +# flat = flatten_result(record) +# df_one = pd.DataFrame([flat]) +# file_exists = Path(path).exists() +# df_one.to_csv(path, mode="a", header=not file_exists, index=False) +# +# +## ========================= +## MAIN LOOP +## ========================= +# +#if __name__ == "__main__": +# +# run_timestamp = now_timestamp() +# +# results_root = Path(RESULTS_ROOT) +# results_root.mkdir(parents=True, exist_ok=True) +# +# run_root = results_root / f"run_{run_timestamp}" +# run_root.mkdir(parents=True, exist_ok=True) +# +# print(f"Results root: {run_root}") +# +# df = pd.read_csv(INPUT_CSV, sep=";") +# +# if MAX_ROWS is not None: +# df = df.head(MAX_ROWS) +# +# total_rows = len(df) +# +# model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] +# +# print(f"Loaded {total_rows} patient records.") +# print(f"Models: {model_names_for_print}") +# print(f"Iterations per model: {NUM_ITERATIONS}") +# +# all_model_summaries = [] +# +# for model_config in MODEL_CONFIGS: +# model_name = model_config["model_name"] +# safe_model = safe_dir_name(model_name) +# +# model_dir = run_root / safe_model +# model_dir.mkdir(parents=True, exist_ok=True) +# +# print(f"\n{'#' * 80}") +# print(f"MODEL: {model_name}") +# print(f"use_response_format: {model_config.get('use_response_format', False)}") +# print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") +# print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") +# print(f"Saving to: {model_dir}") +# print(f"{'#' * 80}") +# +# model_records = [] +# model_start = time.perf_counter() +# +# for iteration in range(1, NUM_ITERATIONS + 1): +# print(f"\n{'=' * 60}") +# print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") +# print(f"{'=' * 60}") +# +# iteration_results = [] +# iteration_start = time.perf_counter() +# +# incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" +# incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" +# +# print(f"Incremental JSONL: {incremental_jsonl_path}") +# print(f"Incremental CSV: {incremental_csv_path}") +# +# for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): +# print( +# f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", +# end="", +# flush=True +# ) +# +# try: +# patient_text = build_patient_text(row) +# +# record = run_inference( +# patient_text=patient_text, +# model_config=model_config +# ) +# +# record["iteration"] = iteration +# record["row_index"] = int(idx) +# record["row_number_in_run"] = int(loop_i) +# record["unique_id"] = row.get("unique_id", f"row_{idx}") +# record["MedDatum"] = row.get("MedDatum", None) +# +# iteration_results.append(record) +# model_records.append(record) +# +# if loop_i % SAVE_EVERY_N_ROWS == 0: +# append_jsonl(incremental_jsonl_path, record) +# append_csv(incremental_csv_path, record) +# +# if record["success"]: +# res = record["result"] or {} +# edss_display = res.get("EDSS_numeric", None) +# edss_valid = res.get("EDSS_in_valid_range", False) +# +# print( +# f" ✅ EDSS={edss_display}, " +# f"valid_range={edss_valid}, " +# f"time={record['inference_time_sec']:.2f}s" +# ) +# else: +# print(f" ❌ {record.get('error', 'Unknown error')}") +# +# except Exception as e: +# print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") +# +# fallback_record = { +# "success": False, +# "error": str(e), +# "last_error": str(e), +# "result": None, +# +# "validation": { +# "json_parse_success": False, +# "required_fields_present": False, +# "required_schema_success": False, +# "clinical_range_valid": False, +# "certainty_present": False, +# "missing_required_fields": [], +# "missing_subcategory_fields": [], +# "EDSS_is_numeric": False, +# "EDSS_in_valid_range": False, +# }, +# "raw_parsed_output": None, +# +# "model": model_name, +# "iteration": iteration, +# "row_index": int(idx), +# "row_number_in_run": int(loop_i), +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None), +# +# "inference_time_sec": None, +# "process_cpu_time_sec": None, +# "rss_before_mb": None, +# "rss_after_mb": None, +# "rss_delta_mb": None, +# "peak_rss_mb": None, +# +# "prompt_tokens": None, +# "completion_tokens": None, +# "total_tokens": None, +# +# "raw_content": None, +# "raw_response_debug": None, +# } +# +# iteration_results.append(fallback_record) +# model_records.append(fallback_record) +# +# append_jsonl(incremental_jsonl_path, fallback_record) +# append_csv(incremental_csv_path, fallback_record) +# +# if STOP_ON_FIRST_ERROR: +# break +# +# iteration_elapsed = time.perf_counter() - iteration_start +# +# # Final full per-iteration JSON +# iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" +# with open(iter_json_path, "w", encoding="utf-8") as f: +# json.dump(iteration_results, f, indent=2, ensure_ascii=False) +# +# # Final full per-iteration CSV +# iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" +# iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) +# iter_flat_df.to_csv(iter_csv_path, index=False) +# +# # Transparent analysis datasets +# primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(iteration_results) +# +# primary_valid_only_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_primary_valid_only.csv" +# sensitivity_all_numeric_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_sensitivity_all_numeric.csv" +# +# primary_valid_only_df.to_csv(primary_valid_only_path, index=False) +# sensitivity_all_numeric_df.to_csv(sensitivity_all_numeric_path, index=False) +# +# print(f"\n✅ Iteration {iteration} complete.") +# print(f"Incremental JSONL saved to: {incremental_jsonl_path}") +# print(f"Incremental CSV saved to: {incremental_csv_path}") +# print(f"Final JSON saved to: {iter_json_path}") +# print(f"Final CSV saved to: {iter_csv_path}") +# print(f"Primary valid-only CSV saved to: {primary_valid_only_path}") +# print(f"Sensitivity all-numeric CSV: {sensitivity_all_numeric_path}") +# print( +# f"⏱️ Iteration time: {iteration_elapsed:.1f}s " +# f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" +# ) +# +# model_elapsed = time.perf_counter() - model_start +# +# # Save all records for this model +# model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" +# with open(model_json_path, "w", encoding="utf-8") as f: +# json.dump(model_records, f, indent=2, ensure_ascii=False) +# +# model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" +# model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) +# model_flat_df.to_csv(model_csv_path, index=False) +# +# # Save model-level analysis datasets +# primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(model_records) +# +# model_primary_valid_only_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_primary_valid_only.csv" +# model_sensitivity_all_numeric_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_sensitivity_all_numeric.csv" +# +# primary_valid_only_df.to_csv(model_primary_valid_only_path, index=False) +# sensitivity_all_numeric_df.to_csv(model_sensitivity_all_numeric_path, index=False) +# +# # Save model summary +# model_summary_df = summarize_records(model_records) +# model_summary_df["model_total_wall_time_sec"] = model_elapsed +# model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 +# +# model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" +# model_summary_df.to_csv(model_summary_path, index=False) +# +# all_model_summaries.append(model_summary_df) +# +# print(f"\n🎉 Model completed: {model_name}") +# print(f"All JSON: {model_json_path}") +# print(f"All CSV: {model_csv_path}") +# print(f"All primary valid-only CSV: {model_primary_valid_only_path}") +# print(f"All sensitivity all-numeric CSV: {model_sensitivity_all_numeric_path}") +# print(f"Summary: {model_summary_path}") +# print(f"Total model time: {model_elapsed / 60:.2f} min") +# +# if all_model_summaries: +# combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) +# combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" +# combined_summary_df.to_csv(combined_summary_path, index=False) +# +# print(f"\n📊 Combined summary saved to: {combined_summary_path}") +# +# print(f"\n🎉 All models and all iterations completed!") +## + + + + +# %% Minimal parallel EDSS benchmark with correct klassifizierbar/EDSS logic -import time -import json import os import re -import threading -from datetime import datetime +import json +import time from pathlib import Path +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed import pandas as pd from openai import OpenAI from dotenv import load_dotenv -try: - import psutil -except ImportError: - psutil = None - print("⚠️ psutil is not installed. Resource metrics will be limited.") - print("Install with: pip install psutil") - # ========================= # CONFIGURATION @@ -2258,7 +3626,18 @@ load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/prompts/Komplett.txt" +RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs" + MODEL_CONFIGS = [ + { + "model_name": "gpt-oss-120b", + "use_response_format": True, + "temperature": 0.0, + "max_tokens": 4096, + "extra_body": None, + }, { "model_name": "qwen3.6-27b", "use_response_format": False, @@ -2277,42 +3656,24 @@ MODEL_CONFIGS = [ "max_tokens": 4096, "extra_body": None, }, - { - "model_name": "gpt-oss-120b", - "use_response_format": True, - "temperature": 0.0, - "max_tokens": 4096, - "extra_body": None, - }, ] -INPUT_CSV ="/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" -EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/prompts/Komplett.txt" +NUM_ITERATIONS = 10 -RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs" - -NUM_ITERATIONS = 2 -STOP_ON_FIRST_ERROR = False - -# For testing, set to e.g. 2. -# For full run, set to None. -MAX_ROWS = 2 +MAX_ROWS = None # MAX_ROWS = 2 -MAX_TOKENS = 4096 -TEMPERATURE = 0.0 +PARALLEL_WORKERS = 10 +BATCH_SIZE = 100 -RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 - -SAVE_EVERY_N_ROWS = 1 - -# Retries for invalid JSON / truncated JSON -MAX_JSON_RETRIES = 2 +MAX_JSON_RETRIES = 5 RETRY_SLEEP_SEC = 2 +STOP_ON_FIRST_ERROR = False + # ========================= -# VALID CLINICAL RANGES +# CONSTANTS # ========================= EDSS_MIN = 0.0 @@ -2344,7 +3705,7 @@ REQUIRED_TOP_LEVEL_FIELDS = [ client = OpenAI( api_key=OPENAI_API_KEY, - base_url=OPENAI_BASE_URL + base_url=OPENAI_BASE_URL, ) @@ -2352,78 +3713,145 @@ client = OpenAI( # HELPERS # ========================= -def safe_dir_name(name: str) -> str: +def safe_dir_name(name): name = str(name).strip() name = re.sub(r"[^\w\-.]+", "_", name) return name[:150] -def now_timestamp() -> str: +def now_timestamp(): return datetime.now().strftime("%Y%m%d_%H%M%S") -def get_process(): - if psutil is None: - return None - return psutil.Process(os.getpid()) +def parse_float(value): + """ + Returns: + - raw value + - numeric float value or None + - is_numeric boolean + """ + if value is None or isinstance(value, bool): + return value, None, False + + try: + return value, float(str(value).replace(",", ".")), True + except Exception: + return value, None, False +def is_in_range(value, min_value, max_value): + if value is None: + return False + return min_value <= value <= max_value -def get_memory_rss_mb(process=None): - if psutil is None: - return None - if process is None: - process = get_process() - return process.memory_info().rss / (1024 * 1024) +def build_patient_text(row): + return ( + str(row.get("T_Zusammenfassung", "")) + "\n" + + str(row.get("Diagnosen", "")) + "\n" + + str(row.get("T_KlinBef", "")) + "\n" + + str(row.get("T_Befunde", "")) + ) -def get_cpu_times_sec(process=None): - if psutil is None: - return None - if process is None: - process = get_process() - cpu_times = process.cpu_times() - return cpu_times.user + cpu_times.system +# ========================= +# READ EDSS INSTRUCTIONS +# ========================= + +with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: + EDSS_INSTRUCTIONS = f.read().strip() -class ResourceSampler: - def __init__(self, interval_sec=0.05): - self.interval_sec = interval_sec - self.process = get_process() - self.running = False - self.thread = None - self.samples_mb = [] +# ========================= +# PROMPT +# ========================= - def start(self): - if psutil is None: - return +def build_prompt(patient_text): + return f"""Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. - self.running = True - self.samples_mb = [] - self.thread = threading.Thread(target=self._sample_loop, daemon=True) - self.thread.start() +Extrahiere: +1. Ob der Bericht für eine EDSS-Einschätzung klassifizierbar ist. +2. Falls klassifizierbar: Gesamt-EDSS-Score von 0.0 bis 10.0. +3. Alle 8 EDSS-Unterkategorien, soweit ableitbar. +4. Sicherheit als Ganzzahl von 0 bis 100. - def stop(self): - if psutil is None: - return +Antworte ausschließlich mit EINEM validen JSON-Objekt. +Kein Markdown. Keine Code-Fences. Kein Text vor oder nach JSON. +Keine Platzhalter. Kopiere kein Schema. - self.running = False - if self.thread is not None: - self.thread.join(timeout=1.0) +Das JSON muss exakt diese Schlüssel enthalten: +- reason +- klassifizierbar +- EDSS +- certainty_percent +- subcategories - def _sample_loop(self): - while self.running: - try: - rss_mb = get_memory_rss_mb(self.process) - self.samples_mb.append(rss_mb) - except Exception: - pass - time.sleep(self.interval_sec) +Die subcategories müssen exakt diese 8 Schlüssel enthalten: +- VISUAL_OPTIC_FUNCTIONS +- BRAINSTEM_FUNCTIONS +- PYRAMIDAL_FUNCTIONS +- CEREBELLAR_FUNCTIONS +- SENSORY_FUNCTIONS +- BOWEL_AND_BLADDER_FUNCTIONS +- CEREBRAL_FUNCTIONS +- AMBULATION - @property - def peak_rss_mb(self): - if not self.samples_mb: - return None - return max(self.samples_mb) +Wichtige Regeln: +- klassifizierbar muss true oder false sein. +- Wenn klassifizierbar=true: + - EDSS muss eine Zahl zwischen 0.0 und 10.0 sein. + - EDSS darf dann niemals null sein. +- Wenn klassifizierbar=false: + - EDSS muss null sein. + - Verwende klassifizierbar=false nur, wenn keine vernünftige EDSS-Einschätzung möglich ist. +- Setze klassifizierbar=true, wenn irgendeine plausible EDSS-Einschätzung aus neurologischem Befund, Gehfähigkeit, Funktionssystemen, Diagnoseverlauf oder explizitem EDSS-Wert ableitbar ist. +- Fehlende einzelne Unterkategorien sind kein Grund für klassifizierbar=false. +- Einzelne Unterkategorien dürfen null sein, wenn sie nicht ausreichend ableitbar sind. +- certainty_percent muss eine Ganzzahl von 0 bis 100 sein. +- reason: maximal 250 Zeichen, Deutsch. + +Valide Beispielausgabe bei klassifizierbarem Bericht: +{{ + "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit.", + "klassifizierbar": true, + "EDSS": 2.0, + "certainty_percent": 90, + "subcategories": {{ + "VISUAL_OPTIC_FUNCTIONS": null, + "BRAINSTEM_FUNCTIONS": null, + "PYRAMIDAL_FUNCTIONS": 1.0, + "CEREBELLAR_FUNCTIONS": 1.0, + "SENSORY_FUNCTIONS": 1.0, + "BOWEL_AND_BLADDER_FUNCTIONS": null, + "CEREBRAL_FUNCTIONS": null, + "AMBULATION": 0.0 + }} +}} + +Valide Beispielausgabe bei nicht klassifizierbarem Bericht: +{{ + "reason": "Keine ausreichenden neurologischen Informationen für eine EDSS-Einschätzung.", + "klassifizierbar": false, + "EDSS": null, + "certainty_percent": 0, + "subcategories": {{ + "VISUAL_OPTIC_FUNCTIONS": null, + "BRAINSTEM_FUNCTIONS": null, + "PYRAMIDAL_FUNCTIONS": null, + "CEREBELLAR_FUNCTIONS": null, + "SENSORY_FUNCTIONS": null, + "BOWEL_AND_BLADDER_FUNCTIONS": null, + "CEREBRAL_FUNCTIONS": null, + "AMBULATION": null + }} +}} + +EDSS-Bewertungsrichtlinien: +{EDSS_INSTRUCTIONS} + +Patientenbericht: +{patient_text} + +Gib ausschließlich das finale JSON-Objekt zurück. +""" # ========================= @@ -2432,7 +3860,7 @@ class ResourceSampler: def extract_json_from_text(text): if text is None: - raise ValueError("Model returned empty content: message.content is None") + raise ValueError("Model returned empty content") text = str(text).strip() @@ -2447,7 +3875,6 @@ def extract_json_from_text(text): .strip() ) - # Direct parse try: parsed = json.loads(text) if isinstance(parsed, dict): @@ -2455,7 +3882,6 @@ def extract_json_from_text(text): except json.JSONDecodeError: pass - # Balanced JSON candidates candidates = [] stack = [] start_idx = None @@ -2490,360 +3916,157 @@ def extract_json_from_text(text): candidates.append(text[start_idx:i + 1]) start_idx = None - valid_objects = [] - - for candidate in candidates: - candidate = candidate.strip() - lowered = candidate.lower() - - invalid_markers = [ - "true/false", - "null or", - "oder zahl", - "0.0-6.0", - "0.0-10.0", - "zahl zwischen", - "...", - ] - - if any(marker in lowered for marker in invalid_markers): - continue - + for candidate in reversed(candidates): try: parsed = json.loads(candidate) if isinstance(parsed, dict): - valid_objects.append(parsed) + return parsed except json.JSONDecodeError: continue - for obj in reversed(valid_objects): - if ( - "klassifizierbar" in obj - and "certainty_percent" in obj - and "subcategories" in obj - ): - return obj - - if valid_objects: - return valid_objects[-1] - - stripped = text.strip() - if stripped.startswith("{") and not stripped.endswith("}"): - raise ValueError( - "Model output looks like truncated JSON. " - f"Raw output starts with: {text[:1000]}" - ) - - raise ValueError( - "No valid JSON object found in model output. " - f"Raw output starts with: {text[:1000]}" - ) + raise ValueError(f"No valid JSON object found. Raw starts with: {text[:500]}") def extract_message_content(message): - raw_content = getattr(message, "content", None) + content = getattr(message, "content", None) - if raw_content is not None: - return raw_content - - msg_dict = None + if content is not None: + return content try: - msg_dict = message.model_dump() + msg = message.model_dump() except Exception: - try: - msg_dict = dict(message) - except Exception: - msg_dict = None - - if not isinstance(msg_dict, dict): return None for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: - value = msg_dict.get(key) - if value: - return value - - possible_content = msg_dict.get("content") - if isinstance(possible_content, list): - parts = [] - for block in possible_content: - if isinstance(block, dict): - if "text" in block: - parts.append(str(block["text"])) - elif "content" in block: - parts.append(str(block["content"])) - if parts: - return "\n".join(parts).strip() + if msg.get(key): + return msg[key] return None # ========================= -# READ INSTRUCTIONS +# VALIDATION WITHOUT CLIPPING # ========================= -with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: - EDSS_INSTRUCTIONS = f.read().strip() - - -# ========================= -# PROMPT -# ========================= - -def build_prompt(patient_text): - return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. - -Extrahiere: -1. Gesamt-EDSS-Score von 0.0 bis 10.0 -2. Alle 8 EDSS-Unterkategorien -3. Sicherheit als Ganzzahl von 0 bis 100 - -Antworte ausschließlich mit EINEM validen JSON-Objekt. -Kein Markdown. -Keine Code-Fences. -Kein Text vor oder nach JSON. -Keine Platzhalter. -Kopiere kein Schema. - -Das JSON muss exakt diese Schlüssel enthalten: -- reason -- klassifizierbar -- EDSS -- certainty_percent -- subcategories - -Die subcategories müssen exakt diese 8 Schlüssel enthalten: -- VISUAL_OPTIC_FUNCTIONS -- BRAINSTEM_FUNCTIONS -- PYRAMIDAL_FUNCTIONS -- CEREBELLAR_FUNCTIONS -- SENSORY_FUNCTIONS -- BOWEL_AND_BLADDER_FUNCTIONS -- CEREBRAL_FUNCTIONS -- AMBULATION - -Werte: -- klassifizierbar: true oder false -- EDSS: Zahl von 0.0 bis 10.0 oder null -- certainty_percent: Ganzzahl von 0 bis 100 -- Unterkategorien: Zahl oder null -- VISUAL_OPTIC_FUNCTIONS maximal 6.0 -- BRAINSTEM_FUNCTIONS maximal 6.0 -- PYRAMIDAL_FUNCTIONS maximal 6.0 -- CEREBELLAR_FUNCTIONS maximal 6.0 -- SENSORY_FUNCTIONS maximal 6.0 -- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 -- CEREBRAL_FUNCTIONS maximal 6.0 -- AMBULATION maximal 10.0 -- reason: maximal 250 Zeichen, Deutsch - -Wenn klassifizierbar false ist, setze EDSS auf null. - -Valide Beispielausgabe: -{{ - "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", - "klassifizierbar": true, - "EDSS": 2.0, - "certainty_percent": 90, - "subcategories": {{ - "VISUAL_OPTIC_FUNCTIONS": null, - "BRAINSTEM_FUNCTIONS": null, - "PYRAMIDAL_FUNCTIONS": 1.0, - "CEREBELLAR_FUNCTIONS": 1.0, - "SENSORY_FUNCTIONS": 1.0, - "BOWEL_AND_BLADDER_FUNCTIONS": null, - "CEREBRAL_FUNCTIONS": null, - "AMBULATION": 0.0 - }} -}} - -EDSS-Bewertungsrichtlinien: -{EDSS_INSTRUCTIONS} - -Patientenbericht: -{patient_text} - -Gib ausschließlich das finale JSON-Objekt zurück. -''' - - -# ========================= -# VALIDATION, NOT NORMALIZATION -# ========================= - -def parse_float_preserve_raw(value): - """ - Try to parse a value as float without clipping or correcting it. - - Returns: - raw_value: original value exactly as present in parsed JSON - numeric_value: float or None - is_numeric: bool - """ - raw_value = value - - if value is None: - return raw_value, None, False - - if isinstance(value, bool): - return raw_value, None, False - - try: - numeric_value = float(str(value).replace(",", ".")) - return raw_value, numeric_value, True - except Exception: - return raw_value, None, False - - -def is_in_range(value, min_value, max_value): - """ - Range check without clipping. - """ - if value is None: - return False - return min_value <= value <= max_value - - def validate_model_output(parsed): - """ - Validate parsed model output without repairing/clipping clinical values. - - Important: - - Does NOT clip EDSS. - - Does NOT clip functional system values. - - Does NOT insert default EDSS. - - Does NOT insert default certainty_percent. - - Missing fields are kept as None. - - Adds explicit validity flags for scientific transparency. - """ - - validation = { - "json_parse_success": isinstance(parsed, dict), - "required_fields_present": False, - "required_schema_success": False, - "clinical_range_valid": False, - "certainty_present": False, - - "missing_required_fields": [], - "missing_subcategory_fields": [], - - "EDSS_is_numeric": False, - "EDSS_in_valid_range": False, - } - - if not isinstance(parsed, dict): - return { - "raw_output": parsed, - "validated_output": {}, - "validation": validation, - } - missing_required = [ field for field in REQUIRED_TOP_LEVEL_FIELDS if field not in parsed ] - validation["missing_required_fields"] = missing_required - validation["required_fields_present"] = len(missing_required) == 0 + required_fields_present = len(missing_required) == 0 - validated = {} + klassifizierbar = parsed.get("klassifizierbar", None) + klassifizierbar_is_bool = isinstance(klassifizierbar, bool) - validated["reason"] = parsed.get("reason", None) - validated["klassifizierbar"] = parsed.get("klassifizierbar", None) + raw_certainty = parsed.get("certainty_percent") + _, certainty_numeric, certainty_is_numeric = parse_float(raw_certainty) - raw_certainty = parsed.get("certainty_percent", None) - validated["raw_certainty_percent"] = raw_certainty - validation["certainty_present"] = "certainty_percent" in parsed and raw_certainty is not None + raw_edss = parsed.get("EDSS") + raw_edss, edss_numeric, edss_is_numeric = parse_float(raw_edss) - _, certainty_numeric, certainty_is_numeric = parse_float_preserve_raw(raw_certainty) - validated["certainty_percent"] = certainty_numeric if certainty_is_numeric else None - validated["certainty_percent_is_numeric"] = certainty_is_numeric - validated["certainty_percent_in_valid_range"] = ( - is_in_range(certainty_numeric, 0.0, 100.0) - if certainty_is_numeric else False - ) - - raw_edss = parsed.get("EDSS", None) - raw_edss, edss_numeric, edss_is_numeric = parse_float_preserve_raw(raw_edss) - - validated["raw_EDSS"] = raw_edss - validated["EDSS_numeric"] = edss_numeric - validated["EDSS"] = edss_numeric # Backward-compatible; parsed only, not clipped - validated["EDSS_is_numeric"] = edss_is_numeric - validated["EDSS_in_valid_range"] = ( + edss_in_valid_range = ( is_in_range(edss_numeric, EDSS_MIN, EDSS_MAX) if edss_is_numeric else False ) - validation["EDSS_is_numeric"] = validated["EDSS_is_numeric"] - validation["EDSS_in_valid_range"] = validated["EDSS_in_valid_range"] - - raw_subcategories = parsed.get("subcategories", None) - - if isinstance(raw_subcategories, dict): - subcategories = raw_subcategories + if klassifizierbar is True: + edss_logic_valid = edss_is_numeric and edss_in_valid_range + elif klassifizierbar is False: + edss_logic_valid = raw_edss is None else: - subcategories = {} + edss_logic_valid = False - validated["subcategories"] = {} - validated["raw_subcategories"] = {} - validated["subcategory_validation"] = {} + raw_subcats = parsed.get("subcategories") + if not isinstance(raw_subcats, dict): + raw_subcats = {} + subcats_numeric = {} + raw_subcats_out = {} + subcat_validation = {} missing_subcats = [] - for subcat, (min_value, max_value) in FUNCTIONAL_SYSTEM_RANGES.items(): - if subcat not in subcategories: - missing_subcats.append(subcat) + for name, (min_value, max_value) in FUNCTIONAL_SYSTEM_RANGES.items(): + if name not in raw_subcats: + missing_subcats.append(name) - raw_value = subcategories.get(subcat, None) - raw_value, numeric_value, is_numeric_value = parse_float_preserve_raw(raw_value) - in_valid_range = ( - is_in_range(numeric_value, min_value, max_value) - if is_numeric_value else False - ) + raw_value = raw_subcats.get(name) + raw_value, numeric_value, is_numeric = parse_float(raw_value) - validated["raw_subcategories"][subcat] = raw_value - validated["subcategories"][subcat] = numeric_value + if raw_value is None: + in_valid_range = True + valid_when_present = True + else: + in_valid_range = ( + is_in_range(numeric_value, min_value, max_value) + if is_numeric else False + ) + valid_when_present = is_numeric and in_valid_range - validated["subcategory_validation"][subcat] = { - "is_numeric": is_numeric_value, + raw_subcats_out[name] = raw_value + subcats_numeric[name] = numeric_value + subcat_validation[name] = { + "is_numeric": is_numeric, "in_valid_range": in_valid_range, - "min_allowed": min_value, - "max_allowed": max_value, + "valid_when_present": valid_when_present, + "is_missing_or_null": raw_value is None, } - validation["missing_subcategory_fields"] = missing_subcats - - subcategory_schema_present = len(missing_subcats) == 0 - - all_subcats_numeric = all( - validated["subcategory_validation"][subcat]["is_numeric"] - for subcat in FUNCTIONAL_SYSTEM_RANGES + all_subcats_valid_when_present = all( + subcat_validation[name]["valid_when_present"] + or subcat_validation[name]["is_missing_or_null"] + for name in FUNCTIONAL_SYSTEM_RANGES ) - all_subcats_in_range = all( - validated["subcategory_validation"][subcat]["in_valid_range"] - for subcat in FUNCTIONAL_SYSTEM_RANGES + required_schema_success = ( + required_fields_present + and len(missing_subcats) == 0 + and klassifizierbar_is_bool ) - validated["all_functional_systems_numeric"] = all_subcats_numeric - validated["all_functional_systems_in_valid_range"] = all_subcats_in_range - - validation["clinical_range_valid"] = ( - validated["EDSS_in_valid_range"] - and all_subcats_in_range - ) - - validation["required_schema_success"] = ( - validation["required_fields_present"] - and subcategory_schema_present + clinical_output_valid = ( + required_schema_success + and edss_logic_valid + and all_subcats_valid_when_present ) return { - "raw_output": parsed, - "validated_output": validated, - "validation": validation, + "reason": parsed.get("reason"), + "klassifizierbar": klassifizierbar, + "klassifizierbar_is_bool": klassifizierbar_is_bool, + + "raw_certainty_percent": raw_certainty, + "certainty_percent": certainty_numeric if certainty_is_numeric else None, + "certainty_present": raw_certainty is not None, + "certainty_percent_is_numeric": certainty_is_numeric, + "certainty_percent_in_valid_range": ( + is_in_range(certainty_numeric, 0.0, 100.0) + if certainty_is_numeric else False + ), + + "raw_EDSS": raw_edss, + "EDSS_numeric": edss_numeric, + "EDSS": edss_numeric, + "EDSS_is_numeric": edss_is_numeric, + "EDSS_in_valid_range": edss_in_valid_range, + "edss_logic_valid": edss_logic_valid, + + "json_parse_success": True, + "required_fields_present": required_fields_present, + "required_schema_success": required_schema_success, + + "clinical_range_valid": clinical_output_valid, + "clinical_output_valid": clinical_output_valid, + + "missing_required_fields": missing_required, + "missing_subcategory_fields": missing_subcats, + + "subcategories": subcats_numeric, + "raw_subcategories": raw_subcats_out, + "subcategory_validation": subcat_validation, + + "all_functional_systems_valid_when_present": all_subcats_valid_when_present, } @@ -2852,212 +4075,147 @@ def validate_model_output(parsed): # ========================= def make_chat_completion(model_config, prompt): - model_name = model_config["model_name"] - - kwargs = dict( - messages=[ + kwargs = { + "messages": [ { "role": "system", "content": ( "Du bist ein JSON-Generator. " "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " - "Keine Erklärung. Kein Markdown. Keine Code-Fences. " - "Keine Platzhalter. Kein Schema kopieren. " - "Das JSON muss vollständig geschlossen sein." - ) + "Keine Erklärung. Kein Markdown. Keine Code-Fences." + ), }, { "role": "user", - "content": prompt - } + "content": prompt, + }, ], - model=model_name, - max_tokens=model_config.get("max_tokens", MAX_TOKENS), - temperature=model_config.get("temperature", TEMPERATURE), - ) + "model": model_config["model_name"], + "max_tokens": model_config["max_tokens"], + "temperature": model_config["temperature"], + } if model_config.get("use_response_format", False): kwargs["response_format"] = {"type": "json_object"} - extra_body = model_config.get("extra_body") - if extra_body is not None: - kwargs["extra_body"] = extra_body + if model_config.get("extra_body") is not None: + kwargs["extra_body"] = model_config["extra_body"] return client.chat.completions.create(**kwargs) -# ========================= -# INFERENCE FUNCTION WITH RETRIES -# ========================= - def run_inference(patient_text, model_config): + base_prompt = build_prompt(patient_text) + prompt = base_prompt model_name = model_config["model_name"] - prompt = build_prompt(patient_text) - process = get_process() - sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) - - wall_start = time.perf_counter() - cpu_start = get_cpu_times_sec(process) - rss_start_mb = get_memory_rss_mb(process) - - sampler.start() + start = time.perf_counter() raw_content = None raw_response_debug = None - raw_parsed_output = None - validation = None last_error = None prompt_tokens = None completion_tokens = None total_tokens = None - try: - for attempt in range(1, MAX_JSON_RETRIES + 2): + for attempt in range(1, MAX_JSON_RETRIES + 2): + try: + response = make_chat_completion(model_config, prompt) + message = response.choices[0].message + raw_content = extract_message_content(message) + try: - response = make_chat_completion( - model_config=model_config, - prompt=prompt + raw_response_debug = response.model_dump() + except Exception: + raw_response_debug = str(response) + + usage = getattr(response, "usage", None) + if usage is not None: + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + total_tokens = getattr(usage, "total_tokens", None) + + parsed = extract_json_from_text(raw_content) + validated = validate_model_output(parsed) + + if not validated.get("clinical_output_valid", False): + raise ValueError( + "Model output is logically invalid. " + f"klassifizierbar={validated.get('klassifizierbar')}, " + f"raw_EDSS={validated.get('raw_EDSS')}, " + f"EDSS_numeric={validated.get('EDSS_numeric')}, " + f"edss_logic_valid={validated.get('edss_logic_valid')}, " + f"missing_required={validated.get('missing_required_fields')}, " + f"missing_subcats={validated.get('missing_subcategory_fields')}" ) - message = response.choices[0].message - raw_content = extract_message_content(message) + return { + "success": True, + "error": None, + "last_error": last_error, + "model": model_name, + "result": validated, + "raw_parsed_output": parsed, + "raw_content": raw_content, + "raw_response_debug": None, + "inference_time_sec": time.perf_counter() - start, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } - try: - raw_response_debug = response.model_dump() - except Exception: - raw_response_debug = str(response) + except Exception as e: + last_error = str(e) - usage = getattr(response, "usage", None) - if usage is not None: - prompt_tokens = getattr(usage, "prompt_tokens", None) - completion_tokens = getattr(usage, "completion_tokens", None) - total_tokens = getattr(usage, "total_tokens", None) + if attempt <= MAX_JSON_RETRIES: + prompt = f""" +Deine vorherige Antwort war ungültig. - parsed = extract_json_from_text(raw_content) - validation_package = validate_model_output(parsed) +Fehler: +{last_error} - success = True - error = None +Antworte erneut mit genau EINEM validen JSON-Objekt. - result = validation_package["validated_output"] - validation = validation_package["validation"] - raw_parsed_output = validation_package["raw_output"] +Strikte Regeln: +- Kein Markdown. +- Keine Code-Fences. +- Kein Text außerhalb des JSON. +- klassifizierbar muss true oder false sein. +- Wenn klassifizierbar=true: EDSS muss eine Zahl zwischen 0.0 und 10.0 sein. +- Wenn klassifizierbar=false: EDSS muss null sein. +- Verwende klassifizierbar=false nur, wenn keine vernünftige EDSS-Einschätzung möglich ist. +- Fehlende einzelne Unterkategorien sind kein Grund für klassifizierbar=false. +- Alle 8 subcategories-Schlüssel müssen vorhanden sein. +- Unterkategorien dürfen null sein, wenn nicht ausreichend ableitbar. - break +Ursprüngliche Aufgabe: +{base_prompt} +""" + time.sleep(RETRY_SLEEP_SEC) + continue - except Exception as e: - last_error = str(e) - - if attempt <= MAX_JSON_RETRIES: - print( - f"\n⚠️ JSON failed on attempt {attempt}. " - f"Retrying row. Error: {last_error[:300]}" - ) - time.sleep(RETRY_SLEEP_SEC) - continue - - raise - - except Exception as e: - print(f"❌ Inference error: {e}") - - success = False - error = str(e) - result = None - raw_parsed_output = None - - validation = { - "json_parse_success": False, - "required_fields_present": False, - "required_schema_success": False, - "clinical_range_valid": False, - "certainty_present": False, - "missing_required_fields": [], - "missing_subcategory_fields": [], - "EDSS_is_numeric": False, - "EDSS_in_valid_range": False, - } - - finally: - sampler.stop() - - wall_end = time.perf_counter() - cpu_end = get_cpu_times_sec(process) - rss_end_mb = get_memory_rss_mb(process) - - wall_time_sec = wall_end - wall_start - - if cpu_start is not None and cpu_end is not None: - process_cpu_time_sec = cpu_end - cpu_start - else: - process_cpu_time_sec = None - - if rss_start_mb is not None and rss_end_mb is not None: - rss_delta_mb = rss_end_mb - rss_start_mb - else: - rss_delta_mb = None - - return { - "success": success, - "error": error, - "result": result, - - "validation": validation, - "raw_parsed_output": raw_parsed_output, - - "model": model_name, - - "inference_time_sec": wall_time_sec, - - "process_cpu_time_sec": process_cpu_time_sec, - "rss_before_mb": rss_start_mb, - "rss_after_mb": rss_end_mb, - "rss_delta_mb": rss_delta_mb, - "peak_rss_mb": sampler.peak_rss_mb, - - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - - # Keeping raw content improves auditability but can make files large. - # To save space, change this to: raw_content if not success else None - "raw_content": raw_content, - "raw_response_debug": raw_response_debug if not success else None, - "last_error": last_error, - } + return { + "success": False, + "error": str(e), + "last_error": last_error, + "model": model_name, + "result": None, + "raw_parsed_output": None, + "raw_content": raw_content, + "raw_response_debug": raw_response_debug, + "inference_time_sec": time.perf_counter() - start, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } # ========================= -# BUILD PATIENT TEXT -# ========================= - -def build_patient_text(row): - return ( - str(row.get("T_Zusammenfassung", "")) + "\n" + - str(row.get("Diagnosen", "")) + "\n" + - str(row.get("T_KlinBef", "")) + "\n" + - str(row.get("T_Befunde", "")) - ) - - -# ========================= -# FLATTEN RESULTS FOR CSV +# FLATTEN / SAVE # ========================= def flatten_result(record): - """ - Flatten one benchmark record for CSV export. - - This preserves: - - raw model values - - parsed numeric values without clipping - - validity flags - - backward-compatible columns where possible - """ - - validation = record.get("validation") or {} result = record.get("result") or {} flat = { @@ -3072,281 +4230,68 @@ def flatten_result(record): "error": record.get("error"), "last_error": record.get("last_error"), - "json_parse_success": validation.get("json_parse_success"), - "required_fields_present": validation.get("required_fields_present"), - "required_schema_success": validation.get("required_schema_success"), - "clinical_range_valid": validation.get("clinical_range_valid"), - "certainty_present": validation.get("certainty_present"), - - "missing_required_fields": json.dumps( - validation.get("missing_required_fields", []), - ensure_ascii=False - ), - "missing_subcategory_fields": json.dumps( - validation.get("missing_subcategory_fields", []), - ensure_ascii=False - ), + "json_parse_success": result.get("json_parse_success", False), + "required_fields_present": result.get("required_fields_present", False), + "required_schema_success": result.get("required_schema_success", False), + "clinical_range_valid": result.get("clinical_range_valid", False), + "clinical_output_valid": result.get("clinical_output_valid", False), + "klassifizierbar_is_bool": result.get("klassifizierbar_is_bool"), + "edss_logic_valid": result.get("edss_logic_valid"), + "all_functional_systems_valid_when_present": result.get("all_functional_systems_valid_when_present"), "inference_time_sec": record.get("inference_time_sec"), - "process_cpu_time_sec": record.get("process_cpu_time_sec"), - "rss_before_mb": record.get("rss_before_mb"), - "rss_after_mb": record.get("rss_after_mb"), - "rss_delta_mb": record.get("rss_delta_mb"), - "peak_rss_mb": record.get("peak_rss_mb"), - "prompt_tokens": record.get("prompt_tokens"), "completion_tokens": record.get("completion_tokens"), "total_tokens": record.get("total_tokens"), "raw_content": record.get("raw_content"), - "raw_parsed_output": json.dumps(record.get("raw_parsed_output"), ensure_ascii=False), + "raw_parsed_output": json.dumps( + record.get("raw_parsed_output"), + ensure_ascii=False + ), - # Backward-compatible fields "reason": result.get("reason"), "klassifizierbar": result.get("klassifizierbar"), "raw_certainty_percent": result.get("raw_certainty_percent"), "certainty_percent": result.get("certainty_percent"), + "certainty_present": result.get("certainty_present"), "certainty_percent_is_numeric": result.get("certainty_percent_is_numeric"), "certainty_percent_in_valid_range": result.get("certainty_percent_in_valid_range"), - # EDSS raw/numeric/validity fields "raw_EDSS": result.get("raw_EDSS"), "EDSS_numeric": result.get("EDSS_numeric"), - "EDSS": result.get("EDSS"), # backward-compatible; same as EDSS_numeric, not clipped + "EDSS": result.get("EDSS"), "EDSS_is_numeric": result.get("EDSS_is_numeric"), "EDSS_in_valid_range": result.get("EDSS_in_valid_range"), - "all_functional_systems_numeric": result.get("all_functional_systems_numeric"), - "all_functional_systems_in_valid_range": result.get("all_functional_systems_in_valid_range"), + "missing_required_fields": json.dumps( + result.get("missing_required_fields", []), + ensure_ascii=False + ), + "missing_subcategory_fields": json.dumps( + result.get("missing_subcategory_fields", []), + ensure_ascii=False + ), } - raw_subcategories = result.get("raw_subcategories", {}) - numeric_subcategories = result.get("subcategories", {}) - subcat_validation = result.get("subcategory_validation", {}) + raw_subcats = result.get("raw_subcategories", {}) + num_subcats = result.get("subcategories", {}) + subcat_flags = result.get("subcategory_validation", {}) - for subcat in FUNCTIONAL_SYSTEM_RANGES: - raw_value = None - numeric_value = None - is_numeric = False - in_valid_range = False + for name in FUNCTIONAL_SYSTEM_RANGES: + flags = subcat_flags.get(name, {}) - if isinstance(raw_subcategories, dict): - raw_value = raw_subcategories.get(subcat) - - if isinstance(numeric_subcategories, dict): - numeric_value = numeric_subcategories.get(subcat) - - if isinstance(subcat_validation, dict): - flags = subcat_validation.get(subcat, {}) - if isinstance(flags, dict): - is_numeric = flags.get("is_numeric", False) - in_valid_range = flags.get("in_valid_range", False) - - # New transparent columns - flat[f"raw_subcat_{subcat}"] = raw_value - flat[f"numeric_subcat_{subcat}"] = numeric_value - flat[f"subcat_{subcat}_is_numeric"] = is_numeric - flat[f"subcat_{subcat}_in_valid_range"] = in_valid_range - - # Backward-compatible old column name. - # This is numeric but NOT clipped. - flat[f"subcat_{subcat}"] = numeric_value + flat[f"raw_subcat_{name}"] = raw_subcats.get(name) + flat[f"numeric_subcat_{name}"] = num_subcats.get(name) + flat[f"subcat_{name}_is_numeric"] = flags.get("is_numeric", False) + flat[f"subcat_{name}_in_valid_range"] = flags.get("in_valid_range", False) + flat[f"subcat_{name}_valid_when_present"] = flags.get("valid_when_present", False) + flat[f"subcat_{name}"] = num_subcats.get(name) return flat -# ========================= -# SUMMARY STATISTICS -# ========================= - -def summarize_records(records): - """ - Create transparent summary statistics per model. - - Separates: - - JSON/schema validity - - numeric parse validity - - clinical range validity - - out-of-range outputs - """ - - df = pd.DataFrame([flatten_result(r) for r in records]) - - if df.empty: - return pd.DataFrame() - - def bool_mean(col): - if col not in df.columns: - return None - return df[col].fillna(False).astype(bool).mean() - - def bool_sum(col): - if col not in df.columns: - return None - return int(df[col].fillna(False).astype(bool).sum()) - - n_records = len(df) - - summary = { - "model": df["model"].iloc[0] if "model" in df.columns else None, - "n_total_responses": n_records, - - "n_success": bool_sum("success"), - "success_rate": bool_mean("success"), - - "n_json_parse_success": bool_sum("json_parse_success"), - "json_parse_success_rate": bool_mean("json_parse_success"), - - "n_required_fields_present": bool_sum("required_fields_present"), - "required_fields_present_rate": bool_mean("required_fields_present"), - - "n_required_schema_success": bool_sum("required_schema_success"), - "required_schema_success_rate": bool_mean("required_schema_success"), - - "n_clinical_range_valid": bool_sum("clinical_range_valid"), - "clinical_range_valid_rate": bool_mean("clinical_range_valid"), - - "n_certainty_present": bool_sum("certainty_present"), - "certainty_present_rate": bool_mean("certainty_present"), - - "n_EDSS_numeric": bool_sum("EDSS_is_numeric"), - "EDSS_numeric_rate": bool_mean("EDSS_is_numeric"), - - "n_EDSS_in_valid_range": bool_sum("EDSS_in_valid_range"), - "EDSS_valid_range_rate": bool_mean("EDSS_in_valid_range"), - } - - # EDSS out-of-range among numeric EDSS outputs - if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: - edss_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) - edss_valid = df["EDSS_in_valid_range"].fillna(False).astype(bool) - edss_out_of_range = edss_numeric & (~edss_valid) - - summary["n_EDSS_out_of_range"] = int(edss_out_of_range.sum()) - summary["EDSS_out_of_range_rate_total"] = float(edss_out_of_range.mean()) - summary["EDSS_out_of_range_rate_among_numeric"] = ( - float(edss_out_of_range.sum() / edss_numeric.sum()) - if edss_numeric.sum() > 0 else None - ) - - # Functional system rates - fs_out_of_range_any = pd.Series(False, index=df.index) - fs_valid_all = pd.Series(True, index=df.index) - - for subcat in FUNCTIONAL_SYSTEM_RANGES: - numeric_col = f"subcat_{subcat}_is_numeric" - valid_col = f"subcat_{subcat}_in_valid_range" - - if numeric_col in df.columns: - numeric_series = df[numeric_col].fillna(False).astype(bool) - else: - numeric_series = pd.Series(False, index=df.index) - - if valid_col in df.columns: - valid_series = df[valid_col].fillna(False).astype(bool) - else: - valid_series = pd.Series(False, index=df.index) - - out_of_range_series = numeric_series & (~valid_series) - - summary[f"n_{subcat}_numeric"] = int(numeric_series.sum()) - summary[f"{subcat}_numeric_rate"] = float(numeric_series.mean()) - - summary[f"n_{subcat}_in_valid_range"] = int(valid_series.sum()) - summary[f"{subcat}_valid_range_rate"] = float(valid_series.mean()) - - summary[f"n_{subcat}_out_of_range"] = int(out_of_range_series.sum()) - summary[f"{subcat}_out_of_range_rate_total"] = float(out_of_range_series.mean()) - summary[f"{subcat}_out_of_range_rate_among_numeric"] = ( - float(out_of_range_series.sum() / numeric_series.sum()) - if numeric_series.sum() > 0 else None - ) - - fs_out_of_range_any = fs_out_of_range_any | out_of_range_series - fs_valid_all = fs_valid_all & valid_series - - summary["n_any_functional_system_out_of_range"] = int(fs_out_of_range_any.sum()) - summary["any_functional_system_out_of_range_rate_total"] = float(fs_out_of_range_any.mean()) - - summary["n_all_functional_systems_in_valid_range"] = int(fs_valid_all.sum()) - summary["all_functional_systems_valid_range_rate"] = float(fs_valid_all.mean()) - - numeric_cols = [ - "inference_time_sec", - "process_cpu_time_sec", - "rss_delta_mb", - "peak_rss_mb", - "prompt_tokens", - "completion_tokens", - "total_tokens", - "certainty_percent", - "EDSS_numeric", - ] - - for col in numeric_cols: - if col in df.columns: - values = pd.to_numeric(df[col], errors="coerce") - summary[f"{col}_mean"] = values.mean() - summary[f"{col}_median"] = values.median() - summary[f"{col}_std"] = values.std() - summary[f"{col}_min"] = values.min() - summary[f"{col}_max"] = values.max() - - if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: - primary_valid_only = ( - df["EDSS_is_numeric"].fillna(False).astype(bool) - & df["EDSS_in_valid_range"].fillna(False).astype(bool) - ) - - sensitivity_all_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) - - summary["n_primary_valid_only_EDSS"] = int(primary_valid_only.sum()) - summary["primary_valid_only_EDSS_rate"] = float(primary_valid_only.mean()) - - summary["n_sensitivity_all_numeric_EDSS"] = int(sensitivity_all_numeric.sum()) - summary["sensitivity_all_numeric_EDSS_rate"] = float(sensitivity_all_numeric.mean()) - - return pd.DataFrame([summary]) - - -# ========================= -# ANALYSIS DATASET HELPERS -# ========================= - -def create_analysis_datasets(records): - """ - Create two transparent EDSS analysis datasets: - - 1. primary_valid_only: - Only numeric EDSS predictions within the valid clinical range. - - 2. sensitivity_all_numeric: - All numeric EDSS predictions, including out-of-range values. - No clipping is applied. - """ - - df = pd.DataFrame([flatten_result(r) for r in records]) - - if df.empty: - return df.copy(), df.copy() - - primary_valid_only = df[ - df["EDSS_is_numeric"].fillna(False).astype(bool) - & df["EDSS_in_valid_range"].fillna(False).astype(bool) - ].copy() - - sensitivity_all_numeric = df[ - df["EDSS_is_numeric"].fillna(False).astype(bool) - ].copy() - - return primary_valid_only, sensitivity_all_numeric - - -# ========================= -# INCREMENTAL SAVE HELPERS -# ========================= - def append_jsonl(path, record): with open(path, "a", encoding="utf-8") as f: f.write(json.dumps(record, ensure_ascii=False) + "\n") @@ -3355,14 +4300,79 @@ def append_jsonl(path, record): def append_csv(path, record): - flat = flatten_result(record) - df_one = pd.DataFrame([flat]) + one = pd.DataFrame([flatten_result(record)]) file_exists = Path(path).exists() - df_one.to_csv(path, mode="a", header=not file_exists, index=False) + one.to_csv(path, mode="a", header=not file_exists, index=False) + + +def save_json(path, records): + with open(path, "w", encoding="utf-8") as f: + json.dump(records, f, indent=2, ensure_ascii=False) + + +def save_csv(path, records): + df = pd.DataFrame([flatten_result(r) for r in records]) + df.to_csv(path, index=False) + + +def summarize_records(records): + df = pd.DataFrame([flatten_result(r) for r in records]) + + if df.empty: + return pd.DataFrame() + + def rate(col): + if col not in df.columns: + return None + return df[col].fillna(False).astype(bool).mean() + + summary = { + "model": df["model"].iloc[0], + "n_total": len(df), + "success_rate": rate("success"), + "json_parse_success_rate": rate("json_parse_success"), + "required_schema_success_rate": rate("required_schema_success"), + "clinical_output_valid_rate": rate("clinical_output_valid"), + "edss_logic_valid_rate": rate("edss_logic_valid"), + "EDSS_numeric_rate": rate("EDSS_is_numeric"), + "EDSS_valid_range_rate": rate("EDSS_in_valid_range"), + "klassifizierbar_true_rate": ( + df["klassifizierbar"].fillna(False).astype(bool).mean() + if "klassifizierbar" in df.columns + else None + ), + "mean_inference_time_sec": pd.to_numeric( + df["inference_time_sec"], + errors="coerce" + ).mean(), + } + + return pd.DataFrame([summary]) # ========================= -# MAIN LOOP +# PARALLEL ROW PROCESSING +# ========================= + +def process_one_row(payload, model_config, iteration): + idx, row_dict, row_number, total_rows = payload + + row = pd.Series(row_dict) + patient_text = build_patient_text(row) + + record = run_inference(patient_text, model_config) + + record["iteration"] = iteration + record["row_index"] = int(idx) + record["row_number_in_run"] = int(row_number) + record["unique_id"] = row.get("unique_id", f"row_{idx}") + record["MedDatum"] = row.get("MedDatum", None) + + return record + + +# ========================= +# MAIN # ========================= if __name__ == "__main__": @@ -3370,8 +4380,6 @@ if __name__ == "__main__": run_timestamp = now_timestamp() results_root = Path(RESULTS_ROOT) - results_root.mkdir(parents=True, exist_ok=True) - run_root = results_root / f"run_{run_timestamp}" run_root.mkdir(parents=True, exist_ok=True) @@ -3384,13 +4392,18 @@ if __name__ == "__main__": total_rows = len(df) - model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] + print(f"Loaded rows: {total_rows}") + print(f"Parallel workers: {PARALLEL_WORKERS}") + print(f"Batch size: {BATCH_SIZE}") + print(f"Iterations: {NUM_ITERATIONS}") + print(f"Models: {[m['model_name'] for m in MODEL_CONFIGS]}") - print(f"Loaded {total_rows} patient records.") - print(f"Models: {model_names_for_print}") - print(f"Iterations per model: {NUM_ITERATIONS}") + row_payloads = [ + (idx, row.to_dict(), row_number, total_rows) + for row_number, (idx, row) in enumerate(df.iterrows(), start=1) + ] - all_model_summaries = [] + all_summaries = [] for model_config in MODEL_CONFIGS: model_name = model_config["model_name"] @@ -3399,203 +4412,130 @@ if __name__ == "__main__": model_dir = run_root / safe_model model_dir.mkdir(parents=True, exist_ok=True) - print(f"\n{'#' * 80}") + print("\n" + "#" * 80) print(f"MODEL: {model_name}") - print(f"use_response_format: {model_config.get('use_response_format', False)}") - print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") - print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") print(f"Saving to: {model_dir}") - print(f"{'#' * 80}") + print("#" * 80) model_records = [] model_start = time.perf_counter() for iteration in range(1, NUM_ITERATIONS + 1): - print(f"\n{'=' * 60}") - print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") - print(f"{'=' * 60}") + print("\n" + "=" * 80) + print(f"MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") + print("=" * 80) - iteration_results = [] iteration_start = time.perf_counter() + iteration_results = [] incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" - print(f"Incremental JSONL: {incremental_jsonl_path}") - print(f"Incremental CSV: {incremental_csv_path}") + completed = 0 + + for batch_start in range(0, len(row_payloads), BATCH_SIZE): + batch = row_payloads[batch_start:batch_start + BATCH_SIZE] - for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): print( - f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", - end="", - flush=True + f"\nSubmitting rows {batch_start + 1}-" + f"{batch_start + len(batch)} / {total_rows}" ) - try: - patient_text = build_patient_text(row) + with ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as executor: + futures = [ + executor.submit( + process_one_row, + payload, + model_config, + iteration + ) + for payload in batch + ] - record = run_inference( - patient_text=patient_text, - model_config=model_config - ) + for future in as_completed(futures): + record = future.result() - record["iteration"] = iteration - record["row_index"] = int(idx) - record["row_number_in_run"] = int(loop_i) - record["unique_id"] = row.get("unique_id", f"row_{idx}") - record["MedDatum"] = row.get("MedDatum", None) + completed += 1 + iteration_results.append(record) + model_records.append(record) - iteration_results.append(record) - model_records.append(record) - - if loop_i % SAVE_EVERY_N_ROWS == 0: append_jsonl(incremental_jsonl_path, record) append_csv(incremental_csv_path, record) - if record["success"]: - res = record["result"] or {} - edss_display = res.get("EDSS_numeric", None) - edss_valid = res.get("EDSS_in_valid_range", False) + if record["success"]: + result = record.get("result") or {} - print( - f" ✅ EDSS={edss_display}, " - f"valid_range={edss_valid}, " - f"time={record['inference_time_sec']:.2f}s" - ) - else: - print(f" ❌ {record.get('error', 'Unknown error')}") + print( + f"Done {completed}/{total_rows} | " + f"row={record.get('row_number_in_run')} | " + f"klassifizierbar={result.get('klassifizierbar')} | " + f"EDSS={result.get('EDSS_numeric')} | " + f"edss_logic_valid={result.get('edss_logic_valid')} | " + f"clinical_output_valid={result.get('clinical_output_valid')} | " + f"time={record.get('inference_time_sec'):.2f}s" + ) - except Exception as e: - print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") + else: + print( + f"Done {completed}/{total_rows} | " + f"row={record.get('row_number_in_run')} | " + f"ERROR={record.get('error')}" + ) - fallback_record = { - "success": False, - "error": str(e), - "last_error": str(e), - "result": None, + if STOP_ON_FIRST_ERROR: + raise RuntimeError(record.get("error")) - "validation": { - "json_parse_success": False, - "required_fields_present": False, - "required_schema_success": False, - "clinical_range_valid": False, - "certainty_present": False, - "missing_required_fields": [], - "missing_subcategory_fields": [], - "EDSS_is_numeric": False, - "EDSS_in_valid_range": False, - }, - "raw_parsed_output": None, - - "model": model_name, - "iteration": iteration, - "row_index": int(idx), - "row_number_in_run": int(loop_i), - "unique_id": row.get("unique_id", f"row_{idx}"), - "MedDatum": row.get("MedDatum", None), - - "inference_time_sec": None, - "process_cpu_time_sec": None, - "rss_before_mb": None, - "rss_after_mb": None, - "rss_delta_mb": None, - "peak_rss_mb": None, - - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - - "raw_content": None, - "raw_response_debug": None, - } - - iteration_results.append(fallback_record) - model_records.append(fallback_record) - - append_jsonl(incremental_jsonl_path, fallback_record) - append_csv(incremental_csv_path, fallback_record) - - if STOP_ON_FIRST_ERROR: - break - - iteration_elapsed = time.perf_counter() - iteration_start - - # Final full per-iteration JSON - iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" - with open(iter_json_path, "w", encoding="utf-8") as f: - json.dump(iteration_results, f, indent=2, ensure_ascii=False) - - # Final full per-iteration CSV - iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" - iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) - iter_flat_df.to_csv(iter_csv_path, index=False) - - # Transparent analysis datasets - primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(iteration_results) - - primary_valid_only_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_primary_valid_only.csv" - sensitivity_all_numeric_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_sensitivity_all_numeric.csv" - - primary_valid_only_df.to_csv(primary_valid_only_path, index=False) - sensitivity_all_numeric_df.to_csv(sensitivity_all_numeric_path, index=False) - - print(f"\n✅ Iteration {iteration} complete.") - print(f"Incremental JSONL saved to: {incremental_jsonl_path}") - print(f"Incremental CSV saved to: {incremental_csv_path}") - print(f"Final JSON saved to: {iter_json_path}") - print(f"Final CSV saved to: {iter_csv_path}") - print(f"Primary valid-only CSV saved to: {primary_valid_only_path}") - print(f"Sensitivity all-numeric CSV: {sensitivity_all_numeric_path}") - print( - f"⏱️ Iteration time: {iteration_elapsed:.1f}s " - f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" + iteration_results = sorted( + iteration_results, + key=lambda r: r.get("row_number_in_run", 10**9) ) - model_elapsed = time.perf_counter() - model_start + iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" + iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" + + save_json(iter_json_path, iteration_results) + save_csv(iter_csv_path, iteration_results) + + elapsed = time.perf_counter() - iteration_start + + print(f"\nIteration {iteration} complete.") + print(f"JSON: {iter_json_path}") + print(f"CSV: {iter_csv_path}") + print(f"Time: {elapsed / 60:.2f} min") + + model_records = sorted( + model_records, + key=lambda r: ( + r.get("iteration", 10**9), + r.get("row_number_in_run", 10**9) + ) + ) - # Save all records for this model model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" - with open(model_json_path, "w", encoding="utf-8") as f: - json.dump(model_records, f, indent=2, ensure_ascii=False) - model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" - model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) - model_flat_df.to_csv(model_csv_path, index=False) - - # Save model-level analysis datasets - primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(model_records) - - model_primary_valid_only_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_primary_valid_only.csv" - model_sensitivity_all_numeric_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_sensitivity_all_numeric.csv" - - primary_valid_only_df.to_csv(model_primary_valid_only_path, index=False) - sensitivity_all_numeric_df.to_csv(model_sensitivity_all_numeric_path, index=False) - - # Save model summary - model_summary_df = summarize_records(model_records) - model_summary_df["model_total_wall_time_sec"] = model_elapsed - model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 - model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" - model_summary_df.to_csv(model_summary_path, index=False) - all_model_summaries.append(model_summary_df) + save_json(model_json_path, model_records) + save_csv(model_csv_path, model_records) - print(f"\n🎉 Model completed: {model_name}") - print(f"All JSON: {model_json_path}") - print(f"All CSV: {model_csv_path}") - print(f"All primary valid-only CSV: {model_primary_valid_only_path}") - print(f"All sensitivity all-numeric CSV: {model_sensitivity_all_numeric_path}") - print(f"Summary: {model_summary_path}") - print(f"Total model time: {model_elapsed / 60:.2f} min") + summary_df = summarize_records(model_records) + summary_df["model_total_wall_time_sec"] = time.perf_counter() - model_start + summary_df["model_total_wall_time_min"] = summary_df["model_total_wall_time_sec"] / 60 + summary_df.to_csv(model_summary_path, index=False) - if all_model_summaries: - combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) - combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" - combined_summary_df.to_csv(combined_summary_path, index=False) + all_summaries.append(summary_df) - print(f"\n📊 Combined summary saved to: {combined_summary_path}") + print(f"\nModel complete: {model_name}") + print(f"All JSON: {model_json_path}") + print(f"All CSV: {model_csv_path}") + print(f"Summary: {model_summary_path}") - print(f"\n🎉 All models and all iterations completed!") + if all_summaries: + combined = pd.concat(all_summaries, ignore_index=True) + combined_path = run_root / f"all_models_summary_{run_timestamp}.csv" + combined.to_csv(combined_path, index=False) + print(f"\nCombined summary: {combined_path}") + + print("\nAll models and iterations completed.") ## diff --git a/scripts/show_plots.py b/scripts/show_plots.py index 384cbaf..96b550d 100644 --- a/scripts/show_plots.py +++ b/scripts/show_plots.py @@ -3119,16 +3119,15 @@ plt.show() ## -# %% Confusion matrix for one EDSS benchmark result file +# %% Confusion matrices for iteration 1 of each model with shared color scale -import os from pathlib import Path +import re import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns - from sklearn.metrics import confusion_matrix, classification_report @@ -3136,25 +3135,39 @@ from sklearn.metrics import confusion_matrix, classification_report # CONFIGURATION # ========================= -REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) -RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv" +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) -OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices" +OUTPUT_DIR = RUN_DIR / "confusion_matrices_iter_1" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) TARGET_ITERATION = 1 -MERGE_KEY = "unique_id" - -# Ground truth EDSS column in the reference file GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" -# Predicted EDSS column in the result file -PRED_EDSS_COL = "EDSS" +# If you want to manually force a color maximum, set this to a number, e.g. 80. +# If None, the script uses the largest cell count across all model confusion matrices. +MANUAL_GLOBAL_VMAX = None EDSS_LABELS = [ - "0-1", "1-2", "2-3", "3-4", "4-5", - "5-6", "6-7", "7-8", "8-9", "9-10" + r"$0 \leq x \leq 1$", + r"$1 < x \leq 2$", + r"$2 < x \leq 3$", + r"$3 < x \leq 4$", + r"$4 < x \leq 5$", + r"$5 < x \leq 6$", + r"$6 < x \leq 7$", + r"$7 < x \leq 8$", + r"$8 < x \leq 9$", + r"$9 < x \leq 10$", ] @@ -3162,244 +3175,9792 @@ EDSS_LABELS = [ # HELPERS # ========================= -def safe_filename(name): - return ( - str(name) - .replace("/", "_") - .replace("\\", "_") - .replace(" ", "_") - .replace(":", "_") - ) - - -def parse_numeric_column(series): +def to_num(s): return pd.to_numeric( - series.astype(str).str.replace(",", ".", regex=False), + s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) -def categorize_edss(value): - if pd.isna(value): - return np.nan - elif value <= 1.0: - return "0-1" - elif value <= 2.0: - return "1-2" - elif value <= 3.0: - return "2-3" - elif value <= 4.0: - return "3-4" - elif value <= 5.0: - return "4-5" - elif value <= 6.0: - return "5-6" - elif value <= 7.0: - return "6-7" - elif value <= 8.0: - return "7-8" - elif value <= 9.0: - return "8-9" - elif value <= 10.0: - return "9-10" - else: +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def safe_name(name): + return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) + + +def categorize_edss(x): + if pd.isna(x): return np.nan + if x <= 1: + return EDSS_LABELS[0] + if x <= 2: + return EDSS_LABELS[1] + if x <= 3: + return EDSS_LABELS[2] + if x <= 4: + return EDSS_LABELS[3] + if x <= 5: + return EDSS_LABELS[4] + if x <= 6: + return EDSS_LABELS[5] + if x <= 7: + return EDSS_LABELS[6] + if x <= 8: + return EDSS_LABELS[7] + if x <= 9: + return EDSS_LABELS[8] + if x <= 10: + return EDSS_LABELS[9] + return np.nan -def load_reference(reference_path): - df_ref = pd.read_csv(reference_path, sep=";") +def find_iter_file(model_dir): + files = sorted(model_dir.glob(f"*results_iter_{TARGET_ITERATION}_*.csv")) - if MERGE_KEY not in df_ref.columns: - raise ValueError(f"Reference file does not contain column: {MERGE_KEY}") + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] - if GT_EDSS_COL not in df_ref.columns: - raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}") - - df_ref = df_ref.copy() - df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str) - - df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL]) - df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss) - - return df_ref + return files[0] if files else None -def load_result(result_path): - df_res = pd.read_csv(result_path, sep=",") - - if MERGE_KEY not in df_res.columns: - raise ValueError(f"Result file does not contain column: {MERGE_KEY}") - - if PRED_EDSS_COL not in df_res.columns: - raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}") - - df_res = df_res.copy() - df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str) - - if "success" in df_res.columns: - df_res = df_res[ - df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"]) - ] - - if TARGET_ITERATION is not None and "iteration" in df_res.columns: - df_res = df_res[df_res["iteration"] == TARGET_ITERATION] - - df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL]) - df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss) - - return df_res +def get_model_name(pred, model_dir): + if "model" in pred.columns and pred["model"].notna().any(): + return str(pred["model"].dropna().iloc[0]) + return model_dir.name -def get_model_name(df_res, result_path): - if "model" in df_res.columns and df_res["model"].notna().any(): - return str(df_res["model"].dropna().iloc[0]) +# ========================= +# LOAD GROUND TRUTH +# ========================= - return Path(result_path).stem +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt["GT_EDSS_cat"] = gt["GT_EDSS_numeric"].apply(categorize_edss) + +print(f"GT rows: {len(gt)}") +print(f"GT numeric EDSS rows: {gt['GT_EDSS_numeric'].notna().sum()}") -def plot_confusion_matrix(cm, model_name, output_path): - plt.figure(figsize=(10, 8)) +# ========================= +# FIRST PASS: COMPUTE ALL CONFUSION MATRICES +# ========================= + +model_results = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() and p.name != OUTPUT_DIR.name +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir) + + if result_file is None: + print(f"\nNo iteration {TARGET_ITERATION} result CSV found in: {model_dir}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print("Skipping: row_index column missing.") + continue + + model_name = get_model_name(pred_raw, model_dir) + safe_model = safe_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + raw_rows = len(pred) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + # For confusion matrix, use only rows where model produced numeric EDSS in valid range. + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["PRED_EDSS_cat"] = pred["PRED_EDSS_numeric"].apply(categorize_edss) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "PRED_EDSS_cat"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + eval_df = merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy() + + print(f"Raw prediction rows: {raw_rows}") + print(f"Prediction rows after filters: {len(pred)}") + print(f"Merged rows: {len(merged)}") + print(f"Evaluable rows: {len(eval_df)}") + + if eval_df.empty: + print("No evaluable rows. Skipping.") + continue + + cm = confusion_matrix( + eval_df["GT_EDSS_cat"], + eval_df["PRED_EDSS_cat"], + labels=EDSS_LABELS + ) + + report = classification_report( + eval_df["GT_EDSS_cat"], + eval_df["PRED_EDSS_cat"], + labels=EDSS_LABELS, + zero_division=0 + ) + + cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS) + cm_df.index.name = "Ground Truth EDSS" + cm_df.columns.name = "LLM Generated EDSS" + + print("\nClassification Report:") + print(report) + + print("\nConfusion Matrix:") + print(cm_df) + + model_results.append({ + "model_name": model_name, + "safe_model": safe_model, + "model_dir": model_dir, + "result_file": result_file, + "raw_rows": raw_rows, + "pred_rows_after_filters": len(pred), + "merged_rows": len(merged), + "evaluable_rows": len(eval_df), + "cm": cm, + "cm_df": cm_df, + "report": report, + "eval_df": eval_df, + }) + + +if not model_results: + raise RuntimeError("No confusion matrices were computed. Check paths and result files.") + + +# ========================= +# SHARED COLOR SCALE +# ========================= + +if MANUAL_GLOBAL_VMAX is not None: + GLOBAL_VMAX = MANUAL_GLOBAL_VMAX +else: + GLOBAL_VMAX = max(item["cm"].max() for item in model_results) + +print("\n" + "=" * 100) +print(f"Shared heatmap color scale: vmin=0, vmax={GLOBAL_VMAX}") +print("=" * 100) + + +# ========================= +# SECOND PASS: SAVE PLOTS AND FILES +# ========================= + +summaries = [] + +for item in model_results: + model_name = item["model_name"] + safe_model = item["safe_model"] + result_file = item["result_file"] + cm = item["cm"] + cm_df = item["cm_df"] + report = item["report"] + eval_df = item["eval_df"] + + svg_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.svg" + png_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.png" + csv_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.csv" + report_path = OUTPUT_DIR / f"{safe_model}_classification_report_iter_{TARGET_ITERATION}.txt" + merged_path = OUTPUT_DIR / f"{safe_model}_merged_eval_rows_iter_{TARGET_ITERATION}.csv" + + plt.figure(figsize=(11, 9)) ax = sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", + vmin=0, + vmax=GLOBAL_VMAX, xticklabels=EDSS_LABELS, yticklabels=EDSS_LABELS ) - cbar = ax.collections[0].colorbar - cbar.set_label("Number of Cases", rotation=270, labelpad=20) + ax.collections[0].colorbar.set_label( + "Number of Cases", + rotation=270, + labelpad=20 + ) plt.xlabel("LLM Generated EDSS") plt.ylabel("Ground Truth EDSS") plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}") - + plt.xticks(rotation=45, ha="right") + plt.yticks(rotation=0) plt.tight_layout() - plt.savefig(output_path, dpi=300, bbox_inches="tight") + + plt.savefig(svg_path, format="svg", bbox_inches="tight") + plt.savefig(png_path, dpi=300, bbox_inches="tight") plt.show() + cm_df.to_csv(csv_path) + + with open(report_path, "w", encoding="utf-8") as f: + f.write(f"Model: {model_name}\n") + f.write(f"Result file: {result_file}\n") + f.write(f"Iteration: {TARGET_ITERATION}\n") + f.write(f"Shared color scale vmax: {GLOBAL_VMAX}\n") + f.write(f"Raw prediction rows: {item['raw_rows']}\n") + f.write(f"Prediction rows after filters: {item['pred_rows_after_filters']}\n") + f.write(f"Merged rows: {item['merged_rows']}\n") + f.write(f"Evaluable rows: {item['evaluable_rows']}\n\n") + f.write("Classification Report:\n") + f.write(report) + f.write("\n\nConfusion Matrix:\n") + f.write(cm_df.to_string()) + + keep_cols = [ + "row_index", + "unique_id_gt" if "unique_id_gt" in eval_df.columns else "unique_id", + "unique_id_pred" if "unique_id_pred" in eval_df.columns else None, + "MedDatum_gt" if "MedDatum_gt" in eval_df.columns else "MedDatum", + "MedDatum_pred" if "MedDatum_pred" in eval_df.columns else None, + "model", + "iteration", + "success", + "klassifizierbar", + "clinical_output_valid", + "edss_logic_valid", + "GT_EDSS_numeric", + "PRED_EDSS_numeric", + "GT_EDSS_cat", + "PRED_EDSS_cat", + "raw_EDSS", + "EDSS_numeric", + "EDSS_in_valid_range", + "certainty_percent", + "reason", + "inference_time_sec", + ] + + keep_cols = [ + c for c in keep_cols + if c is not None and c in eval_df.columns + ] + + eval_df[keep_cols].to_csv(merged_path, index=False) + + print("\nSaved:") + print(svg_path) + print(png_path) + print(csv_path) + print(report_path) + print(merged_path) + + summaries.append({ + "model": model_name, + "result_file": str(result_file), + "iteration": TARGET_ITERATION, + "raw_prediction_rows": item["raw_rows"], + "prediction_rows_after_filters": item["pred_rows_after_filters"], + "merged_rows": item["merged_rows"], + "evaluable_rows": item["evaluable_rows"], + "shared_color_vmax": GLOBAL_VMAX, + "svg_path": str(svg_path), + "png_path": str(png_path), + "csv_path": str(csv_path), + "report_path": str(report_path), + "merged_path": str(merged_path), + }) + + +# ========================= +# SAVE SUMMARY +# ========================= + +summary_df = pd.DataFrame(summaries) +summary_path = OUTPUT_DIR / f"confusion_matrix_summary_iter_{TARGET_ITERATION}.csv" +summary_df.to_csv(summary_path, index=False) + +print("\n" + "=" * 100) +print("Done.") +print(f"Summary saved to: {summary_path}") +print(f"Shared color scale vmax: {GLOBAL_VMAX}") +print("=" * 100) +## + +# %% EDSS metrics across models for new benchmark run + +from pathlib import Path + +import pandas as pd +import numpy as np +from sklearn.metrics import mean_absolute_error, mean_squared_error, cohen_kappa_score +from scipy.stats import spearmanr + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_PATH = RUN_DIR / f"edss_metrics_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def find_iter_file(model_dir, target_iteration): + files = sorted(model_dir.glob(f"*results_iter_{target_iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def safe_rate(numerator, denominator): + if denominator == 0: + return np.nan + return numerator / denominator + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) + +n_total_gt_rows = len(gt) +n_gt_numeric = gt["GT_EDSS_numeric"].notna().sum() + +gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows: {n_total_gt_rows}") +print(f"GT numeric EDSS rows: {n_gt_numeric}") + + +# ========================= +# EVALUATE MODELS +# ========================= + +rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() and not p.name.startswith("confusion") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"\nNo iter_{TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + raw_prediction_rows = len(pred_raw) + + if "row_index" not in pred_raw.columns: + print("Skipping: row_index missing") + continue + + model_name = ( + pred_raw["model"].dropna().iloc[0] + if "model" in pred_raw.columns and pred_raw["model"].notna().any() + else model_dir.name + ) + + # ------------------------- + # Diagnostics before filters + # ------------------------- + n_success = to_bool(pred_raw["success"]).sum() if "success" in pred_raw.columns else np.nan + n_clinical_output_valid = ( + to_bool(pred_raw["clinical_output_valid"]).sum() + if "clinical_output_valid" in pred_raw.columns else np.nan + ) + n_edss_logic_valid = ( + to_bool(pred_raw["edss_logic_valid"]).sum() + if "edss_logic_valid" in pred_raw.columns else np.nan + ) + n_klassifizierbar_true = ( + to_bool(pred_raw["klassifizierbar"]).sum() + if "klassifizierbar" in pred_raw.columns else np.nan + ) + n_edss_numeric = ( + to_bool(pred_raw["EDSS_is_numeric"]).sum() + if "EDSS_is_numeric" in pred_raw.columns else np.nan + ) + n_edss_valid_range = ( + to_bool(pred_raw["EDSS_in_valid_range"]).sum() + if "EDSS_in_valid_range" in pred_raw.columns else np.nan + ) + + print("Raw prediction rows:", raw_prediction_rows) + print("success=True:", n_success) + print("clinical_output_valid=True:", n_clinical_output_valid) + print("edss_logic_valid=True:", n_edss_logic_valid) + print("klassifizierbar=True:", n_klassifizierbar_true) + print("EDSS_is_numeric=True:", n_edss_numeric) + print("EDSS_in_valid_range=True:", n_edss_valid_range) + print("unique row_index:", pred_raw["row_index"].nunique()) + print("GT numeric EDSS:", n_gt_numeric) + + # ------------------------- + # Prepare predictions + # ------------------------- + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + # For EDSS score accuracy, use only predictions where model actually gave a numeric EDSS. + # This automatically excludes valid abstentions with klassifizierbar=false and EDSS=null. + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + n_after_filtering = len(pred) + + merged = gt_numeric.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + n_evaluable = len(merged) + + if n_evaluable == 0: + print("No evaluable rows. Skipping metrics.") + continue + + # ------------------------- + # Metrics + # ------------------------- + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + mae = mean_absolute_error( + merged["GT_EDSS_numeric"], + merged["PRED_EDSS_numeric"] + ) + + rmse = np.sqrt( + mean_squared_error( + merged["GT_EDSS_numeric"], + merged["PRED_EDSS_numeric"] + ) + ) + + median_abs_error = merged["abs_error"].median() + mean_signed_error = merged["error"].mean() + + exact_accuracy_valid_only = (merged["abs_error"] == 0).mean() + accuracy_within_05_valid_only = (merged["abs_error"] <= 0.5).mean() + accuracy_within_10_valid_only = (merged["abs_error"] <= 1.0).mean() + + exact_correct_count = int((merged["abs_error"] == 0).sum()) + within_05_count = int((merged["abs_error"] <= 0.5).sum()) + within_10_count = int((merged["abs_error"] <= 1.0).sum()) + + # Coverage-adjusted accuracies use all GT numeric rows as denominator. + # Missing/abstained/non-numeric predictions count as not correct. + exact_accuracy_all_gt_numeric = safe_rate(exact_correct_count, n_gt_numeric) + accuracy_within_05_all_gt_numeric = safe_rate(within_05_count, n_gt_numeric) + accuracy_within_10_all_gt_numeric = safe_rate(within_10_count, n_gt_numeric) + + coverage_gt_numeric = safe_rate(n_evaluable, n_gt_numeric) + coverage_all_rows = safe_rate(n_evaluable, n_total_gt_rows) + + if n_evaluable > 1: + spearman_rho, spearman_p = spearmanr( + merged["GT_EDSS_numeric"], + merged["PRED_EDSS_numeric"] + ) + else: + spearman_rho, spearman_p = np.nan, np.nan + + gt_half_steps = (merged["GT_EDSS_numeric"] * 2).round().astype(int) + pred_half_steps = (merged["PRED_EDSS_numeric"] * 2).round().astype(int) + + quadratic_weighted_kappa = cohen_kappa_score( + gt_half_steps, + pred_half_steps, + weights="quadratic" + ) + + mean_inference_time = ( + merged["inference_time_sec"].mean() + if "inference_time_sec" in merged.columns + else np.nan + ) + + rows.append({ + "model": model_name, + "result_file": str(result_file), + "iteration": TARGET_ITERATION, + + "n_total_gt_rows": n_total_gt_rows, + "n_gt_numeric": n_gt_numeric, + "raw_prediction_rows": raw_prediction_rows, + + "n_success": n_success, + "success_rate": safe_rate(n_success, raw_prediction_rows), + + "n_clinical_output_valid": n_clinical_output_valid, + "clinical_output_valid_rate": safe_rate(n_clinical_output_valid, raw_prediction_rows), + + "n_edss_logic_valid": n_edss_logic_valid, + "edss_logic_valid_rate": safe_rate(n_edss_logic_valid, raw_prediction_rows), + + "n_klassifizierbar_true": n_klassifizierbar_true, + "klassifizierbar_true_rate": safe_rate(n_klassifizierbar_true, raw_prediction_rows), + + "n_EDSS_numeric": n_edss_numeric, + "EDSS_numeric_rate": safe_rate(n_edss_numeric, raw_prediction_rows), + + "n_EDSS_valid_range": n_edss_valid_range, + "EDSS_valid_range_rate": safe_rate(n_edss_valid_range, raw_prediction_rows), + + "n_after_filtering": n_after_filtering, + "n_evaluable": n_evaluable, + "coverage_gt_numeric": coverage_gt_numeric, + "coverage_gt_numeric_percent": coverage_gt_numeric * 100, + "coverage_all_rows": coverage_all_rows, + "coverage_all_rows_percent": coverage_all_rows * 100, + + "MAE_valid_only": mae, + "median_absolute_error_valid_only": median_abs_error, + "RMSE_valid_only": rmse, + "mean_signed_error_valid_only": mean_signed_error, + + "exact_accuracy_valid_only": exact_accuracy_valid_only, + "accuracy_within_0_5_valid_only": accuracy_within_05_valid_only, + "accuracy_within_1_0_valid_only": accuracy_within_10_valid_only, + + "exact_accuracy_valid_only_percent": exact_accuracy_valid_only * 100, + "accuracy_within_0_5_valid_only_percent": accuracy_within_05_valid_only * 100, + "accuracy_within_1_0_valid_only_percent": accuracy_within_10_valid_only * 100, + + "exact_accuracy_all_gt_numeric": exact_accuracy_all_gt_numeric, + "accuracy_within_0_5_all_gt_numeric": accuracy_within_05_all_gt_numeric, + "accuracy_within_1_0_all_gt_numeric": accuracy_within_10_all_gt_numeric, + + "exact_accuracy_all_gt_numeric_percent": exact_accuracy_all_gt_numeric * 100, + "accuracy_within_0_5_all_gt_numeric_percent": accuracy_within_05_all_gt_numeric * 100, + "accuracy_within_1_0_all_gt_numeric_percent": accuracy_within_10_all_gt_numeric * 100, + + "spearman_rho": spearman_rho, + "spearman_p": spearman_p, + "quadratic_weighted_kappa": quadratic_weighted_kappa, + + "mean_inference_time_sec": mean_inference_time, + }) + + print("\nMetrics:") + print(f"Model: {model_name}") + print(f"n_evaluable: {n_evaluable}") + print(f"Coverage of GT numeric rows: {coverage_gt_numeric * 100:.1f}%") + print(f"MAE: {mae:.3f}") + print(f"Median AE: {median_abs_error:.3f}") + print(f"RMSE: {rmse:.3f}") + print(f"Mean signed error: {mean_signed_error:.3f}") + print(f"Exact accuracy valid-only: {exact_accuracy_valid_only * 100:.1f}%") + print(f"Accuracy ±0.5 valid-only: {accuracy_within_05_valid_only * 100:.1f}%") + print(f"Accuracy ±1.0 valid-only: {accuracy_within_10_valid_only * 100:.1f}%") + print(f"Accuracy ±0.5 all GT numeric: {accuracy_within_05_all_gt_numeric * 100:.1f}%") + print(f"Spearman rho: {spearman_rho:.3f}") + print(f"Quadratic weighted kappa: {quadratic_weighted_kappa:.3f}") + print(f"Mean inference time: {mean_inference_time:.3f} sec") + + +# ========================= +# SAVE METRICS TABLE +# ========================= + +metrics_df = pd.DataFrame(rows) + +if not metrics_df.empty: + metrics_df = metrics_df.sort_values("MAE_valid_only") + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 240) + +print("\n" + "=" * 100) +print("EDSS model comparison metrics:") +print(metrics_df) + +metrics_df.to_csv(OUTPUT_PATH, index=False) + +print(f"\nSaved metrics table to:") +print(OUTPUT_PATH) +## + +# %% Per-patient repeated-run variability across 10 EDSS runs + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +OUTPUT_DIR = RUN_DIR / "repeated_run_variability" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +N_EXPECTED_RUNS = 10 + +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" + +# Use only valid numeric EDSS predictions. +# This excludes valid abstentions where klassifizierbar=false and EDSS=null. +USE_ONLY_VALID_RANGE_EDSS = True + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def find_iteration_files(model_dir): + files = sorted(model_dir.glob("*results_iter_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files + + +def load_model_all_iterations(model_dir): + files = find_iteration_files(model_dir) + + if not files: + return pd.DataFrame(), [] + + dfs = [] + + for file in files: + df = pd.read_csv(file, sep=",") + + if "iteration" not in df.columns: + print(f"Skipping {file}: no iteration column") + continue + + if "row_index" not in df.columns: + print(f"Skipping {file}: no row_index column") + continue + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in df.columns else PRED_EDSS_FALLBACK_COL + + df = df.copy() + df["source_file"] = str(file) + df["row_index"] = pd.to_numeric(df["row_index"], errors="coerce") + df["iteration"] = pd.to_numeric(df["iteration"], errors="coerce") + df["EDSS_prediction"] = to_num(df[pred_col]) + + df = df.dropna(subset=["row_index", "iteration"]).copy() + df["row_index"] = df["row_index"].astype(int) + df["iteration"] = df["iteration"].astype(int) + + if "success" in df.columns: + df = df[to_bool(df["success"])].copy() + + if "EDSS_is_numeric" in df.columns: + df = df[to_bool(df["EDSS_is_numeric"])].copy() + + if USE_ONLY_VALID_RANGE_EDSS and "EDSS_in_valid_range" in df.columns: + df = df[to_bool(df["EDSS_in_valid_range"])].copy() + + df = df.dropna(subset=["EDSS_prediction"]).copy() + + keep_cols = [ + "model", + "iteration", + "row_index", + "row_number_in_run", + "unique_id", + "MedDatum", + "EDSS_prediction", + "EDSS_numeric", + "EDSS", + "EDSS_is_numeric", + "EDSS_in_valid_range", + "klassifizierbar", + "clinical_output_valid", + "edss_logic_valid", + "certainty_percent", + "inference_time_sec", + "source_file", + ] + + keep_cols = [c for c in keep_cols if c in df.columns] + dfs.append(df[keep_cols]) + + if not dfs: + return pd.DataFrame(), files + + all_df = pd.concat(dfs, ignore_index=True) + + # If there are duplicate row_index + iteration rows, keep first. + all_df = all_df.sort_values(["row_index", "iteration"]) + all_df = all_df.drop_duplicates(subset=["row_index", "iteration"], keep="first") + + return all_df, files + + +def summarize_patient_variability(all_df, model_name): + """ + One row per patient / row_index. + """ + grouped = all_df.groupby("row_index") + + patient_rows = [] + + for row_index, g in grouped: + preds = g["EDSS_prediction"].dropna().astype(float) + + n_valid_runs = len(preds) + + if n_valid_runs == 0: + continue + + unique_id = g["unique_id"].dropna().iloc[0] if "unique_id" in g.columns and g["unique_id"].notna().any() else None + meddatum = g["MedDatum"].dropna().iloc[0] if "MedDatum" in g.columns and g["MedDatum"].notna().any() else None + + edss_mean = preds.mean() + edss_std = preds.std(ddof=0) # population SD across repeated runs + edss_median = preds.median() + edss_min = preds.min() + edss_max = preds.max() + edss_range = edss_max - edss_min + n_unique_predictions = preds.nunique() + + identical_all_available_runs = n_unique_predictions == 1 + range_leq_0_5 = edss_range <= 0.5 + complete_10_valid_runs = n_valid_runs == N_EXPECTED_RUNS + + patient_rows.append({ + "model": model_name, + "row_index": row_index, + "unique_id": unique_id, + "MedDatum": meddatum, + + "n_valid_runs": n_valid_runs, + "complete_10_valid_runs": complete_10_valid_runs, + + "EDSS_mean_across_runs": edss_mean, + "EDSS_median_across_runs": edss_median, + "EDSS_std_across_runs": edss_std, + "EDSS_min_across_runs": edss_min, + "EDSS_max_across_runs": edss_max, + "EDSS_range_across_runs": edss_range, + "n_unique_EDSS_predictions": n_unique_predictions, + + "identical_EDSS_all_available_runs": identical_all_available_runs, + "EDSS_range_leq_0_5": range_leq_0_5, + + "iterations_available": ",".join(map(str, sorted(g["iteration"].unique()))), + "EDSS_predictions_by_iteration": ";".join( + f"{int(row.iteration)}:{row.EDSS_prediction}" + for row in g.sort_values("iteration").itertuples() + ), + }) + + return pd.DataFrame(patient_rows) + + +def summarize_model_variability(patient_df, all_df, model_name, n_source_files): + if patient_df.empty: + return { + "model": model_name, + "n_source_iteration_files": n_source_files, + "n_patients_with_at_least_one_valid_prediction": 0, + } + + n_patients = len(patient_df) + + complete_df = patient_df[patient_df["complete_10_valid_runs"]].copy() + + summary = { + "model": model_name, + "n_source_iteration_files": n_source_files, + + "n_total_prediction_rows_valid": len(all_df), + "n_patients_with_at_least_one_valid_prediction": n_patients, + "n_patients_with_10_valid_runs": len(complete_df), + "patients_with_10_valid_runs_percent": len(complete_df) / n_patients * 100 if n_patients else np.nan, + + # Main variability metrics across all patients with at least one valid prediction + "mean_std_EDSS_across_runs": patient_df["EDSS_std_across_runs"].mean(), + "median_std_EDSS_across_runs": patient_df["EDSS_std_across_runs"].median(), + + "mean_range_EDSS_across_runs": patient_df["EDSS_range_across_runs"].mean(), + "median_range_EDSS_across_runs": patient_df["EDSS_range_across_runs"].median(), + + "percent_identical_EDSS_all_available_runs": patient_df["identical_EDSS_all_available_runs"].mean() * 100, + "percent_EDSS_range_leq_0_5": patient_df["EDSS_range_leq_0_5"].mean() * 100, + + "mean_n_valid_runs_per_patient": patient_df["n_valid_runs"].mean(), + "median_n_valid_runs_per_patient": patient_df["n_valid_runs"].median(), + "min_n_valid_runs_per_patient": patient_df["n_valid_runs"].min(), + "max_n_valid_runs_per_patient": patient_df["n_valid_runs"].max(), + } + + # Same metrics restricted to patients with all 10 valid runs + if not complete_df.empty: + summary.update({ + "mean_std_EDSS_10_valid_runs_only": complete_df["EDSS_std_across_runs"].mean(), + "median_std_EDSS_10_valid_runs_only": complete_df["EDSS_std_across_runs"].median(), + + "mean_range_EDSS_10_valid_runs_only": complete_df["EDSS_range_across_runs"].mean(), + "median_range_EDSS_10_valid_runs_only": complete_df["EDSS_range_across_runs"].median(), + + "percent_identical_EDSS_10_valid_runs_only": complete_df["identical_EDSS_all_available_runs"].mean() * 100, + "percent_EDSS_range_leq_0_5_10_valid_runs_only": complete_df["EDSS_range_leq_0_5"].mean() * 100, + }) + else: + summary.update({ + "mean_std_EDSS_10_valid_runs_only": np.nan, + "median_std_EDSS_10_valid_runs_only": np.nan, + "mean_range_EDSS_10_valid_runs_only": np.nan, + "median_range_EDSS_10_valid_runs_only": np.nan, + "percent_identical_EDSS_10_valid_runs_only": np.nan, + "percent_EDSS_range_leq_0_5_10_valid_runs_only": np.nan, + }) + + return summary + # ========================= # MAIN # ========================= -if __name__ == "__main__": +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and p.name != "repeated_run_variability" +] - output_dir = Path(OUTPUT_DIR) - output_dir.mkdir(parents=True, exist_ok=True) +all_model_summaries = [] - print("Loading reference:") - print(REFERENCE_PATH) +for model_dir in model_dirs: + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") - df_ref = load_reference(REFERENCE_PATH) + all_df, source_files = load_model_all_iterations(model_dir) - print(f"Reference rows: {len(df_ref)}") - print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}") + if all_df.empty: + print("No valid prediction data found. Skipping.") + continue - print("\nLoading result:") - print(RESULT_PATH) + model_name = ( + all_df["model"].dropna().iloc[0] + if "model" in all_df.columns and all_df["model"].notna().any() + else model_dir.name + ) - df_res = load_result(RESULT_PATH) + print(f"Model name: {model_name}") + print(f"Iteration files found: {len(source_files)}") + print(f"Valid numeric EDSS prediction rows: {len(all_df)}") + print(f"Patients with at least one valid EDSS prediction: {all_df['row_index'].nunique()}") - model_name = get_model_name(df_res, RESULT_PATH) - safe_model = safe_filename(model_name) + patient_df = summarize_patient_variability(all_df, model_name) + model_summary = summarize_model_variability( + patient_df=patient_df, + all_df=all_df, + model_name=model_name, + n_source_files=len(source_files), + ) + + all_model_summaries.append(model_summary) + + safe_model = model_name.replace("/", "_").replace(" ", "_") + + patient_out = OUTPUT_DIR / f"{safe_model}_per_patient_repeated_run_variability.csv" + all_preds_out = OUTPUT_DIR / f"{safe_model}_all_valid_predictions_long.csv" + + patient_df.to_csv(patient_out, index=False) + all_df.to_csv(all_preds_out, index=False) + + print("\nMain variability metrics:") + print(f"Mean SD across runs: {model_summary['mean_std_EDSS_across_runs']:.3f}") + print(f"Median SD across runs: {model_summary['median_std_EDSS_across_runs']:.3f}") + print(f"Identical EDSS all available runs: {model_summary['percent_identical_EDSS_all_available_runs']:.1f}%") + print(f"Range <= 0.5: {model_summary['percent_EDSS_range_leq_0_5']:.1f}%") + print(f"Patients with 10 valid runs: {model_summary['n_patients_with_10_valid_runs']}") + + print("\nSaved:") + print(patient_out) + print(all_preds_out) + + +summary_df = pd.DataFrame(all_model_summaries) + +summary_out = OUTPUT_DIR / "repeated_run_variability_summary.csv" +summary_df.to_csv(summary_out, index=False) + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 240) + +print("\n" + "=" * 100) +print("Repeated-run variability summary:") +print(summary_df) +print("\nSaved summary to:") +print(summary_out) +## + +# %% Functional system performance per domain - iteration 1 + +from pathlib import Path + +import pandas as pd +import numpy as np +from sklearn.metrics import mean_absolute_error, mean_squared_error +from scipy.stats import spearmanr + + +# ========================= +# PATHS +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"functional_system_metrics_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_FULL_TABLE = OUTPUT_DIR / f"functional_system_metrics_full_iter_{TARGET_ITERATION}.csv" +OUTPUT_SHORT_TABLE = OUTPUT_DIR / f"functional_system_metrics_short_iter_{TARGET_ITERATION}.csv" + + +# ========================= +# FUNCTIONAL SYSTEM MAPPING +# ========================= + +FS_MAP = { + "VISUAL_OPTIC_FUNCTIONS": { + "display": "Visual/optic functions", + "gt": "Sehvermögen", + "pred": "numeric_subcat_VISUAL_OPTIC_FUNCTIONS", + "fallback": "subcat_VISUAL_OPTIC_FUNCTIONS", + "numeric_flag": "subcat_VISUAL_OPTIC_FUNCTIONS_is_numeric", + "range_flag": "subcat_VISUAL_OPTIC_FUNCTIONS_in_valid_range", + }, + "BRAINSTEM_FUNCTIONS": { + "display": "Brainstem functions", + "gt": "Hirnstamm", + "pred": "numeric_subcat_BRAINSTEM_FUNCTIONS", + "fallback": "subcat_BRAINSTEM_FUNCTIONS", + "numeric_flag": "subcat_BRAINSTEM_FUNCTIONS_is_numeric", + "range_flag": "subcat_BRAINSTEM_FUNCTIONS_in_valid_range", + }, + "PYRAMIDAL_FUNCTIONS": { + "display": "Pyramidal functions", + "gt": "Pyramidalmotorik", + "pred": "numeric_subcat_PYRAMIDAL_FUNCTIONS", + "fallback": "subcat_PYRAMIDAL_FUNCTIONS", + "numeric_flag": "subcat_PYRAMIDAL_FUNCTIONS_is_numeric", + "range_flag": "subcat_PYRAMIDAL_FUNCTIONS_in_valid_range", + }, + "CEREBELLAR_FUNCTIONS": { + "display": "Cerebellar functions", + "gt": "Cerebellum", + "pred": "numeric_subcat_CEREBELLAR_FUNCTIONS", + "fallback": "subcat_CEREBELLAR_FUNCTIONS", + "numeric_flag": "subcat_CEREBELLAR_FUNCTIONS_is_numeric", + "range_flag": "subcat_CEREBELLAR_FUNCTIONS_in_valid_range", + }, + "SENSORY_FUNCTIONS": { + "display": "Sensory functions", + "gt": "Sensibiliät", + "pred": "numeric_subcat_SENSORY_FUNCTIONS", + "fallback": "subcat_SENSORY_FUNCTIONS", + "numeric_flag": "subcat_SENSORY_FUNCTIONS_is_numeric", + "range_flag": "subcat_SENSORY_FUNCTIONS_in_valid_range", + }, + "BOWEL_AND_BLADDER_FUNCTIONS": { + "display": "Bowel and bladder functions", + "gt": "Blasen-_und_Mastdarmfunktion", + "pred": "numeric_subcat_BOWEL_AND_BLADDER_FUNCTIONS", + "fallback": "subcat_BOWEL_AND_BLADDER_FUNCTIONS", + "numeric_flag": "subcat_BOWEL_AND_BLADDER_FUNCTIONS_is_numeric", + "range_flag": "subcat_BOWEL_AND_BLADDER_FUNCTIONS_in_valid_range", + }, + "CEREBRAL_FUNCTIONS": { + "display": "Cerebral functions", + "gt": "Cerebrale_Funktion", + "pred": "numeric_subcat_CEREBRAL_FUNCTIONS", + "fallback": "subcat_CEREBRAL_FUNCTIONS", + "numeric_flag": "subcat_CEREBRAL_FUNCTIONS_is_numeric", + "range_flag": "subcat_CEREBRAL_FUNCTIONS_in_valid_range", + }, + "AMBULATION": { + "display": "Ambulation", + "gt": "Ambulation", + "pred": "numeric_subcat_AMBULATION", + "fallback": "subcat_AMBULATION", + "numeric_flag": "subcat_AMBULATION_is_numeric", + "range_flag": "subcat_AMBULATION_in_valid_range", + }, +} + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def rate(n, d): + if d == 0: + return np.nan + return n / d + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index + +print(f"GT rows: {len(gt)}") + +for fs_key, info in FS_MAP.items(): + if info["gt"] not in gt.columns: + print(f"WARNING missing GT column: {info['gt']}") + else: + gt_num = to_num(gt[info["gt"]]) + print( + f"{info['display']}: " + f"GT numeric={gt_num.notna().sum()}, " + f"GT non-zero={(gt_num.dropna() != 0).sum()}" + ) + + +# ========================= +# MAIN ANALYSIS +# ========================= + +rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and p.name not in [ + f"functional_system_metrics_iter_{TARGET_ITERATION}", + "repeated_run_variability", + f"confusion_matrices_iter_{TARGET_ITERATION}", + ] + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"\nNo iteration {TARGET_ITERATION} file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred.columns: + print("Skipping: no row_index column.") + continue + + model_name = get_model_name(pred, model_dir) + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="left", + suffixes=("_gt", "_pred") + ) + + print(f"Model name: {model_name}") + print(f"Prediction rows after success filter: {len(pred)}") + print(f"Merged rows: {len(merged)}") + + for fs_key, info in FS_MAP.items(): + gt_col = info["gt"] + + if gt_col not in merged.columns: + print(f"Skipping {info['display']}: missing GT column {gt_col}") + continue + + pred_col = info["pred"] if info["pred"] in merged.columns else info["fallback"] + + if pred_col not in merged.columns: + print(f"Skipping {info['display']}: missing prediction column") + continue + + temp = merged.copy() + + temp["GT_value"] = to_num(temp[gt_col]) + temp["PRED_value"] = to_num(temp[pred_col]) + + gt_numeric_df = temp.dropna(subset=["GT_value"]).copy() + + n_gt_numeric = len(gt_numeric_df) + n_nonzero_gt = int((gt_numeric_df["GT_value"] != 0).sum()) + percent_nonzero_gt = rate(n_nonzero_gt, n_gt_numeric) * 100 + + if info["numeric_flag"] in gt_numeric_df.columns: + pred_numeric_flag = to_bool(gt_numeric_df[info["numeric_flag"]]) + else: + pred_numeric_flag = gt_numeric_df["PRED_value"].notna() + + if info["range_flag"] in gt_numeric_df.columns: + pred_range_flag = to_bool(gt_numeric_df[info["range_flag"]]) + else: + pred_range_flag = gt_numeric_df["PRED_value"].notna() + + valid_pred = ( + pred_numeric_flag + & pred_range_flag + & gt_numeric_df["PRED_value"].notna() + ) + + n_missing_or_invalid = int((~valid_pred).sum()) + percent_missing_or_invalid = rate(n_missing_or_invalid, n_gt_numeric) * 100 + + eval_df = gt_numeric_df[valid_pred].copy() + n_evaluable = len(eval_df) + + if n_evaluable > 0: + eval_df["error"] = eval_df["PRED_value"] - eval_df["GT_value"] + eval_df["abs_error"] = eval_df["error"].abs() + + mae = mean_absolute_error(eval_df["GT_value"], eval_df["PRED_value"]) + median_ae = eval_df["abs_error"].median() + rmse = np.sqrt(mean_squared_error(eval_df["GT_value"], eval_df["PRED_value"])) + + exact_acc = (eval_df["abs_error"] == 0).mean() + acc_05 = (eval_df["abs_error"] <= 0.5).mean() + acc_10 = (eval_df["abs_error"] <= 1.0).mean() + + enough_variation = ( + n_evaluable >= 3 + and eval_df["GT_value"].nunique() > 1 + and eval_df["PRED_value"].nunique() > 1 + ) + + if enough_variation: + spearman_rho, spearman_p = spearmanr( + eval_df["GT_value"], + eval_df["PRED_value"] + ) + else: + spearman_rho, spearman_p = np.nan, np.nan + + else: + mae = np.nan + median_ae = np.nan + rmse = np.nan + exact_acc = np.nan + acc_05 = np.nan + acc_10 = np.nan + spearman_rho = np.nan + spearman_p = np.nan + enough_variation = False + + row = { + "model": model_name, + "iteration": TARGET_ITERATION, + "result_file": str(result_file), + + "functional_system_key": fs_key, + "functional_system": info["display"], + + "n_gt_numeric": n_gt_numeric, + "n_evaluable": n_evaluable, + + "n_nonzero_ground_truth": n_nonzero_gt, + "percent_nonzero_ground_truth": percent_nonzero_gt, + + "n_missing_or_invalid_model_outputs": n_missing_or_invalid, + "percent_missing_or_invalid_model_outputs": percent_missing_or_invalid, + + "MAE": mae, + "median_absolute_error": median_ae, + "RMSE": rmse, + + "exact_accuracy": exact_acc, + "accuracy_within_0_5": acc_05, + "accuracy_within_1_0": acc_10, + + "exact_accuracy_percent": exact_acc * 100 if pd.notna(exact_acc) else np.nan, + "accuracy_within_0_5_percent": acc_05 * 100 if pd.notna(acc_05) else np.nan, + "accuracy_within_1_0_percent": acc_10 * 100 if pd.notna(acc_10) else np.nan, + + "spearman_rho": spearman_rho, + "spearman_p": spearman_p, + "spearman_calculated": enough_variation, + } + + rows.append(row) + + print( + f"{info['display']}: " + f"n_eval={n_evaluable}, " + f"non-zero GT={n_nonzero_gt} ({percent_nonzero_gt:.1f}%), " + f"MAE={mae:.3f}, " + f"±0.5={row['accuracy_within_0_5_percent']:.1f}%, " + f"missing/invalid={percent_missing_or_invalid:.1f}%" + ) + + +# ========================= +# SAVE TABLES +# ========================= + +metrics_df = pd.DataFrame(rows) + +metrics_df.to_csv(OUTPUT_FULL_TABLE, index=False) + +short_cols = [ + "model", + "functional_system", + "n_evaluable", + "n_nonzero_ground_truth", + "percent_nonzero_ground_truth", + "MAE", + "median_absolute_error", + "RMSE", + "exact_accuracy_percent", + "accuracy_within_0_5_percent", + "accuracy_within_1_0_percent", + "spearman_rho", + "percent_missing_or_invalid_model_outputs", +] + +short_df = metrics_df[short_cols].copy() +short_df.to_csv(OUTPUT_SHORT_TABLE, index=False) + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 240) + +print("\n" + "=" * 100) +print("Functional system performance table:") +print(metrics_df) + +print("\n" + "=" * 100) +print("Short table:") +print(short_df) + +print("\nSaved:") +print(OUTPUT_FULL_TABLE) +print(OUTPUT_SHORT_TABLE) +## + + +# %% Functional Systems + EDSS Error Category Stacked Bar Plot per Model + +from pathlib import Path +import re + +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"functional_system_error_category_plots_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# COLUMN MAPPING +# ========================= + +SYSTEMS_TO_PLOT = [ + { + "name": "Visual Optic Functions", + "gt_col": "Sehvermögen", + "pred_col": "numeric_subcat_VISUAL_OPTIC_FUNCTIONS", + "pred_fallback_col": "subcat_VISUAL_OPTIC_FUNCTIONS", + }, + { + "name": "Cerebellar Functions", + "gt_col": "Cerebellum", + "pred_col": "numeric_subcat_CEREBELLAR_FUNCTIONS", + "pred_fallback_col": "subcat_CEREBELLAR_FUNCTIONS", + }, + { + "name": "Brainstem Functions", + "gt_col": "Hirnstamm", + "pred_col": "numeric_subcat_BRAINSTEM_FUNCTIONS", + "pred_fallback_col": "subcat_BRAINSTEM_FUNCTIONS", + }, + { + "name": "Sensory Functions", + "gt_col": "Sensibiliät", + "pred_col": "numeric_subcat_SENSORY_FUNCTIONS", + "pred_fallback_col": "subcat_SENSORY_FUNCTIONS", + }, + { + "name": "Pyramidal Functions", + "gt_col": "Pyramidalmotorik", + "pred_col": "numeric_subcat_PYRAMIDAL_FUNCTIONS", + "pred_fallback_col": "subcat_PYRAMIDAL_FUNCTIONS", + }, + { + "name": "Ambulation", + "gt_col": "Ambulation", + "pred_col": "numeric_subcat_AMBULATION", + "pred_fallback_col": "subcat_AMBULATION", + }, + { + "name": "Cerebral Functions", + "gt_col": "Cerebrale_Funktion", + "pred_col": "numeric_subcat_CEREBRAL_FUNCTIONS", + "pred_fallback_col": "subcat_CEREBRAL_FUNCTIONS", + }, + { + "name": "Bowel And Bladder Functions", + "gt_col": "Blasen-_und_Mastdarmfunktion", + "pred_col": "numeric_subcat_BOWEL_AND_BLADDER_FUNCTIONS", + "pred_fallback_col": "subcat_BOWEL_AND_BLADDER_FUNCTIONS", + }, + { + "name": "EDSS", + "gt_col": "EDSS", + "pred_col": "EDSS_numeric", + "pred_fallback_col": "EDSS", + }, +] + +SYSTEM_ORDER = [ + "Visual Optic Functions", + "Cerebellar Functions", + "Brainstem Functions", + "Sensory Functions", + "Pyramidal Functions", + "Ambulation", + "Cerebral Functions", + "Bowel And Bladder Functions", + "EDSS", +] + +CATEGORY_ORDER = [ + "Exact", + "≤0.5 error", + "≤1 error", + ">1 error", + "Missing/invalid", +] + +# Blue-based palette for correct predictions +COLORS = { + "Exact": "#1F77B4", # blue + "≤0.5 error": "#9ECAE1", # light blue + "≤1 error": "#FDDC7A", # yellow + ">1 error": "#F28E2B", # orange + "Missing/invalid": "#D62728" # red +} + + +# ========================= +# HELPERS +# ========================= + +def safe_name(name): + return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) + + +def safe_parse_series(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def categorize_error(abs_error): + if pd.isna(abs_error): + return "Missing/invalid" + if abs_error == 0: + return "Exact" + if abs_error <= 0.5: + return "≤0.5 error" + if abs_error <= 1.0: + return "≤1 error" + return ">1 error" + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def get_column_after_merge(df, base_col, side): + """ + After merge with suffixes=('_gt', '_pred'), duplicated columns become: + EDSS_gt and EDSS_pred. + + For non-duplicated GT-only columns, the name remains unchanged. + """ + if base_col in df.columns: + return base_col + + suffixed = f"{base_col}_{side}" + if suffixed in df.columns: + return suffixed + + return None + + +def prepare_plot_data(gt, pred): + rows = [] + + merged = gt.merge( + pred, + on="row_index", + how="left", + suffixes=("_gt", "_pred") + ) + + for system in SYSTEMS_TO_PLOT: + system_name = system["name"] + + gt_col = get_column_after_merge(merged, system["gt_col"], "gt") + + pred_col = None + if system["pred_col"] in merged.columns: + pred_col = system["pred_col"] + elif system["pred_fallback_col"] in merged.columns: + pred_col = system["pred_fallback_col"] + elif f"{system['pred_fallback_col']}_pred" in merged.columns: + pred_col = f"{system['pred_fallback_col']}_pred" + + if gt_col is None: + print(f"Skipping {system_name}: GT column not found: {system['gt_col']}") + continue + + if pred_col is None: + print(f"Skipping {system_name}: prediction column not found") + continue + + gt_values = safe_parse_series(merged[gt_col]) + pred_values = safe_parse_series(merged[pred_col]) + + # Evaluate only cases where ground truth exists. + gt_exists = gt_values.notna() + + for gt_value, pred_value in zip(gt_values[gt_exists], pred_values[gt_exists]): + if pd.isna(pred_value): + category = "Missing/invalid" + else: + abs_error = abs(pred_value - gt_value) + category = categorize_error(abs_error) + + rows.append({ + "system": system_name, + "category": category, + }) + + plot_df = pd.DataFrame(rows) + + if plot_df.empty: + raise ValueError("No valid data available for plotting.") + + counts = ( + plot_df + .groupby(["system", "category"]) + .size() + .unstack(fill_value=0) + ) + + counts = counts.reindex(index=SYSTEM_ORDER) + counts = counts.reindex(columns=CATEGORY_ORDER, fill_value=0) + counts = counts.fillna(0) + + # Remove systems with no available GT rows. + counts = counts[counts.sum(axis=1) > 0] + + percentages = counts.div(counts.sum(axis=1), axis=0) * 100 + percentages = percentages.fillna(0) + + return counts, percentages + + +def plot_error_categories(counts, percentages, model_name, output_base): + fig, ax = plt.subplots(figsize=(13, 7)) + + left = np.zeros(len(percentages)) + + for category in CATEGORY_ORDER: + values = percentages[category].values + + ax.barh( + percentages.index, + values, + left=left, + color=COLORS[category], + edgecolor="white", + linewidth=0.8, + label=category, + ) + + for i, value in enumerate(values): + if value >= 4: + ax.text( + left[i] + value / 2, + i, + f"{value:.1f}%", + ha="center", + va="center", + fontsize=8, + fontweight="bold", + color="black", + ) + + left += values + + for i, system in enumerate(percentages.index): + total_n = int(counts.loc[system].sum()) + + if "Missing/invalid" in counts.columns: + missing_n = int(counts.loc[system, "Missing/invalid"]) + else: + missing_n = 0 + + ax.text( + 101, + i, + f"n={total_n}, missing={missing_n}", + va="center", + ha="left", + fontsize=9, + ) + + ax.set_xlim(0, 115) + ax.set_xlabel("Percentage of Cases", fontsize=11, fontweight="bold") + ax.set_ylabel("Functional System / EDSS", fontsize=11, fontweight="bold") + + ax.set_title( + f"Prediction Error Categories by Functional System and EDSS\n{model_name}, Iteration {TARGET_ITERATION}", + fontsize=13, + fontweight="bold", + pad=35, + ) + + ax.set_xticks(np.arange(0, 101, 10)) + ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) + + ax.xaxis.grid(True, linestyle="--", alpha=0.3) + ax.set_axisbelow(True) + + for spine in ["top", "right", "left"]: + ax.spines[spine].set_visible(False) + + ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=5, + frameon=False, + ) + + plt.tight_layout(rect=[0, 0, 1, 0.90]) + + svg_path = output_base.with_suffix(".svg") + png_path = output_base.with_suffix(".png") + + plt.savefig(svg_path, format="svg", bbox_inches="tight") + plt.savefig(png_path, dpi=300, bbox_inches="tight") + + plt.show() + + return svg_path, png_path + + +# ========================= +# LOAD GT +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index + +print(f"GT rows: {len(gt)}") + + +# ========================= +# MAIN +# ========================= + +summary_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred.columns: + print("Skipping: no row_index column.") + continue + + model_name = get_model_name(pred, model_dir) + safe_model = safe_name(model_name) + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[ + pred["success"] + .astype(str) + .str.lower() + .isin(["true", "1", "yes", "ja"]) + ].copy() + + pred = pred.drop_duplicates("row_index", keep="first").copy() + + counts, percentages = prepare_plot_data(gt, pred) + + output_base = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_categories_iter_{TARGET_ITERATION}" + + svg_path, png_path = plot_error_categories( + counts=counts, + percentages=percentages, + model_name=model_name, + output_base=output_base, + ) + + counts_path = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_category_counts_iter_{TARGET_ITERATION}.csv" + percentages_path = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_category_percentages_iter_{TARGET_ITERATION}.csv" + + counts.to_csv(counts_path) + percentages.to_csv(percentages_path) + + print("Saved:") + print(svg_path) + print(png_path) + print(counts_path) + print(percentages_path) + + summary_rows.append({ + "model": model_name, + "iteration": TARGET_ITERATION, + "result_file": str(result_file), + "svg_path": str(svg_path), + "png_path": str(png_path), + "counts_path": str(counts_path), + "percentages_path": str(percentages_path), + }) + + +summary_df = pd.DataFrame(summary_rows) +summary_path = OUTPUT_DIR / f"functional_systems_edss_error_category_plot_summary_iter_{TARGET_ITERATION}.csv" +summary_df.to_csv(summary_path, index=False) + +print("\n" + "=" * 100) +print("Done.") +print(f"Summary saved to: {summary_path}") +print("=" * 100) +## + + + +# %% EDSS error distribution per model + +from pathlib import Path + +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"edss_error_distribution_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_PATH = OUTPUT_DIR / f"edss_error_distribution_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG_PATH = OUTPUT_DIR / f"edss_error_distribution_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def rate(n, d): + if d == 0: + return np.nan + return n / d + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def classify_error(abs_error): + if pd.isna(abs_error): + return "missing_or_invalid" + if abs_error == 0: + return "exact_match" + if abs_error == 0.5: + return "error_equal_0_5" + if 0.5 < abs_error <= 1.0: + return "error_gt_0_5_le_1_0" + if abs_error > 1.0: + return "error_gt_1_0" + return "other" + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) + +gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +n_total_gt_rows = len(gt) +n_gt_numeric = len(gt_numeric) + +print(f"GT rows: {n_total_gt_rows}") +print(f"GT numeric EDSS rows: {n_gt_numeric}") + + +# ========================= +# MAIN ANALYSIS +# ========================= + +summary_rows = [] +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print("Skipping: no row_index column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + raw_prediction_rows = len(pred) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt_numeric.merge( + pred, + on="row_index", + how="left", + suffixes=("_gt", "_pred") + ) + + merged["prediction_available"] = merged["PRED_EDSS_numeric"].notna() + eval_df = merged[merged["prediction_available"]].copy() + + if eval_df.empty: + print("No evaluable rows.") + continue + + eval_df["error"] = eval_df["PRED_EDSS_numeric"] - eval_df["GT_EDSS_numeric"] + eval_df["abs_error"] = eval_df["error"].abs() + eval_df["error_category"] = eval_df["abs_error"].apply(classify_error) + + n_evaluable = len(eval_df) + + n_exact = int((eval_df["abs_error"] == 0).sum()) + n_error_equal_05 = int((eval_df["abs_error"] == 0.5).sum()) + n_error_gt_05_le_10 = int(((eval_df["abs_error"] > 0.5) & (eval_df["abs_error"] <= 1.0)).sum()) + n_error_gt_10 = int((eval_df["abs_error"] > 1.0).sum()) + n_error_gt_20 = int((eval_df["abs_error"] > 2.0).sum()) + + max_abs_error = eval_df["abs_error"].max() + + n_missing_or_invalid_against_gt_numeric = int((~merged["prediction_available"]).sum()) + + summary_rows.append({ + "model": model_name, + "iteration": TARGET_ITERATION, + "result_file": str(result_file), + + "n_total_gt_rows": n_total_gt_rows, + "n_gt_numeric": n_gt_numeric, + "raw_prediction_rows": raw_prediction_rows, + "n_evaluable": n_evaluable, + "n_missing_or_invalid_against_gt_numeric": n_missing_or_invalid_against_gt_numeric, + + "exact_match_n": n_exact, + "exact_match_percent_valid_only": rate(n_exact, n_evaluable) * 100, + "exact_match_percent_all_gt_numeric": rate(n_exact, n_gt_numeric) * 100, + + "error_equal_0_5_n": n_error_equal_05, + "error_equal_0_5_percent_valid_only": rate(n_error_equal_05, n_evaluable) * 100, + "error_equal_0_5_percent_all_gt_numeric": rate(n_error_equal_05, n_gt_numeric) * 100, + + "error_gt_0_5_le_1_0_n": n_error_gt_05_le_10, + "error_gt_0_5_le_1_0_percent_valid_only": rate(n_error_gt_05_le_10, n_evaluable) * 100, + "error_gt_0_5_le_1_0_percent_all_gt_numeric": rate(n_error_gt_05_le_10, n_gt_numeric) * 100, + + "error_gt_1_0_n": n_error_gt_10, + "error_gt_1_0_percent_valid_only": rate(n_error_gt_10, n_evaluable) * 100, + "error_gt_1_0_percent_all_gt_numeric": rate(n_error_gt_10, n_gt_numeric) * 100, + + "error_gt_2_0_n": n_error_gt_20, + "error_gt_2_0_percent_valid_only": rate(n_error_gt_20, n_evaluable) * 100, + "error_gt_2_0_percent_all_gt_numeric": rate(n_error_gt_20, n_gt_numeric) * 100, + + "maximum_absolute_error": max_abs_error, + }) + + keep_cols = [ + "row_index", + "unique_id_gt" if "unique_id_gt" in eval_df.columns else "unique_id", + "MedDatum_gt" if "MedDatum_gt" in eval_df.columns else "MedDatum", + "model", + "iteration", + "GT_EDSS_numeric", + "PRED_EDSS_numeric", + "error", + "abs_error", + "error_category", + "raw_EDSS", + "EDSS_numeric", + "EDSS_in_valid_range", + "klassifizierbar", + "clinical_output_valid", + "edss_logic_valid", + "certainty_percent", + "reason", + "inference_time_sec", + ] + + keep_cols = [c for c in keep_cols if c in eval_df.columns] + + for _, row in eval_df[keep_cols].iterrows(): + row_dict = row.to_dict() + row_dict["model_for_analysis"] = model_name + long_rows.append(row_dict) print(f"Model: {model_name}") - print(f"Result rows after filtering: {len(df_res)}") + print(f"n_evaluable: {n_evaluable}") + print(f"Exact match: {n_exact} ({rate(n_exact, n_evaluable) * 100:.1f}%)") + print(f"Error = 0.5: {n_error_equal_05} ({rate(n_error_equal_05, n_evaluable) * 100:.1f}%)") + print(f"Error >0.5 and ≤1.0: {n_error_gt_05_le_10} ({rate(n_error_gt_05_le_10, n_evaluable) * 100:.1f}%)") + print(f"Error >1.0: {n_error_gt_10} ({rate(n_error_gt_10, n_evaluable) * 100:.1f}%)") + print(f"Error >2.0: {n_error_gt_20} ({rate(n_error_gt_20, n_evaluable) * 100:.1f}%)") + print(f"Maximum absolute error: {max_abs_error}") - before_dedup = len(df_res) - df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first") - after_dedup = len(df_res) - if before_dedup != after_dedup: - print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}") +# ========================= +# SAVE OUTPUT +# ========================= - df_merged = df_ref.merge( - df_res, - on=MERGE_KEY, +summary_df = pd.DataFrame(summary_rows) +long_df = pd.DataFrame(long_rows) + +summary_df.to_csv(OUTPUT_PATH, index=False) +long_df.to_csv(OUTPUT_LONG_PATH, index=False) + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 240) + +print("\n" + "=" * 100) +print("EDSS error distribution summary:") +print(summary_df) + +print("\nSaved:") +print(OUTPUT_PATH) +print(OUTPUT_LONG_PATH) +## + +# %% EDSS severity-group performance per model + +from pathlib import Path + +import pandas as pd +import numpy as np +from sklearn.metrics import confusion_matrix + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"edss_severity_group_metrics_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SUMMARY_PATH = OUTPUT_DIR / f"edss_severity_group_metrics_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG_PATH = OUTPUT_DIR / f"edss_severity_group_predictions_long_iter_{TARGET_ITERATION}.csv" +OUTPUT_CONFUSION_PATH = OUTPUT_DIR / f"edss_severity_group_confusion_matrices_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" + +SEVERITY_GROUPS = [ + "0.0-3.5", + "4.0-5.5", + "6.0-10.0", +] + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def rate(n, d): + if d == 0: + return np.nan + return n / d + + +def classify_edss_group(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0-3.5" + if 4.0 <= value <= 5.5: + return "4.0-5.5" + if 6.0 <= value <= 10.0: + return "6.0-10.0" + return np.nan + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(classify_edss_group) + +gt_numeric = gt.dropna(subset=["GT_EDSS_numeric", "GT_EDSS_group"]).copy() + +n_total_gt_rows = len(gt) +n_gt_numeric = len(gt_numeric) + +print(f"GT rows: {n_total_gt_rows}") +print(f"GT numeric EDSS rows in severity groups: {n_gt_numeric}") +print("\nGT group counts:") +print(gt_numeric["GT_EDSS_group"].value_counts().reindex(SEVERITY_GROUPS, fill_value=0)) + + +# ========================= +# MAIN ANALYSIS +# ========================= + +summary_rows = [] +long_rows = [] +confusion_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print("Skipping: no row_index column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + raw_prediction_rows = len(pred) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["PRED_EDSS_group"] = pred["PRED_EDSS_numeric"].apply(classify_edss_group) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "PRED_EDSS_group"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt_numeric.merge( + pred, + on="row_index", + how="left", + suffixes=("_gt", "_pred") + ) + + merged["prediction_available"] = merged["PRED_EDSS_group"].notna() + eval_df = merged[merged["prediction_available"]].copy() + + if eval_df.empty: + print("No evaluable rows.") + continue + + n_evaluable = len(eval_df) + n_missing_or_invalid_against_gt_numeric = int((~merged["prediction_available"]).sum()) + + print(f"Model: {model_name}") + print(f"Raw prediction rows: {raw_prediction_rows}") + print(f"Evaluable rows: {n_evaluable}") + print(f"Missing/invalid vs GT numeric: {n_missing_or_invalid_against_gt_numeric}") + + # Multiclass confusion matrix across 3 severity groups + cm = confusion_matrix( + eval_df["GT_EDSS_group"], + eval_df["PRED_EDSS_group"], + labels=SEVERITY_GROUPS + ) + + cm_df = pd.DataFrame( + cm, + index=SEVERITY_GROUPS, + columns=SEVERITY_GROUPS + ) + cm_df.index.name = "Ground truth severity group" + cm_df.columns.name = "Predicted severity group" + + print("\nSeverity-group confusion matrix:") + print(cm_df) + + for gt_group in SEVERITY_GROUPS: + for pred_group in SEVERITY_GROUPS: + confusion_rows.append({ + "model": model_name, + "iteration": TARGET_ITERATION, + "gt_group": gt_group, + "pred_group": pred_group, + "count": int(cm_df.loc[gt_group, pred_group]), + }) + + # One-vs-rest sensitivity/specificity for each severity group + for group in SEVERITY_GROUPS: + y_true = eval_df["GT_EDSS_group"] == group + y_pred = eval_df["PRED_EDSS_group"] == group + + tn, fp, fn, tp = confusion_matrix( + y_true, + y_pred, + labels=[False, True] + ).ravel() + + sensitivity = rate(tp, tp + fn) + specificity = rate(tn, tn + fp) + ppv = rate(tp, tp + fp) + npv = rate(tn, tn + fn) + accuracy = rate(tp + tn, tp + tn + fp + fn) + + gt_positive_prevalence = rate(tp + fn, n_evaluable) + predicted_positive_rate = rate(tp + fp, n_evaluable) + + summary_rows.append({ + "model": model_name, + "iteration": TARGET_ITERATION, + "result_file": str(result_file), + "severity_group": group, + + "n_total_gt_rows": n_total_gt_rows, + "n_gt_numeric_in_groups": n_gt_numeric, + "raw_prediction_rows": raw_prediction_rows, + "n_evaluable": n_evaluable, + "n_missing_or_invalid_against_gt_numeric": n_missing_or_invalid_against_gt_numeric, + + "true_positives": int(tp), + "true_negatives": int(tn), + "false_positives": int(fp), + "false_negatives": int(fn), + + "sensitivity": sensitivity, + "specificity": specificity, + "positive_predictive_value": ppv, + "negative_predictive_value": npv, + "accuracy": accuracy, + + "sensitivity_percent": sensitivity * 100, + "specificity_percent": specificity * 100, + "positive_predictive_value_percent": ppv * 100, + "negative_predictive_value_percent": npv * 100, + "accuracy_percent": accuracy * 100, + + "gt_positive_prevalence": gt_positive_prevalence, + "gt_positive_prevalence_percent": gt_positive_prevalence * 100, + "predicted_positive_rate": predicted_positive_rate, + "predicted_positive_rate_percent": predicted_positive_rate * 100, + }) + + print( + f"Group {group}: " + f"TP={tp}, TN={tn}, FP={fp}, FN={fn}, " + f"sensitivity={sensitivity * 100:.1f}%, " + f"specificity={specificity * 100:.1f}%" + ) + + # Long per-case output + tmp = eval_df.copy() + + tmp["severity_match"] = tmp["GT_EDSS_group"] == tmp["PRED_EDSS_group"] + + keep_cols = [ + "model", + "iteration", + "row_index", + "unique_id_gt" if "unique_id_gt" in tmp.columns else "unique_id", + "MedDatum_gt" if "MedDatum_gt" in tmp.columns else "MedDatum", + "GT_EDSS_numeric", + "PRED_EDSS_numeric", + "GT_EDSS_group", + "PRED_EDSS_group", + "severity_match", + "raw_EDSS", + "EDSS_numeric", + "EDSS_in_valid_range", + "klassifizierbar", + "clinical_output_valid", + "edss_logic_valid", + "certainty_percent", + "reason", + "inference_time_sec", + ] + + keep_cols = [c for c in keep_cols if c in tmp.columns] + + for _, row in tmp[keep_cols].iterrows(): + row_dict = row.to_dict() + row_dict["model_for_analysis"] = model_name + long_rows.append(row_dict) + + +# ========================= +# SAVE OUTPUT +# ========================= + +summary_df = pd.DataFrame(summary_rows) +long_df = pd.DataFrame(long_rows) +confusion_df = pd.DataFrame(confusion_rows) + +summary_df.to_csv(OUTPUT_SUMMARY_PATH, index=False) +long_df.to_csv(OUTPUT_LONG_PATH, index=False) +confusion_df.to_csv(OUTPUT_CONFUSION_PATH, index=False) + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 240) + +print("\n" + "=" * 100) +print("EDSS severity-group performance summary:") +print(summary_df) + +print("\nSaved:") +print(OUTPUT_SUMMARY_PATH) +print(OUTPUT_LONG_PATH) +print(OUTPUT_CONFUSION_PATH) +## + + + + +# %% Coverage table: model evaluable predictions vs ground truth + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_PATH = RUN_DIR / f"model_coverage_table_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) + +total_records = len(gt) +numeric_gt_edss = gt["GT_EDSS_numeric"].notna().sum() + +gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"Total records: {total_records}") +print(f"Numeric ground-truth EDSS: {numeric_gt_edss}") + + +# ========================= +# MODEL COVERAGE TABLE +# ========================= + +rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column") + continue + + model_name = get_model_name(pred_raw, model_dir) + + pred = pred_raw.copy() + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt_numeric.merge( + pred, + on="row_index", how="inner", suffixes=("_gt", "_pred") ) - print(f"Merged rows: {len(df_merged)}") + evaluable_predictions = len(merged) - df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy() - - print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}") - - if df_eval.empty: - raise ValueError("No evaluable rows after merging and EDSS filtering.") - - cm = confusion_matrix( - df_eval["GT_EDSS_cat"], - df_eval["PRED_EDSS_cat"], - labels=EDSS_LABELS + coverage_numeric_gt = ( + evaluable_predictions / numeric_gt_edss * 100 + if numeric_gt_edss > 0 else np.nan ) - suffix = f"iter_{TARGET_ITERATION}" - - plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png" - cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv" - report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt" - merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv" - - plot_confusion_matrix(cm, model_name, plot_path) - - cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS) - cm_df.index.name = "Ground Truth EDSS" - cm_df.columns.name = "LLM Generated EDSS" - cm_df.to_csv(cm_csv_path) - - report = classification_report( - df_eval["GT_EDSS_cat"], - df_eval["PRED_EDSS_cat"], - labels=EDSS_LABELS, - zero_division=0 + coverage_all_records = ( + evaluable_predictions / total_records * 100 + if total_records > 0 else np.nan ) - with open(report_txt_path, "w", encoding="utf-8") as f: - f.write(f"Model: {model_name}\n") - f.write(f"Result file: {RESULT_PATH}\n") - f.write(f"Target iteration: {TARGET_ITERATION}\n") - f.write(f"Merged rows: {len(df_merged)}\n") - f.write(f"Evaluable rows: {len(df_eval)}\n\n") - f.write("Classification Report:\n") - f.write(report) - f.write("\n\nConfusion Matrix Raw Counts:\n") - f.write(cm_df.to_string()) + rows.append({ + "Model": model_name, + "Total records": total_records, + "Numeric ground-truth EDSS": numeric_gt_edss, + "Evaluable predictions": evaluable_predictions, + "Coverage of numeric ground truth (%)": coverage_numeric_gt, + "Coverage of all records (%)": coverage_all_records, + }) - keep_cols = [ - MERGE_KEY, - "MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum", - "GT_EDSS_numeric", - "PRED_EDSS_numeric", - "GT_EDSS_cat", - "PRED_EDSS_cat", - "model", - "iteration", - "success", - "inference_time_sec", - "certainty_percent", - "reason", - ] + print( + f"{model_name}: " + f"evaluable={evaluable_predictions}, " + f"coverage numeric GT={coverage_numeric_gt:.1f}%, " + f"coverage all={coverage_all_records:.1f}%" + ) - keep_cols = [col for col in keep_cols if col in df_eval.columns] - df_eval[keep_cols].to_csv(merged_csv_path, index=False) - print("\nClassification Report:") - print(report) +coverage_df = pd.DataFrame(rows) - print("\nConfusion Matrix Raw Counts:") - print(cm_df) +coverage_df.to_csv(OUTPUT_PATH, index=False) - print("\nSaved files:") - print(f"Plot: {plot_path}") - print(f"Confusion matrix: {cm_csv_path}") - print(f"Report: {report_txt_path}") - print(f"Merged rows: {merged_csv_path}") +print("\nCoverage table:") +print(coverage_df) - print("\nDone.") +print(f"\nSaved to:") +print(OUTPUT_PATH) ## + +# %% Dataset exploration table for EDSS project + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +DATA_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_TABLE_PATH = OUTPUT_DIR / "dataset_column_exploration_table.csv" +OUTPUT_NUMERIC_SUMMARY_PATH = OUTPUT_DIR / "reference_numeric_summary.csv" +OUTPUT_VALUE_COUNTS_PATH = OUTPUT_DIR / "reference_value_counts.csv" +OUTPUT_TEXT_SUMMARY_PATH = OUTPUT_DIR / "model_input_text_summary.csv" + + +# ========================= +# COLUMN DEFINITIONS +# ========================= + +COLUMNS_TO_EXPLORE = [ + { + "output_name": "unique_id", + "source_col": "unique_id", + "variable_type": "Identifier", + "role": "Patient/record linkage", + "description": "Pseudonymized unique identifier generated by hashing patient-related identifiers", + "used_for_model_input": "No", + "used_as_ground_truth": "No", + }, + { + "output_name": "MedDatum", + "source_col": "MedDatum", + "variable_type": "Date", + "role": "Visit metadata", + "description": "Date of clinical visit or medical documentation", + "used_for_model_input": "No", + "used_as_ground_truth": "No", + }, + { + "output_name": "T_Zusammenfassung", + "source_col": "T_Zusammenfassung", + "variable_type": "Text", + "role": "Model input", + "description": "Clinical summary section", + "used_for_model_input": "Yes", + "used_as_ground_truth": "No", + }, + { + "output_name": "Diagnosen", + "source_col": "Diagnosen", + "variable_type": "Text", + "role": "Model input", + "description": "Diagnostic information and coded/free-text diagnoses", + "used_for_model_input": "Yes", + "used_as_ground_truth": "No", + }, + { + "output_name": "T_KlinBef", + "source_col": "T_KlinBef", + "variable_type": "Text", + "role": "Model input", + "description": "Clinical examination findings", + "used_for_model_input": "Yes", + "used_as_ground_truth": "No", + }, + { + "output_name": "T_Befunde", + "source_col": "T_Befunde", + "variable_type": "Text", + "role": "Model input", + "description": "Additional findings and reports", + "used_for_model_input": "Yes", + "used_as_ground_truth": "No", + }, + { + "output_name": "EDSS_reference", + "source_col": "EDSS", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted reference EDSS score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "VISUAL_OPTIC_FUNCTIONS_reference", + "source_col": "Sehvermögen", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted visual/optic functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "BRAINSTEM_FUNCTIONS_reference", + "source_col": "Hirnstamm", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted brainstem functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "PYRAMIDAL_FUNCTIONS_reference", + "source_col": "Pyramidalmotorik", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted pyramidal functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "CEREBELLAR_FUNCTIONS_reference", + "source_col": "Cerebellum", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted cerebellar functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "SENSORY_FUNCTIONS_reference", + "source_col": "Sensibiliät", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted sensory functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "BOWEL_AND_BLADDER_FUNCTIONS_reference", + "source_col": "Blasen-_und_Mastdarmfunktion", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted bowel and bladder functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "CEREBRAL_FUNCTIONS_reference", + "source_col": "Cerebrale_Funktion", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted cerebral functional system score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, + { + "output_name": "AMBULATION_reference", + "source_col": "Ambulation", + "variable_type": "Numeric", + "role": "Ground truth", + "description": "Manually extracted ambulation score", + "used_for_model_input": "No", + "used_as_ground_truth": "Yes", + }, +] + + +TEXT_COLUMNS = [ + "T_Zusammenfassung", + "Diagnosen", + "T_KlinBef", + "T_Befunde", +] + +NUMERIC_REFERENCE_COLUMNS = [ + item for item in COLUMNS_TO_EXPLORE + if item["variable_type"] == "Numeric" +] + + +# ========================= +# HELPERS +# ========================= + +def to_num(series): + return pd.to_numeric( + series.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def is_non_missing_value(series): + """ + Treat NaN, empty string, whitespace-only string, and literal 'nan'/'None' as missing. + """ + s = series.copy() + + missing = s.isna() + + s_str = s.astype(str).str.strip() + missing = ( + missing + | (s_str == "") + | (s_str.str.lower().isin(["nan", "none", "null", "na", "n/a"])) + ) + + return ~missing + + +def text_length_stats(series): + non_missing = is_non_missing_value(series) + lengths = series[non_missing].astype(str).str.len() + + if len(lengths) == 0: + return { + "mean_text_length_chars": np.nan, + "median_text_length_chars": np.nan, + "min_text_length_chars": np.nan, + "max_text_length_chars": np.nan, + } + + return { + "mean_text_length_chars": lengths.mean(), + "median_text_length_chars": lengths.median(), + "min_text_length_chars": lengths.min(), + "max_text_length_chars": lengths.max(), + } + + +# ========================= +# LOAD DATA +# ========================= + +df = pd.read_csv(DATA_PATH, sep=";") +total_n = len(df) + +print(f"Loaded rows: {total_n}") +print(f"Loaded columns: {len(df.columns)}") + + +# ========================= +# MAIN EXPLORATION TABLE +# ========================= + +rows = [] + +for item in COLUMNS_TO_EXPLORE: + source_col = item["source_col"] + + if source_col not in df.columns: + non_missing_n = 0 + non_missing_pct = 0.0 + status = "missing_column" + else: + if item["variable_type"] == "Numeric": + numeric_values = to_num(df[source_col]) + non_missing_n = int(numeric_values.notna().sum()) + else: + non_missing_n = int(is_non_missing_value(df[source_col]).sum()) + + non_missing_pct = non_missing_n / total_n * 100 if total_n > 0 else np.nan + status = "ok" + + rows.append({ + "Variable / Column": item["output_name"], + "Source column": source_col, + "Variable type": item["variable_type"], + "Role in study": item["role"], + "Description": item["description"], + "Non-missing n / total": f"{non_missing_n} / {total_n}", + "Non-missing n": non_missing_n, + "Total n": total_n, + "Non-missing %": round(non_missing_pct, 1), + "Used for model input": item["used_for_model_input"], + "Used as ground truth": item["used_as_ground_truth"], + "Status": status, + }) + +exploration_df = pd.DataFrame(rows) +exploration_df.to_csv(OUTPUT_TABLE_PATH, index=False) + + +# ========================= +# NUMERIC REFERENCE SUMMARY +# ========================= + +numeric_rows = [] + +for item in NUMERIC_REFERENCE_COLUMNS: + source_col = item["source_col"] + + if source_col not in df.columns: + continue + + values = to_num(df[source_col]).dropna() + + if values.empty: + numeric_rows.append({ + "Variable / Column": item["output_name"], + "Source column": source_col, + "n": 0, + "non_zero_n": 0, + "non_zero_percent": np.nan, + "mean": np.nan, + "median": np.nan, + "std": np.nan, + "min": np.nan, + "max": np.nan, + }) + continue + + non_zero_n = int((values != 0).sum()) + + numeric_rows.append({ + "Variable / Column": item["output_name"], + "Source column": source_col, + "n": int(values.notna().sum()), + "non_zero_n": non_zero_n, + "non_zero_percent": non_zero_n / len(values) * 100, + "mean": values.mean(), + "median": values.median(), + "std": values.std(), + "min": values.min(), + "max": values.max(), + }) + +numeric_summary_df = pd.DataFrame(numeric_rows) +numeric_summary_df.to_csv(OUTPUT_NUMERIC_SUMMARY_PATH, index=False) + + +# ========================= +# VALUE COUNTS FOR REFERENCE NUMERIC COLUMNS +# ========================= + +value_count_rows = [] + +for item in NUMERIC_REFERENCE_COLUMNS: + source_col = item["source_col"] + + if source_col not in df.columns: + continue + + values = to_num(df[source_col]).dropna() + + counts = ( + values + .value_counts() + .sort_index() + ) + + for value, count in counts.items(): + value_count_rows.append({ + "Variable / Column": item["output_name"], + "Source column": source_col, + "value": value, + "count": int(count), + "percent_of_non_missing": count / len(values) * 100 if len(values) > 0 else np.nan, + "percent_of_total": count / total_n * 100 if total_n > 0 else np.nan, + }) + +value_counts_df = pd.DataFrame(value_count_rows) +value_counts_df.to_csv(OUTPUT_VALUE_COUNTS_PATH, index=False) + + +# ========================= +# TEXT INPUT SUMMARY +# ========================= + +text_rows = [] + +for col in TEXT_COLUMNS: + if col not in df.columns: + text_rows.append({ + "column": col, + "non_missing_n": 0, + "non_missing_percent": 0.0, + "mean_text_length_chars": np.nan, + "median_text_length_chars": np.nan, + "min_text_length_chars": np.nan, + "max_text_length_chars": np.nan, + }) + continue + + non_missing = is_non_missing_value(df[col]) + stats = text_length_stats(df[col]) + + text_rows.append({ + "column": col, + "non_missing_n": int(non_missing.sum()), + "non_missing_percent": non_missing.sum() / total_n * 100 if total_n > 0 else np.nan, + **stats, + }) + +text_summary_df = pd.DataFrame(text_rows) +text_summary_df.to_csv(OUTPUT_TEXT_SUMMARY_PATH, index=False) + + +# ========================= +# PRINT RESULTS +# ========================= + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 220) + +print("\n" + "=" * 100) +print("Dataset column exploration table:") +print(exploration_df) + +print("\n" + "=" * 100) +print("Numeric reference summary:") +print(numeric_summary_df) + +print("\n" + "=" * 100) +print("Text input summary:") +print(text_summary_df) + +print("\nSaved:") +print(OUTPUT_TABLE_PATH) +print(OUTPUT_NUMERIC_SUMMARY_PATH) +print(OUTPUT_VALUE_COUNTS_PATH) +print(OUTPUT_TEXT_SUMMARY_PATH) +## + + +# %% Dataset characteristics table + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +DATA_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_CSV = OUTPUT_DIR / "dataset_characteristics_table.csv" +OUTPUT_MD = OUTPUT_DIR / "dataset_characteristics_table.md" +OUTPUT_PATIENT_COUNTS = OUTPUT_DIR / "patient_record_counts.csv" + +PATIENT_ID_COL = "unique_id" +DATE_COL = "MedDatum" +EDSS_COL = "EDSS" + + +# ========================= +# HELPERS +# ========================= + +def to_num(series): + return pd.to_numeric( + series.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def parse_dates(series): + """ + Robust date parser. + Handles common German/European and ISO-like dates. + """ + parsed = pd.to_datetime(series, errors="coerce", dayfirst=True) + return parsed + + +def fmt_int(x): + if pd.isna(x): + return "NA" + return f"{int(x)}" + + +def fmt_float(x, digits=1): + if pd.isna(x): + return "NA" + return f"{float(x):.{digits}f}" + + +def fmt_percent(n, d, digits=1): + if d == 0: + return "NA" + return f"{(n / d * 100):.{digits}f}%" + + +def fmt_n_total_percent(n, total): + return f"{int(n)} / {int(total)}, {fmt_percent(n, total)}" + + +def fmt_range(min_value, max_value, digits=1): + if pd.isna(min_value) or pd.isna(max_value): + return "NA" + return f"{float(min_value):.{digits}f}–{float(max_value):.{digits}f}" + + +def fmt_record_range(min_value, max_value): + if pd.isna(min_value) or pd.isna(max_value): + return "NA" + return f"{int(min_value)}–{int(max_value)}" + + +# ========================= +# LOAD DATA +# ========================= + +df = pd.read_csv(DATA_PATH, sep=";") + +total_records = len(df) + +if PATIENT_ID_COL not in df.columns: + raise ValueError(f"Missing patient ID column: {PATIENT_ID_COL}") + +if DATE_COL not in df.columns: + raise ValueError(f"Missing date column: {DATE_COL}") + +if EDSS_COL not in df.columns: + raise ValueError(f"Missing EDSS column: {EDSS_COL}") + + +# ========================= +# BASIC COUNTS +# ========================= + +unique_patients = df[PATIENT_ID_COL].nunique(dropna=True) + +dates = parse_dates(df[DATE_COL]) +valid_dates = dates.dropna() + +if len(valid_dates) > 0: + documentation_start_year = int(valid_dates.min().year) + documentation_end_year = int(valid_dates.max().year) + documentation_period = f"{documentation_start_year}–{documentation_end_year}" +else: + documentation_period = "NA" + +edss_numeric = to_num(df[EDSS_COL]) +records_with_numeric_edss = int(edss_numeric.notna().sum()) +records_without_numeric_edss = int(total_records - records_with_numeric_edss) + + +# ========================= +# PATIENT RECORD COUNTS +# ========================= + +patient_counts = ( + df.groupby(PATIENT_ID_COL, dropna=True) + .size() + .reset_index(name="record_count") + .sort_values("record_count", ascending=False) +) + +patient_counts.to_csv(OUTPUT_PATIENT_COUNTS, index=False) + +median_records_per_patient = patient_counts["record_count"].median() +min_records_per_patient = patient_counts["record_count"].min() +max_records_per_patient = patient_counts["record_count"].max() + +patients_with_one_record = int((patient_counts["record_count"] == 1).sum()) +patients_with_multiple_records = int((patient_counts["record_count"] > 1).sum()) + + +# ========================= +# DUPLICATE VISIT EXPLORATION +# ========================= + +# This estimates duplicate visits using patient ID + documentation date. +# If a patient has multiple rows on the same MedDatum, rows beyond the first are counted as duplicate records. +duplicate_subset = df.copy() +duplicate_subset["_parsed_MedDatum"] = dates + +duplicate_rows_mask = duplicate_subset.duplicated( + subset=[PATIENT_ID_COL, "_parsed_MedDatum"], + keep="first" +) + +records_excluded_as_duplicates = int(duplicate_rows_mask.sum()) + +duplicate_patients = duplicate_subset.loc[ + duplicate_rows_mask, + PATIENT_ID_COL +].nunique(dropna=True) + +if records_excluded_as_duplicates == 0: + duplicate_text = "0" +else: + duplicate_text = ( + f"{records_excluded_as_duplicates} visits from " + f"{duplicate_patients} patients" + ) + + +# ========================= +# EDSS SUMMARY +# ========================= + +edss_valid = edss_numeric.dropna() + +if len(edss_valid) > 0: + median_edss = edss_valid.median() + q1_edss = edss_valid.quantile(0.25) + q3_edss = edss_valid.quantile(0.75) + min_edss = edss_valid.min() + max_edss = edss_valid.max() +else: + median_edss = np.nan + q1_edss = np.nan + q3_edss = np.nan + min_edss = np.nan + max_edss = np.nan + + +# ========================= +# BUILD CHARACTERISTICS TABLE +# ========================= + +rows = [ + { + "Characteristic": "Total clinical records", + "Value": fmt_int(total_records), + }, + { + "Characteristic": "Unique patients", + "Value": fmt_int(unique_patients), + }, + { + "Characteristic": "Documentation period", + "Value": documentation_period, + }, + { + "Characteristic": "Records excluded as duplicates", + "Value": duplicate_text, + }, + { + "Characteristic": "Records with numeric reference EDSS", + "Value": fmt_n_total_percent(records_with_numeric_edss, total_records), + }, + { + "Characteristic": "Records without numeric reference EDSS", + "Value": fmt_n_total_percent(records_without_numeric_edss, total_records), + }, + { + "Characteristic": "Median records per patient", + "Value": fmt_float(median_records_per_patient, digits=1), + }, + { + "Characteristic": "Range of records per patient", + "Value": fmt_record_range(min_records_per_patient, max_records_per_patient), + }, + { + "Characteristic": "Patients with one record", + "Value": fmt_n_total_percent(patients_with_one_record, unique_patients), + }, + { + "Characteristic": "Patients with multiple records", + "Value": fmt_n_total_percent(patients_with_multiple_records, unique_patients), + }, + { + "Characteristic": "Median reference EDSS", + "Value": fmt_float(median_edss, digits=1), + }, + { + "Characteristic": "IQR reference EDSS", + "Value": fmt_range(q1_edss, q3_edss, digits=1), + }, + { + "Characteristic": "Minimum–maximum reference EDSS", + "Value": fmt_range(min_edss, max_edss, digits=1), + }, +] + +characteristics_df = pd.DataFrame(rows) + + +# ========================= +# SAVE OUTPUT +# ========================= + +characteristics_df.to_csv(OUTPUT_CSV, index=False) + +with open(OUTPUT_MD, "w", encoding="utf-8") as f: + f.write(characteristics_df.to_markdown(index=False)) + f.write("\n") + + +# ========================= +# PRINT OUTPUT +# ========================= + +pd.set_option("display.max_colwidth", None) +pd.set_option("display.width", 160) + +print("\nDataset characteristics table:") +print(characteristics_df.to_markdown(index=False)) + +print("\nSaved:") +print(OUTPUT_CSV) +print(OUTPUT_MD) +print(OUTPUT_PATIENT_COUNTS) + +print("\nDuplicate estimate note:") +print( + "Duplicates were estimated as repeated rows with the same unique_id and MedDatum. " + "If you already removed duplicates before this file, this value may be 0." +) +## + + +# %% Dataset characteristics table with visit-count distribution + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +DATA_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_CSV = OUTPUT_DIR / "dataset_characteristics_table.csv" +OUTPUT_MD = OUTPUT_DIR / "dataset_characteristics_table.md" +OUTPUT_PATIENT_COUNTS = OUTPUT_DIR / "patient_record_counts.csv" +OUTPUT_VISIT_DISTRIBUTION = OUTPUT_DIR / "patient_visit_count_distribution.csv" + +PATIENT_ID_COL = "unique_id" +DATE_COL = "MedDatum" +EDSS_COL = "EDSS" + + +# ========================= +# HELPERS +# ========================= + +def to_num(series): + return pd.to_numeric( + series.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def parse_dates(series): + """ + Robust date parser. + Handles common German/European and ISO-like dates. + """ + return pd.to_datetime(series, errors="coerce", dayfirst=True) + + +def fmt_int(x): + if pd.isna(x): + return "NA" + return f"{int(x)}" + + +def fmt_float(x, digits=1): + if pd.isna(x): + return "NA" + return f"{float(x):.{digits}f}" + + +def fmt_percent(n, d, digits=1): + if d == 0: + return "NA" + return f"{(n / d * 100):.{digits}f}%" + + +def fmt_n_total_percent(n, total): + return f"{int(n)} / {int(total)}, {fmt_percent(n, total)}" + + +def fmt_range(min_value, max_value, digits=1): + if pd.isna(min_value) or pd.isna(max_value): + return "NA" + return f"{float(min_value):.{digits}f}–{float(max_value):.{digits}f}" + + +def fmt_record_range(min_value, max_value): + if pd.isna(min_value) or pd.isna(max_value): + return "NA" + return f"{int(min_value)}–{int(max_value)}" + + +# ========================= +# LOAD DATA +# ========================= + +df = pd.read_csv(DATA_PATH, sep=";") + +total_records = len(df) + +if PATIENT_ID_COL not in df.columns: + raise ValueError(f"Missing patient ID column: {PATIENT_ID_COL}") + +if DATE_COL not in df.columns: + raise ValueError(f"Missing date column: {DATE_COL}") + +if EDSS_COL not in df.columns: + raise ValueError(f"Missing EDSS column: {EDSS_COL}") + + +# ========================= +# BASIC COUNTS +# ========================= + +unique_patients = df[PATIENT_ID_COL].nunique(dropna=True) + +# In this dataset, each row is a clinical record / visit. +total_visits = total_records + +dates = parse_dates(df[DATE_COL]) +valid_dates = dates.dropna() + +if len(valid_dates) > 0: + documentation_start_year = int(valid_dates.min().year) + documentation_end_year = int(valid_dates.max().year) + documentation_period = f"{documentation_start_year}–{documentation_end_year}" +else: + documentation_period = "NA" + +edss_numeric = to_num(df[EDSS_COL]) +records_with_numeric_edss = int(edss_numeric.notna().sum()) +records_without_numeric_edss = int(total_records - records_with_numeric_edss) + + +# ========================= +# PATIENT RECORD / VISIT COUNTS +# ========================= + +patient_counts = ( + df.groupby(PATIENT_ID_COL, dropna=True) + .size() + .reset_index(name="record_count") + .sort_values("record_count", ascending=False) +) + +patient_counts.to_csv(OUTPUT_PATIENT_COUNTS, index=False) + +median_records_per_patient = patient_counts["record_count"].median() +min_records_per_patient = patient_counts["record_count"].min() +max_records_per_patient = patient_counts["record_count"].max() + +patients_with_one_record = int((patient_counts["record_count"] == 1).sum()) +patients_with_multiple_records = int((patient_counts["record_count"] > 1).sum()) + +patients_with_n_records = { + n: int((patient_counts["record_count"] == n).sum()) + for n in range(2, 8) +} + +patients_with_more_than_7_records = int((patient_counts["record_count"] > 7).sum()) + +visit_distribution_rows = [] + +visit_distribution_rows.append({ + "records_per_patient": "1", + "patients_n": patients_with_one_record, + "total_patients": unique_patients, + "patients_percent": patients_with_one_record / unique_patients * 100 if unique_patients else np.nan, +}) + +for n in range(2, 8): + visit_distribution_rows.append({ + "records_per_patient": str(n), + "patients_n": patients_with_n_records[n], + "total_patients": unique_patients, + "patients_percent": patients_with_n_records[n] / unique_patients * 100 if unique_patients else np.nan, + }) + +visit_distribution_rows.append({ + "records_per_patient": ">7", + "patients_n": patients_with_more_than_7_records, + "total_patients": unique_patients, + "patients_percent": patients_with_more_than_7_records / unique_patients * 100 if unique_patients else np.nan, +}) + +visit_distribution_df = pd.DataFrame(visit_distribution_rows) +visit_distribution_df.to_csv(OUTPUT_VISIT_DISTRIBUTION, index=False) + + +# ========================= +# DUPLICATE VISIT EXPLORATION +# ========================= + +# This estimates duplicate visits using patient ID + documentation date. +# If a patient has multiple rows on the same MedDatum, rows beyond the first are counted as duplicate records. +duplicate_subset = df.copy() +duplicate_subset["_parsed_MedDatum"] = dates + +duplicate_rows_mask = duplicate_subset.duplicated( + subset=[PATIENT_ID_COL, "_parsed_MedDatum"], + keep="first" +) + +records_excluded_as_duplicates = int(duplicate_rows_mask.sum()) + +duplicate_patients = duplicate_subset.loc[ + duplicate_rows_mask, + PATIENT_ID_COL +].nunique(dropna=True) + +if records_excluded_as_duplicates == 0: + duplicate_text = "0" +else: + duplicate_text = ( + f"{records_excluded_as_duplicates} visits from " + f"{duplicate_patients} patients" + ) + + +# ========================= +# EDSS SUMMARY +# ========================= + +edss_valid = edss_numeric.dropna() + +if len(edss_valid) > 0: + median_edss = edss_valid.median() + q1_edss = edss_valid.quantile(0.25) + q3_edss = edss_valid.quantile(0.75) + min_edss = edss_valid.min() + max_edss = edss_valid.max() +else: + median_edss = np.nan + q1_edss = np.nan + q3_edss = np.nan + min_edss = np.nan + max_edss = np.nan + + +# ========================= +# BUILD CHARACTERISTICS TABLE +# ========================= + +rows = [ + { + "Characteristic": "Total clinical records", + "Value": fmt_int(total_records), + }, + { + "Characteristic": "Total visits", + "Value": fmt_int(total_visits), + }, + { + "Characteristic": "Unique patients", + "Value": fmt_int(unique_patients), + }, + { + "Characteristic": "Documentation period", + "Value": documentation_period, + }, + { + "Characteristic": "Records excluded as duplicates", + "Value": duplicate_text, + }, + { + "Characteristic": "Records with numeric reference EDSS", + "Value": fmt_n_total_percent(records_with_numeric_edss, total_records), + }, + { + "Characteristic": "Records without numeric reference EDSS", + "Value": fmt_n_total_percent(records_without_numeric_edss, total_records), + }, + { + "Characteristic": "Median records per patient", + "Value": fmt_float(median_records_per_patient, digits=1), + }, + { + "Characteristic": "Range of records per patient", + "Value": fmt_record_range(min_records_per_patient, max_records_per_patient), + }, + { + "Characteristic": "Patients with one record", + "Value": fmt_n_total_percent(patients_with_one_record, unique_patients), + }, + { + "Characteristic": "Patients with multiple records", + "Value": fmt_n_total_percent(patients_with_multiple_records, unique_patients), + }, +] + +for n in range(2, 8): + rows.append({ + "Characteristic": f"Patients with {n} records", + "Value": fmt_n_total_percent(patients_with_n_records[n], unique_patients), + }) + +rows.append({ + "Characteristic": "Patients with >7 records", + "Value": fmt_n_total_percent(patients_with_more_than_7_records, unique_patients), +}) + +rows.extend([ + { + "Characteristic": "Median reference EDSS", + "Value": fmt_float(median_edss, digits=1), + }, + { + "Characteristic": "IQR reference EDSS", + "Value": fmt_range(q1_edss, q3_edss, digits=1), + }, + { + "Characteristic": "Minimum–maximum reference EDSS", + "Value": fmt_range(min_edss, max_edss, digits=1), + }, +]) + +characteristics_df = pd.DataFrame(rows) + + +# ========================= +# SAVE OUTPUT +# ========================= + +characteristics_df.to_csv(OUTPUT_CSV, index=False) + +with open(OUTPUT_MD, "w", encoding="utf-8") as f: + f.write(characteristics_df.to_markdown(index=False)) + f.write("\n") + + +# ========================= +# PRINT OUTPUT +# ========================= + +pd.set_option("display.max_colwidth", None) +pd.set_option("display.width", 180) + +print("\nDataset characteristics table:") +print(characteristics_df.to_markdown(index=False)) + +print("\nPatient visit-count distribution:") +print(visit_distribution_df.to_markdown(index=False)) + +print("\nSaved:") +print(OUTPUT_CSV) +print(OUTPUT_MD) +print(OUTPUT_PATIENT_COUNTS) +print(OUTPUT_VISIT_DISTRIBUTION) + +print("\nDuplicate estimate note:") +print( + "Duplicates were estimated as repeated rows with the same unique_id and MedDatum. " + "If you already removed duplicates before this file, this value may be 0." +) +## + + + +# %% Structured-output validity bar chart grouped by metric + +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +OUTPUT_DIR = RUN_DIR / "structured_output_validity_figure" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / "structured_output_validity_bar_chart_grouped_by_metric.svg" +OUTPUT_PNG = OUTPUT_DIR / "structured_output_validity_bar_chart_grouped_by_metric.png" +OUTPUT_CSV = OUTPUT_DIR / "structured_output_validity_table_grouped_by_metric.csv" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def find_summary_file(model_dir): + files = sorted(model_dir.glob("*_summary_*.csv")) + return files[0] if files else None + + +def percent_from_rate(value): + if pd.isna(value): + return np.nan + + value = float(value) + + if value <= 1.0: + return value * 100 + + return value + + +def clean_model_name(name): + name = str(name) + + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + + return replacements.get(name, name) + + +# ========================= +# LOAD SUMMARY DATA +# ========================= + +rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") +] + +for model_dir in model_dirs: + summary_file = find_summary_file(model_dir) + + if summary_file is None: + print(f"No summary file found in {model_dir}") + continue + + df = pd.read_csv(summary_file) + + if df.empty: + print(f"Empty summary file: {summary_file}") + continue + + row = df.iloc[0] + + model = row.get("model", model_dir.name) + model_display = clean_model_name(model) + + success_rate = percent_from_rate(row.get("success_rate", np.nan)) + + if "clinical_output_valid_rate" in df.columns: + clinical_output_valid_rate = percent_from_rate( + row.get("clinical_output_valid_rate", np.nan) + ) + else: + clinical_output_valid_rate = percent_from_rate( + row.get("clinical_range_valid_rate", np.nan) + ) + + edss_valid_range_rate = percent_from_rate( + row.get("EDSS_valid_range_rate", np.nan) + ) + + rows.append({ + "model": model, + "model_display": model_display, + "Success rate": success_rate, + "Clinical-output validity": clinical_output_valid_rate, + "EDSS valid-range rate": edss_valid_range_rate, + "summary_file": str(summary_file), + }) + + +validity_df = pd.DataFrame(rows) + +if validity_df.empty: + raise ValueError("No model summary data found.") + + +# Optional model order +model_order = ["GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it"] +validity_df["model_display"] = pd.Categorical( + validity_df["model_display"], + categories=model_order, + ordered=True +) +validity_df = validity_df.sort_values("model_display").reset_index(drop=True) + +validity_df.to_csv(OUTPUT_CSV, index=False) + +print("\nStructured-output validity table:") +print(validity_df) + + +# ========================= +# PLOT +# ========================= + +metrics = [ + "Success rate", + "Clinical-output validity", + "EDSS valid-range rate", +] + +models = validity_df["model_display"].astype(str).tolist() + +x = np.arange(len(metrics)) +n_models = len(models) +bar_width = 0.22 + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +fig, ax = plt.subplots(figsize=(10, 6)) + +for i, model in enumerate(models): + values = [ + validity_df.loc[validity_df["model_display"].astype(str) == model, metric].iloc[0] + for metric in metrics + ] + + offset = (i - (n_models - 1) / 2) * bar_width + + bars = ax.bar( + x + offset, + values, + width=bar_width, + label=model, + color=colors.get(model, None), + edgecolor="white", + linewidth=0.8, + ) + + for bar, value in zip(bars, values): + if pd.notna(value): + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1, + f"{value:.1f}%", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + rotation=0, + ) + +ax.set_xticks(x) +ax.set_xticklabels(metrics, fontsize=10) + +ax.set_ylim(0, 110) +ax.set_ylabel("Percentage of responses", fontsize=11, fontweight="bold") +ax.set_xlabel("Structured-output metric", fontsize=11, fontweight="bold") + +#ax.set_title( +# "Structured-output validity by metric and model", +# fontsize=13, +# fontweight="bold", +# pad=15, +#) + +ax.set_yticks(np.arange(0, 101, 10)) +ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +plt.tight_layout(rect=[0, 0, 1, 0.92]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_CSV) +print(OUTPUT_SVG) +print(OUTPUT_PNG) +## + + +# %% EDSS severity-group confusion heatmaps per model + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +# ========================= +# CONFIGURATION +# ========================= + +INPUT_LONG_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/edss_severity_group_metrics_iter_1/" + "edss_severity_group_predictions_long_iter_1.csv" +) + +OUTPUT_DIR = INPUT_LONG_PATH.parent / "severity_group_heatmaps" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +# Options: +# "count" -> cell values are raw counts +# "row_percent" -> cell values are percentages within each ground-truth row +PLOT_MODE = "row_percent" +# PLOT_MODE = "count" + +GROUP_ORDER = [ + "0.0-3.5", + "4.0-5.5", + "6.0-10.0", +] + +GROUP_LABELS = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def safe_model_name(name): + return ( + str(name) + .replace("/", "_") + .replace(" ", "_") + .replace(":", "_") + ) + + +def make_confusion_table(df_model): + cm = pd.crosstab( + df_model["GT_EDSS_group"], + df_model["PRED_EDSS_group"], + dropna=False + ) + + cm = cm.reindex(index=GROUP_ORDER, columns=GROUP_ORDER, fill_value=0) + return cm + + +def row_percent_table(cm): + row_sums = cm.sum(axis=1).replace(0, np.nan) + pct = cm.div(row_sums, axis=0) * 100 + return pct.fillna(0) + + +def plot_heatmap(cm_counts, model_name, plot_mode): + if plot_mode == "count": + plot_data = cm_counts.copy() + annot = plot_data.astype(int).astype(str) + fmt = "" + cbar_label = "Number of cases" + title_suffix = "Counts" + vmax = None + + elif plot_mode == "row_percent": + plot_data = row_percent_table(cm_counts) + annot = plot_data.applymap(lambda x: f"{x:.1f}%") + fmt = "" + cbar_label = "Row percentage" + title_suffix = "Row percentages" + vmax = 100 + + else: + raise ValueError(f"Unknown PLOT_MODE: {plot_mode}") + + fig, ax = plt.subplots(figsize=(7, 6)) + + sns.heatmap( + plot_data, + annot=annot, + fmt=fmt, + cmap="Blues", + vmin=0, + vmax=vmax, + xticklabels=GROUP_LABELS, + yticklabels=GROUP_LABELS, + linewidths=0.8, + linecolor="white", + square=True, + cbar_kws={"label": cbar_label}, + ax=ax, + ) + + ax.set_xlabel("Predicted EDSS severity group", fontsize=11, fontweight="bold") + ax.set_ylabel("Ground-truth EDSS severity group", fontsize=11, fontweight="bold") + + ax.set_title( + f"EDSS Severity-Group Confusion Matrix\n{model_name} | {title_suffix}", + fontsize=13, + fontweight="bold", + pad=15, + ) + + plt.xticks(rotation=0) + plt.yticks(rotation=0) + plt.tight_layout() + + safe_name = safe_model_name(model_name) + svg_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_heatmap_{plot_mode}.svg" + png_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_heatmap_{plot_mode}.png" + + plt.savefig(svg_path, format="svg", bbox_inches="tight") + plt.savefig(png_path, dpi=300, bbox_inches="tight") + plt.show() + + return svg_path, png_path + + +# ========================= +# LOAD DATA +# ========================= + +df = pd.read_csv(INPUT_LONG_PATH) + +required_cols = [ + "GT_EDSS_group", + "PRED_EDSS_group", +] + +for col in required_cols: + if col not in df.columns: + raise ValueError(f"Missing required column: {col}") + +if "model_for_analysis" in df.columns: + model_col = "model_for_analysis" +elif "model" in df.columns: + model_col = "model" +else: + raise ValueError("No model column found. Expected 'model_for_analysis' or 'model'.") + +df = df.dropna(subset=["GT_EDSS_group", "PRED_EDSS_group"]).copy() + +print(f"Loaded rows: {len(df)}") +print(f"Models: {sorted(df[model_col].dropna().unique())}") + + +# ========================= +# CREATE HEATMAPS +# ========================= + +summary_rows = [] + +for model_name, df_model in df.groupby(model_col): + print("\n" + "=" * 80) + print(f"Model: {model_name}") + print(f"Rows: {len(df_model)}") + + cm_counts = make_confusion_table(df_model) + cm_row_pct = row_percent_table(cm_counts) + + print("\nCount matrix:") + print(cm_counts) + + print("\nRow percentage matrix:") + print(cm_row_pct.round(1)) + + svg_path, png_path = plot_heatmap( + cm_counts=cm_counts, + model_name=model_name, + plot_mode=PLOT_MODE, + ) + + safe_name = safe_model_name(model_name) + + counts_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_counts.csv" + row_pct_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_row_percent.csv" + + cm_counts.to_csv(counts_path) + cm_row_pct.to_csv(row_pct_path) + + summary_rows.append({ + "model": model_name, + "plot_mode": PLOT_MODE, + "n_rows": len(df_model), + "svg_path": str(svg_path), + "png_path": str(png_path), + "counts_path": str(counts_path), + "row_percent_path": str(row_pct_path), + }) + + print("\nSaved:") + print(svg_path) + print(png_path) + print(counts_path) + print(row_pct_path) + + +# ========================= +# SAVE SUMMARY +# ========================= + +summary_df = pd.DataFrame(summary_rows) +summary_path = OUTPUT_DIR / f"severity_group_heatmap_summary_{PLOT_MODE}.csv" +summary_df.to_csv(summary_path, index=False) + +print("\n" + "=" * 80) +print("Done.") +print(f"Summary saved to: {summary_path}") +## + + + +# %% Grouped bar chart of patient-level EDSS range across 10 runs + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +INPUT_FILES = [ + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gemma-4-31B-it_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gpt-oss-120b_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "qwen3.6-27b_all_valid_predictions_long.csv" + ), +] + +OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / "patient_level_edss_range_grouped_bar.svg" +OUTPUT_PNG = OUTPUT_DIR / "patient_level_edss_range_grouped_bar.png" +OUTPUT_PATIENT_RANGE_CSV = OUTPUT_DIR / "patient_level_edss_range_by_model.csv" +OUTPUT_GROUPED_CSV = OUTPUT_DIR / "patient_level_edss_range_grouped_counts.csv" + +EDSS_COL = "EDSS_prediction" +N_EXPECTED_RUNS = 10 + +# Choose whether to include all patients with at least one valid run, +# or only patients with all 10 valid runs. +USE_ONLY_COMPLETE_10_RUNS = False +# USE_ONLY_COMPLETE_10_RUNS = True + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def clean_model_name(name): + name = str(name) + + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + + return replacements.get(name, name) + + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def categorize_edss_range(value): + """ + Categorize patient-level EDSS range across repeated runs. + """ + if pd.isna(value): + return np.nan + if value == 0: + return "0" + if value <= 0.5: + return "0.5" + if value <= 1.0: + return ">0.5–1.0" + if value <= 2.0: + return ">1.0–2.0" + return ">2.0" + + +# ========================= +# LOAD AND COMBINE DATA +# ========================= + +dfs = [] + +for path in INPUT_FILES: + if not path.exists(): + print(f"Skipping missing file: {path}") + continue + + df = pd.read_csv(path) + + required_cols = ["model", "row_index", EDSS_COL] + for col in required_cols: + if col not in df.columns: + raise ValueError(f"Missing column '{col}' in {path}") + + df = df.copy() + df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) + df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() + + dfs.append(df) + +if not dfs: + raise ValueError("No input data loaded.") + +all_df = pd.concat(dfs, ignore_index=True) +all_df["model_display"] = all_df["model"].apply(clean_model_name) + +print(f"Loaded valid prediction rows: {len(all_df)}") +print("\nRows per model:") +print(all_df["model_display"].value_counts()) + + +# ========================= +# PATIENT-LEVEL RANGE +# ========================= + +group_cols = ["model", "model_display", "row_index"] + +if "unique_id" in all_df.columns: + group_cols.append("unique_id") + +patient_range_df = ( + all_df + .groupby(group_cols, dropna=False) + .agg( + n_valid_runs=("EDSS_prediction_numeric", "count"), + EDSS_min=("EDSS_prediction_numeric", "min"), + EDSS_max=("EDSS_prediction_numeric", "max"), + EDSS_mean=("EDSS_prediction_numeric", "mean"), + EDSS_median=("EDSS_prediction_numeric", "median"), + EDSS_std=("EDSS_prediction_numeric", lambda x: x.std(ddof=0)), + ) + .reset_index() +) + +patient_range_df["EDSS_range"] = ( + patient_range_df["EDSS_max"] - patient_range_df["EDSS_min"] +) + +patient_range_df["complete_10_valid_runs"] = ( + patient_range_df["n_valid_runs"] == N_EXPECTED_RUNS +) + +patient_range_df["EDSS_range_category"] = patient_range_df["EDSS_range"].apply( + categorize_edss_range +) + +patient_range_df.to_csv(OUTPUT_PATIENT_RANGE_CSV, index=False) + + +# ========================= +# OPTIONAL FILTER +# ========================= + +plot_df = patient_range_df.copy() + +if USE_ONLY_COMPLETE_10_RUNS: + plot_df = plot_df[plot_df["complete_10_valid_runs"]].copy() + +if plot_df.empty: + raise ValueError("No patient-level data available after filtering.") + + +# ========================= +# GROUPED COUNTS AND PERCENTAGES +# ========================= + +range_order = [ + "0", + "0.5", + ">0.5–1.0", + ">1.0–2.0", + ">2.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +# Keep only models actually present +model_order = [ + m for m in model_order + if m in plot_df["model_display"].unique() +] + +counts = ( + plot_df + .groupby(["EDSS_range_category", "model_display"]) + .size() + .unstack(fill_value=0) + .reindex(index=range_order, columns=model_order, fill_value=0) +) + +percentages = counts.copy().astype(float) + +for model in model_order: + total = counts[model].sum() + if total > 0: + percentages[model] = counts[model] / total * 100 + else: + percentages[model] = np.nan + +# Save combined counts and percentages +combined_out = [] + +for range_cat in range_order: + for model in model_order: + combined_out.append({ + "EDSS_range_category": range_cat, + "model": model, + "count": int(counts.loc[range_cat, model]), + "percent_within_model": percentages.loc[range_cat, model], + "total_patients_for_model": int(counts[model].sum()), + "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, + }) + +combined_df = pd.DataFrame(combined_out) +combined_df.to_csv(OUTPUT_GROUPED_CSV, index=False) + +print("\nCounts:") +print(counts) + +print("\nPercentages within model:") +print(percentages.round(1)) + + +# ========================= +# PLOT +# ========================= + +x = np.arange(len(range_order)) +n_models = len(model_order) +bar_width = 0.22 + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +fig, ax = plt.subplots(figsize=(10, 6)) + +for i, model in enumerate(model_order): + values = percentages[model].values + offset = (i - (n_models - 1) / 2) * bar_width + + bars = ax.bar( + x + offset, + values, + width=bar_width, + label=model, + color=colors.get(model, None), + edgecolor="white", + linewidth=0.8, + ) + + for bar, value in zip(bars, values): + if pd.notna(value) and value >= 2: + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1, + f"{value:.1f}%", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + +ax.set_xticks(x) +ax.set_xticklabels(range_order, fontsize=10) + +ax.set_ylabel("Patients (%)", fontsize=11, fontweight="bold") +ax.set_xlabel("Patient-level EDSS range across repeated runs", fontsize=11, fontweight="bold") + +title_suffix = "patients with all 10 valid runs" if USE_ONLY_COMPLETE_10_RUNS else "patients with available valid runs" + +#ax.set_title( +# f"Repeated-run stability of EDSS predictions\n{title_suffix}", +# fontsize=13, +# fontweight="bold", +# pad=15, +#) + +ax.set_ylim(0, max(100, np.nanmax(percentages.values) + 10)) +ax.set_yticks(np.arange(0, 101, 10)) +ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +# Add model n below legend area as text +n_text = " | ".join([ + f"{model}: n={int(counts[model].sum())}" + for model in model_order +]) + +ax.text( + 0.5, + 1.08, + n_text, + transform=ax.transAxes, + ha="center", + va="bottom", + fontsize=9, +) + +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_PATIENT_RANGE_CSV) +print(OUTPUT_GROUPED_CSV) +## + + +# %% Simple stability figure: stable / minor variation / unstable + +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +INPUT_FILES = [ + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gemma-4-31B-it_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gpt-oss-120b_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "qwen3.6-27b_all_valid_predictions_long.csv" + ), +] + +OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / "simple_edss_stability_stacked_bar.svg" +OUTPUT_PNG = OUTPUT_DIR / "simple_edss_stability_stacked_bar.png" +OUTPUT_CSV = OUTPUT_DIR / "simple_edss_stability_table.csv" +OUTPUT_PATIENT_LEVEL_CSV = OUTPUT_DIR / "simple_edss_stability_patient_level.csv" + +EDSS_COL = "EDSS_prediction" +N_EXPECTED_RUNS = 10 + +# If True, only patients with all 10 valid predictions are included. +# If False, patients with at least 2 valid predictions are included. +USE_ONLY_COMPLETE_10_RUNS = False + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def classify_stability(edss_range): + if pd.isna(edss_range): + return np.nan + if edss_range == 0: + return "Identical across runs" + if edss_range <= 0.5: + return "Range ≤0.5" + return "Range >0.5" + + +# ========================= +# LOAD DATA +# ========================= + +dfs = [] + +for path in INPUT_FILES: + df = pd.read_csv(path) + df = df.copy() + df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) + df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() + df["model_display"] = df["model"].apply(clean_model_name) + dfs.append(df) + +all_df = pd.concat(dfs, ignore_index=True) + + +# ========================= +# PATIENT-LEVEL RANGE +# ========================= + +group_cols = ["model", "model_display", "row_index"] + +if "unique_id" in all_df.columns: + group_cols.append("unique_id") + +patient_df = ( + all_df + .groupby(group_cols, dropna=False) + .agg( + n_valid_runs=("EDSS_prediction_numeric", "count"), + edss_min=("EDSS_prediction_numeric", "min"), + edss_max=("EDSS_prediction_numeric", "max"), + ) + .reset_index() +) + +patient_df["edss_range"] = patient_df["edss_max"] - patient_df["edss_min"] +patient_df["complete_10_valid_runs"] = patient_df["n_valid_runs"] == N_EXPECTED_RUNS + +# Need at least 2 runs to measure variability. +patient_df = patient_df[patient_df["n_valid_runs"] >= 2].copy() + +if USE_ONLY_COMPLETE_10_RUNS: + patient_df = patient_df[patient_df["complete_10_valid_runs"]].copy() + +patient_df["stability_category"] = patient_df["edss_range"].apply(classify_stability) + +patient_df.to_csv(OUTPUT_PATIENT_LEVEL_CSV, index=False) + + +# ========================= +# SUMMARY TABLE +# ========================= + +category_order = [ + "Identical across runs", + "Range ≤0.5", + "Range >0.5", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in patient_df["model_display"].unique() +] + +counts = ( + patient_df + .groupby(["model_display", "stability_category"]) + .size() + .unstack(fill_value=0) + .reindex(index=model_order, columns=category_order, fill_value=0) +) + +percentages = counts.div(counts.sum(axis=1), axis=0) * 100 +percentages = percentages.fillna(0) + +summary_rows = [] + +for model in model_order: + total = int(counts.loc[model].sum()) + for category in category_order: + summary_rows.append({ + "model": model, + "stability_category": category, + "count": int(counts.loc[model, category]), + "percent": percentages.loc[model, category], + "total_patients": total, + "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, + }) + +summary_df = pd.DataFrame(summary_rows) +summary_df.to_csv(OUTPUT_CSV, index=False) + +print("\nCounts:") +print(counts) + +print("\nPercentages:") +print(percentages.round(1)) + + +# ========================= +# PLOT +# ========================= + +colors = { + "Identical across runs": "#1F77B4", + "Range ≤0.5": "#9ECAE1", + "Range >0.5": "#F28E2B", +} + +fig, ax = plt.subplots(figsize=(10, 5)) + +left = np.zeros(len(model_order)) + +for category in category_order: + values = percentages[category].values + + bars = ax.barh( + model_order, + values, + left=left, + color=colors[category], + edgecolor="white", + linewidth=0.8, + label=category, + ) + + for i, value in enumerate(values): + if value >= 5: + ax.text( + left[i] + value / 2, + i, + f"{value:.1f}%", + ha="center", + va="center", + fontsize=9, + fontweight="bold", + ) + + left += values + +for i, model in enumerate(model_order): + total = int(counts.loc[model].sum()) + ax.text( + 101, + i, + f"n={total}", + va="center", + ha="left", + fontsize=9, + ) + +ax.set_xlim(0, 110) +ax.set_xlabel("Patients (%)", fontsize=11, fontweight="bold") +ax.set_ylabel("Model", fontsize=11, fontweight="bold") + +title_suffix = ( + "patients with all 10 valid runs" + if USE_ONLY_COMPLETE_10_RUNS + else "patients with at least 2 valid runs" +) + +#ax.set_title( +# f"Repeated-run stability of EDSS predictions\n{title_suffix}", +# fontsize=13, +# fontweight="bold", +# pad=15, +#) + +ax.set_xticks(np.arange(0, 101, 10)) +ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) + +ax.xaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right", "left"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_CSV) +print(OUTPUT_PATIENT_LEVEL_CSV) +## +# %% Fancy simple stability figure: rounded horizontal stacked bars + +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import FancyBboxPatch + + +# ========================= +# CONFIGURATION +# ========================= + +INPUT_FILES = [ + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gemma-4-31B-it_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "gpt-oss-120b_all_valid_predictions_long.csv" + ), + Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" + "run_20260528_103942/repeated_run_variability/" + "qwen3.6-27b_all_valid_predictions_long.csv" + ), +] + +OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / "fancy_simple_edss_stability.svg" +OUTPUT_PNG = OUTPUT_DIR / "fancy_simple_edss_stability.png" +OUTPUT_CSV = OUTPUT_DIR / "fancy_simple_edss_stability_table.csv" +OUTPUT_PATIENT_LEVEL_CSV = OUTPUT_DIR / "fancy_simple_edss_stability_patient_level.csv" + +EDSS_COL = "EDSS_prediction" +N_EXPECTED_RUNS = 10 + +# False = include patients with at least 2 valid runs. +# True = only patients with all 10 valid runs. +USE_ONLY_COMPLETE_10_RUNS = False + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def classify_stability(edss_range): + if pd.isna(edss_range): + return np.nan + if edss_range == 0: + return "Identical" + if edss_range <= 0.5: + return "Minor variation" + return "Unstable" + + +def rounded_barh(ax, y, left, width, height, color, radius=0.16): + """ + Draw a rounded horizontal bar segment. + """ + patch = FancyBboxPatch( + (left, y - height / 2), + width, + height, + boxstyle=f"round,pad=0,rounding_size={radius}", + linewidth=0, + facecolor=color, + ) + ax.add_patch(patch) + return patch + + +# ========================= +# LOAD DATA +# ========================= + +dfs = [] + +for path in INPUT_FILES: + df = pd.read_csv(path) + df = df.copy() + df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) + df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() + df["model_display"] = df["model"].apply(clean_model_name) + dfs.append(df) + +all_df = pd.concat(dfs, ignore_index=True) + + +# ========================= +# PATIENT-LEVEL RANGE +# ========================= + +group_cols = ["model", "model_display", "row_index"] + +if "unique_id" in all_df.columns: + group_cols.append("unique_id") + +patient_df = ( + all_df + .groupby(group_cols, dropna=False) + .agg( + n_valid_runs=("EDSS_prediction_numeric", "count"), + edss_min=("EDSS_prediction_numeric", "min"), + edss_max=("EDSS_prediction_numeric", "max"), + ) + .reset_index() +) + +patient_df["edss_range"] = patient_df["edss_max"] - patient_df["edss_min"] +patient_df["complete_10_valid_runs"] = patient_df["n_valid_runs"] == N_EXPECTED_RUNS + +# Need at least 2 repeated predictions to measure stability. +patient_df = patient_df[patient_df["n_valid_runs"] >= 2].copy() + +if USE_ONLY_COMPLETE_10_RUNS: + patient_df = patient_df[patient_df["complete_10_valid_runs"]].copy() + +patient_df["stability_category"] = patient_df["edss_range"].apply(classify_stability) +patient_df.to_csv(OUTPUT_PATIENT_LEVEL_CSV, index=False) + + +# ========================= +# SUMMARY TABLE +# ========================= + +category_order = [ + "Identical", + "Minor variation", + "Unstable", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in patient_df["model_display"].unique() +] + +counts = ( + patient_df + .groupby(["model_display", "stability_category"]) + .size() + .unstack(fill_value=0) + .reindex(index=model_order, columns=category_order, fill_value=0) +) + +percentages = counts.div(counts.sum(axis=1), axis=0) * 100 +percentages = percentages.fillna(0) + +summary_rows = [] + +for model in model_order: + total = int(counts.loc[model].sum()) + for category in category_order: + summary_rows.append({ + "model": model, + "stability_category": category, + "count": int(counts.loc[model, category]), + "percent": percentages.loc[model, category], + "total_patients": total, + "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, + }) + +summary_df = pd.DataFrame(summary_rows) +summary_df.to_csv(OUTPUT_CSV, index=False) + +print("\nPercentages:") +print(percentages.round(1)) + + +# ========================= +# FANCY PLOT +# ========================= + +colors = { + "Identical": "#0B4F8A", + "Minor variation": "#7DB9DE", + "Unstable": "#F28E2B", +} + +fig, ax = plt.subplots(figsize=(10.5, 5.3)) + +bar_height = 0.48 +y_positions = np.arange(len(model_order)) + +for i, model in enumerate(model_order): + left = 0 + + for category in category_order: + value = percentages.loc[model, category] + + if value > 0: + rounded_barh( + ax=ax, + y=i, + left=left, + width=value, + height=bar_height, + color=colors[category], + radius=0.13, + ) + + if value >= 6: + ax.text( + left + value / 2, + i, + f"{value:.1f}%", + ha="center", + va="center", + fontsize=10, + fontweight="bold", + color="white" if category in ["Identical", "Unstable"] else "black", + ) + + left += value + + total = int(counts.loc[model].sum()) + ax.text( + 103, + i, + f"n={total}", + va="center", + ha="left", + fontsize=10, + color="#333333", + ) + + # Main stability label at left + identical = percentages.loc[model, "Identical"] + minor = percentages.loc[model, "Minor variation"] + stable_or_minor = identical + minor + + ax.text( + -3, + i - 0.33, + f"{stable_or_minor:.1f}% ≤0.5 range", + va="center", + ha="right", + fontsize=9, + color="#444444", + ) + + +# Y-axis model labels +ax.set_yticks(y_positions) +ax.set_yticklabels(model_order, fontsize=11, fontweight="bold") + +ax.set_xlim(-18, 112) +ax.set_ylim(-0.8, len(model_order) - 0.2) + +ax.set_xlabel("Patients (%)", fontsize=11, fontweight="bold") +ax.set_title( + "Repeated-run stability of EDSS predictions", + fontsize=15, + fontweight="bold", + pad=18, +) + +subtitle = ( + "Patient-level EDSS range across repeated model runs " + "(identical, minor variation, or unstable)" +) +ax.text( + 0.5, + 1.02, + subtitle, + transform=ax.transAxes, + ha="center", + va="bottom", + fontsize=10, + color="#555555", +) + +# X-axis formatting +ax.set_xticks(np.arange(0, 101, 20)) +ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 20)]) + +ax.xaxis.grid(True, linestyle="--", alpha=0.25) +ax.set_axisbelow(True) + +# Clean style +for spine in ["top", "right", "left", "bottom"]: + ax.spines[spine].set_visible(False) + +ax.tick_params(axis="y", length=0) +ax.tick_params(axis="x", length=0) + +# Legend +legend_handles = [ + plt.Rectangle((0, 0), 1, 1, color=colors["Identical"]), + plt.Rectangle((0, 0), 1, 1, color=colors["Minor variation"]), + plt.Rectangle((0, 0), 1, 1, color=colors["Unstable"]), +] + +ax.legend( + legend_handles, + [ + "Identical across runs", + "Range ≤0.5", + "Range >0.5", + ], + loc="lower center", + bbox_to_anchor=(0.5, -0.18), + ncol=3, + frameon=False, + fontsize=10, +) + +plt.tight_layout(rect=[0, 0.05, 1, 0.95]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_CSV) +print(OUTPUT_PATIENT_LEVEL_CSV) +## + +# %% Functional system heatmap: MAE by functional system and model + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +# ========================= +# CONFIGURATION +# ========================= + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +INPUT_METRICS_PATH = ( + RUN_DIR + / "functional_system_metrics_iter_1" + / "functional_system_metrics_short_iter_1.csv" +) + +OUTPUT_DIR = RUN_DIR / "functional_system_heatmaps" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / "functional_system_mae_heatmap.svg" +OUTPUT_PNG = OUTPUT_DIR / "functional_system_mae_heatmap.png" +OUTPUT_CSV = OUTPUT_DIR / "functional_system_mae_heatmap_table.csv" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# SETTINGS +# ========================= + +MODEL_ORDER = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +FUNCTIONAL_SYSTEM_ORDER = [ + "Visual/optic functions", + "Brainstem functions", + "Pyramidal functions", + "Cerebellar functions", + "Sensory functions", + "Bowel and bladder functions", + "Cerebral functions", + "Ambulation", +] + +MODEL_NAME_MAP = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", +} + + +# ========================= +# LOAD DATA +# ========================= + +df = pd.read_csv(INPUT_METRICS_PATH) + +required_cols = ["model", "functional_system", "MAE"] + +for col in required_cols: + if col not in df.columns: + raise ValueError(f"Missing required column: {col}") + +df = df.copy() +df["model_display"] = df["model"].map(MODEL_NAME_MAP).fillna(df["model"]) + +df["functional_system"] = df["functional_system"].replace({ + "Visual/optic functions": "Visual/optic functions", + "Brainstem functions": "Brainstem functions", + "Pyramidal functions": "Pyramidal functions", + "Cerebellar functions": "Cerebellar functions", + "Sensory functions": "Sensory functions", + "Bowel and bladder functions": "Bowel and bladder functions", + "Cerebral functions": "Cerebral functions", + "Ambulation": "Ambulation", +}) + +df["MAE"] = pd.to_numeric(df["MAE"], errors="coerce") + + +# ========================= +# PIVOT TABLE +# ========================= + +heatmap_df = ( + df + .pivot_table( + index="functional_system", + columns="model_display", + values="MAE", + aggfunc="mean" + ) + .reindex(index=FUNCTIONAL_SYSTEM_ORDER, columns=MODEL_ORDER) +) + +heatmap_df.to_csv(OUTPUT_CSV) + +print("\nMAE heatmap table:") +print(heatmap_df) + + +# ========================= +# PLOT +# ========================= + +fig, ax = plt.subplots(figsize=(8, 6.5)) + +sns.heatmap( + heatmap_df, + annot=True, + fmt=".2f", + cmap="Blues_r", # lower MAE appears darker/better + linewidths=0.8, + linecolor="white", + cbar_kws={"label": "Mean absolute error"}, + ax=ax, +) + +ax.set_xlabel("Model", fontsize=11, fontweight="bold") +ax.set_ylabel("Functional system", fontsize=11, fontweight="bold") + +ax.set_title( + "Functional system performance by model\nMean absolute error", + fontsize=13, + fontweight="bold", + pad=15, +) + +plt.xticks(rotation=30, ha="right") +plt.yticks(rotation=0) + +plt.tight_layout() + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_CSV) +print(OUTPUT_SVG) +print(OUTPUT_PNG) +## + + +# %% Confidence bracket vs EDSS error grouped by model + +from pathlib import Path +import re + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import pearsonr + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_error_analysis_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_bracket_mae_grouped_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_bracket_mae_grouped_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_mae_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_error_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +ADD_TREND_LINES = True + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def safe_name(name): + return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low (<70%)" + if certainty < 80: + return "Moderate (70–80%)" + if certainty < 90: + return "High (80–90%)" + if certainty <= 100: + return "Very High (90–100%)" + return np.nan + + +def confidence_midpoint(bracket): + midpoint_map = { + "Low (<70%)": 65, + "Moderate (70–80%)": 75, + "High (80–90%)": 85, + "Very High (90–100%)": 95, + } + return midpoint_map.get(bracket, np.nan) + + +def sem(series): + values = pd.to_numeric(series, errors="coerce").dropna() + if len(values) <= 1: + return 0.0 + return values.std(ddof=1) / np.sqrt(len(values)) + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# LOAD MODEL PREDICTIONS AND BUILD LONG ERROR DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence_error_analysis") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print("Skipping: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + print("No evaluable rows.") + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["confidence_midpoint"] = merged["confidence_bracket"].apply(confidence_midpoint) + + merged = merged.dropna(subset=["confidence_bracket"]).copy() + + print(f"Evaluable rows with confidence bracket: {len(merged)}") + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "confidence_midpoint": row["confidence_midpoint"], + "error": row["error"], + "abs_error": row["abs_error"], + "inference_time_sec": row.get("inference_time_sec", np.nan), + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY BY MODEL AND CONFIDENCE BRACKET +# ========================= + +bracket_order = [ + "Low (<70%)", + "Moderate (70–80%)", + "High (80–90%)", + "Very High (90–100%)", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "confidence_bracket"], observed=False) + .agg( + n=("abs_error", "count"), + MAE=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + SEM=("abs_error", sem), + mean_certainty=("certainty_percent", "mean"), + ) + .reset_index() +) + +# Ensure full model x bracket grid exists +full_index = pd.MultiIndex.from_product( + [model_order, bracket_order], + names=["model_display", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["confidence_midpoint"] = summary["confidence_bracket"].apply(confidence_midpoint) + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nConfidence-bracket MAE table:") +print(summary) + + +# ========================= +# CORRELATION PER MODEL +# ========================= + +corr_text = {} + +for model in model_order: + df_m = long_df[long_df["model_display"] == model].copy() + + if len(df_m) >= 3 and df_m["certainty_percent"].nunique() > 1 and df_m["abs_error"].nunique() > 1: + r, p = pearsonr(df_m["certainty_percent"], df_m["abs_error"]) + corr_text[model] = f"r={r:.2f}, p={p:.2g}, n={len(df_m)}" + else: + corr_text[model] = f"r=NA, n={len(df_m)}" + + +# ========================= +# PLOT +# ========================= + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +x = np.arange(len(bracket_order)) +n_models = len(model_order) +bar_width = 0.22 + +fig, ax = plt.subplots(figsize=(12, 7)) + +for i, model in enumerate(model_order): + df_m = summary[summary["model_display"] == model].copy() + df_m = df_m.set_index("confidence_bracket").reindex(bracket_order).reset_index() + + values = df_m["MAE"].values + errors = df_m["SEM"].fillna(0).values + ns = df_m["n"].fillna(0).astype(int).values + + offset = (i - (n_models - 1) / 2) * bar_width + + bars = ax.bar( + x + offset, + values, + width=bar_width, + yerr=errors, + capsize=4, + color=colors.get(model, None), + edgecolor="black", + linewidth=0.6, + alpha=0.85, + label=model, + ) + + for bar, value, n in zip(bars, values, ns): + if pd.notna(value) and n > 0: + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.035, + f"{value:.2f}\nn={n}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + + if ADD_TREND_LINES: + valid = df_m.dropna(subset=["MAE", "confidence_midpoint"]) + if len(valid) >= 2: + ax.plot( + x + offset, + df_m["MAE"].values, + linestyle="--", + linewidth=1.5, + color=colors.get(model, None), + alpha=0.9, + ) + + +ax.set_xticks(x) +ax.set_xticklabels(bracket_order, fontsize=10) + +ax.set_ylabel("Mean absolute EDSS error", fontsize=11, fontweight="bold") +ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") + +ax.set_title( + "EDSS prediction error across LLM confidence brackets", + fontsize=14, + fontweight="bold", + pad=15, +) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="upper right", + frameon=True, + title="Model", +) + +# Add correlation text box +corr_lines = ["Pearson correlation: confidence vs absolute error"] +for model in model_order: + corr_lines.append(f"{model}: {corr_text[model]}") + +ax.text( + 0.02, + 0.98, + "\n".join(corr_lines), + transform=ax.transAxes, + ha="left", + va="top", + fontsize=9, + bbox=dict( + boxstyle="round,pad=0.4", + facecolor="white", + edgecolor="#999999", + alpha=0.9, + ), +) + +# Add metric explanation +ax.text( + 0.98, + 0.02, + "Bars: MAE\nError bars: SEM\nDashed lines: bracket trend", + transform=ax.transAxes, + ha="right", + va="bottom", + fontsize=9, + bbox=dict( + boxstyle="round,pad=0.4", + facecolor="white", + edgecolor="#CCCCCC", + alpha=0.9, + ), +) + +plt.tight_layout() + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## + +# %% Confidence bracket vs clinically acceptable EDSS accuracy grouped by model + +from pathlib import Path +import re + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_analysis_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_accuracy_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + merged["exact_match"] = merged["abs_error"] == 0 + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + + merged = merged.dropna(subset=["confidence_bracket"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + mean_abs_error=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, bracket_order], + names=["model_display", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nConfidence-bracket accuracy table:") +print(summary) + + +# ========================= +# PLOT +# ========================= + +x = np.arange(len(bracket_order)) +n_models = len(model_order) +bar_width = 0.22 + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +fig, ax = plt.subplots(figsize=(11, 6.5)) + +for i, model in enumerate(model_order): + df_m = ( + summary[summary["model_display"] == model] + .set_index("confidence_bracket") + .reindex(bracket_order) + .reset_index() + ) + + values = df_m["accuracy_within_0_5_percent"].values + ns = df_m["n"].fillna(0).astype(int).values + + offset = (i - (n_models - 1) / 2) * bar_width + + bars = ax.bar( + x + offset, + values, + width=bar_width, + color=colors.get(model), + edgecolor="white", + linewidth=0.8, + label=model, + ) + + for bar, value, n in zip(bars, values, ns): + if pd.notna(value) and n > 0: + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1.2, + f"{value:.1f}%\nn={n}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + +ax.set_xticks(x) +ax.set_xticklabels(bracket_order, fontsize=10) + +ax.set_ylim(0, 110) +ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") +ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") + +#ax.set_title( +# "Accuracy of EDSS predictions by confidence bracket", +# fontsize=14, +# fontweight="bold", +# pad=15, +#) + +ax.set_yticks(np.arange(0, 101, 10)) +ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +ax.text( + 0.5, + -0.18, + "Higher bars indicate better calibration: high-confidence predictions are more often clinically close to the reference EDSS.", + transform=ax.transAxes, + ha="center", + va="top", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.05, 1, 0.92]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## +# %% Confidence bracket accuracy + predicted EDSS range distribution + +from pathlib import Path +import re + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_analysis_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_ACCURACY_SVG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.svg" +OUTPUT_ACCURACY_PNG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.png" + +OUTPUT_RANGE_SVG = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_iter_{TARGET_ITERATION}.svg" +OUTPUT_RANGE_PNG = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_iter_{TARGET_ITERATION}.png" + +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_accuracy_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_RANGE_TABLE = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_group(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + merged["exact_match"] = merged["abs_error"] == 0 + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) + + merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "GT_EDSS_group": row["GT_EDSS_group"], + "PRED_EDSS_group": row["PRED_EDSS_group"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY: ACCURACY BY CONFIDENCE +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + mean_abs_error=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, bracket_order], + names=["model_display", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + + +# ========================= +# FIGURE 1: ACCURACY WITHIN ±0.5 BY CONFIDENCE +# ========================= + +x = np.arange(len(bracket_order)) +n_models = len(model_order) +bar_width = 0.22 + +model_colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +fig, ax = plt.subplots(figsize=(11, 6.5)) + +for i, model in enumerate(model_order): + df_m = ( + summary[summary["model_display"] == model] + .set_index("confidence_bracket") + .reindex(bracket_order) + .reset_index() + ) + + values = df_m["accuracy_within_0_5_percent"].values + ns = df_m["n"].fillna(0).astype(int).values + + offset = (i - (n_models - 1) / 2) * bar_width + + bars = ax.bar( + x + offset, + values, + width=bar_width, + color=model_colors.get(model), + edgecolor="white", + linewidth=0.8, + label=model, + ) + + for bar, value, n in zip(bars, values, ns): + if pd.notna(value) and n > 0: + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 1.2, + f"{value:.1f}%\nn={n}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) + +ax.set_xticks(x) +ax.set_xticklabels(bracket_order, fontsize=10) + +ax.set_ylim(0, 110) +ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") +ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") + +ax.set_title( + "Accuracy of EDSS predictions by confidence bracket", + fontsize=14, + fontweight="bold", + pad=15, +) + +ax.set_yticks(np.arange(0, 101, 10)) +ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +plt.tight_layout(rect=[0, 0.03, 1, 0.92]) + +plt.savefig(OUTPUT_ACCURACY_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_ACCURACY_PNG, dpi=300, bbox_inches="tight") + +plt.show() + + +# ========================= +# SUMMARY: PREDICTED EDSS RANGE BY CONFIDENCE +# ========================= + +range_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +range_colors = { + "0.0–3.5": "#9ECAE1", + "4.0–5.5": "#FDDC7A", + "6.0–10.0": "#F28E2B", +} + +range_rows = [] + +for model in model_order: + df_m = long_df[long_df["model_display"] == model].copy() + + for bracket in bracket_order: + df_b = df_m[df_m["confidence_bracket"] == bracket].copy() + total = len(df_b) + + for edss_range in range_order: + count = int((df_b["PRED_EDSS_group"] == edss_range).sum()) + percent = count / total * 100 if total > 0 else np.nan + + range_rows.append({ + "model": model, + "confidence_bracket": bracket, + "predicted_EDSS_range": edss_range, + "count": count, + "total_in_confidence_bracket": total, + "percent": percent, + }) + +range_df = pd.DataFrame(range_rows) +range_df.to_csv(OUTPUT_RANGE_TABLE, index=False) + + +# ========================= +# FIGURE 2: PREDICTED EDSS RANGE BY CONFIDENCE +# ========================= + +fig, axes = plt.subplots( + nrows=1, + ncols=len(model_order), + figsize=(5 * len(model_order), 5.5), + sharey=True +) + +if len(model_order) == 1: + axes = [axes] + +for ax, model in zip(axes, model_order): + df_m = range_df[range_df["model"] == model].copy() + + left = np.zeros(len(bracket_order)) + + for edss_range in range_order: + values = [] + + for bracket in bracket_order: + value = df_m.loc[ + (df_m["confidence_bracket"] == bracket) + & (df_m["predicted_EDSS_range"] == edss_range), + "percent" + ] + + if len(value) == 0: + values.append(0) + else: + values.append(value.iloc[0] if pd.notna(value.iloc[0]) else 0) + + bars = ax.bar( + bracket_order, + values, + bottom=left, + color=range_colors[edss_range], + edgecolor="white", + linewidth=0.8, + label=edss_range, + ) + + for i, value in enumerate(values): + if value >= 8: + ax.text( + i, + left[i] + value / 2, + f"{value:.0f}%", + ha="center", + va="center", + fontsize=8, + fontweight="bold", + ) + + left += np.array(values) + + # n labels above bars + for i, bracket in enumerate(bracket_order): + total = df_m.loc[ + df_m["confidence_bracket"] == bracket, + "total_in_confidence_bracket" + ] + + total_n = int(total.iloc[0]) if len(total) > 0 and pd.notna(total.iloc[0]) else 0 + + ax.text( + i, + 102, + f"n={total_n}", + ha="center", + va="bottom", + fontsize=8, + ) + + ax.set_title(model, fontsize=12, fontweight="bold") + ax.set_xlabel("Confidence bracket", fontsize=10, fontweight="bold") + ax.set_ylim(0, 110) + ax.set_xticklabels(bracket_order, rotation=0, fontsize=8) + + ax.yaxis.grid(True, linestyle="--", alpha=0.25) + ax.set_axisbelow(True) + + for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +axes[0].set_ylabel("Predicted EDSS range (%)", fontsize=11, fontweight="bold") + +handles, labels = axes[-1].get_legend_handles_labels() + +fig.legend( + handles, + labels, + title="Predicted EDSS range", + loc="lower center", + bbox_to_anchor=(0.5, -0.02), + ncol=3, + frameon=False, +) + +fig.suptitle( + "Predicted EDSS range within each confidence bracket", + fontsize=14, + fontweight="bold", + y=1.03, +) + +plt.tight_layout(rect=[0, 0.07, 1, 0.96]) + +plt.savefig(OUTPUT_RANGE_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_RANGE_PNG, dpi=300, bbox_inches="tight") + +plt.show() + + +# ========================= +# DONE +# ========================= + +print("\nSaved:") +print(OUTPUT_ACCURACY_SVG) +print(OUTPUT_ACCURACY_PNG) +print(OUTPUT_RANGE_SVG) +print(OUTPUT_RANGE_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_RANGE_TABLE) +print(OUTPUT_LONG) +## +# %% Heatmap: confidence bracket x predicted EDSS range x accuracy, one panel per model + +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_heatmap_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_group(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["abs_error"] = ( + merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + ).abs() + + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) + + merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "GT_EDSS_group": row["GT_EDSS_group"], + "PRED_EDSS_group": row["PRED_EDSS_group"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "abs_error": row["abs_error"], + "within_0_5": row["within_0_5"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# AGGREGATE FOR HEATMAP +# ========================= + +confidence_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +edss_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "PRED_EDSS_group", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + accuracy_within_0_5=("within_0_5", "mean"), + mean_abs_error=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, edss_order, confidence_order], + names=["model_display", "PRED_EDSS_group", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "PRED_EDSS_group", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["n"] = summary["n"].fillna(0).astype(int) +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nHeatmap summary table:") +print(summary) + + +# ========================= +# PLOT +# ========================= + +fig, axes = plt.subplots( + nrows=1, + ncols=len(model_order), + figsize=(5.2 * len(model_order), 4.8), + sharey=True +) + +if len(model_order) == 1: + axes = [axes] + +for ax, model in zip(axes, model_order): + df_m = summary[summary["model_display"] == model].copy() + + heatmap_values = ( + df_m + .pivot( + index="PRED_EDSS_group", + columns="confidence_bracket", + values="accuracy_within_0_5_percent" + ) + .reindex(index=edss_order, columns=confidence_order) + ) + + heatmap_n = ( + df_m + .pivot( + index="PRED_EDSS_group", + columns="confidence_bracket", + values="n" + ) + .reindex(index=edss_order, columns=confidence_order) + .fillna(0) + .astype(int) + ) + + annotations = heatmap_values.copy().astype(object) + + for r in edss_order: + for c in confidence_order: + value = heatmap_values.loc[r, c] + n = heatmap_n.loc[r, c] + + if n == 0 or pd.isna(value): + annotations.loc[r, c] = "" + else: + annotations.loc[r, c] = f"{value:.0f}%\nn={n}" + + sns.heatmap( + heatmap_values, + ax=ax, + annot=annotations, + fmt="", + cmap="Blues", + vmin=0, + vmax=100, + linewidths=1, + linecolor="white", + cbar=False, + square=False, + ) + + ax.set_title(model, fontsize=12, fontweight="bold") + ax.set_xlabel("LLM confidence bracket", fontsize=10, fontweight="bold") + ax.set_ylabel("Predicted EDSS range" if ax == axes[0] else "", fontsize=10, fontweight="bold") + + ax.set_xticklabels(confidence_order, rotation=0, fontsize=8) + ax.set_yticklabels(edss_order, rotation=0, fontsize=9) + +# Shared colorbar +mappable = axes[-1].collections[0] +cbar = fig.colorbar( + mappable, + ax=axes, + orientation="vertical", + fraction=0.025, + pad=0.02, +) +cbar.set_label("Accuracy within ±0.5 EDSS (%)", fontsize=10, fontweight="bold") + +fig.suptitle( + "Confidence-stratified EDSS accuracy by predicted severity range", + fontsize=14, + fontweight="bold", + y=1.03, +) + +fig.text( + 0.5, + 0.01, + "Cell color shows accuracy within ±0.5 EDSS; text shows accuracy and number of predictions.", + ha="center", + va="bottom", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.05, 0.97, 0.95]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## +# %% Improved heatmap: confidence bracket x predicted EDSS range x accuracy, one model per row + +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_heatmap_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_group(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["abs_error"] = ( + merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + ).abs() + + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) + + merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "GT_EDSS_group": row["GT_EDSS_group"], + "PRED_EDSS_group": row["PRED_EDSS_group"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "abs_error": row["abs_error"], + "within_0_5": row["within_0_5"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# AGGREGATE FOR HEATMAP +# ========================= + +confidence_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +edss_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "PRED_EDSS_group", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + accuracy_within_0_5=("within_0_5", "mean"), + mean_abs_error=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, edss_order, confidence_order], + names=["model_display", "PRED_EDSS_group", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "PRED_EDSS_group", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["n"] = summary["n"].fillna(0).astype(int) +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nHeatmap summary table:") +print(summary) + + +# ========================= +# PLOT - ONE MODEL PER ROW +# ========================= + +n_models = len(model_order) + +fig, axes = plt.subplots( + nrows=n_models, + ncols=1, + figsize=(8.5, 3.1 * n_models), + sharex=True, + constrained_layout=False +) + +if n_models == 1: + axes = [axes] + +cbar_ax = fig.add_axes([0.92, 0.18, 0.025, 0.65]) + +for i, (ax, model) in enumerate(zip(axes, model_order)): + df_m = summary[summary["model_display"] == model].copy() + + heatmap_values = ( + df_m + .pivot( + index="PRED_EDSS_group", + columns="confidence_bracket", + values="accuracy_within_0_5_percent" + ) + .reindex(index=edss_order, columns=confidence_order) + ) + + heatmap_n = ( + df_m + .pivot( + index="PRED_EDSS_group", + columns="confidence_bracket", + values="n" + ) + .reindex(index=edss_order, columns=confidence_order) + .fillna(0) + .astype(int) + ) + + annotations = heatmap_values.copy().astype(object) + + for r in edss_order: + for c in confidence_order: + value = heatmap_values.loc[r, c] + n = heatmap_n.loc[r, c] + + if n == 0 or pd.isna(value): + annotations.loc[r, c] = "" + else: + annotations.loc[r, c] = f"{value:.0f}%\nn={n}" + + mask = heatmap_n == 0 + + sns.heatmap( + heatmap_values, + ax=ax, + annot=annotations, + fmt="", + cmap="Blues", + vmin=0, + vmax=100, + mask=mask, + linewidths=1, + linecolor="white", + cbar=(i == 0), + cbar_ax=cbar_ax if i == 0 else None, + cbar_kws={"label": "Accuracy within ±0.5 EDSS (%)"}, + ) + + # Grey background for empty cells + ax.set_facecolor("#F2F2F2") + + ax.set_title(model, fontsize=12, fontweight="bold", loc="left", pad=8) + ax.set_ylabel("Predicted EDSS range", fontsize=10, fontweight="bold") + ax.set_xlabel("") + + ax.set_yticklabels(edss_order, rotation=0, fontsize=9) + + if i == n_models - 1: + ax.set_xlabel("LLM confidence bracket", fontsize=10, fontweight="bold") + ax.set_xticklabels(confidence_order, rotation=0, fontsize=9) + else: + ax.set_xticklabels([]) + +fig.suptitle( + "Confidence-stratified EDSS accuracy by predicted severity range", + fontsize=14, + fontweight="bold", + y=0.98, +) + +fig.text( + 0.5, + 0.03, + "Cell color shows accuracy within ±0.5 EDSS; text shows accuracy and number of predictions. Empty grey cells indicate no predictions.", + ha="center", + va="center", + fontsize=9, + color="#555555", +) + +plt.subplots_adjust( + left=0.17, + right=0.89, + top=0.92, + bottom=0.09, + hspace=0.38 +) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) + + +## + +# %% Line plot: confidence-stratified EDSS accuracy by model + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_lineplot_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_lineplot_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_lineplot_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + print("No evaluable rows.") + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + merged["exact_match"] = merged["abs_error"] == 0 + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + + merged = merged.dropna(subset=["confidence_bracket"]).copy() + + print(f"Evaluable rows with confidence bracket: {len(merged)}") + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "inference_time_sec": row.get("inference_time_sec", np.nan), + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY BY CONFIDENCE BRACKET +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + MAE=("abs_error", "mean"), + median_absolute_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, bracket_order], + names=["model_display", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nConfidence-stratified accuracy table:") +print(summary) + + +# ========================= +# LINE PLOT +# ========================= + +x = np.arange(len(bracket_order)) + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +markers = { + "GPT-OSS-120B": "o", + "Qwen3.6-27B": "s", + "Gemma-4-31B-it": "^", +} + +fig, ax = plt.subplots(figsize=(9.5, 6)) + +for model in model_order: + df_m = ( + summary[summary["model_display"] == model] + .set_index("confidence_bracket") + .reindex(bracket_order) + .reset_index() + ) + + y = df_m["accuracy_within_0_5_percent"].values + n = df_m["n"].fillna(0).astype(int).values + + ax.plot( + x, + y, + marker=markers.get(model, "o"), + markersize=8, + linewidth=2.2, + color=colors.get(model), + label=model, + ) + + for xi, yi, ni in zip(x, y, n): + if pd.notna(yi) and ni > 0: + ax.text( + xi, + yi + 2.2, + f"{yi:.1f}%\nn={ni}", + ha="center", + va="bottom", + fontsize=8, + color=colors.get(model), + fontweight="bold", + ) + +ax.set_xticks(x) +ax.set_xticklabels(bracket_order, fontsize=10) + +ax.set_ylim(0, 110) +ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") +ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") + +ax.set_title( + "Confidence-stratified EDSS accuracy by model", + fontsize=14, + fontweight="bold", + pad=15, +) + +ax.set_yticks(np.arange(0, 101, 10)) +ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + +ax.yaxis.grid(True, linestyle="--", alpha=0.3) +ax.set_axisbelow(True) + +for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + +ax.legend( + loc="lower center", + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False, +) + +ax.text( + 0.5, + -0.18, + "Higher values indicate a larger proportion of predictions within ±0.5 EDSS of the reference score.", + transform=ax.transAxes, + ha="center", + va="top", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.05, 1, 0.92]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## +# %% Line plot: confidence-stratified EDSS accuracy by predicted EDSS range + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_lineplot_by_edss_range_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_range(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + print("\n" + "=" * 100) + print(f"Model folder: {model_dir.name}") + print(f"Result file: {result_file}") + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + print("No evaluable rows.") + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + merged["exact_match"] = merged["abs_error"] == 0 + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) + + merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() + + print(f"Evaluable rows with confidence bracket and EDSS range: {len(merged)}") + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "predicted_EDSS_range": row["predicted_EDSS_range"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "inference_time_sec": row.get("inference_time_sec", np.nan), + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +range_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + MAE=("abs_error", "mean"), + median_absolute_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, range_order, bracket_order], + names=["model_display", "predicted_EDSS_range", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nConfidence-stratified accuracy by predicted EDSS range:") +print(summary) + + +# ========================= +# PLOT: SMALL MULTIPLE LINE PLOT BY EDSS RANGE +# ========================= + +x = np.arange(len(bracket_order)) + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +markers = { + "GPT-OSS-120B": "o", + "Qwen3.6-27B": "s", + "Gemma-4-31B-it": "^", +} + +fig, axes = plt.subplots( + nrows=1, + ncols=len(range_order), + figsize=(5.1 * len(range_order), 5.8), + sharey=True +) + +if len(range_order) == 1: + axes = [axes] + +for ax, edss_r in zip(axes, range_order): + for model in model_order: + df_m = ( + summary[ + (summary["model_display"] == model) + & (summary["predicted_EDSS_range"] == edss_r) + ] + .set_index("confidence_bracket") + .reindex(bracket_order) + .reset_index() + ) + + y = df_m["accuracy_within_0_5_percent"].values + n = df_m["n"].fillna(0).astype(int).values + + ax.plot( + x, + y, + marker=markers.get(model, "o"), + markersize=7, + linewidth=2.0, + color=colors.get(model), + label=model, + ) + + for xi, yi, ni in zip(x, y, n): + if pd.notna(yi) and ni > 0: + ax.text( + xi, + yi + 2.2, + f"{yi:.0f}%\nn={ni}", + ha="center", + va="bottom", + fontsize=7, + color=colors.get(model), + fontweight="bold", + ) + + ax.set_title( + f"Predicted EDSS {edss_r}", + fontsize=12, + fontweight="bold", + pad=10, + ) + + ax.set_xticks(x) + ax.set_xticklabels(bracket_order, fontsize=8) + + ax.set_ylim(0, 112) + ax.set_yticks(np.arange(0, 101, 10)) + ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + + ax.grid(True, axis="y", linestyle="--", alpha=0.3) + ax.set_axisbelow(True) + + for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + + ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") + +axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") + +handles, labels = axes[0].get_legend_handles_labels() + +fig.legend( + handles, + labels, + loc="lower center", + bbox_to_anchor=(0.5, -0.01), + ncol=3, + frameon=False, +) + +fig.suptitle( + "Confidence-stratified EDSS accuracy by predicted EDSS range", + fontsize=14, + fontweight="bold", + y=1.02, +) + +fig.text( + 0.5, + 0.045, + "Each panel shows predictions within a predicted EDSS severity range. Point labels show accuracy and number of predictions.", + ha="center", + va="center", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.08, 1, 0.94]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## +# %% Dot plot: confidence-stratified EDSS accuracy by predicted EDSS range + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_dotplot_by_edss_range_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +# Hide very small cells from the plot but keep them in the CSV. +MIN_N_TO_PLOT = 5 + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_range(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + merged["exact_match"] = merged["abs_error"] == 0 + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) + + merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "predicted_EDSS_range": row["predicted_EDSS_range"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +range_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + MAE=("abs_error", "mean"), + median_absolute_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, range_order, bracket_order], + names=["model_display", "predicted_EDSS_range", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nSummary:") +print(summary) + + +# ========================= +# DOT PLOT +# ========================= + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +markers = { + "GPT-OSS-120B": "o", + "Qwen3.6-27B": "s", + "Gemma-4-31B-it": "^", +} + +x_positions = { + "Low\n<70%": 0, + "Moderate\n70–80%": 1, + "High\n80–90%": 2, + "Very high\n90–100%": 3, +} + +model_offsets = { + "GPT-OSS-120B": -0.18, + "Qwen3.6-27B": 0.00, + "Gemma-4-31B-it": 0.18, +} + +fig, axes = plt.subplots( + nrows=1, + ncols=len(range_order), + figsize=(14, 5.5), + sharey=True +) + +if len(range_order) == 1: + axes = [axes] + +for ax, edss_r in zip(axes, range_order): + df_r = summary[summary["predicted_EDSS_range"] == edss_r].copy() + + for model in model_order: + df_m = df_r[df_r["model_display"] == model].copy() + + for _, row in df_m.iterrows(): + n = row["n"] + acc = row["accuracy_within_0_5_percent"] + bracket = row["confidence_bracket"] + + if pd.isna(acc) or n < MIN_N_TO_PLOT: + continue + + x = x_positions[bracket] + model_offsets.get(model, 0) + + ax.scatter( + x, + acc, + s=45 + n * 2.0, + color=colors[model], + marker=markers[model], + alpha=0.85, + edgecolor="black", + linewidth=0.6, + label=model, + ) + + ax.text( + x, + acc + 2.0, + f"{acc:.0f}%\nn={int(n)}", + ha="center", + va="bottom", + fontsize=7, + color=colors[model], + fontweight="bold", + ) + + ax.set_title( + f"Predicted EDSS {edss_r}", + fontsize=12, + fontweight="bold", + pad=10, + ) + + ax.set_xticks(list(x_positions.values())) + ax.set_xticklabels(bracket_order, fontsize=8) + + ax.set_ylim(0, 112) + ax.set_yticks(np.arange(0, 101, 10)) + ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) + + ax.grid(True, axis="y", linestyle="--", alpha=0.3) + ax.set_axisbelow(True) + + for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + + ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") + +axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") + +handles = [ + plt.Line2D( + [0], + [0], + marker=markers[model], + color="w", + label=model, + markerfacecolor=colors[model], + markeredgecolor="black", + markersize=8, + ) + for model in model_order +] + +fig.legend( + handles=handles, + loc="lower center", + bbox_to_anchor=(0.5, -0.02), + ncol=3, + frameon=False, +) + +fig.suptitle( + "Confidence-stratified EDSS accuracy by predicted EDSS range", + fontsize=14, + fontweight="bold", + y=1.02, +) + +fig.text( + 0.5, + 0.045, + f"Points show accuracy within ±0.5 EDSS. Point size reflects n. Cells with n < {MIN_N_TO_PLOT} are hidden.", + ha="center", + va="center", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.08, 1, 0.94]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## +# %% Clean dot plot: confidence-stratified EDSS accuracy by predicted EDSS range + +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_dotplot_by_edss_range_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_iter_{TARGET_ITERATION}.svg" +OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_iter_{TARGET_ITERATION}.png" +OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_long_iter_{TARGET_ITERATION}.csv" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + +# Hide very small cells from the plot but keep them in the CSV. +MIN_N_TO_PLOT = 5 + +plt.rcParams["font.family"] = "Arial" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low\n<70%" + if certainty < 80: + return "Moderate\n70–80%" + if certainty < 90: + return "High\n80–90%" + if certainty <= 100: + return "Very high\n90–100%" + return np.nan + + +def edss_range(value): + if pd.isna(value): + return np.nan + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return np.nan + + +def size_from_n(n): + """ + Convert n to marker size. + """ + return 35 + (n * 4.0) + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + if "EDSS_is_numeric" in pred.columns: + pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() + + if "EDSS_in_valid_range" in pred.columns: + pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] + merged["abs_error"] = merged["error"].abs() + + merged["exact_match"] = merged["abs_error"] == 0 + merged["within_0_5"] = merged["abs_error"] <= 0.5 + merged["within_1_0"] = merged["abs_error"] <= 1.0 + + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) + + merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "predicted_EDSS_range": row["predicted_EDSS_range"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "error": row["error"], + "abs_error": row["abs_error"], + "exact_match": row["exact_match"], + "within_0_5": row["within_0_5"], + "within_1_0": row["within_1_0"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY +# ========================= + +bracket_order = [ + "Low\n<70%", + "Moderate\n70–80%", + "High\n80–90%", + "Very high\n90–100%", +] + +range_order = [ + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + exact_accuracy=("exact_match", "mean"), + accuracy_within_0_5=("within_0_5", "mean"), + accuracy_within_1_0=("within_1_0", "mean"), + MAE=("abs_error", "mean"), + median_absolute_error=("abs_error", "median"), + mean_confidence=("certainty_percent", "mean"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, range_order, bracket_order], + names=["model_display", "predicted_EDSS_range", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 +summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 +summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 +summary["shown_in_plot"] = summary["n"].fillna(0) >= MIN_N_TO_PLOT + +summary.to_csv(OUTPUT_TABLE, index=False) + +print("\nSummary:") +print(summary) + + +# ========================= +# CLEAN DOT PLOT +# ========================= + +colors = { + "GPT-OSS-120B": "#1F77B4", + "Qwen3.6-27B": "#FF7F0E", + "Gemma-4-31B-it": "#2CA02C", +} + +markers = { + "GPT-OSS-120B": "o", + "Qwen3.6-27B": "s", + "Gemma-4-31B-it": "^", +} + +x_positions = { + "Low\n<70%": 0, + "Moderate\n70–80%": 1, + "High\n80–90%": 2, + "Very high\n90–100%": 3, +} + +model_offsets = { + "GPT-OSS-120B": -0.18, + "Qwen3.6-27B": 0.00, + "Gemma-4-31B-it": 0.18, +} + +fig, axes = plt.subplots( + nrows=1, + ncols=len(range_order), + figsize=(14, 5.3), + sharey=True +) + +if len(range_order) == 1: + axes = [axes] + +for ax, edss_r in zip(axes, range_order): + df_r = summary[ + (summary["predicted_EDSS_range"] == edss_r) + & (summary["shown_in_plot"]) + ].copy() + + for model in model_order: + df_m = df_r[df_r["model_display"] == model].copy() + + for _, row in df_m.iterrows(): + n = int(row["n"]) + acc = row["accuracy_within_0_5_percent"] + bracket = row["confidence_bracket"] + + if pd.isna(acc): + continue + + x = x_positions[bracket] + model_offsets.get(model, 0) + + ax.scatter( + x, + acc, + s=size_from_n(n), + color=colors[model], + marker=markers[model], + alpha=0.85, + edgecolor="black", + linewidth=0.6, + ) + + ax.set_title( + f"Predicted EDSS {edss_r}", + fontsize=12, + fontweight="bold", + pad=10, + ) + + ax.set_xticks(list(x_positions.values())) + ax.set_xticklabels(bracket_order, fontsize=8) + + ax.set_ylim(0, 105) + ax.set_yticks(np.arange(0, 101, 20)) + ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 20)]) + + ax.grid(True, axis="y", linestyle="--", alpha=0.3) + ax.set_axisbelow(True) + + for spine in ["top", "right"]: + ax.spines[spine].set_visible(False) + + ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") + +axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") + + +# ========================= +# LEGENDS +# ========================= + +model_handles = [ + plt.Line2D( + [0], + [0], + marker=markers[model], + color="w", + label=model, + markerfacecolor=colors[model], + markeredgecolor="black", + markersize=8, + ) + for model in model_order +] + +fig.legend( + handles=model_handles, + loc="lower center", + bbox_to_anchor=(0.43, -0.01), + ncol=3, + frameon=False, + title="Model", +) + +size_values = [10, 50, 100, 200] +max_n = int(summary["n"].fillna(0).max()) +size_values = [n for n in size_values if n <= max_n] + +if size_values: + size_handles = [ + plt.scatter( + [], + [], + s=size_from_n(n), + color="lightgray", + edgecolor="black", + alpha=0.85, + label=f"n={n}", + ) + for n in size_values + ] + + fig.legend( + handles=size_handles, + loc="lower center", + bbox_to_anchor=(0.78, -0.01), + ncol=len(size_handles), + frameon=False, + title="Point size", + ) + +fig.suptitle( + "Confidence-stratified EDSS accuracy by predicted EDSS range", + fontsize=14, + fontweight="bold", + y=1.02, +) + +fig.text( + 0.5, + 0.045, + f"Points show accuracy within ±0.5 EDSS. Point size reflects n. Groups with n < {MIN_N_TO_PLOT} are omitted from the figure.", + ha="center", + va="center", + fontsize=9, + color="#555555", +) + +plt.tight_layout(rect=[0, 0.10, 1, 0.94]) + +plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") +plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") + +plt.show() + +print("\nSaved:") +print(OUTPUT_SVG) +print(OUTPUT_PNG) +print(OUTPUT_TABLE) +print(OUTPUT_LONG) +## + + +# %% Confidence x predicted EDSS range table + + +from pathlib import Path +import pandas as pd +import numpy as np + + +# ========================= +# CONFIGURATION +# ========================= + +GT_PATH = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" + "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +) + +RUN_DIR = Path( + "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" +) + +TARGET_ITERATION = 1 + +OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_table_iter_{TARGET_ITERATION}" +OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + +OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" +OUTPUT_SUMMARY = OUTPUT_DIR / f"confidence_accuracy_summary_iter_{TARGET_ITERATION}.csv" +OUTPUT_WIDE_CSV = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_WIDE_MD = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.md" +OUTPUT_WIDE_XLSX = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.xlsx" + +GT_EDSS_COL = "EDSS" +PRED_EDSS_COL = "EDSS_numeric" +PRED_EDSS_FALLBACK_COL = "EDSS" +CERTAINTY_COL = "certainty_percent" + + +# ========================= +# HELPERS +# ========================= + +def to_num(s): + return pd.to_numeric( + s.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def to_bool(s): + return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) + + +def clean_model_name(name): + replacements = { + "gpt-oss-120b": "GPT-OSS-120B", + "qwen3.6-27b": "Qwen3.6-27B", + "gemma-4-31B-it": "Gemma-4-31B-it", + } + return replacements.get(str(name), str(name)) + + +def find_iter_file(model_dir, iteration): + files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) + + files = [ + f for f in files + if "incremental" not in f.name.lower() + and "summary" not in f.name.lower() + and "all_results" not in f.name.lower() + ] + + return files[0] if files else None + + +def get_model_name(df, model_dir): + if "model" in df.columns and df["model"].notna().any(): + return str(df["model"].dropna().iloc[0]) + return model_dir.name + + +def confidence_bracket(certainty): + if pd.isna(certainty): + return np.nan + if certainty < 70: + return "Low (<70%)" + if certainty < 80: + return "Moderate (70–80%)" + if certainty < 90: + return "High (80–90%)" + if certainty <= 100: + return "Very high (90–100%)" + return np.nan + + +def edss_range_with_missing(value): + if pd.isna(value): + return "Missing EDSS" + if 0.0 <= value <= 3.5: + return "0.0–3.5" + if 4.0 <= value <= 5.5: + return "4.0–5.5" + if 6.0 <= value <= 10.0: + return "6.0–10.0" + return "Invalid EDSS" + + +def format_cell(acc, n): + if pd.isna(n) or int(n) == 0: + return "—" + if pd.isna(acc): + return f"NA (n={int(n)})" + return f"{acc:.1f}% (n={int(n)})" + + +# ========================= +# LOAD GROUND TRUTH +# ========================= + +gt = pd.read_csv(GT_PATH, sep=";") +gt["row_index"] = gt.index +gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) +gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() + +print(f"GT rows with numeric EDSS: {len(gt)}") + + +# ========================= +# BUILD LONG DATA +# ========================= + +long_rows = [] + +model_dirs = [ + p for p in sorted(RUN_DIR.iterdir()) + if p.is_dir() + and not p.name.startswith("confusion") + and not p.name.startswith("functional_system") + and not p.name.startswith("repeated_run") + and not p.name.startswith("edss_error_distribution") + and not p.name.startswith("edss_threshold_metrics") + and not p.name.startswith("edss_severity_group_metrics") + and not p.name.startswith("structured_output_validity") + and not p.name.startswith("confidence") +] + +for model_dir in model_dirs: + result_file = find_iter_file(model_dir, TARGET_ITERATION) + + if result_file is None: + print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") + continue + + pred_raw = pd.read_csv(result_file, sep=",") + + if "row_index" not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no row_index column.") + continue + + if CERTAINTY_COL not in pred_raw.columns: + print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") + continue + + model_name = get_model_name(pred_raw, model_dir) + model_display = clean_model_name(model_name) + + pred = pred_raw.copy() + + pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") + pred = pred.dropna(subset=["row_index"]).copy() + pred["row_index"] = pred["row_index"].astype(int) + + if "success" in pred.columns: + pred = pred[to_bool(pred["success"])].copy() + + pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL + + pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) + pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) + + # Keep missing EDSS predictions, but require confidence. + pred = pred.dropna(subset=["certainty_numeric"]).copy() + pred = pred.drop_duplicates("row_index", keep="first").copy() + + merged = gt.merge( + pred, + on="row_index", + how="inner", + suffixes=("_gt", "_pred") + ) + + if merged.empty: + continue + + merged["has_numeric_prediction"] = merged["PRED_EDSS_numeric"].notna() + merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range_with_missing) + merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) + + merged = merged.dropna(subset=["confidence_bracket"]).copy() + + merged["abs_error"] = np.where( + merged["has_numeric_prediction"], + (merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"]).abs(), + np.nan + ) + + # Missing EDSS counts as not within ±0.5. + merged["within_0_5"] = np.where( + merged["has_numeric_prediction"], + merged["abs_error"] <= 0.5, + False + ) + + for _, row in merged.iterrows(): + long_rows.append({ + "model": model_name, + "model_display": model_display, + "iteration": TARGET_ITERATION, + "row_index": row["row_index"], + "GT_EDSS_numeric": row["GT_EDSS_numeric"], + "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], + "has_numeric_prediction": row["has_numeric_prediction"], + "predicted_EDSS_range": row["predicted_EDSS_range"], + "certainty_percent": row["certainty_numeric"], + "confidence_bracket": row["confidence_bracket"], + "abs_error": row["abs_error"], + "within_0_5": row["within_0_5"], + "result_file": str(result_file), + }) + + +long_df = pd.DataFrame(long_rows) + +if long_df.empty: + raise ValueError("No evaluable rows found.") + +long_df.to_csv(OUTPUT_LONG, index=False) + + +# ========================= +# SUMMARY TABLE +# ========================= + +confidence_order = [ + "Low (<70%)", + "Moderate (70–80%)", + "High (80–90%)", + "Very high (90–100%)", +] + +range_order = [ + "Missing EDSS", + "0.0–3.5", + "4.0–5.5", + "6.0–10.0", + "Invalid EDSS", +] + +model_order = [ + "GPT-OSS-120B", + "Qwen3.6-27B", + "Gemma-4-31B-it", +] + +model_order = [ + m for m in model_order + if m in long_df["model_display"].unique() +] + +range_order = [ + r for r in range_order + if r in long_df["predicted_EDSS_range"].unique() +] + +summary = ( + long_df + .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .agg( + n=("within_0_5", "count"), + accuracy_within_0_5_percent=("within_0_5", lambda x: x.mean() * 100), + n_numeric_predictions=("has_numeric_prediction", "sum"), + mean_abs_error=("abs_error", "mean"), + median_abs_error=("abs_error", "median"), + ) + .reset_index() +) + +full_index = pd.MultiIndex.from_product( + [model_order, range_order, confidence_order], + names=["model_display", "predicted_EDSS_range", "confidence_bracket"] +) + +summary = ( + summary + .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) + .reindex(full_index) + .reset_index() +) + +summary["n"] = summary["n"].fillna(0).astype(int) +summary["n_numeric_predictions"] = summary["n_numeric_predictions"].fillna(0).astype(int) + +summary.to_csv(OUTPUT_SUMMARY, index=False) + + +# ========================= +# WIDE TABLE FOR PAPER +# ========================= + +summary["cell"] = summary.apply( + lambda row: format_cell( + row["accuracy_within_0_5_percent"], + row["n"] + ), + axis=1 +) + +wide = ( + summary + .pivot_table( + index=["model_display", "predicted_EDSS_range"], + columns="confidence_bracket", + values="cell", + aggfunc="first" + ) + .reindex(index=pd.MultiIndex.from_product( + [model_order, range_order], + names=["model_display", "predicted_EDSS_range"] + )) + .reindex(columns=confidence_order) + .reset_index() +) + +wide = wide.rename(columns={ + "model_display": "Model", + "predicted_EDSS_range": "Predicted EDSS range", +}) + +wide.to_csv(OUTPUT_WIDE_CSV, index=False) + +with open(OUTPUT_WIDE_MD, "w", encoding="utf-8") as f: + f.write(wide.to_markdown(index=False)) + f.write("\n") + +wide.to_excel(OUTPUT_WIDE_XLSX, index=False) + + +# ========================= +# PRINT OUTPUT +# ========================= + +pd.set_option("display.max_columns", None) +pd.set_option("display.width", 220) +pd.set_option("display.max_colwidth", None) + +print("\nWide confidence accuracy table:") +print(wide.to_markdown(index=False)) + +print("\nSaved:") +print(OUTPUT_LONG) +print(OUTPUT_SUMMARY) +print(OUTPUT_WIDE_CSV) +print(OUTPUT_WIDE_MD) +print(OUTPUT_WIDE_XLSX) +## + + +# %% name +# ========================= +# ALTERNATIVE WIDE TABLE FOR PAPER +# Rows: Predicted EDSS range + Confidence bracket +# Columns: Models +# ========================= + +summary["cell"] = summary.apply( + lambda row: format_cell( + row["accuracy_within_0_5_percent"], + row["n"] + ), + axis=1 +) + +model_as_columns = ( + summary + .pivot_table( + index=["predicted_EDSS_range", "confidence_bracket"], + columns="model_display", + values="cell", + aggfunc="first" + ) + .reindex( + index=pd.MultiIndex.from_product( + [range_order, confidence_order], + names=["predicted_EDSS_range", "confidence_bracket"] + ) + ) + .reindex(columns=model_order) + .reset_index() +) + +model_as_columns = model_as_columns.rename(columns={ + "predicted_EDSS_range": "Predicted EDSS range", + "confidence_bracket": "Confidence bracket", +}) + +OUTPUT_MODEL_COLUMNS_CSV = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.csv" +OUTPUT_MODEL_COLUMNS_MD = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.md" +OUTPUT_MODEL_COLUMNS_XLSX = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.xlsx" + +model_as_columns.to_csv(OUTPUT_MODEL_COLUMNS_CSV, index=False) + +with open(OUTPUT_MODEL_COLUMNS_MD, "w", encoding="utf-8") as f: + f.write(model_as_columns.to_markdown(index=False)) + f.write("\n") + +model_as_columns.to_excel(OUTPUT_MODEL_COLUMNS_XLSX, index=False) + +print("\nModel-as-columns confidence accuracy table:") +print(model_as_columns.to_markdown(index=False)) + +print("\nSaved alternative table:") +print(OUTPUT_MODEL_COLUMNS_CSV) +print(OUTPUT_MODEL_COLUMNS_MD) +print(OUTPUT_MODEL_COLUMNS_XLSX) +## + +