3602 lines
114 KiB
Python
3602 lines
114 KiB
Python
|
||
# %% API call1
|
||
#import time
|
||
#import json
|
||
#import os
|
||
#from datetime import datetime
|
||
#import pandas as pd
|
||
#from openai import OpenAI
|
||
#from dotenv import load_dotenv
|
||
#
|
||
## Load environment variables
|
||
#load_dotenv()
|
||
#
|
||
## === CONFIGURATION ===
|
||
#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
#MODEL_NAME = "GPT-OSS-120B"
|
||
#HEALTH_URL = f"{OPENAI_BASE_URL}/health" # Placeholder - actual health check would need to be implemented
|
||
#CHAT_URL = f"{OPENAI_BASE_URL}/chat/completions"
|
||
#
|
||
## File paths
|
||
#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||
#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
##GRAMMAR_FILE = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/just_edss_schema.gbnf"
|
||
#
|
||
## Initialize OpenAI client
|
||
#client = OpenAI(
|
||
# api_key=OPENAI_API_KEY,
|
||
# base_url=OPENAI_BASE_URL
|
||
#)
|
||
#
|
||
## Read EDSS instructions from file
|
||
#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f:
|
||
# EDSS_INSTRUCTIONS = f.read().strip()
|
||
## === RUN INFERENCE 2 ===
|
||
#def run_inference(patient_text):
|
||
# prompt = f'''
|
||
# Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale) aus klinischen Berichten zu extrahieren.
|
||
#### Regeln für die Ausgabe:
|
||
#1. **Reason**: Erstelle eine prägnante Zusammenfassung (max. 400 Zeichen) der Befunde auf **DEUTSCH**, die zur Einstufung führen.
|
||
#2. **klassifizierbar**:
|
||
# - Setze dies auf **true**, wenn ein EDSS-Wert identifiziert, berechnet oder basierend auf den klinischen Hinweisen plausibel geschätzt werden kann.
|
||
# - Setze dies auf **false**, NUR wenn die Daten absolut unzureichend oder so widersprüchlich sind, dass keinerlei Einstufung möglich ist.
|
||
#3. **EDSS**:
|
||
# - Dieses Feld ist **VERPFLICHTEND**, wenn "klassifizierbar" auf true steht.
|
||
# - Es muss eine Zahl zwischen 0.0 und 10.0 sein.
|
||
# - Versuche stets, den EDSS-Wert so präzise wie möglich zu bestimmen, auch wenn die Datenlage dünn ist (nutze verfügbare Informationen zu Gehstrecke und Funktionssystemen).
|
||
# - Dieses Feld **DARF NICHT ERSCHEINEN**, wenn "klassifizierbar" auf false steht.
|
||
#
|
||
#### Einschränkungen:
|
||
#- Erfinde keine Fakten, aber nutze klinische Herleitungen aus dem Bericht, um den EDSS zu bestimmen.
|
||
#- Priorisiere die Vergabe eines EDSS-Wertes gegenüber der Markierung als nicht klassifizierbar.
|
||
#- Halte dich strikt an die JSON-Struktur.
|
||
#
|
||
#EDSS-Bewertungsrichtlinien:
|
||
#{EDSS_INSTRUCTIONS}
|
||
#
|
||
#Patientenbericht:
|
||
#{patient_text}
|
||
#'''
|
||
# start_time = time.time()
|
||
#
|
||
# try:
|
||
# # Make API call using OpenAI client
|
||
# response = client.chat.completions.create(
|
||
# messages=[
|
||
# {
|
||
# "role": "system",
|
||
# "content": "You extract EDSS scores. You prioritize providing a score even if data is partial, by using clinical inference."
|
||
# },
|
||
# {
|
||
# "role": "user",
|
||
# "content": prompt
|
||
# }
|
||
# ],
|
||
# model=MODEL_NAME,
|
||
# max_tokens=2048,
|
||
# temperature=0.0,
|
||
# response_format={"type": "json_object"}
|
||
# )
|
||
#
|
||
# # Extract content from response
|
||
# content = response.choices[0].message.content
|
||
#
|
||
# # Parse the JSON response
|
||
# parsed = json.loads(content)
|
||
#
|
||
# inference_time = time.time() - start_time
|
||
#
|
||
# return {
|
||
# "success": True,
|
||
# "result": parsed,
|
||
# "inference_time_sec": inference_time
|
||
# }
|
||
#
|
||
# except Exception as e:
|
||
# print(f"Inference error: {e}")
|
||
# return {
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "inference_time_sec": -1
|
||
# }
|
||
## === BUILD PATIENT TEXT ===
|
||
#def build_patient_text(row):
|
||
# return (
|
||
# str(row["T_Zusammenfassung"]) + "\n" +
|
||
# str(row["Diagnosen"]) + "\n" +
|
||
# str(row["T_KlinBef"]) + "\n" +
|
||
# str(row["T_Befunde"]) + "\n"
|
||
# )
|
||
#
|
||
#if __name__ == "__main__":
|
||
# # Read CSV file ONLY inside main block
|
||
# df = pd.read_csv(INPUT_CSV, sep=';')
|
||
# results = []
|
||
#
|
||
# # Process each row
|
||
# for idx, row in df.iterrows():
|
||
# print(f"Processing row {idx + 1}/{len(df)}")
|
||
# try:
|
||
# patient_text = build_patient_text(row)
|
||
# result = run_inference(patient_text)
|
||
#
|
||
# # Add unique_id and MedDatum to result for tracking
|
||
# result["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
# result["MedDatum"] = row.get("MedDatum", None)
|
||
#
|
||
# results.append(result)
|
||
# print(json.dumps(result, indent=2))
|
||
# except Exception as e:
|
||
# print(f"Error processing row {idx}: {e}")
|
||
# results.append({
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "unique_id": row.get("unique_id", f"row_{idx}"),
|
||
# "MedDatum": row.get("MedDatum", None)
|
||
# })
|
||
#
|
||
# # Save results to a JSON file
|
||
# output_json = INPUT_CSV.replace(".csv", "_results_Nisch.json")
|
||
# with open(output_json, 'w') as f:
|
||
# json.dump(results, f, indent=2)
|
||
# print(f"Results saved to {output_json}")
|
||
##
|
||
|
||
|
||
|
||
# %% API call1 - Enhanced with certainty scoring
|
||
#import time
|
||
#import json
|
||
#import os
|
||
#from datetime import datetime
|
||
#import pandas as pd
|
||
#from openai import OpenAI
|
||
#from dotenv import load_dotenv
|
||
#
|
||
## Load environment variables
|
||
#load_dotenv()
|
||
#
|
||
## === CONFIGURATION ===
|
||
#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
#MODEL_NAME = "GPT-OSS-120B"
|
||
#
|
||
## File paths
|
||
#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Test.csv"
|
||
#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
#
|
||
## Initialize OpenAI client
|
||
#client = OpenAI(
|
||
# api_key=OPENAI_API_KEY,
|
||
# base_url=OPENAI_BASE_URL
|
||
#)
|
||
#
|
||
## Read EDSS instructions from file
|
||
#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f:
|
||
# EDSS_INSTRUCTIONS = f.read().strip()
|
||
#
|
||
## === PROMPT WITH CERTAINTY REQUEST ===
|
||
#def build_prompt(patient_text):
|
||
# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren.
|
||
#
|
||
#### Deine Aufgabe:
|
||
#1. Analysiere den Patientenbericht und extrahiere:
|
||
# - Den Gesamt-EDSS-Score (0.0–10.0)
|
||
# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl)
|
||
#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein.
|
||
#
|
||
#### Struktur der JSON-Ausgabe (VERPFLICHTEND):
|
||
#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter.
|
||
#
|
||
#{{
|
||
# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).",
|
||
# "klassifizierbar": true/false,
|
||
# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)",
|
||
# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)",
|
||
# "subcategories": {{
|
||
# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0
|
||
# }}
|
||
#}}
|
||
#
|
||
#### Regeln:
|
||
#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest.
|
||
#- **klassifizierbar**:
|
||
# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind.
|
||
# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist.
|
||
#- **EDSS**:
|
||
# - **VERPFLICHTEND**, wenn `klassifizierbar=true`.
|
||
# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`.
|
||
#- **certainty_percent**:
|
||
# - **Immer present** — Ganzzahl (0–100), basierend auf:
|
||
# - Klarheit und Vollständigkeit der Berichtsangaben,
|
||
# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz),
|
||
# - Konsistenz zwischen den Unterkategorien.
|
||
#- **subcategories**:
|
||
# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein.
|
||
# - Jeder Wert ist entweder:
|
||
# - `null` (wenn keine ausreichende Information vorliegt), **oder**
|
||
# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0).
|
||
# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab.
|
||
# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`.
|
||
#
|
||
#### EDSS-Bewertungsrichtlinien:
|
||
#{EDSS_INSTRUCTIONS}
|
||
#
|
||
#Patientenbericht:
|
||
#{patient_text}
|
||
#'''
|
||
#
|
||
## === INFERENCE FUNCTION ===
|
||
#def run_inference(patient_text):
|
||
# prompt = build_prompt(patient_text)
|
||
#
|
||
# start_time = time.time()
|
||
#
|
||
# try:
|
||
# response = client.chat.completions.create(
|
||
# messages=[
|
||
# {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."}
|
||
# ] + [
|
||
# {"role": "user", "content": prompt}
|
||
# ],
|
||
# model=MODEL_NAME,
|
||
# max_tokens=2048,
|
||
# temperature=0.1, # Slightly higher for more natural certainty estimation (still low for reliability)
|
||
# response_format={"type": "json_object"}
|
||
# )
|
||
#
|
||
# content = response.choices[0].message.content
|
||
#
|
||
# # Parse and validate JSON
|
||
# try:
|
||
# parsed = json.loads(content)
|
||
# except json.JSONDecodeError as e:
|
||
# print(f"⚠️ JSON parsing failed: {e}")
|
||
# print("Raw response:", content[:500])
|
||
# raise ValueError("Model did not return valid JSON")
|
||
#
|
||
# # Enforce required keys
|
||
# if "certainty_percent" not in parsed:
|
||
# print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.")
|
||
# parsed["certainty_percent"] = 0 # fallback
|
||
# elif not isinstance(parsed["certainty_percent"], (int, float)):
|
||
# parsed["certainty_percent"] = int(parsed["certainty_percent"])
|
||
#
|
||
# # Clamp certainty to [0, 100]
|
||
# pct = parsed["certainty_percent"]
|
||
# parsed["certainty_percent"] =max(0, min(100, int(pct)))
|
||
#
|
||
# # Enforce EDSS rules: if not classifiable → remove EDSS
|
||
# if not parsed.get("klassifizierbar", False):
|
||
# if "EDSS" in parsed:
|
||
# del parsed["EDSS"] # per spec, must not appear if not classifiable
|
||
# else:
|
||
# if "EDSS" not in parsed:
|
||
# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.")
|
||
# parsed["EDSS"] = 7.0 # last-resort fallback
|
||
#
|
||
# inference_time = time.time() - start_time
|
||
#
|
||
# return {
|
||
# "success": True,
|
||
# "result": parsed,
|
||
# "inference_time_sec": inference_time
|
||
# }
|
||
#
|
||
# except Exception as e:
|
||
# print(f"❌ Inference error: {e}")
|
||
# return {
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "inference_time_sec": -1,
|
||
# "result": None # no structured output
|
||
# }
|
||
#
|
||
## === BUILD PATIENT TEXT ===
|
||
#def build_patient_text(row):
|
||
# return (
|
||
# str(row.get("T_Zusammenfassung", "")) + "\n" +
|
||
# str(row.get("Diagnosen", "")) + "\n" +
|
||
# str(row.get("T_KlinBef", "")) + "\n" +
|
||
# str(row.get("T_Befunde", ""))
|
||
# )
|
||
#
|
||
#if __name__ == "__main__":
|
||
# # Load data
|
||
# df = pd.read_csv(INPUT_CSV, sep=';')
|
||
# results = []
|
||
#
|
||
# # Optional: limit for testing
|
||
# # df = df.head(3)
|
||
#
|
||
# print(f"Processing {len(df)} rows...")
|
||
# for idx, row in df.iterrows():
|
||
# print(f"\n— Row {idx + 1}/{len(df)} —")
|
||
# try:
|
||
# patient_text = build_patient_text(row)
|
||
# result = run_inference(patient_text)
|
||
#
|
||
# # Attach metadata
|
||
# result["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
# result["MedDatum"] = row.get("MedDatum", None)
|
||
#
|
||
# results.append(result)
|
||
#
|
||
# # Print summary
|
||
# if result["success"]:
|
||
# res = result["result"]
|
||
# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A"
|
||
# print(f"✅ Result → EDSS={edss}, certainty={res.get('certainty_percent', 'N/A')}%")
|
||
# print(f" Reason: {res.get('reason', 'N/A')[:100]}…")
|
||
# else:
|
||
# print(f"❌ Failed: {result.get('error', 'Unknown error')[:100]}")
|
||
#
|
||
# except Exception as e:
|
||
# print(f"⚠️ Error processing row {idx}: {e}")
|
||
# results.append({
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "unique_id": row.get("unique_id", f"row_{idx}"),
|
||
# "MedDatum": row.get("MedDatum", None),
|
||
# "result": None
|
||
# })
|
||
#
|
||
# # Save results
|
||
# output_json = INPUT_CSV.replace(".csv", "_results_Nisch_certainty.json")
|
||
# with open(output_json, 'w', encoding='utf-8') as f:
|
||
# json.dump(results, f, indent=2, ensure_ascii=False)
|
||
# print(f"\n✅ Saved results to: {output_json}")
|
||
#
|
||
##
|
||
|
||
|
||
# %% API call - Multi-iteration EDSS + certainty extraction
|
||
#
|
||
#import time
|
||
#import json
|
||
#import os
|
||
#from datetime import datetime
|
||
#import pandas as pd
|
||
#from openai import OpenAI
|
||
#from dotenv import load_dotenv
|
||
#
|
||
## Load environment variables
|
||
#load_dotenv()
|
||
#
|
||
## === CONFIGURATION ===
|
||
#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
#MODEL_NAME = "GPT-OSS-120B"
|
||
#
|
||
## File paths
|
||
#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||
#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
#
|
||
## Iteration settings
|
||
#NUM_ITERATIONS = 20
|
||
#STOP_ON_FIRST_ERROR = False # Set to True for debugging
|
||
#
|
||
## Initialize OpenAI client
|
||
#client = OpenAI(
|
||
# api_key=OPENAI_API_KEY,
|
||
# base_url=OPENAI_BASE_URL
|
||
#)
|
||
#
|
||
## Read EDSS instructions from file
|
||
#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f:
|
||
# EDSS_INSTRUCTIONS = f.read().strip()
|
||
#
|
||
## === PROMPT (unchanged from before) ===
|
||
#def build_prompt(patient_text):
|
||
# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren.
|
||
#
|
||
#### Deine Aufgabe:
|
||
#1. Analysiere den Patientenbericht und extrahiere:
|
||
# - Den Gesamt-EDSS-Score (0.0–10.0)
|
||
# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl)
|
||
#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein.
|
||
#
|
||
#### Struktur der JSON-Ausgabe (VERPFLICHTEND):
|
||
#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter.
|
||
#
|
||
#{{
|
||
# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).",
|
||
# "klassifizierbar": true/false,
|
||
# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)",
|
||
# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)",
|
||
# "subcategories": {{
|
||
# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0
|
||
# }}
|
||
#}}
|
||
#
|
||
#### Regeln:
|
||
#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest.
|
||
#- **klassifizierbar**:
|
||
# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind.
|
||
# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist.
|
||
#- **EDSS**:
|
||
# - **VERPFLICHTEND**, wenn `klassifizierbar=true`.
|
||
# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`.
|
||
#- **certainty_percent**:
|
||
# - **Immer present** — Ganzzahl (0–100), basierend auf:
|
||
# - Klarheit und Vollständigkeit der Berichtsangaben,
|
||
# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz),
|
||
# - Konsistenz zwischen den Unterkategorien.
|
||
#- **subcategories**:
|
||
# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein.
|
||
# - Jeder Wert ist entweder:
|
||
# - `null` (wenn keine ausreichende Information vorliegt), **oder**
|
||
# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0).
|
||
# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab.
|
||
# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`.
|
||
#
|
||
#### EDSS-Bewertungsrichtlinien:
|
||
#{EDSS_INSTRUCTIONS}
|
||
#
|
||
#Patientenbericht:
|
||
#{patient_text}
|
||
#'''
|
||
#
|
||
## === INFERENCE FUNCTION (unchanged) ===
|
||
#def run_inference(patient_text):
|
||
# prompt = build_prompt(patient_text)
|
||
#
|
||
# start_time = time.time()
|
||
#
|
||
# try:
|
||
# response = client.chat.completions.create(
|
||
# messages=[
|
||
# {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."}
|
||
# ] + [
|
||
# {"role": "user", "content": prompt}
|
||
# ],
|
||
# model=MODEL_NAME,
|
||
# max_tokens=2048,
|
||
# temperature=0.1,
|
||
# response_format={"type": "json_object"}
|
||
# )
|
||
#
|
||
# content = response.choices[0].message.content
|
||
#
|
||
# # Parse and validate JSON
|
||
# try:
|
||
# parsed = json.loads(content)
|
||
# except json.JSONDecodeError as e:
|
||
# print(f"⚠️ JSON parsing failed: {e}")
|
||
# print("Raw response:", content[:500])
|
||
# raise ValueError("Model did not return valid JSON")
|
||
#
|
||
# # Enforce required keys
|
||
# if "certainty_percent" not in parsed:
|
||
# print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.")
|
||
# parsed["certainty_percent"] = 0
|
||
# elif not isinstance(parsed["certainty_percent"], (int, float)):
|
||
# parsed["certainty_percent"] = int(parsed["certainty_percent"])
|
||
#
|
||
# # Clamp certainty to [0, 100]
|
||
# pct = parsed["certainty_percent"]
|
||
# parsed["certainty_percent"] = max(0, min(100, int(pct)))
|
||
#
|
||
# # Enforce EDSS rules
|
||
# if not parsed.get("klassifizierbar", False):
|
||
# if "EDSS" in parsed:
|
||
# del parsed["EDSS"]
|
||
# else:
|
||
# if "EDSS" not in parsed:
|
||
# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.")
|
||
# parsed["EDSS"] = 7.0
|
||
#
|
||
# inference_time = time.time() - start_time
|
||
#
|
||
# return {
|
||
# "success": True,
|
||
# "result": parsed,
|
||
# "inference_time_sec": inference_time
|
||
# }
|
||
#
|
||
# except Exception as e:
|
||
# print(f"❌ Inference error: {e}")
|
||
# return {
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "inference_time_sec": -1,
|
||
# "result": None
|
||
# }
|
||
#
|
||
## === BUILD PATIENT TEXT ===
|
||
#def build_patient_text(row):
|
||
# return (
|
||
# str(row.get("T_Zusammenfassung", "")) + "\n" +
|
||
# str(row.get("Diagnosen", "")) + "\n" +
|
||
# str(row.get("T_KlinBef", "")) + "\n" +
|
||
# str(row.get("T_Befunde", ""))
|
||
# )
|
||
#
|
||
## === MAIN LOOP (NEW: MULTI-ITERATION) ===
|
||
#if __name__ == "__main__":
|
||
# # Load data ONCE (to avoid repeated I/O overhead)
|
||
# df = pd.read_csv(INPUT_CSV, sep=';')
|
||
# total_rows = len(df)
|
||
# print(f"Loaded {total_rows} patient records.")
|
||
#
|
||
# for iteration in range(1, NUM_ITERATIONS + 1):
|
||
# print(f"\n{'='*60}")
|
||
# print(f"🔄 ITERATION {iteration}/{NUM_ITERATIONS}")
|
||
# print(f"{'='*60}")
|
||
#
|
||
# iteration_results = []
|
||
# start_iter = time.time()
|
||
#
|
||
# for idx, row in df.iterrows():
|
||
# print(f"\rRow {idx+1}/{total_rows} | Iter {iteration}", end='', flush=True)
|
||
# try:
|
||
# patient_text = build_patient_text(row)
|
||
# result = run_inference(patient_text)
|
||
#
|
||
# # Attach metadata
|
||
# if result["success"]:
|
||
# res = result["result"].copy() # avoid mutation
|
||
# res["iteration"] = iteration
|
||
# res["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
# res["MedDatum"] = row.get("MedDatum", None)
|
||
# result["result"] = res
|
||
#
|
||
# else:
|
||
# result["iteration"] = iteration
|
||
# result["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
# result["MedDatum"] = row.get("MedDatum", None)
|
||
#
|
||
# iteration_results.append(result)
|
||
#
|
||
# if result["success"]:
|
||
# res = result["result"]
|
||
# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A"
|
||
# print(f" ✅ EDSS={edss}, cert={res.get('certainty_percent', '?')}%")
|
||
# else:
|
||
# print(f" ❌ {result.get('error', 'Unknown')}")
|
||
#
|
||
# except Exception as e:
|
||
# print(f"\n⚠️ Row {idx} failed: {e}")
|
||
# iteration_results.append({
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "iteration": iteration,
|
||
# "unique_id": row.get("unique_id", f"row_{idx}"),
|
||
# "MedDatum": row.get("MedDatum", None),
|
||
# "result": None
|
||
# })
|
||
# if STOP_ON_FIRST_ERROR:
|
||
# break
|
||
#
|
||
# # Save per-iteration results
|
||
# timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
# output_path = INPUT_CSV.replace(".csv", f"_results_iter_{iteration}_{timestamp}.json")
|
||
# with open(output_path, 'w', encoding='utf-8') as f:
|
||
# json.dump(iteration_results, f, indent=2, ensure_ascii=False)
|
||
# print(f"\n✅ Iteration {iteration} complete. Saved to: {output_path}")
|
||
#
|
||
# elapsed = time.time() - start_iter
|
||
# print(f"⏱️ Iteration {iteration} took {elapsed:.1f}s ({elapsed/total_rows:.1f}s/row)")
|
||
#
|
||
# print(f"\n🎉 All {NUM_ITERATIONS} iterations completed!")
|
||
#
|
||
#
|
||
|
||
##
|
||
|
||
|
||
|
||
# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark
|
||
#
|
||
#import time
|
||
#import json
|
||
#import os
|
||
#import re
|
||
#import threading
|
||
#from datetime import datetime
|
||
#from pathlib import Path
|
||
#
|
||
#import pandas as pd
|
||
#from openai import OpenAI
|
||
#from dotenv import load_dotenv
|
||
#
|
||
#try:
|
||
# import psutil
|
||
#except ImportError:
|
||
# psutil = None
|
||
# print("⚠️ psutil is not installed. Resource metrics will be limited.")
|
||
# print("Install with: pip install psutil")
|
||
#
|
||
#
|
||
## =========================
|
||
## CONFIGURATION
|
||
## =========================
|
||
#
|
||
#load_dotenv()
|
||
#
|
||
#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
#
|
||
#MODEL_NAMES = [
|
||
# # "GPT-OSS-120B",
|
||
# "qwen3.6-27b",
|
||
# # "gemma-4-31B-it",
|
||
#]
|
||
#
|
||
#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||
#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
#
|
||
#RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark"
|
||
#
|
||
#NUM_ITERATIONS = 20
|
||
#STOP_ON_FIRST_ERROR = False
|
||
#
|
||
#MAX_TOKENS = 2048
|
||
#TEMPERATURE = 0.1
|
||
#
|
||
## Memory sampling interval during one inference call
|
||
#RESOURCE_SAMPLE_INTERVAL_SEC = 0.05
|
||
#
|
||
#
|
||
## =========================
|
||
## CLIENT
|
||
## =========================
|
||
#
|
||
#client = OpenAI(
|
||
# api_key=OPENAI_API_KEY,
|
||
# base_url=OPENAI_BASE_URL
|
||
#)
|
||
#
|
||
#
|
||
## =========================
|
||
## HELPERS
|
||
## =========================
|
||
#
|
||
#def safe_dir_name(name: str) -> str:
|
||
# """
|
||
# Convert model name to a filesystem-safe directory name.
|
||
# """
|
||
# name = str(name).strip()
|
||
# name = re.sub(r"[^\w\-.]+", "_", name)
|
||
# return name[:150]
|
||
#
|
||
#
|
||
#def now_timestamp() -> str:
|
||
# return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
#
|
||
#
|
||
#def get_process():
|
||
# if psutil is None:
|
||
# return None
|
||
# return psutil.Process(os.getpid())
|
||
#
|
||
#
|
||
#def get_memory_rss_mb(process=None):
|
||
# if psutil is None:
|
||
# return None
|
||
# if process is None:
|
||
# process = get_process()
|
||
# return process.memory_info().rss / (1024 * 1024)
|
||
#
|
||
#
|
||
#def get_cpu_times_sec(process=None):
|
||
# if psutil is None:
|
||
# return None
|
||
# if process is None:
|
||
# process = get_process()
|
||
# cpu_times = process.cpu_times()
|
||
# return cpu_times.user + cpu_times.system
|
||
#
|
||
#
|
||
#class ResourceSampler:
|
||
# """
|
||
# Samples process RSS while an inference call is running.
|
||
# Useful for approximating peak memory usage per request.
|
||
# """
|
||
#
|
||
# def __init__(self, interval_sec=0.05):
|
||
# self.interval_sec = interval_sec
|
||
# self.process = get_process()
|
||
# self.running = False
|
||
# self.thread = None
|
||
# self.samples_mb = []
|
||
#
|
||
# def start(self):
|
||
# if psutil is None:
|
||
# return
|
||
#
|
||
# self.running = True
|
||
# self.samples_mb = []
|
||
# self.thread = threading.Thread(target=self._sample_loop, daemon=True)
|
||
# self.thread.start()
|
||
#
|
||
# def stop(self):
|
||
# if psutil is None:
|
||
# return
|
||
#
|
||
# self.running = False
|
||
# if self.thread is not None:
|
||
# self.thread.join(timeout=1.0)
|
||
#
|
||
# def _sample_loop(self):
|
||
# while self.running:
|
||
# try:
|
||
# rss_mb = get_memory_rss_mb(self.process)
|
||
# self.samples_mb.append(rss_mb)
|
||
# except Exception:
|
||
# pass
|
||
# time.sleep(self.interval_sec)
|
||
#
|
||
# @property
|
||
# def peak_rss_mb(self):
|
||
# if not self.samples_mb:
|
||
# return None
|
||
# return max(self.samples_mb)
|
||
#
|
||
#
|
||
## =========================
|
||
## READ INSTRUCTIONS
|
||
## =========================
|
||
#
|
||
#with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f:
|
||
# EDSS_INSTRUCTIONS = f.read().strip()
|
||
#
|
||
#
|
||
## =========================
|
||
## PROMPT
|
||
## =========================
|
||
#
|
||
#def build_prompt(patient_text):
|
||
# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren.
|
||
#
|
||
#### Deine Aufgabe:
|
||
#1. Analysiere den Patientenbericht und extrahiere:
|
||
# - Den Gesamt-EDSS-Score (0.0–10.0)
|
||
# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl)
|
||
#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein.
|
||
#
|
||
#### Struktur der JSON-Ausgabe (VERPFLICHTEND):
|
||
#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter.
|
||
#
|
||
#{{
|
||
# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).",
|
||
# "klassifizierbar": true/false,
|
||
# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)",
|
||
# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)",
|
||
# "subcategories": {{
|
||
# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0,
|
||
# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0
|
||
# }}
|
||
#}}
|
||
#
|
||
#### Regeln:
|
||
#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest.
|
||
#- **klassifizierbar**:
|
||
# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind.
|
||
# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist.
|
||
#- **EDSS**:
|
||
# - **VERPFLICHTEND**, wenn `klassifizierbar=true`.
|
||
# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`.
|
||
#- **certainty_percent**:
|
||
# - **Immer present** — Ganzzahl (0–100), basierend auf:
|
||
# - Klarheit und Vollständigkeit der Berichtsangaben,
|
||
# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz),
|
||
# - Konsistenz zwischen den Unterkategorien.
|
||
#- **subcategories**:
|
||
# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein.
|
||
# - Jeder Wert ist entweder:
|
||
# - `null` (wenn keine ausreichende Information vorliegt), **oder**
|
||
# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0).
|
||
# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab.
|
||
# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`.
|
||
#
|
||
#### EDSS-Bewertungsrichtlinien:
|
||
#{EDSS_INSTRUCTIONS}
|
||
#
|
||
#Patientenbericht:
|
||
#{patient_text}
|
||
#'''
|
||
#
|
||
#
|
||
## =========================
|
||
## VALIDATION / NORMALIZATION
|
||
## =========================
|
||
#
|
||
#def normalize_model_output(parsed):
|
||
# """
|
||
# Keeps your existing validation behavior, with a few extra safety checks.
|
||
# """
|
||
#
|
||
# if not isinstance(parsed, dict):
|
||
# raise ValueError("Parsed model output is not a JSON object")
|
||
#
|
||
# if "certainty_percent" not in parsed:
|
||
# print("⚠️ Missing 'certainty_percent' in output. Force-adding fallback.")
|
||
# parsed["certainty_percent"] = 0
|
||
# elif not isinstance(parsed["certainty_percent"], (int, float)):
|
||
# parsed["certainty_percent"] = int(parsed["certainty_percent"])
|
||
#
|
||
# parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"])))
|
||
#
|
||
# if "klassifizierbar" not in parsed:
|
||
# parsed["klassifizierbar"] = False
|
||
#
|
||
# if not parsed.get("klassifizierbar", False):
|
||
# parsed.pop("EDSS", None)
|
||
# else:
|
||
# if "EDSS" not in parsed:
|
||
# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.")
|
||
# parsed["EDSS"] = 7.0
|
||
#
|
||
# required_subcategories = {
|
||
# "VISUAL_OPTIC_FUNCTIONS": 6.0,
|
||
# "BRAINSTEM_FUNCTIONS": 6.0,
|
||
# "PYRAMIDAL_FUNCTIONS": 6.0,
|
||
# "CEREBELLAR_FUNCTIONS": 6.0,
|
||
# "SENSORY_FUNCTIONS": 6.0,
|
||
# "BOWEL_AND_BLADDER_FUNCTIONS": 6.0,
|
||
# "CEREBRAL_FUNCTIONS": 6.0,
|
||
# "AMBULATION": 10.0,
|
||
# }
|
||
#
|
||
# if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict):
|
||
# parsed["subcategories"] = {}
|
||
#
|
||
# for key in required_subcategories:
|
||
# if key not in parsed["subcategories"]:
|
||
# parsed["subcategories"][key] = None
|
||
#
|
||
# return parsed
|
||
#
|
||
#
|
||
## =========================
|
||
## INFERENCE FUNCTION
|
||
## =========================
|
||
#
|
||
#def run_inference(patient_text, model_name):
|
||
# prompt = build_prompt(patient_text)
|
||
#
|
||
# process = get_process()
|
||
# sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC)
|
||
#
|
||
# wall_start = time.perf_counter()
|
||
# cpu_start = get_cpu_times_sec(process)
|
||
# rss_start_mb = get_memory_rss_mb(process)
|
||
#
|
||
# sampler.start()
|
||
#
|
||
# try:
|
||
# response = client.chat.completions.create(
|
||
# messages=[
|
||
# {
|
||
# "role": "system",
|
||
# "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."
|
||
# },
|
||
# {
|
||
# "role": "user",
|
||
# "content": prompt
|
||
# }
|
||
# ],
|
||
# model=model_name,
|
||
# max_tokens=MAX_TOKENS,
|
||
# temperature=TEMPERATURE,
|
||
# response_format={"type": "json_object"}
|
||
# )
|
||
#
|
||
# content = response.choices[0].message.content
|
||
#
|
||
# try:
|
||
# parsed = json.loads(content)
|
||
# except json.JSONDecodeError as e:
|
||
# print(f"⚠️ JSON parsing failed: {e}")
|
||
# print("Raw response:", content[:500])
|
||
# raise ValueError("Model did not return valid JSON")
|
||
#
|
||
# parsed = normalize_model_output(parsed)
|
||
#
|
||
# usage = getattr(response, "usage", None)
|
||
#
|
||
# if usage is not None:
|
||
# prompt_tokens = getattr(usage, "prompt_tokens", None)
|
||
# completion_tokens = getattr(usage, "completion_tokens", None)
|
||
# total_tokens = getattr(usage, "total_tokens", None)
|
||
# else:
|
||
# prompt_tokens = None
|
||
# completion_tokens = None
|
||
# total_tokens = None
|
||
#
|
||
# success = True
|
||
# error = None
|
||
# result = parsed
|
||
#
|
||
# except Exception as e:
|
||
# print(f"❌ Inference error: {e}")
|
||
#
|
||
# success = False
|
||
# error = str(e)
|
||
# result = None
|
||
# prompt_tokens = None
|
||
# completion_tokens = None
|
||
# total_tokens = None
|
||
#
|
||
# finally:
|
||
# sampler.stop()
|
||
#
|
||
# wall_end = time.perf_counter()
|
||
# cpu_end = get_cpu_times_sec(process)
|
||
# rss_end_mb = get_memory_rss_mb(process)
|
||
#
|
||
# wall_time_sec = wall_end - wall_start
|
||
#
|
||
# if cpu_start is not None and cpu_end is not None:
|
||
# process_cpu_time_sec = cpu_end - cpu_start
|
||
# else:
|
||
# process_cpu_time_sec = None
|
||
#
|
||
# if rss_start_mb is not None and rss_end_mb is not None:
|
||
# rss_delta_mb = rss_end_mb - rss_start_mb
|
||
# else:
|
||
# rss_delta_mb = None
|
||
#
|
||
# return {
|
||
# "success": success,
|
||
# "error": error,
|
||
# "result": result,
|
||
#
|
||
# "model": model_name,
|
||
#
|
||
# # Main inference timing
|
||
# "inference_time_sec": wall_time_sec,
|
||
#
|
||
# # Resource metrics
|
||
# "process_cpu_time_sec": process_cpu_time_sec,
|
||
# "rss_before_mb": rss_start_mb,
|
||
# "rss_after_mb": rss_end_mb,
|
||
# "rss_delta_mb": rss_delta_mb,
|
||
# "peak_rss_mb": sampler.peak_rss_mb,
|
||
#
|
||
# # Token metrics, when available
|
||
# "prompt_tokens": prompt_tokens,
|
||
# "completion_tokens": completion_tokens,
|
||
# "total_tokens": total_tokens,
|
||
# }
|
||
#
|
||
#
|
||
## =========================
|
||
## BUILD PATIENT TEXT
|
||
## =========================
|
||
#
|
||
#def build_patient_text(row):
|
||
# return (
|
||
# str(row.get("T_Zusammenfassung", "")) + "\n" +
|
||
# str(row.get("Diagnosen", "")) + "\n" +
|
||
# str(row.get("T_KlinBef", "")) + "\n" +
|
||
# str(row.get("T_Befunde", ""))
|
||
# )
|
||
#
|
||
#
|
||
## =========================
|
||
## FLATTEN RESULTS FOR CSV
|
||
## =========================
|
||
#
|
||
#def flatten_result(record):
|
||
# """
|
||
# Converts one result record to a flat row for CSV export.
|
||
# """
|
||
#
|
||
# flat = {
|
||
# "model": record.get("model"),
|
||
# "iteration": record.get("iteration"),
|
||
# "row_index": record.get("row_index"),
|
||
# "unique_id": record.get("unique_id"),
|
||
# "MedDatum": record.get("MedDatum"),
|
||
#
|
||
# "success": record.get("success"),
|
||
# "error": record.get("error"),
|
||
#
|
||
# "inference_time_sec": record.get("inference_time_sec"),
|
||
# "process_cpu_time_sec": record.get("process_cpu_time_sec"),
|
||
# "rss_before_mb": record.get("rss_before_mb"),
|
||
# "rss_after_mb": record.get("rss_after_mb"),
|
||
# "rss_delta_mb": record.get("rss_delta_mb"),
|
||
# "peak_rss_mb": record.get("peak_rss_mb"),
|
||
#
|
||
# "prompt_tokens": record.get("prompt_tokens"),
|
||
# "completion_tokens": record.get("completion_tokens"),
|
||
# "total_tokens": record.get("total_tokens"),
|
||
# }
|
||
#
|
||
# result = record.get("result")
|
||
#
|
||
# if isinstance(result, dict):
|
||
# flat["reason"] = result.get("reason")
|
||
# flat["klassifizierbar"] = result.get("klassifizierbar")
|
||
# flat["EDSS"] = result.get("EDSS")
|
||
# flat["certainty_percent"] = result.get("certainty_percent")
|
||
#
|
||
# subcats = result.get("subcategories", {})
|
||
# if isinstance(subcats, dict):
|
||
# for key, value in subcats.items():
|
||
# flat[f"subcat_{key}"] = value
|
||
#
|
||
# return flat
|
||
#
|
||
#
|
||
#def summarize_records(records):
|
||
# """
|
||
# Creates summary statistics for one model over all iterations.
|
||
# """
|
||
#
|
||
# df = pd.DataFrame([flatten_result(r) for r in records])
|
||
#
|
||
# if df.empty:
|
||
# return pd.DataFrame()
|
||
#
|
||
# summary = {
|
||
# "model": df["model"].iloc[0] if "model" in df.columns else None,
|
||
# "n_records": len(df),
|
||
# "n_success": int(df["success"].sum()) if "success" in df.columns else None,
|
||
# "n_failed": int((~df["success"]).sum()) if "success" in df.columns else None,
|
||
# "success_rate": float(df["success"].mean()) if "success" in df.columns else None,
|
||
# }
|
||
#
|
||
# numeric_cols = [
|
||
# "inference_time_sec",
|
||
# "process_cpu_time_sec",
|
||
# "rss_delta_mb",
|
||
# "peak_rss_mb",
|
||
# "prompt_tokens",
|
||
# "completion_tokens",
|
||
# "total_tokens",
|
||
# "certainty_percent",
|
||
# "EDSS",
|
||
# ]
|
||
#
|
||
# for col in numeric_cols:
|
||
# if col in df.columns:
|
||
# values = pd.to_numeric(df[col], errors="coerce")
|
||
# summary[f"{col}_mean"] = values.mean()
|
||
# summary[f"{col}_median"] = values.median()
|
||
# summary[f"{col}_std"] = values.std()
|
||
# summary[f"{col}_min"] = values.min()
|
||
# summary[f"{col}_max"] = values.max()
|
||
#
|
||
# return pd.DataFrame([summary])
|
||
#
|
||
#
|
||
## =========================
|
||
## MAIN LOOP
|
||
## =========================
|
||
#
|
||
#if __name__ == "__main__":
|
||
#
|
||
# run_timestamp = now_timestamp()
|
||
#
|
||
# results_root = Path(RESULTS_ROOT)
|
||
# results_root.mkdir(parents=True, exist_ok=True)
|
||
#
|
||
# run_root = results_root / f"run_{run_timestamp}"
|
||
# run_root.mkdir(parents=True, exist_ok=True)
|
||
#
|
||
# print(f"Results root: {run_root}")
|
||
#
|
||
# df = pd.read_csv(INPUT_CSV, sep=";")
|
||
# total_rows = len(df)
|
||
#
|
||
# print(f"Loaded {total_rows} patient records.")
|
||
# print(f"Models: {MODEL_NAMES}")
|
||
# print(f"Iterations per model: {NUM_ITERATIONS}")
|
||
#
|
||
# all_model_summaries = []
|
||
#
|
||
# for model_name in MODEL_NAMES:
|
||
# safe_model = safe_dir_name(model_name)
|
||
# model_dir = run_root / safe_model
|
||
# model_dir.mkdir(parents=True, exist_ok=True)
|
||
#
|
||
# print(f"\n{'#' * 80}")
|
||
# print(f"MODEL: {model_name}")
|
||
# print(f"Saving to: {model_dir}")
|
||
# print(f"{'#' * 80}")
|
||
#
|
||
# model_records = []
|
||
# model_start = time.perf_counter()
|
||
#
|
||
# for iteration in range(1, NUM_ITERATIONS + 1):
|
||
# print(f"\n{'=' * 60}")
|
||
# print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}")
|
||
# print(f"{'=' * 60}")
|
||
#
|
||
# iteration_results = []
|
||
# iteration_start = time.perf_counter()
|
||
#
|
||
# for idx, row in df.iterrows():
|
||
# print(
|
||
# f"\rModel={model_name} | Row {idx + 1}/{total_rows} | Iter {iteration}",
|
||
# end="",
|
||
# flush=True
|
||
# )
|
||
#
|
||
# try:
|
||
# patient_text = build_patient_text(row)
|
||
# record = run_inference(patient_text, model_name=model_name)
|
||
#
|
||
# record["iteration"] = iteration
|
||
# record["row_index"] = int(idx)
|
||
# record["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
# record["MedDatum"] = row.get("MedDatum", None)
|
||
#
|
||
# iteration_results.append(record)
|
||
# model_records.append(record)
|
||
#
|
||
# if record["success"]:
|
||
# res = record["result"]
|
||
# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A"
|
||
# print(
|
||
# f" ✅ EDSS={edss}, "
|
||
# f"cert={res.get('certainty_percent', '?')}%, "
|
||
# f"time={record['inference_time_sec']:.2f}s"
|
||
# )
|
||
# else:
|
||
# print(f" ❌ {record.get('error', 'Unknown error')}")
|
||
#
|
||
# except Exception as e:
|
||
# print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}")
|
||
#
|
||
# fallback_record = {
|
||
# "success": False,
|
||
# "error": str(e),
|
||
# "result": None,
|
||
#
|
||
# "model": model_name,
|
||
# "iteration": iteration,
|
||
# "row_index": int(idx),
|
||
# "unique_id": row.get("unique_id", f"row_{idx}"),
|
||
# "MedDatum": row.get("MedDatum", None),
|
||
#
|
||
# "inference_time_sec": None,
|
||
# "process_cpu_time_sec": None,
|
||
# "rss_before_mb": None,
|
||
# "rss_after_mb": None,
|
||
# "rss_delta_mb": None,
|
||
# "peak_rss_mb": None,
|
||
#
|
||
# "prompt_tokens": None,
|
||
# "completion_tokens": None,
|
||
# "total_tokens": None,
|
||
# }
|
||
#
|
||
# iteration_results.append(fallback_record)
|
||
# model_records.append(fallback_record)
|
||
#
|
||
# if STOP_ON_FIRST_ERROR:
|
||
# break
|
||
#
|
||
# iteration_elapsed = time.perf_counter() - iteration_start
|
||
#
|
||
# # Save per-iteration JSON
|
||
# iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json"
|
||
# with open(iter_json_path, "w", encoding="utf-8") as f:
|
||
# json.dump(iteration_results, f, indent=2, ensure_ascii=False)
|
||
#
|
||
# # Save per-iteration CSV
|
||
# iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv"
|
||
# iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results])
|
||
# iter_flat_df.to_csv(iter_csv_path, index=False)
|
||
#
|
||
# print(f"\n✅ Iteration {iteration} complete.")
|
||
# print(f"JSON saved to: {iter_json_path}")
|
||
# print(f"CSV saved to: {iter_csv_path}")
|
||
# print(
|
||
# f"⏱️ Iteration time: {iteration_elapsed:.1f}s "
|
||
# f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)"
|
||
# )
|
||
#
|
||
# model_elapsed = time.perf_counter() - model_start
|
||
#
|
||
# # Save all records for this model
|
||
# model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json"
|
||
# with open(model_json_path, "w", encoding="utf-8") as f:
|
||
# json.dump(model_records, f, indent=2, ensure_ascii=False)
|
||
#
|
||
# model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv"
|
||
# model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records])
|
||
# model_flat_df.to_csv(model_csv_path, index=False)
|
||
#
|
||
# # Save model summary
|
||
# model_summary_df = summarize_records(model_records)
|
||
# model_summary_df["model_total_wall_time_sec"] = model_elapsed
|
||
# model_summary_df["model_total_wall_time_min"] = model_elapsed / 60
|
||
#
|
||
# model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv"
|
||
# model_summary_df.to_csv(model_summary_path, index=False)
|
||
#
|
||
# all_model_summaries.append(model_summary_df)
|
||
#
|
||
# print(f"\n🎉 Model completed: {model_name}")
|
||
# print(f"All JSON: {model_json_path}")
|
||
# print(f"All CSV: {model_csv_path}")
|
||
# print(f"Summary: {model_summary_path}")
|
||
# print(f"Total model time: {model_elapsed / 60:.2f} min")
|
||
#
|
||
# # Save combined model summaries
|
||
# if all_model_summaries:
|
||
# combined_summary_df = pd.concat(all_model_summaries, ignore_index=True)
|
||
# combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv"
|
||
# combined_summary_df.to_csv(combined_summary_path, index=False)
|
||
#
|
||
# print(f"\n📊 Combined summary saved to: {combined_summary_path}")
|
||
#
|
||
# print(f"\n🎉 All models and all iterations completed!")
|
||
#
|
||
##
|
||
|
||
|
||
|
||
|
||
|
||
|
||
# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark
|
||
|
||
import time
|
||
import json
|
||
import os
|
||
import re
|
||
import threading
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import pandas as pd
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
|
||
try:
|
||
import psutil
|
||
except ImportError:
|
||
psutil = None
|
||
print("⚠️ psutil is not installed. Resource metrics will be limited.")
|
||
print("Install with: pip install psutil")
|
||
|
||
|
||
# =========================
|
||
# CONFIGURATION
|
||
# =========================
|
||
|
||
load_dotenv()
|
||
|
||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
|
||
MODEL_CONFIGS = [
|
||
# {
|
||
# "model_name": "qwen3.6-35b-a3b",
|
||
# "use_response_format": False,
|
||
# "temperature": 0.0,
|
||
# "max_tokens": 4096,
|
||
#
|
||
# # If your backend is vLLM / Qwen chat-template compatible,
|
||
# # this may reduce long hidden reasoning and JSON truncation.
|
||
# # If your server errors because of extra_body, set this to None.
|
||
# "extra_body": {
|
||
# "chat_template_kwargs": {
|
||
# "enable_thinking": False
|
||
# }
|
||
# },
|
||
# },
|
||
{
|
||
"model_name": "gemma-4-31B-it",
|
||
"use_response_format": False,
|
||
"temperature": 0.0,
|
||
"max_tokens": 4096,
|
||
"extra_body": None,
|
||
},
|
||
# {
|
||
# "model_name": "GPT-OSS-120B",
|
||
# "use_response_format": True,
|
||
# "temperature": 0.0,
|
||
# "max_tokens": 4096,
|
||
# "extra_body": None,
|
||
# },
|
||
]
|
||
|
||
INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||
EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
|
||
RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark"
|
||
|
||
NUM_ITERATIONS = 10
|
||
STOP_ON_FIRST_ERROR = False
|
||
|
||
# For testing, set to e.g. 2.
|
||
# For full run, set to None.
|
||
MAX_ROWS = 2
|
||
# MAX_ROWS = 2
|
||
|
||
MAX_TOKENS = 4096
|
||
TEMPERATURE = 0.0
|
||
|
||
RESOURCE_SAMPLE_INTERVAL_SEC = 0.05
|
||
|
||
SAVE_EVERY_N_ROWS = 1
|
||
|
||
# Retries for invalid JSON / truncated JSON
|
||
MAX_JSON_RETRIES = 2
|
||
RETRY_SLEEP_SEC = 2
|
||
|
||
|
||
# =========================
|
||
# CLIENT
|
||
# =========================
|
||
|
||
client = OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_BASE_URL
|
||
)
|
||
|
||
|
||
# =========================
|
||
# HELPERS
|
||
# =========================
|
||
|
||
def safe_dir_name(name: str) -> str:
|
||
name = str(name).strip()
|
||
name = re.sub(r"[^\w\-.]+", "_", name)
|
||
return name[:150]
|
||
|
||
|
||
def now_timestamp() -> str:
|
||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
|
||
def get_process():
|
||
if psutil is None:
|
||
return None
|
||
return psutil.Process(os.getpid())
|
||
|
||
|
||
def get_memory_rss_mb(process=None):
|
||
if psutil is None:
|
||
return None
|
||
if process is None:
|
||
process = get_process()
|
||
return process.memory_info().rss / (1024 * 1024)
|
||
|
||
|
||
def get_cpu_times_sec(process=None):
|
||
if psutil is None:
|
||
return None
|
||
if process is None:
|
||
process = get_process()
|
||
cpu_times = process.cpu_times()
|
||
return cpu_times.user + cpu_times.system
|
||
|
||
|
||
class ResourceSampler:
|
||
def __init__(self, interval_sec=0.05):
|
||
self.interval_sec = interval_sec
|
||
self.process = get_process()
|
||
self.running = False
|
||
self.thread = None
|
||
self.samples_mb = []
|
||
|
||
def start(self):
|
||
if psutil is None:
|
||
return
|
||
|
||
self.running = True
|
||
self.samples_mb = []
|
||
self.thread = threading.Thread(target=self._sample_loop, daemon=True)
|
||
self.thread.start()
|
||
|
||
def stop(self):
|
||
if psutil is None:
|
||
return
|
||
|
||
self.running = False
|
||
if self.thread is not None:
|
||
self.thread.join(timeout=1.0)
|
||
|
||
def _sample_loop(self):
|
||
while self.running:
|
||
try:
|
||
rss_mb = get_memory_rss_mb(self.process)
|
||
self.samples_mb.append(rss_mb)
|
||
except Exception:
|
||
pass
|
||
time.sleep(self.interval_sec)
|
||
|
||
@property
|
||
def peak_rss_mb(self):
|
||
if not self.samples_mb:
|
||
return None
|
||
return max(self.samples_mb)
|
||
|
||
|
||
# =========================
|
||
# JSON EXTRACTION
|
||
# =========================
|
||
|
||
def extract_json_from_text(text):
|
||
if text is None:
|
||
raise ValueError("Model returned empty content: message.content is None")
|
||
|
||
text = str(text).strip()
|
||
|
||
if not text:
|
||
raise ValueError("Model returned empty content")
|
||
|
||
text = (
|
||
text.replace("```json", "")
|
||
.replace("```JSON", "")
|
||
.replace("```Json", "")
|
||
.replace("```", "")
|
||
.strip()
|
||
)
|
||
|
||
# Direct parse
|
||
try:
|
||
parsed = json.loads(text)
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Balanced JSON candidates
|
||
candidates = []
|
||
stack = []
|
||
start_idx = None
|
||
in_string = False
|
||
escape = False
|
||
|
||
for i, ch in enumerate(text):
|
||
if escape:
|
||
escape = False
|
||
continue
|
||
|
||
if ch == "\\":
|
||
escape = True
|
||
continue
|
||
|
||
if ch == '"':
|
||
in_string = not in_string
|
||
continue
|
||
|
||
if in_string:
|
||
continue
|
||
|
||
if ch == "{":
|
||
if not stack:
|
||
start_idx = i
|
||
stack.append(ch)
|
||
|
||
elif ch == "}":
|
||
if stack:
|
||
stack.pop()
|
||
if not stack and start_idx is not None:
|
||
candidates.append(text[start_idx:i + 1])
|
||
start_idx = None
|
||
|
||
valid_objects = []
|
||
|
||
for candidate in candidates:
|
||
candidate = candidate.strip()
|
||
lowered = candidate.lower()
|
||
|
||
invalid_markers = [
|
||
"true/false",
|
||
"null or",
|
||
"oder zahl",
|
||
"0.0-6.0",
|
||
"0.0-10.0",
|
||
"zahl zwischen",
|
||
"...",
|
||
]
|
||
|
||
if any(marker in lowered for marker in invalid_markers):
|
||
continue
|
||
|
||
try:
|
||
parsed = json.loads(candidate)
|
||
if isinstance(parsed, dict):
|
||
valid_objects.append(parsed)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
for obj in reversed(valid_objects):
|
||
if (
|
||
"klassifizierbar" in obj
|
||
and "certainty_percent" in obj
|
||
and "subcategories" in obj
|
||
):
|
||
return obj
|
||
|
||
if valid_objects:
|
||
return valid_objects[-1]
|
||
|
||
# Check if it looks like truncated JSON
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and not stripped.endswith("}"):
|
||
raise ValueError(
|
||
"Model output looks like truncated JSON. "
|
||
f"Raw output starts with: {text[:1000]}"
|
||
)
|
||
|
||
raise ValueError(
|
||
"No valid JSON object found in model output. "
|
||
f"Raw output starts with: {text[:1000]}"
|
||
)
|
||
|
||
|
||
def extract_message_content(message):
|
||
raw_content = getattr(message, "content", None)
|
||
|
||
if raw_content is not None:
|
||
return raw_content
|
||
|
||
msg_dict = None
|
||
|
||
try:
|
||
msg_dict = message.model_dump()
|
||
except Exception:
|
||
try:
|
||
msg_dict = dict(message)
|
||
except Exception:
|
||
msg_dict = None
|
||
|
||
if not isinstance(msg_dict, dict):
|
||
return None
|
||
|
||
for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]:
|
||
value = msg_dict.get(key)
|
||
if value:
|
||
return value
|
||
|
||
possible_content = msg_dict.get("content")
|
||
if isinstance(possible_content, list):
|
||
parts = []
|
||
for block in possible_content:
|
||
if isinstance(block, dict):
|
||
if "text" in block:
|
||
parts.append(str(block["text"]))
|
||
elif "content" in block:
|
||
parts.append(str(block["content"]))
|
||
if parts:
|
||
return "\n".join(parts).strip()
|
||
|
||
return None
|
||
|
||
|
||
# =========================
|
||
# READ INSTRUCTIONS
|
||
# =========================
|
||
|
||
with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f:
|
||
EDSS_INSTRUCTIONS = f.read().strip()
|
||
|
||
|
||
# =========================
|
||
# PROMPT
|
||
# =========================
|
||
|
||
def build_prompt(patient_text):
|
||
return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten.
|
||
|
||
Extrahiere:
|
||
1. Gesamt-EDSS-Score von 0.0 bis 10.0
|
||
2. Alle 8 EDSS-Unterkategorien
|
||
3. Sicherheit als Ganzzahl von 0 bis 100
|
||
|
||
Antworte ausschließlich mit EINEM validen JSON-Objekt.
|
||
Kein Markdown.
|
||
Keine Code-Fences.
|
||
Kein Text vor oder nach JSON.
|
||
Keine Platzhalter.
|
||
Kopiere kein Schema.
|
||
|
||
Das JSON muss exakt diese Schlüssel enthalten:
|
||
- reason
|
||
- klassifizierbar
|
||
- EDSS
|
||
- certainty_percent
|
||
- subcategories
|
||
|
||
Die subcategories müssen exakt diese 8 Schlüssel enthalten:
|
||
- VISUAL_OPTIC_FUNCTIONS
|
||
- BRAINSTEM_FUNCTIONS
|
||
- PYRAMIDAL_FUNCTIONS
|
||
- CEREBELLAR_FUNCTIONS
|
||
- SENSORY_FUNCTIONS
|
||
- BOWEL_AND_BLADDER_FUNCTIONS
|
||
- CEREBRAL_FUNCTIONS
|
||
- AMBULATION
|
||
|
||
Werte:
|
||
- klassifizierbar: true oder false
|
||
- EDSS: Zahl von 0.0 bis 10.0 oder null
|
||
- certainty_percent: Ganzzahl von 0 bis 100
|
||
- Unterkategorien: Zahl oder null
|
||
- VISUAL_OPTIC_FUNCTIONS maximal 6.0
|
||
- BRAINSTEM_FUNCTIONS maximal 6.0
|
||
- PYRAMIDAL_FUNCTIONS maximal 6.0
|
||
- CEREBELLAR_FUNCTIONS maximal 6.0
|
||
- SENSORY_FUNCTIONS maximal 6.0
|
||
- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0
|
||
- CEREBRAL_FUNCTIONS maximal 6.0
|
||
- AMBULATION maximal 10.0
|
||
- reason: maximal 250 Zeichen, Deutsch
|
||
|
||
Wenn klassifizierbar false ist, setze EDSS auf null.
|
||
|
||
Valide Beispielausgabe:
|
||
{{
|
||
"reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.",
|
||
"klassifizierbar": true,
|
||
"EDSS": 2.0,
|
||
"certainty_percent": 90,
|
||
"subcategories": {{
|
||
"VISUAL_OPTIC_FUNCTIONS": null,
|
||
"BRAINSTEM_FUNCTIONS": null,
|
||
"PYRAMIDAL_FUNCTIONS": 1.0,
|
||
"CEREBELLAR_FUNCTIONS": 1.0,
|
||
"SENSORY_FUNCTIONS": 1.0,
|
||
"BOWEL_AND_BLADDER_FUNCTIONS": null,
|
||
"CEREBRAL_FUNCTIONS": null,
|
||
"AMBULATION": 0.0
|
||
}}
|
||
}}
|
||
|
||
EDSS-Bewertungsrichtlinien:
|
||
{EDSS_INSTRUCTIONS}
|
||
|
||
Patientenbericht:
|
||
{patient_text}
|
||
|
||
Gib ausschließlich das finale JSON-Objekt zurück.
|
||
'''
|
||
|
||
|
||
# =========================
|
||
# VALIDATION / NORMALIZATION
|
||
# =========================
|
||
|
||
def normalize_model_output(parsed):
|
||
if not isinstance(parsed, dict):
|
||
raise ValueError("Parsed model output is not a JSON object")
|
||
|
||
if "certainty_percent" not in parsed:
|
||
parsed["certainty_percent"] = 0
|
||
elif not isinstance(parsed["certainty_percent"], (int, float)):
|
||
try:
|
||
parsed["certainty_percent"] = int(parsed["certainty_percent"])
|
||
except Exception:
|
||
parsed["certainty_percent"] = 0
|
||
|
||
parsed["certainty_percent"] = max(0, min(100, int(parsed["certainty_percent"])))
|
||
|
||
if "klassifizierbar" not in parsed:
|
||
parsed["klassifizierbar"] = False
|
||
|
||
if not isinstance(parsed["klassifizierbar"], bool):
|
||
if str(parsed["klassifizierbar"]).lower() in ["true", "1", "yes", "ja"]:
|
||
parsed["klassifizierbar"] = True
|
||
else:
|
||
parsed["klassifizierbar"] = False
|
||
|
||
if not parsed.get("klassifizierbar", False):
|
||
parsed["EDSS"] = None
|
||
else:
|
||
if "EDSS" not in parsed or parsed["EDSS"] is None:
|
||
parsed["EDSS"] = 7.0
|
||
else:
|
||
try:
|
||
parsed["EDSS"] = float(parsed["EDSS"])
|
||
parsed["EDSS"] = max(0.0, min(10.0, parsed["EDSS"]))
|
||
except Exception:
|
||
parsed["EDSS"] = 7.0
|
||
|
||
required_subcategories = {
|
||
"VISUAL_OPTIC_FUNCTIONS": 6.0,
|
||
"BRAINSTEM_FUNCTIONS": 6.0,
|
||
"PYRAMIDAL_FUNCTIONS": 6.0,
|
||
"CEREBELLAR_FUNCTIONS": 6.0,
|
||
"SENSORY_FUNCTIONS": 6.0,
|
||
"BOWEL_AND_BLADDER_FUNCTIONS": 6.0,
|
||
"CEREBRAL_FUNCTIONS": 6.0,
|
||
"AMBULATION": 10.0,
|
||
}
|
||
|
||
if "subcategories" not in parsed or not isinstance(parsed["subcategories"], dict):
|
||
parsed["subcategories"] = {}
|
||
|
||
for key, max_value in required_subcategories.items():
|
||
value = parsed["subcategories"].get(key, None)
|
||
|
||
if value is None:
|
||
parsed["subcategories"][key] = None
|
||
else:
|
||
try:
|
||
value = float(value)
|
||
value = max(0.0, min(max_value, value))
|
||
parsed["subcategories"][key] = value
|
||
except Exception:
|
||
parsed["subcategories"][key] = None
|
||
|
||
parsed["subcategories"] = {
|
||
key: parsed["subcategories"].get(key, None)
|
||
for key in required_subcategories.keys()
|
||
}
|
||
|
||
if "reason" not in parsed or parsed["reason"] is None:
|
||
parsed["reason"] = ""
|
||
|
||
parsed["reason"] = str(parsed["reason"])[:250]
|
||
|
||
return parsed
|
||
|
||
|
||
# =========================
|
||
# API CALL
|
||
# =========================
|
||
|
||
def make_chat_completion(model_config, prompt):
|
||
model_name = model_config["model_name"]
|
||
|
||
kwargs = dict(
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"Du bist ein JSON-Generator. "
|
||
"Antworte ausschließlich mit einem einzigen validen JSON-Objekt. "
|
||
"Keine Erklärung. Kein Markdown. Keine Code-Fences. "
|
||
"Keine Platzhalter. Kein Schema kopieren. "
|
||
"Das JSON muss vollständig geschlossen sein."
|
||
)
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
model=model_name,
|
||
max_tokens=model_config.get("max_tokens", MAX_TOKENS),
|
||
temperature=model_config.get("temperature", TEMPERATURE),
|
||
)
|
||
|
||
if model_config.get("use_response_format", False):
|
||
kwargs["response_format"] = {"type": "json_object"}
|
||
|
||
extra_body = model_config.get("extra_body")
|
||
if extra_body is not None:
|
||
kwargs["extra_body"] = extra_body
|
||
|
||
return client.chat.completions.create(**kwargs)
|
||
|
||
|
||
# =========================
|
||
# INFERENCE FUNCTION WITH RETRIES
|
||
# =========================
|
||
|
||
def run_inference(patient_text, model_config):
|
||
model_name = model_config["model_name"]
|
||
prompt = build_prompt(patient_text)
|
||
|
||
process = get_process()
|
||
sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC)
|
||
|
||
wall_start = time.perf_counter()
|
||
cpu_start = get_cpu_times_sec(process)
|
||
rss_start_mb = get_memory_rss_mb(process)
|
||
|
||
sampler.start()
|
||
|
||
raw_content = None
|
||
raw_response_debug = None
|
||
last_error = None
|
||
prompt_tokens = None
|
||
completion_tokens = None
|
||
total_tokens = None
|
||
|
||
try:
|
||
for attempt in range(1, MAX_JSON_RETRIES + 2):
|
||
try:
|
||
response = make_chat_completion(
|
||
model_config=model_config,
|
||
prompt=prompt
|
||
)
|
||
|
||
message = response.choices[0].message
|
||
raw_content = extract_message_content(message)
|
||
|
||
try:
|
||
raw_response_debug = response.model_dump()
|
||
except Exception:
|
||
raw_response_debug = str(response)
|
||
|
||
usage = getattr(response, "usage", None)
|
||
if usage is not None:
|
||
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
||
completion_tokens = getattr(usage, "completion_tokens", None)
|
||
total_tokens = getattr(usage, "total_tokens", None)
|
||
|
||
parsed = extract_json_from_text(raw_content)
|
||
parsed = normalize_model_output(parsed)
|
||
|
||
success = True
|
||
error = None
|
||
result = parsed
|
||
break
|
||
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
|
||
if attempt <= MAX_JSON_RETRIES:
|
||
print(
|
||
f"\n⚠️ JSON failed on attempt {attempt}. "
|
||
f"Retrying row. Error: {last_error[:300]}"
|
||
)
|
||
time.sleep(RETRY_SLEEP_SEC)
|
||
continue
|
||
|
||
raise
|
||
|
||
except Exception as e:
|
||
print(f"❌ Inference error: {e}")
|
||
|
||
success = False
|
||
error = str(e)
|
||
result = None
|
||
|
||
finally:
|
||
sampler.stop()
|
||
|
||
wall_end = time.perf_counter()
|
||
cpu_end = get_cpu_times_sec(process)
|
||
rss_end_mb = get_memory_rss_mb(process)
|
||
|
||
wall_time_sec = wall_end - wall_start
|
||
|
||
if cpu_start is not None and cpu_end is not None:
|
||
process_cpu_time_sec = cpu_end - cpu_start
|
||
else:
|
||
process_cpu_time_sec = None
|
||
|
||
if rss_start_mb is not None and rss_end_mb is not None:
|
||
rss_delta_mb = rss_end_mb - rss_start_mb
|
||
else:
|
||
rss_delta_mb = None
|
||
|
||
return {
|
||
"success": success,
|
||
"error": error,
|
||
"result": result,
|
||
|
||
"model": model_name,
|
||
|
||
"inference_time_sec": wall_time_sec,
|
||
|
||
"process_cpu_time_sec": process_cpu_time_sec,
|
||
"rss_before_mb": rss_start_mb,
|
||
"rss_after_mb": rss_end_mb,
|
||
"rss_delta_mb": rss_delta_mb,
|
||
"peak_rss_mb": sampler.peak_rss_mb,
|
||
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": total_tokens,
|
||
|
||
"raw_content": raw_content if not success else None,
|
||
"raw_response_debug": raw_response_debug if not success else None,
|
||
"last_error": last_error,
|
||
}
|
||
|
||
|
||
# =========================
|
||
# BUILD PATIENT TEXT
|
||
# =========================
|
||
|
||
def build_patient_text(row):
|
||
return (
|
||
str(row.get("T_Zusammenfassung", "")) + "\n" +
|
||
str(row.get("Diagnosen", "")) + "\n" +
|
||
str(row.get("T_KlinBef", "")) + "\n" +
|
||
str(row.get("T_Befunde", ""))
|
||
)
|
||
|
||
|
||
# =========================
|
||
# FLATTEN RESULTS FOR CSV
|
||
# =========================
|
||
|
||
def flatten_result(record):
|
||
flat = {
|
||
"model": record.get("model"),
|
||
"iteration": record.get("iteration"),
|
||
"row_index": record.get("row_index"),
|
||
"row_number_in_run": record.get("row_number_in_run"),
|
||
"unique_id": record.get("unique_id"),
|
||
"MedDatum": record.get("MedDatum"),
|
||
|
||
"success": record.get("success"),
|
||
"error": record.get("error"),
|
||
"last_error": record.get("last_error"),
|
||
|
||
"inference_time_sec": record.get("inference_time_sec"),
|
||
"process_cpu_time_sec": record.get("process_cpu_time_sec"),
|
||
"rss_before_mb": record.get("rss_before_mb"),
|
||
"rss_after_mb": record.get("rss_after_mb"),
|
||
"rss_delta_mb": record.get("rss_delta_mb"),
|
||
"peak_rss_mb": record.get("peak_rss_mb"),
|
||
|
||
"prompt_tokens": record.get("prompt_tokens"),
|
||
"completion_tokens": record.get("completion_tokens"),
|
||
"total_tokens": record.get("total_tokens"),
|
||
|
||
"raw_content": record.get("raw_content"),
|
||
}
|
||
|
||
result = record.get("result")
|
||
|
||
if isinstance(result, dict):
|
||
flat["reason"] = result.get("reason")
|
||
flat["klassifizierbar"] = result.get("klassifizierbar")
|
||
flat["EDSS"] = result.get("EDSS")
|
||
flat["certainty_percent"] = result.get("certainty_percent")
|
||
|
||
subcats = result.get("subcategories", {})
|
||
if isinstance(subcats, dict):
|
||
for key, value in subcats.items():
|
||
flat[f"subcat_{key}"] = value
|
||
|
||
return flat
|
||
|
||
|
||
def summarize_records(records):
|
||
df = pd.DataFrame([flatten_result(r) for r in records])
|
||
|
||
if df.empty:
|
||
return pd.DataFrame()
|
||
|
||
success_series = df["success"].fillna(False).astype(bool)
|
||
|
||
summary = {
|
||
"model": df["model"].iloc[0] if "model" in df.columns else None,
|
||
"n_records": len(df),
|
||
"n_success": int(success_series.sum()),
|
||
"n_failed": int((~success_series).sum()),
|
||
"success_rate": float(success_series.mean()),
|
||
}
|
||
|
||
numeric_cols = [
|
||
"inference_time_sec",
|
||
"process_cpu_time_sec",
|
||
"rss_delta_mb",
|
||
"peak_rss_mb",
|
||
"prompt_tokens",
|
||
"completion_tokens",
|
||
"total_tokens",
|
||
"certainty_percent",
|
||
"EDSS",
|
||
]
|
||
|
||
for col in numeric_cols:
|
||
if col in df.columns:
|
||
values = pd.to_numeric(df[col], errors="coerce")
|
||
summary[f"{col}_mean"] = values.mean()
|
||
summary[f"{col}_median"] = values.median()
|
||
summary[f"{col}_std"] = values.std()
|
||
summary[f"{col}_min"] = values.min()
|
||
summary[f"{col}_max"] = values.max()
|
||
|
||
return pd.DataFrame([summary])
|
||
|
||
|
||
# =========================
|
||
# INCREMENTAL SAVE HELPERS
|
||
# =========================
|
||
|
||
def append_jsonl(path, record):
|
||
with open(path, "a", encoding="utf-8") as f:
|
||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||
f.flush()
|
||
os.fsync(f.fileno())
|
||
|
||
|
||
def append_csv(path, record):
|
||
flat = flatten_result(record)
|
||
df_one = pd.DataFrame([flat])
|
||
file_exists = Path(path).exists()
|
||
df_one.to_csv(path, mode="a", header=not file_exists, index=False)
|
||
|
||
|
||
# =========================
|
||
# MAIN LOOP
|
||
# =========================
|
||
|
||
if __name__ == "__main__":
|
||
|
||
run_timestamp = now_timestamp()
|
||
|
||
results_root = Path(RESULTS_ROOT)
|
||
results_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
run_root = results_root / f"run_{run_timestamp}"
|
||
run_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"Results root: {run_root}")
|
||
|
||
df = pd.read_csv(INPUT_CSV, sep=";")
|
||
|
||
if MAX_ROWS is not None:
|
||
df = df.head(MAX_ROWS)
|
||
|
||
total_rows = len(df)
|
||
|
||
model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS]
|
||
|
||
print(f"Loaded {total_rows} patient records.")
|
||
print(f"Models: {model_names_for_print}")
|
||
print(f"Iterations per model: {NUM_ITERATIONS}")
|
||
|
||
all_model_summaries = []
|
||
|
||
for model_config in MODEL_CONFIGS:
|
||
model_name = model_config["model_name"]
|
||
safe_model = safe_dir_name(model_name)
|
||
|
||
model_dir = run_root / safe_model
|
||
model_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"\n{'#' * 80}")
|
||
print(f"MODEL: {model_name}")
|
||
print(f"use_response_format: {model_config.get('use_response_format', False)}")
|
||
print(f"temperature: {model_config.get('temperature', TEMPERATURE)}")
|
||
print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}")
|
||
print(f"Saving to: {model_dir}")
|
||
print(f"{'#' * 80}")
|
||
|
||
model_records = []
|
||
model_start = time.perf_counter()
|
||
|
||
for iteration in range(1, NUM_ITERATIONS + 1):
|
||
print(f"\n{'=' * 60}")
|
||
print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}")
|
||
print(f"{'=' * 60}")
|
||
|
||
iteration_results = []
|
||
iteration_start = time.perf_counter()
|
||
|
||
incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl"
|
||
incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv"
|
||
|
||
print(f"Incremental JSONL: {incremental_jsonl_path}")
|
||
print(f"Incremental CSV: {incremental_csv_path}")
|
||
|
||
for loop_i, (idx, row) in enumerate(df.iterrows(), start=1):
|
||
print(
|
||
f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}",
|
||
end="",
|
||
flush=True
|
||
)
|
||
|
||
try:
|
||
patient_text = build_patient_text(row)
|
||
|
||
record = run_inference(
|
||
patient_text=patient_text,
|
||
model_config=model_config
|
||
)
|
||
|
||
record["iteration"] = iteration
|
||
record["row_index"] = int(idx)
|
||
record["row_number_in_run"] = int(loop_i)
|
||
record["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
record["MedDatum"] = row.get("MedDatum", None)
|
||
|
||
iteration_results.append(record)
|
||
model_records.append(record)
|
||
|
||
if loop_i % SAVE_EVERY_N_ROWS == 0:
|
||
append_jsonl(incremental_jsonl_path, record)
|
||
append_csv(incremental_csv_path, record)
|
||
|
||
if record["success"]:
|
||
res = record["result"]
|
||
edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A"
|
||
print(
|
||
f" ✅ EDSS={edss}, "
|
||
f"cert={res.get('certainty_percent', '?')}%, "
|
||
f"time={record['inference_time_sec']:.2f}s"
|
||
)
|
||
else:
|
||
print(f" ❌ {record.get('error', 'Unknown error')}")
|
||
|
||
except Exception as e:
|
||
print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}")
|
||
|
||
fallback_record = {
|
||
"success": False,
|
||
"error": str(e),
|
||
"last_error": str(e),
|
||
"result": None,
|
||
|
||
"model": model_name,
|
||
"iteration": iteration,
|
||
"row_index": int(idx),
|
||
"row_number_in_run": int(loop_i),
|
||
"unique_id": row.get("unique_id", f"row_{idx}"),
|
||
"MedDatum": row.get("MedDatum", None),
|
||
|
||
"inference_time_sec": None,
|
||
"process_cpu_time_sec": None,
|
||
"rss_before_mb": None,
|
||
"rss_after_mb": None,
|
||
"rss_delta_mb": None,
|
||
"peak_rss_mb": None,
|
||
|
||
"prompt_tokens": None,
|
||
"completion_tokens": None,
|
||
"total_tokens": None,
|
||
|
||
"raw_content": None,
|
||
"raw_response_debug": None,
|
||
}
|
||
|
||
iteration_results.append(fallback_record)
|
||
model_records.append(fallback_record)
|
||
|
||
append_jsonl(incremental_jsonl_path, fallback_record)
|
||
append_csv(incremental_csv_path, fallback_record)
|
||
|
||
if STOP_ON_FIRST_ERROR:
|
||
break
|
||
|
||
iteration_elapsed = time.perf_counter() - iteration_start
|
||
|
||
iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json"
|
||
with open(iter_json_path, "w", encoding="utf-8") as f:
|
||
json.dump(iteration_results, f, indent=2, ensure_ascii=False)
|
||
|
||
iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv"
|
||
iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results])
|
||
iter_flat_df.to_csv(iter_csv_path, index=False)
|
||
|
||
print(f"\n✅ Iteration {iteration} complete.")
|
||
print(f"Incremental JSONL saved to: {incremental_jsonl_path}")
|
||
print(f"Incremental CSV saved to: {incremental_csv_path}")
|
||
print(f"Final JSON saved to: {iter_json_path}")
|
||
print(f"Final CSV saved to: {iter_csv_path}")
|
||
print(
|
||
f"⏱️ Iteration time: {iteration_elapsed:.1f}s "
|
||
f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)"
|
||
)
|
||
|
||
model_elapsed = time.perf_counter() - model_start
|
||
|
||
model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json"
|
||
with open(model_json_path, "w", encoding="utf-8") as f:
|
||
json.dump(model_records, f, indent=2, ensure_ascii=False)
|
||
|
||
model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv"
|
||
model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records])
|
||
model_flat_df.to_csv(model_csv_path, index=False)
|
||
|
||
model_summary_df = summarize_records(model_records)
|
||
model_summary_df["model_total_wall_time_sec"] = model_elapsed
|
||
model_summary_df["model_total_wall_time_min"] = model_elapsed / 60
|
||
|
||
model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv"
|
||
model_summary_df.to_csv(model_summary_path, index=False)
|
||
|
||
all_model_summaries.append(model_summary_df)
|
||
|
||
print(f"\n🎉 Model completed: {model_name}")
|
||
print(f"All JSON: {model_json_path}")
|
||
print(f"All CSV: {model_csv_path}")
|
||
print(f"Summary: {model_summary_path}")
|
||
print(f"Total model time: {model_elapsed / 60:.2f} min")
|
||
|
||
if all_model_summaries:
|
||
combined_summary_df = pd.concat(all_model_summaries, ignore_index=True)
|
||
combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv"
|
||
combined_summary_df.to_csv(combined_summary_path, index=False)
|
||
|
||
print(f"\n📊 Combined summary saved to: {combined_summary_path}")
|
||
|
||
print(f"\n🎉 All models and all iterations completed!")
|
||
##
|
||
|
||
# %% API call - Multi-model, multi-iteration EDSS + timing/resource benchmark2
|
||
|
||
import time
|
||
import json
|
||
import os
|
||
import re
|
||
import threading
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import pandas as pd
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
|
||
try:
|
||
import psutil
|
||
except ImportError:
|
||
psutil = None
|
||
print("⚠️ psutil is not installed. Resource metrics will be limited.")
|
||
print("Install with: pip install psutil")
|
||
|
||
|
||
# =========================
|
||
# CONFIGURATION
|
||
# =========================
|
||
|
||
load_dotenv()
|
||
|
||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
||
|
||
MODEL_CONFIGS = [
|
||
{
|
||
"model_name": "qwen3.6-27b",
|
||
"use_response_format": False,
|
||
"temperature": 0.0,
|
||
"max_tokens": 4096,
|
||
"extra_body": {
|
||
"chat_template_kwargs": {
|
||
"enable_thinking": False
|
||
}
|
||
},
|
||
},
|
||
{
|
||
"model_name": "gemma-4-31B-it",
|
||
"use_response_format": False,
|
||
"temperature": 0.0,
|
||
"max_tokens": 4096,
|
||
"extra_body": None,
|
||
},
|
||
# {
|
||
# "model_name": "GPT-OSS-120B",
|
||
# "use_response_format": True,
|
||
# "temperature": 0.0,
|
||
# "max_tokens": 4096,
|
||
# "extra_body": None,
|
||
# },
|
||
]
|
||
|
||
INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||
EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt"
|
||
|
||
RESULTS_ROOT = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark"
|
||
|
||
NUM_ITERATIONS = 10
|
||
STOP_ON_FIRST_ERROR = False
|
||
|
||
# For testing, set to e.g. 2.
|
||
# For full run, set to None.
|
||
MAX_ROWS = None
|
||
# MAX_ROWS = 2
|
||
|
||
MAX_TOKENS = 4096
|
||
TEMPERATURE = 0.0
|
||
|
||
RESOURCE_SAMPLE_INTERVAL_SEC = 0.05
|
||
|
||
SAVE_EVERY_N_ROWS = 1
|
||
|
||
# Retries for invalid JSON / truncated JSON
|
||
MAX_JSON_RETRIES = 2
|
||
RETRY_SLEEP_SEC = 2
|
||
|
||
|
||
# =========================
|
||
# VALID CLINICAL RANGES
|
||
# =========================
|
||
|
||
EDSS_MIN = 0.0
|
||
EDSS_MAX = 10.0
|
||
|
||
FUNCTIONAL_SYSTEM_RANGES = {
|
||
"VISUAL_OPTIC_FUNCTIONS": (0.0, 6.0),
|
||
"BRAINSTEM_FUNCTIONS": (0.0, 6.0),
|
||
"PYRAMIDAL_FUNCTIONS": (0.0, 6.0),
|
||
"CEREBELLAR_FUNCTIONS": (0.0, 6.0),
|
||
"SENSORY_FUNCTIONS": (0.0, 6.0),
|
||
"BOWEL_AND_BLADDER_FUNCTIONS": (0.0, 6.0),
|
||
"CEREBRAL_FUNCTIONS": (0.0, 6.0),
|
||
"AMBULATION": (0.0, 10.0),
|
||
}
|
||
|
||
REQUIRED_TOP_LEVEL_FIELDS = [
|
||
"reason",
|
||
"klassifizierbar",
|
||
"EDSS",
|
||
"certainty_percent",
|
||
"subcategories",
|
||
]
|
||
|
||
|
||
# =========================
|
||
# CLIENT
|
||
# =========================
|
||
|
||
client = OpenAI(
|
||
api_key=OPENAI_API_KEY,
|
||
base_url=OPENAI_BASE_URL
|
||
)
|
||
|
||
|
||
# =========================
|
||
# HELPERS
|
||
# =========================
|
||
|
||
def safe_dir_name(name: str) -> str:
|
||
name = str(name).strip()
|
||
name = re.sub(r"[^\w\-.]+", "_", name)
|
||
return name[:150]
|
||
|
||
|
||
def now_timestamp() -> str:
|
||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
|
||
def get_process():
|
||
if psutil is None:
|
||
return None
|
||
return psutil.Process(os.getpid())
|
||
|
||
|
||
def get_memory_rss_mb(process=None):
|
||
if psutil is None:
|
||
return None
|
||
if process is None:
|
||
process = get_process()
|
||
return process.memory_info().rss / (1024 * 1024)
|
||
|
||
|
||
def get_cpu_times_sec(process=None):
|
||
if psutil is None:
|
||
return None
|
||
if process is None:
|
||
process = get_process()
|
||
cpu_times = process.cpu_times()
|
||
return cpu_times.user + cpu_times.system
|
||
|
||
|
||
class ResourceSampler:
|
||
def __init__(self, interval_sec=0.05):
|
||
self.interval_sec = interval_sec
|
||
self.process = get_process()
|
||
self.running = False
|
||
self.thread = None
|
||
self.samples_mb = []
|
||
|
||
def start(self):
|
||
if psutil is None:
|
||
return
|
||
|
||
self.running = True
|
||
self.samples_mb = []
|
||
self.thread = threading.Thread(target=self._sample_loop, daemon=True)
|
||
self.thread.start()
|
||
|
||
def stop(self):
|
||
if psutil is None:
|
||
return
|
||
|
||
self.running = False
|
||
if self.thread is not None:
|
||
self.thread.join(timeout=1.0)
|
||
|
||
def _sample_loop(self):
|
||
while self.running:
|
||
try:
|
||
rss_mb = get_memory_rss_mb(self.process)
|
||
self.samples_mb.append(rss_mb)
|
||
except Exception:
|
||
pass
|
||
time.sleep(self.interval_sec)
|
||
|
||
@property
|
||
def peak_rss_mb(self):
|
||
if not self.samples_mb:
|
||
return None
|
||
return max(self.samples_mb)
|
||
|
||
|
||
# =========================
|
||
# JSON EXTRACTION
|
||
# =========================
|
||
|
||
def extract_json_from_text(text):
|
||
if text is None:
|
||
raise ValueError("Model returned empty content: message.content is None")
|
||
|
||
text = str(text).strip()
|
||
|
||
if not text:
|
||
raise ValueError("Model returned empty content")
|
||
|
||
text = (
|
||
text.replace("```json", "")
|
||
.replace("```JSON", "")
|
||
.replace("```Json", "")
|
||
.replace("```", "")
|
||
.strip()
|
||
)
|
||
|
||
# Direct parse
|
||
try:
|
||
parsed = json.loads(text)
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Balanced JSON candidates
|
||
candidates = []
|
||
stack = []
|
||
start_idx = None
|
||
in_string = False
|
||
escape = False
|
||
|
||
for i, ch in enumerate(text):
|
||
if escape:
|
||
escape = False
|
||
continue
|
||
|
||
if ch == "\\":
|
||
escape = True
|
||
continue
|
||
|
||
if ch == '"':
|
||
in_string = not in_string
|
||
continue
|
||
|
||
if in_string:
|
||
continue
|
||
|
||
if ch == "{":
|
||
if not stack:
|
||
start_idx = i
|
||
stack.append(ch)
|
||
|
||
elif ch == "}":
|
||
if stack:
|
||
stack.pop()
|
||
if not stack and start_idx is not None:
|
||
candidates.append(text[start_idx:i + 1])
|
||
start_idx = None
|
||
|
||
valid_objects = []
|
||
|
||
for candidate in candidates:
|
||
candidate = candidate.strip()
|
||
lowered = candidate.lower()
|
||
|
||
invalid_markers = [
|
||
"true/false",
|
||
"null or",
|
||
"oder zahl",
|
||
"0.0-6.0",
|
||
"0.0-10.0",
|
||
"zahl zwischen",
|
||
"...",
|
||
]
|
||
|
||
if any(marker in lowered for marker in invalid_markers):
|
||
continue
|
||
|
||
try:
|
||
parsed = json.loads(candidate)
|
||
if isinstance(parsed, dict):
|
||
valid_objects.append(parsed)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
for obj in reversed(valid_objects):
|
||
if (
|
||
"klassifizierbar" in obj
|
||
and "certainty_percent" in obj
|
||
and "subcategories" in obj
|
||
):
|
||
return obj
|
||
|
||
if valid_objects:
|
||
return valid_objects[-1]
|
||
|
||
stripped = text.strip()
|
||
if stripped.startswith("{") and not stripped.endswith("}"):
|
||
raise ValueError(
|
||
"Model output looks like truncated JSON. "
|
||
f"Raw output starts with: {text[:1000]}"
|
||
)
|
||
|
||
raise ValueError(
|
||
"No valid JSON object found in model output. "
|
||
f"Raw output starts with: {text[:1000]}"
|
||
)
|
||
|
||
|
||
def extract_message_content(message):
|
||
raw_content = getattr(message, "content", None)
|
||
|
||
if raw_content is not None:
|
||
return raw_content
|
||
|
||
msg_dict = None
|
||
|
||
try:
|
||
msg_dict = message.model_dump()
|
||
except Exception:
|
||
try:
|
||
msg_dict = dict(message)
|
||
except Exception:
|
||
msg_dict = None
|
||
|
||
if not isinstance(msg_dict, dict):
|
||
return None
|
||
|
||
for key in ["content", "reasoning_content", "reasoning", "text", "output_text"]:
|
||
value = msg_dict.get(key)
|
||
if value:
|
||
return value
|
||
|
||
possible_content = msg_dict.get("content")
|
||
if isinstance(possible_content, list):
|
||
parts = []
|
||
for block in possible_content:
|
||
if isinstance(block, dict):
|
||
if "text" in block:
|
||
parts.append(str(block["text"]))
|
||
elif "content" in block:
|
||
parts.append(str(block["content"]))
|
||
if parts:
|
||
return "\n".join(parts).strip()
|
||
|
||
return None
|
||
|
||
|
||
# =========================
|
||
# READ INSTRUCTIONS
|
||
# =========================
|
||
|
||
with open(EDSS_INSTRUCTIONS_PATH, "r", encoding="utf-8") as f:
|
||
EDSS_INSTRUCTIONS = f.read().strip()
|
||
|
||
|
||
# =========================
|
||
# PROMPT
|
||
# =========================
|
||
|
||
def build_prompt(patient_text):
|
||
return f'''Du bist ein medizinischer Assistent für EDSS-Extraktion aus klinischen Berichten.
|
||
|
||
Extrahiere:
|
||
1. Gesamt-EDSS-Score von 0.0 bis 10.0
|
||
2. Alle 8 EDSS-Unterkategorien
|
||
3. Sicherheit als Ganzzahl von 0 bis 100
|
||
|
||
Antworte ausschließlich mit EINEM validen JSON-Objekt.
|
||
Kein Markdown.
|
||
Keine Code-Fences.
|
||
Kein Text vor oder nach JSON.
|
||
Keine Platzhalter.
|
||
Kopiere kein Schema.
|
||
|
||
Das JSON muss exakt diese Schlüssel enthalten:
|
||
- reason
|
||
- klassifizierbar
|
||
- EDSS
|
||
- certainty_percent
|
||
- subcategories
|
||
|
||
Die subcategories müssen exakt diese 8 Schlüssel enthalten:
|
||
- VISUAL_OPTIC_FUNCTIONS
|
||
- BRAINSTEM_FUNCTIONS
|
||
- PYRAMIDAL_FUNCTIONS
|
||
- CEREBELLAR_FUNCTIONS
|
||
- SENSORY_FUNCTIONS
|
||
- BOWEL_AND_BLADDER_FUNCTIONS
|
||
- CEREBRAL_FUNCTIONS
|
||
- AMBULATION
|
||
|
||
Werte:
|
||
- klassifizierbar: true oder false
|
||
- EDSS: Zahl von 0.0 bis 10.0 oder null
|
||
- certainty_percent: Ganzzahl von 0 bis 100
|
||
- Unterkategorien: Zahl oder null
|
||
- VISUAL_OPTIC_FUNCTIONS maximal 6.0
|
||
- BRAINSTEM_FUNCTIONS maximal 6.0
|
||
- PYRAMIDAL_FUNCTIONS maximal 6.0
|
||
- CEREBELLAR_FUNCTIONS maximal 6.0
|
||
- SENSORY_FUNCTIONS maximal 6.0
|
||
- BOWEL_AND_BLADDER_FUNCTIONS maximal 6.0
|
||
- CEREBRAL_FUNCTIONS maximal 6.0
|
||
- AMBULATION maximal 10.0
|
||
- reason: maximal 250 Zeichen, Deutsch
|
||
|
||
Wenn klassifizierbar false ist, setze EDSS auf null.
|
||
|
||
Valide Beispielausgabe:
|
||
{{
|
||
"reason": "Leichte Einschränkungen mit sicher ableitbarer Gehfähigkeit und geringen funktionellen Defiziten.",
|
||
"klassifizierbar": true,
|
||
"EDSS": 2.0,
|
||
"certainty_percent": 90,
|
||
"subcategories": {{
|
||
"VISUAL_OPTIC_FUNCTIONS": null,
|
||
"BRAINSTEM_FUNCTIONS": null,
|
||
"PYRAMIDAL_FUNCTIONS": 1.0,
|
||
"CEREBELLAR_FUNCTIONS": 1.0,
|
||
"SENSORY_FUNCTIONS": 1.0,
|
||
"BOWEL_AND_BLADDER_FUNCTIONS": null,
|
||
"CEREBRAL_FUNCTIONS": null,
|
||
"AMBULATION": 0.0
|
||
}}
|
||
}}
|
||
|
||
EDSS-Bewertungsrichtlinien:
|
||
{EDSS_INSTRUCTIONS}
|
||
|
||
Patientenbericht:
|
||
{patient_text}
|
||
|
||
Gib ausschließlich das finale JSON-Objekt zurück.
|
||
'''
|
||
|
||
|
||
# =========================
|
||
# VALIDATION, NOT NORMALIZATION
|
||
# =========================
|
||
|
||
def parse_float_preserve_raw(value):
|
||
"""
|
||
Try to parse a value as float without clipping or correcting it.
|
||
|
||
Returns:
|
||
raw_value: original value exactly as present in parsed JSON
|
||
numeric_value: float or None
|
||
is_numeric: bool
|
||
"""
|
||
raw_value = value
|
||
|
||
if value is None:
|
||
return raw_value, None, False
|
||
|
||
if isinstance(value, bool):
|
||
return raw_value, None, False
|
||
|
||
try:
|
||
numeric_value = float(str(value).replace(",", "."))
|
||
return raw_value, numeric_value, True
|
||
except Exception:
|
||
return raw_value, None, False
|
||
|
||
|
||
def is_in_range(value, min_value, max_value):
|
||
"""
|
||
Range check without clipping.
|
||
"""
|
||
if value is None:
|
||
return False
|
||
return min_value <= value <= max_value
|
||
|
||
|
||
def validate_model_output(parsed):
|
||
"""
|
||
Validate parsed model output without repairing/clipping clinical values.
|
||
|
||
Important:
|
||
- Does NOT clip EDSS.
|
||
- Does NOT clip functional system values.
|
||
- Does NOT insert default EDSS.
|
||
- Does NOT insert default certainty_percent.
|
||
- Missing fields are kept as None.
|
||
- Adds explicit validity flags for scientific transparency.
|
||
"""
|
||
|
||
validation = {
|
||
"json_parse_success": isinstance(parsed, dict),
|
||
"required_fields_present": False,
|
||
"required_schema_success": False,
|
||
"clinical_range_valid": False,
|
||
"certainty_present": False,
|
||
|
||
"missing_required_fields": [],
|
||
"missing_subcategory_fields": [],
|
||
|
||
"EDSS_is_numeric": False,
|
||
"EDSS_in_valid_range": False,
|
||
}
|
||
|
||
if not isinstance(parsed, dict):
|
||
return {
|
||
"raw_output": parsed,
|
||
"validated_output": {},
|
||
"validation": validation,
|
||
}
|
||
|
||
missing_required = [
|
||
field for field in REQUIRED_TOP_LEVEL_FIELDS
|
||
if field not in parsed
|
||
]
|
||
|
||
validation["missing_required_fields"] = missing_required
|
||
validation["required_fields_present"] = len(missing_required) == 0
|
||
|
||
validated = {}
|
||
|
||
validated["reason"] = parsed.get("reason", None)
|
||
validated["klassifizierbar"] = parsed.get("klassifizierbar", None)
|
||
|
||
raw_certainty = parsed.get("certainty_percent", None)
|
||
validated["raw_certainty_percent"] = raw_certainty
|
||
validation["certainty_present"] = "certainty_percent" in parsed and raw_certainty is not None
|
||
|
||
_, certainty_numeric, certainty_is_numeric = parse_float_preserve_raw(raw_certainty)
|
||
validated["certainty_percent"] = certainty_numeric if certainty_is_numeric else None
|
||
validated["certainty_percent_is_numeric"] = certainty_is_numeric
|
||
validated["certainty_percent_in_valid_range"] = (
|
||
is_in_range(certainty_numeric, 0.0, 100.0)
|
||
if certainty_is_numeric else False
|
||
)
|
||
|
||
raw_edss = parsed.get("EDSS", None)
|
||
raw_edss, edss_numeric, edss_is_numeric = parse_float_preserve_raw(raw_edss)
|
||
|
||
validated["raw_EDSS"] = raw_edss
|
||
validated["EDSS_numeric"] = edss_numeric
|
||
validated["EDSS"] = edss_numeric # Backward-compatible; parsed only, not clipped
|
||
validated["EDSS_is_numeric"] = edss_is_numeric
|
||
validated["EDSS_in_valid_range"] = (
|
||
is_in_range(edss_numeric, EDSS_MIN, EDSS_MAX)
|
||
if edss_is_numeric else False
|
||
)
|
||
|
||
validation["EDSS_is_numeric"] = validated["EDSS_is_numeric"]
|
||
validation["EDSS_in_valid_range"] = validated["EDSS_in_valid_range"]
|
||
|
||
raw_subcategories = parsed.get("subcategories", None)
|
||
|
||
if isinstance(raw_subcategories, dict):
|
||
subcategories = raw_subcategories
|
||
else:
|
||
subcategories = {}
|
||
|
||
validated["subcategories"] = {}
|
||
validated["raw_subcategories"] = {}
|
||
validated["subcategory_validation"] = {}
|
||
|
||
missing_subcats = []
|
||
|
||
for subcat, (min_value, max_value) in FUNCTIONAL_SYSTEM_RANGES.items():
|
||
if subcat not in subcategories:
|
||
missing_subcats.append(subcat)
|
||
|
||
raw_value = subcategories.get(subcat, None)
|
||
raw_value, numeric_value, is_numeric_value = parse_float_preserve_raw(raw_value)
|
||
in_valid_range = (
|
||
is_in_range(numeric_value, min_value, max_value)
|
||
if is_numeric_value else False
|
||
)
|
||
|
||
validated["raw_subcategories"][subcat] = raw_value
|
||
validated["subcategories"][subcat] = numeric_value
|
||
|
||
validated["subcategory_validation"][subcat] = {
|
||
"is_numeric": is_numeric_value,
|
||
"in_valid_range": in_valid_range,
|
||
"min_allowed": min_value,
|
||
"max_allowed": max_value,
|
||
}
|
||
|
||
validation["missing_subcategory_fields"] = missing_subcats
|
||
|
||
subcategory_schema_present = len(missing_subcats) == 0
|
||
|
||
all_subcats_numeric = all(
|
||
validated["subcategory_validation"][subcat]["is_numeric"]
|
||
for subcat in FUNCTIONAL_SYSTEM_RANGES
|
||
)
|
||
|
||
all_subcats_in_range = all(
|
||
validated["subcategory_validation"][subcat]["in_valid_range"]
|
||
for subcat in FUNCTIONAL_SYSTEM_RANGES
|
||
)
|
||
|
||
validated["all_functional_systems_numeric"] = all_subcats_numeric
|
||
validated["all_functional_systems_in_valid_range"] = all_subcats_in_range
|
||
|
||
validation["clinical_range_valid"] = (
|
||
validated["EDSS_in_valid_range"]
|
||
and all_subcats_in_range
|
||
)
|
||
|
||
validation["required_schema_success"] = (
|
||
validation["required_fields_present"]
|
||
and subcategory_schema_present
|
||
)
|
||
|
||
return {
|
||
"raw_output": parsed,
|
||
"validated_output": validated,
|
||
"validation": validation,
|
||
}
|
||
|
||
|
||
# =========================
|
||
# API CALL
|
||
# =========================
|
||
|
||
def make_chat_completion(model_config, prompt):
|
||
model_name = model_config["model_name"]
|
||
|
||
kwargs = dict(
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"Du bist ein JSON-Generator. "
|
||
"Antworte ausschließlich mit einem einzigen validen JSON-Objekt. "
|
||
"Keine Erklärung. Kein Markdown. Keine Code-Fences. "
|
||
"Keine Platzhalter. Kein Schema kopieren. "
|
||
"Das JSON muss vollständig geschlossen sein."
|
||
)
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": prompt
|
||
}
|
||
],
|
||
model=model_name,
|
||
max_tokens=model_config.get("max_tokens", MAX_TOKENS),
|
||
temperature=model_config.get("temperature", TEMPERATURE),
|
||
)
|
||
|
||
if model_config.get("use_response_format", False):
|
||
kwargs["response_format"] = {"type": "json_object"}
|
||
|
||
extra_body = model_config.get("extra_body")
|
||
if extra_body is not None:
|
||
kwargs["extra_body"] = extra_body
|
||
|
||
return client.chat.completions.create(**kwargs)
|
||
|
||
|
||
# =========================
|
||
# INFERENCE FUNCTION WITH RETRIES
|
||
# =========================
|
||
|
||
def run_inference(patient_text, model_config):
|
||
model_name = model_config["model_name"]
|
||
prompt = build_prompt(patient_text)
|
||
|
||
process = get_process()
|
||
sampler = ResourceSampler(interval_sec=RESOURCE_SAMPLE_INTERVAL_SEC)
|
||
|
||
wall_start = time.perf_counter()
|
||
cpu_start = get_cpu_times_sec(process)
|
||
rss_start_mb = get_memory_rss_mb(process)
|
||
|
||
sampler.start()
|
||
|
||
raw_content = None
|
||
raw_response_debug = None
|
||
raw_parsed_output = None
|
||
validation = None
|
||
last_error = None
|
||
|
||
prompt_tokens = None
|
||
completion_tokens = None
|
||
total_tokens = None
|
||
|
||
try:
|
||
for attempt in range(1, MAX_JSON_RETRIES + 2):
|
||
try:
|
||
response = make_chat_completion(
|
||
model_config=model_config,
|
||
prompt=prompt
|
||
)
|
||
|
||
message = response.choices[0].message
|
||
raw_content = extract_message_content(message)
|
||
|
||
try:
|
||
raw_response_debug = response.model_dump()
|
||
except Exception:
|
||
raw_response_debug = str(response)
|
||
|
||
usage = getattr(response, "usage", None)
|
||
if usage is not None:
|
||
prompt_tokens = getattr(usage, "prompt_tokens", None)
|
||
completion_tokens = getattr(usage, "completion_tokens", None)
|
||
total_tokens = getattr(usage, "total_tokens", None)
|
||
|
||
parsed = extract_json_from_text(raw_content)
|
||
validation_package = validate_model_output(parsed)
|
||
|
||
success = True
|
||
error = None
|
||
|
||
result = validation_package["validated_output"]
|
||
validation = validation_package["validation"]
|
||
raw_parsed_output = validation_package["raw_output"]
|
||
|
||
break
|
||
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
|
||
if attempt <= MAX_JSON_RETRIES:
|
||
print(
|
||
f"\n⚠️ JSON failed on attempt {attempt}. "
|
||
f"Retrying row. Error: {last_error[:300]}"
|
||
)
|
||
time.sleep(RETRY_SLEEP_SEC)
|
||
continue
|
||
|
||
raise
|
||
|
||
except Exception as e:
|
||
print(f"❌ Inference error: {e}")
|
||
|
||
success = False
|
||
error = str(e)
|
||
result = None
|
||
raw_parsed_output = None
|
||
|
||
validation = {
|
||
"json_parse_success": False,
|
||
"required_fields_present": False,
|
||
"required_schema_success": False,
|
||
"clinical_range_valid": False,
|
||
"certainty_present": False,
|
||
"missing_required_fields": [],
|
||
"missing_subcategory_fields": [],
|
||
"EDSS_is_numeric": False,
|
||
"EDSS_in_valid_range": False,
|
||
}
|
||
|
||
finally:
|
||
sampler.stop()
|
||
|
||
wall_end = time.perf_counter()
|
||
cpu_end = get_cpu_times_sec(process)
|
||
rss_end_mb = get_memory_rss_mb(process)
|
||
|
||
wall_time_sec = wall_end - wall_start
|
||
|
||
if cpu_start is not None and cpu_end is not None:
|
||
process_cpu_time_sec = cpu_end - cpu_start
|
||
else:
|
||
process_cpu_time_sec = None
|
||
|
||
if rss_start_mb is not None and rss_end_mb is not None:
|
||
rss_delta_mb = rss_end_mb - rss_start_mb
|
||
else:
|
||
rss_delta_mb = None
|
||
|
||
return {
|
||
"success": success,
|
||
"error": error,
|
||
"result": result,
|
||
|
||
"validation": validation,
|
||
"raw_parsed_output": raw_parsed_output,
|
||
|
||
"model": model_name,
|
||
|
||
"inference_time_sec": wall_time_sec,
|
||
|
||
"process_cpu_time_sec": process_cpu_time_sec,
|
||
"rss_before_mb": rss_start_mb,
|
||
"rss_after_mb": rss_end_mb,
|
||
"rss_delta_mb": rss_delta_mb,
|
||
"peak_rss_mb": sampler.peak_rss_mb,
|
||
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": total_tokens,
|
||
|
||
# Keeping raw content improves auditability but can make files large.
|
||
# To save space, change this to: raw_content if not success else None
|
||
"raw_content": raw_content,
|
||
"raw_response_debug": raw_response_debug if not success else None,
|
||
"last_error": last_error,
|
||
}
|
||
|
||
|
||
# =========================
|
||
# BUILD PATIENT TEXT
|
||
# =========================
|
||
|
||
def build_patient_text(row):
|
||
return (
|
||
str(row.get("T_Zusammenfassung", "")) + "\n" +
|
||
str(row.get("Diagnosen", "")) + "\n" +
|
||
str(row.get("T_KlinBef", "")) + "\n" +
|
||
str(row.get("T_Befunde", ""))
|
||
)
|
||
|
||
|
||
# =========================
|
||
# FLATTEN RESULTS FOR CSV
|
||
# =========================
|
||
|
||
def flatten_result(record):
|
||
"""
|
||
Flatten one benchmark record for CSV export.
|
||
|
||
This preserves:
|
||
- raw model values
|
||
- parsed numeric values without clipping
|
||
- validity flags
|
||
- backward-compatible columns where possible
|
||
"""
|
||
|
||
validation = record.get("validation") or {}
|
||
result = record.get("result") or {}
|
||
|
||
flat = {
|
||
"model": record.get("model"),
|
||
"iteration": record.get("iteration"),
|
||
"row_index": record.get("row_index"),
|
||
"row_number_in_run": record.get("row_number_in_run"),
|
||
"unique_id": record.get("unique_id"),
|
||
"MedDatum": record.get("MedDatum"),
|
||
|
||
"success": record.get("success"),
|
||
"error": record.get("error"),
|
||
"last_error": record.get("last_error"),
|
||
|
||
"json_parse_success": validation.get("json_parse_success"),
|
||
"required_fields_present": validation.get("required_fields_present"),
|
||
"required_schema_success": validation.get("required_schema_success"),
|
||
"clinical_range_valid": validation.get("clinical_range_valid"),
|
||
"certainty_present": validation.get("certainty_present"),
|
||
|
||
"missing_required_fields": json.dumps(
|
||
validation.get("missing_required_fields", []),
|
||
ensure_ascii=False
|
||
),
|
||
"missing_subcategory_fields": json.dumps(
|
||
validation.get("missing_subcategory_fields", []),
|
||
ensure_ascii=False
|
||
),
|
||
|
||
"inference_time_sec": record.get("inference_time_sec"),
|
||
"process_cpu_time_sec": record.get("process_cpu_time_sec"),
|
||
"rss_before_mb": record.get("rss_before_mb"),
|
||
"rss_after_mb": record.get("rss_after_mb"),
|
||
"rss_delta_mb": record.get("rss_delta_mb"),
|
||
"peak_rss_mb": record.get("peak_rss_mb"),
|
||
|
||
"prompt_tokens": record.get("prompt_tokens"),
|
||
"completion_tokens": record.get("completion_tokens"),
|
||
"total_tokens": record.get("total_tokens"),
|
||
|
||
"raw_content": record.get("raw_content"),
|
||
"raw_parsed_output": json.dumps(record.get("raw_parsed_output"), ensure_ascii=False),
|
||
|
||
# Backward-compatible fields
|
||
"reason": result.get("reason"),
|
||
"klassifizierbar": result.get("klassifizierbar"),
|
||
|
||
"raw_certainty_percent": result.get("raw_certainty_percent"),
|
||
"certainty_percent": result.get("certainty_percent"),
|
||
"certainty_percent_is_numeric": result.get("certainty_percent_is_numeric"),
|
||
"certainty_percent_in_valid_range": result.get("certainty_percent_in_valid_range"),
|
||
|
||
# EDSS raw/numeric/validity fields
|
||
"raw_EDSS": result.get("raw_EDSS"),
|
||
"EDSS_numeric": result.get("EDSS_numeric"),
|
||
"EDSS": result.get("EDSS"), # backward-compatible; same as EDSS_numeric, not clipped
|
||
"EDSS_is_numeric": result.get("EDSS_is_numeric"),
|
||
"EDSS_in_valid_range": result.get("EDSS_in_valid_range"),
|
||
|
||
"all_functional_systems_numeric": result.get("all_functional_systems_numeric"),
|
||
"all_functional_systems_in_valid_range": result.get("all_functional_systems_in_valid_range"),
|
||
}
|
||
|
||
raw_subcategories = result.get("raw_subcategories", {})
|
||
numeric_subcategories = result.get("subcategories", {})
|
||
subcat_validation = result.get("subcategory_validation", {})
|
||
|
||
for subcat in FUNCTIONAL_SYSTEM_RANGES:
|
||
raw_value = None
|
||
numeric_value = None
|
||
is_numeric = False
|
||
in_valid_range = False
|
||
|
||
if isinstance(raw_subcategories, dict):
|
||
raw_value = raw_subcategories.get(subcat)
|
||
|
||
if isinstance(numeric_subcategories, dict):
|
||
numeric_value = numeric_subcategories.get(subcat)
|
||
|
||
if isinstance(subcat_validation, dict):
|
||
flags = subcat_validation.get(subcat, {})
|
||
if isinstance(flags, dict):
|
||
is_numeric = flags.get("is_numeric", False)
|
||
in_valid_range = flags.get("in_valid_range", False)
|
||
|
||
# New transparent columns
|
||
flat[f"raw_subcat_{subcat}"] = raw_value
|
||
flat[f"numeric_subcat_{subcat}"] = numeric_value
|
||
flat[f"subcat_{subcat}_is_numeric"] = is_numeric
|
||
flat[f"subcat_{subcat}_in_valid_range"] = in_valid_range
|
||
|
||
# Backward-compatible old column name.
|
||
# This is numeric but NOT clipped.
|
||
flat[f"subcat_{subcat}"] = numeric_value
|
||
|
||
return flat
|
||
|
||
|
||
# =========================
|
||
# SUMMARY STATISTICS
|
||
# =========================
|
||
|
||
def summarize_records(records):
|
||
"""
|
||
Create transparent summary statistics per model.
|
||
|
||
Separates:
|
||
- JSON/schema validity
|
||
- numeric parse validity
|
||
- clinical range validity
|
||
- out-of-range outputs
|
||
"""
|
||
|
||
df = pd.DataFrame([flatten_result(r) for r in records])
|
||
|
||
if df.empty:
|
||
return pd.DataFrame()
|
||
|
||
def bool_mean(col):
|
||
if col not in df.columns:
|
||
return None
|
||
return df[col].fillna(False).astype(bool).mean()
|
||
|
||
def bool_sum(col):
|
||
if col not in df.columns:
|
||
return None
|
||
return int(df[col].fillna(False).astype(bool).sum())
|
||
|
||
n_records = len(df)
|
||
|
||
summary = {
|
||
"model": df["model"].iloc[0] if "model" in df.columns else None,
|
||
"n_total_responses": n_records,
|
||
|
||
"n_success": bool_sum("success"),
|
||
"success_rate": bool_mean("success"),
|
||
|
||
"n_json_parse_success": bool_sum("json_parse_success"),
|
||
"json_parse_success_rate": bool_mean("json_parse_success"),
|
||
|
||
"n_required_fields_present": bool_sum("required_fields_present"),
|
||
"required_fields_present_rate": bool_mean("required_fields_present"),
|
||
|
||
"n_required_schema_success": bool_sum("required_schema_success"),
|
||
"required_schema_success_rate": bool_mean("required_schema_success"),
|
||
|
||
"n_clinical_range_valid": bool_sum("clinical_range_valid"),
|
||
"clinical_range_valid_rate": bool_mean("clinical_range_valid"),
|
||
|
||
"n_certainty_present": bool_sum("certainty_present"),
|
||
"certainty_present_rate": bool_mean("certainty_present"),
|
||
|
||
"n_EDSS_numeric": bool_sum("EDSS_is_numeric"),
|
||
"EDSS_numeric_rate": bool_mean("EDSS_is_numeric"),
|
||
|
||
"n_EDSS_in_valid_range": bool_sum("EDSS_in_valid_range"),
|
||
"EDSS_valid_range_rate": bool_mean("EDSS_in_valid_range"),
|
||
}
|
||
|
||
# EDSS out-of-range among numeric EDSS outputs
|
||
if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns:
|
||
edss_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool)
|
||
edss_valid = df["EDSS_in_valid_range"].fillna(False).astype(bool)
|
||
edss_out_of_range = edss_numeric & (~edss_valid)
|
||
|
||
summary["n_EDSS_out_of_range"] = int(edss_out_of_range.sum())
|
||
summary["EDSS_out_of_range_rate_total"] = float(edss_out_of_range.mean())
|
||
summary["EDSS_out_of_range_rate_among_numeric"] = (
|
||
float(edss_out_of_range.sum() / edss_numeric.sum())
|
||
if edss_numeric.sum() > 0 else None
|
||
)
|
||
|
||
# Functional system rates
|
||
fs_out_of_range_any = pd.Series(False, index=df.index)
|
||
fs_valid_all = pd.Series(True, index=df.index)
|
||
|
||
for subcat in FUNCTIONAL_SYSTEM_RANGES:
|
||
numeric_col = f"subcat_{subcat}_is_numeric"
|
||
valid_col = f"subcat_{subcat}_in_valid_range"
|
||
|
||
if numeric_col in df.columns:
|
||
numeric_series = df[numeric_col].fillna(False).astype(bool)
|
||
else:
|
||
numeric_series = pd.Series(False, index=df.index)
|
||
|
||
if valid_col in df.columns:
|
||
valid_series = df[valid_col].fillna(False).astype(bool)
|
||
else:
|
||
valid_series = pd.Series(False, index=df.index)
|
||
|
||
out_of_range_series = numeric_series & (~valid_series)
|
||
|
||
summary[f"n_{subcat}_numeric"] = int(numeric_series.sum())
|
||
summary[f"{subcat}_numeric_rate"] = float(numeric_series.mean())
|
||
|
||
summary[f"n_{subcat}_in_valid_range"] = int(valid_series.sum())
|
||
summary[f"{subcat}_valid_range_rate"] = float(valid_series.mean())
|
||
|
||
summary[f"n_{subcat}_out_of_range"] = int(out_of_range_series.sum())
|
||
summary[f"{subcat}_out_of_range_rate_total"] = float(out_of_range_series.mean())
|
||
summary[f"{subcat}_out_of_range_rate_among_numeric"] = (
|
||
float(out_of_range_series.sum() / numeric_series.sum())
|
||
if numeric_series.sum() > 0 else None
|
||
)
|
||
|
||
fs_out_of_range_any = fs_out_of_range_any | out_of_range_series
|
||
fs_valid_all = fs_valid_all & valid_series
|
||
|
||
summary["n_any_functional_system_out_of_range"] = int(fs_out_of_range_any.sum())
|
||
summary["any_functional_system_out_of_range_rate_total"] = float(fs_out_of_range_any.mean())
|
||
|
||
summary["n_all_functional_systems_in_valid_range"] = int(fs_valid_all.sum())
|
||
summary["all_functional_systems_valid_range_rate"] = float(fs_valid_all.mean())
|
||
|
||
numeric_cols = [
|
||
"inference_time_sec",
|
||
"process_cpu_time_sec",
|
||
"rss_delta_mb",
|
||
"peak_rss_mb",
|
||
"prompt_tokens",
|
||
"completion_tokens",
|
||
"total_tokens",
|
||
"certainty_percent",
|
||
"EDSS_numeric",
|
||
]
|
||
|
||
for col in numeric_cols:
|
||
if col in df.columns:
|
||
values = pd.to_numeric(df[col], errors="coerce")
|
||
summary[f"{col}_mean"] = values.mean()
|
||
summary[f"{col}_median"] = values.median()
|
||
summary[f"{col}_std"] = values.std()
|
||
summary[f"{col}_min"] = values.min()
|
||
summary[f"{col}_max"] = values.max()
|
||
|
||
if "EDSS_is_numeric" in df.columns and "EDSS_in_valid_range" in df.columns:
|
||
primary_valid_only = (
|
||
df["EDSS_is_numeric"].fillna(False).astype(bool)
|
||
& df["EDSS_in_valid_range"].fillna(False).astype(bool)
|
||
)
|
||
|
||
sensitivity_all_numeric = df["EDSS_is_numeric"].fillna(False).astype(bool)
|
||
|
||
summary["n_primary_valid_only_EDSS"] = int(primary_valid_only.sum())
|
||
summary["primary_valid_only_EDSS_rate"] = float(primary_valid_only.mean())
|
||
|
||
summary["n_sensitivity_all_numeric_EDSS"] = int(sensitivity_all_numeric.sum())
|
||
summary["sensitivity_all_numeric_EDSS_rate"] = float(sensitivity_all_numeric.mean())
|
||
|
||
return pd.DataFrame([summary])
|
||
|
||
|
||
# =========================
|
||
# ANALYSIS DATASET HELPERS
|
||
# =========================
|
||
|
||
def create_analysis_datasets(records):
|
||
"""
|
||
Create two transparent EDSS analysis datasets:
|
||
|
||
1. primary_valid_only:
|
||
Only numeric EDSS predictions within the valid clinical range.
|
||
|
||
2. sensitivity_all_numeric:
|
||
All numeric EDSS predictions, including out-of-range values.
|
||
No clipping is applied.
|
||
"""
|
||
|
||
df = pd.DataFrame([flatten_result(r) for r in records])
|
||
|
||
if df.empty:
|
||
return df.copy(), df.copy()
|
||
|
||
primary_valid_only = df[
|
||
df["EDSS_is_numeric"].fillna(False).astype(bool)
|
||
& df["EDSS_in_valid_range"].fillna(False).astype(bool)
|
||
].copy()
|
||
|
||
sensitivity_all_numeric = df[
|
||
df["EDSS_is_numeric"].fillna(False).astype(bool)
|
||
].copy()
|
||
|
||
return primary_valid_only, sensitivity_all_numeric
|
||
|
||
|
||
# =========================
|
||
# INCREMENTAL SAVE HELPERS
|
||
# =========================
|
||
|
||
def append_jsonl(path, record):
|
||
with open(path, "a", encoding="utf-8") as f:
|
||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||
f.flush()
|
||
os.fsync(f.fileno())
|
||
|
||
|
||
def append_csv(path, record):
|
||
flat = flatten_result(record)
|
||
df_one = pd.DataFrame([flat])
|
||
file_exists = Path(path).exists()
|
||
df_one.to_csv(path, mode="a", header=not file_exists, index=False)
|
||
|
||
|
||
# =========================
|
||
# MAIN LOOP
|
||
# =========================
|
||
|
||
if __name__ == "__main__":
|
||
|
||
run_timestamp = now_timestamp()
|
||
|
||
results_root = Path(RESULTS_ROOT)
|
||
results_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
run_root = results_root / f"run_{run_timestamp}"
|
||
run_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"Results root: {run_root}")
|
||
|
||
df = pd.read_csv(INPUT_CSV, sep=";")
|
||
|
||
if MAX_ROWS is not None:
|
||
df = df.head(MAX_ROWS)
|
||
|
||
total_rows = len(df)
|
||
|
||
model_names_for_print = [m["model_name"] for m in MODEL_CONFIGS]
|
||
|
||
print(f"Loaded {total_rows} patient records.")
|
||
print(f"Models: {model_names_for_print}")
|
||
print(f"Iterations per model: {NUM_ITERATIONS}")
|
||
|
||
all_model_summaries = []
|
||
|
||
for model_config in MODEL_CONFIGS:
|
||
model_name = model_config["model_name"]
|
||
safe_model = safe_dir_name(model_name)
|
||
|
||
model_dir = run_root / safe_model
|
||
model_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(f"\n{'#' * 80}")
|
||
print(f"MODEL: {model_name}")
|
||
print(f"use_response_format: {model_config.get('use_response_format', False)}")
|
||
print(f"temperature: {model_config.get('temperature', TEMPERATURE)}")
|
||
print(f"max_tokens: {model_config.get('max_tokens', MAX_TOKENS)}")
|
||
print(f"Saving to: {model_dir}")
|
||
print(f"{'#' * 80}")
|
||
|
||
model_records = []
|
||
model_start = time.perf_counter()
|
||
|
||
for iteration in range(1, NUM_ITERATIONS + 1):
|
||
print(f"\n{'=' * 60}")
|
||
print(f"🔄 MODEL {model_name} | ITERATION {iteration}/{NUM_ITERATIONS}")
|
||
print(f"{'=' * 60}")
|
||
|
||
iteration_results = []
|
||
iteration_start = time.perf_counter()
|
||
|
||
incremental_jsonl_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.jsonl"
|
||
incremental_csv_path = model_dir / f"{safe_model}_iter_{iteration}_{run_timestamp}_incremental.csv"
|
||
|
||
print(f"Incremental JSONL: {incremental_jsonl_path}")
|
||
print(f"Incremental CSV: {incremental_csv_path}")
|
||
|
||
for loop_i, (idx, row) in enumerate(df.iterrows(), start=1):
|
||
print(
|
||
f"\rModel={model_name} | Row {loop_i}/{total_rows} | Iter {iteration}",
|
||
end="",
|
||
flush=True
|
||
)
|
||
|
||
try:
|
||
patient_text = build_patient_text(row)
|
||
|
||
record = run_inference(
|
||
patient_text=patient_text,
|
||
model_config=model_config
|
||
)
|
||
|
||
record["iteration"] = iteration
|
||
record["row_index"] = int(idx)
|
||
record["row_number_in_run"] = int(loop_i)
|
||
record["unique_id"] = row.get("unique_id", f"row_{idx}")
|
||
record["MedDatum"] = row.get("MedDatum", None)
|
||
|
||
iteration_results.append(record)
|
||
model_records.append(record)
|
||
|
||
if loop_i % SAVE_EVERY_N_ROWS == 0:
|
||
append_jsonl(incremental_jsonl_path, record)
|
||
append_csv(incremental_csv_path, record)
|
||
|
||
if record["success"]:
|
||
res = record["result"] or {}
|
||
edss_display = res.get("EDSS_numeric", None)
|
||
edss_valid = res.get("EDSS_in_valid_range", False)
|
||
|
||
print(
|
||
f" ✅ EDSS={edss_display}, "
|
||
f"valid_range={edss_valid}, "
|
||
f"time={record['inference_time_sec']:.2f}s"
|
||
)
|
||
else:
|
||
print(f" ❌ {record.get('error', 'Unknown error')}")
|
||
|
||
except Exception as e:
|
||
print(f"\n⚠️ Row {idx} failed outside inference wrapper: {e}")
|
||
|
||
fallback_record = {
|
||
"success": False,
|
||
"error": str(e),
|
||
"last_error": str(e),
|
||
"result": None,
|
||
|
||
"validation": {
|
||
"json_parse_success": False,
|
||
"required_fields_present": False,
|
||
"required_schema_success": False,
|
||
"clinical_range_valid": False,
|
||
"certainty_present": False,
|
||
"missing_required_fields": [],
|
||
"missing_subcategory_fields": [],
|
||
"EDSS_is_numeric": False,
|
||
"EDSS_in_valid_range": False,
|
||
},
|
||
"raw_parsed_output": None,
|
||
|
||
"model": model_name,
|
||
"iteration": iteration,
|
||
"row_index": int(idx),
|
||
"row_number_in_run": int(loop_i),
|
||
"unique_id": row.get("unique_id", f"row_{idx}"),
|
||
"MedDatum": row.get("MedDatum", None),
|
||
|
||
"inference_time_sec": None,
|
||
"process_cpu_time_sec": None,
|
||
"rss_before_mb": None,
|
||
"rss_after_mb": None,
|
||
"rss_delta_mb": None,
|
||
"peak_rss_mb": None,
|
||
|
||
"prompt_tokens": None,
|
||
"completion_tokens": None,
|
||
"total_tokens": None,
|
||
|
||
"raw_content": None,
|
||
"raw_response_debug": None,
|
||
}
|
||
|
||
iteration_results.append(fallback_record)
|
||
model_records.append(fallback_record)
|
||
|
||
append_jsonl(incremental_jsonl_path, fallback_record)
|
||
append_csv(incremental_csv_path, fallback_record)
|
||
|
||
if STOP_ON_FIRST_ERROR:
|
||
break
|
||
|
||
iteration_elapsed = time.perf_counter() - iteration_start
|
||
|
||
# Final full per-iteration JSON
|
||
iter_json_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.json"
|
||
with open(iter_json_path, "w", encoding="utf-8") as f:
|
||
json.dump(iteration_results, f, indent=2, ensure_ascii=False)
|
||
|
||
# Final full per-iteration CSV
|
||
iter_csv_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}.csv"
|
||
iter_flat_df = pd.DataFrame([flatten_result(r) for r in iteration_results])
|
||
iter_flat_df.to_csv(iter_csv_path, index=False)
|
||
|
||
# Transparent analysis datasets
|
||
primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(iteration_results)
|
||
|
||
primary_valid_only_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_primary_valid_only.csv"
|
||
sensitivity_all_numeric_path = model_dir / f"{safe_model}_results_iter_{iteration}_{run_timestamp}_sensitivity_all_numeric.csv"
|
||
|
||
primary_valid_only_df.to_csv(primary_valid_only_path, index=False)
|
||
sensitivity_all_numeric_df.to_csv(sensitivity_all_numeric_path, index=False)
|
||
|
||
print(f"\n✅ Iteration {iteration} complete.")
|
||
print(f"Incremental JSONL saved to: {incremental_jsonl_path}")
|
||
print(f"Incremental CSV saved to: {incremental_csv_path}")
|
||
print(f"Final JSON saved to: {iter_json_path}")
|
||
print(f"Final CSV saved to: {iter_csv_path}")
|
||
print(f"Primary valid-only CSV saved to: {primary_valid_only_path}")
|
||
print(f"Sensitivity all-numeric CSV: {sensitivity_all_numeric_path}")
|
||
print(
|
||
f"⏱️ Iteration time: {iteration_elapsed:.1f}s "
|
||
f"({iteration_elapsed / max(total_rows, 1):.2f}s/row)"
|
||
)
|
||
|
||
model_elapsed = time.perf_counter() - model_start
|
||
|
||
# Save all records for this model
|
||
model_json_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.json"
|
||
with open(model_json_path, "w", encoding="utf-8") as f:
|
||
json.dump(model_records, f, indent=2, ensure_ascii=False)
|
||
|
||
model_csv_path = model_dir / f"{safe_model}_all_results_{run_timestamp}.csv"
|
||
model_flat_df = pd.DataFrame([flatten_result(r) for r in model_records])
|
||
model_flat_df.to_csv(model_csv_path, index=False)
|
||
|
||
# Save model-level analysis datasets
|
||
primary_valid_only_df, sensitivity_all_numeric_df = create_analysis_datasets(model_records)
|
||
|
||
model_primary_valid_only_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_primary_valid_only.csv"
|
||
model_sensitivity_all_numeric_path = model_dir / f"{safe_model}_all_results_{run_timestamp}_sensitivity_all_numeric.csv"
|
||
|
||
primary_valid_only_df.to_csv(model_primary_valid_only_path, index=False)
|
||
sensitivity_all_numeric_df.to_csv(model_sensitivity_all_numeric_path, index=False)
|
||
|
||
# Save model summary
|
||
model_summary_df = summarize_records(model_records)
|
||
model_summary_df["model_total_wall_time_sec"] = model_elapsed
|
||
model_summary_df["model_total_wall_time_min"] = model_elapsed / 60
|
||
|
||
model_summary_path = model_dir / f"{safe_model}_summary_{run_timestamp}.csv"
|
||
model_summary_df.to_csv(model_summary_path, index=False)
|
||
|
||
all_model_summaries.append(model_summary_df)
|
||
|
||
print(f"\n🎉 Model completed: {model_name}")
|
||
print(f"All JSON: {model_json_path}")
|
||
print(f"All CSV: {model_csv_path}")
|
||
print(f"All primary valid-only CSV: {model_primary_valid_only_path}")
|
||
print(f"All sensitivity all-numeric CSV: {model_sensitivity_all_numeric_path}")
|
||
print(f"Summary: {model_summary_path}")
|
||
print(f"Total model time: {model_elapsed / 60:.2f} min")
|
||
|
||
if all_model_summaries:
|
||
combined_summary_df = pd.concat(all_model_summaries, ignore_index=True)
|
||
combined_summary_path = run_root / f"all_models_summary_{run_timestamp}.csv"
|
||
combined_summary_df.to_csv(combined_summary_path, index=False)
|
||
|
||
print(f"\n📊 Combined summary saved to: {combined_summary_path}")
|
||
|
||
print(f"\n🎉 All models and all iterations completed!")
|
||
##
|
||
|