diff --git a/.gitignore b/.gitignore index 966ab40..4f9c622 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,11 @@ __pycache__/ ======= /reference/ *.svg +**/*.csv +**/*.json* +**/*.txt* +**/*.png* +*.log >>>>>>> Stashed changes # 2. Ignore virtual environments COMPLETELY # This must come BEFORE the unignore rule diff --git a/app.py b/app.py index def7c48..373da9a 100644 --- a/app.py +++ b/app.py @@ -216,6 +216,3 @@ if __name__ == "__main__": -# %% name -eXXXXXXXX -## diff --git a/certainty.py b/certainty.py index cfc11be..fac6998 100644 --- a/certainty.py +++ b/certainty.py @@ -359,166 +359,1563 @@ # %% API call - Multi-iteration EDSS + certainty extraction +# +#import time +#import json +#import os +#from datetime import datetime +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +## Load environment variables +#load_dotenv() +# +## === CONFIGURATION === +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +#MODEL_NAME = "GPT-OSS-120B" +# +## File paths +#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +# +## Iteration settings +#NUM_ITERATIONS = 20 +#STOP_ON_FIRST_ERROR = False # Set to True for debugging +# +## Initialize OpenAI client +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +## Read EDSS instructions from file +#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: +# EDSS_INSTRUCTIONS = f.read().strip() +# +## === PROMPT (unchanged from before) === +#def build_prompt(patient_text): +# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren. +# +#### Deine Aufgabe: +#1. Analysiere den Patientenbericht und extrahiere: +# - Den Gesamt-EDSS-Score (0.0–10.0) +# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl) +#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein. +# +#### Struktur der JSON-Ausgabe (VERPFLICHTEND): +#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter. +# +#{{ +# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).", +# "klassifizierbar": true/false, +# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)", +# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)", +# "subcategories": {{ +# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0 +# }} +#}} +# +#### Regeln: +#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest. +#- **klassifizierbar**: +# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind. +# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist. +#- **EDSS**: +# - **VERPFLICHTEND**, wenn `klassifizierbar=true`. +# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`. +#- **certainty_percent**: +# - **Immer present** — Ganzzahl (0–100), basierend auf: +# - Klarheit und Vollständigkeit der Berichtsangaben, +# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz), +# - Konsistenz zwischen den Unterkategorien. +#- **subcategories**: +# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein. +# - Jeder Wert ist entweder: +# - `null` (wenn keine ausreichende Information vorliegt), **oder** +# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0). +# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab. +# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`. +# +#### EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +#''' +# +## === INFERENCE FUNCTION (unchanged) === +#def run_inference(patient_text): +# prompt = build_prompt(patient_text) +# +# start_time = time.time() +# +# try: +# response = client.chat.completions.create( +# messages=[ +# {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."} +# ] + [ +# {"role": "user", "content": prompt} +# ], +# model=MODEL_NAME, +# max_tokens=2048, +# temperature=0.1, +# response_format={"type": "json_object"} +# ) +# +# content = response.choices[0].message.content +# +# # Parse and validate JSON +# try: +# parsed = json.loads(content) +# except json.JSONDecodeError as e: +# print(f"⚠️ JSON parsing failed: {e}") +# print("Raw response:", content[:500]) +# raise ValueError("Model did not return valid JSON") +# +# # Enforce required keys +# if "certainty_percent" not in parsed: +# print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.") +# parsed["certainty_percent"] = 0 +# elif not isinstance(parsed["certainty_percent"], (int, float)): +# parsed["certainty_percent"] = int(parsed["certainty_percent"]) +# +# # Clamp certainty to [0, 100] +# pct = parsed["certainty_percent"] +# parsed["certainty_percent"] = max(0, min(100, int(pct))) +# +# # Enforce EDSS rules +# if not parsed.get("klassifizierbar", False): +# if "EDSS" in parsed: +# del parsed["EDSS"] +# else: +# if "EDSS" not in parsed: +# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.") +# parsed["EDSS"] = 7.0 +# +# inference_time = time.time() - start_time +# +# return { +# "success": True, +# "result": parsed, +# "inference_time_sec": inference_time +# } +# +# except Exception as e: +# print(f"❌ Inference error: {e}") +# return { +# "success": False, +# "error": str(e), +# "inference_time_sec": -1, +# "result": None +# } +# +## === BUILD PATIENT TEXT === +#def build_patient_text(row): +# return ( +# str(row.get("T_Zusammenfassung", "")) + "\n" + +# str(row.get("Diagnosen", "")) + "\n" + +# str(row.get("T_KlinBef", "")) + "\n" + +# str(row.get("T_Befunde", "")) +# ) +# +## === MAIN LOOP (NEW: MULTI-ITERATION) === +#if __name__ == "__main__": +# # Load data ONCE (to avoid repeated I/O overhead) +# df = pd.read_csv(INPUT_CSV, sep=';') +# total_rows = len(df) +# print(f"Loaded {total_rows} patient records.") +# +# for iteration in range(1, NUM_ITERATIONS + 1): +# print(f"\n{'='*60}") +# print(f"🔄 ITERATION {iteration}/{NUM_ITERATIONS}") +# print(f"{'='*60}") +# +# iteration_results = [] +# start_iter = time.time() +# +# for idx, row in df.iterrows(): +# print(f"\rRow {idx+1}/{total_rows} | Iter {iteration}", end='', flush=True) +# try: +# patient_text = build_patient_text(row) +# result = run_inference(patient_text) +# +# # Attach metadata +# if result["success"]: +# res = result["result"].copy() # avoid mutation +# res["iteration"] = iteration +# res["unique_id"] = row.get("unique_id", f"row_{idx}") +# res["MedDatum"] = row.get("MedDatum", None) +# result["result"] = res +# +# else: +# result["iteration"] = iteration +# result["unique_id"] = row.get("unique_id", f"row_{idx}") +# result["MedDatum"] = row.get("MedDatum", None) +# +# iteration_results.append(result) +# +# if result["success"]: +# res = result["result"] +# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" +# print(f" ✅ EDSS={edss}, cert={res.get('certainty_percent', '?')}%") +# else: +# print(f" ❌ {result.get('error', 'Unknown')}") +# +# except Exception as e: +# print(f"\n⚠️ Row {idx} failed: {e}") +# iteration_results.append({ +# "success": False, +# "error": str(e), +# "iteration": iteration, +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None), +# "result": None +# }) +# if STOP_ON_FIRST_ERROR: +# break +# +# # Save per-iteration results +# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +# output_path = INPUT_CSV.replace(".csv", f"_results_iter_{iteration}_{timestamp}.json") +# with open(output_path, 'w', encoding='utf-8') as f: +# json.dump(iteration_results, f, indent=2, ensure_ascii=False) +# print(f"\n✅ Iteration {iteration} complete. Saved to: {output_path}") +# +# elapsed = time.time() - start_iter +# print(f"⏱️ Iteration {iteration} took {elapsed:.1f}s ({elapsed/total_rows:.1f}s/row)") +# +# print(f"\n🎉 All {NUM_ITERATIONS} iterations completed!") +# +# + +## + + + +# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark +# +#import time +#import json +#import os +#import re +#import threading +#from datetime import datetime +#from pathlib import Path +# +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +#try: +# import psutil +#except ImportError: +# psutil = None +# print("⚠️ psutil is not installed. Resource metrics will be limited.") +# print("Install with: pip install psutil") +# +# +## ========================= +## CONFIGURATION +## ========================= +# +#load_dotenv() +# +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +# +#MODEL_NAMES = [ +# # "GPT-OSS-120B", +# "qwen3.6-27b", +# # "gemma-4-31B-it", +#] +# +#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +# +#RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark" +# +#NUM_ITERATIONS = 20 +#STOP_ON_FIRST_ERROR = False +# +#MAX_TOKENS = 2048 +#TEMPERATURE = 0.1 +# +## Memory sampling interval during one inference call +#RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 +# +# +## ========================= +## CLIENT +## ========================= +# +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +# +## ========================= +## HELPERS +## ========================= +# +#def safe_dir_name(name: str) -> str: +# """ +# Convert model name to a filesystem-safe directory name. +# """ +# name = str(name).strip() +# name = re.sub(r"[^\w\-.]+", "_", name) +# return name[:150] +# +# +#def now_timestamp() -> str: +# return datetime.now().strftime("%Y%m%d_%H%M%S") +# +# +#def get_process(): +# if psutil is None: +# return None +# return psutil.Process(os.getpid()) +# +# +#def get_memory_rss_mb(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# return process.memory_info().rss / (1024 * 1024) +# +# +#def get_cpu_times_sec(process=None): +# if psutil is None: +# return None +# if process is None: +# process = get_process() +# cpu_times = process.cpu_times() +# return cpu_times.user + cpu_times.system +# +# +#class ResourceSampler: +# """ +# Samples process RSS while an inference call is running. +# Useful for approximating peak memory usage per request. +# """ +# +# def __init__(self, interval_sec=0.05): +# self.interval_sec = interval_sec +# self.process = get_process() +# self.running = False +# self.thread = None +# self.samples_mb = [] +# +# def start(self): +# if psutil is None: +# return +# +# self.running = True +# self.samples_mb = [] +# self.thread = threading.Thread(target=self._sample_loop, daemon=True) +# self.thread.start() +# +# def stop(self): +# if psutil is None: +# return +# +# self.running = False +# if self.thread is not None: +# self.thread.join(timeout=1.0) +# +# def _sample_loop(self): +# while self.running: +# try: +# rss_mb = get_memory_rss_mb(self.process) +# self.samples_mb.append(rss_mb) +# except Exception: +# pass +# time.sleep(self.interval_sec) +# +# @property +# def peak_rss_mb(self): +# if not self.samples_mb: +# return None +# return max(self.samples_mb) +# +# +## ========================= +## READ INSTRUCTIONS +## ========================= +# +#with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: +# EDSS_INSTRUCTIONS = f.read().strip() +# +# +## ========================= +## PROMPT +## ========================= +# +#def build_prompt(patient_text): +# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren. +# +#### Deine Aufgabe: +#1. Analysiere den Patientenbericht und extrahiere: +# - Den Gesamt-EDSS-Score (0.0–10.0) +# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl) +#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein. +# +#### Struktur der JSON-Ausgabe (VERPFLICHTEND): +#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter. +# +#{{ +# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).", +# "klassifizierbar": true/false, +# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)", +# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)", +# "subcategories": {{ +# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0 +# }} +#}} +# +#### Regeln: +#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest. +#- **klassifizierbar**: +# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind. +# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist. +#- **EDSS**: +# - **VERPFLICHTEND**, wenn `klassifizierbar=true`. +# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`. +#- **certainty_percent**: +# - **Immer present** — Ganzzahl (0–100), basierend auf: +# - Klarheit und Vollständigkeit der Berichtsangaben, +# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz), +# - Konsistenz zwischen den Unterkategorien. +#- **subcategories**: +# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein. +# - Jeder Wert ist entweder: +# - `null` (wenn keine ausreichende Information vorliegt), **oder** +# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0). +# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab. +# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`. +# +#### EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +#''' +# +# +## ========================= +## VALIDATION / NORMALIZATION +## ========================= +# +#def normalize_model_output(parsed): +# """ +# Keeps your existing validation behavior, with a few extra safety checks. +# """ +# +# if not isinstance(parsed, dict): +# raise ValueError("Parsed model output is not a JSON object") +# +# if "certainty_percent" not in parsed: +# print("⚠️ Missing 'certainty_percent' in output. Force-adding fallback.") +# parsed["certainty_percent"] = 0 +# elif not isinstance(parsed["certainty_percent"], (int, float)): +# parsed["certainty_percent"] = int(parsed["certainty_percent"]) +# +# parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"]))) +# +# if "klassifizierbar" not in parsed: +# parsed["klassifizierbar"] = False +# +# if not parsed.get("klassifizierbar", False): +# parsed.pop("EDSS", None) +# else: +# if "EDSS" not in parsed: +# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.") +# parsed["EDSS"] = 7.0 +# +# required_subcategories = { +# "VISUAL_OPTIC_FUNCTIONS": 6.0, +# "BRAINSTEM_FUNCTIONS": 6.0, +# "PYRAMIDAL_FUNCTIONS": 6.0, +# "CEREBELLAR_FUNCTIONS": 6.0, +# "SENSORY_FUNCTIONS": 6.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": 6.0, +# "CEREBRAL_FUNCTIONS": 6.0, +# "AMBULATION": 10.0, +# } +# +# if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict): +# parsed["subcategories"] = {} +# +# for key in required_subcategories: +# if key not in parsed["subcategories"]: +# parsed["subcategories"][key] = None +# +# return parsed +# +# +## ========================= +## INFERENCE FUNCTION +## ========================= +# +#def run_inference(patient_text, model_name): +# prompt = build_prompt(patient_text) +# +# process = get_process() +# sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) +# +# wall_start = time.perf_counter() +# cpu_start = get_cpu_times_sec(process) +# rss_start_mb = get_memory_rss_mb(process) +# +# sampler.start() +# +# try: +# response = client.chat.completions.create( +# messages=[ +# { +# "role": "system", +# "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen." +# }, +# { +# "role": "user", +# "content": prompt +# } +# ], +# model=model_name, +# max_tokens=MAX_TOKENS, +# temperature=TEMPERATURE, +# response_format={"type": "json_object"} +# ) +# +# content = response.choices[0].message.content +# +# try: +# parsed = json.loads(content) +# except json.JSONDecodeError as e: +# print(f"⚠️ JSON parsing failed: {e}") +# print("Raw response:", content[:500]) +# raise ValueError("Model did not return valid JSON") +# +# parsed = normalize_model_output(parsed) +# +# usage = getattr(response, "usage", None) +# +# if usage is not None: +# prompt_tokens = getattr(usage, "prompt_tokens", None) +# completion_tokens = getattr(usage, "completion_tokens", None) +# total_tokens = getattr(usage, "total_tokens", None) +# else: +# prompt_tokens = None +# completion_tokens = None +# total_tokens = None +# +# success = True +# error = None +# result = parsed +# +# except Exception as e: +# print(f"❌ Inference error: {e}") +# +# success = False +# error = str(e) +# result = None +# prompt_tokens = None +# completion_tokens = None +# total_tokens = None +# +# finally: +# sampler.stop() +# +# wall_end = time.perf_counter() +# cpu_end = get_cpu_times_sec(process) +# rss_end_mb = get_memory_rss_mb(process) +# +# wall_time_sec = wall_end - wall_start +# +# if cpu_start is not None and cpu_end is not None: +# process_cpu_time_sec = cpu_end - cpu_start +# else: +# process_cpu_time_sec = None +# +# if rss_start_mb is not None and rss_end_mb is not None: +# rss_delta_mb = rss_end_mb - rss_start_mb +# else: +# rss_delta_mb = None +# +# return { +# "success": success, +# "error": error, +# "result": result, +# +# "model": model_name, +# +# # Main inference timing +# "inference_time_sec": wall_time_sec, +# +# # Resource metrics +# "process_cpu_time_sec": process_cpu_time_sec, +# "rss_before_mb": rss_start_mb, +# "rss_after_mb": rss_end_mb, +# "rss_delta_mb": rss_delta_mb, +# "peak_rss_mb": sampler.peak_rss_mb, +# +# # Token metrics, when available +# "prompt_tokens": prompt_tokens, +# "completion_tokens": completion_tokens, +# "total_tokens": total_tokens, +# } +# +# +## ========================= +## BUILD PATIENT TEXT +## ========================= +# +#def build_patient_text(row): +# return ( +# str(row.get("T_Zusammenfassung", "")) + "\n" + +# str(row.get("Diagnosen", "")) + "\n" + +# str(row.get("T_KlinBef", "")) + "\n" + +# str(row.get("T_Befunde", "")) +# ) +# +# +## ========================= +## FLATTEN RESULTS FOR CSV +## ========================= +# +#def flatten_result(record): +# """ +# Converts one result record to a flat row for CSV export. +# """ +# +# flat = { +# "model": record.get("model"), +# "iteration": record.get("iteration"), +# "row_index": record.get("row_index"), +# "unique_id": record.get("unique_id"), +# "MedDatum": record.get("MedDatum"), +# +# "success": record.get("success"), +# "error": record.get("error"), +# +# "inference_time_sec": record.get("inference_time_sec"), +# "process_cpu_time_sec": record.get("process_cpu_time_sec"), +# "rss_before_mb": record.get("rss_before_mb"), +# "rss_after_mb": record.get("rss_after_mb"), +# "rss_delta_mb": record.get("rss_delta_mb"), +# "peak_rss_mb": record.get("peak_rss_mb"), +# +# "prompt_tokens": record.get("prompt_tokens"), +# "completion_tokens": record.get("completion_tokens"), +# "total_tokens": record.get("total_tokens"), +# } +# +# result = record.get("result") +# +# if isinstance(result, dict): +# flat["reason"] = result.get("reason") +# flat["klassifizierbar"] = result.get("klassifizierbar") +# flat["EDSS"] = result.get("EDSS") +# flat["certainty_percent"] = result.get("certainty_percent") +# +# subcats = result.get("subcategories", {}) +# if isinstance(subcats, dict): +# for key, value in subcats.items(): +# flat[f"subcat_{key}"] = value +# +# return flat +# +# +#def summarize_records(records): +# """ +# Creates summary statistics for one model over all iterations. +# """ +# +# df = pd.DataFrame([flatten_result(r) for r in records]) +# +# if df.empty: +# return pd.DataFrame() +# +# summary = { +# "model": df["model"].iloc[0] if "model" in df.columns else None, +# "n_records": len(df), +# "n_success": int(df["success"].sum()) if "success" in df.columns else None, +# "n_failed": int((~df["success"]).sum()) if "success" in df.columns else None, +# "success_rate": float(df["success"].mean()) if "success" in df.columns else None, +# } +# +# numeric_cols = [ +# "inference_time_sec", +# "process_cpu_time_sec", +# "rss_delta_mb", +# "peak_rss_mb", +# "prompt_tokens", +# "completion_tokens", +# "total_tokens", +# "certainty_percent", +# "EDSS", +# ] +# +# for col in numeric_cols: +# if col in df.columns: +# values = pd.to_numeric(df[col], errors="coerce") +# summary[f"{col}_mean"] = values.mean() +# summary[f"{col}_median"] = values.median() +# summary[f"{col}_std"] = values.std() +# summary[f"{col}_min"] = values.min() +# summary[f"{col}_max"] = values.max() +# +# return pd.DataFrame([summary]) +# +# +## ========================= +## MAIN LOOP +## ========================= +# +#if __name__ == "__main__": +# +# run_timestamp = now_timestamp() +# +# results_root = Path(RESULTS_ROOT) +# results_root.mkdir(parents=True, exist_ok=True) +# +# run_root = results_root / f"run_{run_timestamp}" +# run_root.mkdir(parents=True, exist_ok=True) +# +# print(f"Results root: {run_root}") +# +# df = pd.read_csv(INPUT_CSV, sep=";") +# total_rows = len(df) +# +# print(f"Loaded {total_rows} patient records.") +# print(f"Models: {MODEL_NAMES}") +# print(f"Iterations per model: {NUM_ITERATIONS}") +# +# all_model_summaries = [] +# +# for model_name in MODEL_NAMES: +# safe_model = safe_dir_name(model_name) +# model_dir = run_root / safe_model +# model_dir.mkdir(parents=True, exist_ok=True) +# +# print(f"\n{'#' * 80}") +# print(f"MODEL: {model_name}") +# print(f"Saving to: {model_dir}") +# print(f"{'#' * 80}") +# +# model_records = [] +# model_start = time.perf_counter() +# +# for iteration in range(1, NUM_ITERATIONS + 1): +# print(f"\n{'=' * 60}") +# print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") +# print(f"{'=' * 60}") +# +# iteration_results = [] +# iteration_start = time.perf_counter() +# +# for idx, row in df.iterrows(): +# print( +# f"\rModel={model_name} | Row {idx + 1}/{total_rows} | Iter {iteration}", +# end="", +# flush=True +# ) +# +# try: +# patient_text = build_patient_text(row) +# record = run_inference(patient_text, model_name=model_name) +# +# record["iteration"] = iteration +# record["row_index"] = int(idx) +# record["unique_id"] = row.get("unique_id", f"row_{idx}") +# record["MedDatum"] = row.get("MedDatum", None) +# +# iteration_results.append(record) +# model_records.append(record) +# +# if record["success"]: +# res = record["result"] +# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" +# print( +# f" ✅ EDSS={edss}, " +# f"cert={res.get('certainty_percent', '?')}%, " +# f"time={record['inference_time_sec']:.2f}s" +# ) +# else: +# print(f" ❌ {record.get('error', 'Unknown error')}") +# +# except Exception as e: +# print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") +# +# fallback_record = { +# "success": False, +# "error": str(e), +# "result": None, +# +# "model": model_name, +# "iteration": iteration, +# "row_index": int(idx), +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None), +# +# "inference_time_sec": None, +# "process_cpu_time_sec": None, +# "rss_before_mb": None, +# "rss_after_mb": None, +# "rss_delta_mb": None, +# "peak_rss_mb": None, +# +# "prompt_tokens": None, +# "completion_tokens": None, +# "total_tokens": None, +# } +# +# iteration_results.append(fallback_record) +# model_records.append(fallback_record) +# +# if STOP_ON_FIRST_ERROR: +# break +# +# iteration_elapsed = time.perf_counter() - iteration_start +# +# # Save per-iteration JSON +# iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" +# with open(iter_json_path, "w", encoding="utf-8") as f: +# json.dump(iteration_results, f, indent=2, ensure_ascii=False) +# +# # Save per-iteration CSV +# iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" +# iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) +# iter_flat_df.to_csv(iter_csv_path, index=False) +# +# print(f"\n✅ Iteration {iteration} complete.") +# print(f"JSON saved to: {iter_json_path}") +# print(f"CSV saved to: {iter_csv_path}") +# print( +# f"⏱️ Iteration time: {iteration_elapsed:.1f}s " +# f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" +# ) +# +# model_elapsed = time.perf_counter() - model_start +# +# # Save all records for this model +# model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" +# with open(model_json_path, "w", encoding="utf-8") as f: +# json.dump(model_records, f, indent=2, ensure_ascii=False) +# +# model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" +# model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) +# model_flat_df.to_csv(model_csv_path, index=False) +# +# # Save model summary +# model_summary_df = summarize_records(model_records) +# model_summary_df["model_total_wall_time_sec"] = model_elapsed +# model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 +# +# model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" +# model_summary_df.to_csv(model_summary_path, index=False) +# +# all_model_summaries.append(model_summary_df) +# +# print(f"\n🎉 Model completed: {model_name}") +# print(f"All JSON: {model_json_path}") +# print(f"All CSV: {model_csv_path}") +# print(f"Summary: {model_summary_path}") +# print(f"Total model time: {model_elapsed / 60:.2f} min") +# +# # Save combined model summaries +# if all_model_summaries: +# combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) +# combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" +# combined_summary_df.to_csv(combined_summary_path, index=False) +# +# print(f"\n📊 Combined summary saved to: {combined_summary_path}") +# +# print(f"\n🎉 All models and all iterations completed!") +# +## + + + + + + +# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark import time import json import os +import re +import threading from datetime import datetime +from pathlib import Path + import pandas as pd from openai import OpenAI from dotenv import load_dotenv -# Load environment variables +try: + import psutil +except ImportError: + psutil = None + print("⚠️ psutil is not installed. Resource metrics will be limited.") + print("Install with: pip install psutil") + + +# ========================= +# CONFIGURATION +# ========================= + load_dotenv() -# === CONFIGURATION === OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") -MODEL_NAME = "GPT-OSS-120B" -# File paths +MODEL_CONFIGS = [ +# { +# "model_name": "qwen3.6-35b-a3b", +# "use_response_format": False, +# "temperature": 0.0, +# "max_tokens": 4096, +# +# # If your backend is vLLM / Qwen chat-template compatible, +# # this may reduce long hidden reasoning and JSON truncation. +# # If your server errors because of extra_body, set this to None. +# "extra_body": { +# "chat_template_kwargs": { +# "enable_thinking": False +# } +# }, +# }, + { + "model_name": "gemma-4-31B-it", + "use_response_format": False, + "temperature": 0.0, + "max_tokens": 4096, + "extra_body": None, + }, + # { + # "model_name": "GPT-OSS-120B", + # "use_response_format": True, + # "temperature": 0.0, + # "max_tokens": 4096, + # "extra_body": None, + # }, +] + INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" -# Iteration settings -NUM_ITERATIONS = 20 -STOP_ON_FIRST_ERROR = False # Set to True for debugging +RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark" + +NUM_ITERATIONS = 10 +STOP_ON_FIRST_ERROR = False + +# For testing, set to e.g. 2. +# For full run, set to None. +MAX_ROWS = 2 +# MAX_ROWS = 2 + +MAX_TOKENS = 4096 +TEMPERATURE = 0.0 + +RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 + +SAVE_EVERY_N_ROWS = 1 + +# Retries for invalid JSON / truncated JSON +MAX_JSON_RETRIES = 2 +RETRY_SLEEP_SEC = 2 + + +# ========================= +# CLIENT +# ========================= -# Initialize OpenAI client client = OpenAI( api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL ) -# Read EDSS instructions from file -with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: + +# ========================= +# HELPERS +# ========================= + +def safe_dir_name(name: str) -> str: + name = str(name).strip() + name = re.sub(r"[^\w\-.]+", "_", name) + return name[:150] + + +def now_timestamp() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def get_process(): + if psutil is None: + return None + return psutil.Process(os.getpid()) + + +def get_memory_rss_mb(process=None): + if psutil is None: + return None + if process is None: + process = get_process() + return process.memory_info().rss / (1024 * 1024) + + +def get_cpu_times_sec(process=None): + if psutil is None: + return None + if process is None: + process = get_process() + cpu_times = process.cpu_times() + return cpu_times.user + cpu_times.system + + +class ResourceSampler: + def __init__(self, interval_sec=0.05): + self.interval_sec = interval_sec + self.process = get_process() + self.running = False + self.thread = None + self.samples_mb = [] + + def start(self): + if psutil is None: + return + + self.running = True + self.samples_mb = [] + self.thread = threading.Thread(target=self._sample_loop, daemon=True) + self.thread.start() + + def stop(self): + if psutil is None: + return + + self.running = False + if self.thread is not None: + self.thread.join(timeout=1.0) + + def _sample_loop(self): + while self.running: + try: + rss_mb = get_memory_rss_mb(self.process) + self.samples_mb.append(rss_mb) + except Exception: + pass + time.sleep(self.interval_sec) + + @property + def peak_rss_mb(self): + if not self.samples_mb: + return None + return max(self.samples_mb) + + +# ========================= +# JSON EXTRACTION +# ========================= + +def extract_json_from_text(text): + if text is None: + raise ValueError("Model returned empty content: message.content is None") + + text = str(text).strip() + + if not text: + raise ValueError("Model returned empty content") + + text = ( + text.replace("```json", "") + .replace("```JSON", "") + .replace("```Json", "") + .replace("```", "") + .strip() + ) + + # Direct parse + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + + # Balanced JSON candidates + candidates = [] + stack = [] + start_idx = None + in_string = False + escape = False + + for i, ch in enumerate(text): + if escape: + escape = False + continue + + if ch == "\\": + escape = True + continue + + if ch == '"': + in_string = not in_string + continue + + if in_string: + continue + + if ch == "{": + if not stack: + start_idx = i + stack.append(ch) + + elif ch == "}": + if stack: + stack.pop() + if not stack and start_idx is not None: + candidates.append(text[start_idx:i + 1]) + start_idx = None + + valid_objects = [] + + for candidate in candidates: + candidate = candidate.strip() + lowered = candidate.lower() + + invalid_markers = [ + "true/false", + "null or", + "oder zahl", + "0.0-6.0", + "0.0-10.0", + "zahl zwischen", + "...", + ] + + if any(marker in lowered for marker in invalid_markers): + continue + + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + valid_objects.append(parsed) + except json.JSONDecodeError: + continue + + for obj in reversed(valid_objects): + if ( + "klassifizierbar" in obj + and "certainty_percent" in obj + and "subcategories" in obj + ): + return obj + + if valid_objects: + return valid_objects[-1] + + # Check if it looks like truncated JSON + stripped = text.strip() + if stripped.startswith("{") and not stripped.endswith("}"): + raise ValueError( + "Model output looks like truncated JSON. " + f"Raw output starts with: {text[:1000]}" + ) + + raise ValueError( + "No valid JSON object found in model output. " + f"Raw output starts with: {text[:1000]}" + ) + + +def extract_message_content(message): + raw_content = getattr(message, "content", None) + + if raw_content is not None: + return raw_content + + msg_dict = None + + try: + msg_dict = message.model_dump() + except Exception: + try: + msg_dict = dict(message) + except Exception: + msg_dict = None + + if not isinstance(msg_dict, dict): + return None + + for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: + value = msg_dict.get(key) + if value: + return value + + possible_content = msg_dict.get("content") + if isinstance(possible_content, list): + parts = [] + for block in possible_content: + if isinstance(block, dict): + if "text" in block: + parts.append(str(block["text"])) + elif "content" in block: + parts.append(str(block["content"])) + if parts: + return "\n".join(parts).strip() + + return None + + +# ========================= +# READ INSTRUCTIONS +# ========================= + +with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: EDSS_INSTRUCTIONS = f.read().strip() -# === PROMPT (unchanged from before) === + +# ========================= +# PROMPT +# ========================= + def build_prompt(patient_text): - return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren. + return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. -### Deine Aufgabe: -1. Analysiere den Patientenbericht und extrahiere: - - Den Gesamt-EDSS-Score (0.0–10.0) - - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl) -2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein. +Extrahiere: +1. Gesamt-EDSS-Score von 0.0 bis 10.0 +2. Alle 8 EDSS-Unterkategorien +3. Sicherheit als Ganzzahl von 0 bis 100 -### Struktur der JSON-Ausgabe (VERPFLICHTEND): -Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter. +Antworte ausschließlich mit EINEM validen JSON-Objekt. +Kein Markdown. +Keine Code-Fences. +Kein Text vor oder nach JSON. +Keine Platzhalter. +Kopiere kein Schema. +Das JSON muss exakt diese Schlüssel enthalten: +- reason +- klassifizierbar +- EDSS +- certainty_percent +- subcategories + +Die subcategories müssen exakt diese 8 Schlüssel enthalten: +- VISUAL_OPTIC_FUNCTIONS +- BRAINSTEM_FUNCTIONS +- PYRAMIDAL_FUNCTIONS +- CEREBELLAR_FUNCTIONS +- SENSORY_FUNCTIONS +- BOWEL_AND_BLADDER_FUNCTIONS +- CEREBRAL_FUNCTIONS +- AMBULATION + +Werte: +- klassifizierbar: true oder false +- EDSS: Zahl von 0.0 bis 10.0 oder null +- certainty_percent: Ganzzahl von 0 bis 100 +- Unterkategorien: Zahl oder null +- VISUAL_OPTIC_FUNCTIONS maximal 6.0 +- BRAINSTEM_FUNCTIONS maximal 6.0 +- PYRAMIDAL_FUNCTIONS maximal 6.0 +- CEREBELLAR_FUNCTIONS maximal 6.0 +- SENSORY_FUNCTIONS maximal 6.0 +- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 +- CEREBRAL_FUNCTIONS maximal 6.0 +- AMBULATION maximal 10.0 +- reason: maximal 250 Zeichen, Deutsch + +Wenn klassifizierbar false ist, setze EDSS auf null. + +Valide Beispielausgabe: {{ - "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).", - "klassifizierbar": true/false, - "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)", - "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)", + "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", + "klassifizierbar": true, + "EDSS": 2.0, + "certainty_percent": 90, "subcategories": {{ - "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, - "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0 + "VISUAL_OPTIC_FUNCTIONS": null, + "BRAINSTEM_FUNCTIONS": null, + "PYRAMIDAL_FUNCTIONS": 1.0, + "CEREBELLAR_FUNCTIONS": 1.0, + "SENSORY_FUNCTIONS": 1.0, + "BOWEL_AND_BLADDER_FUNCTIONS": null, + "CEREBRAL_FUNCTIONS": null, + "AMBULATION": 0.0 }} }} -### Regeln: -- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest. -- **klassifizierbar**: - - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind. - - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist. -- **EDSS**: - - **VERPFLICHTEND**, wenn `klassifizierbar=true`. - - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`. -- **certainty_percent**: - - **Immer present** — Ganzzahl (0–100), basierend auf: - - Klarheit und Vollständigkeit der Berichtsangaben, - - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz), - - Konsistenz zwischen den Unterkategorien. -- **subcategories**: - - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein. - - Jeder Wert ist entweder: - - `null` (wenn keine ausreichende Information vorliegt), **oder** - - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0). - - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab. - - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`. - -### EDSS-Bewertungsrichtlinien: +EDSS-Bewertungsrichtlinien: {EDSS_INSTRUCTIONS} Patientenbericht: {patient_text} + +Gib ausschließlich das finale JSON-Objekt zurück. ''' -# === INFERENCE FUNCTION (unchanged) === -def run_inference(patient_text): - prompt = build_prompt(patient_text) - start_time = time.time() +# ========================= +# VALIDATION / NORMALIZATION +# ========================= - try: - response = client.chat.completions.create( - messages=[ - {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."} - ] + [ - {"role": "user", "content": prompt} - ], - model=MODEL_NAME, - max_tokens=2048, - temperature=0.1, - response_format={"type": "json_object"} - ) +def normalize_model_output(parsed): + if not isinstance(parsed, dict): + raise ValueError("Parsed model output is not a JSON object") - content = response.choices[0].message.content - - # Parse and validate JSON + if "certainty_percent" not in parsed: + parsed["certainty_percent"] = 0 + elif not isinstance(parsed["certainty_percent"], (int, float)): try: - parsed = json.loads(content) - except json.JSONDecodeError as e: - print(f"⚠️ JSON parsing failed: {e}") - print("Raw response:", content[:500]) - raise ValueError("Model did not return valid JSON") - - # Enforce required keys - if "certainty_percent" not in parsed: - print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.") - parsed["certainty_percent"] = 0 - elif not isinstance(parsed["certainty_percent"], (int, float)): parsed["certainty_percent"] = int(parsed["certainty_percent"]) + except Exception: + parsed["certainty_percent"] = 0 - # Clamp certainty to [0, 100] - pct = parsed["certainty_percent"] - parsed["certainty_percent"] = max(0, min(100, int(pct))) + parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"]))) - # Enforce EDSS rules - if not parsed.get("klassifizierbar", False): - if "EDSS" in parsed: - del parsed["EDSS"] + if "klassifizierbar" not in parsed: + parsed["klassifizierbar"] = False + + if not isinstance(parsed["klassifizierbar"], bool): + if str(parsed["klassifizierbar"]).lower() in ["true", "1", "yes", "ja"]: + parsed["klassifizierbar"] = True else: - if "EDSS" not in parsed: - print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.") + parsed["klassifizierbar"] = False + + if not parsed.get("klassifizierbar", False): + parsed["EDSS"] = None + else: + if "EDSS" not in parsed or parsed["EDSS"] is None: + parsed["EDSS"] = 7.0 + else: + try: + parsed["EDSS"] = float(parsed["EDSS"]) + parsed["EDSS"] = max(0.0, min(10.0, parsed["EDSS"])) + except Exception: parsed["EDSS"] = 7.0 - inference_time = time.time() - start_time + required_subcategories = { + "VISUAL_OPTIC_FUNCTIONS": 6.0, + "BRAINSTEM_FUNCTIONS": 6.0, + "PYRAMIDAL_FUNCTIONS": 6.0, + "CEREBELLAR_FUNCTIONS": 6.0, + "SENSORY_FUNCTIONS": 6.0, + "BOWEL_AND_BLADDER_FUNCTIONS": 6.0, + "CEREBRAL_FUNCTIONS": 6.0, + "AMBULATION": 10.0, + } - return { - "success": True, - "result": parsed, - "inference_time_sec": inference_time - } + if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict): + parsed["subcategories"] = {} + + for key, max_value in required_subcategories.items(): + value = parsed["subcategories"].get(key, None) + + if value is None: + parsed["subcategories"][key] = None + else: + try: + value = float(value) + value = max(0.0, min(max_value, value)) + parsed["subcategories"][key] = value + except Exception: + parsed["subcategories"][key] = None + + parsed["subcategories"] = { + key: parsed["subcategories"].get(key, None) + for key in required_subcategories.keys() + } + + if "reason" not in parsed or parsed["reason"] is None: + parsed["reason"] = "" + + parsed["reason"] = str(parsed["reason"])[:250] + + return parsed + + +# ========================= +# API CALL +# ========================= + +def make_chat_completion(model_config, prompt): + model_name = model_config["model_name"] + + kwargs = dict( + messages=[ + { + "role": "system", + "content": ( + "Du bist ein JSON-Generator. " + "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " + "Keine Erklärung. Kein Markdown. Keine Code-Fences. " + "Keine Platzhalter. Kein Schema kopieren. " + "Das JSON muss vollständig geschlossen sein." + ) + }, + { + "role": "user", + "content": prompt + } + ], + model=model_name, + max_tokens=model_config.get("max_tokens", MAX_TOKENS), + temperature=model_config.get("temperature", TEMPERATURE), + ) + + if model_config.get("use_response_format", False): + kwargs["response_format"] = {"type": "json_object"} + + extra_body = model_config.get("extra_body") + if extra_body is not None: + kwargs["extra_body"] = extra_body + + return client.chat.completions.create(**kwargs) + + +# ========================= +# INFERENCE FUNCTION WITH RETRIES +# ========================= + +def run_inference(patient_text, model_config): + model_name = model_config["model_name"] + prompt = build_prompt(patient_text) + + process = get_process() + sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) + + wall_start = time.perf_counter() + cpu_start = get_cpu_times_sec(process) + rss_start_mb = get_memory_rss_mb(process) + + sampler.start() + + raw_content = None + raw_response_debug = None + last_error = None + prompt_tokens = None + completion_tokens = None + total_tokens = None + + try: + for attempt in range(1, MAX_JSON_RETRIES + 2): + try: + response = make_chat_completion( + model_config=model_config, + prompt=prompt + ) + + message = response.choices[0].message + raw_content = extract_message_content(message) + + try: + raw_response_debug = response.model_dump() + except Exception: + raw_response_debug = str(response) + + usage = getattr(response, "usage", None) + if usage is not None: + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + total_tokens = getattr(usage, "total_tokens", None) + + parsed = extract_json_from_text(raw_content) + parsed = normalize_model_output(parsed) + + success = True + error = None + result = parsed + break + + except Exception as e: + last_error = str(e) + + if attempt <= MAX_JSON_RETRIES: + print( + f"\n⚠️ JSON failed on attempt {attempt}. " + f"Retrying row. Error: {last_error[:300]}" + ) + time.sleep(RETRY_SLEEP_SEC) + continue + + raise except Exception as e: print(f"❌ Inference error: {e}") - return { - "success": False, - "error": str(e), - "inference_time_sec": -1, - "result": None - } -# === BUILD PATIENT TEXT === + success = False + error = str(e) + result = None + + finally: + sampler.stop() + + wall_end = time.perf_counter() + cpu_end = get_cpu_times_sec(process) + rss_end_mb = get_memory_rss_mb(process) + + wall_time_sec = wall_end - wall_start + + if cpu_start is not None and cpu_end is not None: + process_cpu_time_sec = cpu_end - cpu_start + else: + process_cpu_time_sec = None + + if rss_start_mb is not None and rss_end_mb is not None: + rss_delta_mb = rss_end_mb - rss_start_mb + else: + rss_delta_mb = None + + return { + "success": success, + "error": error, + "result": result, + + "model": model_name, + + "inference_time_sec": wall_time_sec, + + "process_cpu_time_sec": process_cpu_time_sec, + "rss_before_mb": rss_start_mb, + "rss_after_mb": rss_end_mb, + "rss_delta_mb": rss_delta_mb, + "peak_rss_mb": sampler.peak_rss_mb, + + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + + "raw_content": raw_content if not success else None, + "raw_response_debug": raw_response_debug if not success else None, + "last_error": last_error, + } + + +# ========================= +# BUILD PATIENT TEXT +# ========================= + def build_patient_text(row): return ( str(row.get("T_Zusammenfassung", "")) + "\n" + @@ -527,74 +1924,1678 @@ def build_patient_text(row): str(row.get("T_Befunde", "")) ) -# === MAIN LOOP (NEW: MULTI-ITERATION) === + +# ========================= +# FLATTEN RESULTS FOR CSV +# ========================= + +def flatten_result(record): + flat = { + "model": record.get("model"), + "iteration": record.get("iteration"), + "row_index": record.get("row_index"), + "row_number_in_run": record.get("row_number_in_run"), + "unique_id": record.get("unique_id"), + "MedDatum": record.get("MedDatum"), + + "success": record.get("success"), + "error": record.get("error"), + "last_error": record.get("last_error"), + + "inference_time_sec": record.get("inference_time_sec"), + "process_cpu_time_sec": record.get("process_cpu_time_sec"), + "rss_before_mb": record.get("rss_before_mb"), + "rss_after_mb": record.get("rss_after_mb"), + "rss_delta_mb": record.get("rss_delta_mb"), + "peak_rss_mb": record.get("peak_rss_mb"), + + "prompt_tokens": record.get("prompt_tokens"), + "completion_tokens": record.get("completion_tokens"), + "total_tokens": record.get("total_tokens"), + + "raw_content": record.get("raw_content"), + } + + result = record.get("result") + + if isinstance(result, dict): + flat["reason"] = result.get("reason") + flat["klassifizierbar"] = result.get("klassifizierbar") + flat["EDSS"] = result.get("EDSS") + flat["certainty_percent"] = result.get("certainty_percent") + + subcats = result.get("subcategories", {}) + if isinstance(subcats, dict): + for key, value in subcats.items(): + flat[f"subcat_{key}"] = value + + return flat + + +def summarize_records(records): + df = pd.DataFrame([flatten_result(r) for r in records]) + + if df.empty: + return pd.DataFrame() + + success_series = df["success"].fillna(False).astype(bool) + + summary = { + "model": df["model"].iloc[0] if "model" in df.columns else None, + "n_records": len(df), + "n_success": int(success_series.sum()), + "n_failed": int((~success_series).sum()), + "success_rate": float(success_series.mean()), + } + + numeric_cols = [ + "inference_time_sec", + "process_cpu_time_sec", + "rss_delta_mb", + "peak_rss_mb", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "certainty_percent", + "EDSS", + ] + + for col in numeric_cols: + if col in df.columns: + values = pd.to_numeric(df[col], errors="coerce") + summary[f"{col}_mean"] = values.mean() + summary[f"{col}_median"] = values.median() + summary[f"{col}_std"] = values.std() + summary[f"{col}_min"] = values.min() + summary[f"{col}_max"] = values.max() + + return pd.DataFrame([summary]) + + +# ========================= +# INCREMENTAL SAVE HELPERS +# ========================= + +def append_jsonl(path, record): + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + f.flush() + os.fsync(f.fileno()) + + +def append_csv(path, record): + flat = flatten_result(record) + df_one = pd.DataFrame([flat]) + file_exists = Path(path).exists() + df_one.to_csv(path, mode="a", header=not file_exists, index=False) + + +# ========================= +# MAIN LOOP +# ========================= + if __name__ == "__main__": - # Load data ONCE (to avoid repeated I/O overhead) - df = pd.read_csv(INPUT_CSV, sep=';') + + run_timestamp = now_timestamp() + + results_root = Path(RESULTS_ROOT) + results_root.mkdir(parents=True, exist_ok=True) + + run_root = results_root / f"run_{run_timestamp}" + run_root.mkdir(parents=True, exist_ok=True) + + print(f"Results root: {run_root}") + + df = pd.read_csv(INPUT_CSV, sep=";") + + if MAX_ROWS is not None: + df = df.head(MAX_ROWS) + total_rows = len(df) + + model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] + print(f"Loaded {total_rows} patient records.") + print(f"Models: {model_names_for_print}") + print(f"Iterations per model: {NUM_ITERATIONS}") - for iteration in range(1, NUM_ITERATIONS + 1): - print(f"\n{'='*60}") - print(f"🔄 ITERATION {iteration}/{NUM_ITERATIONS}") - print(f"{'='*60}") + all_model_summaries = [] - iteration_results = [] - start_iter = time.time() + for model_config in MODEL_CONFIGS: + model_name = model_config["model_name"] + safe_model = safe_dir_name(model_name) - for idx, row in df.iterrows(): - print(f"\rRow {idx+1}/{total_rows} | Iter {iteration}", end='', flush=True) + model_dir = run_root / safe_model + model_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'#' * 80}") + print(f"MODEL: {model_name}") + print(f"use_response_format: {model_config.get('use_response_format', False)}") + print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") + print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") + print(f"Saving to: {model_dir}") + print(f"{'#' * 80}") + + model_records = [] + model_start = time.perf_counter() + + for iteration in range(1, NUM_ITERATIONS + 1): + print(f"\n{'=' * 60}") + print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") + print(f"{'=' * 60}") + + iteration_results = [] + iteration_start = time.perf_counter() + + incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" + incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" + + print(f"Incremental JSONL: {incremental_jsonl_path}") + print(f"Incremental CSV: {incremental_csv_path}") + + for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): + print( + f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", + end="", + flush=True + ) + + try: + patient_text = build_patient_text(row) + + record = run_inference( + patient_text=patient_text, + model_config=model_config + ) + + record["iteration"] = iteration + record["row_index"] = int(idx) + record["row_number_in_run"] = int(loop_i) + record["unique_id"] = row.get("unique_id", f"row_{idx}") + record["MedDatum"] = row.get("MedDatum", None) + + iteration_results.append(record) + model_records.append(record) + + if loop_i % SAVE_EVERY_N_ROWS == 0: + append_jsonl(incremental_jsonl_path, record) + append_csv(incremental_csv_path, record) + + if record["success"]: + res = record["result"] + edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" + print( + f" ✅ EDSS={edss}, " + f"cert={res.get('certainty_percent', '?')}%, " + f"time={record['inference_time_sec']:.2f}s" + ) + else: + print(f" ❌ {record.get('error', 'Unknown error')}") + + except Exception as e: + print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") + + fallback_record = { + "success": False, + "error": str(e), + "last_error": str(e), + "result": None, + + "model": model_name, + "iteration": iteration, + "row_index": int(idx), + "row_number_in_run": int(loop_i), + "unique_id": row.get("unique_id", f"row_{idx}"), + "MedDatum": row.get("MedDatum", None), + + "inference_time_sec": None, + "process_cpu_time_sec": None, + "rss_before_mb": None, + "rss_after_mb": None, + "rss_delta_mb": None, + "peak_rss_mb": None, + + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + + "raw_content": None, + "raw_response_debug": None, + } + + iteration_results.append(fallback_record) + model_records.append(fallback_record) + + append_jsonl(incremental_jsonl_path, fallback_record) + append_csv(incremental_csv_path, fallback_record) + + if STOP_ON_FIRST_ERROR: + break + + iteration_elapsed = time.perf_counter() - iteration_start + + iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" + with open(iter_json_path, "w", encoding="utf-8") as f: + json.dump(iteration_results, f, indent=2, ensure_ascii=False) + + iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" + iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) + iter_flat_df.to_csv(iter_csv_path, index=False) + + print(f"\n✅ Iteration {iteration} complete.") + print(f"Incremental JSONL saved to: {incremental_jsonl_path}") + print(f"Incremental CSV saved to: {incremental_csv_path}") + print(f"Final JSON saved to: {iter_json_path}") + print(f"Final CSV saved to: {iter_csv_path}") + print( + f"⏱️ Iteration time: {iteration_elapsed:.1f}s " + f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" + ) + + model_elapsed = time.perf_counter() - model_start + + model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" + with open(model_json_path, "w", encoding="utf-8") as f: + json.dump(model_records, f, indent=2, ensure_ascii=False) + + model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" + model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) + model_flat_df.to_csv(model_csv_path, index=False) + + model_summary_df = summarize_records(model_records) + model_summary_df["model_total_wall_time_sec"] = model_elapsed + model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 + + model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" + model_summary_df.to_csv(model_summary_path, index=False) + + all_model_summaries.append(model_summary_df) + + print(f"\n🎉 Model completed: {model_name}") + print(f"All JSON: {model_json_path}") + print(f"All CSV: {model_csv_path}") + print(f"Summary: {model_summary_path}") + print(f"Total model time: {model_elapsed / 60:.2f} min") + + if all_model_summaries: + combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) + combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" + combined_summary_df.to_csv(combined_summary_path, index=False) + + print(f"\n📊 Combined summary saved to: {combined_summary_path}") + + print(f"\n🎉 All models and all iterations completed!") +## + +# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark2 + +import time +import json +import os +import re +import threading +from datetime import datetime +from pathlib import Path + +import pandas as pd +from openai import OpenAI +from dotenv import load_dotenv + +try: + import psutil +except ImportError: + psutil = None + print("⚠️ psutil is not installed. Resource metrics will be limited.") + print("Install with: pip install psutil") + + +# ========================= +# CONFIGURATION +# ========================= + +load_dotenv() + +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") + +MODEL_CONFIGS = [ + { + "model_name": "qwen3.6-27b", + "use_response_format": False, + "temperature": 0.0, + "max_tokens": 4096, + "extra_body": { + "chat_template_kwargs": { + "enable_thinking": False + } + }, + }, + { + "model_name": "gemma-4-31B-it", + "use_response_format": False, + "temperature": 0.0, + "max_tokens": 4096, + "extra_body": None, + }, + # { + # "model_name": "GPT-OSS-120B", + # "use_response_format": True, + # "temperature": 0.0, + # "max_tokens": 4096, + # "extra_body": None, + # }, +] + +INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" + +RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark" + +NUM_ITERATIONS = 10 +STOP_ON_FIRST_ERROR = False + +# For testing, set to e.g. 2. +# For full run, set to None. +MAX_ROWS = None +# MAX_ROWS = 2 + +MAX_TOKENS = 4096 +TEMPERATURE = 0.0 + +RESOURCE_SAMPLE_INTERVAL_SEC = 0.05 + +SAVE_EVERY_N_ROWS = 1 + +# Retries for invalid JSON / truncated JSON +MAX_JSON_RETRIES = 2 +RETRY_SLEEP_SEC = 2 + + +# ========================= +# VALID CLINICAL RANGES +# ========================= + +EDSS_MIN = 0.0 +EDSS_MAX = 10.0 + +FUNCTIONAL_SYSTEM_RANGES = { + "VISUAL_OPTIC_FUNCTIONS": (0.0, 6.0), + "BRAINSTEM_FUNCTIONS": (0.0, 6.0), + "PYRAMIDAL_FUNCTIONS": (0.0, 6.0), + "CEREBELLAR_FUNCTIONS": (0.0, 6.0), + "SENSORY_FUNCTIONS": (0.0, 6.0), + "BOWEL_AND_BLADDER_FUNCTIONS": (0.0, 6.0), + "CEREBRAL_FUNCTIONS": (0.0, 6.0), + "AMBULATION": (0.0, 10.0), +} + +REQUIRED_TOP_LEVEL_FIELDS = [ + "reason", + "klassifizierbar", + "EDSS", + "certainty_percent", + "subcategories", +] + + +# ========================= +# CLIENT +# ========================= + +client = OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_BASE_URL +) + + +# ========================= +# HELPERS +# ========================= + +def safe_dir_name(name: str) -> str: + name = str(name).strip() + name = re.sub(r"[^\w\-.]+", "_", name) + return name[:150] + + +def now_timestamp() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def get_process(): + if psutil is None: + return None + return psutil.Process(os.getpid()) + + +def get_memory_rss_mb(process=None): + if psutil is None: + return None + if process is None: + process = get_process() + return process.memory_info().rss / (1024 * 1024) + + +def get_cpu_times_sec(process=None): + if psutil is None: + return None + if process is None: + process = get_process() + cpu_times = process.cpu_times() + return cpu_times.user + cpu_times.system + + +class ResourceSampler: + def __init__(self, interval_sec=0.05): + self.interval_sec = interval_sec + self.process = get_process() + self.running = False + self.thread = None + self.samples_mb = [] + + def start(self): + if psutil is None: + return + + self.running = True + self.samples_mb = [] + self.thread = threading.Thread(target=self._sample_loop, daemon=True) + self.thread.start() + + def stop(self): + if psutil is None: + return + + self.running = False + if self.thread is not None: + self.thread.join(timeout=1.0) + + def _sample_loop(self): + while self.running: try: - patient_text = build_patient_text(row) - result = run_inference(patient_text) + rss_mb = get_memory_rss_mb(self.process) + self.samples_mb.append(rss_mb) + except Exception: + pass + time.sleep(self.interval_sec) - # Attach metadata - if result["success"]: - res = result["result"].copy() # avoid mutation - res["iteration"] = iteration - res["unique_id"] = row.get("unique_id", f"row_{idx}") - res["MedDatum"] = row.get("MedDatum", None) - result["result"] = res + @property + def peak_rss_mb(self): + if not self.samples_mb: + return None + return max(self.samples_mb) - else: - result["iteration"] = iteration - result["unique_id"] = row.get("unique_id", f"row_{idx}") - result["MedDatum"] = row.get("MedDatum", None) - iteration_results.append(result) +# ========================= +# JSON EXTRACTION +# ========================= - if result["success"]: - res = result["result"] - edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" - print(f" ✅ EDSS={edss}, cert={res.get('certainty_percent', '?')}%") - else: - print(f" ❌ {result.get('error', 'Unknown')}") +def extract_json_from_text(text): + if text is None: + raise ValueError("Model returned empty content: message.content is None") + + text = str(text).strip() + + if not text: + raise ValueError("Model returned empty content") + + text = ( + text.replace("```json", "") + .replace("```JSON", "") + .replace("```Json", "") + .replace("```", "") + .strip() + ) + + # Direct parse + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + + # Balanced JSON candidates + candidates = [] + stack = [] + start_idx = None + in_string = False + escape = False + + for i, ch in enumerate(text): + if escape: + escape = False + continue + + if ch == "\\": + escape = True + continue + + if ch == '"': + in_string = not in_string + continue + + if in_string: + continue + + if ch == "{": + if not stack: + start_idx = i + stack.append(ch) + + elif ch == "}": + if stack: + stack.pop() + if not stack and start_idx is not None: + candidates.append(text[start_idx:i + 1]) + start_idx = None + + valid_objects = [] + + for candidate in candidates: + candidate = candidate.strip() + lowered = candidate.lower() + + invalid_markers = [ + "true/false", + "null or", + "oder zahl", + "0.0-6.0", + "0.0-10.0", + "zahl zwischen", + "...", + ] + + if any(marker in lowered for marker in invalid_markers): + continue + + try: + parsed = json.loads(candidate) + if isinstance(parsed, dict): + valid_objects.append(parsed) + except json.JSONDecodeError: + continue + + for obj in reversed(valid_objects): + if ( + "klassifizierbar" in obj + and "certainty_percent" in obj + and "subcategories" in obj + ): + return obj + + if valid_objects: + return valid_objects[-1] + + stripped = text.strip() + if stripped.startswith("{") and not stripped.endswith("}"): + raise ValueError( + "Model output looks like truncated JSON. " + f"Raw output starts with: {text[:1000]}" + ) + + raise ValueError( + "No valid JSON object found in model output. " + f"Raw output starts with: {text[:1000]}" + ) + + +def extract_message_content(message): + raw_content = getattr(message, "content", None) + + if raw_content is not None: + return raw_content + + msg_dict = None + + try: + msg_dict = message.model_dump() + except Exception: + try: + msg_dict = dict(message) + except Exception: + msg_dict = None + + if not isinstance(msg_dict, dict): + return None + + for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]: + value = msg_dict.get(key) + if value: + return value + + possible_content = msg_dict.get("content") + if isinstance(possible_content, list): + parts = [] + for block in possible_content: + if isinstance(block, dict): + if "text" in block: + parts.append(str(block["text"])) + elif "content" in block: + parts.append(str(block["content"])) + if parts: + return "\n".join(parts).strip() + + return None + + +# ========================= +# READ INSTRUCTIONS +# ========================= + +with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f: + EDSS_INSTRUCTIONS = f.read().strip() + + +# ========================= +# PROMPT +# ========================= + +def build_prompt(patient_text): + return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten. + +Extrahiere: +1. Gesamt-EDSS-Score von 0.0 bis 10.0 +2. Alle 8 EDSS-Unterkategorien +3. Sicherheit als Ganzzahl von 0 bis 100 + +Antworte ausschließlich mit EINEM validen JSON-Objekt. +Kein Markdown. +Keine Code-Fences. +Kein Text vor oder nach JSON. +Keine Platzhalter. +Kopiere kein Schema. + +Das JSON muss exakt diese Schlüssel enthalten: +- reason +- klassifizierbar +- EDSS +- certainty_percent +- subcategories + +Die subcategories müssen exakt diese 8 Schlüssel enthalten: +- VISUAL_OPTIC_FUNCTIONS +- BRAINSTEM_FUNCTIONS +- PYRAMIDAL_FUNCTIONS +- CEREBELLAR_FUNCTIONS +- SENSORY_FUNCTIONS +- BOWEL_AND_BLADDER_FUNCTIONS +- CEREBRAL_FUNCTIONS +- AMBULATION + +Werte: +- klassifizierbar: true oder false +- EDSS: Zahl von 0.0 bis 10.0 oder null +- certainty_percent: Ganzzahl von 0 bis 100 +- Unterkategorien: Zahl oder null +- VISUAL_OPTIC_FUNCTIONS maximal 6.0 +- BRAINSTEM_FUNCTIONS maximal 6.0 +- PYRAMIDAL_FUNCTIONS maximal 6.0 +- CEREBELLAR_FUNCTIONS maximal 6.0 +- SENSORY_FUNCTIONS maximal 6.0 +- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0 +- CEREBRAL_FUNCTIONS maximal 6.0 +- AMBULATION maximal 10.0 +- reason: maximal 250 Zeichen, Deutsch + +Wenn klassifizierbar false ist, setze EDSS auf null. + +Valide Beispielausgabe: +{{ + "reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.", + "klassifizierbar": true, + "EDSS": 2.0, + "certainty_percent": 90, + "subcategories": {{ + "VISUAL_OPTIC_FUNCTIONS": null, + "BRAINSTEM_FUNCTIONS": null, + "PYRAMIDAL_FUNCTIONS": 1.0, + "CEREBELLAR_FUNCTIONS": 1.0, + "SENSORY_FUNCTIONS": 1.0, + "BOWEL_AND_BLADDER_FUNCTIONS": null, + "CEREBRAL_FUNCTIONS": null, + "AMBULATION": 0.0 + }} +}} + +EDSS-Bewertungsrichtlinien: +{EDSS_INSTRUCTIONS} + +Patientenbericht: +{patient_text} + +Gib ausschließlich das finale JSON-Objekt zurück. +''' + + +# ========================= +# VALIDATION, NOT NORMALIZATION +# ========================= + +def parse_float_preserve_raw(value): + """ + Try to parse a value as float without clipping or correcting it. + + Returns: + raw_value: original value exactly as present in parsed JSON + numeric_value: float or None + is_numeric: bool + """ + raw_value = value + + if value is None: + return raw_value, None, False + + if isinstance(value, bool): + return raw_value, None, False + + try: + numeric_value = float(str(value).replace(",", ".")) + return raw_value, numeric_value, True + except Exception: + return raw_value, None, False + + +def is_in_range(value, min_value, max_value): + """ + Range check without clipping. + """ + if value is None: + return False + return min_value <= value <= max_value + + +def validate_model_output(parsed): + """ + Validate parsed model output without repairing/clipping clinical values. + + Important: + - Does NOT clip EDSS. + - Does NOT clip functional system values. + - Does NOT insert default EDSS. + - Does NOT insert default certainty_percent. + - Missing fields are kept as None. + - Adds explicit validity flags for scientific transparency. + """ + + validation = { + "json_parse_success": isinstance(parsed, dict), + "required_fields_present": False, + "required_schema_success": False, + "clinical_range_valid": False, + "certainty_present": False, + + "missing_required_fields": [], + "missing_subcategory_fields": [], + + "EDSS_is_numeric": False, + "EDSS_in_valid_range": False, + } + + if not isinstance(parsed, dict): + return { + "raw_output": parsed, + "validated_output": {}, + "validation": validation, + } + + missing_required = [ + field for field in REQUIRED_TOP_LEVEL_FIELDS + if field not in parsed + ] + + validation["missing_required_fields"] = missing_required + validation["required_fields_present"] = len(missing_required) == 0 + + validated = {} + + validated["reason"] = parsed.get("reason", None) + validated["klassifizierbar"] = parsed.get("klassifizierbar", None) + + raw_certainty = parsed.get("certainty_percent", None) + validated["raw_certainty_percent"] = raw_certainty + validation["certainty_present"] = "certainty_percent" in parsed and raw_certainty is not None + + _, certainty_numeric, certainty_is_numeric = parse_float_preserve_raw(raw_certainty) + validated["certainty_percent"] = certainty_numeric if certainty_is_numeric else None + validated["certainty_percent_is_numeric"] = certainty_is_numeric + validated["certainty_percent_in_valid_range"] = ( + is_in_range(certainty_numeric, 0.0, 100.0) + if certainty_is_numeric else False + ) + + raw_edss = parsed.get("EDSS", None) + raw_edss, edss_numeric, edss_is_numeric = parse_float_preserve_raw(raw_edss) + + validated["raw_EDSS"] = raw_edss + validated["EDSS_numeric"] = edss_numeric + validated["EDSS"] = edss_numeric # Backward-compatible; parsed only, not clipped + validated["EDSS_is_numeric"] = edss_is_numeric + validated["EDSS_in_valid_range"] = ( + is_in_range(edss_numeric, EDSS_MIN, EDSS_MAX) + if edss_is_numeric else False + ) + + validation["EDSS_is_numeric"] = validated["EDSS_is_numeric"] + validation["EDSS_in_valid_range"] = validated["EDSS_in_valid_range"] + + raw_subcategories = parsed.get("subcategories", None) + + if isinstance(raw_subcategories, dict): + subcategories = raw_subcategories + else: + subcategories = {} + + validated["subcategories"] = {} + validated["raw_subcategories"] = {} + validated["subcategory_validation"] = {} + + missing_subcats = [] + + for subcat, (min_value, max_value) in FUNCTIONAL_SYSTEM_RANGES.items(): + if subcat not in subcategories: + missing_subcats.append(subcat) + + raw_value = subcategories.get(subcat, None) + raw_value, numeric_value, is_numeric_value = parse_float_preserve_raw(raw_value) + in_valid_range = ( + is_in_range(numeric_value, min_value, max_value) + if is_numeric_value else False + ) + + validated["raw_subcategories"][subcat] = raw_value + validated["subcategories"][subcat] = numeric_value + + validated["subcategory_validation"][subcat] = { + "is_numeric": is_numeric_value, + "in_valid_range": in_valid_range, + "min_allowed": min_value, + "max_allowed": max_value, + } + + validation["missing_subcategory_fields"] = missing_subcats + + subcategory_schema_present = len(missing_subcats) == 0 + + all_subcats_numeric = all( + validated["subcategory_validation"][subcat]["is_numeric"] + for subcat in FUNCTIONAL_SYSTEM_RANGES + ) + + all_subcats_in_range = all( + validated["subcategory_validation"][subcat]["in_valid_range"] + for subcat in FUNCTIONAL_SYSTEM_RANGES + ) + + validated["all_functional_systems_numeric"] = all_subcats_numeric + validated["all_functional_systems_in_valid_range"] = all_subcats_in_range + + validation["clinical_range_valid"] = ( + validated["EDSS_in_valid_range"] + and all_subcats_in_range + ) + + validation["required_schema_success"] = ( + validation["required_fields_present"] + and subcategory_schema_present + ) + + return { + "raw_output": parsed, + "validated_output": validated, + "validation": validation, + } + + +# ========================= +# API CALL +# ========================= + +def make_chat_completion(model_config, prompt): + model_name = model_config["model_name"] + + kwargs = dict( + messages=[ + { + "role": "system", + "content": ( + "Du bist ein JSON-Generator. " + "Antworte ausschließlich mit einem einzigen validen JSON-Objekt. " + "Keine Erklärung. Kein Markdown. Keine Code-Fences. " + "Keine Platzhalter. Kein Schema kopieren. " + "Das JSON muss vollständig geschlossen sein." + ) + }, + { + "role": "user", + "content": prompt + } + ], + model=model_name, + max_tokens=model_config.get("max_tokens", MAX_TOKENS), + temperature=model_config.get("temperature", TEMPERATURE), + ) + + if model_config.get("use_response_format", False): + kwargs["response_format"] = {"type": "json_object"} + + extra_body = model_config.get("extra_body") + if extra_body is not None: + kwargs["extra_body"] = extra_body + + return client.chat.completions.create(**kwargs) + + +# ========================= +# INFERENCE FUNCTION WITH RETRIES +# ========================= + +def run_inference(patient_text, model_config): + model_name = model_config["model_name"] + prompt = build_prompt(patient_text) + + process = get_process() + sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC) + + wall_start = time.perf_counter() + cpu_start = get_cpu_times_sec(process) + rss_start_mb = get_memory_rss_mb(process) + + sampler.start() + + raw_content = None + raw_response_debug = None + raw_parsed_output = None + validation = None + last_error = None + + prompt_tokens = None + completion_tokens = None + total_tokens = None + + try: + for attempt in range(1, MAX_JSON_RETRIES + 2): + try: + response = make_chat_completion( + model_config=model_config, + prompt=prompt + ) + + message = response.choices[0].message + raw_content = extract_message_content(message) + + try: + raw_response_debug = response.model_dump() + except Exception: + raw_response_debug = str(response) + + usage = getattr(response, "usage", None) + if usage is not None: + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + total_tokens = getattr(usage, "total_tokens", None) + + parsed = extract_json_from_text(raw_content) + validation_package = validate_model_output(parsed) + + success = True + error = None + + result = validation_package["validated_output"] + validation = validation_package["validation"] + raw_parsed_output = validation_package["raw_output"] + + break except Exception as e: - print(f"\n⚠️ Row {idx} failed: {e}") - iteration_results.append({ - "success": False, - "error": str(e), - "iteration": iteration, - "unique_id": row.get("unique_id", f"row_{idx}"), - "MedDatum": row.get("MedDatum", None), - "result": None - }) - if STOP_ON_FIRST_ERROR: - break + last_error = str(e) - # Save per-iteration results - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = INPUT_CSV.replace(".csv", f"_results_iter_{iteration}_{timestamp}.json") - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(iteration_results, f, indent=2, ensure_ascii=False) - print(f"\n✅ Iteration {iteration} complete. Saved to: {output_path}") + if attempt <= MAX_JSON_RETRIES: + print( + f"\n⚠️ JSON failed on attempt {attempt}. " + f"Retrying row. Error: {last_error[:300]}" + ) + time.sleep(RETRY_SLEEP_SEC) + continue - elapsed = time.time() - start_iter - print(f"⏱️ Iteration {iteration} took {elapsed:.1f}s ({elapsed/total_rows:.1f}s/row)") + raise - print(f"\n🎉 All {NUM_ITERATIONS} iterations completed!") + except Exception as e: + print(f"❌ Inference error: {e}") + + success = False + error = str(e) + result = None + raw_parsed_output = None + + validation = { + "json_parse_success": False, + "required_fields_present": False, + "required_schema_success": False, + "clinical_range_valid": False, + "certainty_present": False, + "missing_required_fields": [], + "missing_subcategory_fields": [], + "EDSS_is_numeric": False, + "EDSS_in_valid_range": False, + } + + finally: + sampler.stop() + + wall_end = time.perf_counter() + cpu_end = get_cpu_times_sec(process) + rss_end_mb = get_memory_rss_mb(process) + + wall_time_sec = wall_end - wall_start + + if cpu_start is not None and cpu_end is not None: + process_cpu_time_sec = cpu_end - cpu_start + else: + process_cpu_time_sec = None + + if rss_start_mb is not None and rss_end_mb is not None: + rss_delta_mb = rss_end_mb - rss_start_mb + else: + rss_delta_mb = None + + return { + "success": success, + "error": error, + "result": result, + + "validation": validation, + "raw_parsed_output": raw_parsed_output, + + "model": model_name, + + "inference_time_sec": wall_time_sec, + + "process_cpu_time_sec": process_cpu_time_sec, + "rss_before_mb": rss_start_mb, + "rss_after_mb": rss_end_mb, + "rss_delta_mb": rss_delta_mb, + "peak_rss_mb": sampler.peak_rss_mb, + + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + + # Keeping raw content improves auditability but can make files large. + # To save space, change this to: raw_content if not success else None + "raw_content": raw_content, + "raw_response_debug": raw_response_debug if not success else None, + "last_error": last_error, + } +# ========================= +# BUILD PATIENT TEXT +# ========================= +def build_patient_text(row): + return ( + str(row.get("T_Zusammenfassung", "")) + "\n" + + str(row.get("Diagnosen", "")) + "\n" + + str(row.get("T_KlinBef", "")) + "\n" + + str(row.get("T_Befunde", "")) + ) + + +# ========================= +# FLATTEN RESULTS FOR CSV +# ========================= + +def flatten_result(record): + """ + Flatten one benchmark record for CSV export. + + This preserves: + - raw model values + - parsed numeric values without clipping + - validity flags + - backward-compatible columns where possible + """ + + validation = record.get("validation") or {} + result = record.get("result") or {} + + flat = { + "model": record.get("model"), + "iteration": record.get("iteration"), + "row_index": record.get("row_index"), + "row_number_in_run": record.get("row_number_in_run"), + "unique_id": record.get("unique_id"), + "MedDatum": record.get("MedDatum"), + + "success": record.get("success"), + "error": record.get("error"), + "last_error": record.get("last_error"), + + "json_parse_success": validation.get("json_parse_success"), + "required_fields_present": validation.get("required_fields_present"), + "required_schema_success": validation.get("required_schema_success"), + "clinical_range_valid": validation.get("clinical_range_valid"), + "certainty_present": validation.get("certainty_present"), + + "missing_required_fields": json.dumps( + validation.get("missing_required_fields", []), + ensure_ascii=False + ), + "missing_subcategory_fields": json.dumps( + validation.get("missing_subcategory_fields", []), + ensure_ascii=False + ), + + "inference_time_sec": record.get("inference_time_sec"), + "process_cpu_time_sec": record.get("process_cpu_time_sec"), + "rss_before_mb": record.get("rss_before_mb"), + "rss_after_mb": record.get("rss_after_mb"), + "rss_delta_mb": record.get("rss_delta_mb"), + "peak_rss_mb": record.get("peak_rss_mb"), + + "prompt_tokens": record.get("prompt_tokens"), + "completion_tokens": record.get("completion_tokens"), + "total_tokens": record.get("total_tokens"), + + "raw_content": record.get("raw_content"), + "raw_parsed_output": json.dumps(record.get("raw_parsed_output"), ensure_ascii=False), + + # Backward-compatible fields + "reason": result.get("reason"), + "klassifizierbar": result.get("klassifizierbar"), + + "raw_certainty_percent": result.get("raw_certainty_percent"), + "certainty_percent": result.get("certainty_percent"), + "certainty_percent_is_numeric": result.get("certainty_percent_is_numeric"), + "certainty_percent_in_valid_range": result.get("certainty_percent_in_valid_range"), + + # EDSS raw/numeric/validity fields + "raw_EDSS": result.get("raw_EDSS"), + "EDSS_numeric": result.get("EDSS_numeric"), + "EDSS": result.get("EDSS"), # backward-compatible; same as EDSS_numeric, not clipped + "EDSS_is_numeric": result.get("EDSS_is_numeric"), + "EDSS_in_valid_range": result.get("EDSS_in_valid_range"), + + "all_functional_systems_numeric": result.get("all_functional_systems_numeric"), + "all_functional_systems_in_valid_range": result.get("all_functional_systems_in_valid_range"), + } + + raw_subcategories = result.get("raw_subcategories", {}) + numeric_subcategories = result.get("subcategories", {}) + subcat_validation = result.get("subcategory_validation", {}) + + for subcat in FUNCTIONAL_SYSTEM_RANGES: + raw_value = None + numeric_value = None + is_numeric = False + in_valid_range = False + + if isinstance(raw_subcategories, dict): + raw_value = raw_subcategories.get(subcat) + + if isinstance(numeric_subcategories, dict): + numeric_value = numeric_subcategories.get(subcat) + + if isinstance(subcat_validation, dict): + flags = subcat_validation.get(subcat, {}) + if isinstance(flags, dict): + is_numeric = flags.get("is_numeric", False) + in_valid_range = flags.get("in_valid_range", False) + + # New transparent columns + flat[f"raw_subcat_{subcat}"] = raw_value + flat[f"numeric_subcat_{subcat}"] = numeric_value + flat[f"subcat_{subcat}_is_numeric"] = is_numeric + flat[f"subcat_{subcat}_in_valid_range"] = in_valid_range + + # Backward-compatible old column name. + # This is numeric but NOT clipped. + flat[f"subcat_{subcat}"] = numeric_value + + return flat + + +# ========================= +# SUMMARY STATISTICS +# ========================= + +def summarize_records(records): + """ + Create transparent summary statistics per model. + + Separates: + - JSON/schema validity + - numeric parse validity + - clinical range validity + - out-of-range outputs + """ + + df = pd.DataFrame([flatten_result(r) for r in records]) + + if df.empty: + return pd.DataFrame() + + def bool_mean(col): + if col not in df.columns: + return None + return df[col].fillna(False).astype(bool).mean() + + def bool_sum(col): + if col not in df.columns: + return None + return int(df[col].fillna(False).astype(bool).sum()) + + n_records = len(df) + + summary = { + "model": df["model"].iloc[0] if "model" in df.columns else None, + "n_total_responses": n_records, + + "n_success": bool_sum("success"), + "success_rate": bool_mean("success"), + + "n_json_parse_success": bool_sum("json_parse_success"), + "json_parse_success_rate": bool_mean("json_parse_success"), + + "n_required_fields_present": bool_sum("required_fields_present"), + "required_fields_present_rate": bool_mean("required_fields_present"), + + "n_required_schema_success": bool_sum("required_schema_success"), + "required_schema_success_rate": bool_mean("required_schema_success"), + + "n_clinical_range_valid": bool_sum("clinical_range_valid"), + "clinical_range_valid_rate": bool_mean("clinical_range_valid"), + + "n_certainty_present": bool_sum("certainty_present"), + "certainty_present_rate": bool_mean("certainty_present"), + + "n_EDSS_numeric": bool_sum("EDSS_is_numeric"), + "EDSS_numeric_rate": bool_mean("EDSS_is_numeric"), + + "n_EDSS_in_valid_range": bool_sum("EDSS_in_valid_range"), + "EDSS_valid_range_rate": bool_mean("EDSS_in_valid_range"), + } + + # EDSS out-of-range among numeric EDSS outputs + if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: + edss_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) + edss_valid = df["EDSS_in_valid_range"].fillna(False).astype(bool) + edss_out_of_range = edss_numeric & (~edss_valid) + + summary["n_EDSS_out_of_range"] = int(edss_out_of_range.sum()) + summary["EDSS_out_of_range_rate_total"] = float(edss_out_of_range.mean()) + summary["EDSS_out_of_range_rate_among_numeric"] = ( + float(edss_out_of_range.sum() / edss_numeric.sum()) + if edss_numeric.sum() > 0 else None + ) + + # Functional system rates + fs_out_of_range_any = pd.Series(False, index=df.index) + fs_valid_all = pd.Series(True, index=df.index) + + for subcat in FUNCTIONAL_SYSTEM_RANGES: + numeric_col = f"subcat_{subcat}_is_numeric" + valid_col = f"subcat_{subcat}_in_valid_range" + + if numeric_col in df.columns: + numeric_series = df[numeric_col].fillna(False).astype(bool) + else: + numeric_series = pd.Series(False, index=df.index) + + if valid_col in df.columns: + valid_series = df[valid_col].fillna(False).astype(bool) + else: + valid_series = pd.Series(False, index=df.index) + + out_of_range_series = numeric_series & (~valid_series) + + summary[f"n_{subcat}_numeric"] = int(numeric_series.sum()) + summary[f"{subcat}_numeric_rate"] = float(numeric_series.mean()) + + summary[f"n_{subcat}_in_valid_range"] = int(valid_series.sum()) + summary[f"{subcat}_valid_range_rate"] = float(valid_series.mean()) + + summary[f"n_{subcat}_out_of_range"] = int(out_of_range_series.sum()) + summary[f"{subcat}_out_of_range_rate_total"] = float(out_of_range_series.mean()) + summary[f"{subcat}_out_of_range_rate_among_numeric"] = ( + float(out_of_range_series.sum() / numeric_series.sum()) + if numeric_series.sum() > 0 else None + ) + + fs_out_of_range_any = fs_out_of_range_any | out_of_range_series + fs_valid_all = fs_valid_all & valid_series + + summary["n_any_functional_system_out_of_range"] = int(fs_out_of_range_any.sum()) + summary["any_functional_system_out_of_range_rate_total"] = float(fs_out_of_range_any.mean()) + + summary["n_all_functional_systems_in_valid_range"] = int(fs_valid_all.sum()) + summary["all_functional_systems_valid_range_rate"] = float(fs_valid_all.mean()) + + numeric_cols = [ + "inference_time_sec", + "process_cpu_time_sec", + "rss_delta_mb", + "peak_rss_mb", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "certainty_percent", + "EDSS_numeric", + ] + + for col in numeric_cols: + if col in df.columns: + values = pd.to_numeric(df[col], errors="coerce") + summary[f"{col}_mean"] = values.mean() + summary[f"{col}_median"] = values.median() + summary[f"{col}_std"] = values.std() + summary[f"{col}_min"] = values.min() + summary[f"{col}_max"] = values.max() + + if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns: + primary_valid_only = ( + df["EDSS_is_numeric"].fillna(False).astype(bool) + & df["EDSS_in_valid_range"].fillna(False).astype(bool) + ) + + sensitivity_all_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool) + + summary["n_primary_valid_only_EDSS"] = int(primary_valid_only.sum()) + summary["primary_valid_only_EDSS_rate"] = float(primary_valid_only.mean()) + + summary["n_sensitivity_all_numeric_EDSS"] = int(sensitivity_all_numeric.sum()) + summary["sensitivity_all_numeric_EDSS_rate"] = float(sensitivity_all_numeric.mean()) + + return pd.DataFrame([summary]) + + +# ========================= +# ANALYSIS DATASET HELPERS +# ========================= + +def create_analysis_datasets(records): + """ + Create two transparent EDSS analysis datasets: + + 1. primary_valid_only: + Only numeric EDSS predictions within the valid clinical range. + + 2. sensitivity_all_numeric: + All numeric EDSS predictions, including out-of-range values. + No clipping is applied. + """ + + df = pd.DataFrame([flatten_result(r) for r in records]) + + if df.empty: + return df.copy(), df.copy() + + primary_valid_only = df[ + df["EDSS_is_numeric"].fillna(False).astype(bool) + & df["EDSS_in_valid_range"].fillna(False).astype(bool) + ].copy() + + sensitivity_all_numeric = df[ + df["EDSS_is_numeric"].fillna(False).astype(bool) + ].copy() + + return primary_valid_only, sensitivity_all_numeric + + +# ========================= +# INCREMENTAL SAVE HELPERS +# ========================= + +def append_jsonl(path, record): + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + f.flush() + os.fsync(f.fileno()) + + +def append_csv(path, record): + flat = flatten_result(record) + df_one = pd.DataFrame([flat]) + file_exists = Path(path).exists() + df_one.to_csv(path, mode="a", header=not file_exists, index=False) + + +# ========================= +# MAIN LOOP +# ========================= + +if __name__ == "__main__": + + run_timestamp = now_timestamp() + + results_root = Path(RESULTS_ROOT) + results_root.mkdir(parents=True, exist_ok=True) + + run_root = results_root / f"run_{run_timestamp}" + run_root.mkdir(parents=True, exist_ok=True) + + print(f"Results root: {run_root}") + + df = pd.read_csv(INPUT_CSV, sep=";") + + if MAX_ROWS is not None: + df = df.head(MAX_ROWS) + + total_rows = len(df) + + model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS] + + print(f"Loaded {total_rows} patient records.") + print(f"Models: {model_names_for_print}") + print(f"Iterations per model: {NUM_ITERATIONS}") + + all_model_summaries = [] + + for model_config in MODEL_CONFIGS: + model_name = model_config["model_name"] + safe_model = safe_dir_name(model_name) + + model_dir = run_root / safe_model + model_dir.mkdir(parents=True, exist_ok=True) + + print(f"\n{'#' * 80}") + print(f"MODEL: {model_name}") + print(f"use_response_format: {model_config.get('use_response_format', False)}") + print(f"temperature: {model_config.get('temperature', TEMPERATURE)}") + print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}") + print(f"Saving to: {model_dir}") + print(f"{'#' * 80}") + + model_records = [] + model_start = time.perf_counter() + + for iteration in range(1, NUM_ITERATIONS + 1): + print(f"\n{'=' * 60}") + print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}") + print(f"{'=' * 60}") + + iteration_results = [] + iteration_start = time.perf_counter() + + incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl" + incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv" + + print(f"Incremental JSONL: {incremental_jsonl_path}") + print(f"Incremental CSV: {incremental_csv_path}") + + for loop_i, (idx, row) in enumerate(df.iterrows(), start=1): + print( + f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}", + end="", + flush=True + ) + + try: + patient_text = build_patient_text(row) + + record = run_inference( + patient_text=patient_text, + model_config=model_config + ) + + record["iteration"] = iteration + record["row_index"] = int(idx) + record["row_number_in_run"] = int(loop_i) + record["unique_id"] = row.get("unique_id", f"row_{idx}") + record["MedDatum"] = row.get("MedDatum", None) + + iteration_results.append(record) + model_records.append(record) + + if loop_i % SAVE_EVERY_N_ROWS == 0: + append_jsonl(incremental_jsonl_path, record) + append_csv(incremental_csv_path, record) + + if record["success"]: + res = record["result"] or {} + edss_display = res.get("EDSS_numeric", None) + edss_valid = res.get("EDSS_in_valid_range", False) + + print( + f" ✅ EDSS={edss_display}, " + f"valid_range={edss_valid}, " + f"time={record['inference_time_sec']:.2f}s" + ) + else: + print(f" ❌ {record.get('error', 'Unknown error')}") + + except Exception as e: + print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}") + + fallback_record = { + "success": False, + "error": str(e), + "last_error": str(e), + "result": None, + + "validation": { + "json_parse_success": False, + "required_fields_present": False, + "required_schema_success": False, + "clinical_range_valid": False, + "certainty_present": False, + "missing_required_fields": [], + "missing_subcategory_fields": [], + "EDSS_is_numeric": False, + "EDSS_in_valid_range": False, + }, + "raw_parsed_output": None, + + "model": model_name, + "iteration": iteration, + "row_index": int(idx), + "row_number_in_run": int(loop_i), + "unique_id": row.get("unique_id", f"row_{idx}"), + "MedDatum": row.get("MedDatum", None), + + "inference_time_sec": None, + "process_cpu_time_sec": None, + "rss_before_mb": None, + "rss_after_mb": None, + "rss_delta_mb": None, + "peak_rss_mb": None, + + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + + "raw_content": None, + "raw_response_debug": None, + } + + iteration_results.append(fallback_record) + model_records.append(fallback_record) + + append_jsonl(incremental_jsonl_path, fallback_record) + append_csv(incremental_csv_path, fallback_record) + + if STOP_ON_FIRST_ERROR: + break + + iteration_elapsed = time.perf_counter() - iteration_start + + # Final full per-iteration JSON + iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json" + with open(iter_json_path, "w", encoding="utf-8") as f: + json.dump(iteration_results, f, indent=2, ensure_ascii=False) + + # Final full per-iteration CSV + iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv" + iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results]) + iter_flat_df.to_csv(iter_csv_path, index=False) + + # Transparent analysis datasets + primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(iteration_results) + + primary_valid_only_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_primary_valid_only.csv" + sensitivity_all_numeric_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_sensitivity_all_numeric.csv" + + primary_valid_only_df.to_csv(primary_valid_only_path, index=False) + sensitivity_all_numeric_df.to_csv(sensitivity_all_numeric_path, index=False) + + print(f"\n✅ Iteration {iteration} complete.") + print(f"Incremental JSONL saved to: {incremental_jsonl_path}") + print(f"Incremental CSV saved to: {incremental_csv_path}") + print(f"Final JSON saved to: {iter_json_path}") + print(f"Final CSV saved to: {iter_csv_path}") + print(f"Primary valid-only CSV saved to: {primary_valid_only_path}") + print(f"Sensitivity all-numeric CSV: {sensitivity_all_numeric_path}") + print( + f"⏱️ Iteration time: {iteration_elapsed:.1f}s " + f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)" + ) + + model_elapsed = time.perf_counter() - model_start + + # Save all records for this model + model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json" + with open(model_json_path, "w", encoding="utf-8") as f: + json.dump(model_records, f, indent=2, ensure_ascii=False) + + model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv" + model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records]) + model_flat_df.to_csv(model_csv_path, index=False) + + # Save model-level analysis datasets + primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(model_records) + + model_primary_valid_only_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_primary_valid_only.csv" + model_sensitivity_all_numeric_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_sensitivity_all_numeric.csv" + + primary_valid_only_df.to_csv(model_primary_valid_only_path, index=False) + sensitivity_all_numeric_df.to_csv(model_sensitivity_all_numeric_path, index=False) + + # Save model summary + model_summary_df = summarize_records(model_records) + model_summary_df["model_total_wall_time_sec"] = model_elapsed + model_summary_df["model_total_wall_time_min"] = model_elapsed / 60 + + model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv" + model_summary_df.to_csv(model_summary_path, index=False) + + all_model_summaries.append(model_summary_df) + + print(f"\n🎉 Model completed: {model_name}") + print(f"All JSON: {model_json_path}") + print(f"All CSV: {model_csv_path}") + print(f"All primary valid-only CSV: {model_primary_valid_only_path}") + print(f"All sensitivity all-numeric CSV: {model_sensitivity_all_numeric_path}") + print(f"Summary: {model_summary_path}") + print(f"Total model time: {model_elapsed / 60:.2f} min") + + if all_model_summaries: + combined_summary_df = pd.concat(all_model_summaries, ignore_index=True) + combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv" + combined_summary_df.to_csv(combined_summary_path, index=False) + + print(f"\n📊 Combined summary saved to: {combined_summary_path}") + + print(f"\n🎉 All models and all iterations completed!") ## + diff --git a/show_plots.py b/show_plots.py index e824f95..384cbaf 100644 --- a/show_plots.py +++ b/show_plots.py @@ -3118,3 +3118,288 @@ plt.savefig(figure_save_path, format="svg", bbox_inches="tight") plt.show() ## + +# %% Confusion matrix for one EDSS benchmark result file + +import os +from pathlib import Path + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns + +from sklearn.metrics import confusion_matrix, classification_report + + +# ========================= +# CONFIGURATION +# ========================= + +REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" + +RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv" + +OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices" + +TARGET_ITERATION = 1 + +MERGE_KEY = "unique_id" + +# Ground truth EDSS column in the reference file +GT_EDSS_COL = "EDSS" + +# Predicted EDSS column in the result file +PRED_EDSS_COL = "EDSS" + +EDSS_LABELS = [ + "0-1", "1-2", "2-3", "3-4", "4-5", + "5-6", "6-7", "7-8", "8-9", "9-10" +] + + +# ========================= +# HELPERS +# ========================= + +def safe_filename(name): + return ( + str(name) + .replace("/", "_") + .replace("\\", "_") + .replace(" ", "_") + .replace(":", "_") + ) + + +def parse_numeric_column(series): + return pd.to_numeric( + series.astype(str).str.replace(",", ".", regex=False), + errors="coerce" + ) + + +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return "0-1" + elif value <= 2.0: + return "1-2" + elif value <= 3.0: + return "2-3" + elif value <= 4.0: + return "3-4" + elif value <= 5.0: + return "4-5" + elif value <= 6.0: + return "5-6" + elif value <= 7.0: + return "6-7" + elif value <= 8.0: + return "7-8" + elif value <= 9.0: + return "8-9" + elif value <= 10.0: + return "9-10" + else: + return np.nan + + +def load_reference(reference_path): + df_ref = pd.read_csv(reference_path, sep=";") + + if MERGE_KEY not in df_ref.columns: + raise ValueError(f"Reference file does not contain column: {MERGE_KEY}") + + if GT_EDSS_COL not in df_ref.columns: + raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}") + + df_ref = df_ref.copy() + df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str) + + df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL]) + df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss) + + return df_ref + + +def load_result(result_path): + df_res = pd.read_csv(result_path, sep=",") + + if MERGE_KEY not in df_res.columns: + raise ValueError(f"Result file does not contain column: {MERGE_KEY}") + + if PRED_EDSS_COL not in df_res.columns: + raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}") + + df_res = df_res.copy() + df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str) + + if "success" in df_res.columns: + df_res = df_res[ + df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"]) + ] + + if TARGET_ITERATION is not None and "iteration" in df_res.columns: + df_res = df_res[df_res["iteration"] == TARGET_ITERATION] + + df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL]) + df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss) + + return df_res + + +def get_model_name(df_res, result_path): + if "model" in df_res.columns and df_res["model"].notna().any(): + return str(df_res["model"].dropna().iloc[0]) + + return Path(result_path).stem + + +def plot_confusion_matrix(cm, model_name, output_path): + plt.figure(figsize=(10, 8)) + + ax = sns.heatmap( + cm, + annot=True, + fmt="d", + cmap="Blues", + xticklabels=EDSS_LABELS, + yticklabels=EDSS_LABELS + ) + + cbar = ax.collections[0].colorbar + cbar.set_label("Number of Cases", rotation=270, labelpad=20) + + plt.xlabel("LLM Generated EDSS") + plt.ylabel("Ground Truth EDSS") + plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}") + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.show() + + +# ========================= +# MAIN +# ========================= + +if __name__ == "__main__": + + output_dir = Path(OUTPUT_DIR) + output_dir.mkdir(parents=True, exist_ok=True) + + print("Loading reference:") + print(REFERENCE_PATH) + + df_ref = load_reference(REFERENCE_PATH) + + print(f"Reference rows: {len(df_ref)}") + print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}") + + print("\nLoading result:") + print(RESULT_PATH) + + df_res = load_result(RESULT_PATH) + + model_name = get_model_name(df_res, RESULT_PATH) + safe_model = safe_filename(model_name) + + print(f"Model: {model_name}") + print(f"Result rows after filtering: {len(df_res)}") + + before_dedup = len(df_res) + df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first") + after_dedup = len(df_res) + + if before_dedup != after_dedup: + print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}") + + df_merged = df_ref.merge( + df_res, + on=MERGE_KEY, + how="inner", + suffixes=("_gt", "_pred") + ) + + print(f"Merged rows: {len(df_merged)}") + + df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy() + + print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}") + + if df_eval.empty: + raise ValueError("No evaluable rows after merging and EDSS filtering.") + + cm = confusion_matrix( + df_eval["GT_EDSS_cat"], + df_eval["PRED_EDSS_cat"], + labels=EDSS_LABELS + ) + + suffix = f"iter_{TARGET_ITERATION}" + + plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png" + cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv" + report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt" + merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv" + + plot_confusion_matrix(cm, model_name, plot_path) + + cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS) + cm_df.index.name = "Ground Truth EDSS" + cm_df.columns.name = "LLM Generated EDSS" + cm_df.to_csv(cm_csv_path) + + report = classification_report( + df_eval["GT_EDSS_cat"], + df_eval["PRED_EDSS_cat"], + labels=EDSS_LABELS, + zero_division=0 + ) + + with open(report_txt_path, "w", encoding="utf-8") as f: + f.write(f"Model: {model_name}\n") + f.write(f"Result file: {RESULT_PATH}\n") + f.write(f"Target iteration: {TARGET_ITERATION}\n") + f.write(f"Merged rows: {len(df_merged)}\n") + f.write(f"Evaluable rows: {len(df_eval)}\n\n") + f.write("Classification Report:\n") + f.write(report) + f.write("\n\nConfusion Matrix Raw Counts:\n") + f.write(cm_df.to_string()) + + keep_cols = [ + MERGE_KEY, + "MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum", + "GT_EDSS_numeric", + "PRED_EDSS_numeric", + "GT_EDSS_cat", + "PRED_EDSS_cat", + "model", + "iteration", + "success", + "inference_time_sec", + "certainty_percent", + "reason", + ] + + keep_cols = [col for col in keep_cols if col in df_eval.columns] + df_eval[keep_cols].to_csv(merged_csv_path, index=False) + + print("\nClassification Report:") + print(report) + + print("\nConfusion Matrix Raw Counts:") + print(cm_df) + + print("\nSaved files:") + print(f"Plot: {plot_path}") + print(f"Confusion matrix: {cm_csv_path}") + print(f"Report: {report_txt_path}") + print(f"Merged rows: {merged_csv_path}") + + print("\nDone.") +## +