refinement

This commit is contained in:
2026-02-23 15:06:54 +01:00
parent 99862629b8
commit 118e3e63b3
5 changed files with 2431 additions and 415 deletions

View File

@@ -1,415 +0,0 @@
# %% Confirm EDSS missing
import pandas as pd
import numpy as np
def clean_series(s):
return s.astype(str).str.strip().str.lower()
def gt_edss_audit(ground_truth_path, edss_col="EDSS"):
df_gt = pd.read_csv(ground_truth_path, sep=';')
# normalize keys
df_gt['unique_id'] = clean_series(df_gt['unique_id'])
df_gt['MedDatum'] = clean_series(df_gt['MedDatum'])
df_gt['key'] = df_gt['unique_id'] + "_" + df_gt['MedDatum']
print("GT rows:", len(df_gt))
print("GT unique keys:", df_gt['key'].nunique())
# IMPORTANT: parse EDSS robustly (German decimal commas etc.)
if edss_col in df_gt.columns:
edss_raw = df_gt[edss_col]
edss_num = pd.to_numeric(
edss_raw.astype(str).str.replace(",", ".", regex=False).str.strip(),
errors="coerce"
)
df_gt["_edss_num"] = edss_num
print(f"GT missing EDSS look (numeric-coerce): {df_gt['_edss_num'].isna().sum()}")
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['_edss_num'].isna(), 'key'].nunique()}")
# duplicates on key
dup = df_gt['key'].duplicated(keep=False)
print("GT duplicate-key rows:", dup.sum())
if dup.any():
# how many duplicate keys exist?
print("GT duplicate keys:", df_gt.loc[dup, 'key'].nunique())
# of duplicate-key rows, how many have missing EDSS?
print("Duplicate-key rows with missing EDSS:", df_gt.loc[dup, "_edss_num"].isna().sum())
# show the worst offenders
print("\nTop duplicate keys (by count):")
print(df_gt.loc[dup, 'key'].value_counts().head(10))
else:
print(f"EDSS column '{edss_col}' not found in GT columns:", df_gt.columns.tolist())
return df_gt
df_gt = gt_edss_audit("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", edss_col="EDSS")
##
# %% trace missing ones
import json, glob, os
import pandas as pd
def load_preds(json_dir_path):
all_preds = []
for file_path in glob.glob(os.path.join(json_dir_path, "*.json")):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
file_name = os.path.basename(file_path)
for entry in data:
if entry.get("success"):
res = entry["result"]
all_preds.append({
"unique_id": str(res.get("unique_id")).strip().lower(),
"MedDatum": str(res.get("MedDatum")).strip().lower(),
"file": file_name
})
df_pred = pd.DataFrame(all_preds)
df_pred["key"] = df_pred["unique_id"] + "_" + df_pred["MedDatum"]
return df_pred
df_pred = load_preds("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration")
print("Pred rows:", len(df_pred))
print("Pred unique keys:", df_pred["key"].nunique())
# Suppose df_gt was returned from step 1 and has _edss_num + key
missing_gt_keys = set(df_gt.loc[df_gt["_edss_num"].isna(), "key"])
df_pred["gt_key_missing_edss"] = df_pred["key"].isin(missing_gt_keys)
print("Pred rows whose GT key has missing EDSS:", df_pred["gt_key_missing_edss"].sum())
print("Unique keys (among preds) whose GT EDSS missing:", df_pred.loc[df_pred["gt_key_missing_edss"], "key"].nunique())
print("\nTop files contributing to missing-GT-EDSS rows:")
print(df_pred.loc[df_pred["gt_key_missing_edss"], "file"].value_counts().head(20))
print("\nTop keys replicated in predictions (why count inflates):")
print(df_pred.loc[df_pred["gt_key_missing_edss"], "key"].value_counts().head(20))
##
# %% verify
merged = df_pred.merge(
df_gt[["key", "_edss_num"]], # use the numeric-coerced GT EDSS
on="key",
how="left",
validate="many_to_one" # will ERROR if GT has duplicate keys (GOOD!)
)
print("Merged rows:", len(merged))
print("Merged missing GT EDSS:", merged["_edss_num"].isna().sum())
##
# %% 1json (rewritten with robust parsing + detailed data log)
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
def plot_single_json_error_analysis_with_log(
json_file_path,
ground_truth_path,
edss_gt_col="EDSS",
min_bin_count=5,
):
def norm_str(x):
# normalize identifiers and dates consistently
return str(x).strip().lower()
def parse_edss(x):
# robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc.
if x is None:
return np.nan
s = str(x).strip()
if s == "" or s.lower() in {"nan", "none", "null"}:
return np.nan
s = s.replace(",", ".")
return pd.to_numeric(s, errors="coerce")
print("\n" + "="*80)
print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)")
print("="*80)
print(f"JSON: {json_file_path}")
print(f"GT: {ground_truth_path}")
# ------------------------------------------------------------------
# 1) Load Ground Truth
# ------------------------------------------------------------------
df_gt = pd.read_csv(ground_truth_path, sep=";")
required_gt_cols = {"unique_id", "MedDatum", edss_gt_col}
missing_cols = required_gt_cols - set(df_gt.columns)
if missing_cols:
raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}")
df_gt["unique_id"] = df_gt["unique_id"].map(norm_str)
df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str)
df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"]
# Robust EDSS parsing (important!)
df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss)
# GT logs
print("\n--- GT LOG ---")
print(f"GT rows: {len(df_gt)}")
print(f"GT unique keys: {df_gt['key'].nunique()}")
gt_dup = df_gt["key"].duplicated(keep=False).sum()
print(f"GT duplicate-key rows: {gt_dup}")
print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}")
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}")
if gt_dup > 0:
print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:")
print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10))
# ------------------------------------------------------------------
# 2) Load Predictions from the specific JSON
# ------------------------------------------------------------------
with open(json_file_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_entries = len(data)
success_entries = sum(1 for e in data if e.get("success"))
all_preds = []
skipped = {
"not_success": 0,
"missing_uid_or_date": 0,
"missing_edss": 0,
"missing_conf": 0,
}
for entry in data:
if not entry.get("success"):
skipped["not_success"] += 1
continue
res = entry.get("result", {})
uid = res.get("unique_id")
md = res.get("MedDatum")
if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "":
skipped["missing_uid_or_date"] += 1
continue
edss_pred = parse_edss(res.get("EDSS"))
conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce")
if pd.isna(edss_pred):
skipped["missing_edss"] += 1
if pd.isna(conf):
skipped["missing_conf"] += 1
all_preds.append({
"unique_id": norm_str(uid),
"MedDatum": norm_str(md),
"key": norm_str(uid) + "_" + norm_str(md),
"EDSS_pred": edss_pred,
"confidence": conf,
})
df_pred = pd.DataFrame(all_preds)
# Pred logs
print("\n--- PRED LOG ---")
print(f"JSON total entries: {total_entries}")
print(f"JSON success entries: {success_entries}")
print(f"Pred rows loaded (success + has keys): {len(df_pred)}")
if len(df_pred) == 0:
print("[ERROR] No usable prediction rows found. Nothing to plot.")
return
print(f"Pred unique keys: {df_pred['key'].nunique()}")
print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}")
print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}")
print("Skipped counts:", skipped)
# Are keys duplicated within this JSON? (often yes if multiple notes map to same key)
key_counts = df_pred["key"].value_counts()
dup_pred_rows = (key_counts > 1).sum()
max_rep = int(key_counts.max())
print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}")
print(f"Max repetitions of a single key in this JSON: {max_rep}")
if max_rep > 1:
print("Top repeated keys in this JSON:")
print(key_counts.head(10))
# ------------------------------------------------------------------
# 3) Merge (and diagnose why rows drop)
# ------------------------------------------------------------------
# Diagnose how many pred keys exist in GT
gt_key_set = set(df_gt["key"])
df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set)
not_in_gt = df_pred.loc[~df_pred["key_in_gt"]]
print("\n--- KEY MATCH LOG ---")
print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}")
print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}")
if len(not_in_gt) > 0:
print("[WARNING] Some prediction keys are not present in GT. First 10:")
print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10))
# Now merge; we expect GT is one-to-many with pred (many_to_one)
# If GT had duplicates, validate would raise.
df_merged = df_pred.merge(
df_gt[["key", "EDSS_gt"]],
on="key",
how="inner",
validate="many_to_one"
)
print("\n--- MERGE LOG ---")
print(f"Merged rows (inner join): {len(df_merged)}")
print(f"Merged unique keys: {df_merged['key'].nunique()}")
print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}")
print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}")
print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}")
# How many rows will be removed by dropna() in your old code?
# Old code did .dropna() on ALL columns, which can remove rows for missing confidence too.
rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"])
print("\n--- FILTER LOG (what will be used for stats/plot) ---")
print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}")
if len(rows_complete) == 0:
print("[ERROR] No complete rows after filtering. Nothing to plot.")
return
# Compute abs error
rows_complete = rows_complete.copy()
rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs()
# ------------------------------------------------------------------
# 4) Binning + stats (with guardrails)
# ------------------------------------------------------------------
bins = [0, 70, 80, 90, 100]
labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"]
# Confidence outside bins becomes NaN; log it
rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True)
conf_outside = rows_complete["conf_bin"].isna().sum()
print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}")
if conf_outside > 0:
print("Example confidences outside bins:")
print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list())
df_plot = rows_complete.dropna(subset=["conf_bin"])
stats = (
df_plot.groupby("conf_bin", observed=True)["abs_error"]
.agg(mean="mean", std="std", count="count")
.reindex(labels)
.reset_index()
)
print("\n--- BIN STATS ---")
print(stats)
# Warn about low counts
low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]]
if not low_bins.empty:
print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:")
print(low_bins)
# ------------------------------------------------------------------
# 5) Plot
# ------------------------------------------------------------------
plt.figure(figsize=(13, 8))
colors = sns.color_palette("Blues", n_colors=len(labels))
# Replace NaNs in mean for plotting bars (empty bins)
means = stats["mean"].to_numpy()
counts = stats["count"].fillna(0).astype(int).to_numpy()
stds = stats["std"].to_numpy()
# For bins with no data, bar height 0 (and no errorbar)
means_plot = np.nan_to_num(means, nan=0.0)
bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85)
# Error bars only where count>1 and std is not NaN
sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan)
plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5)
# Trend line only if at least 2 non-empty bins
valid_idx = np.where(~np.isnan(means))[0]
if len(valid_idx) >= 2:
x_idx = np.arange(len(labels))
z = np.polyfit(valid_idx, means[valid_idx], 1)
p = np.poly1d(z)
plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5)
trend_label = "Trend Line"
else:
trend_label = "Trend Line (insufficient bins)"
print("\n[INFO] Not enough non-empty bins to fit a trend line.")
# Data labels
for i, bar in enumerate(bars):
n_count = int(counts[i])
mae_val = means[i]
if np.isnan(mae_val) or n_count == 0:
txt = "empty"
y = 0.02
else:
txt = f"MAE: {mae_val:.2f}\nn={n_count}"
y = bar.get_height() + 0.04
plt.text(
bar.get_x() + bar.get_width()/2,
y,
txt,
ha="center",
va="bottom",
fontweight="bold",
fontsize=10
)
# Legend
legend_elements = [
Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"),
Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"),
Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"),
Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"),
Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label),
Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"),
Patch(color="none", label="Metric: Mean Absolute Error (MAE)")
]
plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend")
plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12)
plt.xlabel("LLM Confidence Bracket", fontsize=12)
plt.grid(axis="y", linestyle=":", alpha=0.5)
ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0
plt.ylim(0, max(0.5, float(ymax) + 0.6))
plt.tight_layout()
plt.show()
print("\n" + "="*80)
print("DONE")
print("="*80)
# --- RUN ---
json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json"
gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv"
plot_single_json_error_analysis_with_log(json_path, gt_path)
##

2371
audit.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -320,3 +320,63 @@ plt.tight_layout()
plt.show() plt.show()
## ##
# %% Patientjourney Bubble chart
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
mpl.rcParams["font.family"] = "DejaVu Sans" # or "Arial", "Calibri", "Times New Roman", ...
mpl.rcParams["font.size"] = 12 # default size for text
mpl.rcParams["axes.titlesize"] = 14
mpl.rcParams["axes.titleweight"] = "bold"
# Data (your counts)
visits = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
patient_count = np.array([32, 24, 28, 17, 13, 6, 3, 3, 2])
# "Remaining" = patients with >= that many visits (cumulative from the right)
remaining = np.array([patient_count[i:].sum() for i in range(len(patient_count))])
# --- Plot ---
fig, ax = plt.subplots(figsize=(12, 3))
y = 0.0 # all bubbles on one horizontal line
# Horizontal line
ax.hlines(y, visits.min() - 0.4, visits.max() + 0.4, color="#1f77b4", linewidth=3)
# Bubble sizes (scale as needed)
# (Matplotlib scatter uses area in points^2)
sizes = patient_count * 35 # tweak this multiplier if you want bigger/smaller bubbles
ax.scatter(visits, np.full_like(visits, y), s=sizes, color="#1f77b4", zorder=3)
# Title
#ax.set_title("Patient Journey by Visit Count", fontsize=14, pad=18)
# Top labels: "1 visits", "2 visits", ...
for x in visits:
label = f"{x} visit" if x == 1 else f"{x} visits"
ax.text(x, y + 0.18, label, ha="center", va="bottom", fontsize=10)
# Bottom labels: "X patients" and "Y remaining"
for x, pc, rem in zip(visits, patient_count, remaining):
ax.text(x, y - 0.20, f"{pc} patients", ha="center", va="top", fontsize=9)
ax.text(x, y - 0.32, f"{rem} remaining", ha="center", va="top", fontsize=9)
# Cosmetics: remove axes, keep spacing nice
ax.set_xlim(visits.min() - 0.6, visits.max() + 0.6)
ax.set_ylim(-0.5, 0.35)
ax.set_xticks([])
ax.set_yticks([])
for spine in ax.spines.values():
spine.set_visible(False)
plt.tight_layout()
plt.show()
plt.savefig("patient_journey.svg", format="svg", bbox_inches="tight")
##