Added Loop for multiple models.
This commit is contained in:
@@ -10,6 +10,11 @@ __pycache__/
|
|||||||
=======
|
=======
|
||||||
/reference/
|
/reference/
|
||||||
*.svg
|
*.svg
|
||||||
|
**/*.csv
|
||||||
|
**/*.json*
|
||||||
|
**/*.txt*
|
||||||
|
**/*.png*
|
||||||
|
*.log
|
||||||
>>>>>>> Stashed changes
|
>>>>>>> Stashed changes
|
||||||
# 2. Ignore virtual environments COMPLETELY
|
# 2. Ignore virtual environments COMPLETELY
|
||||||
# This must come BEFORE the unignore rule
|
# This must come BEFORE the unignore rule
|
||||||
|
|||||||
+3138
-137
File diff suppressed because it is too large
Load Diff
+285
@@ -3118,3 +3118,288 @@ plt.savefig(figure_save_path, format="svg", bbox_inches="tight")
|
|||||||
plt.show()
|
plt.show()
|
||||||
##
|
##
|
||||||
|
|
||||||
|
|
||||||
|
# %% Confusion matrix for one EDSS benchmark result file
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
from sklearn.metrics import confusion_matrix, classification_report
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# CONFIGURATION
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||||||
|
|
||||||
|
RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv"
|
||||||
|
|
||||||
|
OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices"
|
||||||
|
|
||||||
|
TARGET_ITERATION = 1
|
||||||
|
|
||||||
|
MERGE_KEY = "unique_id"
|
||||||
|
|
||||||
|
# Ground truth EDSS column in the reference file
|
||||||
|
GT_EDSS_COL = "EDSS"
|
||||||
|
|
||||||
|
# Predicted EDSS column in the result file
|
||||||
|
PRED_EDSS_COL = "EDSS"
|
||||||
|
|
||||||
|
EDSS_LABELS = [
|
||||||
|
"0-1", "1-2", "2-3", "3-4", "4-5",
|
||||||
|
"5-6", "6-7", "7-8", "8-9", "9-10"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# HELPERS
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
def safe_filename(name):
|
||||||
|
return (
|
||||||
|
str(name)
|
||||||
|
.replace("/", "_")
|
||||||
|
.replace("\\", "_")
|
||||||
|
.replace(" ", "_")
|
||||||
|
.replace(":", "_")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_numeric_column(series):
|
||||||
|
return pd.to_numeric(
|
||||||
|
series.astype(str).str.replace(",", ".", regex=False),
|
||||||
|
errors="coerce"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def categorize_edss(value):
|
||||||
|
if pd.isna(value):
|
||||||
|
return np.nan
|
||||||
|
elif value <= 1.0:
|
||||||
|
return "0-1"
|
||||||
|
elif value <= 2.0:
|
||||||
|
return "1-2"
|
||||||
|
elif value <= 3.0:
|
||||||
|
return "2-3"
|
||||||
|
elif value <= 4.0:
|
||||||
|
return "3-4"
|
||||||
|
elif value <= 5.0:
|
||||||
|
return "4-5"
|
||||||
|
elif value <= 6.0:
|
||||||
|
return "5-6"
|
||||||
|
elif value <= 7.0:
|
||||||
|
return "6-7"
|
||||||
|
elif value <= 8.0:
|
||||||
|
return "7-8"
|
||||||
|
elif value <= 9.0:
|
||||||
|
return "8-9"
|
||||||
|
elif value <= 10.0:
|
||||||
|
return "9-10"
|
||||||
|
else:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
|
||||||
|
def load_reference(reference_path):
|
||||||
|
df_ref = pd.read_csv(reference_path, sep=";")
|
||||||
|
|
||||||
|
if MERGE_KEY not in df_ref.columns:
|
||||||
|
raise ValueError(f"Reference file does not contain column: {MERGE_KEY}")
|
||||||
|
|
||||||
|
if GT_EDSS_COL not in df_ref.columns:
|
||||||
|
raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}")
|
||||||
|
|
||||||
|
df_ref = df_ref.copy()
|
||||||
|
df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str)
|
||||||
|
|
||||||
|
df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL])
|
||||||
|
df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss)
|
||||||
|
|
||||||
|
return df_ref
|
||||||
|
|
||||||
|
|
||||||
|
def load_result(result_path):
|
||||||
|
df_res = pd.read_csv(result_path, sep=",")
|
||||||
|
|
||||||
|
if MERGE_KEY not in df_res.columns:
|
||||||
|
raise ValueError(f"Result file does not contain column: {MERGE_KEY}")
|
||||||
|
|
||||||
|
if PRED_EDSS_COL not in df_res.columns:
|
||||||
|
raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}")
|
||||||
|
|
||||||
|
df_res = df_res.copy()
|
||||||
|
df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str)
|
||||||
|
|
||||||
|
if "success" in df_res.columns:
|
||||||
|
df_res = df_res[
|
||||||
|
df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"])
|
||||||
|
]
|
||||||
|
|
||||||
|
if TARGET_ITERATION is not None and "iteration" in df_res.columns:
|
||||||
|
df_res = df_res[df_res["iteration"] == TARGET_ITERATION]
|
||||||
|
|
||||||
|
df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL])
|
||||||
|
df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss)
|
||||||
|
|
||||||
|
return df_res
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_name(df_res, result_path):
|
||||||
|
if "model" in df_res.columns and df_res["model"].notna().any():
|
||||||
|
return str(df_res["model"].dropna().iloc[0])
|
||||||
|
|
||||||
|
return Path(result_path).stem
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(cm, model_name, output_path):
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
|
||||||
|
ax = sns.heatmap(
|
||||||
|
cm,
|
||||||
|
annot=True,
|
||||||
|
fmt="d",
|
||||||
|
cmap="Blues",
|
||||||
|
xticklabels=EDSS_LABELS,
|
||||||
|
yticklabels=EDSS_LABELS
|
||||||
|
)
|
||||||
|
|
||||||
|
cbar = ax.collections[0].colorbar
|
||||||
|
cbar.set_label("Number of Cases", rotation=270, labelpad=20)
|
||||||
|
|
||||||
|
plt.xlabel("LLM Generated EDSS")
|
||||||
|
plt.ylabel("Ground Truth EDSS")
|
||||||
|
plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# MAIN
|
||||||
|
# =========================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
output_dir = Path(OUTPUT_DIR)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
print("Loading reference:")
|
||||||
|
print(REFERENCE_PATH)
|
||||||
|
|
||||||
|
df_ref = load_reference(REFERENCE_PATH)
|
||||||
|
|
||||||
|
print(f"Reference rows: {len(df_ref)}")
|
||||||
|
print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}")
|
||||||
|
|
||||||
|
print("\nLoading result:")
|
||||||
|
print(RESULT_PATH)
|
||||||
|
|
||||||
|
df_res = load_result(RESULT_PATH)
|
||||||
|
|
||||||
|
model_name = get_model_name(df_res, RESULT_PATH)
|
||||||
|
safe_model = safe_filename(model_name)
|
||||||
|
|
||||||
|
print(f"Model: {model_name}")
|
||||||
|
print(f"Result rows after filtering: {len(df_res)}")
|
||||||
|
|
||||||
|
before_dedup = len(df_res)
|
||||||
|
df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first")
|
||||||
|
after_dedup = len(df_res)
|
||||||
|
|
||||||
|
if before_dedup != after_dedup:
|
||||||
|
print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}")
|
||||||
|
|
||||||
|
df_merged = df_ref.merge(
|
||||||
|
df_res,
|
||||||
|
on=MERGE_KEY,
|
||||||
|
how="inner",
|
||||||
|
suffixes=("_gt", "_pred")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Merged rows: {len(df_merged)}")
|
||||||
|
|
||||||
|
df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy()
|
||||||
|
|
||||||
|
print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}")
|
||||||
|
|
||||||
|
if df_eval.empty:
|
||||||
|
raise ValueError("No evaluable rows after merging and EDSS filtering.")
|
||||||
|
|
||||||
|
cm = confusion_matrix(
|
||||||
|
df_eval["GT_EDSS_cat"],
|
||||||
|
df_eval["PRED_EDSS_cat"],
|
||||||
|
labels=EDSS_LABELS
|
||||||
|
)
|
||||||
|
|
||||||
|
suffix = f"iter_{TARGET_ITERATION}"
|
||||||
|
|
||||||
|
plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png"
|
||||||
|
cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv"
|
||||||
|
report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt"
|
||||||
|
merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv"
|
||||||
|
|
||||||
|
plot_confusion_matrix(cm, model_name, plot_path)
|
||||||
|
|
||||||
|
cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS)
|
||||||
|
cm_df.index.name = "Ground Truth EDSS"
|
||||||
|
cm_df.columns.name = "LLM Generated EDSS"
|
||||||
|
cm_df.to_csv(cm_csv_path)
|
||||||
|
|
||||||
|
report = classification_report(
|
||||||
|
df_eval["GT_EDSS_cat"],
|
||||||
|
df_eval["PRED_EDSS_cat"],
|
||||||
|
labels=EDSS_LABELS,
|
||||||
|
zero_division=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(report_txt_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(f"Model: {model_name}\n")
|
||||||
|
f.write(f"Result file: {RESULT_PATH}\n")
|
||||||
|
f.write(f"Target iteration: {TARGET_ITERATION}\n")
|
||||||
|
f.write(f"Merged rows: {len(df_merged)}\n")
|
||||||
|
f.write(f"Evaluable rows: {len(df_eval)}\n\n")
|
||||||
|
f.write("Classification Report:\n")
|
||||||
|
f.write(report)
|
||||||
|
f.write("\n\nConfusion Matrix Raw Counts:\n")
|
||||||
|
f.write(cm_df.to_string())
|
||||||
|
|
||||||
|
keep_cols = [
|
||||||
|
MERGE_KEY,
|
||||||
|
"MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum",
|
||||||
|
"GT_EDSS_numeric",
|
||||||
|
"PRED_EDSS_numeric",
|
||||||
|
"GT_EDSS_cat",
|
||||||
|
"PRED_EDSS_cat",
|
||||||
|
"model",
|
||||||
|
"iteration",
|
||||||
|
"success",
|
||||||
|
"inference_time_sec",
|
||||||
|
"certainty_percent",
|
||||||
|
"reason",
|
||||||
|
]
|
||||||
|
|
||||||
|
keep_cols = [col for col in keep_cols if col in df_eval.columns]
|
||||||
|
df_eval[keep_cols].to_csv(merged_csv_path, index=False)
|
||||||
|
|
||||||
|
print("\nClassification Report:")
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
print("\nConfusion Matrix Raw Counts:")
|
||||||
|
print(cm_df)
|
||||||
|
|
||||||
|
print("\nSaved files:")
|
||||||
|
print(f"Plot: {plot_path}")
|
||||||
|
print(f"Confusion matrix: {cm_csv_path}")
|
||||||
|
print(f"Report: {report_txt_path}")
|
||||||
|
print(f"Merged rows: {merged_csv_path}")
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
|
##
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user