diff --git a/audit.py b/audit.py index 00dd9bf..7534646 100644 --- a/audit.py +++ b/audit.py @@ -389,7 +389,7 @@ def plot_single_json_error_analysis_with_log( ] plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend") - plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30) +# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30) plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12) plt.xlabel("LLM Confidence Bracket", fontsize=12) plt.grid(axis="y", linestyle=":", alpha=0.5) @@ -414,6 +414,317 @@ plot_single_json_error_analysis_with_log(json_path, gt_path) ## +# %% 1json (rewritten with robust parsing + detailed data log + Pearson r in plot) +import pandas as pd +import numpy as np +import json +import matplotlib.pyplot as plt +import seaborn as sns +from matplotlib.patches import Patch +from matplotlib.lines import Line2D +from scipy.stats import pearsonr + +def plot_single_json_error_analysis_with_log( + json_file_path, + ground_truth_path, + edss_gt_col="EDSS", + min_bin_count=5, +): + def norm_str(x): + # normalize identifiers and dates consistently + return str(x).strip().lower() + + def parse_edss(x): + # robust numeric parse: handles "3,5" as 3.5, blanks, "nan", etc. + if x is None: + return np.nan + s = str(x).strip() + if s == "" or s.lower() in {"nan", "none", "null"}: + return np.nan + s = s.replace(",", ".") + return pd.to_numeric(s, errors="coerce") + + print("\n" + "="*80) + print("SINGLE-JSON ERROR ANALYSIS (WITH LOG)") + print("="*80) + print(f"JSON: {json_file_path}") + print(f"GT: {ground_truth_path}") + + # ------------------------------------------------------------------ + # 1) Load Ground Truth + # ------------------------------------------------------------------ + df_gt = pd.read_csv(ground_truth_path, sep=";") + + required_gt_cols = {"unique_id", "MedDatum", edss_gt_col} + missing_cols = required_gt_cols - set(df_gt.columns) + if missing_cols: + raise ValueError(f"GT is missing required columns: {missing_cols}. Available: {df_gt.columns.tolist()}") + + df_gt["unique_id"] = df_gt["unique_id"].map(norm_str) + df_gt["MedDatum"] = df_gt["MedDatum"].map(norm_str) + df_gt["key"] = df_gt["unique_id"] + "_" + df_gt["MedDatum"] + + # Robust EDSS parsing + df_gt["EDSS_gt"] = df_gt[edss_gt_col].map(parse_edss) + + # GT logs + print("\n--- GT LOG ---") + print(f"GT rows: {len(df_gt)}") + print(f"GT unique keys: {df_gt['key'].nunique()}") + gt_dup = df_gt["key"].duplicated(keep=False).sum() + print(f"GT duplicate-key rows: {gt_dup}") + print(f"GT missing EDSS (numeric): {df_gt['EDSS_gt'].isna().sum()}") + print(f"GT missing EDSS unique keys: {df_gt.loc[df_gt['EDSS_gt'].isna(), 'key'].nunique()}") + + if gt_dup > 0: + print("\n[WARNING] GT has duplicate keys. Merge can duplicate rows. Example duplicate keys:") + print(df_gt.loc[df_gt["key"].duplicated(keep=False), "key"].value_counts().head(10)) + + # ------------------------------------------------------------------ + # 2) Load Predictions from the specific JSON + # ------------------------------------------------------------------ + with open(json_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + + total_entries = len(data) + success_entries = sum(1 for e in data if e.get("success")) + + all_preds = [] + skipped = { + "not_success": 0, + "missing_uid_or_date": 0, + "missing_edss": 0, + "missing_conf": 0, + } + + for entry in data: + if not entry.get("success"): + skipped["not_success"] += 1 + continue + + res = entry.get("result", {}) + uid = res.get("unique_id") + md = res.get("MedDatum") + + if uid is None or md is None or str(uid).strip() == "" or str(md).strip() == "": + skipped["missing_uid_or_date"] += 1 + continue + + edss_pred = parse_edss(res.get("EDSS")) + conf = pd.to_numeric(res.get("certainty_percent"), errors="coerce") + + if pd.isna(edss_pred): + skipped["missing_edss"] += 1 + if pd.isna(conf): + skipped["missing_conf"] += 1 + + all_preds.append({ + "unique_id": norm_str(uid), + "MedDatum": norm_str(md), + "key": norm_str(uid) + "_" + norm_str(md), + "EDSS_pred": edss_pred, + "confidence": conf, + }) + + df_pred = pd.DataFrame(all_preds) + + # Pred logs + print("\n--- PRED LOG ---") + print(f"JSON total entries: {total_entries}") + print(f"JSON success entries: {success_entries}") + print(f"Pred rows loaded (success + has keys): {len(df_pred)}") + if len(df_pred) == 0: + print("[ERROR] No usable prediction rows found. Nothing to plot.") + return + + print(f"Pred unique keys: {df_pred['key'].nunique()}") + print(f"Pred missing EDSS (numeric): {df_pred['EDSS_pred'].isna().sum()}") + print(f"Pred missing confidence: {df_pred['confidence'].isna().sum()}") + print("Skipped counts:", skipped) + + key_counts = df_pred["key"].value_counts() + dup_pred_rows = (key_counts > 1).sum() + max_rep = int(key_counts.max()) + print(f"Keys with >1 prediction in this JSON: {dup_pred_rows}") + print(f"Max repetitions of a single key in this JSON: {max_rep}") + if max_rep > 1: + print("Top repeated keys in this JSON:") + print(key_counts.head(10)) + + # ------------------------------------------------------------------ + # 3) Merge + # ------------------------------------------------------------------ + gt_key_set = set(df_gt["key"]) + df_pred["key_in_gt"] = df_pred["key"].isin(gt_key_set) + not_in_gt = df_pred.loc[~df_pred["key_in_gt"]] + + print("\n--- KEY MATCH LOG ---") + print(f"Pred rows with key found in GT: {df_pred['key_in_gt'].sum()} / {len(df_pred)}") + print(f"Pred rows with key NOT found in GT: {len(not_in_gt)}") + if len(not_in_gt) > 0: + print("[WARNING] Some prediction keys are not present in GT. First 10:") + print(not_in_gt[["unique_id", "MedDatum", "key"]].head(10)) + + df_merged = df_pred.merge( + df_gt[["key", "EDSS_gt"]], + on="key", + how="inner", + validate="many_to_one" + ) + + print("\n--- MERGE LOG ---") + print(f"Merged rows (inner join): {len(df_merged)}") + print(f"Merged unique keys: {df_merged['key'].nunique()}") + print(f"Merged missing GT EDSS: {df_merged['EDSS_gt'].isna().sum()}") + print(f"Merged missing pred EDSS: {df_merged['EDSS_pred'].isna().sum()}") + print(f"Merged missing confidence:{df_merged['confidence'].isna().sum()}") + + rows_complete = df_merged.dropna(subset=["EDSS_gt", "EDSS_pred", "confidence"]) + print("\n--- FILTER LOG (what will be used for stats/plot) ---") + print(f"Rows with all required fields (EDSS_gt, EDSS_pred, confidence): {len(rows_complete)}") + if len(rows_complete) == 0: + print("[ERROR] No complete rows after filtering. Nothing to plot.") + return + + rows_complete = rows_complete.copy() + rows_complete["abs_error"] = (rows_complete["EDSS_pred"] - rows_complete["EDSS_gt"]).abs() + + # ------------------------------------------------------------------ + # 4) Pearson correlation on row-level data + # ------------------------------------------------------------------ + corr_df = rows_complete.dropna(subset=["confidence", "abs_error"]).copy() + + if len(corr_df) >= 2 and corr_df["confidence"].nunique() > 1 and corr_df["abs_error"].nunique() > 1: + r_value, p_value = pearsonr(corr_df["confidence"], corr_df["abs_error"]) + corr_text = f"Pearson r = {r_value:.3f}\np = {p_value:.3g}\nn = {len(corr_df)}" + else: + r_value, p_value = np.nan, np.nan + corr_text = f"Pearson r = NA\np = NA\nn = {len(corr_df)}" + + print("\n--- CORRELATION LOG ---") + print(corr_text.replace("\n", " | ")) + + # ------------------------------------------------------------------ + # 5) Binning + stats + # ------------------------------------------------------------------ + bins = [0, 70, 80, 90, 100] + labels = ["Low (<70%)", "Moderate (70-80%)", "High (80-90%)", "Very High (90-100%)"] + + rows_complete["conf_bin"] = pd.cut(rows_complete["confidence"], bins=bins, labels=labels, include_lowest=True) + conf_outside = rows_complete["conf_bin"].isna().sum() + print(f"Rows with confidence outside [0,100] or outside bin edges: {conf_outside}") + if conf_outside > 0: + print("Example confidences outside bins:") + print(rows_complete.loc[rows_complete["conf_bin"].isna(), "confidence"].head(20).to_list()) + + df_plot = rows_complete.dropna(subset=["conf_bin"]) + stats = ( + df_plot.groupby("conf_bin", observed=True)["abs_error"] + .agg(mean="mean", std="std", count="count") + .reindex(labels) + .reset_index() + ) + + print("\n--- BIN STATS ---") + print(stats) + + low_bins = stats.loc[stats["count"].fillna(0) < min_bin_count, ["conf_bin", "count"]] + if not low_bins.empty: + print(f"\n[WARNING] Some bins have < {min_bin_count} rows; error bars/trend may be unstable:") + print(low_bins) + + # ------------------------------------------------------------------ + # 6) Plot + # ------------------------------------------------------------------ + plt.figure(figsize=(13, 8)) + colors = sns.color_palette("Blues", n_colors=len(labels)) + + means = stats["mean"].to_numpy() + counts = stats["count"].fillna(0).astype(int).to_numpy() + stds = stats["std"].to_numpy() + means_plot = np.nan_to_num(means, nan=0.0) + + bars = plt.bar(labels, means_plot, color=colors, edgecolor="black", alpha=0.85) + + sem = np.where((counts > 1) & (~np.isnan(stds)), stds / np.sqrt(counts), np.nan) + plt.errorbar(labels, means_plot, yerr=sem, fmt="none", c="black", capsize=8, elinewidth=1.5) + + valid_idx = np.where(~np.isnan(means))[0] + if len(valid_idx) >= 2: + x_idx = np.arange(len(labels)) + z = np.polyfit(valid_idx, means[valid_idx], 1) + p = np.poly1d(z) + plt.plot(x_idx, p(x_idx), color="#e74c3c", linestyle="--", linewidth=3, zorder=5) + trend_label = "Trend Line" + else: + trend_label = "Trend Line (insufficient bins)" + print("\n[INFO] Not enough non-empty bins to fit a trend line.") + + # Data labels + for i, bar in enumerate(bars): + n_count = int(counts[i]) + mae_val = means[i] + if np.isnan(mae_val) or n_count == 0: + txt = "empty" + y = 0.02 + else: + txt = f"MAE: {mae_val:.2f}\nn={n_count}" + y = bar.get_height() + 0.04 + plt.text( + bar.get_x() + bar.get_width()/2, + y, + txt, + ha="center", + va="bottom", + fontweight="bold", + fontsize=10 + ) + + # Pearson correlation text box inside plot + ax = plt.gca() + ax.text( + 0.02, 0.98, + corr_text, + transform=ax.transAxes, + ha="left", + va="top", + fontsize=11, + zorder=10, + bbox=dict(boxstyle="round,pad=0.4", facecolor="white", edgecolor="gray", alpha=0.95) + ) + # Legend + legend_elements = [ + Patch(facecolor=colors[0], edgecolor="black", label=f"Bin 1: {labels[0]}"), + Patch(facecolor=colors[1], edgecolor="black", label=f"Bin 2: {labels[1]}"), + Patch(facecolor=colors[2], edgecolor="black", label=f"Bin 3: {labels[2]}"), + Patch(facecolor=colors[3], edgecolor="black", label=f"Bin 4: {labels[3]}"), + Line2D([0], [0], color="#e74c3c", linestyle="--", lw=3, label=trend_label), + Line2D([0], [0], color="black", marker="_", linestyle="None", markersize=10, label="Std Error (SEM)"), + Patch(color="none", label="Metric: Mean Absolute Error (MAE)") + ] + plt.legend(handles=legend_elements, loc="upper right", frameon=True, shadow=True, title="Legend") + +# plt.title("Validation: Confidence vs. Error Magnitude (Single JSON)", fontsize=15, pad=30) + plt.ylabel("Mean Absolute Error (EDSS Points)", fontsize=12) + plt.xlabel("LLM Confidence Bracket", fontsize=12) + plt.grid(axis="y", linestyle=":", alpha=0.5) + + ymax = np.nanmax(means) if np.any(~np.isnan(means)) else 0.0 + plt.ylim(0, max(0.5, float(ymax) + 0.6)) + plt.tight_layout() + plt.show() + + print("\n" + "="*80) + print("DONE") + print("="*80) + + +# --- RUN --- +json_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique_results_iter_1_20260212_020628.json" +gt_path = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv" + +plot_single_json_error_analysis_with_log(json_path, gt_path) +## # %% Certainty vs Delta (rewritten with robust parsing + detailed data loss logs) import pandas as pd diff --git a/certainty_show.py b/certainty_show.py index 24b199d..bb15a7b 100644 --- a/certainty_show.py +++ b/certainty_show.py @@ -1535,6 +1535,6 @@ plot_single_json_error_analysis(json_path, gt_path) #plot_error_distribution_by_confidence('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') #plot_confidence_vs_abs_error_refined('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') #plot_confidence_vs_abs_error_with_counts('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') -#plot_final_thesis_error_chart('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') +plot_final_thesis_error_chart('/home/shahin/Lab/Doktorarbeit/Barcelona/Data/iteration', '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/GT_Numbers.csv') ## diff --git a/show_plots.py b/show_plots.py index 36c3d2a..def6ffd 100644 --- a/show_plots.py +++ b/show_plots.py @@ -401,7 +401,7 @@ import seaborn as sns import numpy as np # Load your data from TSV file -file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_results_unique.tsv' +file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS @@ -745,7 +745,7 @@ df = df.rename(columns=column_mapping) df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') # Patient -patient_id = '6389d658' +patient_id = 'd13e4aa3' patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() if patient_data.empty: raise ValueError(f"No data found for patient: {patient_id}") @@ -1764,100 +1764,7 @@ plt.show() -# %% Difference Gemini easy - - -# --- 1. Process Error Data --- -system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] -plot_list = [] - -for gt_col, res_col in functional_systems_to_plot: - sys_name = gt_col.split('.')[1] - - # Robust parsing - gt = df[gt_col].apply(safe_parse) - res = df[res_col].apply(safe_parse) - error = res - gt - - # Calculate counts - matches = (error == 0).sum() - under = (error < 0).sum() - over = (error > 0).sum() - total = error.dropna().count() - - # Calculate Percentages - # Using max(total, 1) to avoid division by zero - divisor = max(total, 1) - match_pct = (matches / divisor) * 100 - under_pct = (under / divisor) * 100 - over_pct = (over / divisor) * 100 - - plot_list.append({ - 'System': sys_name.replace('_', ' ').title(), - 'Matches': matches, - 'MatchPct': match_pct, - 'Under': under, - 'UnderPct': under_pct, - 'Over': over, - 'OverPct': over_pct - }) - -stats_df = pd.DataFrame(plot_list) - -# --- 2. Plotting --- -fig, ax = plt.subplots(figsize=(12, 8)) # Slightly taller for multi-line labels - -color_under = '#E74C3C' -color_over = '#3498DB' -bar_height = 0.6 - -y_pos = np.arange(len(stats_df)) - -ax.barh(y_pos, -stats_df['Under'], bar_height, label='Under-scored', color=color_under, edgecolor='white', alpha=0.8) -ax.barh(y_pos, stats_df['Over'], bar_height, label='Over-scored', color=color_over, edgecolor='white', alpha=0.8) - -# --- 3. Aesthetics & Labels --- - -for i, row in stats_df.iterrows(): - # Constructing a detailed label for the left side - # Matches (Bold) | Under % | Over % - label_text = ( - f"$\mathbf{{{row['System']}}}$\n" - f"Matches: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" - f"Under: {int(row['Under'])} ({row['UnderPct']:.1f}%) | Over: {int(row['Over'])} ({row['OverPct']:.1f}%)" - ) - - # Position text to the left of the x=0 line - ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.3) - -# Zero line -ax.axvline(0, color='black', linewidth=1.2, alpha=0.7) - -# Clean up axes -ax.set_yticks([]) -ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold', labelpad=10) -#ax.set_title('Directional Error Analysis by Functional System', fontsize=14, pad=30) - -# Make X-axis labels absolute -ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) - -# Remove spines -for spine in ['top', 'right', 'left']: - ax.spines[spine].set_visible(False) - -# Legend -ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1)) - -# Grid -ax.xaxis.grid(True, linestyle='--', alpha=0.3) - -plt.tight_layout() -plt.show() -## - - - -# %% name +# %% Difference Plot Gemini import pandas as pd import matplotlib.pyplot as plt import os @@ -1946,6 +1853,136 @@ ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) plt.tight_layout() plt.show() ## + + +# %% Functional System Error Boxplots +import pandas as pd +import matplotlib.pyplot as plt +import os +import numpy as np +from matplotlib.patches import Patch +from matplotlib.lines import Line2D + +# --- Configuration & Theme --- +plt.rcParams['font.family'] = 'Arial' +figure_save_path = 'project/visuals/functional_systems_boxplot.svg' + +# --- 1. Build error data for boxplots --- +boxplot_data = [] +system_labels = [] +sample_sizes = [] + +for gt_col, res_col in functional_systems_to_plot: + sys_name = gt_col.split('.')[1] + + # Robust parsing + gt = df[gt_col].apply(safe_parse) + res = df[res_col].apply(safe_parse) + + # Error = result - ground truth + error = (res - gt).dropna() + + # Ignore all 0 errors + error = error[error != 0] + + # Keep only systems that actually have non-zero data + if len(error) > 0: + clean_name = sys_name.replace('_', ' ').title() + boxplot_data.append(error.values) + system_labels.append(clean_name) + sample_sizes.append(len(error)) + +# Safety check +if not boxplot_data: + raise ValueError("No valid non-zero error data available for any functional system.") + +# Put n into x-axis labels so it doesn't overlap the plot +xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] + +# --- 2. Plotting --- +fig, ax = plt.subplots(figsize=(14, 8)) + +bp = ax.boxplot( + boxplot_data, + vert=True, + patch_artist=True, + labels=xtick_labels, + showmeans=True, + meanline=False +) + +# --- 3. Styling --- +box_face = '#D6EAF8' +box_edge = '#2980B9' +whisker_col = '#7F8C8D' +median_col = '#C0392B' +mean_col = '#1ABC9C' +flier_face = '#95A5A6' +flier_edge = '#7F8C8D' + +for box in bp['boxes']: + box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) + +for whisker in bp['whiskers']: + whisker.set(color=whisker_col, linewidth=1.2) + +for cap in bp['caps']: + cap.set(color=whisker_col, linewidth=1.2) + +for median in bp['medians']: + median.set(color=median_col, linewidth=2) + +for mean in bp['means']: + mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) + +for flier in bp['fliers']: + flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) + +# Reference line at zero error +ax.axhline(0, color='black', linewidth=1.2, linestyle='--') + +# Labels and formatting +ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') +ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') + +# Rotate x labels for readability +plt.xticks(rotation=45, ha='right') + +# Grid and spines +ax.yaxis.grid(True, linestyle='--', alpha=0.3) +for spine in ['top', 'right']: + ax.spines[spine].set_visible(False) + +# --- 4. Legend above the plot, outside the axes --- +legend_handles = [ + Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), + Line2D([0], [0], color=median_col, lw=2, label='Median'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, + markeredgecolor='black', markersize=7, label='Mean'), + Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, + markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), + Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') +] + +ax.legend( + handles=legend_handles, + loc='lower center', + bbox_to_anchor=(0.5, 1.02), + ncol=3, + frameon=False +) + +# Leave room at the top for the legend +plt.tight_layout(rect=[0, 0, 1, 0.90]) + +# Optional save +os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) +plt.savefig(figure_save_path, format='svg', bbox_inches='tight') + +plt.show() +## + + # %% test # Diagnose: what are the actual differences? print("\nšŸ” Raw differences (first 5 rows per system):")