Added Loop for multiple models.
This commit is contained in:
@@ -10,6 +10,11 @@ __pycache__/
|
||||
=======
|
||||
/reference/
|
||||
*.svg
|
||||
**/*.csv
|
||||
**/*.json*
|
||||
**/*.txt*
|
||||
**/*.png*
|
||||
*.log
|
||||
>>>>>>> Stashed changes
|
||||
# 2. Ignore virtual environments COMPLETELY
|
||||
# This must come BEFORE the unignore rule
|
||||
|
||||
+3157
-156
File diff suppressed because it is too large
Load Diff
+285
@@ -3118,3 +3118,288 @@ plt.savefig(figure_save_path, format="svg", bbox_inches="tight")
|
||||
plt.show()
|
||||
##
|
||||
|
||||
|
||||
# %% Confusion matrix for one EDSS benchmark result file
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
from sklearn.metrics import confusion_matrix, classification_report
|
||||
|
||||
|
||||
# =========================
|
||||
# CONFIGURATION
|
||||
# =========================
|
||||
|
||||
REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv"
|
||||
|
||||
RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv"
|
||||
|
||||
OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices"
|
||||
|
||||
TARGET_ITERATION = 1
|
||||
|
||||
MERGE_KEY = "unique_id"
|
||||
|
||||
# Ground truth EDSS column in the reference file
|
||||
GT_EDSS_COL = "EDSS"
|
||||
|
||||
# Predicted EDSS column in the result file
|
||||
PRED_EDSS_COL = "EDSS"
|
||||
|
||||
EDSS_LABELS = [
|
||||
"0-1", "1-2", "2-3", "3-4", "4-5",
|
||||
"5-6", "6-7", "7-8", "8-9", "9-10"
|
||||
]
|
||||
|
||||
|
||||
# =========================
|
||||
# HELPERS
|
||||
# =========================
|
||||
|
||||
def safe_filename(name):
|
||||
return (
|
||||
str(name)
|
||||
.replace("/", "_")
|
||||
.replace("\\", "_")
|
||||
.replace(" ", "_")
|
||||
.replace(":", "_")
|
||||
)
|
||||
|
||||
|
||||
def parse_numeric_column(series):
|
||||
return pd.to_numeric(
|
||||
series.astype(str).str.replace(",", ".", regex=False),
|
||||
errors="coerce"
|
||||
)
|
||||
|
||||
|
||||
def categorize_edss(value):
|
||||
if pd.isna(value):
|
||||
return np.nan
|
||||
elif value <= 1.0:
|
||||
return "0-1"
|
||||
elif value <= 2.0:
|
||||
return "1-2"
|
||||
elif value <= 3.0:
|
||||
return "2-3"
|
||||
elif value <= 4.0:
|
||||
return "3-4"
|
||||
elif value <= 5.0:
|
||||
return "4-5"
|
||||
elif value <= 6.0:
|
||||
return "5-6"
|
||||
elif value <= 7.0:
|
||||
return "6-7"
|
||||
elif value <= 8.0:
|
||||
return "7-8"
|
||||
elif value <= 9.0:
|
||||
return "8-9"
|
||||
elif value <= 10.0:
|
||||
return "9-10"
|
||||
else:
|
||||
return np.nan
|
||||
|
||||
|
||||
def load_reference(reference_path):
|
||||
df_ref = pd.read_csv(reference_path, sep=";")
|
||||
|
||||
if MERGE_KEY not in df_ref.columns:
|
||||
raise ValueError(f"Reference file does not contain column: {MERGE_KEY}")
|
||||
|
||||
if GT_EDSS_COL not in df_ref.columns:
|
||||
raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}")
|
||||
|
||||
df_ref = df_ref.copy()
|
||||
df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str)
|
||||
|
||||
df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL])
|
||||
df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss)
|
||||
|
||||
return df_ref
|
||||
|
||||
|
||||
def load_result(result_path):
|
||||
df_res = pd.read_csv(result_path, sep=",")
|
||||
|
||||
if MERGE_KEY not in df_res.columns:
|
||||
raise ValueError(f"Result file does not contain column: {MERGE_KEY}")
|
||||
|
||||
if PRED_EDSS_COL not in df_res.columns:
|
||||
raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}")
|
||||
|
||||
df_res = df_res.copy()
|
||||
df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str)
|
||||
|
||||
if "success" in df_res.columns:
|
||||
df_res = df_res[
|
||||
df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"])
|
||||
]
|
||||
|
||||
if TARGET_ITERATION is not None and "iteration" in df_res.columns:
|
||||
df_res = df_res[df_res["iteration"] == TARGET_ITERATION]
|
||||
|
||||
df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL])
|
||||
df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss)
|
||||
|
||||
return df_res
|
||||
|
||||
|
||||
def get_model_name(df_res, result_path):
|
||||
if "model" in df_res.columns and df_res["model"].notna().any():
|
||||
return str(df_res["model"].dropna().iloc[0])
|
||||
|
||||
return Path(result_path).stem
|
||||
|
||||
|
||||
def plot_confusion_matrix(cm, model_name, output_path):
|
||||
plt.figure(figsize=(10, 8))
|
||||
|
||||
ax = sns.heatmap(
|
||||
cm,
|
||||
annot=True,
|
||||
fmt="d",
|
||||
cmap="Blues",
|
||||
xticklabels=EDSS_LABELS,
|
||||
yticklabels=EDSS_LABELS
|
||||
)
|
||||
|
||||
cbar = ax.collections[0].colorbar
|
||||
cbar.set_label("Number of Cases", rotation=270, labelpad=20)
|
||||
|
||||
plt.xlabel("LLM Generated EDSS")
|
||||
plt.ylabel("Ground Truth EDSS")
|
||||
plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
|
||||
# =========================
|
||||
# MAIN
|
||||
# =========================
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
output_dir = Path(OUTPUT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("Loading reference:")
|
||||
print(REFERENCE_PATH)
|
||||
|
||||
df_ref = load_reference(REFERENCE_PATH)
|
||||
|
||||
print(f"Reference rows: {len(df_ref)}")
|
||||
print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}")
|
||||
|
||||
print("\nLoading result:")
|
||||
print(RESULT_PATH)
|
||||
|
||||
df_res = load_result(RESULT_PATH)
|
||||
|
||||
model_name = get_model_name(df_res, RESULT_PATH)
|
||||
safe_model = safe_filename(model_name)
|
||||
|
||||
print(f"Model: {model_name}")
|
||||
print(f"Result rows after filtering: {len(df_res)}")
|
||||
|
||||
before_dedup = len(df_res)
|
||||
df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first")
|
||||
after_dedup = len(df_res)
|
||||
|
||||
if before_dedup != after_dedup:
|
||||
print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}")
|
||||
|
||||
df_merged = df_ref.merge(
|
||||
df_res,
|
||||
on=MERGE_KEY,
|
||||
how="inner",
|
||||
suffixes=("_gt", "_pred")
|
||||
)
|
||||
|
||||
print(f"Merged rows: {len(df_merged)}")
|
||||
|
||||
df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy()
|
||||
|
||||
print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}")
|
||||
|
||||
if df_eval.empty:
|
||||
raise ValueError("No evaluable rows after merging and EDSS filtering.")
|
||||
|
||||
cm = confusion_matrix(
|
||||
df_eval["GT_EDSS_cat"],
|
||||
df_eval["PRED_EDSS_cat"],
|
||||
labels=EDSS_LABELS
|
||||
)
|
||||
|
||||
suffix = f"iter_{TARGET_ITERATION}"
|
||||
|
||||
plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png"
|
||||
cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv"
|
||||
report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt"
|
||||
merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv"
|
||||
|
||||
plot_confusion_matrix(cm, model_name, plot_path)
|
||||
|
||||
cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS)
|
||||
cm_df.index.name = "Ground Truth EDSS"
|
||||
cm_df.columns.name = "LLM Generated EDSS"
|
||||
cm_df.to_csv(cm_csv_path)
|
||||
|
||||
report = classification_report(
|
||||
df_eval["GT_EDSS_cat"],
|
||||
df_eval["PRED_EDSS_cat"],
|
||||
labels=EDSS_LABELS,
|
||||
zero_division=0
|
||||
)
|
||||
|
||||
with open(report_txt_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"Model: {model_name}\n")
|
||||
f.write(f"Result file: {RESULT_PATH}\n")
|
||||
f.write(f"Target iteration: {TARGET_ITERATION}\n")
|
||||
f.write(f"Merged rows: {len(df_merged)}\n")
|
||||
f.write(f"Evaluable rows: {len(df_eval)}\n\n")
|
||||
f.write("Classification Report:\n")
|
||||
f.write(report)
|
||||
f.write("\n\nConfusion Matrix Raw Counts:\n")
|
||||
f.write(cm_df.to_string())
|
||||
|
||||
keep_cols = [
|
||||
MERGE_KEY,
|
||||
"MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum",
|
||||
"GT_EDSS_numeric",
|
||||
"PRED_EDSS_numeric",
|
||||
"GT_EDSS_cat",
|
||||
"PRED_EDSS_cat",
|
||||
"model",
|
||||
"iteration",
|
||||
"success",
|
||||
"inference_time_sec",
|
||||
"certainty_percent",
|
||||
"reason",
|
||||
]
|
||||
|
||||
keep_cols = [col for col in keep_cols if col in df_eval.columns]
|
||||
df_eval[keep_cols].to_csv(merged_csv_path, index=False)
|
||||
|
||||
print("\nClassification Report:")
|
||||
print(report)
|
||||
|
||||
print("\nConfusion Matrix Raw Counts:")
|
||||
print(cm_df)
|
||||
|
||||
print("\nSaved files:")
|
||||
print(f"Plot: {plot_path}")
|
||||
print(f"Confusion matrix: {cm_csv_path}")
|
||||
print(f"Report: {report_txt_path}")
|
||||
print(f"Merged rows: {merged_csv_path}")
|
||||
|
||||
print("\nDone.")
|
||||
##
|
||||
|
||||
|
||||
Reference in New Issue
Block a user