Added Loop for multiple models.

This commit is contained in:
2026-05-16 16:50:33 +02:00
parent f6ec60e685
commit 590f2cd68e
4 changed files with 3447 additions and 159 deletions
+5
View File
@@ -10,6 +10,11 @@ __pycache__/
=======
/reference/
*.svg
**/*.csv
**/*.json*
**/*.txt*
**/*.png*
*.log
>>>>>>> Stashed changes
# 2. Ignore virtual environments COMPLETELY
# This must come BEFORE the unignore rule
-3
View File
@@ -216,6 +216,3 @@ if __name__ == "__main__":
# %% name
eXXXXXXXX
##
+3140 -139
View File
File diff suppressed because it is too large Load Diff
+285
View File
@@ -3118,3 +3118,288 @@ plt.savefig(figure_save_path, format="svg", bbox_inches="tight")
plt.show()
##
# %% Confusion matrix for one EDSS benchmark result file
import os
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
# =========================
# CONFIGURATION
# =========================
REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv"
OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices"
TARGET_ITERATION = 1
MERGE_KEY = "unique_id"
# Ground truth EDSS column in the reference file
GT_EDSS_COL = "EDSS"
# Predicted EDSS column in the result file
PRED_EDSS_COL = "EDSS"
EDSS_LABELS = [
"0-1", "1-2", "2-3", "3-4", "4-5",
"5-6", "6-7", "7-8", "8-9", "9-10"
]
# =========================
# HELPERS
# =========================
def safe_filename(name):
return (
str(name)
.replace("/", "_")
.replace("\\", "_")
.replace(" ", "_")
.replace(":", "_")
)
def parse_numeric_column(series):
return pd.to_numeric(
series.astype(str).str.replace(",", ".", regex=False),
errors="coerce"
)
def categorize_edss(value):
if pd.isna(value):
return np.nan
elif value <= 1.0:
return "0-1"
elif value <= 2.0:
return "1-2"
elif value <= 3.0:
return "2-3"
elif value <= 4.0:
return "3-4"
elif value <= 5.0:
return "4-5"
elif value <= 6.0:
return "5-6"
elif value <= 7.0:
return "6-7"
elif value <= 8.0:
return "7-8"
elif value <= 9.0:
return "8-9"
elif value <= 10.0:
return "9-10"
else:
return np.nan
def load_reference(reference_path):
df_ref = pd.read_csv(reference_path, sep=";")
if MERGE_KEY not in df_ref.columns:
raise ValueError(f"Reference file does not contain column: {MERGE_KEY}")
if GT_EDSS_COL not in df_ref.columns:
raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}")
df_ref = df_ref.copy()
df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str)
df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL])
df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss)
return df_ref
def load_result(result_path):
df_res = pd.read_csv(result_path, sep=",")
if MERGE_KEY not in df_res.columns:
raise ValueError(f"Result file does not contain column: {MERGE_KEY}")
if PRED_EDSS_COL not in df_res.columns:
raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}")
df_res = df_res.copy()
df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str)
if "success" in df_res.columns:
df_res = df_res[
df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"])
]
if TARGET_ITERATION is not None and "iteration" in df_res.columns:
df_res = df_res[df_res["iteration"] == TARGET_ITERATION]
df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL])
df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss)
return df_res
def get_model_name(df_res, result_path):
if "model" in df_res.columns and df_res["model"].notna().any():
return str(df_res["model"].dropna().iloc[0])
return Path(result_path).stem
def plot_confusion_matrix(cm, model_name, output_path):
plt.figure(figsize=(10, 8))
ax = sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=EDSS_LABELS,
yticklabels=EDSS_LABELS
)
cbar = ax.collections[0].colorbar
cbar.set_label("Number of Cases", rotation=270, labelpad=20)
plt.xlabel("LLM Generated EDSS")
plt.ylabel("Ground Truth EDSS")
plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}")
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.show()
# =========================
# MAIN
# =========================
if __name__ == "__main__":
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
print("Loading reference:")
print(REFERENCE_PATH)
df_ref = load_reference(REFERENCE_PATH)
print(f"Reference rows: {len(df_ref)}")
print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}")
print("\nLoading result:")
print(RESULT_PATH)
df_res = load_result(RESULT_PATH)
model_name = get_model_name(df_res, RESULT_PATH)
safe_model = safe_filename(model_name)
print(f"Model: {model_name}")
print(f"Result rows after filtering: {len(df_res)}")
before_dedup = len(df_res)
df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first")
after_dedup = len(df_res)
if before_dedup != after_dedup:
print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}")
df_merged = df_ref.merge(
df_res,
on=MERGE_KEY,
how="inner",
suffixes=("_gt", "_pred")
)
print(f"Merged rows: {len(df_merged)}")
df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy()
print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}")
if df_eval.empty:
raise ValueError("No evaluable rows after merging and EDSS filtering.")
cm = confusion_matrix(
df_eval["GT_EDSS_cat"],
df_eval["PRED_EDSS_cat"],
labels=EDSS_LABELS
)
suffix = f"iter_{TARGET_ITERATION}"
plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png"
cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv"
report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt"
merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv"
plot_confusion_matrix(cm, model_name, plot_path)
cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS)
cm_df.index.name = "Ground Truth EDSS"
cm_df.columns.name = "LLM Generated EDSS"
cm_df.to_csv(cm_csv_path)
report = classification_report(
df_eval["GT_EDSS_cat"],
df_eval["PRED_EDSS_cat"],
labels=EDSS_LABELS,
zero_division=0
)
with open(report_txt_path, "w", encoding="utf-8") as f:
f.write(f"Model: {model_name}\n")
f.write(f"Result file: {RESULT_PATH}\n")
f.write(f"Target iteration: {TARGET_ITERATION}\n")
f.write(f"Merged rows: {len(df_merged)}\n")
f.write(f"Evaluable rows: {len(df_eval)}\n\n")
f.write("Classification Report:\n")
f.write(report)
f.write("\n\nConfusion Matrix Raw Counts:\n")
f.write(cm_df.to_string())
keep_cols = [
MERGE_KEY,
"MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum",
"GT_EDSS_numeric",
"PRED_EDSS_numeric",
"GT_EDSS_cat",
"PRED_EDSS_cat",
"model",
"iteration",
"success",
"inference_time_sec",
"certainty_percent",
"reason",
]
keep_cols = [col for col in keep_cols if col in df_eval.columns]
df_eval[keep_cols].to_csv(merged_csv_path, index=False)
print("\nClassification Report:")
print(report)
print("\nConfusion Matrix Raw Counts:")
print(cm_df)
print("\nSaved files:")
print(f"Plot: {plot_path}")
print(f"Confusion matrix: {cm_csv_path}")
print(f"Report: {report_txt_path}")
print(f"Merged rows: {merged_csv_path}")
print("\nDone.")
##