Modifications

This commit is contained in:
2026-04-27 11:52:53 +02:00
parent 816c50e467
commit 90d411f086
3 changed files with 446 additions and 98 deletions
+312 -1
View File
@@ -389,7 +389,7 @@ def plot_single_json_error_analysis_with_log(
]
plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend")
plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12)
plt.xlabel("LLM Confidence Bracket", fontsize=12)
plt.grid(axis="y", linestyle=":", alpha=0.5)
@@ -414,6 +414,317 @@ plot_single_json_error_analysis_with_log(json_path, gt_path)
##
# %% 1json (rewritten with robust parsing + detailed data log + Pearson r in plot)
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from scipy.stats import pearsonr
def plot_single_json_error_analysis_with_log(
json_file_path,
ground_truth_path,
edss_gt_col="EDSS",
min_bin_count=5,
):
def norm_str(x):
# normalize identifiers and dates consistently
return str(x).strip().lower()
def parse_edss(x):
# robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc.
if x is None:
return np.nan
s = str(x).strip()
if s == "" or s.lower() in {"nan", "none", "null"}:
return np.nan
s = s.replace(",", ".")
return pd.to_numeric(s, errors="coerce")
print("\n" + "="*80)
print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)")
print("="*80)
print(f"JSON: {json_file_path}")
print(f"GT: {ground_truth_path}")
# ------------------------------------------------------------------
# 1) Load Ground Truth
# ------------------------------------------------------------------
df_gt = pd.read_csv(ground_truth_path, sep=";")
required_gt_cols = {"unique_id", "MedDatum", edss_gt_col}
missing_cols = required_gt_cols - set(df_gt.columns)
if missing_cols:
raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}")
df_gt["unique_id"] = df_gt["unique_id"].map(norm_str)
df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str)
df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"]
# Robust EDSS parsing
df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss)
# GT logs
print("\n--- GT LOG ---")
print(f"GT rows: {len(df_gt)}")
print(f"GT unique keys: {df_gt['key'].nunique()}")
gt_dup = df_gt["key"].duplicated(keep=False).sum()
print(f"GT duplicate-key rows: {gt_dup}")
print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}")
print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}")
if gt_dup > 0:
print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:")
print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10))
# ------------------------------------------------------------------
# 2) Load Predictions from the specific JSON
# ------------------------------------------------------------------
with open(json_file_path, "r", encoding="utf-8") as f:
data = json.load(f)
total_entries = len(data)
success_entries = sum(1 for e in data if e.get("success"))
all_preds = []
skipped = {
"not_success": 0,
"missing_uid_or_date": 0,
"missing_edss": 0,
"missing_conf": 0,
}
for entry in data:
if not entry.get("success"):
skipped["not_success"] += 1
continue
res = entry.get("result", {})
uid = res.get("unique_id")
md = res.get("MedDatum")
if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "":
skipped["missing_uid_or_date"] += 1
continue
edss_pred = parse_edss(res.get("EDSS"))
conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce")
if pd.isna(edss_pred):
skipped["missing_edss"] += 1
if pd.isna(conf):
skipped["missing_conf"] += 1
all_preds.append({
"unique_id": norm_str(uid),
"MedDatum": norm_str(md),
"key": norm_str(uid) + "_" + norm_str(md),
"EDSS_pred": edss_pred,
"confidence": conf,
})
df_pred = pd.DataFrame(all_preds)
# Pred logs
print("\n--- PRED LOG ---")
print(f"JSON total entries: {total_entries}")
print(f"JSON success entries: {success_entries}")
print(f"Pred rows loaded (success + has keys): {len(df_pred)}")
if len(df_pred) == 0:
print("[ERROR] No usable prediction rows found. Nothing to plot.")
return
print(f"Pred unique keys: {df_pred['key'].nunique()}")
print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}")
print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}")
print("Skipped counts:", skipped)
key_counts = df_pred["key"].value_counts()
dup_pred_rows = (key_counts > 1).sum()
max_rep = int(key_counts.max())
print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}")
print(f"Max repetitions of a single key in this JSON: {max_rep}")
if max_rep > 1:
print("Top repeated keys in this JSON:")
print(key_counts.head(10))
# ------------------------------------------------------------------
# 3) Merge
# ------------------------------------------------------------------
gt_key_set = set(df_gt["key"])
df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set)
not_in_gt = df_pred.loc[~df_pred["key_in_gt"]]
print("\n--- KEY MATCH LOG ---")
print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}")
print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}")
if len(not_in_gt) > 0:
print("[WARNING] Some prediction keys are not present in GT. First 10:")
print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10))
df_merged = df_pred.merge(
df_gt[["key", "EDSS_gt"]],
on="key",
how="inner",
validate="many_to_one"
)
print("\n--- MERGE LOG ---")
print(f"Merged rows (inner join): {len(df_merged)}")
print(f"Merged unique keys: {df_merged['key'].nunique()}")
print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}")
print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}")
print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}")
rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"])
print("\n--- FILTER LOG (what will be used for stats/plot) ---")
print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}")
if len(rows_complete) == 0:
print("[ERROR] No complete rows after filtering. Nothing to plot.")
return
rows_complete = rows_complete.copy()
rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs()
# ------------------------------------------------------------------
# 4) Pearson correlation on row-level data
# ------------------------------------------------------------------
corr_df = rows_complete.dropna(subset=["confidence", "abs_error"]).copy()
if len(corr_df) >= 2 and corr_df["confidence"].nunique() > 1 and corr_df["abs_error"].nunique() > 1:
r_value, p_value = pearsonr(corr_df["confidence"], corr_df["abs_error"])
corr_text = f"Pearson r = {r_value:.3f}\np = {p_value:.3g}\nn = {len(corr_df)}"
else:
r_value, p_value = np.nan, np.nan
corr_text = f"Pearson r = NA\np = NA\nn = {len(corr_df)}"
print("\n--- CORRELATION LOG ---")
print(corr_text.replace("\n", " | "))
# ------------------------------------------------------------------
# 5) Binning + stats
# ------------------------------------------------------------------
bins = [0, 70, 80, 90, 100]
labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"]
rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True)
conf_outside = rows_complete["conf_bin"].isna().sum()
print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}")
if conf_outside > 0:
print("Example confidences outside bins:")
print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list())
df_plot = rows_complete.dropna(subset=["conf_bin"])
stats = (
df_plot.groupby("conf_bin", observed=True)["abs_error"]
.agg(mean="mean", std="std", count="count")
.reindex(labels)
.reset_index()
)
print("\n--- BIN STATS ---")
print(stats)
low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]]
if not low_bins.empty:
print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:")
print(low_bins)
# ------------------------------------------------------------------
# 6) Plot
# ------------------------------------------------------------------
plt.figure(figsize=(13, 8))
colors = sns.color_palette("Blues", n_colors=len(labels))
means = stats["mean"].to_numpy()
counts = stats["count"].fillna(0).astype(int).to_numpy()
stds = stats["std"].to_numpy()
means_plot = np.nan_to_num(means, nan=0.0)
bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85)
sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan)
plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5)
valid_idx = np.where(~np.isnan(means))[0]
if len(valid_idx) >= 2:
x_idx = np.arange(len(labels))
z = np.polyfit(valid_idx, means[valid_idx], 1)
p = np.poly1d(z)
plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5)
trend_label = "Trend Line"
else:
trend_label = "Trend Line (insufficient bins)"
print("\n[INFO] Not enough non-empty bins to fit a trend line.")
# Data labels
for i, bar in enumerate(bars):
n_count = int(counts[i])
mae_val = means[i]
if np.isnan(mae_val) or n_count == 0:
txt = "empty"
y = 0.02
else:
txt = f"MAE: {mae_val:.2f}\nn={n_count}"
y = bar.get_height() + 0.04
plt.text(
bar.get_x() + bar.get_width()/2,
y,
txt,
ha="center",
va="bottom",
fontweight="bold",
fontsize=10
)
# Pearson correlation text box inside plot
ax = plt.gca()
ax.text(
0.02, 0.98,
corr_text,
transform=ax.transAxes,
ha="left",
va="top",
fontsize=11,
zorder=10,
bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="gray", alpha=0.95)
)
# Legend
legend_elements = [
Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"),
Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"),
Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"),
Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"),
Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label),
Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"),
Patch(color="none", label="Metric: Mean Absolute Error (MAE)")
]
plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend")
# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30)
plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12)
plt.xlabel("LLM Confidence Bracket", fontsize=12)
plt.grid(axis="y", linestyle=":", alpha=0.5)
ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0
plt.ylim(0, max(0.5, float(ymax) + 0.6))
plt.tight_layout()
plt.show()
print("\n" + "="*80)
print("DONE")
print("="*80)
# --- RUN ---
json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json"
gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv"
plot_single_json_error_analysis_with_log(json_path, gt_path)
##
# %% Certainty vs Delta (rewritten with robust parsing + detailed data loss logs)
import pandas as pd