diff --git a/.gitignore b/.gitignore index d8c01b5..0160b64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,29 @@ +<<<<<<< HEAD # Ignore all contents of these directories !**/*.py /Data/ /attach/ /results/ /enarcelona/ +======= +# 1. Broad Ignores +/Data/* +/attach/* +/results/* +/enarcelona/* +>>>>>>> Certainty .env +__pycache__/ +*.pyc +*.csv +======= +/reference/ +*.svg +>>>>>>> Stashed changes +# 2. Ignore virtual environments COMPLETELY +# This must come BEFORE the unignore rule +env*/ + +# 3. The "Unignore" rule (Whitelisting) +# We only unignore .py files that aren't already blocked by the rules above +!**/*.py diff --git a/.gitignore.orig b/.gitignore.orig new file mode 100644 index 0000000..0160b64 --- /dev/null +++ b/.gitignore.orig @@ -0,0 +1,29 @@ +<<<<<<< HEAD +# Ignore all contents of these directories +!**/*.py +/Data/ +/attach/ +/results/ +/enarcelona/ +======= +# 1. Broad Ignores +/Data/* +/attach/* +/results/* +/enarcelona/* +>>>>>>> Certainty +.env +__pycache__/ +*.pyc +*.csv +======= +/reference/ +*.svg +>>>>>>> Stashed changes +# 2. Ignore virtual environments COMPLETELY +# This must come BEFORE the unignore rule +env*/ + +# 3. The "Unignore" rule (Whitelisting) +# We only unignore .py files that aren't already blocked by the rules above +!**/*.py diff --git a/app.py b/app.py index 0f70ca9..def7c48 100644 --- a/app.py +++ b/app.py @@ -214,3 +214,8 @@ if __name__ == "__main__": print(f"Results saved to {output_json}") ## + + +# %% name +eXXXXXXXX +## diff --git a/audit.py b/audit.py new file mode 100644 index 0000000..7534646 --- /dev/null +++ b/audit.py @@ -0,0 +1,2682 @@ +# %% Confirm EDSS missing +import pandas as pd +import numpy as np + +def clean_series(s): + return s.astype(str).str.strip().str.lower() + +def gt_edss_audit(ground_truth_path, edss_col="EDSS"): + df_gt = pd.read_csv(ground_truth_path, sep=';') + + # normalize keys + df_gt['unique_id'] = clean_series(df_gt['unique_id']) + df_gt['MedDatum'] = clean_series(df_gt['MedDatum']) + df_gt['key'] = df_gt['unique_id'] + "_" + df_gt['MedDatum'] + + print("GT rows:", len(df_gt)) + print("GT unique keys:", df_gt['key'].nunique()) + + # IMPORTANT: parse EDSS robustly (German decimal commas etc.) + if edss_col in df_gt.columns: + edss_raw = df_gt[edss_col] + edss_num = pd.to_numeric( + edss_raw.astype(str).str.replace(",", ".", regex=False).str.strip(), + errors="coerce" + ) + df_gt["_edss_num"] = edss_num + + print(f"GT missing EDSS look (numeric-coerce): {df_gt['_edss_num'].isna().sum()}") + print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['_edss_num'].isna(), 'key'].nunique()}") + + # duplicates on key + dup = df_gt['key'].duplicated(keep=False) + print("GT duplicate-key rows:", dup.sum()) + if dup.any(): + # how many duplicate keys exist? + print("GT duplicate keys:", df_gt.loc[dup, 'key'].nunique()) + # of duplicate-key rows, how many have missing EDSS? + print("Duplicate-key rows with missing EDSS:", df_gt.loc[dup, "_edss_num"].isna().sum()) + + # show the worst offenders + print("\nTop duplicate keys (by count):") + print(df_gt.loc[dup, 'key'].value_counts().head(10)) + else: + print(f"EDSS column '{edss_col}' not found in GT columns:", df_gt.columns.tolist()) + + return df_gt + +df_gt = gt_edss_audit("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", edss_col="EDSS") + +## + + + + +# %% trace missing ones + +import json, glob, os +import pandas as pd + +def load_preds(json_dir_path): + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + file_name = os.path.basename(file_path) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + "unique_id": str(res.get("unique_id")).strip().lower(), + "MedDatum": str(res.get("MedDatum")).strip().lower(), + "file": file_name + }) + df_pred = pd.DataFrame(all_preds) + df_pred["key"] = df_pred["unique_id"] + "_" + df_pred["MedDatum"] + return df_pred + +df_pred = load_preds("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration") +print("Pred rows:", len(df_pred)) +print("Pred unique keys:", df_pred["key"].nunique()) + +# Suppose df_gt was returned from step 1 and has _edss_num + key +missing_gt_keys = set(df_gt.loc[df_gt["_edss_num"].isna(), "key"]) + +df_pred["gt_key_missing_edss"] = df_pred["key"].isin(missing_gt_keys) + +print("Pred rows whose GT key has missing EDSS:", df_pred["gt_key_missing_edss"].sum()) +print("Unique keys (among preds) whose GT EDSS missing:", df_pred.loc[df_pred["gt_key_missing_edss"], "key"].nunique()) + +print("\nTop files contributing to missing-GT-EDSS rows:") +print(df_pred.loc[df_pred["gt_key_missing_edss"], "file"].value_counts().head(20)) + +print("\nTop keys replicated in predictions (why count inflates):") +print(df_pred.loc[df_pred["gt_key_missing_edss"], "key"].value_counts().head(20)) + + +## + + +# %% verify + +merged = df_pred.merge( + df_gt[["key", "_edss_num"]], # use the numeric-coerced GT EDSS + on="key", + how="left", + validate="many_to_one" # will ERROR if GT has duplicate keys (GOOD!) +) + +print("Merged rows:", len(merged)) +print("Merged missing GT EDSS:", merged["_edss_num"].isna().sum()) + + +## + + +# %% 1json (rewritten with robust parsing + detailed data log) +import pandas as pd +import numpy as np +import json +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_single_json_error_analysis_with_log( + json_file_path, + ground_truth_path, + edss_gt_col="EDSS", + min_bin_count=5, +): + def norm_str(x): + # normalize identifiers and dates consistently + return str(x).strip().lower() + + def parse_edss(x): + # robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc. + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + print("\n" + "="*80) + print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)") + print("="*80) + print(f"JSON: {json_file_path}") + print(f"GT: {ground_truth_path}") + + # ------------------------------------------------------------------ + # 1) Load Ground Truth + # ------------------------------------------------------------------ + df_gt = pd.read_csv(ground_truth_path, sep=";") + + required_gt_cols = {"unique_id", "MedDatum", edss_gt_col} + missing_cols = required_gt_cols - set(df_gt.columns) + if missing_cols: + raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + + # Robust EDSS parsing (important!) + df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss) + + # GT logs + print("\n--- GT LOG ---") + print(f"GT rows: {len(df_gt)}") + print(f"GT unique keys: {df_gt['key'].nunique()}") + gt_dup = df_gt["key"].duplicated(keep=False).sum() + print(f"GT duplicate-key rows: {gt_dup}") + print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}") + print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}") + + if gt_dup > 0: + print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:") + print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10)) + + # ------------------------------------------------------------------ + # 2) Load Predictions from the specific JSON + # ------------------------------------------------------------------ + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + total_entries = len(data) + success_entries = sum(1 for e in data if e.get("success")) + + all_preds = [] + skipped = { + "not_success": 0, + "missing_uid_or_date": 0, + "missing_edss": 0, + "missing_conf": 0, + } + + for entry in data: + if not entry.get("success"): + skipped["not_success"] += 1 + continue + + res = entry.get("result", {}) + uid = res.get("unique_id") + md = res.get("MedDatum") + + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + skipped["missing_uid_or_date"] += 1 + continue + + edss_pred = parse_edss(res.get("EDSS")) + conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce") + + if pd.isna(edss_pred): + skipped["missing_edss"] += 1 + if pd.isna(conf): + skipped["missing_conf"] += 1 + + all_preds.append({ + "unique_id": norm_str(uid), + "MedDatum": norm_str(md), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": edss_pred, + "confidence": conf, + }) + + df_pred = pd.DataFrame(all_preds) + + # Pred logs + print("\n--- PRED LOG ---") + print(f"JSON total entries: {total_entries}") + print(f"JSON success entries: {success_entries}") + print(f"Pred rows loaded (success + has keys): {len(df_pred)}") + if len(df_pred) == 0: + print("[ERROR] No usable prediction rows found. Nothing to plot.") + return + + print(f"Pred unique keys: {df_pred['key'].nunique()}") + print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}") + print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}") + print("Skipped counts:", skipped) + + # Are keys duplicated within this JSON? (often yes if multiple notes map to same key) + key_counts = df_pred["key"].value_counts() + dup_pred_rows = (key_counts > 1).sum() + max_rep = int(key_counts.max()) + print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}") + print(f"Max repetitions of a single key in this JSON: {max_rep}") + if max_rep > 1: + print("Top repeated keys in this JSON:") + print(key_counts.head(10)) + + # ------------------------------------------------------------------ + # 3) Merge (and diagnose why rows drop) + # ------------------------------------------------------------------ + # Diagnose how many pred keys exist in GT + gt_key_set = set(df_gt["key"]) + df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set) + not_in_gt = df_pred.loc[~df_pred["key_in_gt"]] + + print("\n--- KEY MATCH LOG ---") + print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}") + print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}") + if len(not_in_gt) > 0: + print("[WARNING] Some prediction keys are not present in GT. First 10:") + print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10)) + + # Now merge; we expect GT is one-to-many with pred (many_to_one) + # If GT had duplicates, validate would raise. + df_merged = df_pred.merge( + df_gt[["key", "EDSS_gt"]], + on="key", + how="inner", + validate="many_to_one" + ) + + print("\n--- MERGE LOG ---") + print(f"Merged rows (inner join): {len(df_merged)}") + print(f"Merged unique keys: {df_merged['key'].nunique()}") + print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}") + print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}") + print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}") + + # How many rows will be removed by dropna() in your old code? + # Old code did .dropna() on ALL columns, which can remove rows for missing confidence too. + rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]) + print("\n--- FILTER LOG (what will be used for stats/plot) ---") + print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}") + if len(rows_complete) == 0: + print("[ERROR] No complete rows after filtering. Nothing to plot.") + return + + # Compute abs error + rows_complete = rows_complete.copy() + rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs() + + # ------------------------------------------------------------------ + # 4) Binning + stats (with guardrails) + # ------------------------------------------------------------------ + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + + # Confidence outside bins becomes NaN; log it + rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True) + conf_outside = rows_complete["conf_bin"].isna().sum() + print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}") + if conf_outside > 0: + print("Example confidences outside bins:") + print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list()) + + df_plot = rows_complete.dropna(subset=["conf_bin"]) + stats = ( + df_plot.groupby("conf_bin", observed=True)["abs_error"] + .agg(mean="mean", std="std", count="count") + .reindex(labels) + .reset_index() + ) + + print("\n--- BIN STATS ---") + print(stats) + + # Warn about low counts + low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]] + if not low_bins.empty: + print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:") + print(low_bins) + + # ------------------------------------------------------------------ + # 5) Plot + # ------------------------------------------------------------------ + plt.figure(figsize=(13, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + # Replace NaNs in mean for plotting bars (empty bins) + means = stats["mean"].to_numpy() + counts = stats["count"].fillna(0).astype(int).to_numpy() + stds = stats["std"].to_numpy() + + # For bins with no data, bar height 0 (and no errorbar) + means_plot = np.nan_to_num(means, nan=0.0) + + bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85) + + # Error bars only where count>1 and std is not NaN + sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan) + plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5) + + # Trend line only if at least 2 non-empty bins + valid_idx = np.where(~np.isnan(means))[0] + if len(valid_idx) >= 2: + x_idx = np.arange(len(labels)) + z = np.polyfit(valid_idx, means[valid_idx], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5) + trend_label = "Trend Line" + else: + trend_label = "Trend Line (insufficient bins)" + print("\n[INFO] Not enough non-empty bins to fit a trend line.") + + # Data labels + for i, bar in enumerate(bars): + n_count = int(counts[i]) + mae_val = means[i] + if np.isnan(mae_val) or n_count == 0: + txt = "empty" + y = 0.02 + else: + txt = f"MAE: {mae_val:.2f}\nn={n_count}" + y = bar.get_height() + 0.04 + plt.text( + bar.get_x() + bar.get_width()/2, + y, + txt, + ha="center", + va="bottom", + fontweight="bold", + fontsize=10 + ) + + # Legend + legend_elements = [ + Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"), + Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"), + Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"), + Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"), + Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label), + Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"), + Patch(color="none", label="Metric: Mean Absolute Error (MAE)") + ] + plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend") + +# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30) + plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12) + plt.xlabel("LLM Confidence Bracket", fontsize=12) + plt.grid(axis="y", linestyle=":", alpha=0.5) + + ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0 + plt.ylim(0, max(0.5, float(ymax) + 0.6)) + plt.tight_layout() + plt.show() + + print("\n" + "="*80) + print("DONE") + print("="*80) + + +# --- RUN --- +json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" + +plot_single_json_error_analysis_with_log(json_path, gt_path) + + + +## + +# %% 1json (rewritten with robust parsing + detailed data log + Pearson r in plot) +import pandas as pd +import numpy as np +import json +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D +from scipy.stats import pearsonr + +def plot_single_json_error_analysis_with_log( + json_file_path, + ground_truth_path, + edss_gt_col="EDSS", + min_bin_count=5, +): + def norm_str(x): + # normalize identifiers and dates consistently + return str(x).strip().lower() + + def parse_edss(x): + # robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc. + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + print("\n" + "="*80) + print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)") + print("="*80) + print(f"JSON: {json_file_path}") + print(f"GT: {ground_truth_path}") + + # ------------------------------------------------------------------ + # 1) Load Ground Truth + # ------------------------------------------------------------------ + df_gt = pd.read_csv(ground_truth_path, sep=";") + + required_gt_cols = {"unique_id", "MedDatum", edss_gt_col} + missing_cols = required_gt_cols - set(df_gt.columns) + if missing_cols: + raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + + # Robust EDSS parsing + df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss) + + # GT logs + print("\n--- GT LOG ---") + print(f"GT rows: {len(df_gt)}") + print(f"GT unique keys: {df_gt['key'].nunique()}") + gt_dup = df_gt["key"].duplicated(keep=False).sum() + print(f"GT duplicate-key rows: {gt_dup}") + print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}") + print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}") + + if gt_dup > 0: + print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:") + print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10)) + + # ------------------------------------------------------------------ + # 2) Load Predictions from the specific JSON + # ------------------------------------------------------------------ + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + total_entries = len(data) + success_entries = sum(1 for e in data if e.get("success")) + + all_preds = [] + skipped = { + "not_success": 0, + "missing_uid_or_date": 0, + "missing_edss": 0, + "missing_conf": 0, + } + + for entry in data: + if not entry.get("success"): + skipped["not_success"] += 1 + continue + + res = entry.get("result", {}) + uid = res.get("unique_id") + md = res.get("MedDatum") + + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + skipped["missing_uid_or_date"] += 1 + continue + + edss_pred = parse_edss(res.get("EDSS")) + conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce") + + if pd.isna(edss_pred): + skipped["missing_edss"] += 1 + if pd.isna(conf): + skipped["missing_conf"] += 1 + + all_preds.append({ + "unique_id": norm_str(uid), + "MedDatum": norm_str(md), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": edss_pred, + "confidence": conf, + }) + + df_pred = pd.DataFrame(all_preds) + + # Pred logs + print("\n--- PRED LOG ---") + print(f"JSON total entries: {total_entries}") + print(f"JSON success entries: {success_entries}") + print(f"Pred rows loaded (success + has keys): {len(df_pred)}") + if len(df_pred) == 0: + print("[ERROR] No usable prediction rows found. Nothing to plot.") + return + + print(f"Pred unique keys: {df_pred['key'].nunique()}") + print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}") + print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}") + print("Skipped counts:", skipped) + + key_counts = df_pred["key"].value_counts() + dup_pred_rows = (key_counts > 1).sum() + max_rep = int(key_counts.max()) + print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}") + print(f"Max repetitions of a single key in this JSON: {max_rep}") + if max_rep > 1: + print("Top repeated keys in this JSON:") + print(key_counts.head(10)) + + # ------------------------------------------------------------------ + # 3) Merge + # ------------------------------------------------------------------ + gt_key_set = set(df_gt["key"]) + df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set) + not_in_gt = df_pred.loc[~df_pred["key_in_gt"]] + + print("\n--- KEY MATCH LOG ---") + print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}") + print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}") + if len(not_in_gt) > 0: + print("[WARNING] Some prediction keys are not present in GT. First 10:") + print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10)) + + df_merged = df_pred.merge( + df_gt[["key", "EDSS_gt"]], + on="key", + how="inner", + validate="many_to_one" + ) + + print("\n--- MERGE LOG ---") + print(f"Merged rows (inner join): {len(df_merged)}") + print(f"Merged unique keys: {df_merged['key'].nunique()}") + print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}") + print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}") + print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}") + + rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]) + print("\n--- FILTER LOG (what will be used for stats/plot) ---") + print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}") + if len(rows_complete) == 0: + print("[ERROR] No complete rows after filtering. Nothing to plot.") + return + + rows_complete = rows_complete.copy() + rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs() + + # ------------------------------------------------------------------ + # 4) Pearson correlation on row-level data + # ------------------------------------------------------------------ + corr_df = rows_complete.dropna(subset=["confidence", "abs_error"]).copy() + + if len(corr_df) >= 2 and corr_df["confidence"].nunique() > 1 and corr_df["abs_error"].nunique() > 1: + r_value, p_value = pearsonr(corr_df["confidence"], corr_df["abs_error"]) + corr_text = f"Pearson r = {r_value:.3f}\np = {p_value:.3g}\nn = {len(corr_df)}" + else: + r_value, p_value = np.nan, np.nan + corr_text = f"Pearson r = NA\np = NA\nn = {len(corr_df)}" + + print("\n--- CORRELATION LOG ---") + print(corr_text.replace("\n", " | ")) + + # ------------------------------------------------------------------ + # 5) Binning + stats + # ------------------------------------------------------------------ + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + + rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True) + conf_outside = rows_complete["conf_bin"].isna().sum() + print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}") + if conf_outside > 0: + print("Example confidences outside bins:") + print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list()) + + df_plot = rows_complete.dropna(subset=["conf_bin"]) + stats = ( + df_plot.groupby("conf_bin", observed=True)["abs_error"] + .agg(mean="mean", std="std", count="count") + .reindex(labels) + .reset_index() + ) + + print("\n--- BIN STATS ---") + print(stats) + + low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]] + if not low_bins.empty: + print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:") + print(low_bins) + + # ------------------------------------------------------------------ + # 6) Plot + # ------------------------------------------------------------------ + plt.figure(figsize=(13, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + means = stats["mean"].to_numpy() + counts = stats["count"].fillna(0).astype(int).to_numpy() + stds = stats["std"].to_numpy() + means_plot = np.nan_to_num(means, nan=0.0) + + bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85) + + sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan) + plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5) + + valid_idx = np.where(~np.isnan(means))[0] + if len(valid_idx) >= 2: + x_idx = np.arange(len(labels)) + z = np.polyfit(valid_idx, means[valid_idx], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5) + trend_label = "Trend Line" + else: + trend_label = "Trend Line (insufficient bins)" + print("\n[INFO] Not enough non-empty bins to fit a trend line.") + + # Data labels + for i, bar in enumerate(bars): + n_count = int(counts[i]) + mae_val = means[i] + if np.isnan(mae_val) or n_count == 0: + txt = "empty" + y = 0.02 + else: + txt = f"MAE: {mae_val:.2f}\nn={n_count}" + y = bar.get_height() + 0.04 + plt.text( + bar.get_x() + bar.get_width()/2, + y, + txt, + ha="center", + va="bottom", + fontweight="bold", + fontsize=10 + ) + + # Pearson correlation text box inside plot + ax = plt.gca() + ax.text( + 0.02, 0.98, + corr_text, + transform=ax.transAxes, + ha="left", + va="top", + fontsize=11, + zorder=10, + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="gray", alpha=0.95) + ) + # Legend + legend_elements = [ + Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"), + Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"), + Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"), + Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"), + Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label), + Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"), + Patch(color="none", label="Metric: Mean Absolute Error (MAE)") + ] + plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend") + +# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30) + plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12) + plt.xlabel("LLM Confidence Bracket", fontsize=12) + plt.grid(axis="y", linestyle=":", alpha=0.5) + + ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0 + plt.ylim(0, max(0.5, float(ymax) + 0.6)) + plt.tight_layout() + plt.show() + + print("\n" + "="*80) + print("DONE") + print("="*80) + + +# --- RUN --- +json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" + +plot_single_json_error_analysis_with_log(json_path, gt_path) +## + +# %% Certainty vs Delta (rewritten with robust parsing + detailed data loss logs) +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_confidence_vs_abs_error_with_log( + json_dir_path, + ground_truth_path, + edss_gt_col="EDSS", + min_bin_count=5, + include_lowest=True, +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + # robust numeric parse: handles comma decimals and empty tokens + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + print("\n" + "="*90) + print("CERTAINTY vs ABS ERROR (ALL JSONs) — WITH DATA LOSS LOG") + print("="*90) + print(f"JSON DIR: {json_dir_path}") + print(f"GT FILE: {ground_truth_path}") + + # ------------------------------------------------------------------ + # 1) Load GT + # ------------------------------------------------------------------ + df_gt = pd.read_csv(ground_truth_path, sep=";") + required_gt_cols = {"unique_id", "MedDatum", edss_gt_col} + missing_cols = required_gt_cols - set(df_gt.columns) + if missing_cols: + raise ValueError(f"GT missing columns: {missing_cols}. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss) + + # GT logs + print("\n--- GT LOG ---") + print(f"GT rows: {len(df_gt)}") + print(f"GT unique keys: {df_gt['key'].nunique()}") + gt_dup_rows = df_gt["key"].duplicated(keep=False).sum() + print(f"GT duplicate-key rows: {gt_dup_rows}") + print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}") + print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}") + if gt_dup_rows > 0: + print("\n[WARNING] GT has duplicate keys; merge can explode rows. Top duplicate keys:") + print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10)) + + gt_key_set = set(df_gt["key"]) + + # ------------------------------------------------------------------ + # 2) Load predictions from all JSON files (with per-file logs) + # ------------------------------------------------------------------ + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + all_preds = [] + per_file_summary = [] + + total_entries_all = 0 + total_success_all = 0 + skipped_all = {"not_success": 0, "missing_uid_or_date": 0} + + for file_path in json_files: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + total_entries = len(data) + success_entries = sum(1 for e in data if e.get("success")) + + total_entries_all += total_entries + total_success_all += success_entries + + skipped = {"not_success": 0, "missing_uid_or_date": 0} + loaded_rows = 0 + + for entry in data: + if not entry.get("success"): + skipped["not_success"] += 1 + continue + res = entry.get("result", {}) + uid = res.get("unique_id") + md = res.get("MedDatum") + + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + skipped["missing_uid_or_date"] += 1 + continue + + all_preds.append({ + "file": os.path.basename(file_path), + "unique_id": norm_str(uid), + "MedDatum": norm_str(md), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + "confidence": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + loaded_rows += 1 + + skipped_all["not_success"] += skipped["not_success"] + skipped_all["missing_uid_or_date"] += skipped["missing_uid_or_date"] + + per_file_summary.append({ + "file": os.path.basename(file_path), + "entries_total": total_entries, + "entries_success": success_entries, + "pred_rows_loaded": loaded_rows, + "skipped_not_success": skipped["not_success"], + "skipped_missing_uid_or_date": skipped["missing_uid_or_date"], + }) + + df_pred = pd.DataFrame(all_preds) + df_file = pd.DataFrame(per_file_summary) + + # PRED logs + print("\n--- PRED LOG (ALL FILES) ---") + print(f"JSON files found: {len(json_files)}") + print(f"Total JSON entries: {total_entries_all}") + print(f"Total success entries:{total_success_all}") + print(f"Pred rows loaded (success + has keys): {len(df_pred)}") + if len(df_pred) == 0: + print("[ERROR] No usable prediction rows found. Nothing to plot.") + return + + print(f"Pred unique keys (across all files): {df_pred['key'].nunique()}") + print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}") + print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}") + print("Skipped totals:", skipped_all) + + # show per-file quick check (useful when one iteration is broken) + print("\nPer-file loaded rows (head):") + print(df_file.sort_values("file").head(10)) + + # ------------------------------------------------------------------ + # 3) Key match log (pred -> GT) + # ------------------------------------------------------------------ + df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set) + not_in_gt = df_pred.loc[~df_pred["key_in_gt"]] + + print("\n--- KEY MATCH LOG ---") + print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}") + print(f"Pred rows with key NOT in GT: {len(not_in_gt)}") + if len(not_in_gt) > 0: + print("[WARNING] Example keys not found in GT (first 10):") + print(not_in_gt[["file", "unique_id", "MedDatum", "key"]].head(10)) + print("\n[WARNING] Files contributing most to key-mismatch:") + print(not_in_gt["file"].value_counts().head(10)) + + # ------------------------------------------------------------------ + # 4) Merge (no dropna yet) + detailed data loss accounting + # ------------------------------------------------------------------ + df_merged = df_pred.merge( + df_gt[["key", "EDSS_gt"]], + on="key", + how="inner", + validate="many_to_one" # catches GT duplicates + ) + + print("\n--- MERGE LOG ---") + print(f"Merged rows (inner join): {len(df_merged)}") + print(f"Merged unique keys: {df_merged['key'].nunique()}") + + # Now quantify what you lose at each filter stage + n0 = len(df_merged) + + miss_gt = df_merged["EDSS_gt"].isna() + miss_pred = df_merged["EDSS_pred"].isna() + miss_conf = df_merged["confidence"].isna() + + print("\n--- MISSINGNESS IN MERGED ---") + print(f"Missing GT EDSS: {miss_gt.sum()}") + print(f"Missing Pred EDSS: {miss_pred.sum()}") + print(f"Missing Confidence: {miss_conf.sum()}") + + # IMPORTANT: your old code used .dropna() with no subset => drops if ANY column is NaN. + # We'll replicate the intended logic explicitly and log counts. + df_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]) + n1 = len(df_complete) + print("\n--- FILTER LOG ---") + print(f"Rows before filtering: {n0}") + print(f"Rows after requiring EDSS_gt, EDSS_pred, confidence: {n1}") + print(f"Rows lost due to missing required fields: {n0 - n1}") + + # Break down why rows were lost (overlap-aware) + lost_mask = df_merged[["EDSS_gt", "EDSS_pred", "confidence"]].isna().any(axis=1) + lost = df_merged.loc[lost_mask].copy() + if len(lost) > 0: + lost_reason = ( + (lost["EDSS_gt"].isna()).astype(int).map({1:"GT",0:""}) + + (lost["EDSS_pred"].isna()).astype(int).map({1:"+PRED",0:""}) + + (lost["confidence"].isna()).astype(int).map({1:"+CONF",0:""}) + ) + lost["loss_reason"] = lost_reason.str.replace(r"^\+", "", regex=True).replace("", "UNKNOWN") + print("\nTop loss reasons (overlap-aware):") + print(lost["loss_reason"].value_counts().head(10)) + + print("\nFiles contributing most to lost rows:") + print(lost["file"].value_counts().head(10)) + + if len(df_complete) == 0: + print("[ERROR] No complete rows left after filtering. Nothing to plot.") + return + + # ------------------------------------------------------------------ + # 5) Abs error + binning + # ------------------------------------------------------------------ + df_complete = df_complete.copy() + df_complete["abs_error"] = (df_complete["EDSS_pred"] - df_complete["EDSS_gt"]).abs() + + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + + df_complete["conf_bin"] = pd.cut( + df_complete["confidence"], + bins=bins, + labels=labels, + include_lowest=include_lowest + ) + + conf_outside = df_complete["conf_bin"].isna().sum() + print("\n--- BINNING LOG ---") + print(f"Rows with confidence outside bin edges / invalid: {conf_outside}") + if conf_outside > 0: + print("Example out-of-bin confidences:") + print(df_complete.loc[df_complete["conf_bin"].isna(), "confidence"].head(20).to_list()) + + df_plot = df_complete.dropna(subset=["conf_bin"]) + print(f"Rows kept for bin stats/plot (after dropping out-of-bin): {len(df_plot)}") + print(f"Rows lost due to out-of-bin confidence: {len(df_complete) - len(df_plot)}") + + stats = ( + df_plot.groupby("conf_bin", observed=True)["abs_error"] + .agg(mean="mean", std="std", count="count") + .reindex(labels) + .reset_index() + ) + + print("\n--- BIN STATS ---") + print(stats) + + low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]] + if not low_bins.empty: + print(f"\n[WARNING] Some bins have < {min_bin_count} rows (unstable SEM/trend):") + print(low_bins) + + # ------------------------------------------------------------------ + # 6) Plot + # ------------------------------------------------------------------ + plt.figure(figsize=(12, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + means = stats["mean"].to_numpy() + counts = stats["count"].fillna(0).astype(int).to_numpy() + stds = stats["std"].to_numpy() + + means_plot = np.nan_to_num(means, nan=0.0) + bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", linewidth=1.2) + + sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan) + plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=6, elinewidth=1.5) + + # Trend line only if >=2 non-empty bins + valid_idx = np.where(~np.isnan(means))[0] + if len(valid_idx) >= 2: + x_idx = np.arange(len(labels)) + z = np.polyfit(valid_idx, means[valid_idx], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=2.5) + trend_label = "Correlation Trend" + else: + trend_label = "Correlation Trend (insufficient bins)" + print("\n[INFO] Not enough non-empty bins to fit a trend line.") + + # Bar annotations (MAE + n) + for i, bar in enumerate(bars): + n = int(counts[i]) + m = means[i] + if n == 0 or np.isnan(m): + txt = "empty" + y = 0.02 + else: + txt = f"MAE: {m:.2f}\nn={n}" + y = bar.get_height() + 0.05 + plt.text(bar.get_x() + bar.get_width()/2, y, txt, ha="center", fontweight="bold") + + legend_elements = [ + Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"), + Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"), + Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"), + Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"), + Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Standard Error (SEM)"), + Line2D([0], [0], color="#e74c3c", linestyle="--", lw=2.5, label=trend_label), + Patch(color="none", label="Metric: Mean Absolute Error (MAE)") + ] + plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, fontsize=10, title="Legend") + + plt.title("Validation: Inverse Correlation of Confidence vs. Error Magnitude", fontsize=15, pad=20) + plt.ylabel("Mean Absolute Error (Δ EDSS Points)", fontsize=12) + plt.xlabel("LLM Confidence Bracket", fontsize=12) + plt.grid(axis="y", linestyle=":", alpha=0.5) + + ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0 + plt.ylim(0, max(0.5, float(ymax) + 0.6)) + plt.tight_layout() + plt.show() + + print("\n" + "="*90) + print("DONE") + print("="*90) + + +# Example run: +plot_confidence_vs_abs_error_with_log("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv") + + +## + + +# %% Empirical Confidence +# Empirical stability confidence (from 10 runs) + LLM certainty_percent as secondary signal +# - Reads all JSONs in a folder (your 10 iterations) +# - Aggregates by key = unique_id + MedDatum +# - Computes: +# * EDSS_mean, EDSS_std, EDSS_iqr, mode/share +# * empirical_conf_0_100 (based on stability) +# * llm_conf_mean_0_100 (mean certainty_percent) +# * combined_conf_0_100 (weighted blend) +# - Optional: merges GT EDSS and computes abs error on the aggregated prediction + +import os, glob, json +import numpy as np +import pandas as pd + +def build_empirical_confidence_table( + json_dir_path: str, + ground_truth_path: str | None = None, + gt_sep: str = ";", + gt_edss_col: str = "EDSS", + w_empirical: float = 0.7, # weight for empirical stability + w_llm: float = 0.3, # weight for LLM self-reported confidence + tol_mode: float = 0.5, # tolerance to treat EDSS as "same" (EDSS often in 0.5 steps) + min_runs_expected: int = 10, +): + # ----------------------------- + # Helpers + # ----------------------------- + def norm_str(x): + return str(x).strip().lower() + + def parse_number(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + def robust_iqr(x: pd.Series): + x = x.dropna() + if len(x) == 0: + return np.nan + return float(x.quantile(0.75) - x.quantile(0.25)) + + def stability_to_confidence(std_val: float) -> float: + """ + Map EDSS variability across runs to a 0..100 confidence. + EDSS is typically on 0.5 steps. A natural scale: + std ~= 0.0 -> ~100 + std ~= 0.25 -> ~75-90 + std ~= 0.5 -> ~50-70 + std >= 1.0 -> low + Use a smooth exponential mapping. + """ + if np.isnan(std_val): + return np.nan + # scale parameter: std=0.5 -> exp(-1)=0.367 -> ~36.7 + scale = 0.5 + conf = 100.0 * np.exp(-(std_val / scale)) + # clamp + return float(np.clip(conf, 0.0, 100.0)) + + def mode_share_with_tolerance(values: np.ndarray, tol: float) -> tuple[float, float]: + """ + Compute a 'mode' under tolerance: pick the cluster center (median) and count + how many values fall within +/- tol. Return (mode_center, share). + This is robust to tiny float differences. + """ + vals = values[~np.isnan(values)] + if len(vals) == 0: + return (np.nan, np.nan) + center = float(np.median(vals)) + share = float(np.mean(np.abs(vals - center) <= tol)) + return (center, share) + + # ----------------------------- + # Load predictions from all JSONs + # ----------------------------- + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + rows = [] + per_file = [] + total_entries_all = 0 + total_success_all = 0 + skipped_all = {"not_success": 0, "missing_uid_or_date": 0} + + for fp in json_files: + with open(fp, "r", encoding="utf-8") as f: + data = json.load(f) + + total_entries = len(data) + success_entries = sum(1 for e in data if e.get("success")) + total_entries_all += total_entries + total_success_all += success_entries + + skipped = {"not_success": 0, "missing_uid_or_date": 0} + loaded = 0 + + for entry in data: + if not entry.get("success"): + skipped["not_success"] += 1 + continue + + res = entry.get("result", {}) + uid = res.get("unique_id") + md = res.get("MedDatum") + + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + skipped["missing_uid_or_date"] += 1 + continue + + edss = parse_number(res.get("EDSS")) + conf = parse_number(res.get("certainty_percent")) + it = res.get("iteration", None) + + rows.append({ + "file": os.path.basename(fp), + "iteration": it, + "unique_id": norm_str(uid), + "MedDatum": norm_str(md), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": edss, + "llm_conf": conf, + }) + loaded += 1 + + skipped_all["not_success"] += skipped["not_success"] + skipped_all["missing_uid_or_date"] += skipped["missing_uid_or_date"] + + per_file.append({ + "file": os.path.basename(fp), + "entries_total": total_entries, + "entries_success": success_entries, + "rows_loaded": loaded, + "skipped_not_success": skipped["not_success"], + "skipped_missing_uid_or_date": skipped["missing_uid_or_date"], + }) + + df_pred = pd.DataFrame(rows) + df_file = pd.DataFrame(per_file) + + # ----------------------------- + # Logs: ingestion + # ----------------------------- + print("\n" + "="*90) + print("EMPIRICAL CONFIDENCE (10-RUN STABILITY) + LLM CONFIDENCE (SECONDARY)") + print("="*90) + print(f"JSON DIR: {json_dir_path}") + print(f"JSON files: {len(json_files)}") + print("\n--- INGEST LOG ---") + print(f"Total JSON entries: {total_entries_all}") + print(f"Total success entries:{total_success_all}") + print(f"Pred rows loaded: {len(df_pred)}") + print(f"Unique keys in preds: {df_pred['key'].nunique() if len(df_pred) else 0}") + print(f"Missing EDSS_pred: {df_pred['EDSS_pred'].isna().sum() if len(df_pred) else 0}") + print(f"Missing llm_conf: {df_pred['llm_conf'].isna().sum() if len(df_pred) else 0}") + print("Skipped totals:", skipped_all) + + print("\nPer-file summary (top 10 by name):") + print(df_file.sort_values("file").head(10)) + + # ----------------------------- + # Aggregate by key (empirical stability) + # ----------------------------- + if len(df_pred) == 0: + print("[ERROR] No usable prediction rows.") + return None + + # how many runs per key (expect ~10) + runs_per_key = df_pred.groupby("key")["EDSS_pred"].size().rename("n_rows").reset_index() + print("\n--- RUNS PER KEY LOG ---") + print(f"Keys with at least 1 row: {len(runs_per_key)}") + print("Distribution of rows per key (value_counts):") + print(runs_per_key["n_rows"].value_counts().sort_index()) + + # Aggregate stats + def agg_block(g: pd.DataFrame): + ed = g["EDSS_pred"].to_numpy(dtype=float) + ll = g["llm_conf"].to_numpy(dtype=float) + + n_rows = len(g) + n_edss = int(np.sum(~np.isnan(ed))) + n_llm = int(np.sum(~np.isnan(ll))) + + ed_mean = float(np.nanmean(ed)) if n_edss else np.nan + ed_std = float(np.nanstd(ed, ddof=1)) if n_edss >= 2 else (0.0 if n_edss == 1 else np.nan) + ed_iqr = robust_iqr(pd.Series(ed)) + mode_center, mode_share = mode_share_with_tolerance(ed, tol=tol_mode) + + llm_mean = float(np.nanmean(ll)) if n_llm else np.nan + llm_std = float(np.nanstd(ll, ddof=1)) if n_llm >= 2 else (0.0 if n_llm == 1 else np.nan) + + emp_conf = stability_to_confidence(ed_std) if not np.isnan(ed_std) else np.nan + + # Combined confidence (weighted). If one side missing, fall back to the other. + if np.isnan(emp_conf) and np.isnan(llm_mean): + comb = np.nan + elif np.isnan(emp_conf): + comb = llm_mean + elif np.isnan(llm_mean): + comb = emp_conf + else: + comb = w_empirical * emp_conf + w_llm * llm_mean + + return pd.Series({ + "unique_id": g["unique_id"].iloc[0], + "MedDatum": g["MedDatum"].iloc[0], + "n_rows": n_rows, + "n_edss": n_edss, + "n_llm_conf":n_llm, + "EDSS_mean": ed_mean, + "EDSS_std": ed_std, + "EDSS_iqr": ed_iqr, + "EDSS_mode_center": mode_center, + "EDSS_mode_share": mode_share, # fraction within ±tol_mode of median center + "llm_conf_mean": llm_mean, + "llm_conf_std": llm_std, + "empirical_conf_0_100": emp_conf, + "combined_conf_0_100": float(np.clip(comb, 0.0, 100.0)) if not np.isnan(comb) else np.nan, + }) + + df_agg = df_pred.groupby("key", as_index=False).apply(agg_block) + # groupby+apply returns a multiindex sometimes depending on pandas version + if isinstance(df_agg.index, pd.MultiIndex): + df_agg = df_agg.reset_index(drop=True) + + # Logs: aggregation + losses + print("\n--- AGGREGATION LOG ---") + print(f"Aggregated keys: {len(df_agg)}") + print(f"Keys with EDSS in >=1 run: {(df_agg['n_edss'] >= 1).sum()}") + print(f"Keys with EDSS in >=2 runs (std meaningful): {(df_agg['n_edss'] >= 2).sum()}") + print(f"Keys missing EDSS in all runs: {(df_agg['n_edss'] == 0).sum()}") + print(f"Keys missing llm_conf in all runs: {(df_agg['n_llm_conf'] == 0).sum()}") + + # Expected runs check + if min_runs_expected is not None: + print(f"\nKeys with < {min_runs_expected} rows (potential missing iterations):") + print(df_agg.loc[df_agg["n_rows"] < min_runs_expected, ["key", "n_rows"]].sort_values("n_rows").head(20)) + + # ----------------------------- + # Optional: merge GT and compute error on aggregated EDSS_mean + # ----------------------------- + if ground_truth_path is not None: + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + need = {"unique_id", "MedDatum", gt_edss_col} + miss = need - set(df_gt.columns) + if miss: + raise ValueError(f"GT missing columns: {miss}. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].apply(parse_number) + + print("\n--- GT MERGE LOG ---") + print(f"GT rows: {len(df_gt)} | GT unique keys: {df_gt['key'].nunique()}") + print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}") + + df_final = df_agg.merge(df_gt[["key", "EDSS_gt"]], on="key", how="left", validate="one_to_one") + + print(f"Aggregated keys with GT match: {df_final['EDSS_gt'].notna().sum()} / {len(df_final)}") + print(f"Aggregated keys missing GT EDSS: {df_final['EDSS_gt'].isna().sum()}") + + df_final["abs_error_mean"] = (df_final["EDSS_mean"] - df_final["EDSS_gt"]).abs() + + # How many keys usable for evaluation? + usable = df_final.dropna(subset=["EDSS_mean", "EDSS_gt"]) + print("\n--- EVAL LOG (AGGREGATED) ---") + print(f"Keys with both EDSS_mean and EDSS_gt: {len(usable)}") + if len(usable) > 0: + print(f"MAE on EDSS_mean vs GT: {usable['abs_error_mean'].mean():.3f}") + print(f"Median abs error: {usable['abs_error_mean'].median():.3f}") + + return df_final + + return df_agg + + +# Example usage: +df = build_empirical_confidence_table(json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", w_empirical=0.7, w_llm=0.3, tol_mode=0.5,min_runs_expected=10,) +df.to_csv("empirical_confidence_table.csv", index=False) + +## + + + + +# %% Executive Boxplot +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +def plot_exec_boxplots(df, min_bin_size_warn=10): + """ + Two side-by-side boxplots: + - Left: abs_error_mean grouped by empirical_conf_0_100 quantile bins + - Right: abs_error_mean grouped by llm_conf_mean quantile bins + + Adds: + - Robust qcut labeling (handles ties; bins may be < 4) + - Data logs + per-bin summary table printed + - Clear legend explaining each panel and what box elements mean + """ + need_cols = ["abs_error_mean", "empirical_conf_0_100", "llm_conf_mean"] + missing = [c for c in need_cols if c not in df.columns] + if missing: + raise ValueError(f"Missing columns in df: {missing}. Available: {df.columns.tolist()}") + + d = df[need_cols].copy() + + # ----------------------------- + # Data logs: survivorship + # ----------------------------- + d_emp = d.dropna(subset=["abs_error_mean", "empirical_conf_0_100"]).copy() + d_llm = d.dropna(subset=["abs_error_mean", "llm_conf_mean"]).copy() + + print("\n" + "="*90) + print("EXECUTIVE BOXPLOTS — DATA LOG + SUMMARY") + print("="*90) + print(f"Total rows in df: {len(df)}") + print(f"Rows for empirical plot: {len(d_emp)} (dropped {len(df) - len(d_emp)})") + print(f"Rows for LLM plot: {len(d_llm)} (dropped {len(df) - len(d_llm)})") + + if len(d_emp) == 0 or len(d_llm) == 0: + print("[ERROR] Not enough data after dropping NaNs to build both plots.") + return + + # ----------------------------- + # Robust quantile binning (handles ties) + # ----------------------------- + # Empirical + emp_bins = pd.qcut(d_emp["empirical_conf_0_100"], q=4, duplicates="drop") + k_emp = emp_bins.cat.categories.size + emp_labels = [f"Q{i+1}" for i in range(k_emp)] + d_emp["emp_q"] = pd.qcut(d_emp["empirical_conf_0_100"], q=4, duplicates="drop", labels=emp_labels) + + # LLM + llm_bins = pd.qcut(d_llm["llm_conf_mean"], q=4, duplicates="drop") + k_llm = llm_bins.cat.categories.size + llm_labels = [f"Q{i+1}" for i in range(k_llm)] + d_llm["llm_q"] = pd.qcut(d_llm["llm_conf_mean"], q=4, duplicates="drop", labels=llm_labels) + + # Print bin edges (so you can discuss exact thresholds) + print("\n--- BIN EDGES (actual ranges) ---") + print("Empirical confidence bins:") + for i, interval in enumerate(emp_bins.cat.categories): + print(f" {emp_labels[i]}: {interval}") + print("LLM confidence bins:") + for i, interval in enumerate(llm_bins.cat.categories): + print(f" {llm_labels[i]}: {interval}") + + # ----------------------------- + # Summary tables (per bin) + # ----------------------------- + def summarize_bins(df_in, bin_col, conf_col, label): + g = df_in.groupby(bin_col, observed=True).agg( + n=("abs_error_mean", "size"), + mae_mean=("abs_error_mean", "mean"), + mae_median=("abs_error_mean", "median"), + mae_q25=("abs_error_mean", lambda x: x.quantile(0.25)), + mae_q75=("abs_error_mean", lambda x: x.quantile(0.75)), + conf_mean=(conf_col, "mean"), + conf_median=(conf_col, "median"), + ).reset_index().rename(columns={bin_col: "bin"}) + g["panel"] = label + return g[["panel", "bin", "n", "mae_mean", "mae_median", "mae_q25", "mae_q75", "conf_mean", "conf_median"]] + + summary_emp = summarize_bins(d_emp, "emp_q", "empirical_conf_0_100", "Empirical") + summary_llm = summarize_bins(d_llm, "llm_q", "llm_conf_mean", "LLM") + + print("\n--- SUMMARY TABLE: Empirical confidence quartiles (or fewer if ties) ---") + print(summary_emp.to_string(index=False, float_format=lambda x: f"{x:.3f}")) + + print("\n--- SUMMARY TABLE: LLM confidence quartiles (or fewer if ties) ---") + print(summary_llm.to_string(index=False, float_format=lambda x: f"{x:.3f}")) + + # Warn about small bins + small_emp = summary_emp.loc[summary_emp["n"] < min_bin_size_warn, ["bin", "n"]] + small_llm = summary_llm.loc[summary_llm["n"] < min_bin_size_warn, ["bin", "n"]] + if not small_emp.empty or not small_llm.empty: + print(f"\n[WARNING] Some bins have < {min_bin_size_warn} points; compare them cautiously.") + if not small_emp.empty: + print(" Empirical small bins:") + print(small_emp.to_string(index=False)) + if not small_llm.empty: + print(" LLM small bins:") + print(small_llm.to_string(index=False)) + + # ----------------------------- + # Prepare data for boxplots + # ----------------------------- + emp_cats = list(d_emp["emp_q"].cat.categories) + llm_cats = list(d_llm["llm_q"].cat.categories) + + emp_groups = [d_emp.loc[d_emp["emp_q"] == q, "abs_error_mean"].values for q in emp_cats] + llm_groups = [d_llm.loc[d_llm["llm_q"] == q, "abs_error_mean"].values for q in llm_cats] + + # ----------------------------- + # Plot + # ----------------------------- + fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True) + + bp0 = axes[0].boxplot(emp_groups, labels=emp_cats, showfliers=False, patch_artist=True) + bp1 = axes[1].boxplot(llm_groups, labels=llm_cats, showfliers=False, patch_artist=True) + + # Make panels visually distinct but still simple (no extra clutter) + for patch in bp0["boxes"]: + patch.set_alpha(0.6) + for patch in bp1["boxes"]: + patch.set_alpha(0.6) + + axes[0].set_title("Error by Empirical Confidence (quantile bins)") + axes[0].set_xlabel("Empirical confidence bin") + axes[0].set_ylabel("Absolute Error (|EDSS_mean − EDSS_gt|)") + + axes[1].set_title("Error by LLM Confidence (quantile bins)") + axes[1].set_xlabel("LLM confidence bin") + + for ax in axes: + ax.grid(axis="y", linestyle=":", alpha=0.5) + + # ----------------------------- + # Legend (simple, but useful) + # ----------------------------- + legend_elements = [ + Patch(facecolor="white", edgecolor="black", label="Box = IQR (25%–75%)"), + Patch(facecolor="white", edgecolor="black", label="Center line = median"), + Patch(facecolor="white", edgecolor="black", label="Whiskers = typical range (no outliers shown)"), + Patch(facecolor="white", edgecolor="white", label="Left panel: empirical stability bins"), + Patch(facecolor="white", edgecolor="white", label="Right panel: LLM self-reported bins"), + ] + fig.legend(handles=legend_elements, loc="upper center", ncol=3, frameon=True) + + plt.tight_layout(rect=[0, 0, 1, 0.90]) + plt.show() + + print("\n" + "="*90) + print("DONE") + print("="*90) + + +# Example (complete): +df_final = build_empirical_confidence_table( + json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", + w_empirical=0.7, + w_llm=0.3, + tol_mode=0.5, + min_runs_expected=10, +) +plot_exec_boxplots(df_final) + + +## + + + +# %% Scatter + +import os, json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +def scatter_abs_error_by_conf_bins_single_json( + json_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + # ---- Load GT + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---- Load preds from JSON + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + rows = [] + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + + rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + "confidence": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_pred = pd.DataFrame(rows) + + # ---- Merge + filter + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]).copy() + df["abs_error"] = (df["EDSS_pred"] - df["EDSS_gt"]).abs() + + # ---- Bin confidence into 4 categories + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + df["conf_bin"] = pd.cut(df["confidence"], bins=bins, labels=labels, include_lowest=True) + df = df.dropna(subset=["conf_bin"]).copy() + + # ---- Logs + print("\n--- BIN COUNTS (points plotted) ---") + print(df["conf_bin"].value_counts().reindex(labels).fillna(0).astype(int)) + print(f"Total points plotted: {len(df)}") + + # ---- Scatter (categorical x with jitter) + x_map = {lab: i for i, lab in enumerate(labels)} + x = df["conf_bin"].map(x_map).astype(float).to_numpy() + jitter = np.random.uniform(-0.12, 0.12, size=len(df)) + xj = x + jitter + + plt.figure(figsize=(12, 6)) + plt.scatter(xj, df["abs_error"].to_numpy(), alpha=0.55) + plt.xticks(range(len(labels)), labels) + plt.xlabel("certainty_percent category (Iteration 1)") + plt.ylabel("Absolute Error (|EDSS_pred − EDSS_gt|)") + plt.title("Absolute Error vs LLM Confidence Category (Single JSON)") + plt.grid(axis="y", linestyle=":", alpha=0.5) + plt.tight_layout() + plt.show() + +# --- RUN --- +scatter_abs_error_by_conf_bins_single_json( + json_file_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json", + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", +) + +## + + + + +# %% Boxplot2 + +# Boxplot + light jittered points +# - Single JSON (iteration 1) +# - X: confidence bin (<70, 70-80, 80-90, 90-100) +# - Y: absolute error +# - Legend includes n per bin + +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.patches import Patch + +def boxplot_with_jitter_abs_error_by_conf_bins_single_json( + json_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + jitter_width=0.12, + point_alpha=0.25, + show_outliers=False, +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + # ---- Load GT + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---- Load preds from JSON + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + rows = [] + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + "confidence": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_pred = pd.DataFrame(rows) + + # ---- Merge + filter + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]).copy() + df["abs_error"] = (df["EDSS_pred"] - df["EDSS_gt"]).abs() + + # ---- Bin confidence + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + df["conf_bin"] = pd.cut(df["confidence"], bins=bins, labels=labels, include_lowest=True) + df = df.dropna(subset=["conf_bin"]).copy() + + # ---- Prepare per-bin arrays + bin_arrays = [df.loc[df["conf_bin"] == lab, "abs_error"].to_numpy() for lab in labels] + n_counts = [len(a) for a in bin_arrays] + + # ---- Plot + fig, ax = plt.subplots(figsize=(12, 6)) + + # Boxplot (no fliers by default to reduce clutter) + bp = ax.boxplot( + bin_arrays, + labels=labels, + showfliers=show_outliers, + patch_artist=True, + widths=0.55, + ) + + # Light fill for boxes (no explicit color choices required) + for b in bp["boxes"]: + b.set_alpha(0.35) + + # Jittered points on top + for i, arr in enumerate(bin_arrays, start=1): + if len(arr) == 0: + continue + x = np.full(len(arr), i, dtype=float) + x += np.random.uniform(-jitter_width, jitter_width, size=len(arr)) + ax.scatter(x, arr, alpha=point_alpha, s=18) + + ax.set_title("Absolute Error by LLM Confidence Bin (Iteration 1)") + ax.set_xlabel("certainty_percent category") + ax.set_ylabel("Absolute Error (|EDSS_pred − EDSS_gt|)") + ax.grid(axis="y", linestyle=":", alpha=0.5) + + # Legend showing n per bin + legend_handles = [ + Patch(facecolor="white", edgecolor="black", label=f"{lab}: n={n}") + for lab, n in zip(labels, n_counts) + ] + ax.legend(handles=legend_handles, title="Bin counts", loc="upper right", frameon=True) + + plt.tight_layout() + plt.show() + + # Print counts too (useful for discussion) + print("\n--- BIN COUNTS (points plotted) ---") + for lab, n in zip(labels, n_counts): + print(f"{lab:>18}: n={n}") + print(f"Total points plotted: {sum(n_counts)}") + + +# Example run: +boxplot_with_jitter_abs_error_by_conf_bins_single_json( + json_file_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json", + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" +) + +## + + + + +# %% Boxplot3 + + +# Boxplot + jitter with SIGNED error (direction) +# - Y-axis: signed error = EDSS_pred - EDSS_gt (negative = underestimation, positive = overestimation) +# - Also prints per-bin summary (n, mean signed error, median, MAE) + +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def boxplot_with_jitter_signed_error_by_conf_bins_single_json( + json_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + jitter_width=0.12, + point_alpha=0.25, + show_outliers=False, +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + # ---- Load GT + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---- Load preds from JSON + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + rows = [] + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + "confidence": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_pred = pd.DataFrame(rows) + + # ---- Merge + filter + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]).copy() + + # SIGNED ERROR (direction) + df["signed_error"] = df["EDSS_pred"] - df["EDSS_gt"] + df["abs_error"] = df["signed_error"].abs() + + # ---- Bin confidence + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + df["conf_bin"] = pd.cut(df["confidence"], bins=bins, labels=labels, include_lowest=True) + df = df.dropna(subset=["conf_bin"]).copy() + + # ---- Prepare arrays + bin_arrays = [df.loc[df["conf_bin"] == lab, "signed_error"].to_numpy() for lab in labels] + n_counts = [len(a) for a in bin_arrays] + + # ---- Plot + fig, ax = plt.subplots(figsize=(12, 6)) + + bp = ax.boxplot( + bin_arrays, + labels=labels, + showfliers=show_outliers, + patch_artist=True, + widths=0.55, + ) + + for b in bp["boxes"]: + b.set_alpha(0.35) + + # Jittered points + for i, arr in enumerate(bin_arrays, start=1): + if len(arr) == 0: + continue + x = np.full(len(arr), i, dtype=float) + x += np.random.uniform(-jitter_width, jitter_width, size=len(arr)) + ax.scatter(x, arr, alpha=point_alpha, s=18) + + # Zero line to show over/under clearly + ax.axhline(0, linewidth=1.5, linestyle="--") + + ax.set_title("Signed Error by LLM Confidence Bin (Iteration 1)") + ax.set_xlabel("certainty_percent category") + ax.set_ylabel("Signed Error (EDSS_pred − EDSS_gt)") + ax.grid(axis="y", linestyle=":", alpha=0.5) + + # Legend with n per bin + zero-line meaning + legend_handles = [ + Patch(facecolor="white", edgecolor="black", label=f"{lab}: n={n}") + for lab, n in zip(labels, n_counts) + ] + legend_handles.append(Line2D([0], [0], linestyle="--", color="black", label="0 = unbiased (over/under split)")) + ax.legend(handles=legend_handles, title="Bin counts", loc="upper right", frameon=True) + + plt.tight_layout() + plt.show() + + # ---- Print per-bin summary to discuss + print("\n--- PER-BIN SUMMARY (points plotted) ---") + for lab in labels: + sub = df.loc[df["conf_bin"] == lab] + n = len(sub) + if n == 0: + print(f"{lab:>18}: n=0") + continue + print( + f"{lab:>18}: n={n:3d} | " + f"mean signed={sub['signed_error'].mean(): .3f} | " + f"median signed={sub['signed_error'].median(): .3f} | " + f"MAE={sub['abs_error'].mean(): .3f}" + ) + print(f"Total points plotted: {len(df)}") + + +# Example run: +boxplot_with_jitter_signed_error_by_conf_bins_single_json( + json_file_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json", + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" +) +## + + + +# %% jitter and violin 10x10 + +# Violin + jitter (all JSONs in folder), with signed error +# - X: confidence bins (<70, 70-80, 80-90, 90-100) +# - Y: signed error = EDSS_pred - EDSS_gt (direction) +# - Prints bin counts (n) and puts n into the legend + +import os, glob, json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def violin_jitter_signed_error_all_jsons( + json_dir_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + jitter_width=0.12, + point_alpha=0.20, + point_size=10, + violin_inner="quartile", # 'quartile', 'box', 'stick', or None +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + # ---- Load GT + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---- Load preds from ALL JSONs + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + rows = [] + for fp in json_files: + with open(fp, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "file": os.path.basename(fp), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + "confidence": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_pred = pd.DataFrame(rows) + + # ---- Merge + filter + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]).copy() + df["signed_error"] = df["EDSS_pred"] - df["EDSS_gt"] + + # ---- Bin confidence + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + df["conf_bin"] = pd.cut(df["confidence"], bins=bins, labels=labels, include_lowest=True) + df = df.dropna(subset=["conf_bin"]).copy() + + # ---- Counts + log + counts = df["conf_bin"].value_counts().reindex(labels).fillna(0).astype(int) + print("\n--- BIN COUNTS (all JSONs) ---") + for lab in labels: + print(f"{lab:>18}: n={counts[lab]}") + print(f"Total points plotted: {len(df)}") + print(f"JSON files: {len(json_files)}") + + # Ensure ordering for seaborn + df["conf_bin"] = pd.Categorical(df["conf_bin"], categories=labels, ordered=True) + + # ---- Plot + plt.figure(figsize=(12, 6)) + + # Violin (density) + sns.violinplot( + data=df, + x="conf_bin", + y="signed_error", + order=labels, + inner=violin_inner, + cut=0 + ) + + # Jittered points (manual jitter to keep it consistent and fast) + x_map = {lab: i for i, lab in enumerate(labels)} + x = df["conf_bin"].map(x_map).astype(float).to_numpy() + xj = x + np.random.uniform(-jitter_width, jitter_width, size=len(df)) + plt.scatter(xj, df["signed_error"].to_numpy(), alpha=point_alpha, s=point_size) + + # Zero line (over/under split) + plt.axhline(0, linestyle="--", linewidth=1.5) + + plt.xticks(range(len(labels)), labels) + plt.xlabel("certainty_percent category (all iterations)") + plt.ylabel("Signed Error (EDSS_pred − EDSS_gt)") + plt.title("Signed Error vs LLM Confidence Category — Violin + Jitter (All JSONs)") + plt.grid(axis="y", linestyle=":", alpha=0.5) + + # Legend with n per bin + legend_handles = [ + Patch(facecolor="white", edgecolor="black", label=f"{lab}: n={int(counts[lab])}") + for lab in labels + ] + legend_handles.append(Line2D([0], [0], linestyle="--", color="black", label="0 = unbiased (over/under split)")) + plt.legend(handles=legend_handles, title="Bin counts", loc="upper right", frameon=True) + + plt.tight_layout() + plt.show() + + +# Example run: +violin_jitter_signed_error_all_jsons( + json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" +) + +## + + + +# %% jitter and violin 10x1 + + +# Adjusted: Violin + jitter (ALL JSONs for points) but X-bins come ONLY from JSON #1 (reference) +# Fixes: +# 1) Legend has colors matching bins +# 2) Legend placed OUTSIDE plot area +# 3) X-axis binning uses certainty_percent from JSON1 (by key), then all iterations' points inherit that bin + +import os, glob, json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def violin_jitter_signed_error_all_jsons_xbins_from_json1( + json_dir_path, + json1_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + jitter_width=0.12, + point_alpha=0.18, + point_size=10, + violin_inner="quartile", # 'quartile', 'box', 'stick', or None +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + # ---------------------------- + # Load GT + # ---------------------------- + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---------------------------- + # Load JSON1 and build reference bins by KEY + # ---------------------------- + with open(json1_file_path, "r", encoding="utf-8") as f: + data1 = json.load(f) + + ref_rows = [] + for entry in data1: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + ref_rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "confidence_ref": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_ref = pd.DataFrame(ref_rows) + + # If JSON1 has duplicates for a key (unlikely, but safe), take the first non-null confidence + df_ref = (df_ref.sort_values("confidence_ref") + .groupby("key", as_index=False)["confidence_ref"] + .apply(lambda s: s.dropna().iloc[0] if s.dropna().any() else np.nan)) + if isinstance(df_ref.index, pd.MultiIndex): + df_ref = df_ref.reset_index(drop=True) + + # Confidence bins + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + df_ref["conf_bin_ref"] = pd.cut(df_ref["confidence_ref"], bins=bins, labels=labels, include_lowest=True) + df_ref = df_ref.dropna(subset=["conf_bin_ref"]).copy() + + # ---------------------------- + # Load ALL JSONs (all points) + # ---------------------------- + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + rows = [] + for fp in json_files: + with open(fp, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "file": os.path.basename(fp), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + }) + + df_pred = pd.DataFrame(rows) + + # ---------------------------- + # Merge: preds + GT + reference bins (from JSON1) + # ---------------------------- + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.merge(df_ref[["key", "conf_bin_ref"]], on="key", how="inner", validate="many_to_one") + + # filter for plotting + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "conf_bin_ref"]).copy() + df["signed_error"] = df["EDSS_pred"] - df["EDSS_gt"] + + # ordering + df["conf_bin_ref"] = pd.Categorical(df["conf_bin_ref"], categories=labels, ordered=True) + + # ---------------------------- + # Logs + counts + # ---------------------------- + counts = df["conf_bin_ref"].value_counts().reindex(labels).fillna(0).astype(int) + + print("\n--- BIN COUNTS (ALL JSON points, binned by JSON1 confidence) ---") + for lab in labels: + print(f"{lab:>18}: n={int(counts[lab])}") + print(f"Total points plotted: {len(df)}") + print(f"JSON files used for points: {len(json_files)}") + print(f"Reference JSON1 bins derived from: {os.path.basename(json1_file_path)}") + print(f"Keys in reference (after binning & non-null): {df_ref['key'].nunique()}") + + # ---------------------------- + # Colors + legend patches + # ---------------------------- + palette = sns.color_palette("Blues", n_colors=len(labels)) + bin_colors = {lab: palette[i] for i, lab in enumerate(labels)} + + legend_handles = [ + Patch(facecolor=bin_colors[lab], edgecolor="black", label=f"{lab}: n={int(counts[lab])}") + for lab in labels + ] + legend_handles.append(Line2D([0], [0], linestyle="--", color="black", label="0 = unbiased (over/under split)")) + + # ---------------------------- + # Plot (legend outside) + # ---------------------------- + fig, ax = plt.subplots(figsize=(12.5, 6)) + + sns.violinplot( + data=df, + x="conf_bin_ref", + y="signed_error", + order=labels, + inner=violin_inner, + cut=0, + palette=[bin_colors[l] for l in labels], + ax=ax, + ) + + # jittered points (manual jitter) + x_map = {lab: i for i, lab in enumerate(labels)} + x = df["conf_bin_ref"].map(x_map).astype(float).to_numpy() + xj = x + np.random.uniform(-jitter_width, jitter_width, size=len(df)) + ax.scatter(xj, df["signed_error"].to_numpy(), alpha=point_alpha, s=point_size) + + ax.axhline(0, linestyle="--", linewidth=1.5) + + ax.set_xlabel("certainty_percent category (from JSON 1 as reference)") + ax.set_ylabel("Signed Error (EDSS_pred − EDSS_gt)") + ax.set_title("Signed Error vs LLM Confidence Category — Violin + Jitter (All JSONs)\nBinned by JSON 1 certainty_percent") + ax.grid(axis="y", linestyle=":", alpha=0.5) + + # Legend outside (right) + ax.legend( + handles=legend_handles, + title="Bin counts", + loc="center left", + bbox_to_anchor=(1.02, 0.5), + frameon=True + ) + + plt.tight_layout() + plt.show() + + +# Example run: +json1_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +violin_jitter_signed_error_all_jsons_xbins_from_json1( + json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", + json1_file_path=json1_path, + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv") + +## + + +# %% Coorelation + +# Correlation plot (RAW certainty_percent) vs error +# - Uses ALL JSONs as points +# - Uses JSON1 certainty_percent as the x-value reference (per key) +# - Y can be abs_error or signed_error (choose with y_mode) +# - Prints Spearman + Pearson correlations +# - Adds a simple linear trend line + +import os, glob, json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +def correlation_scatter_raw_certainty_json1_reference( + json_dir_path, + json1_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + y_mode="abs", # "abs" or "signed" + point_alpha=0.18, + point_size=12, +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + def rankdata(a): + # Average-rank for ties (Spearman needs ranks) + s = pd.Series(a) + return s.rank(method="average").to_numpy() + + # ---------------------------- + # Load GT + # ---------------------------- + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---------------------------- + # Load JSON1 reference certainty_percent (per key) + # ---------------------------- + with open(json1_file_path, "r", encoding="utf-8") as f: + data1 = json.load(f) + + ref_rows = [] + for entry in data1: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + ref_rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "certainty_ref": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_ref = pd.DataFrame(ref_rows) + + # Deduplicate keys if needed: take first non-null certainty + df_ref = (df_ref.dropna(subset=["certainty_ref"]) + .groupby("key", as_index=False)["certainty_ref"] + .first()) + + # ---------------------------- + # Load ALL JSON predictions (points) + # ---------------------------- + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + rows = [] + for fp in json_files: + with open(fp, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "file": os.path.basename(fp), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + }) + + df_pred = pd.DataFrame(rows) + + # ---------------------------- + # Merge: preds + GT + JSON1 reference certainty + # ---------------------------- + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.merge(df_ref[["key", "certainty_ref"]], on="key", how="inner", validate="many_to_one") + + # Filter needed fields + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "certainty_ref"]).copy() + + df["signed_error"] = df["EDSS_pred"] - df["EDSS_gt"] + df["abs_error"] = df["signed_error"].abs() + y_col = "abs_error" if y_mode == "abs" else "signed_error" + + # ---------------------------- + # Logs + # ---------------------------- + print("\n" + "="*90) + print("CORRELATION: RAW certainty_percent (JSON1 reference) vs ERROR (ALL JSON points)") + print("="*90) + print(f"JSON DIR (points): {json_dir_path} | files: {len(json_files)}") + print(f"JSON1 reference: {os.path.basename(json1_file_path)}") + print(f"Points available after merge+filter: {len(df)}") + print(f"Unique keys in plot: {df['key'].nunique()}") + print(f"Y mode: {y_mode} ({y_col})") + + # ---------------------------- + # Correlations (Pearson + Spearman) + # ---------------------------- + x = df["certainty_ref"].to_numpy(dtype=float) + y = df[y_col].to_numpy(dtype=float) + + # Pearson + pearson = np.corrcoef(x, y)[0, 1] if len(df) >= 2 else np.nan + + # Spearman = Pearson corr of ranks + rx = rankdata(x) + ry = rankdata(y) + spearman = np.corrcoef(rx, ry)[0, 1] if len(df) >= 2 else np.nan + + print(f"\nPearson r: {pearson:.4f}") + print(f"Spearman ρ: {spearman:.4f}") + + # ---------------------------- + # Trend line (simple linear fit) + # ---------------------------- + # Fit y = a*x + b + if len(df) >= 2: + a, b = np.polyfit(x, y, 1) + else: + a, b = np.nan, np.nan + + # ---------------------------- + # Plot + # ---------------------------- + plt.figure(figsize=(12, 6)) + plt.scatter(x, y, alpha=point_alpha, s=point_size) + + # trend line across full x-range + if np.isfinite(a) and np.isfinite(b): + xs = np.linspace(np.nanmin(x), np.nanmax(x), 200) + plt.plot(xs, a * xs + b, linestyle="--", linewidth=2) + + plt.xlabel("certainty_percent (from JSON 1, per key)") + ylabel = "Absolute Error |EDSS_pred − EDSS_gt|" if y_mode == "abs" else "Signed Error (EDSS_pred − EDSS_gt)" + plt.ylabel(ylabel) + plt.title(f"Correlation of JSON1 certainty_percent vs {y_col} (All iterations)\n" + f"Pearson r={pearson:.3f} | Spearman ρ={spearman:.3f}") + plt.grid(linestyle=":", alpha=0.5) + plt.tight_layout() + plt.show() + + +# Example run: +json1_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +correlation_scatter_raw_certainty_json1_reference( + json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", + json1_file_path=json1_path, + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", + y_mode="abs" # or "signed" +) +## + +# %% Correlation adjusted + +# Correlation scatter (RAW certainty_percent from JSON1) vs error (all JSON points) +# Adds: +# 1) Legend (points, trend line) + Pearson/Spearman shown in legend and title +# 2) Trend line color set to high-contrast (black by default) +# 3) Density coloring: dots colored by local point density (bluer = more cases) + colorbar + +import os, glob, json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.colors import LogNorm, PowerNorm + +def correlation_scatter_raw_certainty_json1_reference( + json_dir_path, + json1_file_path, + ground_truth_path, + gt_sep=";", + gt_edss_col="EDSS", + y_mode="abs", # "abs" or "signed" + point_alpha=0.85, # higher alpha works better with density coloring + point_size=14, + trend_color="black", # high-contrast line + save_svg_path=None, + dpi=300 +): + def norm_str(x): + return str(x).strip().lower() + + def parse_edss(x): + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + def rankdata(a): + return pd.Series(a).rank(method="average").to_numpy() + + # ---------------------------- + # Load GT + # ---------------------------- + df_gt = pd.read_csv(ground_truth_path, sep=gt_sep) + for col in ["unique_id", "MedDatum", gt_edss_col]: + if col not in df_gt.columns: + raise ValueError(f"GT missing column '{col}'. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + df_gt["EDSS_gt"] = df_gt[gt_edss_col].map(parse_edss) + + # ---------------------------- + # Load JSON1 reference certainty_percent (per key) + # ---------------------------- + with open(json1_file_path, "r", encoding="utf-8") as f: + data1 = json.load(f) + + ref_rows = [] + for entry in data1: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + ref_rows.append({ + "key": norm_str(uid) + "_" + norm_str(md), + "certainty_ref": pd.to_numeric(res.get("certainty_percent"), errors="coerce"), + }) + + df_ref = pd.DataFrame(ref_rows) + df_ref = (df_ref.dropna(subset=["certainty_ref"]) + .groupby("key", as_index=False)["certainty_ref"] + .first()) + + # ---------------------------- + # Load ALL JSON predictions (points) + # ---------------------------- + json_files = sorted(glob.glob(os.path.join(json_dir_path, "*.json"))) + if not json_files: + raise FileNotFoundError(f"No JSON files found in: {json_dir_path}") + + rows = [] + for fp in json_files: + with open(fp, "r", encoding="utf-8") as f: + data = json.load(f) + for entry in data: + if not entry.get("success"): + continue + res = entry.get("result", {}) + uid, md = res.get("unique_id"), res.get("MedDatum") + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + continue + rows.append({ + "file": os.path.basename(fp), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": parse_edss(res.get("EDSS")), + }) + + df_pred = pd.DataFrame(rows) + + # ---------------------------- + # Merge: preds + GT + JSON1 reference certainty + # ---------------------------- + df = df_pred.merge(df_gt[["key", "EDSS_gt"]], on="key", how="inner", validate="many_to_one") + df = df.merge(df_ref[["key", "certainty_ref"]], on="key", how="inner", validate="many_to_one") + df = df.dropna(subset=["EDSS_gt", "EDSS_pred", "certainty_ref"]).copy() + + df["signed_error"] = df["EDSS_pred"] - df["EDSS_gt"] + df["abs_error"] = df["signed_error"].abs() + y_col = "abs_error" if y_mode == "abs" else "signed_error" + + # ---------------------------- + # Correlations + # ---------------------------- + x = df["certainty_ref"].to_numpy(dtype=float) + y = df[y_col].to_numpy(dtype=float) + + pearson = np.corrcoef(x, y)[0, 1] if len(df) >= 2 else np.nan + rx, ry = rankdata(x), rankdata(y) + spearman = np.corrcoef(rx, ry)[0, 1] if len(df) >= 2 else np.nan + + # ---------------------------- + # Trend line (linear fit) + # ---------------------------- + if len(df) >= 2: + a, b = np.polyfit(x, y, 1) + else: + a, b = np.nan, np.nan + + # ---------------------------- + # Density coloring (2D histogram bin counts) + # "how blue" = how many points are around that location + # ---------------------------- + # Choose binning resolution (balanced for ~thousands of points) + x_bins = 50 + y_bins = 50 + + # Compute bin index per point + x_edges = np.linspace(np.nanmin(x), np.nanmax(x), x_bins + 1) + y_edges = np.linspace(np.nanmin(y), np.nanmax(y), y_bins + 1) + + xi = np.clip(np.digitize(x, x_edges) - 1, 0, x_bins - 1) + yi = np.clip(np.digitize(y, y_edges) - 1, 0, y_bins - 1) + + # 2D counts + counts2d = np.zeros((x_bins, y_bins), dtype=int) + for i in range(len(x)): + counts2d[xi[i], yi[i]] += 1 + + # density per point = count of its bin + density = np.array([counts2d[xi[i], yi[i]] for i in range(len(x))], dtype=float) + + # Plot low density first, high density last (so dense points are visible) + order = np.argsort(density) + x_o, y_o, d_o = x[order], y[order], density[order] + + + + +# ... keep everything above the "Plot" section identical ... + + # ---------------------------- + # Plot (IMPROVED COLORS) + # ---------------------------- + fig, ax = plt.subplots(figsize=(12.5, 6)) + + # Option A (recommended): logarithmic color scaling + # Add +1 to avoid log(0) + d_plot = d_o + 1 + + # clip vmax so one extreme bin doesn't wash everything out + vmax = np.percentile(d_plot, 99) # try 95 or 99 depending on your data + norm = LogNorm(vmin=1, vmax=max(2, vmax)) + + sc = ax.scatter( + x_o, y_o, + c=d_plot, + cmap="Blues", + norm=norm, + s=point_size, + alpha=point_alpha, + linewidths=0 + ) + + # Trend line (black) + if np.isfinite(a) and np.isfinite(b): + xs = np.linspace(np.nanmin(x), np.nanmax(x), 200) + ax.plot(xs, a * xs + b, linestyle="--", linewidth=2.5, color=trend_color) + + ax.set_xlabel("certainty percent") + ax.set_ylabel("Absolute Error" if y_mode == "abs" else "Signed Error (EDSS_pred − EDSS_gt)") +# ax.set_title( +# f"Correlation: JSON1 certainty_percent vs {y_col} (All iterations)\n" +# f"Pearson r={pearson:.3f} | Spearman ρ={spearman:.3f}" +# ) + ax.grid(linestyle=":", alpha=0.5) + + # Colorbar + cbar = plt.colorbar(sc, ax=ax) + cbar.set_label("Local density (count of cases in bin, log-scaled)") + + # Legend + legend_items = [ + Line2D([0], [0], marker="o", linestyle="None", color="navy", + label=f"Data points (n={len(df)})"), + Line2D([0], [0], linestyle="--", color=trend_color, linewidth=2.5, + label=f"Linear trend (Pearson r={pearson:.3f})"), + ] + ax.legend(handles=legend_items, loc="upper right", frameon=True, title="Legend") + + plt.tight_layout() + # Save as SVG (optional) + if save_svg_path: + fig.savefig(save_svg_path, format="svg", bbox_inches="tight", dpi=dpi) + print(f"[SAVED] {save_svg_path}") + + plt.show() + +json1_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" + +correlation_scatter_raw_certainty_json1_reference( + json_dir_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration", + json1_file_path=json1_path, + ground_truth_path="/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv", + y_mode="abs", + # save_svg_path="/home/shahin/Lab/Doktorarbeit/Barcelona/results/corr_json1_abs_error.svg" +) + +## + + + diff --git a/certainty.py b/certainty.py new file mode 100644 index 0000000..cfc11be --- /dev/null +++ b/certainty.py @@ -0,0 +1,600 @@ + +# %% API call1 +#import time +#import json +#import os +#from datetime import datetime +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +## Load environment variables +#load_dotenv() +# +## === CONFIGURATION === +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +#MODEL_NAME = "GPT-OSS-120B" +#HEALTH_URL = f"{OPENAI_BASE_URL}/health" # Placeholder - actual health check would need to be implemented +#CHAT_URL = f"{OPENAI_BASE_URL}/chat/completions" +# +## File paths +#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +##GRAMMAR_FILE = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/just_edss_schema.gbnf" +# +## Initialize OpenAI client +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +## Read EDSS instructions from file +#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: +# EDSS_INSTRUCTIONS = f.read().strip() +## === RUN INFERENCE 2 === +#def run_inference(patient_text): +# prompt = f''' +# Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale) aus klinischen Berichten zu extrahieren. +#### Regeln für die Ausgabe: +#1. **Reason**: Erstelle eine prägnante Zusammenfassung (max. 400 Zeichen) der Befunde auf **DEUTSCH**, die zur Einstufung führen. +#2. **klassifizierbar**: +# - Setze dies auf **true**, wenn ein EDSS-Wert identifiziert, berechnet oder basierend auf den klinischen Hinweisen plausibel geschätzt werden kann. +# - Setze dies auf **false**, NUR wenn die Daten absolut unzureichend oder so widersprüchlich sind, dass keinerlei Einstufung möglich ist. +#3. **EDSS**: +# - Dieses Feld ist **VERPFLICHTEND**, wenn "klassifizierbar" auf true steht. +# - Es muss eine Zahl zwischen 0.0 und 10.0 sein. +# - Versuche stets, den EDSS-Wert so präzise wie möglich zu bestimmen, auch wenn die Datenlage dünn ist (nutze verfügbare Informationen zu Gehstrecke und Funktionssystemen). +# - Dieses Feld **DARF NICHT ERSCHEINEN**, wenn "klassifizierbar" auf false steht. +# +#### Einschränkungen: +#- Erfinde keine Fakten, aber nutze klinische Herleitungen aus dem Bericht, um den EDSS zu bestimmen. +#- Priorisiere die Vergabe eines EDSS-Wertes gegenüber der Markierung als nicht klassifizierbar. +#- Halte dich strikt an die JSON-Struktur. +# +#EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +#''' +# start_time = time.time() +# +# try: +# # Make API call using OpenAI client +# response = client.chat.completions.create( +# messages=[ +# { +# "role": "system", +# "content": "You extract EDSS scores. You prioritize providing a score even if data is partial, by using clinical inference." +# }, +# { +# "role": "user", +# "content": prompt +# } +# ], +# model=MODEL_NAME, +# max_tokens=2048, +# temperature=0.0, +# response_format={"type": "json_object"} +# ) +# +# # Extract content from response +# content = response.choices[0].message.content +# +# # Parse the JSON response +# parsed = json.loads(content) +# +# inference_time = time.time() - start_time +# +# return { +# "success": True, +# "result": parsed, +# "inference_time_sec": inference_time +# } +# +# except Exception as e: +# print(f"Inference error: {e}") +# return { +# "success": False, +# "error": str(e), +# "inference_time_sec": -1 +# } +## === BUILD PATIENT TEXT === +#def build_patient_text(row): +# return ( +# str(row["T_Zusammenfassung"]) + "\n" + +# str(row["Diagnosen"]) + "\n" + +# str(row["T_KlinBef"]) + "\n" + +# str(row["T_Befunde"]) + "\n" +# ) +# +#if __name__ == "__main__": +# # Read CSV file ONLY inside main block +# df = pd.read_csv(INPUT_CSV, sep=';') +# results = [] +# +# # Process each row +# for idx, row in df.iterrows(): +# print(f"Processing row {idx + 1}/{len(df)}") +# try: +# patient_text = build_patient_text(row) +# result = run_inference(patient_text) +# +# # Add unique_id and MedDatum to result for tracking +# result["unique_id"] = row.get("unique_id", f"row_{idx}") +# result["MedDatum"] = row.get("MedDatum", None) +# +# results.append(result) +# print(json.dumps(result, indent=2)) +# except Exception as e: +# print(f"Error processing row {idx}: {e}") +# results.append({ +# "success": False, +# "error": str(e), +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None) +# }) +# +# # Save results to a JSON file +# output_json = INPUT_CSV.replace(".csv", "_results_Nisch.json") +# with open(output_json, 'w') as f: +# json.dump(results, f, indent=2) +# print(f"Results saved to {output_json}") +## + + + +# %% API call1 - Enhanced with certainty scoring +#import time +#import json +#import os +#from datetime import datetime +#import pandas as pd +#from openai import OpenAI +#from dotenv import load_dotenv +# +## Load environment variables +#load_dotenv() +# +## === CONFIGURATION === +#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +#OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +#MODEL_NAME = "GPT-OSS-120B" +# +## File paths +#INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Test.csv" +#EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +# +## Initialize OpenAI client +#client = OpenAI( +# api_key=OPENAI_API_KEY, +# base_url=OPENAI_BASE_URL +#) +# +## Read EDSS instructions from file +#with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: +# EDSS_INSTRUCTIONS = f.read().strip() +# +## === PROMPT WITH CERTAINTY REQUEST === +#def build_prompt(patient_text): +# return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren. +# +#### Deine Aufgabe: +#1. Analysiere den Patientenbericht und extrahiere: +# - Den Gesamt-EDSS-Score (0.0–10.0) +# - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl) +#2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein. +# +#### Struktur der JSON-Ausgabe (VERPFLICHTEND): +#Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter. +# +#{{ +# "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).", +# "klassifizierbar": true/false, +# "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)", +# "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)", +# "subcategories": {{ +# "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, +# "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0 +# }} +#}} +# +#### Regeln: +#- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest. +#- **klassifizierbar**: +# - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind. +# - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist. +#- **EDSS**: +# - **VERPFLICHTEND**, wenn `klassifizierbar=true`. +# - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`. +#- **certainty_percent**: +# - **Immer present** — Ganzzahl (0–100), basierend auf: +# - Klarheit und Vollständigkeit der Berichtsangaben, +# - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz), +# - Konsistenz zwischen den Unterkategorien. +#- **subcategories**: +# - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein. +# - Jeder Wert ist entweder: +# - `null` (wenn keine ausreichende Information vorliegt), **oder** +# - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0). +# - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab. +# - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`. +# +#### EDSS-Bewertungsrichtlinien: +#{EDSS_INSTRUCTIONS} +# +#Patientenbericht: +#{patient_text} +#''' +# +## === INFERENCE FUNCTION === +#def run_inference(patient_text): +# prompt = build_prompt(patient_text) +# +# start_time = time.time() +# +# try: +# response = client.chat.completions.create( +# messages=[ +# {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."} +# ] + [ +# {"role": "user", "content": prompt} +# ], +# model=MODEL_NAME, +# max_tokens=2048, +# temperature=0.1, # Slightly higher for more natural certainty estimation (still low for reliability) +# response_format={"type": "json_object"} +# ) +# +# content = response.choices[0].message.content +# +# # Parse and validate JSON +# try: +# parsed = json.loads(content) +# except json.JSONDecodeError as e: +# print(f"⚠️ JSON parsing failed: {e}") +# print("Raw response:", content[:500]) +# raise ValueError("Model did not return valid JSON") +# +# # Enforce required keys +# if "certainty_percent" not in parsed: +# print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.") +# parsed["certainty_percent"] = 0 # fallback +# elif not isinstance(parsed["certainty_percent"], (int, float)): +# parsed["certainty_percent"] = int(parsed["certainty_percent"]) +# +# # Clamp certainty to [0, 100] +# pct = parsed["certainty_percent"] +# parsed["certainty_percent"] =max(0, min(100, int(pct))) +# +# # Enforce EDSS rules: if not classifiable → remove EDSS +# if not parsed.get("klassifizierbar", False): +# if "EDSS" in parsed: +# del parsed["EDSS"] # per spec, must not appear if not classifiable +# else: +# if "EDSS" not in parsed: +# print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.") +# parsed["EDSS"] = 7.0 # last-resort fallback +# +# inference_time = time.time() - start_time +# +# return { +# "success": True, +# "result": parsed, +# "inference_time_sec": inference_time +# } +# +# except Exception as e: +# print(f"❌ Inference error: {e}") +# return { +# "success": False, +# "error": str(e), +# "inference_time_sec": -1, +# "result": None # no structured output +# } +# +## === BUILD PATIENT TEXT === +#def build_patient_text(row): +# return ( +# str(row.get("T_Zusammenfassung", "")) + "\n" + +# str(row.get("Diagnosen", "")) + "\n" + +# str(row.get("T_KlinBef", "")) + "\n" + +# str(row.get("T_Befunde", "")) +# ) +# +#if __name__ == "__main__": +# # Load data +# df = pd.read_csv(INPUT_CSV, sep=';') +# results = [] +# +# # Optional: limit for testing +# # df = df.head(3) +# +# print(f"Processing {len(df)} rows...") +# for idx, row in df.iterrows(): +# print(f"\n— Row {idx + 1}/{len(df)} —") +# try: +# patient_text = build_patient_text(row) +# result = run_inference(patient_text) +# +# # Attach metadata +# result["unique_id"] = row.get("unique_id", f"row_{idx}") +# result["MedDatum"] = row.get("MedDatum", None) +# +# results.append(result) +# +# # Print summary +# if result["success"]: +# res = result["result"] +# edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" +# print(f"✅ Result → EDSS={edss}, certainty={res.get('certainty_percent', 'N/A')}%") +# print(f" Reason: {res.get('reason', 'N/A')[:100]}…") +# else: +# print(f"❌ Failed: {result.get('error', 'Unknown error')[:100]}") +# +# except Exception as e: +# print(f"⚠️ Error processing row {idx}: {e}") +# results.append({ +# "success": False, +# "error": str(e), +# "unique_id": row.get("unique_id", f"row_{idx}"), +# "MedDatum": row.get("MedDatum", None), +# "result": None +# }) +# +# # Save results +# output_json = INPUT_CSV.replace(".csv", "_results_Nisch_certainty.json") +# with open(output_json, 'w', encoding='utf-8') as f: +# json.dump(results, f, indent=2, ensure_ascii=False) +# print(f"\n✅ Saved results to: {output_json}") +# +## + + +# %% API call - Multi-iteration EDSS + certainty extraction + +import time +import json +import os +from datetime import datetime +import pandas as pd +from openai import OpenAI +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# === CONFIGURATION === +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = "GPT-OSS-120B" + +# File paths +INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" + +# Iteration settings +NUM_ITERATIONS = 20 +STOP_ON_FIRST_ERROR = False # Set to True for debugging + +# Initialize OpenAI client +client = OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_BASE_URL +) + +# Read EDSS instructions from file +with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: + EDSS_INSTRUCTIONS = f.read().strip() + +# === PROMPT (unchanged from before) === +def build_prompt(patient_text): + return f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale), alle Unterkategorien und die Bewertungssicherheit aus klinischen Berichten zu extrahieren. + +### Deine Aufgabe: +1. Analysiere den Patientenbericht und extrahiere: + - Den Gesamt-EDSS-Score (0.0–10.0) + - Alle 8 EDSS-Unterkategorien (mit jeweils eigener Maximalpunktzahl) +2. Schätze für jede Entscheidung die Sicherheit als Ganzzahl von 0–100 % ein. + +### Struktur der JSON-Ausgabe (VERPFLICHTEND): +Gib NUR gültiges JSON zurück — kein Markdown, kein Text davor/dahinter. + +{{ + "reason": "Kernaussage zur EDSS-Begründung (max. 400 Zeichen, auf Deutsch).", + "klassifizierbar": true/false, + "EDSS": null ODER Zahl zwischen 0.0 und 10.0 (nur wenn klassifizierbar=true)", + "certainty_percent": 0 ODER Zahl zwischen 0 und 100 (Ganzzahl)", + "subcategories": {{ + "VISUAL_OPTIC_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "BRAINSTEM_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "PYRAMIDAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "CEREBELLAR_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "SENSORY_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "BOWEL_AND_BLADDER_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "CEREBRAL_FUNCTIONS": null ODER Zahl zwischen 0.0 und 6.0, + "AMBULATION": null ODER Zahl zwischen 0.0 und 10.0 + }} +}} + +### Regeln: +- **reason**: Kurze, prägnante Begründung (auf Deutsch, max. 400 Zeichen), warum du den EDSS-Wert und die Unterkategorien so bewertest. +- **klassifizierbar**: + - `true`, wenn EDSS und mindestens die wichtigsten Unterkategorien *eindeutig ableitbar* oder *plausibel inferierbar* sind. + - `false`, **nur**, wenn keine relevanten Daten vorliegen, oder diese so widersprüchlich/inkonsistent sind, dass keine vernünftige Einschätzung möglich ist. +- **EDSS**: + - **VERPFLICHTEND**, wenn `klassifizierbar=true`. + - Zahl zwischen 0.0 und 10.0 (z.B. 3.0, 5.5). Darf **nicht** erscheinen, wenn `klassifizierbar=false`. +- **certainty_percent**: + - **Immer present** — Ganzzahl (0–100), basierend auf: + - Klarheit und Vollständigkeit der Berichtsangaben, + - Stichhaltigkeit der Schlussfolgerung (inkl. Inferenz), + - Konsistenz zwischen den Unterkategorien. +- **subcategories**: + - **Immer present** — **alle 8 Unterkategorien** müssen enthalten sein. + - Jeder Wert ist entweder: + - `null` (wenn keine ausreichende Information vorliegt), **oder** + - eine Zahl ≤ jeweiliger Obergrenze (z.B. Ambulation ≤ 10.0). + - Wenn die Unterkategorie plausibel inferiert werden kann (auch indirekt), gib einen sinnvollen Wert ab. + - Beispiel: Wenn „Gang mit Krückstock auf ebenem Boden bis 200 m“ steht, setze `AMBULATION: 5.5`. + +### EDSS-Bewertungsrichtlinien: +{EDSS_INSTRUCTIONS} + +Patientenbericht: +{patient_text} +''' + +# === INFERENCE FUNCTION (unchanged) === +def run_inference(patient_text): + prompt = build_prompt(patient_text) + + start_time = time.time() + + try: + response = client.chat.completions.create( + messages=[ + {"role": "system", "content": "Du gibst EXKLUSIV gültiges JSON zurück — keine weiteren Erklärungen."} + ] + [ + {"role": "user", "content": prompt} + ], + model=MODEL_NAME, + max_tokens=2048, + temperature=0.1, + response_format={"type": "json_object"} + ) + + content = response.choices[0].message.content + + # Parse and validate JSON + try: + parsed = json.loads(content) + except json.JSONDecodeError as e: + print(f"⚠️ JSON parsing failed: {e}") + print("Raw response:", content[:500]) + raise ValueError("Model did not return valid JSON") + + # Enforce required keys + if "certainty_percent" not in parsed: + print("⚠️ Missing 'certainty_percent' in output! Force-adding fallback.") + parsed["certainty_percent"] = 0 + elif not isinstance(parsed["certainty_percent"], (int, float)): + parsed["certainty_percent"] = int(parsed["certainty_percent"]) + + # Clamp certainty to [0, 100] + pct = parsed["certainty_percent"] + parsed["certainty_percent"] = max(0, min(100, int(pct))) + + # Enforce EDSS rules + if not parsed.get("klassifizierbar", False): + if "EDSS" in parsed: + del parsed["EDSS"] + else: + if "EDSS" not in parsed: + print("⚠️ 'klassifizierbar' is true but EDSS missing — adding fallback.") + parsed["EDSS"] = 7.0 + + inference_time = time.time() - start_time + + return { + "success": True, + "result": parsed, + "inference_time_sec": inference_time + } + + except Exception as e: + print(f"❌ Inference error: {e}") + return { + "success": False, + "error": str(e), + "inference_time_sec": -1, + "result": None + } + +# === BUILD PATIENT TEXT === +def build_patient_text(row): + return ( + str(row.get("T_Zusammenfassung", "")) + "\n" + + str(row.get("Diagnosen", "")) + "\n" + + str(row.get("T_KlinBef", "")) + "\n" + + str(row.get("T_Befunde", "")) + ) + +# === MAIN LOOP (NEW: MULTI-ITERATION) === +if __name__ == "__main__": + # Load data ONCE (to avoid repeated I/O overhead) + df = pd.read_csv(INPUT_CSV, sep=';') + total_rows = len(df) + print(f"Loaded {total_rows} patient records.") + + for iteration in range(1, NUM_ITERATIONS + 1): + print(f"\n{'='*60}") + print(f"🔄 ITERATION {iteration}/{NUM_ITERATIONS}") + print(f"{'='*60}") + + iteration_results = [] + start_iter = time.time() + + for idx, row in df.iterrows(): + print(f"\rRow {idx+1}/{total_rows} | Iter {iteration}", end='', flush=True) + try: + patient_text = build_patient_text(row) + result = run_inference(patient_text) + + # Attach metadata + if result["success"]: + res = result["result"].copy() # avoid mutation + res["iteration"] = iteration + res["unique_id"] = row.get("unique_id", f"row_{idx}") + res["MedDatum"] = row.get("MedDatum", None) + result["result"] = res + + else: + result["iteration"] = iteration + result["unique_id"] = row.get("unique_id", f"row_{idx}") + result["MedDatum"] = row.get("MedDatum", None) + + iteration_results.append(result) + + if result["success"]: + res = result["result"] + edss = res.get("EDSS", "N/A") if res.get("klassifizierbar") else "N/A" + print(f" ✅ EDSS={edss}, cert={res.get('certainty_percent', '?')}%") + else: + print(f" ❌ {result.get('error', 'Unknown')}") + + except Exception as e: + print(f"\n⚠️ Row {idx} failed: {e}") + iteration_results.append({ + "success": False, + "error": str(e), + "iteration": iteration, + "unique_id": row.get("unique_id", f"row_{idx}"), + "MedDatum": row.get("MedDatum", None), + "result": None + }) + if STOP_ON_FIRST_ERROR: + break + + # Save per-iteration results + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = INPUT_CSV.replace(".csv", f"_results_iter_{iteration}_{timestamp}.json") + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(iteration_results, f, indent=2, ensure_ascii=False) + print(f"\n✅ Iteration {iteration} complete. Saved to: {output_path}") + + elapsed = time.time() - start_iter + print(f"⏱️ Iteration {iteration} took {elapsed:.1f}s ({elapsed/total_rows:.1f}s/row)") + + print(f"\n🎉 All {NUM_ITERATIONS} iterations completed!") + + + +## diff --git a/certainty_show.py b/certainty_show.py new file mode 100644 index 0000000..bb15a7b --- /dev/null +++ b/certainty_show.py @@ -0,0 +1,1540 @@ +# %% Explore Dist Plot +import pandas as pd +import json +import glob +import os +import re +import matplotlib.pyplot as plt + +def plot_edss_distribution_per_iteration(json_dir_path): + # 1. Reuse your categorization logic + def categorize_edss(value): + if pd.isna(value): return 'Unknown' + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + + # 2. Extract data from all files with Numerical Sorting + all_records = [] + json_files = glob.glob(os.path.join(json_dir_path, "*.json")) + + # Natural sort function to handle Iter 1, Iter 2 ... Iter 10 + def natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)] + + json_files.sort(key=natural_key) + + for i, file_path in enumerate(json_files): + # We use the index + 1 for the label to ensure Iter 1 to Iter 10 order + iter_label = f"Iter {i+1}" + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + for entry in data: + if entry.get("success"): + val = entry["result"].get("EDSS") + all_records.append({ + 'Iteration': iter_label, + 'Category': categorize_edss(val), + 'Order': i # Used to maintain sort order in the table + }) + except Exception as e: + print(f"Error reading {file_path}: {e}") + + df = pd.DataFrame(all_records) + + # 3. Create a Frequency Table (Crosstab) + # Pivot so iterations are on the X-axis + dist_table = pd.crosstab(df['Iteration'], df['Category']) + + # Ensure the rows (Iterations) stay in the 1-10 order + iter_order = [f"Iter {i+1}" for i in range(len(json_files))] + dist_table = dist_table.reindex(iter_order) + + # Ensure columns follow clinical order + fixed_labels = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + available_labels = [l for l in fixed_labels if l in dist_table.columns] + dist_table = dist_table[available_labels] + + # 4. Plotting + ax = dist_table.plot(kind='bar', stacked=True, figsize=(14, 8), colormap='viridis', edgecolor='white') + + plt.title('Distribution of Predicted EDSS Categories per Iteration', fontsize=15, pad=20) + plt.xlabel('JSON Iteration File', fontsize=12) + plt.ylabel('Number of Cases (Count)', fontsize=12) + plt.xticks(rotation=0) + + # Move legend outside to the right + plt.legend(title="EDSS Category", bbox_to_anchor=(1.05, 1), loc='upper left') + + # Add total count labels on top of bars + for i, (name, row) in enumerate(dist_table.iterrows()): + total = row.sum() + if total > 0: + plt.text(i, total + 2, f'Total: {int(total)}', ha='center', va='bottom', fontweight='bold') + + plt.tight_layout() + plt.show() + + return dist_table +# Usage: +counts_table = plot_edss_distribution_per_iteration('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration') +print(counts_table) +## + + +# %% Explore Table +import pandas as pd +import json +import glob +import os +import re + +def generate_edss_distribution_csv(json_dir_path, output_filename='edss_distribution_summary.csv'): + # 1. Categorization logic + def categorize_edss(value): + if pd.isna(value): return 'Unknown' + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + + # 2. Extract data from files with Natural Sorting + all_records = [] + json_files = glob.glob(os.path.join(json_dir_path, "*.json")) + + def natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)] + + json_files.sort(key=natural_key) + + for i, file_path in enumerate(json_files): + iter_label = f"Iter {i+1}" + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + for entry in data: + if entry.get("success"): + val = entry["result"].get("EDSS") + all_records.append({ + 'Iteration': iter_label, + 'Category': categorize_edss(val) + }) + except Exception as e: + print(f"Error reading {file_path}: {e}") + + df = pd.DataFrame(all_records) + + # 3. Create Frequency Table (Crosstab) + dist_table = pd.crosstab(df['Iteration'], df['Category']) + + # 4. Reindex Rows (Numerical order) and Columns (Clinical order) + iter_order = [f"Iter {i+1}" for i in range(len(json_files))] + dist_table = dist_table.reindex(iter_order) + + fixed_labels = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + available_labels = [l for l in fixed_labels if l in dist_table.columns] + dist_table = dist_table[available_labels] + + # Fill missing categories with 0 and convert to integers + dist_table = dist_table.fillna(0).astype(int) + + # 5. Add "Total" row at the end + # This sums the counts for each category across all iterations + dist_table.loc['Total Sum'] = dist_table.sum() + + # 6. Save to CSV + dist_table.to_csv(output_filename) + print(f"Table successfully saved to: {output_filename}") + + return dist_table + +# Usage: +final_table = generate_edss_distribution_csv('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration') +print(final_table) +## + +# %% EDSS Confusion Matrix +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay + +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +def plot_categorized_edss(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str) + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str) + df_gt['EDSS'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + # 2. Iterate through JSON files + all_preds = [] + json_pattern = os.path.join(json_dir_path, "*.json") + for file_path in glob.glob(json_pattern): + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + for entry in data: + if entry.get("success") and "result" in entry: + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')), + 'MedDatum': str(res.get('MedDatum')), + 'edss_pred': res.get('EDSS') + }) + except Exception as e: + print(f"Error reading {file_path}: {e}") + + df_pred = pd.DataFrame(all_preds) + df_pred['edss_pred'] = pd.to_numeric(df_pred['edss_pred'], errors='coerce') + + # 3. Merge and Categorize + # Clean keys to ensure 100% match rate + for df in [df_gt, df_pred]: + df['unique_id'] = df['unique_id'].astype(str).str.strip() + df['MedDatum'] = df['MedDatum'].astype(str).str.strip() + + df_merged = pd.merge( + df_gt[['unique_id', 'MedDatum', 'EDSS']], + df_pred, + on=['unique_id', 'MedDatum'], + how='inner' + ) + + df_merged = df_merged.dropna(subset=['EDSS', 'edss_pred']) + + # --- ADDED THESE LINES TO FIX THE NAMEERROR --- + y_true = df_merged['EDSS'].apply(categorize_edss) + y_pred = df_merged['edss_pred'].apply(categorize_edss) + # ---------------------------------------------- + + print(f"Verification: Total matches in Confusion Matrix: {len(df_merged)}") + + # 4. Define fixed labels to handle data gaps + fixed_labels = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + + # 5. Generate Confusion Matrix + cm = confusion_matrix(y_true, y_pred, labels=fixed_labels) + + # 6. Plotting + fig, ax = plt.subplots(figsize=(10, 8)) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=fixed_labels) + + # Plotting (y_axis is Ground Truth, x_axis is LLM Prediction) + disp.plot(cmap=plt.cm.Blues, values_format='d', ax=ax) + + plt.title('Categorized EDSS: Ground Truth vs LLM Prediction') + plt.ylabel('Ground Truth EDSS') + plt.xlabel('LLM Prediction') + plt.show() +## + +# %% Confusion Matrix adjusted +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay + +def categorize_edss(value): + """Bins EDSS values into clinical categories.""" + if pd.isna(value): + return np.nan + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + +def plot_categorized_edss(json_dir_path, ground_truth_path): + # 1. Load Ground Truth with Normalization + df_gt = pd.read_csv(ground_truth_path, sep=';') + # Standardize keys to ensure 1:N matching works + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + # 2. Load All Predictions from JSONs + all_preds = [] + json_files = glob.glob(os.path.join(json_dir_path, "*.json")) + + for file_path in json_files: + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + for entry in data: + # We only take 'success': true entries + if entry.get("success") and "result" in entry: + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'edss_pred': res.get('EDSS') + }) + except Exception as e: + print(f"Error reading {file_path}: {e}") + + df_pred = pd.DataFrame(all_preds) + df_pred['edss_pred'] = pd.to_numeric(df_pred['edss_pred'], errors='coerce') + + # 3. Merge (This should give you ~3934 rows based on your audit) + df_merged = pd.merge( + df_gt[['unique_id', 'MedDatum', 'EDSS']], + df_pred, + on=['unique_id', 'MedDatum'], + how='inner' + ) + + # --- THE BIG REVEAL: Count the NaNs --- + nan_in_gt = df_merged['EDSS'].isna().sum() + nan_in_pred = df_merged['edss_pred'].isna().sum() + + print("-" * 40) + print(f"TOTAL MERGED ROWS: {len(df_merged)}") + print(f"Rows with missing Ground Truth EDSS: {nan_in_gt}") + print(f"Rows with missing Prediction EDSS: {nan_in_pred}") + print("-" * 40) + + # Now drop rows that have NO values in either side for the matrix + df_final = df_merged.dropna(subset=['EDSS', 'edss_pred']).copy() + print(f"FINAL ROWS FOR CONFUSION MATRIX: {len(df_final)}") + print("-" * 40) + + # 4. Categorize for the Matrix + y_true = df_final['EDSS'].apply(categorize_edss) + y_pred = df_final['edss_pred'].apply(categorize_edss) + + fixed_labels = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + + # 5. Generate and Print Raw Matrix + cm = confusion_matrix(y_true, y_pred, labels=fixed_labels) + + # Print the Raw Matrix to terminal + cm_df = pd.DataFrame(cm, index=[f"True_{l}" for l in fixed_labels], + columns=[f"Pred_{l}" for l in fixed_labels]) + print("\nRAW CONFUSION MATRIX (Rows=True, Cols=Pred):") + print(cm_df) + + # 6. Plotting + fig, ax = plt.subplots(figsize=(12, 10)) + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=fixed_labels) + + # Values_format='d' ensures we see whole numbers, not scientific notation + disp.plot(cmap=plt.cm.Blues, values_format='d', ax=ax) + + plt.title(f'EDSS Confusion Matrix\n(n={len(df_final)} iterations across ~400 cases)', fontsize=14) + plt.ylabel('Ground Truth (Clinician)') + plt.xlabel('LLM Prediction') + plt.xticks(rotation=45) + plt.tight_layout() + plt.show() + +## +# %% Subcategories + +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt + +def plot_subcategory_analysis(json_dir_path, ground_truth_path): + # 1. Column Mapping (JSON Key : CSV Column) + mapping = { + "VISUAL_OPTIC_FUNCTIONS": "Sehvermögen", + "BRAINSTEM_FUNCTIONS": "Hirnstamm", + "PYRAMIDAL_FUNCTIONS": "Pyramidalmotorik", + "CEREBELLAR_FUNCTIONS": "Cerebellum", + "SENSORY_FUNCTIONS": "Sensibiliät", + "BOWEL_AND_BLADDER_FUNCTIONS": "Blasen-_und_Mastdarmfunktion", + "CEREBRAL_FUNCTIONS": "Cerebrale_Funktion", + "AMBULATION": "Ambulation" + } + + # 2. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str) + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str) + + # 3. Load Predictions including Subcategories + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + row = { + 'unique_id': str(res.get('unique_id')), + 'MedDatum': str(res.get('MedDatum')) + } + # Add subcategory scores + for json_key in mapping.keys(): + row[json_key] = res.get('subcategories', {}).get(json_key) + all_preds.append(row) + + df_pred = pd.DataFrame(all_preds) + + # 4. Merge + df_merged = pd.merge(df_gt, df_pred, on=['unique_id', 'MedDatum'], suffixes=('_gt', '_llm')) + + # 5. Calculate Metrics + results = [] + for json_key, csv_col in mapping.items(): + # Ensure numeric + true_vals = pd.to_numeric(df_merged[csv_col], errors='coerce') + pred_vals = pd.to_numeric(df_merged[json_key], errors='coerce') + + # Drop NaNs for this specific subcategory + mask = true_vals.notna() & pred_vals.notna() + y_t = true_vals[mask] + y_p = pred_vals[mask] + + if len(y_t) > 0: + accuracy = (y_t == y_p).mean() * 100 + mae = np.abs(y_t - y_p).mean() # Mean Absolute Error (Deviation) + results.append({ + 'Subcategory': csv_col, + 'Accuracy': accuracy, + 'Deviation': mae + }) + + stats_df = pd.DataFrame(results).sort_values('Accuracy', ascending=False) + +# 6. Plotting + fig, ax1 = plt.subplots(figsize=(14, 7)) + + # Bar chart for Accuracy + bars = ax1.bar(stats_df['Subcategory'], stats_df['Accuracy'], + color='#3498db', alpha=0.8, label='Accuracy (%)') + ax1.set_ylabel('Accuracy (%)', color='#2980b9', fontsize=12, fontweight='bold') + ax1.set_ylim(0, 115) # Extra head room for labels + ax1.grid(axis='y', linestyle='--', alpha=0.7) + + # Rotate labels + plt.xticks(rotation=30, ha='right', fontsize=10) + + # Line chart for Deviation + ax2 = ax1.twinx() + ax2.plot(stats_df['Subcategory'], stats_df['Deviation'], + color='#e74c3c', marker='o', linewidth=2.5, markersize=8, + label='Mean Abs. Deviation (Score Points)') + ax2.set_ylabel('Mean Absolute Deviation', color='#c0392b', fontsize=12, fontweight='bold') + + # Adjust ax2 limit to avoid overlap with accuracy text + ax2.set_ylim(0, max(stats_df['Deviation']) * 1.5 if not stats_df['Deviation'].empty else 1) + +# plt.title('Subcategory Performance: Accuracy vs. Mean Deviation', fontsize=14, pad=20) + + # --- THE FIX: Better Legend Placement --- + # Combine legends from both axes and place them above the plot + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, + loc='upper center', bbox_to_anchor=(0.5, 1.12), + ncol=2, frameon=False, fontsize=11) + + # Add percentage labels on top of bars + for bar in bars: + height = bar.get_height() + ax1.annotate(f'{height:.1f}%', + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 5), textcoords="offset points", + ha='center', va='bottom', fontweight='bold', color='#2c3e50') + + plt.tight_layout() + plt.show() +## + +# %% Certainty +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt + +def categorize_edss(value): + if pd.isna(value): return np.nan + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + +def plot_certainty_vs_accuracy_by_category(json_dir_path, ground_truth_path): + # 1. Data Loading & Merging + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str) + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str) + df_gt['EDSS_true'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')), + 'MedDatum': str(res.get('MedDatum')), + 'EDSS_pred': res.get('EDSS'), + 'certainty': res.get('certainty_percent') + }) + + df_pred = pd.DataFrame(all_preds) + df_pred['EDSS_pred'] = pd.to_numeric(df_pred['EDSS_pred'], errors='coerce') + + df = pd.merge(df_gt[['unique_id', 'MedDatum', 'EDSS_true']], + df_pred, on=['unique_id', 'MedDatum']).dropna() + + # 2. Process Metrics + df['gt_category'] = df['EDSS_true'].apply(categorize_edss) + df['is_correct'] = (df['EDSS_true'].round(1) == df['EDSS_pred'].round(1)) + + fixed_labels = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + + # Calculate Mean Certainty and Mean Accuracy per category + stats = df.groupby('gt_category').agg({ + 'is_correct': 'mean', + 'certainty': 'mean', + 'unique_id': 'count' + }).reindex(fixed_labels) + + stats['accuracy_percent'] = stats['is_correct'] * 100 + stats = stats.fillna(0) + + # 3. Plotting + x = np.arange(len(fixed_labels)) + width = 0.35 # Width of the bars + + fig, ax = plt.subplots(figsize=(14, 8)) + + # Plotting both bars side-by-side + rects1 = ax.bar(x - width/2, stats['accuracy_percent'], width, + label='Actual Accuracy (%)', color='#2ecc71', alpha=0.8) + rects2 = ax.bar(x + width/2, stats['certainty'], width, + label='LLM Avg. Certainty (%)', color='#e67e22', alpha=0.8) + + # Add text labels, titles and custom x-axis tick labels, etc. + ax.set_ylabel('Percentage (%)', fontsize=12) + ax.set_xlabel('Ground Truth EDSS Category', fontsize=12) +# ax.set_title('Comparison: LLM Confidence (Certainty) vs. Real Accuracy per EDSS Range', fontsize=15, pad=25) + ax.set_xticks(x) + ax.set_xticklabels(fixed_labels) + ax.set_ylim(0, 115) + ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.08), ncol=2, frameon=False) + ax.grid(axis='y', linestyle=':', alpha=0.5) + + # Helper function to label bar heights + def autolabel(rects): + for rect in rects: + height = rect.get_height() + if height > 0: + ax.annotate(f'{height:.0f}%', + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), textcoords="offset points", + ha='center', va='bottom', fontsize=9, fontweight='bold') + + autolabel(rects1) + autolabel(rects2) + + # Add sample size (n) at the bottom + for i, count in enumerate(stats['unique_id']): + ax.text(i, 2, f'n={int(count)}', ha='center', va='bottom', fontsize=10, color='white', fontweight='bold') + + plt.tight_layout() + plt.show() + +## + + + +# %% Boxplot +import pandas as pd +import numpy as np +import json +import glob +import os +import re +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Patch + +def natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)] + +def plot_edss_boxplot(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + gt_values = pd.to_numeric(df_gt['EDSS'], errors='coerce').dropna().tolist() + + # 2. Load Iterations + json_files = glob.glob(os.path.join(json_dir_path, "*.json")) + json_files.sort(key=natural_key) + + plot_data = [gt_values] + labels = ['Ground Truth'] + + for i, file_path in enumerate(json_files): + iteration_values = [] + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + val = entry["result"].get("EDSS") + if val is not None: + iteration_values.append(float(val)) + plot_data.append(iteration_values) + labels.append(f"Iter {i+1}") + + # 3. Plotting Configuration + plt.figure(figsize=(14, 8)) + + # Define colors + gt_color = '#ff9999' # Soft Red + iter_color = '#66b3ff' # Soft Blue + + # Create the boxplot + bplot = plt.boxplot(plot_data, labels=labels, patch_artist=True, + notch=False, + medianprops={'color': 'black', 'linewidth': 2}, + flierprops={'marker': 'o', 'markerfacecolor': 'gray', 'markersize': 5, 'alpha': 0.5}, + showmeans=True, + meanprops={"marker":"D", "markerfacecolor":"white", "markeredgecolor":"black", "markersize": 6}) + + # 4. Fill boxes with colors + colors = [gt_color] + [iter_color] * (len(plot_data) - 1) + for patch, color in zip(bplot['boxes'], colors): + patch.set_facecolor(color) + + # 5. CONSTRUCT THE COMPLETE LEGEND + legend_elements = [ + Patch(facecolor=gt_color, edgecolor='black', label='Ground Truth'), + Patch(facecolor=iter_color, edgecolor='black', label='LLM Iterations (1-10)'), + Line2D([0], [0], color='black', lw=2, label='Median'), + Line2D([0], [0], marker='D', color='w', label='Mean Score', + markerfacecolor='white', markeredgecolor='black', markersize=8), + Line2D([0], [0], marker='o', color='w', label='Outliers', + markerfacecolor='gray', markersize=6, alpha=0.5) + ] + + plt.legend(handles=legend_elements, loc='upper right', frameon=True, shadow=True, title="Legend") + + # Formatting + plt.title('Distribution of EDSS Scores: Ground Truth vs. 10 LLM Iterations', fontsize=16, pad=20) + plt.ylabel('EDSS Score (0-10)', fontsize=12) + plt.xlabel('Data Source', fontsize=12) + plt.grid(axis='y', linestyle='--', alpha=0.4) + plt.ylim(-0.5, 10.5) + plt.xticks(rotation=45) + + plt.tight_layout() + plt.show() +## + +# %% Audit + + +import pandas as pd +import numpy as np +import json +import glob +import os + +def audit_matches(json_dir_path, ground_truth_path): + # 1. Load GT + df_gt = pd.read_csv(ground_truth_path, sep=';') + + # 2. Advanced Normalization + def clean_series(s): + return s.astype(str).str.strip().str.lower() + + df_gt['unique_id'] = clean_series(df_gt['unique_id']) + df_gt['MedDatum'] = clean_series(df_gt['MedDatum']) + + # 3. Load Predictions + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + file_name = os.path.basename(file_path) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'file': file_name + }) + + df_pred = pd.DataFrame(all_preds) + + # 4. Find the "Ghost" entries (In JSON but not in GT) + # Create a 'key' column for easy comparison + df_gt['key'] = df_gt['unique_id'] + "_" + df_gt['MedDatum'] + df_pred['key'] = df_pred['unique_id'] + "_" + df_pred['MedDatum'] + + gt_keys = set(df_gt['key']) + df_pred['is_matched'] = df_pred['key'].isin(gt_keys) + + unmatched_summary = df_pred[df_pred['is_matched'] == False] + + print("--- AUDIT RESULTS ---") + print(f"Total rows in JSON: {len(df_pred)}") + print(f"Rows that matched GT: {df_pred['is_matched'].sum()}") + print(f"Rows that FAILED to match: {len(unmatched_summary)}") + + if not unmatched_summary.empty: + print("\nFirst 10 Unmatched Entries (check these against your CSV):") + print(unmatched_summary[['unique_id', 'MedDatum', 'file']].head(10)) + + # Breakdown by file - see if specific JSON files are broken + print("\nFailure count per JSON file:") + print(unmatched_summary['file'].value_counts()) + +audit_matches('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') + +## + + + + +# %% Cinfidence accuracy correlation + +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt + +def categorize_edss(value): + if pd.isna(value): return np.nan + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '6-7' + elif value <= 7.0: return '7-8' + elif value <= 8.0: return '8-9' + elif value <= 9.0: return '9-10' + else: return '10+' + +def plot_binned_calibration(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['gt_cat'] = pd.to_numeric(df_gt['EDSS'], errors='coerce').apply(categorize_edss) + + # 2. Load Predictions + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'pred_cat': categorize_edss(res.get('EDSS')), + 'confidence': res.get('certainty_percent') + }) + + df_pred = pd.DataFrame(all_preds) + df_merged = pd.merge(df_pred, df_gt[['unique_id', 'MedDatum', 'gt_cat']], + on=['unique_id', 'MedDatum'], how='inner') + + # Define correctness + df_merged['is_correct'] = (df_merged['pred_cat'] == df_merged['gt_cat']).astype(int) + + # 3. Create Confidence Bins (e.g., 0-60, 60-70, 70-80, 80-90, 90-100) + bins = [0, 60, 70, 80, 90, 100] + labels = ['<60%', '60-70%', '70-80%', '80-90%', '90-100%'] + df_merged['conf_bin'] = pd.cut(df_merged['confidence'], bins=bins, labels=labels) + + # Calculate average accuracy per bin + calibration_stats = df_merged.groupby('conf_bin')['is_correct'].agg(['mean', 'count']).reset_index() + + # 4. Plotting + plt.figure(figsize=(10, 6)) + + # Bar chart for actual accuracy + bars = plt.bar(calibration_stats['conf_bin'], calibration_stats['mean'], + color='skyblue', edgecolor='navy', alpha=0.7, label='Actual Accuracy') + + # Add the "Perfect Calibration" line + # (If confidence is 95%, accuracy should be 0.95) + expected_x = np.arange(len(labels)) + expected_y = [0.3, 0.65, 0.75, 0.85, 0.95] # Midpoints of the bins for visual reference + plt.plot(expected_x, expected_y, color='red', marker='o', linestyle='--', + linewidth=2, label='Perfect Calibration (Theoretical)') + + # 5. Add text labels on top of bars to show sample size (how many cases in that bin) + for i, bar in enumerate(bars): + yval = bar.get_height() + count = calibration_stats.loc[i, 'count'] + plt.text(bar.get_x() + bar.get_width()/2, yval + 0.02, + f'Acc: {yval:.1%}\n(n={count})', ha='center', va='bottom', fontsize=9) + + # Legend and Labels + plt.title('Model Calibration: Does Confidence Match Accuracy?', fontsize=14, pad=15) + plt.xlabel('LLM Confidence Score Bin', fontsize=12) + plt.ylabel('Actual Accuracy (Correct Category %)', fontsize=12) + plt.ylim(0, 1.1) + plt.grid(axis='y', linestyle=':', alpha=0.5) + + # Adding a clean, informative legend + plt.legend(loc='upper left', frameon=True, shadow=True) + + plt.tight_layout() + plt.show() +## + + + +# %% Confidence comparison + +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.lines import Line2D +from matplotlib.patches import Patch + +def plot_edss_confidence_comparison(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS_gt'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + # 2. Load Predictions from all JSONs + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + try: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'EDSS_pred': pd.to_numeric(res.get('EDSS'), errors='coerce'), + 'confidence': pd.to_numeric(res.get('certainty_percent'), errors='coerce') + }) + except Exception as e: + print(f"Skipping {file_path}: {e}") + + df_pred = pd.DataFrame(all_preds) + + # 3. Merge and Clean + df_merged = pd.merge(df_pred, df_gt[['unique_id', 'MedDatum', 'EDSS_gt']], + on=['unique_id', 'MedDatum'], how='inner') + df_plot = df_merged.dropna(subset=['EDSS_pred', 'EDSS_gt', 'confidence']).copy() + + # 4. Bin Confidence (X-Axis Categories) + # We group confidence into bins to create a readable boxplot + bins = [0, 60, 70, 80, 90, 100] + labels = ['<60%', '60-70%', '70-80%', '80-90%', '90-100%'] + df_plot['conf_bin'] = pd.cut(df_plot['confidence'], bins=bins, labels=labels) + + # 5. Plotting + plt.figure(figsize=(14, 8)) + + # A. Boxplot: Shows the distribution of LLM PREDICTIONS + sns.boxplot(data=df_plot, x='conf_bin', y='EDSS_pred', + color='#3498db', width=0.5, showfliers=False, + boxprops=dict(alpha=0.4, edgecolor='navy')) + + # B. Stripplot (Dots): Shows individual GROUND TRUTH scores + # We add jitter so dots don't hide each other + sns.stripplot(data=df_plot, x='conf_bin', y='EDSS_gt', + color='#e74c3c', alpha=0.4, jitter=0.2, size=5) + + # 6. Create a CLEAR Legend + legend_elements = [ + Patch(facecolor='#3498db', edgecolor='navy', alpha=0.4, + label='LLM Predictions (Box = Distribution)'), + Line2D([0], [0], marker='o', color='w', label='Ground Truth (Dots = Clinician Scores)', + markerfacecolor='#e74c3c', markersize=8, alpha=0.6), + Line2D([0], [0], color='black', lw=2, label='Median Predicted EDSS') + ] + plt.legend(handles=legend_elements, loc='upper left', frameon=True, shadow=True, title="Legend") + + # Final Labels + plt.title('Comparison of EDSS Scores Across Confidence Levels', fontsize=16, pad=20) + plt.xlabel('LLM Certainty Score (%)', fontsize=12) + plt.ylabel('EDSS Score (0-10)', fontsize=12) + plt.ylim(-0.5, 10.5) + plt.yticks(np.arange(0, 11, 1)) + plt.grid(axis='y', linestyle='--', alpha=0.3) + + plt.tight_layout() + plt.show() + + +## + + + +# %% EDSS vs Boxplot + +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch + +def categorize_edss(value): + if pd.isna(value): return np.nan + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + +def plot_edss_vs_confidence_boxplot(json_dir_path): + # 1. Load all Predictions + all_preds = [] + json_files = glob.glob(os.path.join(json_dir_path, "*.json")) + + for file_path in json_files: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + edss_val = pd.to_numeric(res.get('EDSS'), errors='coerce') + conf_val = pd.to_numeric(res.get('certainty_percent'), errors='coerce') + + if not pd.isna(edss_val) and not pd.isna(conf_val): + all_preds.append({ + 'edss_cat': categorize_edss(edss_val), + 'confidence': conf_val + }) + + df = pd.DataFrame(all_preds) + + # 2. Sort categories correctly for the x-axis + cat_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + df['edss_cat'] = pd.Categorical(df['edss_cat'], categories=cat_order, ordered=True) + + # 3. Plotting + plt.figure(figsize=(14, 8)) + + # Create Boxplot + sns.boxplot(data=df, x='edss_cat', y='confidence', + palette="Blues", width=0.6, showfliers=False) + + # Add Stripplot (Dots) to show density of cases + sns.stripplot(data=df, x='edss_cat', y='confidence', + color='black', alpha=0.15, jitter=0.2, size=3) + + # 4. Legend and Labels + # Since boxplot color is clear, we add a legend for the components + legend_elements = [ + Patch(facecolor='#6da7d1', label='Confidence Distribution (IQR)'), + plt.Line2D([0], [0], color='black', marker='o', linestyle='', + markersize=4, alpha=0.4, label='Individual Predictions') + ] + plt.legend(handles=legend_elements, loc='lower left', frameon=True) + + plt.title('LLM Confidence Levels Across Clinical EDSS Categories', fontsize=16, pad=20) + plt.xlabel('Predicted EDSS Category (Clinical Severity)', fontsize=12) + plt.ylabel('Confidence Score (%)', fontsize=12) + plt.ylim(0, 105) + plt.grid(axis='y', linestyle='--', alpha=0.3) + + plt.tight_layout() + plt.show() +## + + + + +# %% Correlation Boxplot +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from sklearn.metrics import cohen_kappa_score + +def categorize_edss(value): + """Standardized clinical categorization.""" + if pd.isna(value): return np.nan + elif value <= 1.0: return '0-1' + elif value <= 2.0: return '1-2' + elif value <= 3.0: return '2-3' + elif value <= 4.0: return '3-4' + elif value <= 5.0: return '4-5' + elif value <= 6.0: return '5-6' + elif value <= 7.0: return '6-7' + elif value <= 8.0: return '7-8' + elif value <= 9.0: return '8-9' + elif value <= 10.0: return '9-10' + else: return '10+' + +def plot_categorical_vs_categorical(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['gt_cat'] = pd.to_numeric(df_gt['EDSS'], errors='coerce').apply(categorize_edss) + + # 2. Load Predictions + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'pred_cat': categorize_edss(pd.to_numeric(res.get('EDSS'), errors='coerce')) + }) + + df_pred = pd.DataFrame(all_preds) + + # 3. Merge + df_merged = pd.merge(df_pred, df_gt[['unique_id', 'MedDatum', 'gt_cat']], + on=['unique_id', 'MedDatum'], how='inner').dropna() + + # 4. Set Order and Numeric Mapping for Plotting + cat_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] + cat_map = {cat: i for i, cat in enumerate(cat_order)} + + df_merged['gt_idx'] = df_merged['gt_cat'].map(cat_map) + df_merged['pred_idx'] = df_merged['pred_cat'].map(cat_map) + + # Calculate Cohen's Kappa (Standard for categorical agreement) + kappa = cohen_kappa_score(df_merged['gt_cat'], df_merged['pred_cat'], weights='linear') + + # 5. Plotting + plt.figure(figsize=(14, 8)) + + # BOXPLOT: Distribution of Predicted Categories relative to Ground Truth + sns.boxplot(data=df_merged, x='gt_cat', y='pred_idx', + palette="rocket", width=0.6, showfliers=False, boxprops=dict(alpha=0.5)) + + # STRIPPLOT: Individual counts + sns.stripplot(data=df_merged, x='gt_cat', y='pred_idx', + color='black', alpha=0.1, jitter=0.3, size=4) + + # DIAGONAL REFERENCE: Perfect category match + plt.plot([0, 9], [0, 9], color='red', linestyle='--', linewidth=2) + + # 6. Formatting Legend & Axes + plt.yticks(ticks=range(len(cat_order)), labels=cat_order) + + legend_elements = [ + Patch(facecolor='#ae3e50', alpha=0.5, label='Predicted Category Spread'), + plt.Line2D([0], [0], color='red', linestyle='--', label='Perfect Category Agreement'), + plt.Line2D([0], [0], color='black', marker='o', linestyle='', markersize=4, alpha=0.3, label='Iteration Matches'), + Patch(color='none', label=f'Linear Weighted Kappa: {kappa:.3f}') + ] + plt.legend(handles=legend_elements, loc='upper left', frameon=True, shadow=True, title="Agreement Metrics") + + plt.title('Categorical Agreement: Ground Truth vs. LLM Prediction', fontsize=16, pad=20) + plt.xlabel('Ground Truth Category (Clinician)', fontsize=12) + plt.ylabel('LLM Predicted Category', fontsize=12) + plt.grid(axis='both', linestyle=':', alpha=0.4) + + plt.tight_layout() + plt.show() +## + + + +# %% rainplot +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_error_distribution_by_confidence(json_dir_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS_gt'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + # 2. Load Predictions + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'EDSS_pred': pd.to_numeric(res.get('EDSS'), errors='coerce'), + 'confidence': pd.to_numeric(res.get('certainty_percent'), errors='coerce') + }) + + df_merged = pd.merge(pd.DataFrame(all_preds), df_gt[['unique_id', 'MedDatum', 'EDSS_gt']], + on=['unique_id', 'MedDatum'], how='inner').dropna() + + # 3. Calculate Error + df_merged['error'] = df_merged['EDSS_pred'] - df_merged['EDSS_gt'] + + # 4. Bin Confidence + bins = [0, 70, 80, 90, 100] + labels = ['Low (<70%)', 'Moderate (70-80%)', 'High (80-90%)', 'Very High (90-100%)'] + df_merged['conf_bin'] = pd.cut(df_merged['confidence'], bins=bins, labels=labels) + + # Calculate counts for labels + counts = df_merged['conf_bin'].value_counts().reindex(labels) + new_labels = [f"{l}\n(n={int(counts[l])})" for l in labels] + + # 5. Plotting + plt.figure(figsize=(13, 8)) + + # Using a sequential color palette (Light blue to Dark blue) + palette_colors = sns.color_palette("Blues", n_colors=len(labels)) + + vplot = sns.violinplot(data=df_merged, x='conf_bin', y='error', inner="quartile", + palette=palette_colors, cut=0) + + # Reference line at 0 + plt.axhline(0, color='#d9534f', linestyle='--', linewidth=2.5) + + # 6. UPDATED LEGEND WITH CORRECT COLORS + legend_elements = [ + # Legend items for the color gradient + Patch(facecolor=palette_colors[0], label='Confidence: <70%'), + Patch(facecolor=palette_colors[1], label='Confidence: 70-80%'), + Patch(facecolor=palette_colors[2], label='Confidence: 80-90%'), + Patch(facecolor=palette_colors[3], label='Confidence: 90-100%'), + # Legend items for the symbols + Line2D([0], [0], color='black', linestyle=':', label='Quartile Lines (25th, 50th, 75th)'), + Line2D([0], [0], color='#d9534f', linestyle='--', lw=2.5, label='Zero Error (Perfect Match)') + ] + + plt.legend(handles=legend_elements, loc='upper left', frameon=True, shadow=True, title="Legend & Confidence Gradient") + + # Formatting + plt.title('Error Magnitude vs. LLM Confidence Levels', fontsize=16, pad=20) + plt.xlabel('LLM Certainty Group', fontsize=12) + plt.ylabel('Prediction Delta (EDSS_pred - EDSS_gt)', fontsize=12) + plt.xticks(ticks=range(len(labels)), labels=new_labels) + plt.grid(axis='y', linestyle=':', alpha=0.5) + + plt.tight_layout() + plt.show() + +# plot_error_distribution_by_confidence('jsons_folder/', 'ground_truth.csv') +## + + + +# %% Certainty vs Delta +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_confidence_vs_abs_error_refined(json_dir_path, ground_truth_path): + # 1. Load and Merge Data + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS_gt'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'EDSS_pred': pd.to_numeric(res.get('EDSS'), errors='coerce'), + 'confidence': pd.to_numeric(res.get('certainty_percent'), errors='coerce') + }) + + df_merged = pd.merge(pd.DataFrame(all_preds), df_gt[['unique_id', 'MedDatum', 'EDSS_gt']], + on=['unique_id', 'MedDatum'], how='inner').dropna() + + # 2. Calculate Absolute Delta + df_merged['abs_error'] = (df_merged['EDSS_pred'] - df_merged['EDSS_gt']).abs() + + # 3. Binning + bins = [0, 70, 80, 90, 100] + labels = ['Low (<70%)', 'Moderate (70-80%)', 'High (80-90%)', 'Very High (90-100%)'] + df_merged['conf_bin'] = pd.cut(df_merged['confidence'], bins=bins, labels=labels) + + stats = df_merged.groupby('conf_bin', observed=True)['abs_error'].agg(['mean', 'std', 'count']).reset_index() + + # 4. Plotting + plt.figure(figsize=(12, 8)) + # Sequential palette: light to dark + colors = sns.color_palette("Blues", n_colors=len(labels)) + + bars = plt.bar(stats['conf_bin'], stats['mean'], color=colors, edgecolor='black', linewidth=1.2) + + # Standard Error Bars + plt.errorbar(stats['conf_bin'], stats['mean'], + yerr=stats['std']/np.sqrt(stats['count']), + fmt='none', c='black', capsize=6, elinewidth=1.5) + + # Trend Line (Linear Fit) + x_idx = np.arange(len(labels)) + z = np.polyfit(x_idx, stats['mean'], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=2.5) + + # 5. THE COMPLETE LEGEND + # We create a specific handle for every single thing on the chart + legend_elements = [ + # Explicit color mapping for bins + Patch(facecolor=colors[0], edgecolor='black', label=f'Bin 1: {labels[0]}'), + Patch(facecolor=colors[1], edgecolor='black', label=f'Bin 2: {labels[1]}'), + Patch(facecolor=colors[2], edgecolor='black', label=f'Bin 3: {labels[2]}'), + Patch(facecolor=colors[3], edgecolor='black', label=f'Bin 4: {labels[3]}'), + # Statistical components + Line2D([0], [0], color='black', marker='_', linestyle='None', markersize=10, label='Standard Error (SEM)'), + Line2D([0], [0], color='#e74c3c', linestyle='--', lw=2.5, label='Correlation Trend (Inverse Rel.)'), + # Metric definition + Patch(color='none', label='Metric: Mean Absolute Error (MAE)') + ] + + plt.legend(handles=legend_elements, loc='upper right', frameon=True, + shadow=True, fontsize=10, title="Legend") + + # Final Labels & Clean-up + plt.title('Validation: Inverse Correlation of Confidence vs. Error Magnitude', fontsize=15, pad=20) + plt.ylabel('Mean Absolute Error (Δ EDSS Points)', fontsize=12) + plt.xlabel('LLM Confidence Bracket', fontsize=12) + + # Text annotations for MAE on bars + for i, bar in enumerate(bars): + plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05, + f'MAE: {stats.loc[i, "mean"]:.2f}', ha='center', fontweight='bold') + + plt.grid(axis='y', linestyle=':', alpha=0.5) + plt.tight_layout() + plt.show() +## + + + +# %% name +import pandas as pd +import numpy as np +import json +import glob +import os +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_final_thesis_error_chart(json_dir_path, ground_truth_path): + # 1. Load Ground Truth & Predictions + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS_gt'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + all_preds = [] + for file_path in glob.glob(os.path.join(json_dir_path, "*.json")): + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'EDSS_pred': pd.to_numeric(res.get('EDSS'), errors='coerce'), + 'confidence': pd.to_numeric(res.get('certainty_percent'), errors='coerce') + }) + + df_merged = pd.merge(pd.DataFrame(all_preds), df_gt[['unique_id', 'MedDatum', 'EDSS_gt']], + on=['unique_id', 'MedDatum'], how='inner').dropna() + + # 2. Metric Calculation + df_merged['abs_error'] = (df_merged['EDSS_pred'] - df_merged['EDSS_gt']).abs() + + # 3. Binning & Stats + bins = [0, 70, 80, 90, 100] + labels = ['Low (<70%)', 'Moderate (70-80%)', 'High (80-90%)', 'Very High (90-100%)'] + df_merged['conf_bin'] = pd.cut(df_merged['confidence'], bins=bins, labels=labels) + + stats = df_merged.groupby('conf_bin', observed=True)['abs_error'].agg(['mean', 'std', 'count']).reset_index() + + # 4. Plotting + plt.figure(figsize=(13, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + # BARS (MAE) + bars = plt.bar(stats['conf_bin'], stats['mean'], color=colors, edgecolor='black', alpha=0.85) + + # ERROR BARS (Standard Error of the Mean) + plt.errorbar(stats['conf_bin'], stats['mean'], + yerr=stats['std']/np.sqrt(stats['count']), + fmt='none', c='black', capsize=8, elinewidth=1.5) + + # CORRELATION TREND LINE + x_idx = np.arange(len(labels)) + z = np.polyfit(x_idx, stats['mean'], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5) + + # 5. DATA LABELS (n and MAE) + for i, bar in enumerate(bars): + n_count = int(stats.loc[i, 'count']) + mae_val = stats.loc[i, 'mean'] + plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.04, + f'MAE: {mae_val:.2f}\nn={n_count}', + ha='center', va='bottom', fontweight='bold', fontsize=10) + + # 6. THE COMPLETE LEGEND + legend_elements = [ + Patch(facecolor=colors[0], edgecolor='black', label=f'Bin 1: {labels[0]}'), + Patch(facecolor=colors[1], edgecolor='black', label=f'Bin 2: {labels[1]}'), + Patch(facecolor=colors[2], edgecolor='black', label=f'Bin 3: {labels[2]}'), + Patch(facecolor=colors[3], edgecolor='black', label=f'Bin 4: {labels[3]}'), + Line2D([0], [0], color='#e74c3c', linestyle='--', lw=3, label='Correlation Trend (Inverse Relationship)'), + Line2D([0], [0], color='black', marker='_', linestyle='None', markersize=10, label='Standard Error (SEM)'), + Patch(color='none', label='Metric: Mean Absolute Error (MAE)') + ] + plt.legend(handles=legend_elements, loc='upper right', frameon=True, shadow=True, title="Chart Components") + + # Formatting + plt.title('Clinical Validation: LLM Certainty vs. Prediction Accuracy', fontsize=16, pad=30) + plt.ylabel('Mean Absolute Error (EDSS Points)', fontsize=12) + plt.xlabel('LLM Confidence Bracket', fontsize=12) + plt.grid(axis='y', linestyle=':', alpha=0.5) + plt.ylim(0, stats['mean'].max() + 0.6) # Add room for labels + + plt.tight_layout() + plt.show() + +# plot_final_thesis_error_chart('jsons_folder/', 'gt.csv') +## + + + +# %% 1json +import pandas as pd +import numpy as np +import json +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +def plot_single_json_error_analysis(json_file_path, ground_truth_path): + # 1. Load Ground Truth + df_gt = pd.read_csv(ground_truth_path, sep=';') + df_gt['unique_id'] = df_gt['unique_id'].astype(str).str.strip().str.lower() + df_gt['MedDatum'] = df_gt['MedDatum'].astype(str).str.strip().str.lower() + df_gt['EDSS_gt'] = pd.to_numeric(df_gt['EDSS'], errors='coerce') + + # 2. Load the Specific JSON + all_preds = [] + with open(json_file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + for entry in data: + if entry.get("success"): + res = entry["result"] + all_preds.append({ + 'unique_id': str(res.get('unique_id')).strip().lower(), + 'MedDatum': str(res.get('MedDatum')).strip().lower(), + 'EDSS_pred': pd.to_numeric(res.get('EDSS'), errors='coerce'), + 'confidence': pd.to_numeric(res.get('certainty_percent'), errors='coerce') + }) + + df_pred = pd.DataFrame(all_preds) + + # 3. Merge and Calculate Absolute Error + df_merged = pd.merge(df_pred, df_gt[['unique_id', 'MedDatum', 'EDSS_gt']], + on=['unique_id', 'MedDatum'], how='inner').dropna() + + df_merged['abs_error'] = (df_merged['EDSS_pred'] - df_merged['EDSS_gt']).abs() + + # 4. Binning and Statistics + bins = [0, 70, 80, 90, 100] + labels = ['Low (<70%)', 'Moderate (70-80%)', 'High (80-90%)', 'Very High (90-100%)'] + df_merged['conf_bin'] = pd.cut(df_merged['confidence'], bins=bins, labels=labels) + + stats = df_merged.groupby('conf_bin', observed=True)['abs_error'].agg(['mean', 'std', 'count']).reset_index() + + # 5. Plotting + plt.figure(figsize=(13, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + # BARS (MAE) + bars = plt.bar(stats['conf_bin'], stats['mean'], color=colors, edgecolor='black', alpha=0.85) + + # ERROR BARS (SEM) + plt.errorbar(stats['conf_bin'], stats['mean'], + yerr=stats['std']/np.sqrt(stats['count']), + fmt='none', c='black', capsize=8, elinewidth=1.5) + + # CORRELATION TREND LINE + x_idx = np.arange(len(labels)) + z = np.polyfit(x_idx, stats['mean'], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5) + + # 6. DATA LABELS (n and MAE) + for i, bar in enumerate(bars): + n_count = int(stats.loc[i, 'count']) + mae_val = stats.loc[i, 'mean'] + plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.04, + f'MAE: {mae_val:.2f}\nn={n_count}', + ha='center', va='bottom', fontweight='bold', fontsize=10) + + # 7. COMPREHENSIVE LEGEND + legend_elements = [ + Patch(facecolor=colors[0], edgecolor='black', label=f'Bin 1: {labels[0]}'), + Patch(facecolor=colors[1], edgecolor='black', label=f'Bin 2: {labels[1]}'), + Patch(facecolor=colors[2], edgecolor='black', label=f'Bin 3: {labels[2]}'), + Patch(facecolor=colors[3], edgecolor='black', label=f'Bin 4: {labels[3]}'), + Line2D([0], [0], color='#e74c3c', linestyle='--', lw=3, label='Inverse Trend Line'), + Line2D([0], [0], color='black', marker='_', linestyle='None', markersize=10, label='Std Error (SEM)'), + Patch(color='none', label='Metric: Mean Absolute Error (MAE)') + ] + plt.legend(handles=legend_elements, loc='upper right', frameon=True, shadow=True, title="Legend") + + # Final Styling + plt.title('Validation: Confidence vs. Error Magnitude (Iteration 1 Only)', fontsize=15, pad=30) + plt.ylabel('Mean Absolute Error (EDSS Points)', fontsize=12) + plt.xlabel('LLM Confidence Bracket', fontsize=12) + plt.grid(axis='y', linestyle=':', alpha=0.5) + plt.ylim(0, stats['mean'].max() + 0.6) + + plt.tight_layout() + plt.show() + +# --- RUN THE PLOT --- +json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" + +plot_single_json_error_analysis(json_path, gt_path) +## + + + + +# %% Usage + +# --- Usage --- +#plot_categorized_edss('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', +# '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') + +#plot_subcategory_analysis('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_certainty_vs_accuracy_by_category('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') + + +#plot_edss_boxplot('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_binned_calibration('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') + +#plot_edss_vs_confidence_boxplot('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration') +#plot_gt_vs_llm_boxplot('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_categorical_vs_categorical('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_error_distribution_by_confidence('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_confidence_vs_abs_error_refined('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +#plot_confidence_vs_abs_error_with_counts('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +plot_final_thesis_error_chart('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') + +## diff --git a/figure1.py b/figure1.py index ee51e62..9f4706b 100644 --- a/figure1.py +++ b/figure1.py @@ -263,3 +263,120 @@ plt.legend(frameon=False, loc='upper center', bbox_to_anchor=(0.5, -0.05)) plt.tight_layout() plt.show() ## + + + + +# %% name +import matplotlib.pyplot as plt + +# Data +data = { + 'Visit': [9, 8, 7, 6, 5, 4, 3, 2, 1], + 'patient_count': [2, 3, 3, 6, 13, 17, 28, 24, 32] +} + +# Create figure and axis +fig, ax = plt.subplots(figsize=(10, 6)) + +# Plot the bar chart +bars = ax.bar(data['Visit'], data['patient_count'], color='darkblue', label='Patients by Visit Count') + +# Add labels and title +ax.set_xlabel('Visit Number (from last to first)', fontsize=12) +ax.set_ylabel('Number of Patients', fontsize=12) +ax.set_title('Patient Visits by Visit Number', fontsize=14) + +# Invert x-axis to show Visit 9 on the left (descending order) if desired, but keep natural order (1–9 left to right) +# For descending order (9→1 from left to right), we'd need to reverse: +# Visit = data['Visit'][::-1], patient_count = data['patient_count'][::-1] +# But standard practice is ascending (1 to 9), so we'll sort accordingly: +# Let's sort by Visit to ensure left-to-right: 1,2,...,9 + +# Actually, your current Visit list is [9,8,...,1], which is descending. +# Let's sort by Visit for intuitive left-to-right increasing order: +sorted_indices = sorted(range(len(data['Visit'])), key=lambda i: data['Visit'][i]) +visit_sorted = [data['Visit'][i] for i in sorted_indices] +count_sorted = [data['patient_count'][i] for i in sorted_indices] + +# Re-plot with sorted x-axis: +ax.clear() +bars = ax.bar(visit_sorted, count_sorted, color='darkblue', label='Patients by Visit Count') + +# Re-apply labels, etc. +ax.set_xlabel('Number of Visits', fontsize=12) +ax.set_ylabel('Number of Unique Patients', fontsize=12) +#ax.set_title('Number of Patients by Visit Number', fontsize=14) + +# Add legend +ax.legend() + +# Improve layout and grid +ax.grid(axis='y', linestyle='--', alpha=0.7) +plt.xticks(visit_sorted) # Ensure all integer visit numbers are shown + +# Show the plot +plt.tight_layout() +plt.show() + +## + +# %% Patientjourney Bubble chart +import matplotlib.pyplot as plt +import numpy as np + +import matplotlib as mpl + +mpl.rcParams["font.family"] = "DejaVu Sans" # or "Arial", "Calibri", "Times New Roman", ... +mpl.rcParams["font.size"] = 12 # default size for text +mpl.rcParams["axes.titlesize"] = 14 +mpl.rcParams["axes.titleweight"] = "bold" + + +# Data (your counts) +visits = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) +patient_count = np.array([32, 24, 28, 17, 13, 6, 3, 3, 2]) + +# "Remaining" = patients with >= that many visits (cumulative from the right) +remaining = np.array([patient_count[i:].sum() for i in range(len(patient_count))]) + +# --- Plot --- +fig, ax = plt.subplots(figsize=(12, 3)) + +y = 0.0 # all bubbles on one horizontal line + +# Horizontal line +ax.hlines(y, visits.min() - 0.4, visits.max() + 0.4, color="#1f77b4", linewidth=3) + +# Bubble sizes (scale as needed) +# (Matplotlib scatter uses area in points^2) +sizes = patient_count * 35 # tweak this multiplier if you want bigger/smaller bubbles + +ax.scatter(visits, np.full_like(visits, y), s=sizes, color="#1f77b4", zorder=3) + +# Title +#ax.set_title("Patient Journey by Visit Count", fontsize=14, pad=18) + +# Top labels: "1 visits", "2 visits", ... +for x in visits: + label = f"{x} visit" if x == 1 else f"{x} visits" + ax.text(x, y + 0.18, label, ha="center", va="bottom", fontsize=10) + +# Bottom labels: "X patients" and "Y remaining" +for x, pc, rem in zip(visits, patient_count, remaining): + ax.text(x, y - 0.20, f"{pc} patients", ha="center", va="top", fontsize=9) + ax.text(x, y - 0.32, f"{rem} remaining", ha="center", va="top", fontsize=9) + +# Cosmetics: remove axes, keep spacing nice +ax.set_xlim(visits.min() - 0.6, visits.max() + 0.6) +ax.set_ylim(-0.5, 0.35) +ax.set_xticks([]) +ax.set_yticks([]) +for spine in ax.spines.values(): + spine.set_visible(False) + +plt.tight_layout() +plt.show() +plt.savefig("patient_journey.svg", format="svg", bbox_inches="tight") +## + diff --git a/show_plots.py b/show_plots.py new file mode 100644 index 0000000..4332b98 --- /dev/null +++ b/show_plots.py @@ -0,0 +1,2320 @@ +# %% Scatter +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results +df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') +df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') +df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) + +# Create scatter plot +plt.figure(figsize=(8, 6)) +plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue') + +# Add labels and title +plt.xlabel('GT_EDSS') +plt.ylabel('LLM_Results') +plt.title('Comparison of GT_EDSS vs LLM_Results') + +# Optional: Add a diagonal line for reference (perfect prediction) +plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction') +plt.legend() + +# Show plot +plt.grid(True) +plt.tight_layout() +plt.show() + +## + + +# %% Bland0-altman + +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import statsmodels.api as sm + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results +df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') +df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') +df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) + +# Create Bland-Altman plot +f, ax = plt.subplots(1, figsize=(8, 5)) +sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax) + +# Add labels and title +ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results') +ax.set_xlabel('Mean of GT_EDSS and LLM_Results') +ax.set_ylabel('Difference between GT_EDSS and LLM_Results') + +# Display Bland-Altman plot +plt.tight_layout() +plt.show() + +# Print some statistics +mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results']) +std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results']) +print(f"Mean difference: {mean_diff:.3f}") +print(f"Standard deviation of differences: {std_diff:.3f}") +print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]") + +## + + + +# %% Confusion matrix +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# For confusion matrix, we need to categorize the values +# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +# Create categorical versions +df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) +df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) + +# Remove any NaN categories +df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) + +# Create confusion matrix +cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], + labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Plot confusion matrix +plt.figure(figsize=(10, 8)) +sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], + yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) +#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') +plt.xlabel('LLM Generated EDSS') +plt.ylabel('Ground Truth EDSS') +plt.tight_layout() +plt.show() + +# Print classification report +print("Classification Report:") +print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) + +# Print raw counts +print("\nConfusion Matrix (Raw Counts):") +print(cm) + +## + + +# %% Confusion matrix +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# For confusion matrix, we need to categorize the values +# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +# Create categorical versions +df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) +df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) + +# Remove any NaN categories +df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) + +# Create confusion matrix +cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], + labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Plot confusion matrix +plt.figure(figsize=(10, 8)) +ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], + yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Add legend text above the color bar +# Get the colorbar object +cbar = ax.collections[0].colorbar +# Add text above the colorbar +cbar.set_label('Number of Cases', rotation=270, labelpad=20) + +plt.xlabel('LLM Generated EDSS') +plt.ylabel('Ground Truth EDSS') +#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') +plt.tight_layout() +plt.show() + +# Print classification report +print("Classification Report:") +print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) + +# Print raw counts +print("\nConfusion Matrix (Raw Counts):") +print(cm) + + +## + + + + +# %% Classification +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix +import numpy as np + +# Load your data from TSV file +file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' + +df = pd.read_csv(file_path, sep='\t') + +# Check data structure +print("Data shape:", df.shape) +print("First few rows:") +print(df.head()) +print("\nColumn names:") +for col in df.columns: + print(f" {col}") + +# Function to safely convert to boolean +def safe_bool_convert(series): + '''Safely convert series to boolean, handling various input formats''' + # Convert to string first, then to boolean + series_str = series.astype(str).str.strip().str.lower() + + # Handle different true/false representations + bool_map = { + 'true': True, '1': True, 'yes': True, 'y': True, + 'false': False, '0': False, 'no': False, 'n': False + } + + converted = series_str.map(bool_map) + + # Handle remaining NaN values + converted = converted.fillna(False) # or True, depending on your preference + + return converted + +# Convert columns safely +if 'result.klassifizierbar' in df.columns: + print("\nresult.klassifizierbar column info:") + print(df['result.klassifizierbar'].head(10)) + print("Unique values:", df['result.klassifizierbar'].unique()) + + df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar']) + print("After conversion:") + print(df['result.klassifizierbar'].value_counts()) + +if 'GT.klassifizierbar' in df.columns: + print("\nGT.klassifizierbar column info:") + print(df['GT.klassifizierbar'].head(10)) + print("Unique values:", df['GT.klassifizierbar'].unique()) + + df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar']) + print("After conversion:") + print(df['GT.klassifizierbar'].value_counts()) + +# Create bar chart showing only True values for klassifizierbar +if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: + # Get counts for True values only + llm_true_count = df['result.klassifizierbar'].sum() + gt_true_count = df['GT.klassifizierbar'].sum() + + # Plot using matplotlib directly + fig, ax = plt.subplots(figsize=(8, 6)) + + x = np.arange(2) + width = 0.35 + + bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8) + bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8) + + # Add value labels on bars + ax.annotate(f'{llm_true_count}', + xy=(x[0], llm_true_count), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom') + + ax.annotate(f'{gt_true_count}', + xy=(x[1], gt_true_count), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom') + + ax.set_xlabel('Classification Status (klassifizierbar)') + ax.set_ylabel('Count') + ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"') + ax.set_xticks(x) + ax.set_xticklabels(['LLM', 'GT']) + ax.legend() + + plt.tight_layout() + plt.show() + +# Create confusion matrix if both columns exist +if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: + try: + # Ensure both columns are boolean + llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool) + gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool) + + cm = confusion_matrix(gt_bool, llm_bool) + + # Plot confusion matrix + fig, ax = plt.subplots(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['False ', 'True '], + yticklabels=['False', 'True '], + ax=ax) + ax.set_xlabel('LLM Predictions ') + ax.set_ylabel('GT Labels ') + ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"') + + plt.tight_layout() + plt.show() + + print("Confusion Matrix:") + print(cm) + + except Exception as e: + print(f"Error creating confusion matrix: {e}") + +# Show final data info +print("\nFinal DataFrame info:") +print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info()) + +## + + + + +# %% Boxplot +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# 1. DEFINE CATEGORY ORDER +# This ensures the X-axis is numerically logical (0-1 comes before 1-2) +category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+'] + +# Convert the column to a Categorical type with the specific order +df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss), + categories=category_order, + ordered=True) + +plt.figure(figsize=(14, 8)) + +# 2. ADD HUE FOR LEGEND +# Assigning x to 'hue' allows Seaborn to generate a legend automatically +box_plot = sns.boxplot( + data=df_clean, + x='GT.EDSS_cat', + y='result.EDSS', + hue='GT.EDSS_cat', # Added hue + palette='viridis', + linewidth=1.5, + legend=True # Ensure legend is enabled +) + +# 3. CUSTOMIZE PLOT +plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20) +plt.xlabel('Ground Truth EDSS Category', fontsize=14) +plt.ylabel('LLM Predicted EDSS', fontsize=14) + +# Move legend to the side or top +plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left') + +plt.xticks(rotation=45, ha='right', fontsize=10) +plt.grid(True, axis='y', alpha=0.3) +plt.tight_layout() + +plt.show() +## + + +# %% Postproccessing Column names + +import pandas as pd + +# Read the TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Create a mapping dictionary for German to English column names +column_mapping = { + 'EDSS':'GT.EDSS', + 'klassifizierbar': 'GT.klassifizierbar', + 'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS', + 'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS', + 'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS', + 'Sensibiliät': 'GT.SENSORY_FUNCTIONS', + 'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS', + 'Ambulation': 'GT.AMBULATION', + 'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS', + 'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS' +} + +# Rename columns +df = df.rename(columns=column_mapping) + +# Save the modified dataframe back to TSV file +df.to_csv(file_path, sep='\t', index=False) + +print("Columns have been successfully renamed!") +print("Renamed columns:") +for old_name, new_name in column_mapping.items(): + if old_name in df.columns: + print(f" {old_name} -> {new_name}") + + +## + + + + +# %% Styled table +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +import dataframe_image as dfi +# Load data +df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t') + +# 1. Identify all GT and result columns +gt_columns = [col for col in df.columns if col.startswith('GT.')] +result_columns = [col for col in df.columns if col.startswith('result.')] + +print("GT Columns found:", gt_columns) +print("Result Columns found:", result_columns) + +# 2. Create proper mapping between GT and result columns +# Handle various naming conventions (spaces, underscores, etc.) +column_mapping = {} + +for gt_col in gt_columns: + base_name = gt_col.replace('GT.', '') + + # Clean the base name for matching - remove spaces, underscores, etc. + # Try different matching approaches + candidates = [ + f'result.{base_name}', # Exact match + f'result.{base_name.replace(" ", "_")}', # With underscores + f'result.{base_name.replace("_", " ")}', # With spaces + f'result.{base_name.replace(" ", "")}', # No spaces + f'result.{base_name.replace("_", "")}' # No underscores + ] + + # Also try case-insensitive matching + candidates.append(f'result.{base_name.lower()}') + candidates.append(f'result.{base_name.upper()}') + + # Try to find matching result column + matched = False + for candidate in candidates: + if candidate in result_columns: + column_mapping[gt_col] = candidate + matched = True + break + + # If no exact match found, try partial matching + if not matched: + # Try to match by removing special characters and comparing + base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' ']) + for result_col in result_columns: + result_base = result_col.replace('result.', '') + result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' ']) + if base_clean.lower() == result_clean.lower(): + column_mapping[gt_col] = result_col + matched = True + break + +print("Column mapping:", column_mapping) + +# 3. Faster, vectorized computation using the corrected mapping +data_list = [] + +for gt_col, result_col in column_mapping.items(): + print(f"Processing {gt_col} vs {result_col}") + + # Convert to numeric, forcing errors to NaN + s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float) + s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float) + + # Calculate matches (abs difference <= 0.5) + diff = np.abs(s1 - s2) + matches = (diff <= 0.5).sum() + + # Determine the denominator (total valid comparisons) + valid_count = diff.notna().sum() + + if valid_count > 0: + percentage = (matches / valid_count) * 100 + else: + percentage = 0 + + # Extract clean base name for display + base_name = gt_col.replace('GT.', '') + + data_list.append({ + 'GT': base_name, + 'Match %': round(percentage, 1) + }) + + + + +# 4. Prepare Data +match_df = pd.DataFrame(data_list) +# Clean up labels: Replace underscores with spaces and capitalize +match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title() +match_df = match_df.sort_values('Match %', ascending=False) + +# 5. Create a "Beautiful" Table using Seaborn Heatmap +def create_luxury_table(df, output_file="edss_agreement.png"): + # Set the aesthetic style + sns.set_theme(style="white", font="sans-serif") + + # Prepare data for heatmap + plot_data = df.set_index('GT')[['Match %']] + + # Initialize the figure + # Height is dynamic based on number of rows + fig, ax = plt.subplots(figsize=(8, len(df) * 0.6)) + + # Create a custom diverging color map (Deep Red -> Mustard -> Emerald) + # This looks more professional than standard 'RdYlGn' + cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True) + + # Draw the heatmap + sns.heatmap( + plot_data, + annot=True, + fmt=".1f", + cmap=cmap, + center=85, # Centers the color transition + vmin=50, vmax=100, # Range of the gradient + linewidths=2, + linecolor='white', + cbar=False, # Remove color bar for a "table" look + annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"} + ) + + # Styling the Axes (Turning the heatmap into a table) + ax.set_xlabel("") + ax.set_ylabel("") + ax.xaxis.tick_top() # Move "Match %" label to top + ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50') + ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0) + + # Add a thin border around the plot + for _, spine in ax.spines.items(): + spine.set_visible(True) + spine.set_color('#ecf0f1') + + plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50') + + # Add a subtle footer + plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points", + wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic') + + # Save with high resolution + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Beautiful table saved as {output_file}") + +# Execute +create_luxury_table(match_df) + + +# Run the function +save_styled_table(match_df) +# 6. Save as SVG + +plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight') +print("Successfully saved agreement_table.svg") + +# Show plot if running in a GUI environment +plt.show() +## + + + +# %% Time Plot +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from scipy import stats + +# Load the TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Extract the inference_time_sec column +inference_times = df['inference_time_sec'].dropna() # Remove NaN values + +# Calculate statistics +mean_time = inference_times.mean() +std_time = inference_times.std() +median_time = np.median(inference_times) + +# Create the histogram +fig, ax = plt.subplots(figsize=(10, 6)) + +# Create histogram with bins of 1 second width +min_time = int(inference_times.min()) +max_time = int(inference_times.max()) + 1 +bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width + +# Create histogram with counts (not probability density) +n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5) + +# Generate Gaussian curve for fit +x = np.linspace(inference_times.min(), inference_times.max(), 100) +# Scale Gaussian to match histogram counts +gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0]) + +# Plot Gaussian fit +ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)') + +# Add vertical lines for mean and median +ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s') +ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s') + +# Add standard deviation as vertical lines +ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s') +ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s') + +ax.set_xlabel('Inference Time (seconds)') +ax.set_ylabel('Frequency') +ax.set_title('Inference Time Distribution with Gaussian Fit') +ax.legend() +ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +## + +# %% Dashboard +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import numpy as np +from matplotlib.gridspec import GridSpec + +def to_numeric_comma(s: pd.Series) -> pd.Series: + # accepts 1.5 and 1,5 + return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Rename columns to remove 'result.' prefix and replace spaces +column_mapping = {} +for col in df.columns: + if col.startswith('result.'): + new_name = col.replace('result.', '').replace(' ', '_') + column_mapping[col] = new_name +df = df.rename(columns=column_mapping) + +# Parse MedDatum safely +df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') + +# Patient +patient_id = 'd13e4aa3' +patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() +if patient_data.empty: + raise ValueError(f"No data found for patient: {patient_id}") + +# Functional systems + EDSS +edss_col, edss_title = ('GT.EDSS', 'EDSS') + +functional_systems = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), + ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), + ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), + ('GT.SENSORY_FUNCTIONS', 'Sensory'), + ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), + ('GT.AMBULATION', 'Ambulation'), + ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), +] + +# y-axis max rules +ymax_by_col = { + 'GT.PYRAMIDAL_FUNCTIONS': 6, + 'GT.SENSORY_FUNCTIONS': 6, + 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, + 'GT.VISUAL_OPTIC_FUNCTIONS': 6, + 'GT.CEREBELLAR_FUNCTIONS': 5, + 'GT.CEREBRAL_FUNCTIONS': 5, + 'GT.BRAINSTEM_FUNCTIONS': 5, + 'GT.EDSS': 10, +} +default_ymax = 6 + +# ---------- Build shared "event dates" ticks ---------- +cols_for_dates = [edss_col] + [c for c, _ in functional_systems] +event_dates = [] + +for c in cols_for_dates: + if c in patient_data.columns: + y = to_numeric_comma(patient_data[c]) # <-- changed + x = patient_data['MedDatum'] + tmp = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]) + event_dates.extend(tmp["x"].tolist()) + +event_dates = sorted(pd.Series(event_dates).drop_duplicates().tolist()) + +max_ticks = 8 +if len(event_dates) > max_ticks: + idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) + event_dates = [event_dates[i] for i in idx] + +# ---------- A4 figure ---------- +fig = plt.figure(figsize=(11.69, 8.27)) +gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) + +def style_time_axis(ax, show_labels=True): + ax.set_xticks(event_dates) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) + ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) + if not show_labels: + ax.tick_params(labelbottom=False) + +# ---------- EDSS main plot ---------- +ax_main = fig.add_subplot(gs[0, :]) + +if edss_col in patient_data.columns: + y = to_numeric_comma(patient_data[edss_col]) # <-- changed + x = patient_data['MedDatum'] + plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") + + ax_main.set_title(edss_title, fontsize=14, fontweight='bold') + ax_main.set_ylabel("Score") + ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) + ax_main.grid(True, alpha=0.3) + + if not plot_df.empty: + ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') + else: + ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') +else: + ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') + ax_main.set_ylim(0, ymax_by_col.get(edss_col, 10)) + ax_main.grid(True, alpha=0.3) + +style_time_axis(ax_main) + +# ---------- Small aligned plots ---------- +small_axes = [] +for k, (col, title) in enumerate(functional_systems): + r = 1 + (k // 4) + c = (k % 4) + ax = fig.add_subplot(gs[r, c], sharex=ax_main) + small_axes.append(ax) + + ymax = ymax_by_col.get(col, default_ymax) + ax.set_title(title, fontsize=10) + ax.set_ylabel("Score") + ax.set_ylim(0, ymax) + ax.grid(True, alpha=0.3) + + if col in patient_data.columns: + y = to_numeric_comma(patient_data[col]) # <-- changed + x = patient_data['MedDatum'] + plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") + + if not plot_df.empty: + ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') + else: + ax.set_title(f"{title} (no data)", fontsize=10) + else: + ax.set_title(f"{title} (missing)", fontsize=10) + + style_time_axis(ax) + +# Hide x tick labels on first row of small plots +for ax in small_axes[:4]: + ax.tick_params(labelbottom=False) + +plt.tight_layout() +fig.subplots_adjust(hspace=0.7) +plt.show() +## + +<<<<<<< Updated upstream +======= + + + +# %% Dashboard Angepasst +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import numpy as np +from matplotlib.gridspec import GridSpec + +def to_numeric_comma(s: pd.Series) -> pd.Series: + # accepts 1.5 and 1,5 + return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Rename columns to remove 'result.' prefix and replace spaces +column_mapping = {} +for col in df.columns: + if col.startswith('result.'): + new_name = col.replace('result.', '').replace(' ', '_') + column_mapping[col] = new_name +df = df.rename(columns=column_mapping) + +# Parse MedDatum safely +df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') + +# Patient +patient_id = '3d942c60' + +patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() +if patient_data.empty: + raise ValueError(f"No data found for patient: {patient_id}") + +# Functional systems + EDSS +edss_col, edss_title = ('GT.EDSS', 'EDSS') + +functional_systems = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), + ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), + ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), + ('GT.SENSORY_FUNCTIONS', 'Sensory'), + ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), + ('GT.AMBULATION', 'Ambulation'), + ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), +] + +# y-axis max rules +ymax_by_col = { + 'GT.PYRAMIDAL_FUNCTIONS': 6, + 'GT.SENSORY_FUNCTIONS': 6, + 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, + 'GT.VISUAL_OPTIC_FUNCTIONS': 6, + 'GT.CEREBELLAR_FUNCTIONS': 5, + 'GT.CEREBRAL_FUNCTIONS': 5, + 'GT.BRAINSTEM_FUNCTIONS': 5, + 'GT.EDSS': 10, +} +default_ymax = 6 + +# ---------- Build shared visit dates ticks ---------- +# Use ALL patient visit dates, not only dates with valid numeric values +event_dates = sorted(patient_data['MedDatum'].dropna().drop_duplicates().tolist()) + +max_ticks = 8 +if len(event_dates) > max_ticks: + idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) + event_dates = [event_dates[i] for i in idx] + +# Base timeline for plotting: one row per patient visit date +timeline = ( + patient_data[['MedDatum']] + .dropna() + .drop_duplicates() + .sort_values('MedDatum') + .rename(columns={'MedDatum': 'x'}) +) + +# ---------- A4 figure ---------- +fig = plt.figure(figsize=(11.69, 8.27)) +gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) + +def style_time_axis(ax, show_labels=True): + ax.set_xticks(event_dates) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) + ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) + if not show_labels: + ax.tick_params(labelbottom=False) + +def get_plot_df(patient_data, col): + """ + Keep all visit dates. + Missing values stay NaN so matplotlib draws gaps instead of zeros. + """ + tmp = patient_data[['MedDatum', col]].copy() + tmp = tmp.rename(columns={'MedDatum': 'x', col: 'raw_y'}) + tmp['y'] = to_numeric_comma(tmp['raw_y']) + + # aggregate if multiple rows exist on same date + tmp = tmp.groupby('x', as_index=False)['y'].max() + + # merge onto full timeline so all dates remain visible + plot_df = timeline.merge(tmp, on='x', how='left').sort_values('x') + return plot_df + +# ---------- EDSS main plot ---------- +ax_main = fig.add_subplot(gs[0, :]) +ax_main.set_title(edss_title, fontsize=14, fontweight='bold') +ax_main.set_ylabel("Score") +ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) +ax_main.grid(True, alpha=0.3) + +if edss_col in patient_data.columns: + plot_df = get_plot_df(patient_data, edss_col) + + if plot_df['y'].notna().any(): + # NaNs create visible gaps in the line + ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') + else: + ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') +else: + ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') + +style_time_axis(ax_main) + +# ---------- Small aligned plots ---------- +small_axes = [] +for k, (col, title) in enumerate(functional_systems): + r = 1 + (k // 4) + c = (k % 4) + ax = fig.add_subplot(gs[r, c], sharex=ax_main) + small_axes.append(ax) + + ymax = ymax_by_col.get(col, default_ymax) + ax.set_title(title, fontsize=10) + ax.set_ylabel("Score") + ax.set_ylim(0, ymax) + ax.grid(True, alpha=0.3) + + if col in patient_data.columns: + plot_df = get_plot_df(patient_data, col) + + if plot_df['y'].notna().any(): + # NaNs remain in y -> line breaks where data is missing + ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') + else: + ax.set_title(f"{title} (no numeric data)", fontsize=10) + else: + ax.set_title(f"{title} (missing)", fontsize=10) + + style_time_axis(ax) + +# Hide x tick labels on first row of small plots +for ax in small_axes[:4]: + ax.tick_params(labelbottom=False) + +plt.tight_layout() +fig.subplots_adjust(hspace=0.7) +plt.show() + + + +## + +>>>>>>> Stashed changes +# %% Table +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from datetime import datetime +import numpy as np + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Convert MedDatum to datetime +df['MedDatum'] = pd.to_datetime(df['MedDatum']) + +# Check what columns actually exist in the dataset +print("Available columns:") +print(df.columns.tolist()) +print("\nFirst few rows:") +print(df.head()) + +# Check data types +print("\nData types:") +print(df.dtypes) + +# Hardcode specific patient names +patient_names = ['6ccda8c6'] + +# Define the functional systems (columns to plot) +functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] + +# Create subplots +num_plots = len(functional_systems) +num_cols = 2 +num_rows = (num_plots + num_cols - 1) // num_cols + +fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) +if num_plots == 1: + axes = [axes] +elif num_rows == 1: + axes = axes +else: + axes = axes.flatten() + +# Plot for the hardcoded patient +for i, system in enumerate(functional_systems): + # Filter data for this specific patient + patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum') + + # Check if patient data exists + if patient_data.empty: + print(f"No data found for patient: {patient_names[0]}") + axes[i].set_title(f'Functional System: {system} (No data)') + axes[i].set_ylabel('Score') + continue + + # Check if the system column exists + if system in patient_data.columns: + # Plot only valid data (non-null values) + valid_data = patient_data.dropna(subset=[system]) + + if not valid_data.empty: + # Ensure MedDatum is properly formatted for plotting + axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system) + axes[i].set_ylabel('Score') + axes[i].set_title(f'Functional System: {system}') + axes[i].grid(True, alpha=0.3) + axes[i].legend() + else: + axes[i].set_title(f'Functional System: {system} (No valid data)') + axes[i].set_ylabel('Score') + else: + # Try to find similar column names + found_column = None + for col in df.columns: + if system.lower() in col.lower(): + found_column = col + break + + if found_column: + valid_data = patient_data.dropna(subset=[found_column]) + if not valid_data.empty: + axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column) + axes[i].set_ylabel('Score') + axes[i].set_title(f'Functional System: {system} (found as: {found_column})') + axes[i].grid(True, alpha=0.3) + axes[i].legend() + else: + axes[i].set_title(f'Functional System: {system} (No valid data)') + axes[i].set_ylabel('Score') + else: + axes[i].set_title(f'Functional System: {system} (Column not found)') + axes[i].set_ylabel('Score') + +# Hide empty subplots +for i in range(len(functional_systems), len(axes)): + axes[i].set_visible(False) + +# Set x-axis label for the last row only +for i in range(len(functional_systems)): + if i >= len(axes) - num_cols: # Last row + axes[i].set_xlabel('Date') + +# Format x-axis dates +for ax in axes: + if ax.get_lines(): # Only format if there are lines to plot + ax.tick_params(axis='x', rotation=45) + ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d')) + +# Automatically adjust layout +plt.tight_layout() +plt.show() + + + + +## + + + + +# %% Histogram Fig1 +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +import json +import os + +def create_visit_frequency_plot( + file_path, + output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', + output_filename='visit_frequency_distribution.svg', + fontsize=10, + color_scheme_path='colors.json' +): + """ + Creates a publication-ready bar chart of patient visit frequency. + + Args: + file_path (str): Path to the input TSV file. + output_dir (str): Directory to save the output SVG file. + output_filename (str): Name of the output SVG file. + fontsize (int): Font size for all text elements (labels, title). + color_scheme_path (str): Path to the JSON file containing the color palette. + """ + # --- 1. Load Data and Color Scheme --- + try: + df = pd.read_csv(file_path, sep='\t') + print("Data loaded successfully.") + # Sort data for easier visual comparison + df = df.sort_values(by='Visits Count') + except FileNotFoundError: + print(f"Error: The file was not found at {file_path}") + return + + try: + with open(color_scheme_path, 'r') as f: + colors = json.load(f) + # Select a blue from the sequential palette for the bars + bar_color = colors['sequential']['blues'][-2] # A saturated blue + except FileNotFoundError: + print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") + bar_color = '#2171b5' # A common matplotlib blue + + # --- 2. Set up the Plot with Scientific Style --- + plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height + + # Set the font to Arial + arial_font = fm.FontProperties(family='Arial', size=fontsize) + plt.rcParams['font.family'] = 'Arial' + plt.rcParams['font.size'] = fontsize + + # --- 3. Create the Bar Chart --- + ax = plt.gca() + bars = plt.bar( + x=df['Visits Count'], + height=df['Unique Patients'], + color=bar_color, + edgecolor='black', + linewidth=0.5, # Minimum line thickness + width=0.7 + ) + + # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- + # Get the unique visit counts to use as tick labels + visit_counts = df['Visits Count'].unique() + # Set the x-ticks to be at the center of each bar + ax.set_xticks(visit_counts) + # Set the x-tick labels to be the visit counts, using the specified font + ax.set_xticklabels(visit_counts, fontproperties=arial_font) + # --- END OF NEW SECTION --- + + # --- 4. Customize Axes and Layout (Nature style) --- + # Display only left and bottom axes + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Turn off axis ticks (the marks, not the labels) + plt.tick_params(axis='both', which='both', length=0) + + # Remove grid lines + plt.grid(False) + + # Set background to white (no shading) + ax.set_facecolor('white') + plt.gcf().set_facecolor('white') + + # --- 5. Add Labels and Title --- + plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) + plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) + plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) + + # --- 6. Add y-axis values on top of each bar --- + # This adds the count of unique patients directly above each bar. + ax.bar_label(bars, fmt='%d', padding=3) + + # --- 7. Export the Figure --- + # Ensure the output directory exists + os.makedirs(output_dir, exist_ok=True) + + full_output_path = os.path.join(output_dir, output_filename) + plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') + print(f"\nFigure saved as '{full_output_path}'") + + # --- 8. (Optional) Display the Plot --- + # plt.show() + +# --- Main execution --- +if __name__ == '__main__': + # Define the file path + input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' + + # Call the function to create and save the plot + create_visit_frequency_plot( + file_path=input_file, + fontsize=10 # Using a 10 pt font size as per guidelines + ) + +## + + + +# %% Scatter Plot functional system + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os + +# --- Configuration --- +# Set the font to Arial for all text in the plot, as per the guidelines +plt.rcParams['font.family'] = 'Arial' + +# Define the path to your data file +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' + +# Define the path to save the color mapping JSON file +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' + +# Define the path to save the final figure +figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' + +# --- 1. Load the Dataset --- +try: + # Load the TSV file + df = pd.read_csv(data_path, sep='\t') + print(f"Successfully loaded data from {data_path}") + print(f"Data shape: {df.shape}") +except FileNotFoundError: + print(f"Error: The file at {data_path} was not found.") + # Exit or handle the error appropriately + raise + +# --- 2. Define Functional Systems and Create Color Mapping --- +# List of tuples containing (ground_truth_column, result_column) +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +# Extract system names for color mapping and legend +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] + +# Define a professional color palette (dark blue theme) +# This is a qualitative palette with distinct, accessible colors +colors = [ + '#003366', # Dark Blue + '#336699', # Medium Blue + '#6699CC', # Light Blue + '#99CCFF', # Very Light Blue + '#FF9966', # Coral + '#FF6666', # Light Red + '#CC6699', # Magenta + '#9966CC' # Purple +] + +# Create a dictionary mapping system names to colors +color_map = dict(zip(system_names, colors)) + +# Ensure the directory for the JSON file exists +os.makedirs(os.path.dirname(color_json_path), exist_ok=True) + +# Save the color map to a JSON file +with open(color_json_path, 'w') as f: + json.dump(color_map, f, indent=4) + +print(f"Color mapping saved to {color_json_path}") + +# --- 3. Calculate Agreement Percentages and Format Legend Labels --- +agreement_percentages = {} +legend_labels = {} + +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Ensure we are comparing the same rows + common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) + gt_data = gt_numeric.loc[common_index] + res_data = res_numeric.loc[common_index] + + # Calculate agreement percentage + if len(gt_data) > 0: + agreement = (gt_data == res_data).mean() * 100 + else: + agreement = 0 # Handle case with no valid data + + agreement_percentages[system_name] = agreement + + # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") + formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) + legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" + +# --- 4. Reshape Data for Plotting --- +plot_data = [] +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Create a temporary DataFrame with the numeric data + temp_df = pd.DataFrame({ + 'system': system_name, + 'ground_truth': gt_numeric, + 'inference': res_numeric + }) + + # Drop rows where either value is NaN, as they cannot be plotted + temp_df = temp_df.dropna() + + plot_data.append(temp_df) + +# Concatenate all the temporary DataFrames into one +plot_df = pd.concat(plot_data, ignore_index=True) + +if plot_df.empty: + print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") +else: + print(f"Prepared plot data with {len(plot_df)} data points.") + +# --- 5. Create the Scatter Plot --- +plt.figure(figsize=(10, 8)) + +# Plot each functional system with its assigned color and formatted legend label +for system, group in plot_df.groupby('system'): + plt.scatter( + group['ground_truth'], + group['inference'], + label=legend_labels[system], + color=color_map[system], + alpha=0.7, + s=30 + ) + +# Add a diagonal line representing perfect agreement (y = x) +# This line helps visualize how close the predictions are to the ground truth +if not plot_df.empty: + plt.plot( + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + color='black', + linestyle='--', + linewidth=0.8, + alpha=0.7 + ) + +# --- 6. Apply Styling and Labels --- +plt.xlabel('Ground Truth', fontsize=12) +plt.ylabel('LLM Inference', fontsize=12) +plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) + +# Apply scientific visualization styling rules +ax = plt.gca() +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.tick_params(axis='both', which='both', length=0) # Remove ticks +ax.grid(False) # Remove grid lines +plt.legend(title='Functional System', frameon=False, fontsize=10) + +# --- 7. Save and Display the Figure --- +# Ensure the directory for the figure exists +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) + +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +print(f"Figure successfully saved to {figure_save_path}") + +# Display the plot +plt.show() +## + + + + +# %% Confusion Matrix functional systems + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os +import numpy as np +import matplotlib.colors as mcolors + +# --- Configuration --- +plt.rcParams['font.family'] = 'Arial' +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' +figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' + +# --- 1. Load the Dataset --- +df = pd.read_csv(data_path, sep='\t') + +# --- 2. Define Functional Systems and Colors --- +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] +colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] +color_map = dict(zip(system_names, colors)) + +# --- 3. Categorization Function --- +categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] +category_to_index = {cat: i for i, cat in enumerate(categories)} +n_categories = len(categories) + +def categorize_edss(value): + if pd.isna(value): return np.nan + # Ensure value is float to avoid TypeError + val = float(value) + idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 + return categories[min(idx, len(categories)-1)] + +# --- 4. Prepare Mixed Color Matrix with Saturation --- +cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) + +for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): + # Fix: Ensure numeric conversion to avoid string comparison errors + temp_df = df[[gt_col, res_col]].copy() + temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') + temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') + valid_df = temp_df.dropna() + + for _, row in valid_df.iterrows(): + gt_cat = categorize_edss(row[gt_col]) + res_cat = categorize_edss(row[res_col]) + if gt_cat in category_to_index and res_cat in category_to_index: + cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 + +# Create an RGBA image matrix (10x10x4) +rgba_matrix = np.zeros((n_categories, n_categories, 4)) + +total_counts = np.sum(cell_system_counts, axis=2) +max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 + +for i in range(n_categories): + for j in range(n_categories): + count_sum = total_counts[i, j] + if count_sum > 0: + mixed_rgb = np.zeros(3) + for s_idx, s_name in enumerate(system_names): + weight = cell_system_counts[i, j, s_idx] / count_sum + system_rgb = mcolors.to_rgb(color_map[s_name]) + mixed_rgb += np.array(system_rgb) * weight + + # Set RGB channels + rgba_matrix[i, j, :3] = mixed_rgb + + # Set Alpha channel (Saturation Effect) + # Using a square root scale to make lower counts more visible but still "lighter" + alpha = np.sqrt(count_sum / max_count) + # Ensure alpha is at least 0.1 so it's not invisible + rgba_matrix[i, j, 3] = max(0.1, alpha) + else: + # Empty cells are white + rgba_matrix[i, j] = [1, 1, 1, 0] + +# --- 5. Plotting --- +fig, ax = plt.subplots(figsize=(12, 10)) + +# Show the matrix +# Note: we use origin='lower' if you want 0-1 at the bottom, +# but confusion matrices usually have 0-1 at the top (origin='upper') +im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') + +# Add count labels +for i in range(n_categories): + for j in range(n_categories): + if total_counts[i, j] > 0: + # Background brightness for text contrast + bg_color = rgba_matrix[i, j, :3] + lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] + # If alpha is low, background is effectively white, so use black text + text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" + ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", + color=text_col, fontsize=10, fontweight='bold') + +# --- 6. Styling --- +ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) +ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) +ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) + +ax.set_xticks(np.arange(n_categories)) +ax.set_xticklabels(categories) +ax.set_yticks(np.arange(n_categories)) +ax.set_yticklabels(categories) + +# Remove the frame/spines for a cleaner look +for spine in ax.spines.values(): + spine.set_visible(False) + +# Custom Legend +handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] +labels = [name.replace('_', ' ').capitalize() for name in system_names] +ax.legend(handles, labels, title='Functional Systems', loc='upper left', + bbox_to_anchor=(1.05, 1), frameon=False) + +plt.tight_layout() +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +plt.show() + +## + + + + + + + + + + +# %% Difference Plot Functional system + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os +import numpy as np + +# --- Configuration --- +# Set the font to Arial for all text in the plot, as per the guidelines +plt.rcParams['font.family'] = 'Arial' + +# Define the path to your data file +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' + +# Define the path to save the color mapping JSON file +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' + +# Define the path to save the final figure +figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' + +# --- 1. Load the Dataset --- +try: + # Load the TSV file + df = pd.read_csv(data_path, sep='\t') + print(f"Successfully loaded data from {data_path}") + print(f"Data shape: {df.shape}") +except FileNotFoundError: + print(f"Error: The file at {data_path} was not found.") + # Exit or handle the error appropriately + raise + +# --- 2. Define Functional Systems and Create Color Mapping --- +# List of tuples containing (ground_truth_column, result_column) +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +# Extract system names for color mapping and legend +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] + +# Define a professional color palette (dark blue theme) +# This is a qualitative palette with distinct, accessible colors +colors = [ + '#003366', # Dark Blue + '#336699', # Medium Blue + '#6699CC', # Light Blue + '#99CCFF', # Very Light Blue + '#FF9966', # Coral + '#FF6666', # Light Red + '#CC6699', # Magenta + '#9966CC' # Purple +] + +# Create a dictionary mapping system names to colors +color_map = dict(zip(system_names, colors)) + +# Ensure the directory for the JSON file exists +os.makedirs(os.path.dirname(color_json_path), exist_ok=True) + +# Save the color map to a JSON file +with open(color_json_path, 'w') as f: + json.dump(color_map, f, indent=4) + +print(f"Color mapping saved to {color_json_path}") + +# --- 3. Calculate Agreement Percentages and Format Legend Labels --- +agreement_percentages = {} +legend_labels = {} + +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Ensure we are comparing the same rows + common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) + gt_data = gt_numeric.loc[common_index] + res_data = res_numeric.loc[common_index] + + # Calculate agreement percentage + if len(gt_data) > 0: + agreement = (gt_data == res_data).mean() * 100 + else: + agreement = 0 # Handle case with no valid data + + agreement_percentages[system_name] = agreement + + # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") + formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) + legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" + + + # --- 4. Robustly Prepare Error Data for Boxplot --- + +def safe_parse(s): + '''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)''' + if pd.isna(s): + return np.nan + if isinstance(s, (int, float)): + return float(s) + # Replace comma with dot, then strip whitespace + s_clean = str(s).replace(',', '.').strip() + try: + return float(s_clean) + except ValueError: + return np.nan + +plot_data = [] +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Parse both columns with robust comma handling + gt_numeric = df[gt_col].apply(safe_parse) + res_numeric = df[res_col].apply(safe_parse) + + # Compute error (only where both are finite) + error = res_numeric - gt_numeric + + # Create temp DataFrame + temp_df = pd.DataFrame({ + 'system': system_name, + 'error': error + }).dropna() # drop rows where either was unparseable + + plot_data.append(temp_df) + +plot_df = pd.concat(plot_data, ignore_index=True) + +if plot_df.empty: + print("⚠️ Warning: No valid numeric error data to plot after robust parsing.") +else: + print(f"✅ Prepared error data with {len(plot_df)} data points.") + # Diagnostic: show a few samples + print("\n📌 Sample errors by system:") + for sys, grp in plot_df.groupby('system'): + print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}") + +# Ensure categorical ordering +plot_df['system'] = pd.Categorical( + plot_df['system'], + categories=[name.split('.')[1] for name, _ in functional_systems_to_plot], + ordered=True +) + + +# --- 5. Prepare Data for Diverging Stacked Bar Plot --- +print("\n📊 Preparing diverging stacked bar plot data...") + +# Define bins for error direction +def categorize_error(err): + if pd.isna(err): + return 'missing' + elif err < 0: + return 'underestimate' + elif err > 0: + return 'overestimate' + else: + return 'match' + +# Add category column (only on finite errors) +plot_df_clean = plot_df[plot_df['error'].notna()].copy() +plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error) + +# Count by system + category +category_counts = ( + plot_df_clean + .groupby(['system', 'category']) + .size() + .unstack(fill_value=0) + .reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0) +) +# Reorder systems +category_counts = category_counts.reindex(system_names) + +# Prepare for diverging plot: +# - Underestimates: plotted to the *left* (negative x) +# - Overestimates: plotted to the *right* (positive x) +# - Matches: centered (no width needed, or as a bar of width 0.2) +underestimate_counts = category_counts['underestimate'] +match_counts = category_counts['match'] +overestimate_counts = category_counts['overestimate'] + +# For diverging: left = -underestimate, right = overestimate +left_counts = underestimate_counts +right_counts = overestimate_counts + +# Compute max absolute bar height (for symmetric x-axis) +max_bar = max(left_counts.max(), right_counts.max(), 1) +plot_range = (-max_bar, max_bar) + +# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ... +n_systems = len(system_names) +positions = np.arange(n_systems) +left_positions = -positions - 0.5 # left-aligned underestimates +right_positions = positions + 0.5 # right-aligned overestimates + +# --- 6. Create Diverging Stacked Bar Plot --- +plt.figure(figsize=(12, 7)) + +# Colors: diverging palette +colors = { + 'underestimate': '#E74C3C', # Red (left) + 'match': '#2ECC71', # Green (center) + 'overestimate': '#F39C12' # Orange (right) +} + +# Plot underestimates (left side) +bars_left = plt.barh( + left_positions, + left_counts.values, + height=0.8, + left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values) + color=colors['underestimate'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Underestimate' +) + +# Plot overestimates (right side) +bars_right = plt.barh( + right_positions, + right_counts.values, + height=0.8, + left=0, + color=colors['overestimate'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Overestimate' +) + +# Plot matches (center — narrow bar) +# Use a very narrow width (0.2) centered at 0 +plt.barh( + positions, + match_counts.values, + height=0.2, + left=0, # starts at 0, extends right + color=colors['match'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Exact Match' +) + +# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match) +# For perfect symmetry: +for i, count in enumerate(match_counts.values): + if count > 0: + plt.barh( + positions[i], + width=count, + left=-count/2, + height=0.25, + color=colors['match'], + edgecolor='black', + linewidth=0.5, + alpha=0.95 + ) + +# --- 7. Styling & Labels --- +# Zero reference line +plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8) + +# X-axis: symmetric around 0 +plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1) +plt.xticks(rotation=0, fontsize=10) +plt.xlabel('Count', fontsize=12) + +# Y-axis: system names at original positions (centered) +plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10) +plt.ylabel('Functional System', fontsize=12) + +# Title & layout +plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15) + +# Clean axes +ax = plt.gca() +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['left'].set_visible(False) # We only need bottom axis +ax.xaxis.set_ticks_position('bottom') +ax.yaxis.set_ticks_position('none') + +# Grid only along x +ax.xaxis.grid(True, linestyle=':', alpha=0.5) + +# Legend +from matplotlib.patches import Patch +legend_elements = [ + Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'), + Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'), + Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate') +] +plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10) + +# Optional: Add counts on bars +for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)): + if left > 0: + plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold') + if right > 0: + plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold') + if match > 0: + plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black') + +plt.tight_layout() + +# --- 8. Save & Show --- +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +print(f"✅ Diverging bar plot saved to {figure_save_path}") + +plt.show() + +## + + + +# %% Difference Plot Gemini +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg' + +# --- 1. Process Error Data with Magnitude Breakdown --- +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] +plot_list = [] + +for gt_col, res_col in functional_systems_to_plot: + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + error = res - gt + + # Granular Counts + matches = (error == 0).sum() + u_1 = (error == -1).sum() + u_2plus = (error <= -2).sum() + o_1 = (error == 1).sum() + o_2plus = (error >= 2).sum() + + total = error.dropna().count() + divisor = max(total, 1) + + plot_list.append({ + 'System': sys_name.replace('_', ' ').title(), + 'Matches': matches, 'MatchPct': (matches / divisor) * 100, + 'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus, + 'UnderPct': ((u_1 + u_2plus) / divisor) * 100, + 'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus, + 'OverPct': ((o_1 + o_2plus) / divisor) * 100 + }) + +stats_df = pd.DataFrame(plot_list) + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(13, 8)) + +# Define Magnitude Colors +c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1) +c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1) +bar_height = 0.6 +y_pos = np.arange(len(stats_df)) + +# Plot Under-scored (Stacked: -2+ then -1) +ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white') +ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white') + +# Plot Over-scored (Stacked: +1 then +2+) +ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white') +ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white') + +# --- 3. Aesthetics & Table Labels --- +for i, row in stats_df.iterrows(): + label_text = ( + f"$\\mathbf{{{row['System']}}}$\n" + f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" + f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)" + ) + # Position table text to the left + ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4) + +# Formatting +ax.axvline(0, color='black', linewidth=1.2) +ax.set_yticks([]) +ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold') +#ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35) + +# Absolute X-axis labels +ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) + +# Remove spines and add grid +for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False) +ax.xaxis.grid(True, linestyle='--', alpha=0.3) + +# Legend with magnitude info +ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) + +plt.tight_layout() +plt.show() +## + + +# %% Functional System Error Boxplots +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_boxplot.svg' + +# --- 1. Build error data for boxplots --- +boxplot_data = [] +system_labels = [] +sample_sizes = [] + +for gt_col, res_col in functional_systems_to_plot: + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + + # Error = result - ground truth + error = (res - gt).dropna() + + # Ignore all 0 errors + error = error[error != 0] + + # Keep only systems that actually have non-zero data + if len(error) > 0: + clean_name = sys_name.replace('_', ' ').title() + boxplot_data.append(error.values) + system_labels.append(clean_name) + sample_sizes.append(len(error)) + +# Safety check +if not boxplot_data: + raise ValueError("No valid non-zero error data available for any functional system.") + +# Put n into x-axis labels so it doesn't overlap the plot +xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(14, 8)) + +bp = ax.boxplot( + boxplot_data, + vert=True, + patch_artist=True, + labels=xtick_labels, + showmeans=True, + meanline=False +) + +# --- 3. Styling --- +box_face = '#D6EAF8' +box_edge = '#2980B9' +whisker_col = '#7F8C8D' +median_col = '#C0392B' +mean_col = '#1ABC9C' +flier_face = '#95A5A6' +flier_edge = '#7F8C8D' + +for box in bp['boxes']: + box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) + +for whisker in bp['whiskers']: + whisker.set(color=whisker_col, linewidth=1.2) + +for cap in bp['caps']: + cap.set(color=whisker_col, linewidth=1.2) + +for median in bp['medians']: + median.set(color=median_col, linewidth=2) + +for mean in bp['means']: + mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) + +for flier in bp['fliers']: + flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) + +# Reference line at zero error +ax.axhline(0, color='black', linewidth=1.2, linestyle='--') + +# Labels and formatting +ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') +ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') + +# Rotate x labels for readability +plt.xticks(rotation=45, ha='right') + +# Grid and spines +ax.yaxis.grid(True, linestyle='--', alpha=0.3) +for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + +# --- 4. Legend above the plot, outside the axes --- +legend_handles = [ + Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), + Line2D([0], [0], color=median_col, lw=2, label='Median'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, + markeredgecolor='black', markersize=7, label='Mean'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, + markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), + Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') +] + +ax.legend( + handles=legend_handles, + loc='lower center', + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False +) + +# Leave room at the top for the legend +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +# Optional save +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') + +plt.show() +## + +<<<<<<< Updated upstream +======= +# %% Functional System + EDSS Error Boxplots +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_edss_boxplot.svg' + +# ------------------------------------------------------------ +# Expect functional_systems_to_plot like: +# [ +# ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL_OPTIC_FUNCTIONS'), +# ... +# ] +# +# Add EDSS here: +# ------------------------------------------------------------ +all_systems_to_plot = list(functional_systems_to_plot) + [ + ('GT.EDSS', 'result.EDSS') +] + +# --- 1. Build error data for boxplots --- +boxplot_data = [] +system_labels = [] +sample_sizes = [] + +for gt_col, res_col in all_systems_to_plot: + # Skip safely if a column is missing + if gt_col not in df.columns or res_col not in df.columns: + print(f"Skipping missing columns: {gt_col}, {res_col}") + continue + + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + + # Error = result - ground truth + error = (res - gt).dropna() + + # Ignore all 0 errors + error = error[error != 0] + + # Keep only systems that actually have non-zero data + if len(error) > 0: + if sys_name == 'EDSS': + clean_name = 'EDSS' + else: + clean_name = sys_name.replace('_', ' ').title() + + boxplot_data.append(error.values) + system_labels.append(clean_name) + sample_sizes.append(len(error)) + +# Safety check +if not boxplot_data: + raise ValueError("No valid non-zero error data available for any functional system or EDSS.") + +# Put n into x-axis labels so it doesn't overlap the plot +xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(15, 8)) + +bp = ax.boxplot( + boxplot_data, + vert=True, + patch_artist=True, + labels=xtick_labels, + showmeans=True, + meanline=False +) + +# --- 3. Styling --- +box_face = '#D6EAF8' +box_edge = '#2980B9' +whisker_col = '#7F8C8D' +median_col = '#C0392B' +mean_col = '#1ABC9C' +flier_face = '#95A5A6' +flier_edge = '#7F8C8D' + +for box in bp['boxes']: + box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) + +for whisker in bp['whiskers']: + whisker.set(color=whisker_col, linewidth=1.2) + +for cap in bp['caps']: + cap.set(color=whisker_col, linewidth=1.2) + +for median in bp['medians']: + median.set(color=median_col, linewidth=2) + +for mean in bp['means']: + mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) + +for flier in bp['fliers']: + flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) + +# Reference line at zero error +ax.axhline(0, color='black', linewidth=1.2, linestyle='--') + +# Labels and formatting +ax.set_xlabel('Functional System / EDSS', fontsize=11, fontweight='bold') +ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') + +# Rotate x labels for readability +plt.xticks(rotation=45, ha='right') + +# Grid and spines +ax.yaxis.grid(True, linestyle='--', alpha=0.3) +for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + +# --- 4. Legend above the plot, outside the axes --- +legend_handles = [ + Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), + Line2D([0], [0], color=median_col, lw=2, label='Median'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, + markeredgecolor='black', markersize=7, label='Mean'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, + markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), + Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') +] + +ax.legend( + handles=legend_handles, + loc='lower center', + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False +) + +# Leave room at the top for the legend +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +# Optional save +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') + +plt.show() +## +>>>>>>> Stashed changes + +# %% test +# Diagnose: what are the actual differences? +print("\n🔍 Raw differences (first 5 rows per system):") +for gt_col, res_col in functional_systems_to_plot: + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + diff = res - gt + non_zero = (diff != 0).sum() + # Check if it's due to floating point noise + abs_diff = diff.abs() + tiny = (abs_diff > 0) & (abs_diff < 1e-10) + print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}") + +## diff --git a/show_plots.py.orig b/show_plots.py.orig new file mode 100644 index 0000000..4332b98 --- /dev/null +++ b/show_plots.py.orig @@ -0,0 +1,2320 @@ +# %% Scatter +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results +df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') +df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') +df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) + +# Create scatter plot +plt.figure(figsize=(8, 6)) +plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue') + +# Add labels and title +plt.xlabel('GT_EDSS') +plt.ylabel('LLM_Results') +plt.title('Comparison of GT_EDSS vs LLM_Results') + +# Optional: Add a diagonal line for reference (perfect prediction) +plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction') +plt.legend() + +# Show plot +plt.grid(True) +plt.tight_layout() +plt.show() + +## + + +# %% Bland0-altman + +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +import statsmodels.api as sm + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results +df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') +df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') +df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) + +# Create Bland-Altman plot +f, ax = plt.subplots(1, figsize=(8, 5)) +sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax) + +# Add labels and title +ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results') +ax.set_xlabel('Mean of GT_EDSS and LLM_Results') +ax.set_ylabel('Difference between GT_EDSS and LLM_Results') + +# Display Bland-Altman plot +plt.tight_layout() +plt.show() + +# Print some statistics +mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results']) +std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results']) +print(f"Mean difference: {mean_diff:.3f}") +print(f"Standard deviation of differences: {std_diff:.3f}") +print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]") + +## + + + +# %% Confusion matrix +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# For confusion matrix, we need to categorize the values +# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +# Create categorical versions +df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) +df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) + +# Remove any NaN categories +df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) + +# Create confusion matrix +cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], + labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Plot confusion matrix +plt.figure(figsize=(10, 8)) +sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], + yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) +#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') +plt.xlabel('LLM Generated EDSS') +plt.ylabel('Ground Truth EDSS') +plt.tight_layout() +plt.show() + +# Print classification report +print("Classification Report:") +print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) + +# Print raw counts +print("\nConfusion Matrix (Raw Counts):") +print(cm) + +## + + +# %% Confusion matrix +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import confusion_matrix, classification_report +import seaborn as sns + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# For confusion matrix, we need to categorize the values +# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) +def categorize_edss(value): + if pd.isna(value): + return np.nan + elif value <= 1.0: + return '0-1' + elif value <= 2.0: + return '1-2' + elif value <= 3.0: + return '2-3' + elif value <= 4.0: + return '3-4' + elif value <= 5.0: + return '4-5' + elif value <= 6.0: + return '5-6' + elif value <= 7.0: + return '6-7' + elif value <= 8.0: + return '7-8' + elif value <= 9.0: + return '8-9' + elif value <= 10.0: + return '9-10' + else: + return '10+' + +# Create categorical versions +df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) +df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) + +# Remove any NaN categories +df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) + +# Create confusion matrix +cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], + labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Plot confusion matrix +plt.figure(figsize=(10, 8)) +ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], + yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) + +# Add legend text above the color bar +# Get the colorbar object +cbar = ax.collections[0].colorbar +# Add text above the colorbar +cbar.set_label('Number of Cases', rotation=270, labelpad=20) + +plt.xlabel('LLM Generated EDSS') +plt.ylabel('Ground Truth EDSS') +#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') +plt.tight_layout() +plt.show() + +# Print classification report +print("Classification Report:") +print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) + +# Print raw counts +print("\nConfusion Matrix (Raw Counts):") +print(cm) + + +## + + + + +# %% Classification +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import confusion_matrix +import numpy as np + +# Load your data from TSV file +file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' + +df = pd.read_csv(file_path, sep='\t') + +# Check data structure +print("Data shape:", df.shape) +print("First few rows:") +print(df.head()) +print("\nColumn names:") +for col in df.columns: + print(f" {col}") + +# Function to safely convert to boolean +def safe_bool_convert(series): + '''Safely convert series to boolean, handling various input formats''' + # Convert to string first, then to boolean + series_str = series.astype(str).str.strip().str.lower() + + # Handle different true/false representations + bool_map = { + 'true': True, '1': True, 'yes': True, 'y': True, + 'false': False, '0': False, 'no': False, 'n': False + } + + converted = series_str.map(bool_map) + + # Handle remaining NaN values + converted = converted.fillna(False) # or True, depending on your preference + + return converted + +# Convert columns safely +if 'result.klassifizierbar' in df.columns: + print("\nresult.klassifizierbar column info:") + print(df['result.klassifizierbar'].head(10)) + print("Unique values:", df['result.klassifizierbar'].unique()) + + df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar']) + print("After conversion:") + print(df['result.klassifizierbar'].value_counts()) + +if 'GT.klassifizierbar' in df.columns: + print("\nGT.klassifizierbar column info:") + print(df['GT.klassifizierbar'].head(10)) + print("Unique values:", df['GT.klassifizierbar'].unique()) + + df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar']) + print("After conversion:") + print(df['GT.klassifizierbar'].value_counts()) + +# Create bar chart showing only True values for klassifizierbar +if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: + # Get counts for True values only + llm_true_count = df['result.klassifizierbar'].sum() + gt_true_count = df['GT.klassifizierbar'].sum() + + # Plot using matplotlib directly + fig, ax = plt.subplots(figsize=(8, 6)) + + x = np.arange(2) + width = 0.35 + + bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8) + bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8) + + # Add value labels on bars + ax.annotate(f'{llm_true_count}', + xy=(x[0], llm_true_count), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom') + + ax.annotate(f'{gt_true_count}', + xy=(x[1], gt_true_count), + xytext=(0, 3), + textcoords="offset points", + ha='center', va='bottom') + + ax.set_xlabel('Classification Status (klassifizierbar)') + ax.set_ylabel('Count') + ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"') + ax.set_xticks(x) + ax.set_xticklabels(['LLM', 'GT']) + ax.legend() + + plt.tight_layout() + plt.show() + +# Create confusion matrix if both columns exist +if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: + try: + # Ensure both columns are boolean + llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool) + gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool) + + cm = confusion_matrix(gt_bool, llm_bool) + + # Plot confusion matrix + fig, ax = plt.subplots(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=['False ', 'True '], + yticklabels=['False', 'True '], + ax=ax) + ax.set_xlabel('LLM Predictions ') + ax.set_ylabel('GT Labels ') + ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"') + + plt.tight_layout() + plt.show() + + print("Confusion Matrix:") + print(cm) + + except Exception as e: + print(f"Error creating confusion matrix: {e}") + +# Show final data info +print("\nFinal DataFrame info:") +print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info()) + +## + + + + +# %% Boxplot +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import numpy as np + +# Load your data from TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS +df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') +df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') + +# Convert to float (handle invalid entries gracefully) +df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') +df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') + +# Drop rows where either column is NaN +df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) + +# 1. DEFINE CATEGORY ORDER +# This ensures the X-axis is numerically logical (0-1 comes before 1-2) +category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+'] + +# Convert the column to a Categorical type with the specific order +df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss), + categories=category_order, + ordered=True) + +plt.figure(figsize=(14, 8)) + +# 2. ADD HUE FOR LEGEND +# Assigning x to 'hue' allows Seaborn to generate a legend automatically +box_plot = sns.boxplot( + data=df_clean, + x='GT.EDSS_cat', + y='result.EDSS', + hue='GT.EDSS_cat', # Added hue + palette='viridis', + linewidth=1.5, + legend=True # Ensure legend is enabled +) + +# 3. CUSTOMIZE PLOT +plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20) +plt.xlabel('Ground Truth EDSS Category', fontsize=14) +plt.ylabel('LLM Predicted EDSS', fontsize=14) + +# Move legend to the side or top +plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left') + +plt.xticks(rotation=45, ha='right', fontsize=10) +plt.grid(True, axis='y', alpha=0.3) +plt.tight_layout() + +plt.show() +## + + +# %% Postproccessing Column names + +import pandas as pd + +# Read the TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Create a mapping dictionary for German to English column names +column_mapping = { + 'EDSS':'GT.EDSS', + 'klassifizierbar': 'GT.klassifizierbar', + 'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS', + 'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS', + 'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS', + 'Sensibiliät': 'GT.SENSORY_FUNCTIONS', + 'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS', + 'Ambulation': 'GT.AMBULATION', + 'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS', + 'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS' +} + +# Rename columns +df = df.rename(columns=column_mapping) + +# Save the modified dataframe back to TSV file +df.to_csv(file_path, sep='\t', index=False) + +print("Columns have been successfully renamed!") +print("Renamed columns:") +for old_name, new_name in column_mapping.items(): + if old_name in df.columns: + print(f" {old_name} -> {new_name}") + + +## + + + + +# %% Styled table +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +import dataframe_image as dfi +# Load data +df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t') + +# 1. Identify all GT and result columns +gt_columns = [col for col in df.columns if col.startswith('GT.')] +result_columns = [col for col in df.columns if col.startswith('result.')] + +print("GT Columns found:", gt_columns) +print("Result Columns found:", result_columns) + +# 2. Create proper mapping between GT and result columns +# Handle various naming conventions (spaces, underscores, etc.) +column_mapping = {} + +for gt_col in gt_columns: + base_name = gt_col.replace('GT.', '') + + # Clean the base name for matching - remove spaces, underscores, etc. + # Try different matching approaches + candidates = [ + f'result.{base_name}', # Exact match + f'result.{base_name.replace(" ", "_")}', # With underscores + f'result.{base_name.replace("_", " ")}', # With spaces + f'result.{base_name.replace(" ", "")}', # No spaces + f'result.{base_name.replace("_", "")}' # No underscores + ] + + # Also try case-insensitive matching + candidates.append(f'result.{base_name.lower()}') + candidates.append(f'result.{base_name.upper()}') + + # Try to find matching result column + matched = False + for candidate in candidates: + if candidate in result_columns: + column_mapping[gt_col] = candidate + matched = True + break + + # If no exact match found, try partial matching + if not matched: + # Try to match by removing special characters and comparing + base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' ']) + for result_col in result_columns: + result_base = result_col.replace('result.', '') + result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' ']) + if base_clean.lower() == result_clean.lower(): + column_mapping[gt_col] = result_col + matched = True + break + +print("Column mapping:", column_mapping) + +# 3. Faster, vectorized computation using the corrected mapping +data_list = [] + +for gt_col, result_col in column_mapping.items(): + print(f"Processing {gt_col} vs {result_col}") + + # Convert to numeric, forcing errors to NaN + s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float) + s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float) + + # Calculate matches (abs difference <= 0.5) + diff = np.abs(s1 - s2) + matches = (diff <= 0.5).sum() + + # Determine the denominator (total valid comparisons) + valid_count = diff.notna().sum() + + if valid_count > 0: + percentage = (matches / valid_count) * 100 + else: + percentage = 0 + + # Extract clean base name for display + base_name = gt_col.replace('GT.', '') + + data_list.append({ + 'GT': base_name, + 'Match %': round(percentage, 1) + }) + + + + +# 4. Prepare Data +match_df = pd.DataFrame(data_list) +# Clean up labels: Replace underscores with spaces and capitalize +match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title() +match_df = match_df.sort_values('Match %', ascending=False) + +# 5. Create a "Beautiful" Table using Seaborn Heatmap +def create_luxury_table(df, output_file="edss_agreement.png"): + # Set the aesthetic style + sns.set_theme(style="white", font="sans-serif") + + # Prepare data for heatmap + plot_data = df.set_index('GT')[['Match %']] + + # Initialize the figure + # Height is dynamic based on number of rows + fig, ax = plt.subplots(figsize=(8, len(df) * 0.6)) + + # Create a custom diverging color map (Deep Red -> Mustard -> Emerald) + # This looks more professional than standard 'RdYlGn' + cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True) + + # Draw the heatmap + sns.heatmap( + plot_data, + annot=True, + fmt=".1f", + cmap=cmap, + center=85, # Centers the color transition + vmin=50, vmax=100, # Range of the gradient + linewidths=2, + linecolor='white', + cbar=False, # Remove color bar for a "table" look + annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"} + ) + + # Styling the Axes (Turning the heatmap into a table) + ax.set_xlabel("") + ax.set_ylabel("") + ax.xaxis.tick_top() # Move "Match %" label to top + ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50') + ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0) + + # Add a thin border around the plot + for _, spine in ax.spines.items(): + spine.set_visible(True) + spine.set_color('#ecf0f1') + + plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50') + + # Add a subtle footer + plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points", + wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic') + + # Save with high resolution + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + print(f"Beautiful table saved as {output_file}") + +# Execute +create_luxury_table(match_df) + + +# Run the function +save_styled_table(match_df) +# 6. Save as SVG + +plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight') +print("Successfully saved agreement_table.svg") + +# Show plot if running in a GUI environment +plt.show() +## + + + +# %% Time Plot +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd +from scipy import stats + +# Load the TSV file +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Extract the inference_time_sec column +inference_times = df['inference_time_sec'].dropna() # Remove NaN values + +# Calculate statistics +mean_time = inference_times.mean() +std_time = inference_times.std() +median_time = np.median(inference_times) + +# Create the histogram +fig, ax = plt.subplots(figsize=(10, 6)) + +# Create histogram with bins of 1 second width +min_time = int(inference_times.min()) +max_time = int(inference_times.max()) + 1 +bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width + +# Create histogram with counts (not probability density) +n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5) + +# Generate Gaussian curve for fit +x = np.linspace(inference_times.min(), inference_times.max(), 100) +# Scale Gaussian to match histogram counts +gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0]) + +# Plot Gaussian fit +ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)') + +# Add vertical lines for mean and median +ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s') +ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s') + +# Add standard deviation as vertical lines +ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s') +ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s') + +ax.set_xlabel('Inference Time (seconds)') +ax.set_ylabel('Frequency') +ax.set_title('Inference Time Distribution with Gaussian Fit') +ax.legend() +ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +## + +# %% Dashboard +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import numpy as np +from matplotlib.gridspec import GridSpec + +def to_numeric_comma(s: pd.Series) -> pd.Series: + # accepts 1.5 and 1,5 + return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Rename columns to remove 'result.' prefix and replace spaces +column_mapping = {} +for col in df.columns: + if col.startswith('result.'): + new_name = col.replace('result.', '').replace(' ', '_') + column_mapping[col] = new_name +df = df.rename(columns=column_mapping) + +# Parse MedDatum safely +df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') + +# Patient +patient_id = 'd13e4aa3' +patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() +if patient_data.empty: + raise ValueError(f"No data found for patient: {patient_id}") + +# Functional systems + EDSS +edss_col, edss_title = ('GT.EDSS', 'EDSS') + +functional_systems = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), + ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), + ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), + ('GT.SENSORY_FUNCTIONS', 'Sensory'), + ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), + ('GT.AMBULATION', 'Ambulation'), + ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), +] + +# y-axis max rules +ymax_by_col = { + 'GT.PYRAMIDAL_FUNCTIONS': 6, + 'GT.SENSORY_FUNCTIONS': 6, + 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, + 'GT.VISUAL_OPTIC_FUNCTIONS': 6, + 'GT.CEREBELLAR_FUNCTIONS': 5, + 'GT.CEREBRAL_FUNCTIONS': 5, + 'GT.BRAINSTEM_FUNCTIONS': 5, + 'GT.EDSS': 10, +} +default_ymax = 6 + +# ---------- Build shared "event dates" ticks ---------- +cols_for_dates = [edss_col] + [c for c, _ in functional_systems] +event_dates = [] + +for c in cols_for_dates: + if c in patient_data.columns: + y = to_numeric_comma(patient_data[c]) # <-- changed + x = patient_data['MedDatum'] + tmp = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]) + event_dates.extend(tmp["x"].tolist()) + +event_dates = sorted(pd.Series(event_dates).drop_duplicates().tolist()) + +max_ticks = 8 +if len(event_dates) > max_ticks: + idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) + event_dates = [event_dates[i] for i in idx] + +# ---------- A4 figure ---------- +fig = plt.figure(figsize=(11.69, 8.27)) +gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) + +def style_time_axis(ax, show_labels=True): + ax.set_xticks(event_dates) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) + ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) + if not show_labels: + ax.tick_params(labelbottom=False) + +# ---------- EDSS main plot ---------- +ax_main = fig.add_subplot(gs[0, :]) + +if edss_col in patient_data.columns: + y = to_numeric_comma(patient_data[edss_col]) # <-- changed + x = patient_data['MedDatum'] + plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") + + ax_main.set_title(edss_title, fontsize=14, fontweight='bold') + ax_main.set_ylabel("Score") + ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) + ax_main.grid(True, alpha=0.3) + + if not plot_df.empty: + ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') + else: + ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') +else: + ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') + ax_main.set_ylim(0, ymax_by_col.get(edss_col, 10)) + ax_main.grid(True, alpha=0.3) + +style_time_axis(ax_main) + +# ---------- Small aligned plots ---------- +small_axes = [] +for k, (col, title) in enumerate(functional_systems): + r = 1 + (k // 4) + c = (k % 4) + ax = fig.add_subplot(gs[r, c], sharex=ax_main) + small_axes.append(ax) + + ymax = ymax_by_col.get(col, default_ymax) + ax.set_title(title, fontsize=10) + ax.set_ylabel("Score") + ax.set_ylim(0, ymax) + ax.grid(True, alpha=0.3) + + if col in patient_data.columns: + y = to_numeric_comma(patient_data[col]) # <-- changed + x = patient_data['MedDatum'] + plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") + + if not plot_df.empty: + ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') + else: + ax.set_title(f"{title} (no data)", fontsize=10) + else: + ax.set_title(f"{title} (missing)", fontsize=10) + + style_time_axis(ax) + +# Hide x tick labels on first row of small plots +for ax in small_axes[:4]: + ax.tick_params(labelbottom=False) + +plt.tight_layout() +fig.subplots_adjust(hspace=0.7) +plt.show() +## + +<<<<<<< Updated upstream +======= + + + +# %% Dashboard Angepasst +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import numpy as np +from matplotlib.gridspec import GridSpec + +def to_numeric_comma(s: pd.Series) -> pd.Series: + # accepts 1.5 and 1,5 + return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Rename columns to remove 'result.' prefix and replace spaces +column_mapping = {} +for col in df.columns: + if col.startswith('result.'): + new_name = col.replace('result.', '').replace(' ', '_') + column_mapping[col] = new_name +df = df.rename(columns=column_mapping) + +# Parse MedDatum safely +df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') + +# Patient +patient_id = '3d942c60' + +patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() +if patient_data.empty: + raise ValueError(f"No data found for patient: {patient_id}") + +# Functional systems + EDSS +edss_col, edss_title = ('GT.EDSS', 'EDSS') + +functional_systems = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), + ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), + ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), + ('GT.SENSORY_FUNCTIONS', 'Sensory'), + ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), + ('GT.AMBULATION', 'Ambulation'), + ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), +] + +# y-axis max rules +ymax_by_col = { + 'GT.PYRAMIDAL_FUNCTIONS': 6, + 'GT.SENSORY_FUNCTIONS': 6, + 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, + 'GT.VISUAL_OPTIC_FUNCTIONS': 6, + 'GT.CEREBELLAR_FUNCTIONS': 5, + 'GT.CEREBRAL_FUNCTIONS': 5, + 'GT.BRAINSTEM_FUNCTIONS': 5, + 'GT.EDSS': 10, +} +default_ymax = 6 + +# ---------- Build shared visit dates ticks ---------- +# Use ALL patient visit dates, not only dates with valid numeric values +event_dates = sorted(patient_data['MedDatum'].dropna().drop_duplicates().tolist()) + +max_ticks = 8 +if len(event_dates) > max_ticks: + idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) + event_dates = [event_dates[i] for i in idx] + +# Base timeline for plotting: one row per patient visit date +timeline = ( + patient_data[['MedDatum']] + .dropna() + .drop_duplicates() + .sort_values('MedDatum') + .rename(columns={'MedDatum': 'x'}) +) + +# ---------- A4 figure ---------- +fig = plt.figure(figsize=(11.69, 8.27)) +gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) + +def style_time_axis(ax, show_labels=True): + ax.set_xticks(event_dates) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) + ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) + if not show_labels: + ax.tick_params(labelbottom=False) + +def get_plot_df(patient_data, col): + """ + Keep all visit dates. + Missing values stay NaN so matplotlib draws gaps instead of zeros. + """ + tmp = patient_data[['MedDatum', col]].copy() + tmp = tmp.rename(columns={'MedDatum': 'x', col: 'raw_y'}) + tmp['y'] = to_numeric_comma(tmp['raw_y']) + + # aggregate if multiple rows exist on same date + tmp = tmp.groupby('x', as_index=False)['y'].max() + + # merge onto full timeline so all dates remain visible + plot_df = timeline.merge(tmp, on='x', how='left').sort_values('x') + return plot_df + +# ---------- EDSS main plot ---------- +ax_main = fig.add_subplot(gs[0, :]) +ax_main.set_title(edss_title, fontsize=14, fontweight='bold') +ax_main.set_ylabel("Score") +ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) +ax_main.grid(True, alpha=0.3) + +if edss_col in patient_data.columns: + plot_df = get_plot_df(patient_data, edss_col) + + if plot_df['y'].notna().any(): + # NaNs create visible gaps in the line + ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') + else: + ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') +else: + ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') + +style_time_axis(ax_main) + +# ---------- Small aligned plots ---------- +small_axes = [] +for k, (col, title) in enumerate(functional_systems): + r = 1 + (k // 4) + c = (k % 4) + ax = fig.add_subplot(gs[r, c], sharex=ax_main) + small_axes.append(ax) + + ymax = ymax_by_col.get(col, default_ymax) + ax.set_title(title, fontsize=10) + ax.set_ylabel("Score") + ax.set_ylim(0, ymax) + ax.grid(True, alpha=0.3) + + if col in patient_data.columns: + plot_df = get_plot_df(patient_data, col) + + if plot_df['y'].notna().any(): + # NaNs remain in y -> line breaks where data is missing + ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') + else: + ax.set_title(f"{title} (no numeric data)", fontsize=10) + else: + ax.set_title(f"{title} (missing)", fontsize=10) + + style_time_axis(ax) + +# Hide x tick labels on first row of small plots +for ax in small_axes[:4]: + ax.tick_params(labelbottom=False) + +plt.tight_layout() +fig.subplots_adjust(hspace=0.7) +plt.show() + + + +## + +>>>>>>> Stashed changes +# %% Table +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from datetime import datetime +import numpy as np + +# Load the data +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' +df = pd.read_csv(file_path, sep='\t') + +# Convert MedDatum to datetime +df['MedDatum'] = pd.to_datetime(df['MedDatum']) + +# Check what columns actually exist in the dataset +print("Available columns:") +print(df.columns.tolist()) +print("\nFirst few rows:") +print(df.head()) + +# Check data types +print("\nData types:") +print(df.dtypes) + +# Hardcode specific patient names +patient_names = ['6ccda8c6'] + +# Define the functional systems (columns to plot) +functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] + +# Create subplots +num_plots = len(functional_systems) +num_cols = 2 +num_rows = (num_plots + num_cols - 1) // num_cols + +fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) +if num_plots == 1: + axes = [axes] +elif num_rows == 1: + axes = axes +else: + axes = axes.flatten() + +# Plot for the hardcoded patient +for i, system in enumerate(functional_systems): + # Filter data for this specific patient + patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum') + + # Check if patient data exists + if patient_data.empty: + print(f"No data found for patient: {patient_names[0]}") + axes[i].set_title(f'Functional System: {system} (No data)') + axes[i].set_ylabel('Score') + continue + + # Check if the system column exists + if system in patient_data.columns: + # Plot only valid data (non-null values) + valid_data = patient_data.dropna(subset=[system]) + + if not valid_data.empty: + # Ensure MedDatum is properly formatted for plotting + axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system) + axes[i].set_ylabel('Score') + axes[i].set_title(f'Functional System: {system}') + axes[i].grid(True, alpha=0.3) + axes[i].legend() + else: + axes[i].set_title(f'Functional System: {system} (No valid data)') + axes[i].set_ylabel('Score') + else: + # Try to find similar column names + found_column = None + for col in df.columns: + if system.lower() in col.lower(): + found_column = col + break + + if found_column: + valid_data = patient_data.dropna(subset=[found_column]) + if not valid_data.empty: + axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column) + axes[i].set_ylabel('Score') + axes[i].set_title(f'Functional System: {system} (found as: {found_column})') + axes[i].grid(True, alpha=0.3) + axes[i].legend() + else: + axes[i].set_title(f'Functional System: {system} (No valid data)') + axes[i].set_ylabel('Score') + else: + axes[i].set_title(f'Functional System: {system} (Column not found)') + axes[i].set_ylabel('Score') + +# Hide empty subplots +for i in range(len(functional_systems), len(axes)): + axes[i].set_visible(False) + +# Set x-axis label for the last row only +for i in range(len(functional_systems)): + if i >= len(axes) - num_cols: # Last row + axes[i].set_xlabel('Date') + +# Format x-axis dates +for ax in axes: + if ax.get_lines(): # Only format if there are lines to plot + ax.tick_params(axis='x', rotation=45) + ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d')) + +# Automatically adjust layout +plt.tight_layout() +plt.show() + + + + +## + + + + +# %% Histogram Fig1 +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.font_manager as fm +import json +import os + +def create_visit_frequency_plot( + file_path, + output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', + output_filename='visit_frequency_distribution.svg', + fontsize=10, + color_scheme_path='colors.json' +): + """ + Creates a publication-ready bar chart of patient visit frequency. + + Args: + file_path (str): Path to the input TSV file. + output_dir (str): Directory to save the output SVG file. + output_filename (str): Name of the output SVG file. + fontsize (int): Font size for all text elements (labels, title). + color_scheme_path (str): Path to the JSON file containing the color palette. + """ + # --- 1. Load Data and Color Scheme --- + try: + df = pd.read_csv(file_path, sep='\t') + print("Data loaded successfully.") + # Sort data for easier visual comparison + df = df.sort_values(by='Visits Count') + except FileNotFoundError: + print(f"Error: The file was not found at {file_path}") + return + + try: + with open(color_scheme_path, 'r') as f: + colors = json.load(f) + # Select a blue from the sequential palette for the bars + bar_color = colors['sequential']['blues'][-2] # A saturated blue + except FileNotFoundError: + print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") + bar_color = '#2171b5' # A common matplotlib blue + + # --- 2. Set up the Plot with Scientific Style --- + plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height + + # Set the font to Arial + arial_font = fm.FontProperties(family='Arial', size=fontsize) + plt.rcParams['font.family'] = 'Arial' + plt.rcParams['font.size'] = fontsize + + # --- 3. Create the Bar Chart --- + ax = plt.gca() + bars = plt.bar( + x=df['Visits Count'], + height=df['Unique Patients'], + color=bar_color, + edgecolor='black', + linewidth=0.5, # Minimum line thickness + width=0.7 + ) + + # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- + # Get the unique visit counts to use as tick labels + visit_counts = df['Visits Count'].unique() + # Set the x-ticks to be at the center of each bar + ax.set_xticks(visit_counts) + # Set the x-tick labels to be the visit counts, using the specified font + ax.set_xticklabels(visit_counts, fontproperties=arial_font) + # --- END OF NEW SECTION --- + + # --- 4. Customize Axes and Layout (Nature style) --- + # Display only left and bottom axes + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + # Turn off axis ticks (the marks, not the labels) + plt.tick_params(axis='both', which='both', length=0) + + # Remove grid lines + plt.grid(False) + + # Set background to white (no shading) + ax.set_facecolor('white') + plt.gcf().set_facecolor('white') + + # --- 5. Add Labels and Title --- + plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) + plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) + plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) + + # --- 6. Add y-axis values on top of each bar --- + # This adds the count of unique patients directly above each bar. + ax.bar_label(bars, fmt='%d', padding=3) + + # --- 7. Export the Figure --- + # Ensure the output directory exists + os.makedirs(output_dir, exist_ok=True) + + full_output_path = os.path.join(output_dir, output_filename) + plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') + print(f"\nFigure saved as '{full_output_path}'") + + # --- 8. (Optional) Display the Plot --- + # plt.show() + +# --- Main execution --- +if __name__ == '__main__': + # Define the file path + input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' + + # Call the function to create and save the plot + create_visit_frequency_plot( + file_path=input_file, + fontsize=10 # Using a 10 pt font size as per guidelines + ) + +## + + + +# %% Scatter Plot functional system + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os + +# --- Configuration --- +# Set the font to Arial for all text in the plot, as per the guidelines +plt.rcParams['font.family'] = 'Arial' + +# Define the path to your data file +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' + +# Define the path to save the color mapping JSON file +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' + +# Define the path to save the final figure +figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' + +# --- 1. Load the Dataset --- +try: + # Load the TSV file + df = pd.read_csv(data_path, sep='\t') + print(f"Successfully loaded data from {data_path}") + print(f"Data shape: {df.shape}") +except FileNotFoundError: + print(f"Error: The file at {data_path} was not found.") + # Exit or handle the error appropriately + raise + +# --- 2. Define Functional Systems and Create Color Mapping --- +# List of tuples containing (ground_truth_column, result_column) +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +# Extract system names for color mapping and legend +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] + +# Define a professional color palette (dark blue theme) +# This is a qualitative palette with distinct, accessible colors +colors = [ + '#003366', # Dark Blue + '#336699', # Medium Blue + '#6699CC', # Light Blue + '#99CCFF', # Very Light Blue + '#FF9966', # Coral + '#FF6666', # Light Red + '#CC6699', # Magenta + '#9966CC' # Purple +] + +# Create a dictionary mapping system names to colors +color_map = dict(zip(system_names, colors)) + +# Ensure the directory for the JSON file exists +os.makedirs(os.path.dirname(color_json_path), exist_ok=True) + +# Save the color map to a JSON file +with open(color_json_path, 'w') as f: + json.dump(color_map, f, indent=4) + +print(f"Color mapping saved to {color_json_path}") + +# --- 3. Calculate Agreement Percentages and Format Legend Labels --- +agreement_percentages = {} +legend_labels = {} + +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Ensure we are comparing the same rows + common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) + gt_data = gt_numeric.loc[common_index] + res_data = res_numeric.loc[common_index] + + # Calculate agreement percentage + if len(gt_data) > 0: + agreement = (gt_data == res_data).mean() * 100 + else: + agreement = 0 # Handle case with no valid data + + agreement_percentages[system_name] = agreement + + # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") + formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) + legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" + +# --- 4. Reshape Data for Plotting --- +plot_data = [] +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Create a temporary DataFrame with the numeric data + temp_df = pd.DataFrame({ + 'system': system_name, + 'ground_truth': gt_numeric, + 'inference': res_numeric + }) + + # Drop rows where either value is NaN, as they cannot be plotted + temp_df = temp_df.dropna() + + plot_data.append(temp_df) + +# Concatenate all the temporary DataFrames into one +plot_df = pd.concat(plot_data, ignore_index=True) + +if plot_df.empty: + print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") +else: + print(f"Prepared plot data with {len(plot_df)} data points.") + +# --- 5. Create the Scatter Plot --- +plt.figure(figsize=(10, 8)) + +# Plot each functional system with its assigned color and formatted legend label +for system, group in plot_df.groupby('system'): + plt.scatter( + group['ground_truth'], + group['inference'], + label=legend_labels[system], + color=color_map[system], + alpha=0.7, + s=30 + ) + +# Add a diagonal line representing perfect agreement (y = x) +# This line helps visualize how close the predictions are to the ground truth +if not plot_df.empty: + plt.plot( + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], + color='black', + linestyle='--', + linewidth=0.8, + alpha=0.7 + ) + +# --- 6. Apply Styling and Labels --- +plt.xlabel('Ground Truth', fontsize=12) +plt.ylabel('LLM Inference', fontsize=12) +plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) + +# Apply scientific visualization styling rules +ax = plt.gca() +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.tick_params(axis='both', which='both', length=0) # Remove ticks +ax.grid(False) # Remove grid lines +plt.legend(title='Functional System', frameon=False, fontsize=10) + +# --- 7. Save and Display the Figure --- +# Ensure the directory for the figure exists +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) + +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +print(f"Figure successfully saved to {figure_save_path}") + +# Display the plot +plt.show() +## + + + + +# %% Confusion Matrix functional systems + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os +import numpy as np +import matplotlib.colors as mcolors + +# --- Configuration --- +plt.rcParams['font.family'] = 'Arial' +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' +figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' + +# --- 1. Load the Dataset --- +df = pd.read_csv(data_path, sep='\t') + +# --- 2. Define Functional Systems and Colors --- +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] +colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] +color_map = dict(zip(system_names, colors)) + +# --- 3. Categorization Function --- +categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] +category_to_index = {cat: i for i, cat in enumerate(categories)} +n_categories = len(categories) + +def categorize_edss(value): + if pd.isna(value): return np.nan + # Ensure value is float to avoid TypeError + val = float(value) + idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 + return categories[min(idx, len(categories)-1)] + +# --- 4. Prepare Mixed Color Matrix with Saturation --- +cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) + +for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): + # Fix: Ensure numeric conversion to avoid string comparison errors + temp_df = df[[gt_col, res_col]].copy() + temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') + temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') + valid_df = temp_df.dropna() + + for _, row in valid_df.iterrows(): + gt_cat = categorize_edss(row[gt_col]) + res_cat = categorize_edss(row[res_col]) + if gt_cat in category_to_index and res_cat in category_to_index: + cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 + +# Create an RGBA image matrix (10x10x4) +rgba_matrix = np.zeros((n_categories, n_categories, 4)) + +total_counts = np.sum(cell_system_counts, axis=2) +max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 + +for i in range(n_categories): + for j in range(n_categories): + count_sum = total_counts[i, j] + if count_sum > 0: + mixed_rgb = np.zeros(3) + for s_idx, s_name in enumerate(system_names): + weight = cell_system_counts[i, j, s_idx] / count_sum + system_rgb = mcolors.to_rgb(color_map[s_name]) + mixed_rgb += np.array(system_rgb) * weight + + # Set RGB channels + rgba_matrix[i, j, :3] = mixed_rgb + + # Set Alpha channel (Saturation Effect) + # Using a square root scale to make lower counts more visible but still "lighter" + alpha = np.sqrt(count_sum / max_count) + # Ensure alpha is at least 0.1 so it's not invisible + rgba_matrix[i, j, 3] = max(0.1, alpha) + else: + # Empty cells are white + rgba_matrix[i, j] = [1, 1, 1, 0] + +# --- 5. Plotting --- +fig, ax = plt.subplots(figsize=(12, 10)) + +# Show the matrix +# Note: we use origin='lower' if you want 0-1 at the bottom, +# but confusion matrices usually have 0-1 at the top (origin='upper') +im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') + +# Add count labels +for i in range(n_categories): + for j in range(n_categories): + if total_counts[i, j] > 0: + # Background brightness for text contrast + bg_color = rgba_matrix[i, j, :3] + lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] + # If alpha is low, background is effectively white, so use black text + text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" + ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", + color=text_col, fontsize=10, fontweight='bold') + +# --- 6. Styling --- +ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) +ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) +ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) + +ax.set_xticks(np.arange(n_categories)) +ax.set_xticklabels(categories) +ax.set_yticks(np.arange(n_categories)) +ax.set_yticklabels(categories) + +# Remove the frame/spines for a cleaner look +for spine in ax.spines.values(): + spine.set_visible(False) + +# Custom Legend +handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] +labels = [name.replace('_', ' ').capitalize() for name in system_names] +ax.legend(handles, labels, title='Functional Systems', loc='upper left', + bbox_to_anchor=(1.05, 1), frameon=False) + +plt.tight_layout() +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +plt.show() + +## + + + + + + + + + + +# %% Difference Plot Functional system + +import pandas as pd +import matplotlib.pyplot as plt +import json +import os +import numpy as np + +# --- Configuration --- +# Set the font to Arial for all text in the plot, as per the guidelines +plt.rcParams['font.family'] = 'Arial' + +# Define the path to your data file +data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' + +# Define the path to save the color mapping JSON file +color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' + +# Define the path to save the final figure +figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' + +# --- 1. Load the Dataset --- +try: + # Load the TSV file + df = pd.read_csv(data_path, sep='\t') + print(f"Successfully loaded data from {data_path}") + print(f"Data shape: {df.shape}") +except FileNotFoundError: + print(f"Error: The file at {data_path} was not found.") + # Exit or handle the error appropriately + raise + +# --- 2. Define Functional Systems and Create Color Mapping --- +# List of tuples containing (ground_truth_column, result_column) +functional_systems_to_plot = [ + ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), + ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), + ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), + ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), + ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), + ('GT.AMBULATION', 'result.AMBULATION'), + ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), + ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') +] + +# Extract system names for color mapping and legend +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] + +# Define a professional color palette (dark blue theme) +# This is a qualitative palette with distinct, accessible colors +colors = [ + '#003366', # Dark Blue + '#336699', # Medium Blue + '#6699CC', # Light Blue + '#99CCFF', # Very Light Blue + '#FF9966', # Coral + '#FF6666', # Light Red + '#CC6699', # Magenta + '#9966CC' # Purple +] + +# Create a dictionary mapping system names to colors +color_map = dict(zip(system_names, colors)) + +# Ensure the directory for the JSON file exists +os.makedirs(os.path.dirname(color_json_path), exist_ok=True) + +# Save the color map to a JSON file +with open(color_json_path, 'w') as f: + json.dump(color_map, f, indent=4) + +print(f"Color mapping saved to {color_json_path}") + +# --- 3. Calculate Agreement Percentages and Format Legend Labels --- +agreement_percentages = {} +legend_labels = {} + +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Convert columns to numeric, setting errors to NaN + gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') + res_numeric = pd.to_numeric(df[res_col], errors='coerce') + + # Ensure we are comparing the same rows + common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) + gt_data = gt_numeric.loc[common_index] + res_data = res_numeric.loc[common_index] + + # Calculate agreement percentage + if len(gt_data) > 0: + agreement = (gt_data == res_data).mean() * 100 + else: + agreement = 0 # Handle case with no valid data + + agreement_percentages[system_name] = agreement + + # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") + formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) + legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" + + + # --- 4. Robustly Prepare Error Data for Boxplot --- + +def safe_parse(s): + '''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)''' + if pd.isna(s): + return np.nan + if isinstance(s, (int, float)): + return float(s) + # Replace comma with dot, then strip whitespace + s_clean = str(s).replace(',', '.').strip() + try: + return float(s_clean) + except ValueError: + return np.nan + +plot_data = [] +for gt_col, res_col in functional_systems_to_plot: + system_name = gt_col.split('.')[1] + + # Parse both columns with robust comma handling + gt_numeric = df[gt_col].apply(safe_parse) + res_numeric = df[res_col].apply(safe_parse) + + # Compute error (only where both are finite) + error = res_numeric - gt_numeric + + # Create temp DataFrame + temp_df = pd.DataFrame({ + 'system': system_name, + 'error': error + }).dropna() # drop rows where either was unparseable + + plot_data.append(temp_df) + +plot_df = pd.concat(plot_data, ignore_index=True) + +if plot_df.empty: + print("⚠️ Warning: No valid numeric error data to plot after robust parsing.") +else: + print(f"✅ Prepared error data with {len(plot_df)} data points.") + # Diagnostic: show a few samples + print("\n📌 Sample errors by system:") + for sys, grp in plot_df.groupby('system'): + print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}") + +# Ensure categorical ordering +plot_df['system'] = pd.Categorical( + plot_df['system'], + categories=[name.split('.')[1] for name, _ in functional_systems_to_plot], + ordered=True +) + + +# --- 5. Prepare Data for Diverging Stacked Bar Plot --- +print("\n📊 Preparing diverging stacked bar plot data...") + +# Define bins for error direction +def categorize_error(err): + if pd.isna(err): + return 'missing' + elif err < 0: + return 'underestimate' + elif err > 0: + return 'overestimate' + else: + return 'match' + +# Add category column (only on finite errors) +plot_df_clean = plot_df[plot_df['error'].notna()].copy() +plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error) + +# Count by system + category +category_counts = ( + plot_df_clean + .groupby(['system', 'category']) + .size() + .unstack(fill_value=0) + .reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0) +) +# Reorder systems +category_counts = category_counts.reindex(system_names) + +# Prepare for diverging plot: +# - Underestimates: plotted to the *left* (negative x) +# - Overestimates: plotted to the *right* (positive x) +# - Matches: centered (no width needed, or as a bar of width 0.2) +underestimate_counts = category_counts['underestimate'] +match_counts = category_counts['match'] +overestimate_counts = category_counts['overestimate'] + +# For diverging: left = -underestimate, right = overestimate +left_counts = underestimate_counts +right_counts = overestimate_counts + +# Compute max absolute bar height (for symmetric x-axis) +max_bar = max(left_counts.max(), right_counts.max(), 1) +plot_range = (-max_bar, max_bar) + +# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ... +n_systems = len(system_names) +positions = np.arange(n_systems) +left_positions = -positions - 0.5 # left-aligned underestimates +right_positions = positions + 0.5 # right-aligned overestimates + +# --- 6. Create Diverging Stacked Bar Plot --- +plt.figure(figsize=(12, 7)) + +# Colors: diverging palette +colors = { + 'underestimate': '#E74C3C', # Red (left) + 'match': '#2ECC71', # Green (center) + 'overestimate': '#F39C12' # Orange (right) +} + +# Plot underestimates (left side) +bars_left = plt.barh( + left_positions, + left_counts.values, + height=0.8, + left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values) + color=colors['underestimate'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Underestimate' +) + +# Plot overestimates (right side) +bars_right = plt.barh( + right_positions, + right_counts.values, + height=0.8, + left=0, + color=colors['overestimate'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Overestimate' +) + +# Plot matches (center — narrow bar) +# Use a very narrow width (0.2) centered at 0 +plt.barh( + positions, + match_counts.values, + height=0.2, + left=0, # starts at 0, extends right + color=colors['match'], + edgecolor='black', + linewidth=0.5, + alpha=0.9, + label='Exact Match' +) + +# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match) +# For perfect symmetry: +for i, count in enumerate(match_counts.values): + if count > 0: + plt.barh( + positions[i], + width=count, + left=-count/2, + height=0.25, + color=colors['match'], + edgecolor='black', + linewidth=0.5, + alpha=0.95 + ) + +# --- 7. Styling & Labels --- +# Zero reference line +plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8) + +# X-axis: symmetric around 0 +plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1) +plt.xticks(rotation=0, fontsize=10) +plt.xlabel('Count', fontsize=12) + +# Y-axis: system names at original positions (centered) +plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10) +plt.ylabel('Functional System', fontsize=12) + +# Title & layout +plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15) + +# Clean axes +ax = plt.gca() +ax.spines['top'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['left'].set_visible(False) # We only need bottom axis +ax.xaxis.set_ticks_position('bottom') +ax.yaxis.set_ticks_position('none') + +# Grid only along x +ax.xaxis.grid(True, linestyle=':', alpha=0.5) + +# Legend +from matplotlib.patches import Patch +legend_elements = [ + Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'), + Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'), + Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate') +] +plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10) + +# Optional: Add counts on bars +for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)): + if left > 0: + plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold') + if right > 0: + plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold') + if match > 0: + plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black') + +plt.tight_layout() + +# --- 8. Save & Show --- +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') +print(f"✅ Diverging bar plot saved to {figure_save_path}") + +plt.show() + +## + + + +# %% Difference Plot Gemini +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg' + +# --- 1. Process Error Data with Magnitude Breakdown --- +system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] +plot_list = [] + +for gt_col, res_col in functional_systems_to_plot: + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + error = res - gt + + # Granular Counts + matches = (error == 0).sum() + u_1 = (error == -1).sum() + u_2plus = (error <= -2).sum() + o_1 = (error == 1).sum() + o_2plus = (error >= 2).sum() + + total = error.dropna().count() + divisor = max(total, 1) + + plot_list.append({ + 'System': sys_name.replace('_', ' ').title(), + 'Matches': matches, 'MatchPct': (matches / divisor) * 100, + 'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus, + 'UnderPct': ((u_1 + u_2plus) / divisor) * 100, + 'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus, + 'OverPct': ((o_1 + o_2plus) / divisor) * 100 + }) + +stats_df = pd.DataFrame(plot_list) + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(13, 8)) + +# Define Magnitude Colors +c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1) +c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1) +bar_height = 0.6 +y_pos = np.arange(len(stats_df)) + +# Plot Under-scored (Stacked: -2+ then -1) +ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white') +ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white') + +# Plot Over-scored (Stacked: +1 then +2+) +ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white') +ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white') + +# --- 3. Aesthetics & Table Labels --- +for i, row in stats_df.iterrows(): + label_text = ( + f"$\\mathbf{{{row['System']}}}$\n" + f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" + f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)" + ) + # Position table text to the left + ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4) + +# Formatting +ax.axvline(0, color='black', linewidth=1.2) +ax.set_yticks([]) +ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold') +#ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35) + +# Absolute X-axis labels +ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) + +# Remove spines and add grid +for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False) +ax.xaxis.grid(True, linestyle='--', alpha=0.3) + +# Legend with magnitude info +ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) + +plt.tight_layout() +plt.show() +## + + +# %% Functional System Error Boxplots +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_boxplot.svg' + +# --- 1. Build error data for boxplots --- +boxplot_data = [] +system_labels = [] +sample_sizes = [] + +for gt_col, res_col in functional_systems_to_plot: + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + + # Error = result - ground truth + error = (res - gt).dropna() + + # Ignore all 0 errors + error = error[error != 0] + + # Keep only systems that actually have non-zero data + if len(error) > 0: + clean_name = sys_name.replace('_', ' ').title() + boxplot_data.append(error.values) + system_labels.append(clean_name) + sample_sizes.append(len(error)) + +# Safety check +if not boxplot_data: + raise ValueError("No valid non-zero error data available for any functional system.") + +# Put n into x-axis labels so it doesn't overlap the plot +xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(14, 8)) + +bp = ax.boxplot( + boxplot_data, + vert=True, + patch_artist=True, + labels=xtick_labels, + showmeans=True, + meanline=False +) + +# --- 3. Styling --- +box_face = '#D6EAF8' +box_edge = '#2980B9' +whisker_col = '#7F8C8D' +median_col = '#C0392B' +mean_col = '#1ABC9C' +flier_face = '#95A5A6' +flier_edge = '#7F8C8D' + +for box in bp['boxes']: + box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) + +for whisker in bp['whiskers']: + whisker.set(color=whisker_col, linewidth=1.2) + +for cap in bp['caps']: + cap.set(color=whisker_col, linewidth=1.2) + +for median in bp['medians']: + median.set(color=median_col, linewidth=2) + +for mean in bp['means']: + mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) + +for flier in bp['fliers']: + flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) + +# Reference line at zero error +ax.axhline(0, color='black', linewidth=1.2, linestyle='--') + +# Labels and formatting +ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') +ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') + +# Rotate x labels for readability +plt.xticks(rotation=45, ha='right') + +# Grid and spines +ax.yaxis.grid(True, linestyle='--', alpha=0.3) +for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + +# --- 4. Legend above the plot, outside the axes --- +legend_handles = [ + Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), + Line2D([0], [0], color=median_col, lw=2, label='Median'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, + markeredgecolor='black', markersize=7, label='Mean'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, + markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), + Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') +] + +ax.legend( + handles=legend_handles, + loc='lower center', + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False +) + +# Leave room at the top for the legend +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +# Optional save +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') + +plt.show() +## + +<<<<<<< Updated upstream +======= +# %% Functional System + EDSS Error Boxplots +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_edss_boxplot.svg' + +# ------------------------------------------------------------ +# Expect functional_systems_to_plot like: +# [ +# ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL_OPTIC_FUNCTIONS'), +# ... +# ] +# +# Add EDSS here: +# ------------------------------------------------------------ +all_systems_to_plot = list(functional_systems_to_plot) + [ + ('GT.EDSS', 'result.EDSS') +] + +# --- 1. Build error data for boxplots --- +boxplot_data = [] +system_labels = [] +sample_sizes = [] + +for gt_col, res_col in all_systems_to_plot: + # Skip safely if a column is missing + if gt_col not in df.columns or res_col not in df.columns: + print(f"Skipping missing columns: {gt_col}, {res_col}") + continue + + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + + # Error = result - ground truth + error = (res - gt).dropna() + + # Ignore all 0 errors + error = error[error != 0] + + # Keep only systems that actually have non-zero data + if len(error) > 0: + if sys_name == 'EDSS': + clean_name = 'EDSS' + else: + clean_name = sys_name.replace('_', ' ').title() + + boxplot_data.append(error.values) + system_labels.append(clean_name) + sample_sizes.append(len(error)) + +# Safety check +if not boxplot_data: + raise ValueError("No valid non-zero error data available for any functional system or EDSS.") + +# Put n into x-axis labels so it doesn't overlap the plot +xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(15, 8)) + +bp = ax.boxplot( + boxplot_data, + vert=True, + patch_artist=True, + labels=xtick_labels, + showmeans=True, + meanline=False +) + +# --- 3. Styling --- +box_face = '#D6EAF8' +box_edge = '#2980B9' +whisker_col = '#7F8C8D' +median_col = '#C0392B' +mean_col = '#1ABC9C' +flier_face = '#95A5A6' +flier_edge = '#7F8C8D' + +for box in bp['boxes']: + box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) + +for whisker in bp['whiskers']: + whisker.set(color=whisker_col, linewidth=1.2) + +for cap in bp['caps']: + cap.set(color=whisker_col, linewidth=1.2) + +for median in bp['medians']: + median.set(color=median_col, linewidth=2) + +for mean in bp['means']: + mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) + +for flier in bp['fliers']: + flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) + +# Reference line at zero error +ax.axhline(0, color='black', linewidth=1.2, linestyle='--') + +# Labels and formatting +ax.set_xlabel('Functional System / EDSS', fontsize=11, fontweight='bold') +ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') + +# Rotate x labels for readability +plt.xticks(rotation=45, ha='right') + +# Grid and spines +ax.yaxis.grid(True, linestyle='--', alpha=0.3) +for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + +# --- 4. Legend above the plot, outside the axes --- +legend_handles = [ + Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), + Line2D([0], [0], color=median_col, lw=2, label='Median'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, + markeredgecolor='black', markersize=7, label='Mean'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, + markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), + Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') +] + +ax.legend( + handles=legend_handles, + loc='lower center', + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False +) + +# Leave room at the top for the legend +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +# Optional save +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') + +plt.show() +## +>>>>>>> Stashed changes + +# %% test +# Diagnose: what are the actual differences? +print("\n🔍 Raw differences (first 5 rows per system):") +for gt_col, res_col in functional_systems_to_plot: + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + diff = res - gt + non_zero = (diff != 0).sum() + # Check if it's due to floating point noise + abs_diff = diff.abs() + tiny = (abs_diff > 0) & (abs_diff < 1e-10) + print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}") + +## diff --git a/total_app.py b/total_app.py new file mode 100644 index 0000000..b077ebb --- /dev/null +++ b/total_app.py @@ -0,0 +1,149 @@ +import time +import json +import os +from datetime import datetime +import pandas as pd +from openai import OpenAI +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +# === CONFIGURATION === +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL") +MODEL_NAME = "GPT-OSS-120B" +HEALTH_URL = f"{OPENAI_BASE_URL}/health" # Placeholder - actual health check would need to be implemented +CHAT_URL = f"{OPENAI_BASE_URL}/chat/completions" +# File paths +INPUT_CSV = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" +EDSS_INSTRUCTIONS_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/Komplett.txt" +#GRAMMAR_FILE = "/home/shahin/Lab/Doktorarbeit/Barcelona/attach/just_edss_schema.gbnf" +# Initialize OpenAI client +client = OpenAI( + api_key=OPENAI_API_KEY, + base_url=OPENAI_BASE_URL +) +# Read EDSS instructions from file +with open(EDSS_INSTRUCTIONS_PATH, 'r') as f: + EDSS_INSTRUCTIONS = f.read().strip() + +# === RUN INFERENCE 2 === +def run_inference(patient_text, max_retries=3): + prompt = f'''Du bist ein medizinischer Assistent, der spezialisiert darauf ist, EDSS-Scores (Expanded Disability Status Scale) sowie alle Unterkategorien aus klinischen Berichten zu extrahieren. +### Regeln für die Ausgabe: +1. **Reason**: Erstelle eine prägnante Zusammenfassung (max. 400 Zeichen) der Befunde auf **DEUTSCH**, die zur Einstufung führen. +2. **klassifizierbar**: + - Setze dies auf **true**, wenn ein EDSS-Wert identifiziert, berechnet oder basierend auf den klinischen Hinweisen plausibel geschätzt werden kann. + - Setze dies auf **false**, NUR wenn die Daten absolut unzureichend oder so widersprüchlich sind, dass keinerlei Einstufung möglich ist. +3. **EDSS**: + - Dieses Feld ist **VERPFLICHTEND**, wenn "klassifizierbar" auf true steht. + - Es muss eine Zahl zwischen 0.0 und 10.0 sein. + - Versuche stets, den EDSS-Wert so präzise wie möglich zu bestimmen, auch wenn die Datenlage dünn ist (nutze verfügbare Informationen zu Gehstrecke und Funktionssystemen). + - Dieses Feld **DARF NICHT ERSCHEINEN**, wenn "klassifizierbar" auf false steht. +4. **Unterkategorien**: + - Extrahiere alle folgenden Unterkategorien aus dem Bericht: + - VISUAL OPTIC FUNCTIONS (max. 6.0) + - BRAINSTEM FUNCTIONS (max. 6.0) + - PYRAMIDAL FUNCTIONS (max. 6.0) + - CEREBELLAR FUNCTIONS (max. 6.0) + - SENSORY FUNCTIONS (max. 6.0) + - BOWEL AND BLADDER FUNCTIONS (max. 6.0) + - CEREBRAL FUNCTIONS (max. 6.0) + - AMBULATION (max. 10.0) + - Jede Unterkategorie sollte eine Zahl zwischen 0.0 und der jeweiligen Obergrenze enthalten, wenn sie klassifizierbar ist + - Wenn eine Unterkategorie nicht klassifizierbar ist, setze den Wert auf null +### Einschränkungen: +- Erfinde keine Fakten, aber nutze klinische Herleitungen aus dem Bericht, um den EDSS und die Unterkategorien zu bestimmen. +- Priorisiere die Vergabe eines EDSS-Wertes gegenüber der Markierung als nicht klassifizierbar. +- Halte dich strikt an die JSON-Struktur. +- Die Unterkategorien müssen immer enthalten sein, auch wenn sie null sind. +EDSS-Bewertungsrichtlinien: +{EDSS_INSTRUCTIONS} +Patientenbericht: +{patient_text} +''' + + start_time = time.time() + for attempt in range(max_retries + 1): + try: + response = client.chat.completions.create( + messages=[ + { + "role": "system", + "content": "You extract EDSS scores and all subcategories. You prioritize providing values even if data is partial, by using clinical inference." + }, + { + "role": "user", + "content": prompt + } + ], + model=MODEL_NAME, + max_tokens=2048, + temperature=0.0, + response_format={"type": "json_object"} + ) + content = response.choices[0].message.content + + if content is None or content.strip() == "": + raise ValueError("API returned empty or None response content") + + parsed = json.loads(content) + inference_time = time.time() - start_time + return { + "success": True, + "result": parsed, + "inference_time_sec": inference_time + } + + except Exception as e: + print(f"Attempt {attempt + 1} failed: {e}") + if attempt < max_retries: + time.sleep(2 ** attempt) # Exponential backoff + continue + else: + print("All retries exhausted.") + return { + "success": False, + "error": str(e), + "inference_time_sec": -1 + } + +# === BUILD PATIENT TEXT === +def build_patient_text(row): + # Handle potential NaN or None values in the row + summary = str(row.get("T_Zusammenfassung", "")) if pd.notna(row.get("T_Zusammenfassung")) else "" + diagnoses = str(row.get("Diagnosen", "")) if pd.notna(row.get("Diagnosen")) else "" + clinical = str(row.get("T_KlinBef", "")) if pd.notna(row.get("T_KlinBef")) else "" + findings = str(row.get("T_Befunde", "")) if pd.notna(row.get("T_Befunde")) else "" + return "\n".join([summary, diagnoses, clinical, findings]).strip() +if __name__ == "__main__": + # Read CSV file ONLY inside main block + df = pd.read_csv(INPUT_CSV, sep=';') + results = [] + # Process each row + for idx, row in df.iterrows(): + print(f"Processing row {idx + 1}/{len(df)}") + try: + patient_text = build_patient_text(row) + result = run_inference(patient_text) + # Add unique_id and MedDatum to result for tracking + result["unique_id"] = row.get("unique_id", f"row_{idx}") + result["MedDatum"] = row.get("MedDatum", None) + results.append(result) + print(json.dumps(result, indent=2, ensure_ascii=False)) + except Exception as e: + print(f"Error processing row {idx}: {e}") + results.append({ + "success": False, + "error": str(e), + "unique_id": row.get("unique_id", f"row_{idx}"), + "MedDatum": row.get("MedDatum", None) + }) + # Save results to a JSON file + output_json = INPUT_CSV.replace(".csv", "_results_total.json") + with open(output_json, 'w', encoding='utf-8') as f: + json.dump(results, f, indent=2, ensure_ascii=False) + print(f"Results saved to {output_json}") + +