refinement
This commit is contained in:
415
Data/audit.py
415
Data/audit.py
@@ -1,415 +0,0 @@
|
||||
# %% Confirm EDSS missing
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
def clean_series(s):
|
||||
return s.astype(str).str.strip().str.lower()
|
||||
|
||||
def gt_edss_audit(ground_truth_path, edss_col="EDSS"):
|
||||
df_gt = pd.read_csv(ground_truth_path, sep=';')
|
||||
|
||||
# normalize keys
|
||||
df_gt['unique_id'] = clean_series(df_gt['unique_id'])
|
||||
df_gt['MedDatum'] = clean_series(df_gt['MedDatum'])
|
||||
df_gt['key'] = df_gt['unique_id'] + "_" + df_gt['MedDatum']
|
||||
|
||||
print("GT rows:", len(df_gt))
|
||||
print("GT unique keys:", df_gt['key'].nunique())
|
||||
|
||||
# IMPORTANT: parse EDSS robustly (German decimal commas etc.)
|
||||
if edss_col in df_gt.columns:
|
||||
edss_raw = df_gt[edss_col]
|
||||
edss_num = pd.to_numeric(
|
||||
edss_raw.astype(str).str.replace(",", ".", regex=False).str.strip(),
|
||||
errors="coerce"
|
||||
)
|
||||
df_gt["_edss_num"] = edss_num
|
||||
|
||||
print(f"GT missing EDSS look (numeric-coerce): {df_gt['_edss_num'].isna().sum()}")
|
||||
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['_edss_num'].isna(), 'key'].nunique()}")
|
||||
|
||||
# duplicates on key
|
||||
dup = df_gt['key'].duplicated(keep=False)
|
||||
print("GT duplicate-key rows:", dup.sum())
|
||||
if dup.any():
|
||||
# how many duplicate keys exist?
|
||||
print("GT duplicate keys:", df_gt.loc[dup, 'key'].nunique())
|
||||
# of duplicate-key rows, how many have missing EDSS?
|
||||
print("Duplicate-key rows with missing EDSS:", df_gt.loc[dup, "_edss_num"].isna().sum())
|
||||
|
||||
# show the worst offenders
|
||||
print("\nTop duplicate keys (by count):")
|
||||
print(df_gt.loc[dup, 'key'].value_counts().head(10))
|
||||
else:
|
||||
print(f"EDSS column '{edss_col}' not found in GT columns:", df_gt.columns.tolist())
|
||||
|
||||
return df_gt
|
||||
|
||||
df_gt = gt_edss_audit("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", edss_col="EDSS")
|
||||
|
||||
##
|
||||
|
||||
|
||||
|
||||
|
||||
# %% trace missing ones
|
||||
|
||||
import json, glob, os
|
||||
import pandas as pd
|
||||
|
||||
def load_preds(json_dir_path):
|
||||
all_preds = []
|
||||
for file_path in glob.glob(os.path.join(json_dir_path, "*.json")):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
file_name = os.path.basename(file_path)
|
||||
for entry in data:
|
||||
if entry.get("success"):
|
||||
res = entry["result"]
|
||||
all_preds.append({
|
||||
"unique_id": str(res.get("unique_id")).strip().lower(),
|
||||
"MedDatum": str(res.get("MedDatum")).strip().lower(),
|
||||
"file": file_name
|
||||
})
|
||||
df_pred = pd.DataFrame(all_preds)
|
||||
df_pred["key"] = df_pred["unique_id"] + "_" + df_pred["MedDatum"]
|
||||
return df_pred
|
||||
|
||||
df_pred = load_preds("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration")
|
||||
print("Pred rows:", len(df_pred))
|
||||
print("Pred unique keys:", df_pred["key"].nunique())
|
||||
|
||||
# Suppose df_gt was returned from step 1 and has _edss_num + key
|
||||
missing_gt_keys = set(df_gt.loc[df_gt["_edss_num"].isna(), "key"])
|
||||
|
||||
df_pred["gt_key_missing_edss"] = df_pred["key"].isin(missing_gt_keys)
|
||||
|
||||
print("Pred rows whose GT key has missing EDSS:", df_pred["gt_key_missing_edss"].sum())
|
||||
print("Unique keys (among preds) whose GT EDSS missing:", df_pred.loc[df_pred["gt_key_missing_edss"], "key"].nunique())
|
||||
|
||||
print("\nTop files contributing to missing-GT-EDSS rows:")
|
||||
print(df_pred.loc[df_pred["gt_key_missing_edss"], "file"].value_counts().head(20))
|
||||
|
||||
print("\nTop keys replicated in predictions (why count inflates):")
|
||||
print(df_pred.loc[df_pred["gt_key_missing_edss"], "key"].value_counts().head(20))
|
||||
|
||||
|
||||
##
|
||||
|
||||
|
||||
# %% verify
|
||||
|
||||
merged = df_pred.merge(
|
||||
df_gt[["key", "_edss_num"]], # use the numeric-coerced GT EDSS
|
||||
on="key",
|
||||
how="left",
|
||||
validate="many_to_one" # will ERROR if GT has duplicate keys (GOOD!)
|
||||
)
|
||||
|
||||
print("Merged rows:", len(merged))
|
||||
print("Merged missing GT EDSS:", merged["_edss_num"].isna().sum())
|
||||
|
||||
|
||||
##
|
||||
|
||||
|
||||
# %% 1json (rewritten with robust parsing + detailed data log)
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from matplotlib.patches import Patch
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
def plot_single_json_error_analysis_with_log(
|
||||
json_file_path,
|
||||
ground_truth_path,
|
||||
edss_gt_col="EDSS",
|
||||
min_bin_count=5,
|
||||
):
|
||||
def norm_str(x):
|
||||
# normalize identifiers and dates consistently
|
||||
return str(x).strip().lower()
|
||||
|
||||
def parse_edss(x):
|
||||
# robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc.
|
||||
if x is None:
|
||||
return np.nan
|
||||
s = str(x).strip()
|
||||
if s == "" or s.lower() in {"nan", "none", "null"}:
|
||||
return np.nan
|
||||
s = s.replace(",", ".")
|
||||
return pd.to_numeric(s, errors="coerce")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)")
|
||||
print("="*80)
|
||||
print(f"JSON: {json_file_path}")
|
||||
print(f"GT: {ground_truth_path}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1) Load Ground Truth
|
||||
# ------------------------------------------------------------------
|
||||
df_gt = pd.read_csv(ground_truth_path, sep=";")
|
||||
|
||||
required_gt_cols = {"unique_id", "MedDatum", edss_gt_col}
|
||||
missing_cols = required_gt_cols - set(df_gt.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}")
|
||||
|
||||
df_gt["unique_id"] = df_gt["unique_id"].map(norm_str)
|
||||
df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str)
|
||||
df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"]
|
||||
|
||||
# Robust EDSS parsing (important!)
|
||||
df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss)
|
||||
|
||||
# GT logs
|
||||
print("\n--- GT LOG ---")
|
||||
print(f"GT rows: {len(df_gt)}")
|
||||
print(f"GT unique keys: {df_gt['key'].nunique()}")
|
||||
gt_dup = df_gt["key"].duplicated(keep=False).sum()
|
||||
print(f"GT duplicate-key rows: {gt_dup}")
|
||||
print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}")
|
||||
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}")
|
||||
|
||||
if gt_dup > 0:
|
||||
print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:")
|
||||
print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2) Load Predictions from the specific JSON
|
||||
# ------------------------------------------------------------------
|
||||
with open(json_file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
total_entries = len(data)
|
||||
success_entries = sum(1 for e in data if e.get("success"))
|
||||
|
||||
all_preds = []
|
||||
skipped = {
|
||||
"not_success": 0,
|
||||
"missing_uid_or_date": 0,
|
||||
"missing_edss": 0,
|
||||
"missing_conf": 0,
|
||||
}
|
||||
|
||||
for entry in data:
|
||||
if not entry.get("success"):
|
||||
skipped["not_success"] += 1
|
||||
continue
|
||||
|
||||
res = entry.get("result", {})
|
||||
uid = res.get("unique_id")
|
||||
md = res.get("MedDatum")
|
||||
|
||||
if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "":
|
||||
skipped["missing_uid_or_date"] += 1
|
||||
continue
|
||||
|
||||
edss_pred = parse_edss(res.get("EDSS"))
|
||||
conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce")
|
||||
|
||||
if pd.isna(edss_pred):
|
||||
skipped["missing_edss"] += 1
|
||||
if pd.isna(conf):
|
||||
skipped["missing_conf"] += 1
|
||||
|
||||
all_preds.append({
|
||||
"unique_id": norm_str(uid),
|
||||
"MedDatum": norm_str(md),
|
||||
"key": norm_str(uid) + "_" + norm_str(md),
|
||||
"EDSS_pred": edss_pred,
|
||||
"confidence": conf,
|
||||
})
|
||||
|
||||
df_pred = pd.DataFrame(all_preds)
|
||||
|
||||
# Pred logs
|
||||
print("\n--- PRED LOG ---")
|
||||
print(f"JSON total entries: {total_entries}")
|
||||
print(f"JSON success entries: {success_entries}")
|
||||
print(f"Pred rows loaded (success + has keys): {len(df_pred)}")
|
||||
if len(df_pred) == 0:
|
||||
print("[ERROR] No usable prediction rows found. Nothing to plot.")
|
||||
return
|
||||
|
||||
print(f"Pred unique keys: {df_pred['key'].nunique()}")
|
||||
print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}")
|
||||
print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}")
|
||||
print("Skipped counts:", skipped)
|
||||
|
||||
# Are keys duplicated within this JSON? (often yes if multiple notes map to same key)
|
||||
key_counts = df_pred["key"].value_counts()
|
||||
dup_pred_rows = (key_counts > 1).sum()
|
||||
max_rep = int(key_counts.max())
|
||||
print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}")
|
||||
print(f"Max repetitions of a single key in this JSON: {max_rep}")
|
||||
if max_rep > 1:
|
||||
print("Top repeated keys in this JSON:")
|
||||
print(key_counts.head(10))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3) Merge (and diagnose why rows drop)
|
||||
# ------------------------------------------------------------------
|
||||
# Diagnose how many pred keys exist in GT
|
||||
gt_key_set = set(df_gt["key"])
|
||||
df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set)
|
||||
not_in_gt = df_pred.loc[~df_pred["key_in_gt"]]
|
||||
|
||||
print("\n--- KEY MATCH LOG ---")
|
||||
print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}")
|
||||
print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}")
|
||||
if len(not_in_gt) > 0:
|
||||
print("[WARNING] Some prediction keys are not present in GT. First 10:")
|
||||
print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10))
|
||||
|
||||
# Now merge; we expect GT is one-to-many with pred (many_to_one)
|
||||
# If GT had duplicates, validate would raise.
|
||||
df_merged = df_pred.merge(
|
||||
df_gt[["key", "EDSS_gt"]],
|
||||
on="key",
|
||||
how="inner",
|
||||
validate="many_to_one"
|
||||
)
|
||||
|
||||
print("\n--- MERGE LOG ---")
|
||||
print(f"Merged rows (inner join): {len(df_merged)}")
|
||||
print(f"Merged unique keys: {df_merged['key'].nunique()}")
|
||||
print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}")
|
||||
print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}")
|
||||
print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}")
|
||||
|
||||
# How many rows will be removed by dropna() in your old code?
|
||||
# Old code did .dropna() on ALL columns, which can remove rows for missing confidence too.
|
||||
rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"])
|
||||
print("\n--- FILTER LOG (what will be used for stats/plot) ---")
|
||||
print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}")
|
||||
if len(rows_complete) == 0:
|
||||
print("[ERROR] No complete rows after filtering. Nothing to plot.")
|
||||
return
|
||||
|
||||
# Compute abs error
|
||||
rows_complete = rows_complete.copy()
|
||||
rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4) Binning + stats (with guardrails)
|
||||
# ------------------------------------------------------------------
|
||||
bins = [0, 70, 80, 90, 100]
|
||||
labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"]
|
||||
|
||||
# Confidence outside bins becomes NaN; log it
|
||||
rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True)
|
||||
conf_outside = rows_complete["conf_bin"].isna().sum()
|
||||
print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}")
|
||||
if conf_outside > 0:
|
||||
print("Example confidences outside bins:")
|
||||
print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list())
|
||||
|
||||
df_plot = rows_complete.dropna(subset=["conf_bin"])
|
||||
stats = (
|
||||
df_plot.groupby("conf_bin", observed=True)["abs_error"]
|
||||
.agg(mean="mean", std="std", count="count")
|
||||
.reindex(labels)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
print("\n--- BIN STATS ---")
|
||||
print(stats)
|
||||
|
||||
# Warn about low counts
|
||||
low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]]
|
||||
if not low_bins.empty:
|
||||
print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:")
|
||||
print(low_bins)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5) Plot
|
||||
# ------------------------------------------------------------------
|
||||
plt.figure(figsize=(13, 8))
|
||||
colors = sns.color_palette("Blues", n_colors=len(labels))
|
||||
|
||||
# Replace NaNs in mean for plotting bars (empty bins)
|
||||
means = stats["mean"].to_numpy()
|
||||
counts = stats["count"].fillna(0).astype(int).to_numpy()
|
||||
stds = stats["std"].to_numpy()
|
||||
|
||||
# For bins with no data, bar height 0 (and no errorbar)
|
||||
means_plot = np.nan_to_num(means, nan=0.0)
|
||||
|
||||
bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85)
|
||||
|
||||
# Error bars only where count>1 and std is not NaN
|
||||
sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan)
|
||||
plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5)
|
||||
|
||||
# Trend line only if at least 2 non-empty bins
|
||||
valid_idx = np.where(~np.isnan(means))[0]
|
||||
if len(valid_idx) >= 2:
|
||||
x_idx = np.arange(len(labels))
|
||||
z = np.polyfit(valid_idx, means[valid_idx], 1)
|
||||
p = np.poly1d(z)
|
||||
plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5)
|
||||
trend_label = "Trend Line"
|
||||
else:
|
||||
trend_label = "Trend Line (insufficient bins)"
|
||||
print("\n[INFO] Not enough non-empty bins to fit a trend line.")
|
||||
|
||||
# Data labels
|
||||
for i, bar in enumerate(bars):
|
||||
n_count = int(counts[i])
|
||||
mae_val = means[i]
|
||||
if np.isnan(mae_val) or n_count == 0:
|
||||
txt = "empty"
|
||||
y = 0.02
|
||||
else:
|
||||
txt = f"MAE: {mae_val:.2f}\nn={n_count}"
|
||||
y = bar.get_height() + 0.04
|
||||
plt.text(
|
||||
bar.get_x() + bar.get_width()/2,
|
||||
y,
|
||||
txt,
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=10
|
||||
)
|
||||
|
||||
# Legend
|
||||
legend_elements = [
|
||||
Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"),
|
||||
Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"),
|
||||
Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"),
|
||||
Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"),
|
||||
Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label),
|
||||
Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"),
|
||||
Patch(color="none", label="Metric: Mean Absolute Error (MAE)")
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend")
|
||||
|
||||
plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
|
||||
plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12)
|
||||
plt.xlabel("LLM Confidence Bracket", fontsize=12)
|
||||
plt.grid(axis="y", linestyle=":", alpha=0.5)
|
||||
|
||||
ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0
|
||||
plt.ylim(0, max(0.5, float(ymax) + 0.6))
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("DONE")
|
||||
print("="*80)
|
||||
|
||||
|
||||
# --- RUN ---
|
||||
json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json"
|
||||
gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv"
|
||||
|
||||
plot_single_json_error_analysis_with_log(json_path, gt_path)
|
||||
|
||||
|
||||
|
||||
##
|
||||
60
figure1.py
60
figure1.py
@@ -320,3 +320,63 @@ plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
##
|
||||
|
||||
# %% Patientjourney Bubble chart
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import matplotlib as mpl
|
||||
|
||||
mpl.rcParams["font.family"] = "DejaVu Sans" # or "Arial", "Calibri", "Times New Roman", ...
|
||||
mpl.rcParams["font.size"] = 12 # default size for text
|
||||
mpl.rcParams["axes.titlesize"] = 14
|
||||
mpl.rcParams["axes.titleweight"] = "bold"
|
||||
|
||||
|
||||
# Data (your counts)
|
||||
visits = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
patient_count = np.array([32, 24, 28, 17, 13, 6, 3, 3, 2])
|
||||
|
||||
# "Remaining" = patients with >= that many visits (cumulative from the right)
|
||||
remaining = np.array([patient_count[i:].sum() for i in range(len(patient_count))])
|
||||
|
||||
# --- Plot ---
|
||||
fig, ax = plt.subplots(figsize=(12, 3))
|
||||
|
||||
y = 0.0 # all bubbles on one horizontal line
|
||||
|
||||
# Horizontal line
|
||||
ax.hlines(y, visits.min() - 0.4, visits.max() + 0.4, color="#1f77b4", linewidth=3)
|
||||
|
||||
# Bubble sizes (scale as needed)
|
||||
# (Matplotlib scatter uses area in points^2)
|
||||
sizes = patient_count * 35 # tweak this multiplier if you want bigger/smaller bubbles
|
||||
|
||||
ax.scatter(visits, np.full_like(visits, y), s=sizes, color="#1f77b4", zorder=3)
|
||||
|
||||
# Title
|
||||
#ax.set_title("Patient Journey by Visit Count", fontsize=14, pad=18)
|
||||
|
||||
# Top labels: "1 visits", "2 visits", ...
|
||||
for x in visits:
|
||||
label = f"{x} visit" if x == 1 else f"{x} visits"
|
||||
ax.text(x, y + 0.18, label, ha="center", va="bottom", fontsize=10)
|
||||
|
||||
# Bottom labels: "X patients" and "Y remaining"
|
||||
for x, pc, rem in zip(visits, patient_count, remaining):
|
||||
ax.text(x, y - 0.20, f"{pc} patients", ha="center", va="top", fontsize=9)
|
||||
ax.text(x, y - 0.32, f"{rem} remaining", ha="center", va="top", fontsize=9)
|
||||
|
||||
# Cosmetics: remove axes, keep spacing nice
|
||||
ax.set_xlim(visits.min() - 0.6, visits.max() + 0.6)
|
||||
ax.set_ylim(-0.5, 0.35)
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
plt.savefig("patient_journey.svg", format="svg", bbox_inches="tight")
|
||||
##
|
||||
|
||||
|
||||
Reference in New Issue
Block a user