416 lines
15 KiB
Python
416 lines
15 KiB
Python
# %% Confirm EDSS missing
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
def clean_series(s):
|
|
return s.astype(str).str.strip().str.lower()
|
|
|
|
def gt_edss_audit(ground_truth_path, edss_col="EDSS"):
|
|
df_gt = pd.read_csv(ground_truth_path, sep=';')
|
|
|
|
# normalize keys
|
|
df_gt['unique_id'] = clean_series(df_gt['unique_id'])
|
|
df_gt['MedDatum'] = clean_series(df_gt['MedDatum'])
|
|
df_gt['key'] = df_gt['unique_id'] + "_" + df_gt['MedDatum']
|
|
|
|
print("GT rows:", len(df_gt))
|
|
print("GT unique keys:", df_gt['key'].nunique())
|
|
|
|
# IMPORTANT: parse EDSS robustly (German decimal commas etc.)
|
|
if edss_col in df_gt.columns:
|
|
edss_raw = df_gt[edss_col]
|
|
edss_num = pd.to_numeric(
|
|
edss_raw.astype(str).str.replace(",", ".", regex=False).str.strip(),
|
|
errors="coerce"
|
|
)
|
|
df_gt["_edss_num"] = edss_num
|
|
|
|
print(f"GT missing EDSS look (numeric-coerce): {df_gt['_edss_num'].isna().sum()}")
|
|
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['_edss_num'].isna(), 'key'].nunique()}")
|
|
|
|
# duplicates on key
|
|
dup = df_gt['key'].duplicated(keep=False)
|
|
print("GT duplicate-key rows:", dup.sum())
|
|
if dup.any():
|
|
# how many duplicate keys exist?
|
|
print("GT duplicate keys:", df_gt.loc[dup, 'key'].nunique())
|
|
# of duplicate-key rows, how many have missing EDSS?
|
|
print("Duplicate-key rows with missing EDSS:", df_gt.loc[dup, "_edss_num"].isna().sum())
|
|
|
|
# show the worst offenders
|
|
print("\nTop duplicate keys (by count):")
|
|
print(df_gt.loc[dup, 'key'].value_counts().head(10))
|
|
else:
|
|
print(f"EDSS column '{edss_col}' not found in GT columns:", df_gt.columns.tolist())
|
|
|
|
return df_gt
|
|
|
|
df_gt = gt_edss_audit("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", edss_col="EDSS")
|
|
|
|
##
|
|
|
|
|
|
|
|
|
|
# %% trace missing ones
|
|
|
|
import json, glob, os
|
|
import pandas as pd
|
|
|
|
def load_preds(json_dir_path):
|
|
all_preds = []
|
|
for file_path in glob.glob(os.path.join(json_dir_path, "*.json")):
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
file_name = os.path.basename(file_path)
|
|
for entry in data:
|
|
if entry.get("success"):
|
|
res = entry["result"]
|
|
all_preds.append({
|
|
"unique_id": str(res.get("unique_id")).strip().lower(),
|
|
"MedDatum": str(res.get("MedDatum")).strip().lower(),
|
|
"file": file_name
|
|
})
|
|
df_pred = pd.DataFrame(all_preds)
|
|
df_pred["key"] = df_pred["unique_id"] + "_" + df_pred["MedDatum"]
|
|
return df_pred
|
|
|
|
df_pred = load_preds("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration")
|
|
print("Pred rows:", len(df_pred))
|
|
print("Pred unique keys:", df_pred["key"].nunique())
|
|
|
|
# Suppose df_gt was returned from step 1 and has _edss_num + key
|
|
missing_gt_keys = set(df_gt.loc[df_gt["_edss_num"].isna(), "key"])
|
|
|
|
df_pred["gt_key_missing_edss"] = df_pred["key"].isin(missing_gt_keys)
|
|
|
|
print("Pred rows whose GT key has missing EDSS:", df_pred["gt_key_missing_edss"].sum())
|
|
print("Unique keys (among preds) whose GT EDSS missing:", df_pred.loc[df_pred["gt_key_missing_edss"], "key"].nunique())
|
|
|
|
print("\nTop files contributing to missing-GT-EDSS rows:")
|
|
print(df_pred.loc[df_pred["gt_key_missing_edss"], "file"].value_counts().head(20))
|
|
|
|
print("\nTop keys replicated in predictions (why count inflates):")
|
|
print(df_pred.loc[df_pred["gt_key_missing_edss"], "key"].value_counts().head(20))
|
|
|
|
|
|
##
|
|
|
|
|
|
# %% verify
|
|
|
|
merged = df_pred.merge(
|
|
df_gt[["key", "_edss_num"]], # use the numeric-coerced GT EDSS
|
|
on="key",
|
|
how="left",
|
|
validate="many_to_one" # will ERROR if GT has duplicate keys (GOOD!)
|
|
)
|
|
|
|
print("Merged rows:", len(merged))
|
|
print("Merged missing GT EDSS:", merged["_edss_num"].isna().sum())
|
|
|
|
|
|
##
|
|
|
|
|
|
# %% 1json (rewritten with robust parsing + detailed data log)
|
|
import pandas as pd
|
|
import numpy as np
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from matplotlib.patches import Patch
|
|
from matplotlib.lines import Line2D
|
|
|
|
def plot_single_json_error_analysis_with_log(
|
|
json_file_path,
|
|
ground_truth_path,
|
|
edss_gt_col="EDSS",
|
|
min_bin_count=5,
|
|
):
|
|
def norm_str(x):
|
|
# normalize identifiers and dates consistently
|
|
return str(x).strip().lower()
|
|
|
|
def parse_edss(x):
|
|
# robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc.
|
|
if x is None:
|
|
return np.nan
|
|
s = str(x).strip()
|
|
if s == "" or s.lower() in {"nan", "none", "null"}:
|
|
return np.nan
|
|
s = s.replace(",", ".")
|
|
return pd.to_numeric(s, errors="coerce")
|
|
|
|
print("\n" + "="*80)
|
|
print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)")
|
|
print("="*80)
|
|
print(f"JSON: {json_file_path}")
|
|
print(f"GT: {ground_truth_path}")
|
|
|
|
# ------------------------------------------------------------------
|
|
# 1) Load Ground Truth
|
|
# ------------------------------------------------------------------
|
|
df_gt = pd.read_csv(ground_truth_path, sep=";")
|
|
|
|
required_gt_cols = {"unique_id", "MedDatum", edss_gt_col}
|
|
missing_cols = required_gt_cols - set(df_gt.columns)
|
|
if missing_cols:
|
|
raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}")
|
|
|
|
df_gt["unique_id"] = df_gt["unique_id"].map(norm_str)
|
|
df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str)
|
|
df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"]
|
|
|
|
# Robust EDSS parsing (important!)
|
|
df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss)
|
|
|
|
# GT logs
|
|
print("\n--- GT LOG ---")
|
|
print(f"GT rows: {len(df_gt)}")
|
|
print(f"GT unique keys: {df_gt['key'].nunique()}")
|
|
gt_dup = df_gt["key"].duplicated(keep=False).sum()
|
|
print(f"GT duplicate-key rows: {gt_dup}")
|
|
print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}")
|
|
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}")
|
|
|
|
if gt_dup > 0:
|
|
print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:")
|
|
print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10))
|
|
|
|
# ------------------------------------------------------------------
|
|
# 2) Load Predictions from the specific JSON
|
|
# ------------------------------------------------------------------
|
|
with open(json_file_path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
|
|
total_entries = len(data)
|
|
success_entries = sum(1 for e in data if e.get("success"))
|
|
|
|
all_preds = []
|
|
skipped = {
|
|
"not_success": 0,
|
|
"missing_uid_or_date": 0,
|
|
"missing_edss": 0,
|
|
"missing_conf": 0,
|
|
}
|
|
|
|
for entry in data:
|
|
if not entry.get("success"):
|
|
skipped["not_success"] += 1
|
|
continue
|
|
|
|
res = entry.get("result", {})
|
|
uid = res.get("unique_id")
|
|
md = res.get("MedDatum")
|
|
|
|
if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "":
|
|
skipped["missing_uid_or_date"] += 1
|
|
continue
|
|
|
|
edss_pred = parse_edss(res.get("EDSS"))
|
|
conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce")
|
|
|
|
if pd.isna(edss_pred):
|
|
skipped["missing_edss"] += 1
|
|
if pd.isna(conf):
|
|
skipped["missing_conf"] += 1
|
|
|
|
all_preds.append({
|
|
"unique_id": norm_str(uid),
|
|
"MedDatum": norm_str(md),
|
|
"key": norm_str(uid) + "_" + norm_str(md),
|
|
"EDSS_pred": edss_pred,
|
|
"confidence": conf,
|
|
})
|
|
|
|
df_pred = pd.DataFrame(all_preds)
|
|
|
|
# Pred logs
|
|
print("\n--- PRED LOG ---")
|
|
print(f"JSON total entries: {total_entries}")
|
|
print(f"JSON success entries: {success_entries}")
|
|
print(f"Pred rows loaded (success + has keys): {len(df_pred)}")
|
|
if len(df_pred) == 0:
|
|
print("[ERROR] No usable prediction rows found. Nothing to plot.")
|
|
return
|
|
|
|
print(f"Pred unique keys: {df_pred['key'].nunique()}")
|
|
print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}")
|
|
print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}")
|
|
print("Skipped counts:", skipped)
|
|
|
|
# Are keys duplicated within this JSON? (often yes if multiple notes map to same key)
|
|
key_counts = df_pred["key"].value_counts()
|
|
dup_pred_rows = (key_counts > 1).sum()
|
|
max_rep = int(key_counts.max())
|
|
print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}")
|
|
print(f"Max repetitions of a single key in this JSON: {max_rep}")
|
|
if max_rep > 1:
|
|
print("Top repeated keys in this JSON:")
|
|
print(key_counts.head(10))
|
|
|
|
# ------------------------------------------------------------------
|
|
# 3) Merge (and diagnose why rows drop)
|
|
# ------------------------------------------------------------------
|
|
# Diagnose how many pred keys exist in GT
|
|
gt_key_set = set(df_gt["key"])
|
|
df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set)
|
|
not_in_gt = df_pred.loc[~df_pred["key_in_gt"]]
|
|
|
|
print("\n--- KEY MATCH LOG ---")
|
|
print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}")
|
|
print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}")
|
|
if len(not_in_gt) > 0:
|
|
print("[WARNING] Some prediction keys are not present in GT. First 10:")
|
|
print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10))
|
|
|
|
# Now merge; we expect GT is one-to-many with pred (many_to_one)
|
|
# If GT had duplicates, validate would raise.
|
|
df_merged = df_pred.merge(
|
|
df_gt[["key", "EDSS_gt"]],
|
|
on="key",
|
|
how="inner",
|
|
validate="many_to_one"
|
|
)
|
|
|
|
print("\n--- MERGE LOG ---")
|
|
print(f"Merged rows (inner join): {len(df_merged)}")
|
|
print(f"Merged unique keys: {df_merged['key'].nunique()}")
|
|
print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}")
|
|
print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}")
|
|
print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}")
|
|
|
|
# How many rows will be removed by dropna() in your old code?
|
|
# Old code did .dropna() on ALL columns, which can remove rows for missing confidence too.
|
|
rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"])
|
|
print("\n--- FILTER LOG (what will be used for stats/plot) ---")
|
|
print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}")
|
|
if len(rows_complete) == 0:
|
|
print("[ERROR] No complete rows after filtering. Nothing to plot.")
|
|
return
|
|
|
|
# Compute abs error
|
|
rows_complete = rows_complete.copy()
|
|
rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs()
|
|
|
|
# ------------------------------------------------------------------
|
|
# 4) Binning + stats (with guardrails)
|
|
# ------------------------------------------------------------------
|
|
bins = [0, 70, 80, 90, 100]
|
|
labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"]
|
|
|
|
# Confidence outside bins becomes NaN; log it
|
|
rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True)
|
|
conf_outside = rows_complete["conf_bin"].isna().sum()
|
|
print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}")
|
|
if conf_outside > 0:
|
|
print("Example confidences outside bins:")
|
|
print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list())
|
|
|
|
df_plot = rows_complete.dropna(subset=["conf_bin"])
|
|
stats = (
|
|
df_plot.groupby("conf_bin", observed=True)["abs_error"]
|
|
.agg(mean="mean", std="std", count="count")
|
|
.reindex(labels)
|
|
.reset_index()
|
|
)
|
|
|
|
print("\n--- BIN STATS ---")
|
|
print(stats)
|
|
|
|
# Warn about low counts
|
|
low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]]
|
|
if not low_bins.empty:
|
|
print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:")
|
|
print(low_bins)
|
|
|
|
# ------------------------------------------------------------------
|
|
# 5) Plot
|
|
# ------------------------------------------------------------------
|
|
plt.figure(figsize=(13, 8))
|
|
colors = sns.color_palette("Blues", n_colors=len(labels))
|
|
|
|
# Replace NaNs in mean for plotting bars (empty bins)
|
|
means = stats["mean"].to_numpy()
|
|
counts = stats["count"].fillna(0).astype(int).to_numpy()
|
|
stds = stats["std"].to_numpy()
|
|
|
|
# For bins with no data, bar height 0 (and no errorbar)
|
|
means_plot = np.nan_to_num(means, nan=0.0)
|
|
|
|
bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85)
|
|
|
|
# Error bars only where count>1 and std is not NaN
|
|
sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan)
|
|
plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5)
|
|
|
|
# Trend line only if at least 2 non-empty bins
|
|
valid_idx = np.where(~np.isnan(means))[0]
|
|
if len(valid_idx) >= 2:
|
|
x_idx = np.arange(len(labels))
|
|
z = np.polyfit(valid_idx, means[valid_idx], 1)
|
|
p = np.poly1d(z)
|
|
plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5)
|
|
trend_label = "Trend Line"
|
|
else:
|
|
trend_label = "Trend Line (insufficient bins)"
|
|
print("\n[INFO] Not enough non-empty bins to fit a trend line.")
|
|
|
|
# Data labels
|
|
for i, bar in enumerate(bars):
|
|
n_count = int(counts[i])
|
|
mae_val = means[i]
|
|
if np.isnan(mae_val) or n_count == 0:
|
|
txt = "empty"
|
|
y = 0.02
|
|
else:
|
|
txt = f"MAE: {mae_val:.2f}\nn={n_count}"
|
|
y = bar.get_height() + 0.04
|
|
plt.text(
|
|
bar.get_x() + bar.get_width()/2,
|
|
y,
|
|
txt,
|
|
ha="center",
|
|
va="bottom",
|
|
fontweight="bold",
|
|
fontsize=10
|
|
)
|
|
|
|
# Legend
|
|
legend_elements = [
|
|
Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"),
|
|
Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"),
|
|
Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"),
|
|
Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"),
|
|
Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label),
|
|
Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"),
|
|
Patch(color="none", label="Metric: Mean Absolute Error (MAE)")
|
|
]
|
|
plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend")
|
|
|
|
plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
|
|
plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12)
|
|
plt.xlabel("LLM Confidence Bracket", fontsize=12)
|
|
plt.grid(axis="y", linestyle=":", alpha=0.5)
|
|
|
|
ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0
|
|
plt.ylim(0, max(0.5, float(ymax) + 0.6))
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
print("\n" + "="*80)
|
|
print("DONE")
|
|
print("="*80)
|
|
|
|
|
|
# --- RUN ---
|
|
json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json"
|
|
gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv"
|
|
|
|
plot_single_json_error_analysis_with_log(json_path, gt_path)
|
|
|
|
|
|
|
|
##
|