# %% Scatter import pandas as pd import matplotlib.pyplot as plt import numpy as np # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) # Create scatter plot plt.figure(figsize=(8, 6)) plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue') # Add labels and title plt.xlabel('GT_EDSS') plt.ylabel('LLM_Results') plt.title('Comparison of GT_EDSS vs LLM_Results') # Optional: Add a diagonal line for reference (perfect prediction) plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction') plt.legend() # Show plot plt.grid(True) plt.tight_layout() plt.show() ## # %% Bland0-altman import pandas as pd import matplotlib.pyplot as plt import numpy as np import statsmodels.api as sm # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) # Create Bland-Altman plot f, ax = plt.subplots(1, figsize=(8, 5)) sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax) # Add labels and title ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results') ax.set_xlabel('Mean of GT_EDSS and LLM_Results') ax.set_ylabel('Difference between GT_EDSS and LLM_Results') # Display Bland-Altman plot plt.tight_layout() plt.show() # Print some statistics mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results']) std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results']) print(f"Mean difference: {mean_diff:.3f}") print(f"Standard deviation of differences: {std_diff:.3f}") print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]") ## # %% Confusion matrix import pandas as pd import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # For confusion matrix, we need to categorize the values # Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) def categorize_edss(value): if pd.isna(value): return np.nan elif value <= 1.0: return '0-1' elif value <= 2.0: return '1-2' elif value <= 3.0: return '2-3' elif value <= 4.0: return '3-4' elif value <= 5.0: return '4-5' elif value <= 6.0: return '5-6' elif value <= 7.0: return '6-7' elif value <= 8.0: return '7-8' elif value <= 9.0: return '8-9' elif value <= 10.0: return '9-10' else: return '10+' # Create categorical versions df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) # Remove any NaN categories df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) # Create confusion matrix cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Plot confusion matrix plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) #plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') plt.xlabel('LLM Generated EDSS') plt.ylabel('Ground Truth EDSS') plt.tight_layout() plt.show() # Print classification report print("Classification Report:") print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) # Print raw counts print("\nConfusion Matrix (Raw Counts):") print(cm) ## # %% Confusion matrix import pandas as pd import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # For confusion matrix, we need to categorize the values # Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) def categorize_edss(value): if pd.isna(value): return np.nan elif value <= 1.0: return '0-1' elif value <= 2.0: return '1-2' elif value <= 3.0: return '2-3' elif value <= 4.0: return '3-4' elif value <= 5.0: return '4-5' elif value <= 6.0: return '5-6' elif value <= 7.0: return '6-7' elif value <= 8.0: return '7-8' elif value <= 9.0: return '8-9' elif value <= 10.0: return '9-10' else: return '10+' # Create categorical versions df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) # Remove any NaN categories df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) # Create confusion matrix cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Plot confusion matrix plt.figure(figsize=(10, 8)) ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Add legend text above the color bar # Get the colorbar object cbar = ax.collections[0].colorbar # Add text above the colorbar cbar.set_label('Number of Cases', rotation=270, labelpad=20) plt.xlabel('LLM Generated EDSS') plt.ylabel('Ground Truth EDSS') #plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') plt.tight_layout() plt.show() # Print classification report print("Classification Report:") print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) # Print raw counts print("\nConfusion Matrix (Raw Counts):") print(cm) ## # %% Classification import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix import numpy as np # Load your data from TSV file file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Check data structure print("Data shape:", df.shape) print("First few rows:") print(df.head()) print("\nColumn names:") for col in df.columns: print(f" {col}") # Function to safely convert to boolean def safe_bool_convert(series): '''Safely convert series to boolean, handling various input formats''' # Convert to string first, then to boolean series_str = series.astype(str).str.strip().str.lower() # Handle different true/false representations bool_map = { 'true': True, '1': True, 'yes': True, 'y': True, 'false': False, '0': False, 'no': False, 'n': False } converted = series_str.map(bool_map) # Handle remaining NaN values converted = converted.fillna(False) # or True, depending on your preference return converted # Convert columns safely if 'result.klassifizierbar' in df.columns: print("\nresult.klassifizierbar column info:") print(df['result.klassifizierbar'].head(10)) print("Unique values:", df['result.klassifizierbar'].unique()) df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar']) print("After conversion:") print(df['result.klassifizierbar'].value_counts()) if 'GT.klassifizierbar' in df.columns: print("\nGT.klassifizierbar column info:") print(df['GT.klassifizierbar'].head(10)) print("Unique values:", df['GT.klassifizierbar'].unique()) df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar']) print("After conversion:") print(df['GT.klassifizierbar'].value_counts()) # Create bar chart showing only True values for klassifizierbar if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: # Get counts for True values only llm_true_count = df['result.klassifizierbar'].sum() gt_true_count = df['GT.klassifizierbar'].sum() # Plot using matplotlib directly fig, ax = plt.subplots(figsize=(8, 6)) x = np.arange(2) width = 0.35 bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8) bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8) # Add value labels on bars ax.annotate(f'{llm_true_count}', xy=(x[0], llm_true_count), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom') ax.annotate(f'{gt_true_count}', xy=(x[1], gt_true_count), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom') ax.set_xlabel('Classification Status (klassifizierbar)') ax.set_ylabel('Count') ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"') ax.set_xticks(x) ax.set_xticklabels(['LLM', 'GT']) ax.legend() plt.tight_layout() plt.show() # Create confusion matrix if both columns exist if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: try: # Ensure both columns are boolean llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool) gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool) cm = confusion_matrix(gt_bool, llm_bool) # Plot confusion matrix fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['False ', 'True '], yticklabels=['False', 'True '], ax=ax) ax.set_xlabel('LLM Predictions ') ax.set_ylabel('GT Labels ') ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"') plt.tight_layout() plt.show() print("Confusion Matrix:") print(cm) except Exception as e: print(f"Error creating confusion matrix: {e}") # Show final data info print("\nFinal DataFrame info:") print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info()) ## # %% Boxplot import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # 1. DEFINE CATEGORY ORDER # This ensures the X-axis is numerically logical (0-1 comes before 1-2) category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+'] # Convert the column to a Categorical type with the specific order df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss), categories=category_order, ordered=True) plt.figure(figsize=(14, 8)) # 2. ADD HUE FOR LEGEND # Assigning x to 'hue' allows Seaborn to generate a legend automatically box_plot = sns.boxplot( data=df_clean, x='GT.EDSS_cat', y='result.EDSS', hue='GT.EDSS_cat', # Added hue palette='viridis', linewidth=1.5, legend=True # Ensure legend is enabled ) # 3. CUSTOMIZE PLOT plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20) plt.xlabel('Ground Truth EDSS Category', fontsize=14) plt.ylabel('LLM Predicted EDSS', fontsize=14) # Move legend to the side or top plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left') plt.xticks(rotation=45, ha='right', fontsize=10) plt.grid(True, axis='y', alpha=0.3) plt.tight_layout() plt.show() ## # %% Postproccessing Column names import pandas as pd # Read the TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Create a mapping dictionary for German to English column names column_mapping = { 'EDSS':'GT.EDSS', 'klassifizierbar': 'GT.klassifizierbar', 'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS', 'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS', 'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS', 'Sensibiliät': 'GT.SENSORY_FUNCTIONS', 'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS', 'Ambulation': 'GT.AMBULATION', 'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS', 'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS' } # Rename columns df = df.rename(columns=column_mapping) # Save the modified dataframe back to TSV file df.to_csv(file_path, sep='\t', index=False) print("Columns have been successfully renamed!") print("Renamed columns:") for old_name, new_name in column_mapping.items(): if old_name in df.columns: print(f" {old_name} -> {new_name}") ## # %% Styled table import pandas as pd import numpy as np import seaborn as sns import matplotlib.pyplot as plt import dataframe_image as dfi # Load data df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t') # 1. Identify all GT and result columns gt_columns = [col for col in df.columns if col.startswith('GT.')] result_columns = [col for col in df.columns if col.startswith('result.')] print("GT Columns found:", gt_columns) print("Result Columns found:", result_columns) # 2. Create proper mapping between GT and result columns # Handle various naming conventions (spaces, underscores, etc.) column_mapping = {} for gt_col in gt_columns: base_name = gt_col.replace('GT.', '') # Clean the base name for matching - remove spaces, underscores, etc. # Try different matching approaches candidates = [ f'result.{base_name}', # Exact match f'result.{base_name.replace(" ", "_")}', # With underscores f'result.{base_name.replace("_", " ")}', # With spaces f'result.{base_name.replace(" ", "")}', # No spaces f'result.{base_name.replace("_", "")}' # No underscores ] # Also try case-insensitive matching candidates.append(f'result.{base_name.lower()}') candidates.append(f'result.{base_name.upper()}') # Try to find matching result column matched = False for candidate in candidates: if candidate in result_columns: column_mapping[gt_col] = candidate matched = True break # If no exact match found, try partial matching if not matched: # Try to match by removing special characters and comparing base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' ']) for result_col in result_columns: result_base = result_col.replace('result.', '') result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' ']) if base_clean.lower() == result_clean.lower(): column_mapping[gt_col] = result_col matched = True break print("Column mapping:", column_mapping) # 3. Faster, vectorized computation using the corrected mapping data_list = [] for gt_col, result_col in column_mapping.items(): print(f"Processing {gt_col} vs {result_col}") # Convert to numeric, forcing errors to NaN s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float) s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float) # Calculate matches (abs difference <= 0.5) diff = np.abs(s1 - s2) matches = (diff <= 0.5).sum() # Determine the denominator (total valid comparisons) valid_count = diff.notna().sum() if valid_count > 0: percentage = (matches / valid_count) * 100 else: percentage = 0 # Extract clean base name for display base_name = gt_col.replace('GT.', '') data_list.append({ 'GT': base_name, 'Match %': round(percentage, 1) }) # 4. Prepare Data match_df = pd.DataFrame(data_list) # Clean up labels: Replace underscores with spaces and capitalize match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title() match_df = match_df.sort_values('Match %', ascending=False) # 5. Create a "Beautiful" Table using Seaborn Heatmap def create_luxury_table(df, output_file="edss_agreement.png"): # Set the aesthetic style sns.set_theme(style="white", font="sans-serif") # Prepare data for heatmap plot_data = df.set_index('GT')[['Match %']] # Initialize the figure # Height is dynamic based on number of rows fig, ax = plt.subplots(figsize=(8, len(df) * 0.6)) # Create a custom diverging color map (Deep Red -> Mustard -> Emerald) # This looks more professional than standard 'RdYlGn' cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True) # Draw the heatmap sns.heatmap( plot_data, annot=True, fmt=".1f", cmap=cmap, center=85, # Centers the color transition vmin=50, vmax=100, # Range of the gradient linewidths=2, linecolor='white', cbar=False, # Remove color bar for a "table" look annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"} ) # Styling the Axes (Turning the heatmap into a table) ax.set_xlabel("") ax.set_ylabel("") ax.xaxis.tick_top() # Move "Match %" label to top ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50') ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0) # Add a thin border around the plot for _, spine in ax.spines.items(): spine.set_visible(True) spine.set_color('#ecf0f1') plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50') # Add a subtle footer plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points", wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic') # Save with high resolution plt.tight_layout() plt.savefig(output_file, dpi=300, bbox_inches='tight') print(f"Beautiful table saved as {output_file}") # Execute create_luxury_table(match_df) # Run the function save_styled_table(match_df) # 6. Save as SVG plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight') print("Successfully saved agreement_table.svg") # Show plot if running in a GUI environment plt.show() ## # %% Time Plot import numpy as np import matplotlib.pyplot as plt import pandas as pd from scipy import stats # Load the TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Extract the inference_time_sec column inference_times = df['inference_time_sec'].dropna() # Remove NaN values # Calculate statistics mean_time = inference_times.mean() std_time = inference_times.std() median_time = np.median(inference_times) # Create the histogram fig, ax = plt.subplots(figsize=(10, 6)) # Create histogram with bins of 1 second width min_time = int(inference_times.min()) max_time = int(inference_times.max()) + 1 bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width # Create histogram with counts (not probability density) n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5) # Generate Gaussian curve for fit x = np.linspace(inference_times.min(), inference_times.max(), 100) # Scale Gaussian to match histogram counts gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0]) # Plot Gaussian fit ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)') # Add vertical lines for mean and median ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s') ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s') # Add standard deviation as vertical lines ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s') ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s') ax.set_xlabel('Inference Time (seconds)') ax.set_ylabel('Frequency') ax.set_title('Inference Time Distribution with Gaussian Fit') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.show() ## # %% Dashboard import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import numpy as np from matplotlib.gridspec import GridSpec def to_numeric_comma(s: pd.Series) -> pd.Series: # accepts 1.5 and 1,5 return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Rename columns to remove 'result.' prefix and replace spaces column_mapping = {} for col in df.columns: if col.startswith('result.'): new_name = col.replace('result.', '').replace(' ', '_') column_mapping[col] = new_name df = df.rename(columns=column_mapping) # Parse MedDatum safely df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') # Patient patient_id = 'd13e4aa3' patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() if patient_data.empty: raise ValueError(f"No data found for patient: {patient_id}") # Functional systems + EDSS edss_col, edss_title = ('GT.EDSS', 'EDSS') functional_systems = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), ('GT.SENSORY_FUNCTIONS', 'Sensory'), ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), ('GT.AMBULATION', 'Ambulation'), ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), ] # y-axis max rules ymax_by_col = { 'GT.PYRAMIDAL_FUNCTIONS': 6, 'GT.SENSORY_FUNCTIONS': 6, 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, 'GT.VISUAL_OPTIC_FUNCTIONS': 6, 'GT.CEREBELLAR_FUNCTIONS': 5, 'GT.CEREBRAL_FUNCTIONS': 5, 'GT.BRAINSTEM_FUNCTIONS': 5, 'GT.EDSS': 10, } default_ymax = 6 # ---------- Build shared "event dates" ticks ---------- cols_for_dates = [edss_col] + [c for c, _ in functional_systems] event_dates = [] for c in cols_for_dates: if c in patient_data.columns: y = to_numeric_comma(patient_data[c]) # <-- changed x = patient_data['MedDatum'] tmp = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]) event_dates.extend(tmp["x"].tolist()) event_dates = sorted(pd.Series(event_dates).drop_duplicates().tolist()) max_ticks = 8 if len(event_dates) > max_ticks: idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) event_dates = [event_dates[i] for i in idx] # ---------- A4 figure ---------- fig = plt.figure(figsize=(11.69, 8.27)) gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) def style_time_axis(ax, show_labels=True): ax.set_xticks(event_dates) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) if not show_labels: ax.tick_params(labelbottom=False) # ---------- EDSS main plot ---------- ax_main = fig.add_subplot(gs[0, :]) if edss_col in patient_data.columns: y = to_numeric_comma(patient_data[edss_col]) # <-- changed x = patient_data['MedDatum'] plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") ax_main.set_title(edss_title, fontsize=14, fontweight='bold') ax_main.set_ylabel("Score") ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) ax_main.grid(True, alpha=0.3) if not plot_df.empty: ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') else: ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') else: ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') ax_main.set_ylim(0, ymax_by_col.get(edss_col, 10)) ax_main.grid(True, alpha=0.3) style_time_axis(ax_main) # ---------- Small aligned plots ---------- small_axes = [] for k, (col, title) in enumerate(functional_systems): r = 1 + (k // 4) c = (k % 4) ax = fig.add_subplot(gs[r, c], sharex=ax_main) small_axes.append(ax) ymax = ymax_by_col.get(col, default_ymax) ax.set_title(title, fontsize=10) ax.set_ylabel("Score") ax.set_ylim(0, ymax) ax.grid(True, alpha=0.3) if col in patient_data.columns: y = to_numeric_comma(patient_data[col]) # <-- changed x = patient_data['MedDatum'] plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") if not plot_df.empty: ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') else: ax.set_title(f"{title} (no data)", fontsize=10) else: ax.set_title(f"{title} (missing)", fontsize=10) style_time_axis(ax) # Hide x tick labels on first row of small plots for ax in small_axes[:4]: ax.tick_params(labelbottom=False) plt.tight_layout() fig.subplots_adjust(hspace=0.7) plt.show() ## <<<<<<< Updated upstream ======= # %% Dashboard Angepasst import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import numpy as np from matplotlib.gridspec import GridSpec def to_numeric_comma(s: pd.Series) -> pd.Series: # accepts 1.5 and 1,5 return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Rename columns to remove 'result.' prefix and replace spaces column_mapping = {} for col in df.columns: if col.startswith('result.'): new_name = col.replace('result.', '').replace(' ', '_') column_mapping[col] = new_name df = df.rename(columns=column_mapping) # Parse MedDatum safely df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') # Patient patient_id = '3d942c60' patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() if patient_data.empty: raise ValueError(f"No data found for patient: {patient_id}") # Functional systems + EDSS edss_col, edss_title = ('GT.EDSS', 'EDSS') functional_systems = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), ('GT.SENSORY_FUNCTIONS', 'Sensory'), ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), ('GT.AMBULATION', 'Ambulation'), ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), ] # y-axis max rules ymax_by_col = { 'GT.PYRAMIDAL_FUNCTIONS': 6, 'GT.SENSORY_FUNCTIONS': 6, 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, 'GT.VISUAL_OPTIC_FUNCTIONS': 6, 'GT.CEREBELLAR_FUNCTIONS': 5, 'GT.CEREBRAL_FUNCTIONS': 5, 'GT.BRAINSTEM_FUNCTIONS': 5, 'GT.EDSS': 10, } default_ymax = 6 # ---------- Build shared visit dates ticks ---------- # Use ALL patient visit dates, not only dates with valid numeric values event_dates = sorted(patient_data['MedDatum'].dropna().drop_duplicates().tolist()) max_ticks = 8 if len(event_dates) > max_ticks: idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) event_dates = [event_dates[i] for i in idx] # Base timeline for plotting: one row per patient visit date timeline = ( patient_data[['MedDatum']] .dropna() .drop_duplicates() .sort_values('MedDatum') .rename(columns={'MedDatum': 'x'}) ) # ---------- A4 figure ---------- fig = plt.figure(figsize=(11.69, 8.27)) gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) def style_time_axis(ax, show_labels=True): ax.set_xticks(event_dates) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) if not show_labels: ax.tick_params(labelbottom=False) def get_plot_df(patient_data, col): """ Keep all visit dates. Missing values stay NaN so matplotlib draws gaps instead of zeros. """ tmp = patient_data[['MedDatum', col]].copy() tmp = tmp.rename(columns={'MedDatum': 'x', col: 'raw_y'}) tmp['y'] = to_numeric_comma(tmp['raw_y']) # aggregate if multiple rows exist on same date tmp = tmp.groupby('x', as_index=False)['y'].max() # merge onto full timeline so all dates remain visible plot_df = timeline.merge(tmp, on='x', how='left').sort_values('x') return plot_df # ---------- EDSS main plot ---------- ax_main = fig.add_subplot(gs[0, :]) ax_main.set_title(edss_title, fontsize=14, fontweight='bold') ax_main.set_ylabel("Score") ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) ax_main.grid(True, alpha=0.3) if edss_col in patient_data.columns: plot_df = get_plot_df(patient_data, edss_col) if plot_df['y'].notna().any(): # NaNs create visible gaps in the line ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') else: ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') else: ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') style_time_axis(ax_main) # ---------- Small aligned plots ---------- small_axes = [] for k, (col, title) in enumerate(functional_systems): r = 1 + (k // 4) c = (k % 4) ax = fig.add_subplot(gs[r, c], sharex=ax_main) small_axes.append(ax) ymax = ymax_by_col.get(col, default_ymax) ax.set_title(title, fontsize=10) ax.set_ylabel("Score") ax.set_ylim(0, ymax) ax.grid(True, alpha=0.3) if col in patient_data.columns: plot_df = get_plot_df(patient_data, col) if plot_df['y'].notna().any(): # NaNs remain in y -> line breaks where data is missing ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') else: ax.set_title(f"{title} (no numeric data)", fontsize=10) else: ax.set_title(f"{title} (missing)", fontsize=10) style_time_axis(ax) # Hide x tick labels on first row of small plots for ax in small_axes[:4]: ax.tick_params(labelbottom=False) plt.tight_layout() fig.subplots_adjust(hspace=0.7) plt.show() ## >>>>>>> Stashed changes # %% Table import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from datetime import datetime import numpy as np # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Convert MedDatum to datetime df['MedDatum'] = pd.to_datetime(df['MedDatum']) # Check what columns actually exist in the dataset print("Available columns:") print(df.columns.tolist()) print("\nFirst few rows:") print(df.head()) # Check data types print("\nData types:") print(df.dtypes) # Hardcode specific patient names patient_names = ['6ccda8c6'] # Define the functional systems (columns to plot) functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] # Create subplots num_plots = len(functional_systems) num_cols = 2 num_rows = (num_plots + num_cols - 1) // num_cols fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) if num_plots == 1: axes = [axes] elif num_rows == 1: axes = axes else: axes = axes.flatten() # Plot for the hardcoded patient for i, system in enumerate(functional_systems): # Filter data for this specific patient patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum') # Check if patient data exists if patient_data.empty: print(f"No data found for patient: {patient_names[0]}") axes[i].set_title(f'Functional System: {system} (No data)') axes[i].set_ylabel('Score') continue # Check if the system column exists if system in patient_data.columns: # Plot only valid data (non-null values) valid_data = patient_data.dropna(subset=[system]) if not valid_data.empty: # Ensure MedDatum is properly formatted for plotting axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system) axes[i].set_ylabel('Score') axes[i].set_title(f'Functional System: {system}') axes[i].grid(True, alpha=0.3) axes[i].legend() else: axes[i].set_title(f'Functional System: {system} (No valid data)') axes[i].set_ylabel('Score') else: # Try to find similar column names found_column = None for col in df.columns: if system.lower() in col.lower(): found_column = col break if found_column: valid_data = patient_data.dropna(subset=[found_column]) if not valid_data.empty: axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column) axes[i].set_ylabel('Score') axes[i].set_title(f'Functional System: {system} (found as: {found_column})') axes[i].grid(True, alpha=0.3) axes[i].legend() else: axes[i].set_title(f'Functional System: {system} (No valid data)') axes[i].set_ylabel('Score') else: axes[i].set_title(f'Functional System: {system} (Column not found)') axes[i].set_ylabel('Score') # Hide empty subplots for i in range(len(functional_systems), len(axes)): axes[i].set_visible(False) # Set x-axis label for the last row only for i in range(len(functional_systems)): if i >= len(axes) - num_cols: # Last row axes[i].set_xlabel('Date') # Format x-axis dates for ax in axes: if ax.get_lines(): # Only format if there are lines to plot ax.tick_params(axis='x', rotation=45) ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d')) # Automatically adjust layout plt.tight_layout() plt.show() ## # %% Histogram Fig1 import pandas as pd import matplotlib.pyplot as plt import matplotlib.font_manager as fm import json import os def create_visit_frequency_plot( file_path, output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', output_filename='visit_frequency_distribution.svg', fontsize=10, color_scheme_path='colors.json' ): """ Creates a publication-ready bar chart of patient visit frequency. Args: file_path (str): Path to the input TSV file. output_dir (str): Directory to save the output SVG file. output_filename (str): Name of the output SVG file. fontsize (int): Font size for all text elements (labels, title). color_scheme_path (str): Path to the JSON file containing the color palette. """ # --- 1. Load Data and Color Scheme --- try: df = pd.read_csv(file_path, sep='\t') print("Data loaded successfully.") # Sort data for easier visual comparison df = df.sort_values(by='Visits Count') except FileNotFoundError: print(f"Error: The file was not found at {file_path}") return try: with open(color_scheme_path, 'r') as f: colors = json.load(f) # Select a blue from the sequential palette for the bars bar_color = colors['sequential']['blues'][-2] # A saturated blue except FileNotFoundError: print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") bar_color = '#2171b5' # A common matplotlib blue # --- 2. Set up the Plot with Scientific Style --- plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height # Set the font to Arial arial_font = fm.FontProperties(family='Arial', size=fontsize) plt.rcParams['font.family'] = 'Arial' plt.rcParams['font.size'] = fontsize # --- 3. Create the Bar Chart --- ax = plt.gca() bars = plt.bar( x=df['Visits Count'], height=df['Unique Patients'], color=bar_color, edgecolor='black', linewidth=0.5, # Minimum line thickness width=0.7 ) # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- # Get the unique visit counts to use as tick labels visit_counts = df['Visits Count'].unique() # Set the x-ticks to be at the center of each bar ax.set_xticks(visit_counts) # Set the x-tick labels to be the visit counts, using the specified font ax.set_xticklabels(visit_counts, fontproperties=arial_font) # --- END OF NEW SECTION --- # --- 4. Customize Axes and Layout (Nature style) --- # Display only left and bottom axes ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Turn off axis ticks (the marks, not the labels) plt.tick_params(axis='both', which='both', length=0) # Remove grid lines plt.grid(False) # Set background to white (no shading) ax.set_facecolor('white') plt.gcf().set_facecolor('white') # --- 5. Add Labels and Title --- plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) # --- 6. Add y-axis values on top of each bar --- # This adds the count of unique patients directly above each bar. ax.bar_label(bars, fmt='%d', padding=3) # --- 7. Export the Figure --- # Ensure the output directory exists os.makedirs(output_dir, exist_ok=True) full_output_path = os.path.join(output_dir, output_filename) plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') print(f"\nFigure saved as '{full_output_path}'") # --- 8. (Optional) Display the Plot --- # plt.show() # --- Main execution --- if __name__ == '__main__': # Define the file path input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' # Call the function to create and save the plot create_visit_frequency_plot( file_path=input_file, fontsize=10 # Using a 10 pt font size as per guidelines ) ## # %% Scatter Plot functional system import pandas as pd import matplotlib.pyplot as plt import json import os # --- Configuration --- # Set the font to Arial for all text in the plot, as per the guidelines plt.rcParams['font.family'] = 'Arial' # Define the path to your data file data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' # Define the path to save the color mapping JSON file color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' # Define the path to save the final figure figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' # --- 1. Load the Dataset --- try: # Load the TSV file df = pd.read_csv(data_path, sep='\t') print(f"Successfully loaded data from {data_path}") print(f"Data shape: {df.shape}") except FileNotFoundError: print(f"Error: The file at {data_path} was not found.") # Exit or handle the error appropriately raise # --- 2. Define Functional Systems and Create Color Mapping --- # List of tuples containing (ground_truth_column, result_column) functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # Extract system names for color mapping and legend system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] # Define a professional color palette (dark blue theme) # This is a qualitative palette with distinct, accessible colors colors = [ '#003366', # Dark Blue '#336699', # Medium Blue '#6699CC', # Light Blue '#99CCFF', # Very Light Blue '#FF9966', # Coral '#FF6666', # Light Red '#CC6699', # Magenta '#9966CC' # Purple ] # Create a dictionary mapping system names to colors color_map = dict(zip(system_names, colors)) # Ensure the directory for the JSON file exists os.makedirs(os.path.dirname(color_json_path), exist_ok=True) # Save the color map to a JSON file with open(color_json_path, 'w') as f: json.dump(color_map, f, indent=4) print(f"Color mapping saved to {color_json_path}") # --- 3. Calculate Agreement Percentages and Format Legend Labels --- agreement_percentages = {} legend_labels = {} for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Ensure we are comparing the same rows common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) gt_data = gt_numeric.loc[common_index] res_data = res_numeric.loc[common_index] # Calculate agreement percentage if len(gt_data) > 0: agreement = (gt_data == res_data).mean() * 100 else: agreement = 0 # Handle case with no valid data agreement_percentages[system_name] = agreement # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" # --- 4. Reshape Data for Plotting --- plot_data = [] for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Create a temporary DataFrame with the numeric data temp_df = pd.DataFrame({ 'system': system_name, 'ground_truth': gt_numeric, 'inference': res_numeric }) # Drop rows where either value is NaN, as they cannot be plotted temp_df = temp_df.dropna() plot_data.append(temp_df) # Concatenate all the temporary DataFrames into one plot_df = pd.concat(plot_data, ignore_index=True) if plot_df.empty: print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") else: print(f"Prepared plot data with {len(plot_df)} data points.") # --- 5. Create the Scatter Plot --- plt.figure(figsize=(10, 8)) # Plot each functional system with its assigned color and formatted legend label for system, group in plot_df.groupby('system'): plt.scatter( group['ground_truth'], group['inference'], label=legend_labels[system], color=color_map[system], alpha=0.7, s=30 ) # Add a diagonal line representing perfect agreement (y = x) # This line helps visualize how close the predictions are to the ground truth if not plot_df.empty: plt.plot( [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], color='black', linestyle='--', linewidth=0.8, alpha=0.7 ) # --- 6. Apply Styling and Labels --- plt.xlabel('Ground Truth', fontsize=12) plt.ylabel('LLM Inference', fontsize=12) plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) # Apply scientific visualization styling rules ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.tick_params(axis='both', which='both', length=0) # Remove ticks ax.grid(False) # Remove grid lines plt.legend(title='Functional System', frameon=False, fontsize=10) # --- 7. Save and Display the Figure --- # Ensure the directory for the figure exists os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') print(f"Figure successfully saved to {figure_save_path}") # Display the plot plt.show() ## # %% Confusion Matrix functional systems import pandas as pd import matplotlib.pyplot as plt import json import os import numpy as np import matplotlib.colors as mcolors # --- Configuration --- plt.rcParams['font.family'] = 'Arial' data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' # --- 1. Load the Dataset --- df = pd.read_csv(data_path, sep='\t') # --- 2. Define Functional Systems and Colors --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] color_map = dict(zip(system_names, colors)) # --- 3. Categorization Function --- categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] category_to_index = {cat: i for i, cat in enumerate(categories)} n_categories = len(categories) def categorize_edss(value): if pd.isna(value): return np.nan # Ensure value is float to avoid TypeError val = float(value) idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 return categories[min(idx, len(categories)-1)] # --- 4. Prepare Mixed Color Matrix with Saturation --- cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): # Fix: Ensure numeric conversion to avoid string comparison errors temp_df = df[[gt_col, res_col]].copy() temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') valid_df = temp_df.dropna() for _, row in valid_df.iterrows(): gt_cat = categorize_edss(row[gt_col]) res_cat = categorize_edss(row[res_col]) if gt_cat in category_to_index and res_cat in category_to_index: cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 # Create an RGBA image matrix (10x10x4) rgba_matrix = np.zeros((n_categories, n_categories, 4)) total_counts = np.sum(cell_system_counts, axis=2) max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 for i in range(n_categories): for j in range(n_categories): count_sum = total_counts[i, j] if count_sum > 0: mixed_rgb = np.zeros(3) for s_idx, s_name in enumerate(system_names): weight = cell_system_counts[i, j, s_idx] / count_sum system_rgb = mcolors.to_rgb(color_map[s_name]) mixed_rgb += np.array(system_rgb) * weight # Set RGB channels rgba_matrix[i, j, :3] = mixed_rgb # Set Alpha channel (Saturation Effect) # Using a square root scale to make lower counts more visible but still "lighter" alpha = np.sqrt(count_sum / max_count) # Ensure alpha is at least 0.1 so it's not invisible rgba_matrix[i, j, 3] = max(0.1, alpha) else: # Empty cells are white rgba_matrix[i, j] = [1, 1, 1, 0] # --- 5. Plotting --- fig, ax = plt.subplots(figsize=(12, 10)) # Show the matrix # Note: we use origin='lower' if you want 0-1 at the bottom, # but confusion matrices usually have 0-1 at the top (origin='upper') im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') # Add count labels for i in range(n_categories): for j in range(n_categories): if total_counts[i, j] > 0: # Background brightness for text contrast bg_color = rgba_matrix[i, j, :3] lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] # If alpha is low, background is effectively white, so use black text text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=10, fontweight='bold') # --- 6. Styling --- ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) ax.set_xticks(np.arange(n_categories)) ax.set_xticklabels(categories) ax.set_yticks(np.arange(n_categories)) ax.set_yticklabels(categories) # Remove the frame/spines for a cleaner look for spine in ax.spines.values(): spine.set_visible(False) # Custom Legend handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] labels = [name.replace('_', ' ').capitalize() for name in system_names] ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False) plt.tight_layout() os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Difference Plot Functional system import pandas as pd import matplotlib.pyplot as plt import json import os import numpy as np # --- Configuration --- # Set the font to Arial for all text in the plot, as per the guidelines plt.rcParams['font.family'] = 'Arial' # Define the path to your data file data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' # Define the path to save the color mapping JSON file color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' # Define the path to save the final figure figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' # --- 1. Load the Dataset --- try: # Load the TSV file df = pd.read_csv(data_path, sep='\t') print(f"Successfully loaded data from {data_path}") print(f"Data shape: {df.shape}") except FileNotFoundError: print(f"Error: The file at {data_path} was not found.") # Exit or handle the error appropriately raise # --- 2. Define Functional Systems and Create Color Mapping --- # List of tuples containing (ground_truth_column, result_column) functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # Extract system names for color mapping and legend system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] # Define a professional color palette (dark blue theme) # This is a qualitative palette with distinct, accessible colors colors = [ '#003366', # Dark Blue '#336699', # Medium Blue '#6699CC', # Light Blue '#99CCFF', # Very Light Blue '#FF9966', # Coral '#FF6666', # Light Red '#CC6699', # Magenta '#9966CC' # Purple ] # Create a dictionary mapping system names to colors color_map = dict(zip(system_names, colors)) # Ensure the directory for the JSON file exists os.makedirs(os.path.dirname(color_json_path), exist_ok=True) # Save the color map to a JSON file with open(color_json_path, 'w') as f: json.dump(color_map, f, indent=4) print(f"Color mapping saved to {color_json_path}") # --- 3. Calculate Agreement Percentages and Format Legend Labels --- agreement_percentages = {} legend_labels = {} for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Ensure we are comparing the same rows common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) gt_data = gt_numeric.loc[common_index] res_data = res_numeric.loc[common_index] # Calculate agreement percentage if len(gt_data) > 0: agreement = (gt_data == res_data).mean() * 100 else: agreement = 0 # Handle case with no valid data agreement_percentages[system_name] = agreement # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" # --- 4. Robustly Prepare Error Data for Boxplot --- def safe_parse(s): '''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)''' if pd.isna(s): return np.nan if isinstance(s, (int, float)): return float(s) # Replace comma with dot, then strip whitespace s_clean = str(s).replace(',', '.').strip() try: return float(s_clean) except ValueError: return np.nan plot_data = [] for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Parse both columns with robust comma handling gt_numeric = df[gt_col].apply(safe_parse) res_numeric = df[res_col].apply(safe_parse) # Compute error (only where both are finite) error = res_numeric - gt_numeric # Create temp DataFrame temp_df = pd.DataFrame({ 'system': system_name, 'error': error }).dropna() # drop rows where either was unparseable plot_data.append(temp_df) plot_df = pd.concat(plot_data, ignore_index=True) if plot_df.empty: print("⚠️ Warning: No valid numeric error data to plot after robust parsing.") else: print(f"✅ Prepared error data with {len(plot_df)} data points.") # Diagnostic: show a few samples print("\n📌 Sample errors by system:") for sys, grp in plot_df.groupby('system'): print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}") # Ensure categorical ordering plot_df['system'] = pd.Categorical( plot_df['system'], categories=[name.split('.')[1] for name, _ in functional_systems_to_plot], ordered=True ) # --- 5. Prepare Data for Diverging Stacked Bar Plot --- print("\n📊 Preparing diverging stacked bar plot data...") # Define bins for error direction def categorize_error(err): if pd.isna(err): return 'missing' elif err < 0: return 'underestimate' elif err > 0: return 'overestimate' else: return 'match' # Add category column (only on finite errors) plot_df_clean = plot_df[plot_df['error'].notna()].copy() plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error) # Count by system + category category_counts = ( plot_df_clean .groupby(['system', 'category']) .size() .unstack(fill_value=0) .reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0) ) # Reorder systems category_counts = category_counts.reindex(system_names) # Prepare for diverging plot: # - Underestimates: plotted to the *left* (negative x) # - Overestimates: plotted to the *right* (positive x) # - Matches: centered (no width needed, or as a bar of width 0.2) underestimate_counts = category_counts['underestimate'] match_counts = category_counts['match'] overestimate_counts = category_counts['overestimate'] # For diverging: left = -underestimate, right = overestimate left_counts = underestimate_counts right_counts = overestimate_counts # Compute max absolute bar height (for symmetric x-axis) max_bar = max(left_counts.max(), right_counts.max(), 1) plot_range = (-max_bar, max_bar) # X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ... n_systems = len(system_names) positions = np.arange(n_systems) left_positions = -positions - 0.5 # left-aligned underestimates right_positions = positions + 0.5 # right-aligned overestimates # --- 6. Create Diverging Stacked Bar Plot --- plt.figure(figsize=(12, 7)) # Colors: diverging palette colors = { 'underestimate': '#E74C3C', # Red (left) 'match': '#2ECC71', # Green (center) 'overestimate': '#F39C12' # Orange (right) } # Plot underestimates (left side) bars_left = plt.barh( left_positions, left_counts.values, height=0.8, left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values) color=colors['underestimate'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Underestimate' ) # Plot overestimates (right side) bars_right = plt.barh( right_positions, right_counts.values, height=0.8, left=0, color=colors['overestimate'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Overestimate' ) # Plot matches (center — narrow bar) # Use a very narrow width (0.2) centered at 0 plt.barh( positions, match_counts.values, height=0.2, left=0, # starts at 0, extends right color=colors['match'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Exact Match' ) # ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match) # For perfect symmetry: for i, count in enumerate(match_counts.values): if count > 0: plt.barh( positions[i], width=count, left=-count/2, height=0.25, color=colors['match'], edgecolor='black', linewidth=0.5, alpha=0.95 ) # --- 7. Styling & Labels --- # Zero reference line plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8) # X-axis: symmetric around 0 plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1) plt.xticks(rotation=0, fontsize=10) plt.xlabel('Count', fontsize=12) # Y-axis: system names at original positions (centered) plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10) plt.ylabel('Functional System', fontsize=12) # Title & layout plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15) # Clean axes ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) # We only need bottom axis ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('none') # Grid only along x ax.xaxis.grid(True, linestyle=':', alpha=0.5) # Legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'), Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'), Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate') ] plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10) # Optional: Add counts on bars for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)): if left > 0: plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold') if right > 0: plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold') if match > 0: plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black') plt.tight_layout() # --- 8. Save & Show --- os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') print(f"✅ Diverging bar plot saved to {figure_save_path}") plt.show() ## # %% Difference Plot Gemini import pandas as pd import matplotlib.pyplot as plt import os import numpy as np # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg' # --- 1. Process Error Data with Magnitude Breakdown --- system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] plot_list = [] for gt_col, res_col in functional_systems_to_plot: sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) error = res - gt # Granular Counts matches = (error == 0).sum() u_1 = (error == -1).sum() u_2plus = (error <= -2).sum() o_1 = (error == 1).sum() o_2plus = (error >= 2).sum() total = error.dropna().count() divisor = max(total, 1) plot_list.append({ 'System': sys_name.replace('_', ' ').title(), 'Matches': matches, 'MatchPct': (matches / divisor) * 100, 'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus, 'UnderPct': ((u_1 + u_2plus) / divisor) * 100, 'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus, 'OverPct': ((o_1 + o_2plus) / divisor) * 100 }) stats_df = pd.DataFrame(plot_list) # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(13, 8)) # Define Magnitude Colors c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1) c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1) bar_height = 0.6 y_pos = np.arange(len(stats_df)) # Plot Under-scored (Stacked: -2+ then -1) ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white') ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white') # Plot Over-scored (Stacked: +1 then +2+) ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white') ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white') # --- 3. Aesthetics & Table Labels --- for i, row in stats_df.iterrows(): label_text = ( f"$\\mathbf{{{row['System']}}}$\n" f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)" ) # Position table text to the left ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4) # Formatting ax.axvline(0, color='black', linewidth=1.2) ax.set_yticks([]) ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold') #ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35) # Absolute X-axis labels ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) # Remove spines and add grid for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False) ax.xaxis.grid(True, linestyle='--', alpha=0.3) # Legend with magnitude info ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) plt.tight_layout() plt.show() ## # %% Functional System Error Boxplots import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_boxplot.svg' # --- 1. Build error data for boxplots --- boxplot_data = [] system_labels = [] sample_sizes = [] for gt_col, res_col in functional_systems_to_plot: sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Error = result - ground truth error = (res - gt).dropna() # Ignore all 0 errors error = error[error != 0] # Keep only systems that actually have non-zero data if len(error) > 0: clean_name = sys_name.replace('_', ' ').title() boxplot_data.append(error.values) system_labels.append(clean_name) sample_sizes.append(len(error)) # Safety check if not boxplot_data: raise ValueError("No valid non-zero error data available for any functional system.") # Put n into x-axis labels so it doesn't overlap the plot xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(14, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False ) # --- 3. Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) for flier in bp['fliers']: flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) # Reference line at zero error ax.axhline(0, color='black', linewidth=1.2, linestyle='--') # Labels and formatting ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') # Rotate x labels for readability plt.xticks(rotation=45, ha='right') # Grid and spines ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- 4. Legend above the plot, outside the axes --- legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False ) # Leave room at the top for the legend plt.tight_layout(rect=[0, 0, 1, 0.90]) # Optional save os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional System + EDSS Error Boxplots import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_boxplot.svg' # ------------------------------------------------------------ # Expect functional_systems_to_plot like: # [ # ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL_OPTIC_FUNCTIONS'), # ... # ] # # Add EDSS here: # ------------------------------------------------------------ all_systems_to_plot = list(functional_systems_to_plot) + [ ('GT.EDSS', 'result.EDSS') ] # --- 1. Build error data for boxplots --- boxplot_data = [] system_labels = [] sample_sizes = [] for gt_col, res_col in all_systems_to_plot: # Skip safely if a column is missing if gt_col not in df.columns or res_col not in df.columns: print(f"Skipping missing columns: {gt_col}, {res_col}") continue sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Error = result - ground truth error = (res - gt).dropna() # Ignore all 0 errors error = error[error != 0] # Keep only systems that actually have non-zero data if len(error) > 0: if sys_name == 'EDSS': clean_name = 'EDSS' else: clean_name = sys_name.replace('_', ' ').title() boxplot_data.append(error.values) system_labels.append(clean_name) sample_sizes.append(len(error)) # Safety check if not boxplot_data: raise ValueError("No valid non-zero error data available for any functional system or EDSS.") # Put n into x-axis labels so it doesn't overlap the plot xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(15, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False ) # --- 3. Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) for flier in bp['fliers']: flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) # Reference line at zero error ax.axhline(0, color='black', linewidth=1.2, linestyle='--') # Labels and formatting ax.set_xlabel('Functional System / EDSS', fontsize=11, fontweight='bold') ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') # Rotate x labels for readability plt.xticks(rotation=45, ha='right') # Grid and spines ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- 4. Legend above the plot, outside the axes --- legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False ) # Leave room at the top for the legend plt.tight_layout(rect=[0, 0, 1, 0.90]) # Optional save os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% test # Diagnose: what are the actual differences? print("\n🔍 Raw differences (first 5 rows per system):") for gt_col, res_col in functional_systems_to_plot: gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) diff = res - gt non_zero = (diff != 0).sum() # Check if it's due to floating point noise abs_diff = diff.abs() tiny = (abs_diff > 0) & (abs_diff < 1e-10) print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}") ## # %% Functional System Continuous Accuracy Boxplot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_continuous_accuracy_boxplot.svg' # --- Functional systems using your actual column names --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Build accuracy data --- boxplot_data = [] system_labels = [] predicted_counts = [] missing_prediction_counts = [] total_gt_counts = [] mean_accuracies = [] for gt_col, res_col in functional_systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Only rows where ground truth exists gt_exists = gt.notna() total_gt = gt_exists.sum() if total_gt == 0: print(f"Skipping {system_name}: no ground-truth values") continue gt_valid = gt[gt_exists] res_valid = res[gt_exists] # GT exists, but LLM prediction is missing missing_count = res_valid.isna().sum() # For the boxplot, use rows where both GT and result exist both_exist = res_valid.notna() if both_exist.sum() == 0: print(f"Skipping {system_name}: no predicted values") continue gt_eval = gt_valid[both_exist] res_eval = res_valid[both_exist] # Functional system score range. # Adjust if your functional systems use another scale. score_range = 5 # Continuous accuracy: # exact match = 1.0 # off by 1 point = 0.8 # off by 2 points = 0.6 # etc. abs_error = (res_eval - gt_eval).abs() accuracy = 1 - (abs_error / score_range) accuracy = accuracy.clip(lower=0, upper=1) clean_name = system_name.replace('_', ' ').title() boxplot_data.append(accuracy.values) system_labels.append(clean_name) predicted_counts.append(len(gt_eval)) missing_prediction_counts.append(missing_count) total_gt_counts.append(total_gt) mean_accuracies.append(accuracy.mean()) print( f"{clean_name}: " f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, " f"mean accuracy={accuracy.mean():.1%}" ) if not boxplot_data: raise ValueError("No valid accuracy data available for plotting.") # X-axis labels xtick_labels = [ f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}" for label, gt_n, pred_n, miss_n in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts) ] # --- Plot --- fig, ax = plt.subplots(figsize=(16, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False, widths=0.55 ) # --- Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set( marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6 ) for flier in bp['fliers']: flier.set( marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4 ) # Mean accuracy label above each box for i, acc in enumerate(mean_accuracies, start=1): ax.text( i, 1.03, f"{acc:.1%}", ha='center', va='bottom', fontsize=9, fontweight='bold' ) # Perfect accuracy reference line ax.axhline(1, color='black', linewidth=1.2, linestyle='--', alpha=0.7) # Labels and formatting ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') ax.set_ylabel('Continuous Accuracy', fontsize=11, fontweight='bold') ax.set_ylim(-0.05, 1.10) ax.set_yticks(np.arange(0, 1.01, 0.1)) ax.set_yticklabels([f"{int(y * 100)}%" for y in np.arange(0, 1.01, 0.1)]) plt.xticks(rotation=45, ha='right') ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # Legend legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR of continuous accuracy'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Perfect accuracy') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.06), ncol=5, frameon=False ) plt.tight_layout(rect=[0, 0, 1, 0.88]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional Systems + EDSS Continuous Accuracy Boxplot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_continuous_accuracy_boxplot.svg' # --- Functional systems + EDSS using your actual column names --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'), # EDSS ('GT.EDSS', 'result.EDSS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Build accuracy data --- boxplot_data = [] system_labels = [] predicted_counts = [] missing_prediction_counts = [] total_gt_counts = [] mean_accuracies = [] for gt_col, res_col in functional_systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Only rows where ground truth exists gt_exists = gt.notna() total_gt = gt_exists.sum() if total_gt == 0: print(f"Skipping {system_name}: no ground-truth values") continue gt_valid = gt[gt_exists] res_valid = res[gt_exists] # Count cases where GT exists but LLM prediction is missing missing_count = res_valid.isna().sum() # For the boxplot, use only rows where both GT and prediction exist both_exist = res_valid.notna() if both_exist.sum() == 0: print(f"Skipping {system_name}: no predicted values") continue gt_eval = gt_valid[both_exist] res_eval = res_valid[both_exist] # Functional systems are usually scored 0-5. # EDSS is usually scored 0-10. if system_name == "EDSS": score_range = 10 clean_name = "EDSS" else: score_range = 5 clean_name = system_name.replace('_', ' ').title() # Continuous accuracy: # exact match = 1.0 # off by 1 point in FS = 0.8 # off by 1 point in EDSS = 0.9 abs_error = (res_eval - gt_eval).abs() accuracy = 1 - (abs_error / score_range) # Keep values between 0 and 1 accuracy = accuracy.clip(lower=0, upper=1) boxplot_data.append(accuracy.values) system_labels.append(clean_name) predicted_counts.append(len(gt_eval)) missing_prediction_counts.append(missing_count) total_gt_counts.append(total_gt) mean_accuracies.append(accuracy.mean()) print( f"{clean_name}: " f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, " f"mean accuracy={accuracy.mean():.1%}" ) if not boxplot_data: raise ValueError("No valid accuracy data available for plotting.") # --- X-axis labels --- xtick_labels = [ f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}" for label, gt_n, pred_n, miss_n in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts) ] # --- Plot --- fig, ax = plt.subplots(figsize=(17, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False, widths=0.55 ) # --- Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set( facecolor=box_face, edgecolor=box_edge, linewidth=1.5 ) for whisker in bp['whiskers']: whisker.set( color=whisker_col, linewidth=1.2 ) for cap in bp['caps']: cap.set( color=whisker_col, linewidth=1.2 ) for median in bp['medians']: median.set( color=median_col, linewidth=2 ) for mean in bp['means']: mean.set( marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6 ) for flier in bp['fliers']: flier.set( marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4 ) # --- Mean accuracy labels above each box --- for i, acc in enumerate(mean_accuracies, start=1): ax.text( i, 1.03, f"{acc:.1%}", ha='center', va='bottom', fontsize=9, fontweight='bold' ) # --- Perfect accuracy reference line --- ax.axhline( 1, color='black', linewidth=1.2, linestyle='--', alpha=0.7 ) # --- Labels and formatting --- ax.set_xlabel( 'Functional System / EDSS', fontsize=11, fontweight='bold' ) ax.set_ylabel( 'Continuous Accuracy', fontsize=11, fontweight='bold' ) #ax.set_title( # 'Continuous Accuracy of Functional Systems and EDSS', # fontsize=14, # fontweight='bold', # pad=35 #) ax.set_ylim(-0.05, 1.10) yticks = np.arange(0, 1.01, 0.1) ax.set_yticks(yticks) ax.set_yticklabels([f"{int(y * 100)}%" for y in yticks]) plt.xticks(rotation=45, ha='right') ax.yaxis.grid(True, linestyle='--', alpha=0.3) ax.set_axisbelow(True) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- Legend --- legend_handles = [ Patch( facecolor=box_face, edgecolor=box_edge, label='IQR of continuous accuracy' ), Line2D( [0], [0], color=median_col, lw=2, label='Median' ), Line2D( [0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean' ), Line2D( [0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier' ), Line2D( [0], [0], color='black', lw=1.2, linestyle='--', label='Perfect accuracy' ) ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.08), ncol=5, frameon=False ) # --- Save and show --- plt.tight_layout(rect=[0, 0, 1, 0.86]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional Systems + EDSS Error Category Stacked Bar Plot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_error_categories.svg' # --- Functional systems + EDSS using your actual column names --- systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'), ('GT.EDSS', 'result.EDSS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Categorize absolute error --- def categorize_error(abs_error): if abs_error == 0: return "Exact" elif abs_error <= 0.5: return "≤0.5 error" elif abs_error <= 1: return "≤1 error" else: return ">1 error" # --- Prepare data --- rows = [] for gt_col, res_col in systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] if system_name == "EDSS": clean_name = "EDSS" else: clean_name = system_name.replace("_", " ").title() gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Evaluate only cases where ground truth exists gt_exists = gt.notna() gt_valid = gt[gt_exists] res_valid = res[gt_exists] if len(gt_valid) == 0: continue for gt_value, res_value in zip(gt_valid, res_valid): if pd.isna(res_value): category = "Missing" else: abs_error = abs(res_value - gt_value) category = categorize_error(abs_error) rows.append({ "system": clean_name, "category": category }) plot_df = pd.DataFrame(rows) if plot_df.empty: raise ValueError("No valid data available for plotting.") category_order = [ "Exact", "≤0.5 error", "≤1 error", ">1 error", "Missing" ] system_order = [ "Visual Optic Functions", "Cerebellar Functions", "Brainstem Functions", "Sensory Functions", "Pyramidal Functions", "Ambulation", "Cerebral Functions", "Bowel And Bladder Functions", "EDSS" ] counts = ( plot_df .groupby(["system", "category"]) .size() .unstack(fill_value=0) .reindex(index=system_order) .reindex(columns=category_order, fill_value=0) ) # Remove systems that were not available counts = counts.dropna(how="all") # Convert to percentages for easier comparison percentages = counts.div(counts.sum(axis=1), axis=0) * 100 # --- Plot --- fig, ax = plt.subplots(figsize=(13, 7)) colors = { "Exact": "#2ECC71", "≤0.5 error": "#A9DFBF", "≤1 error": "#F9E79F", ">1 error": "#E67E22", "Missing": "#E74C3C" } left = np.zeros(len(percentages)) for category in category_order: values = percentages[category].values ax.barh( percentages.index, values, left=left, color=colors[category], edgecolor="white", linewidth=0.8, label=category ) # Add labels only if segment is large enough for i, value in enumerate(values): if value >= 4: ax.text( left[i] + value / 2, i, f"{value:.1f}%", ha="center", va="center", fontsize=8, fontweight="bold" ) left += values # Add total n and missing count at the right side for i, system in enumerate(percentages.index): total_n = int(counts.loc[system].sum()) missing_n = int(counts.loc[system, "Missing"]) ax.text( 101, i, f"n={total_n}, missing={missing_n}", va="center", ha="left", fontsize=9 ) # --- Formatting --- ax.set_xlim(0, 115) ax.set_xlabel("Percentage of Cases", fontsize=11, fontweight="bold") ax.set_ylabel("Functional System / EDSS", fontsize=11, fontweight="bold") #ax.set_title( # "Prediction Error Categories by Functional System and EDSS", # fontsize=14, # fontweight="bold", # pad=20 #) ax.set_xticks(np.arange(0, 101, 10)) ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) ax.xaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right", "left"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=5, frameon=False ) plt.tight_layout(rect=[0, 0, 1, 0.92]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format="svg", bbox_inches="tight") plt.show() ## # %% Confusion matrices for iteration 1 of each model with shared color scale from pathlib import Path import re import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix, classification_report # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) OUTPUT_DIR = RUN_DIR / "confusion_matrices_iter_1" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) TARGET_ITERATION = 1 GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" # If you want to manually force a color maximum, set this to a number, e.g. 80. # If None, the script uses the largest cell count across all model confusion matrices. MANUAL_GLOBAL_VMAX = None EDSS_LABELS = [ r"$0 \leq x \leq 1$", r"$1 < x \leq 2$", r"$2 < x \leq 3$", r"$3 < x \leq 4$", r"$4 < x \leq 5$", r"$5 < x \leq 6$", r"$6 < x \leq 7$", r"$7 < x \leq 8$", r"$8 < x \leq 9$", r"$9 < x \leq 10$", ] # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def safe_name(name): return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) def categorize_edss(x): if pd.isna(x): return np.nan if x <= 1: return EDSS_LABELS[0] if x <= 2: return EDSS_LABELS[1] if x <= 3: return EDSS_LABELS[2] if x <= 4: return EDSS_LABELS[3] if x <= 5: return EDSS_LABELS[4] if x <= 6: return EDSS_LABELS[5] if x <= 7: return EDSS_LABELS[6] if x <= 8: return EDSS_LABELS[7] if x <= 9: return EDSS_LABELS[8] if x <= 10: return EDSS_LABELS[9] return np.nan def find_iter_file(model_dir): files = sorted(model_dir.glob(f"*results_iter_{TARGET_ITERATION}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(pred, model_dir): if "model" in pred.columns and pred["model"].notna().any(): return str(pred["model"].dropna().iloc[0]) return model_dir.name # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt["GT_EDSS_cat"] = gt["GT_EDSS_numeric"].apply(categorize_edss) print(f"GT rows: {len(gt)}") print(f"GT numeric EDSS rows: {gt['GT_EDSS_numeric'].notna().sum()}") # ========================= # FIRST PASS: COMPUTE ALL CONFUSION MATRICES # ========================= model_results = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and p.name != OUTPUT_DIR.name ] for model_dir in model_dirs: result_file = find_iter_file(model_dir) if result_file is None: print(f"\nNo iteration {TARGET_ITERATION} result CSV found in: {model_dir}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print("Skipping: row_index column missing.") continue model_name = get_model_name(pred_raw, model_dir) safe_model = safe_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) raw_rows = len(pred) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() # For confusion matrix, use only rows where model produced numeric EDSS in valid range. if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["PRED_EDSS_cat"] = pred["PRED_EDSS_numeric"].apply(categorize_edss) pred = pred.dropna(subset=["PRED_EDSS_numeric", "PRED_EDSS_cat"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) eval_df = merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy() print(f"Raw prediction rows: {raw_rows}") print(f"Prediction rows after filters: {len(pred)}") print(f"Merged rows: {len(merged)}") print(f"Evaluable rows: {len(eval_df)}") if eval_df.empty: print("No evaluable rows. Skipping.") continue cm = confusion_matrix( eval_df["GT_EDSS_cat"], eval_df["PRED_EDSS_cat"], labels=EDSS_LABELS ) report = classification_report( eval_df["GT_EDSS_cat"], eval_df["PRED_EDSS_cat"], labels=EDSS_LABELS, zero_division=0 ) cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS) cm_df.index.name = "Ground Truth EDSS" cm_df.columns.name = "LLM Generated EDSS" print("\nClassification Report:") print(report) print("\nConfusion Matrix:") print(cm_df) model_results.append({ "model_name": model_name, "safe_model": safe_model, "model_dir": model_dir, "result_file": result_file, "raw_rows": raw_rows, "pred_rows_after_filters": len(pred), "merged_rows": len(merged), "evaluable_rows": len(eval_df), "cm": cm, "cm_df": cm_df, "report": report, "eval_df": eval_df, }) if not model_results: raise RuntimeError("No confusion matrices were computed. Check paths and result files.") # ========================= # SHARED COLOR SCALE # ========================= if MANUAL_GLOBAL_VMAX is not None: GLOBAL_VMAX = MANUAL_GLOBAL_VMAX else: GLOBAL_VMAX = max(item["cm"].max() for item in model_results) print("\n" + "=" * 100) print(f"Shared heatmap color scale: vmin=0, vmax={GLOBAL_VMAX}") print("=" * 100) # ========================= # SECOND PASS: SAVE PLOTS AND FILES # ========================= summaries = [] for item in model_results: model_name = item["model_name"] safe_model = item["safe_model"] result_file = item["result_file"] cm = item["cm"] cm_df = item["cm_df"] report = item["report"] eval_df = item["eval_df"] svg_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.svg" png_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.png" csv_path = OUTPUT_DIR / f"{safe_model}_confusion_matrix_iter_{TARGET_ITERATION}.csv" report_path = OUTPUT_DIR / f"{safe_model}_classification_report_iter_{TARGET_ITERATION}.txt" merged_path = OUTPUT_DIR / f"{safe_model}_merged_eval_rows_iter_{TARGET_ITERATION}.csv" plt.figure(figsize=(11, 9)) ax = sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", vmin=0, vmax=GLOBAL_VMAX, xticklabels=EDSS_LABELS, yticklabels=EDSS_LABELS ) ax.collections[0].colorbar.set_label( "Number of Cases", rotation=270, labelpad=20 ) plt.xlabel("LLM Generated EDSS") plt.ylabel("Ground Truth EDSS") plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}") plt.xticks(rotation=45, ha="right") plt.yticks(rotation=0) plt.tight_layout() plt.savefig(svg_path, format="svg", bbox_inches="tight") plt.savefig(png_path, dpi=300, bbox_inches="tight") plt.show() cm_df.to_csv(csv_path) with open(report_path, "w", encoding="utf-8") as f: f.write(f"Model: {model_name}\n") f.write(f"Result file: {result_file}\n") f.write(f"Iteration: {TARGET_ITERATION}\n") f.write(f"Shared color scale vmax: {GLOBAL_VMAX}\n") f.write(f"Raw prediction rows: {item['raw_rows']}\n") f.write(f"Prediction rows after filters: {item['pred_rows_after_filters']}\n") f.write(f"Merged rows: {item['merged_rows']}\n") f.write(f"Evaluable rows: {item['evaluable_rows']}\n\n") f.write("Classification Report:\n") f.write(report) f.write("\n\nConfusion Matrix:\n") f.write(cm_df.to_string()) keep_cols = [ "row_index", "unique_id_gt" if "unique_id_gt" in eval_df.columns else "unique_id", "unique_id_pred" if "unique_id_pred" in eval_df.columns else None, "MedDatum_gt" if "MedDatum_gt" in eval_df.columns else "MedDatum", "MedDatum_pred" if "MedDatum_pred" in eval_df.columns else None, "model", "iteration", "success", "klassifizierbar", "clinical_output_valid", "edss_logic_valid", "GT_EDSS_numeric", "PRED_EDSS_numeric", "GT_EDSS_cat", "PRED_EDSS_cat", "raw_EDSS", "EDSS_numeric", "EDSS_in_valid_range", "certainty_percent", "reason", "inference_time_sec", ] keep_cols = [ c for c in keep_cols if c is not None and c in eval_df.columns ] eval_df[keep_cols].to_csv(merged_path, index=False) print("\nSaved:") print(svg_path) print(png_path) print(csv_path) print(report_path) print(merged_path) summaries.append({ "model": model_name, "result_file": str(result_file), "iteration": TARGET_ITERATION, "raw_prediction_rows": item["raw_rows"], "prediction_rows_after_filters": item["pred_rows_after_filters"], "merged_rows": item["merged_rows"], "evaluable_rows": item["evaluable_rows"], "shared_color_vmax": GLOBAL_VMAX, "svg_path": str(svg_path), "png_path": str(png_path), "csv_path": str(csv_path), "report_path": str(report_path), "merged_path": str(merged_path), }) # ========================= # SAVE SUMMARY # ========================= summary_df = pd.DataFrame(summaries) summary_path = OUTPUT_DIR / f"confusion_matrix_summary_iter_{TARGET_ITERATION}.csv" summary_df.to_csv(summary_path, index=False) print("\n" + "=" * 100) print("Done.") print(f"Summary saved to: {summary_path}") print(f"Shared color scale vmax: {GLOBAL_VMAX}") print("=" * 100) ## # %% EDSS metrics across models for new benchmark run from pathlib import Path import pandas as pd import numpy as np from sklearn.metrics import mean_absolute_error, mean_squared_error, cohen_kappa_score from scipy.stats import spearmanr # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_PATH = RUN_DIR / f"edss_metrics_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def find_iter_file(model_dir, target_iteration): files = sorted(model_dir.glob(f"*results_iter_{target_iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def safe_rate(numerator, denominator): if denominator == 0: return np.nan return numerator / denominator # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) n_total_gt_rows = len(gt) n_gt_numeric = gt["GT_EDSS_numeric"].notna().sum() gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows: {n_total_gt_rows}") print(f"GT numeric EDSS rows: {n_gt_numeric}") # ========================= # EVALUATE MODELS # ========================= rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"\nNo iter_{TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") raw_prediction_rows = len(pred_raw) if "row_index" not in pred_raw.columns: print("Skipping: row_index missing") continue model_name = ( pred_raw["model"].dropna().iloc[0] if "model" in pred_raw.columns and pred_raw["model"].notna().any() else model_dir.name ) # ------------------------- # Diagnostics before filters # ------------------------- n_success = to_bool(pred_raw["success"]).sum() if "success" in pred_raw.columns else np.nan n_clinical_output_valid = ( to_bool(pred_raw["clinical_output_valid"]).sum() if "clinical_output_valid" in pred_raw.columns else np.nan ) n_edss_logic_valid = ( to_bool(pred_raw["edss_logic_valid"]).sum() if "edss_logic_valid" in pred_raw.columns else np.nan ) n_klassifizierbar_true = ( to_bool(pred_raw["klassifizierbar"]).sum() if "klassifizierbar" in pred_raw.columns else np.nan ) n_edss_numeric = ( to_bool(pred_raw["EDSS_is_numeric"]).sum() if "EDSS_is_numeric" in pred_raw.columns else np.nan ) n_edss_valid_range = ( to_bool(pred_raw["EDSS_in_valid_range"]).sum() if "EDSS_in_valid_range" in pred_raw.columns else np.nan ) print("Raw prediction rows:", raw_prediction_rows) print("success=True:", n_success) print("clinical_output_valid=True:", n_clinical_output_valid) print("edss_logic_valid=True:", n_edss_logic_valid) print("klassifizierbar=True:", n_klassifizierbar_true) print("EDSS_is_numeric=True:", n_edss_numeric) print("EDSS_in_valid_range=True:", n_edss_valid_range) print("unique row_index:", pred_raw["row_index"].nunique()) print("GT numeric EDSS:", n_gt_numeric) # ------------------------- # Prepare predictions # ------------------------- pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() # For EDSS score accuracy, use only predictions where model actually gave a numeric EDSS. # This automatically excludes valid abstentions with klassifizierbar=false and EDSS=null. if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() n_after_filtering = len(pred) merged = gt_numeric.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) n_evaluable = len(merged) if n_evaluable == 0: print("No evaluable rows. Skipping metrics.") continue # ------------------------- # Metrics # ------------------------- merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() mae = mean_absolute_error( merged["GT_EDSS_numeric"], merged["PRED_EDSS_numeric"] ) rmse = np.sqrt( mean_squared_error( merged["GT_EDSS_numeric"], merged["PRED_EDSS_numeric"] ) ) median_abs_error = merged["abs_error"].median() mean_signed_error = merged["error"].mean() exact_accuracy_valid_only = (merged["abs_error"] == 0).mean() accuracy_within_05_valid_only = (merged["abs_error"] <= 0.5).mean() accuracy_within_10_valid_only = (merged["abs_error"] <= 1.0).mean() exact_correct_count = int((merged["abs_error"] == 0).sum()) within_05_count = int((merged["abs_error"] <= 0.5).sum()) within_10_count = int((merged["abs_error"] <= 1.0).sum()) # Coverage-adjusted accuracies use all GT numeric rows as denominator. # Missing/abstained/non-numeric predictions count as not correct. exact_accuracy_all_gt_numeric = safe_rate(exact_correct_count, n_gt_numeric) accuracy_within_05_all_gt_numeric = safe_rate(within_05_count, n_gt_numeric) accuracy_within_10_all_gt_numeric = safe_rate(within_10_count, n_gt_numeric) coverage_gt_numeric = safe_rate(n_evaluable, n_gt_numeric) coverage_all_rows = safe_rate(n_evaluable, n_total_gt_rows) if n_evaluable > 1: spearman_rho, spearman_p = spearmanr( merged["GT_EDSS_numeric"], merged["PRED_EDSS_numeric"] ) else: spearman_rho, spearman_p = np.nan, np.nan gt_half_steps = (merged["GT_EDSS_numeric"] * 2).round().astype(int) pred_half_steps = (merged["PRED_EDSS_numeric"] * 2).round().astype(int) quadratic_weighted_kappa = cohen_kappa_score( gt_half_steps, pred_half_steps, weights="quadratic" ) mean_inference_time = ( merged["inference_time_sec"].mean() if "inference_time_sec" in merged.columns else np.nan ) rows.append({ "model": model_name, "result_file": str(result_file), "iteration": TARGET_ITERATION, "n_total_gt_rows": n_total_gt_rows, "n_gt_numeric": n_gt_numeric, "raw_prediction_rows": raw_prediction_rows, "n_success": n_success, "success_rate": safe_rate(n_success, raw_prediction_rows), "n_clinical_output_valid": n_clinical_output_valid, "clinical_output_valid_rate": safe_rate(n_clinical_output_valid, raw_prediction_rows), "n_edss_logic_valid": n_edss_logic_valid, "edss_logic_valid_rate": safe_rate(n_edss_logic_valid, raw_prediction_rows), "n_klassifizierbar_true": n_klassifizierbar_true, "klassifizierbar_true_rate": safe_rate(n_klassifizierbar_true, raw_prediction_rows), "n_EDSS_numeric": n_edss_numeric, "EDSS_numeric_rate": safe_rate(n_edss_numeric, raw_prediction_rows), "n_EDSS_valid_range": n_edss_valid_range, "EDSS_valid_range_rate": safe_rate(n_edss_valid_range, raw_prediction_rows), "n_after_filtering": n_after_filtering, "n_evaluable": n_evaluable, "coverage_gt_numeric": coverage_gt_numeric, "coverage_gt_numeric_percent": coverage_gt_numeric * 100, "coverage_all_rows": coverage_all_rows, "coverage_all_rows_percent": coverage_all_rows * 100, "MAE_valid_only": mae, "median_absolute_error_valid_only": median_abs_error, "RMSE_valid_only": rmse, "mean_signed_error_valid_only": mean_signed_error, "exact_accuracy_valid_only": exact_accuracy_valid_only, "accuracy_within_0_5_valid_only": accuracy_within_05_valid_only, "accuracy_within_1_0_valid_only": accuracy_within_10_valid_only, "exact_accuracy_valid_only_percent": exact_accuracy_valid_only * 100, "accuracy_within_0_5_valid_only_percent": accuracy_within_05_valid_only * 100, "accuracy_within_1_0_valid_only_percent": accuracy_within_10_valid_only * 100, "exact_accuracy_all_gt_numeric": exact_accuracy_all_gt_numeric, "accuracy_within_0_5_all_gt_numeric": accuracy_within_05_all_gt_numeric, "accuracy_within_1_0_all_gt_numeric": accuracy_within_10_all_gt_numeric, "exact_accuracy_all_gt_numeric_percent": exact_accuracy_all_gt_numeric * 100, "accuracy_within_0_5_all_gt_numeric_percent": accuracy_within_05_all_gt_numeric * 100, "accuracy_within_1_0_all_gt_numeric_percent": accuracy_within_10_all_gt_numeric * 100, "spearman_rho": spearman_rho, "spearman_p": spearman_p, "quadratic_weighted_kappa": quadratic_weighted_kappa, "mean_inference_time_sec": mean_inference_time, }) print("\nMetrics:") print(f"Model: {model_name}") print(f"n_evaluable: {n_evaluable}") print(f"Coverage of GT numeric rows: {coverage_gt_numeric * 100:.1f}%") print(f"MAE: {mae:.3f}") print(f"Median AE: {median_abs_error:.3f}") print(f"RMSE: {rmse:.3f}") print(f"Mean signed error: {mean_signed_error:.3f}") print(f"Exact accuracy valid-only: {exact_accuracy_valid_only * 100:.1f}%") print(f"Accuracy ±0.5 valid-only: {accuracy_within_05_valid_only * 100:.1f}%") print(f"Accuracy ±1.0 valid-only: {accuracy_within_10_valid_only * 100:.1f}%") print(f"Accuracy ±0.5 all GT numeric: {accuracy_within_05_all_gt_numeric * 100:.1f}%") print(f"Spearman rho: {spearman_rho:.3f}") print(f"Quadratic weighted kappa: {quadratic_weighted_kappa:.3f}") print(f"Mean inference time: {mean_inference_time:.3f} sec") # ========================= # SAVE METRICS TABLE # ========================= metrics_df = pd.DataFrame(rows) if not metrics_df.empty: metrics_df = metrics_df.sort_values("MAE_valid_only") pd.set_option("display.max_columns", None) pd.set_option("display.width", 240) print("\n" + "=" * 100) print("EDSS model comparison metrics:") print(metrics_df) metrics_df.to_csv(OUTPUT_PATH, index=False) print(f"\nSaved metrics table to:") print(OUTPUT_PATH) ## # %% Per-patient repeated-run variability across 10 EDSS runs from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) OUTPUT_DIR = RUN_DIR / "repeated_run_variability" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) N_EXPECTED_RUNS = 10 PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" # Use only valid numeric EDSS predictions. # This excludes valid abstentions where klassifizierbar=false and EDSS=null. USE_ONLY_VALID_RANGE_EDSS = True # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def find_iteration_files(model_dir): files = sorted(model_dir.glob("*results_iter_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files def load_model_all_iterations(model_dir): files = find_iteration_files(model_dir) if not files: return pd.DataFrame(), [] dfs = [] for file in files: df = pd.read_csv(file, sep=",") if "iteration" not in df.columns: print(f"Skipping {file}: no iteration column") continue if "row_index" not in df.columns: print(f"Skipping {file}: no row_index column") continue pred_col = PRED_EDSS_COL if PRED_EDSS_COL in df.columns else PRED_EDSS_FALLBACK_COL df = df.copy() df["source_file"] = str(file) df["row_index"] = pd.to_numeric(df["row_index"], errors="coerce") df["iteration"] = pd.to_numeric(df["iteration"], errors="coerce") df["EDSS_prediction"] = to_num(df[pred_col]) df = df.dropna(subset=["row_index", "iteration"]).copy() df["row_index"] = df["row_index"].astype(int) df["iteration"] = df["iteration"].astype(int) if "success" in df.columns: df = df[to_bool(df["success"])].copy() if "EDSS_is_numeric" in df.columns: df = df[to_bool(df["EDSS_is_numeric"])].copy() if USE_ONLY_VALID_RANGE_EDSS and "EDSS_in_valid_range" in df.columns: df = df[to_bool(df["EDSS_in_valid_range"])].copy() df = df.dropna(subset=["EDSS_prediction"]).copy() keep_cols = [ "model", "iteration", "row_index", "row_number_in_run", "unique_id", "MedDatum", "EDSS_prediction", "EDSS_numeric", "EDSS", "EDSS_is_numeric", "EDSS_in_valid_range", "klassifizierbar", "clinical_output_valid", "edss_logic_valid", "certainty_percent", "inference_time_sec", "source_file", ] keep_cols = [c for c in keep_cols if c in df.columns] dfs.append(df[keep_cols]) if not dfs: return pd.DataFrame(), files all_df = pd.concat(dfs, ignore_index=True) # If there are duplicate row_index + iteration rows, keep first. all_df = all_df.sort_values(["row_index", "iteration"]) all_df = all_df.drop_duplicates(subset=["row_index", "iteration"], keep="first") return all_df, files def summarize_patient_variability(all_df, model_name): """ One row per patient / row_index. """ grouped = all_df.groupby("row_index") patient_rows = [] for row_index, g in grouped: preds = g["EDSS_prediction"].dropna().astype(float) n_valid_runs = len(preds) if n_valid_runs == 0: continue unique_id = g["unique_id"].dropna().iloc[0] if "unique_id" in g.columns and g["unique_id"].notna().any() else None meddatum = g["MedDatum"].dropna().iloc[0] if "MedDatum" in g.columns and g["MedDatum"].notna().any() else None edss_mean = preds.mean() edss_std = preds.std(ddof=0) # population SD across repeated runs edss_median = preds.median() edss_min = preds.min() edss_max = preds.max() edss_range = edss_max - edss_min n_unique_predictions = preds.nunique() identical_all_available_runs = n_unique_predictions == 1 range_leq_0_5 = edss_range <= 0.5 complete_10_valid_runs = n_valid_runs == N_EXPECTED_RUNS patient_rows.append({ "model": model_name, "row_index": row_index, "unique_id": unique_id, "MedDatum": meddatum, "n_valid_runs": n_valid_runs, "complete_10_valid_runs": complete_10_valid_runs, "EDSS_mean_across_runs": edss_mean, "EDSS_median_across_runs": edss_median, "EDSS_std_across_runs": edss_std, "EDSS_min_across_runs": edss_min, "EDSS_max_across_runs": edss_max, "EDSS_range_across_runs": edss_range, "n_unique_EDSS_predictions": n_unique_predictions, "identical_EDSS_all_available_runs": identical_all_available_runs, "EDSS_range_leq_0_5": range_leq_0_5, "iterations_available": ",".join(map(str, sorted(g["iteration"].unique()))), "EDSS_predictions_by_iteration": ";".join( f"{int(row.iteration)}:{row.EDSS_prediction}" for row in g.sort_values("iteration").itertuples() ), }) return pd.DataFrame(patient_rows) def summarize_model_variability(patient_df, all_df, model_name, n_source_files): if patient_df.empty: return { "model": model_name, "n_source_iteration_files": n_source_files, "n_patients_with_at_least_one_valid_prediction": 0, } n_patients = len(patient_df) complete_df = patient_df[patient_df["complete_10_valid_runs"]].copy() summary = { "model": model_name, "n_source_iteration_files": n_source_files, "n_total_prediction_rows_valid": len(all_df), "n_patients_with_at_least_one_valid_prediction": n_patients, "n_patients_with_10_valid_runs": len(complete_df), "patients_with_10_valid_runs_percent": len(complete_df) / n_patients * 100 if n_patients else np.nan, # Main variability metrics across all patients with at least one valid prediction "mean_std_EDSS_across_runs": patient_df["EDSS_std_across_runs"].mean(), "median_std_EDSS_across_runs": patient_df["EDSS_std_across_runs"].median(), "mean_range_EDSS_across_runs": patient_df["EDSS_range_across_runs"].mean(), "median_range_EDSS_across_runs": patient_df["EDSS_range_across_runs"].median(), "percent_identical_EDSS_all_available_runs": patient_df["identical_EDSS_all_available_runs"].mean() * 100, "percent_EDSS_range_leq_0_5": patient_df["EDSS_range_leq_0_5"].mean() * 100, "mean_n_valid_runs_per_patient": patient_df["n_valid_runs"].mean(), "median_n_valid_runs_per_patient": patient_df["n_valid_runs"].median(), "min_n_valid_runs_per_patient": patient_df["n_valid_runs"].min(), "max_n_valid_runs_per_patient": patient_df["n_valid_runs"].max(), } # Same metrics restricted to patients with all 10 valid runs if not complete_df.empty: summary.update({ "mean_std_EDSS_10_valid_runs_only": complete_df["EDSS_std_across_runs"].mean(), "median_std_EDSS_10_valid_runs_only": complete_df["EDSS_std_across_runs"].median(), "mean_range_EDSS_10_valid_runs_only": complete_df["EDSS_range_across_runs"].mean(), "median_range_EDSS_10_valid_runs_only": complete_df["EDSS_range_across_runs"].median(), "percent_identical_EDSS_10_valid_runs_only": complete_df["identical_EDSS_all_available_runs"].mean() * 100, "percent_EDSS_range_leq_0_5_10_valid_runs_only": complete_df["EDSS_range_leq_0_5"].mean() * 100, }) else: summary.update({ "mean_std_EDSS_10_valid_runs_only": np.nan, "median_std_EDSS_10_valid_runs_only": np.nan, "mean_range_EDSS_10_valid_runs_only": np.nan, "median_range_EDSS_10_valid_runs_only": np.nan, "percent_identical_EDSS_10_valid_runs_only": np.nan, "percent_EDSS_range_leq_0_5_10_valid_runs_only": np.nan, }) return summary # ========================= # MAIN # ========================= model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and p.name != "repeated_run_variability" ] all_model_summaries = [] for model_dir in model_dirs: print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") all_df, source_files = load_model_all_iterations(model_dir) if all_df.empty: print("No valid prediction data found. Skipping.") continue model_name = ( all_df["model"].dropna().iloc[0] if "model" in all_df.columns and all_df["model"].notna().any() else model_dir.name ) print(f"Model name: {model_name}") print(f"Iteration files found: {len(source_files)}") print(f"Valid numeric EDSS prediction rows: {len(all_df)}") print(f"Patients with at least one valid EDSS prediction: {all_df['row_index'].nunique()}") patient_df = summarize_patient_variability(all_df, model_name) model_summary = summarize_model_variability( patient_df=patient_df, all_df=all_df, model_name=model_name, n_source_files=len(source_files), ) all_model_summaries.append(model_summary) safe_model = model_name.replace("/", "_").replace(" ", "_") patient_out = OUTPUT_DIR / f"{safe_model}_per_patient_repeated_run_variability.csv" all_preds_out = OUTPUT_DIR / f"{safe_model}_all_valid_predictions_long.csv" patient_df.to_csv(patient_out, index=False) all_df.to_csv(all_preds_out, index=False) print("\nMain variability metrics:") print(f"Mean SD across runs: {model_summary['mean_std_EDSS_across_runs']:.3f}") print(f"Median SD across runs: {model_summary['median_std_EDSS_across_runs']:.3f}") print(f"Identical EDSS all available runs: {model_summary['percent_identical_EDSS_all_available_runs']:.1f}%") print(f"Range <= 0.5: {model_summary['percent_EDSS_range_leq_0_5']:.1f}%") print(f"Patients with 10 valid runs: {model_summary['n_patients_with_10_valid_runs']}") print("\nSaved:") print(patient_out) print(all_preds_out) summary_df = pd.DataFrame(all_model_summaries) summary_out = OUTPUT_DIR / "repeated_run_variability_summary.csv" summary_df.to_csv(summary_out, index=False) pd.set_option("display.max_columns", None) pd.set_option("display.width", 240) print("\n" + "=" * 100) print("Repeated-run variability summary:") print(summary_df) print("\nSaved summary to:") print(summary_out) ## # %% Functional system performance per domain - iteration 1 from pathlib import Path import pandas as pd import numpy as np from sklearn.metrics import mean_absolute_error, mean_squared_error from scipy.stats import spearmanr # ========================= # PATHS # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"functional_system_metrics_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_FULL_TABLE = OUTPUT_DIR / f"functional_system_metrics_full_iter_{TARGET_ITERATION}.csv" OUTPUT_SHORT_TABLE = OUTPUT_DIR / f"functional_system_metrics_short_iter_{TARGET_ITERATION}.csv" # ========================= # FUNCTIONAL SYSTEM MAPPING # ========================= FS_MAP = { "VISUAL_OPTIC_FUNCTIONS": { "display": "Visual/optic functions", "gt": "Sehvermögen", "pred": "numeric_subcat_VISUAL_OPTIC_FUNCTIONS", "fallback": "subcat_VISUAL_OPTIC_FUNCTIONS", "numeric_flag": "subcat_VISUAL_OPTIC_FUNCTIONS_is_numeric", "range_flag": "subcat_VISUAL_OPTIC_FUNCTIONS_in_valid_range", }, "BRAINSTEM_FUNCTIONS": { "display": "Brainstem functions", "gt": "Hirnstamm", "pred": "numeric_subcat_BRAINSTEM_FUNCTIONS", "fallback": "subcat_BRAINSTEM_FUNCTIONS", "numeric_flag": "subcat_BRAINSTEM_FUNCTIONS_is_numeric", "range_flag": "subcat_BRAINSTEM_FUNCTIONS_in_valid_range", }, "PYRAMIDAL_FUNCTIONS": { "display": "Pyramidal functions", "gt": "Pyramidalmotorik", "pred": "numeric_subcat_PYRAMIDAL_FUNCTIONS", "fallback": "subcat_PYRAMIDAL_FUNCTIONS", "numeric_flag": "subcat_PYRAMIDAL_FUNCTIONS_is_numeric", "range_flag": "subcat_PYRAMIDAL_FUNCTIONS_in_valid_range", }, "CEREBELLAR_FUNCTIONS": { "display": "Cerebellar functions", "gt": "Cerebellum", "pred": "numeric_subcat_CEREBELLAR_FUNCTIONS", "fallback": "subcat_CEREBELLAR_FUNCTIONS", "numeric_flag": "subcat_CEREBELLAR_FUNCTIONS_is_numeric", "range_flag": "subcat_CEREBELLAR_FUNCTIONS_in_valid_range", }, "SENSORY_FUNCTIONS": { "display": "Sensory functions", "gt": "Sensibiliät", "pred": "numeric_subcat_SENSORY_FUNCTIONS", "fallback": "subcat_SENSORY_FUNCTIONS", "numeric_flag": "subcat_SENSORY_FUNCTIONS_is_numeric", "range_flag": "subcat_SENSORY_FUNCTIONS_in_valid_range", }, "BOWEL_AND_BLADDER_FUNCTIONS": { "display": "Bowel and bladder functions", "gt": "Blasen-_und_Mastdarmfunktion", "pred": "numeric_subcat_BOWEL_AND_BLADDER_FUNCTIONS", "fallback": "subcat_BOWEL_AND_BLADDER_FUNCTIONS", "numeric_flag": "subcat_BOWEL_AND_BLADDER_FUNCTIONS_is_numeric", "range_flag": "subcat_BOWEL_AND_BLADDER_FUNCTIONS_in_valid_range", }, "CEREBRAL_FUNCTIONS": { "display": "Cerebral functions", "gt": "Cerebrale_Funktion", "pred": "numeric_subcat_CEREBRAL_FUNCTIONS", "fallback": "subcat_CEREBRAL_FUNCTIONS", "numeric_flag": "subcat_CEREBRAL_FUNCTIONS_is_numeric", "range_flag": "subcat_CEREBRAL_FUNCTIONS_in_valid_range", }, "AMBULATION": { "display": "Ambulation", "gt": "Ambulation", "pred": "numeric_subcat_AMBULATION", "fallback": "subcat_AMBULATION", "numeric_flag": "subcat_AMBULATION_is_numeric", "range_flag": "subcat_AMBULATION_in_valid_range", }, } # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def rate(n, d): if d == 0: return np.nan return n / d def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index print(f"GT rows: {len(gt)}") for fs_key, info in FS_MAP.items(): if info["gt"] not in gt.columns: print(f"WARNING missing GT column: {info['gt']}") else: gt_num = to_num(gt[info["gt"]]) print( f"{info['display']}: " f"GT numeric={gt_num.notna().sum()}, " f"GT non-zero={(gt_num.dropna() != 0).sum()}" ) # ========================= # MAIN ANALYSIS # ========================= rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and p.name not in [ f"functional_system_metrics_iter_{TARGET_ITERATION}", "repeated_run_variability", f"confusion_matrices_iter_{TARGET_ITERATION}", ] and not p.name.startswith("confusion") and not p.name.startswith("functional_system") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"\nNo iteration {TARGET_ITERATION} file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred = pd.read_csv(result_file, sep=",") if "row_index" not in pred.columns: print("Skipping: no row_index column.") continue model_name = get_model_name(pred, model_dir) pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="left", suffixes=("_gt", "_pred") ) print(f"Model name: {model_name}") print(f"Prediction rows after success filter: {len(pred)}") print(f"Merged rows: {len(merged)}") for fs_key, info in FS_MAP.items(): gt_col = info["gt"] if gt_col not in merged.columns: print(f"Skipping {info['display']}: missing GT column {gt_col}") continue pred_col = info["pred"] if info["pred"] in merged.columns else info["fallback"] if pred_col not in merged.columns: print(f"Skipping {info['display']}: missing prediction column") continue temp = merged.copy() temp["GT_value"] = to_num(temp[gt_col]) temp["PRED_value"] = to_num(temp[pred_col]) gt_numeric_df = temp.dropna(subset=["GT_value"]).copy() n_gt_numeric = len(gt_numeric_df) n_nonzero_gt = int((gt_numeric_df["GT_value"] != 0).sum()) percent_nonzero_gt = rate(n_nonzero_gt, n_gt_numeric) * 100 if info["numeric_flag"] in gt_numeric_df.columns: pred_numeric_flag = to_bool(gt_numeric_df[info["numeric_flag"]]) else: pred_numeric_flag = gt_numeric_df["PRED_value"].notna() if info["range_flag"] in gt_numeric_df.columns: pred_range_flag = to_bool(gt_numeric_df[info["range_flag"]]) else: pred_range_flag = gt_numeric_df["PRED_value"].notna() valid_pred = ( pred_numeric_flag & pred_range_flag & gt_numeric_df["PRED_value"].notna() ) n_missing_or_invalid = int((~valid_pred).sum()) percent_missing_or_invalid = rate(n_missing_or_invalid, n_gt_numeric) * 100 eval_df = gt_numeric_df[valid_pred].copy() n_evaluable = len(eval_df) if n_evaluable > 0: eval_df["error"] = eval_df["PRED_value"] - eval_df["GT_value"] eval_df["abs_error"] = eval_df["error"].abs() mae = mean_absolute_error(eval_df["GT_value"], eval_df["PRED_value"]) median_ae = eval_df["abs_error"].median() rmse = np.sqrt(mean_squared_error(eval_df["GT_value"], eval_df["PRED_value"])) exact_acc = (eval_df["abs_error"] == 0).mean() acc_05 = (eval_df["abs_error"] <= 0.5).mean() acc_10 = (eval_df["abs_error"] <= 1.0).mean() enough_variation = ( n_evaluable >= 3 and eval_df["GT_value"].nunique() > 1 and eval_df["PRED_value"].nunique() > 1 ) if enough_variation: spearman_rho, spearman_p = spearmanr( eval_df["GT_value"], eval_df["PRED_value"] ) else: spearman_rho, spearman_p = np.nan, np.nan else: mae = np.nan median_ae = np.nan rmse = np.nan exact_acc = np.nan acc_05 = np.nan acc_10 = np.nan spearman_rho = np.nan spearman_p = np.nan enough_variation = False row = { "model": model_name, "iteration": TARGET_ITERATION, "result_file": str(result_file), "functional_system_key": fs_key, "functional_system": info["display"], "n_gt_numeric": n_gt_numeric, "n_evaluable": n_evaluable, "n_nonzero_ground_truth": n_nonzero_gt, "percent_nonzero_ground_truth": percent_nonzero_gt, "n_missing_or_invalid_model_outputs": n_missing_or_invalid, "percent_missing_or_invalid_model_outputs": percent_missing_or_invalid, "MAE": mae, "median_absolute_error": median_ae, "RMSE": rmse, "exact_accuracy": exact_acc, "accuracy_within_0_5": acc_05, "accuracy_within_1_0": acc_10, "exact_accuracy_percent": exact_acc * 100 if pd.notna(exact_acc) else np.nan, "accuracy_within_0_5_percent": acc_05 * 100 if pd.notna(acc_05) else np.nan, "accuracy_within_1_0_percent": acc_10 * 100 if pd.notna(acc_10) else np.nan, "spearman_rho": spearman_rho, "spearman_p": spearman_p, "spearman_calculated": enough_variation, } rows.append(row) print( f"{info['display']}: " f"n_eval={n_evaluable}, " f"non-zero GT={n_nonzero_gt} ({percent_nonzero_gt:.1f}%), " f"MAE={mae:.3f}, " f"±0.5={row['accuracy_within_0_5_percent']:.1f}%, " f"missing/invalid={percent_missing_or_invalid:.1f}%" ) # ========================= # SAVE TABLES # ========================= metrics_df = pd.DataFrame(rows) metrics_df.to_csv(OUTPUT_FULL_TABLE, index=False) short_cols = [ "model", "functional_system", "n_evaluable", "n_nonzero_ground_truth", "percent_nonzero_ground_truth", "MAE", "median_absolute_error", "RMSE", "exact_accuracy_percent", "accuracy_within_0_5_percent", "accuracy_within_1_0_percent", "spearman_rho", "percent_missing_or_invalid_model_outputs", ] short_df = metrics_df[short_cols].copy() short_df.to_csv(OUTPUT_SHORT_TABLE, index=False) pd.set_option("display.max_columns", None) pd.set_option("display.width", 240) print("\n" + "=" * 100) print("Functional system performance table:") print(metrics_df) print("\n" + "=" * 100) print("Short table:") print(short_df) print("\nSaved:") print(OUTPUT_FULL_TABLE) print(OUTPUT_SHORT_TABLE) ## # %% Functional Systems + EDSS Error Category Stacked Bar Plot per Model from pathlib import Path import re import pandas as pd import matplotlib.pyplot as plt import numpy as np # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"functional_system_error_category_plots_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) plt.rcParams["font.family"] = "Arial" # ========================= # COLUMN MAPPING # ========================= SYSTEMS_TO_PLOT = [ { "name": "Visual Optic Functions", "gt_col": "Sehvermögen", "pred_col": "numeric_subcat_VISUAL_OPTIC_FUNCTIONS", "pred_fallback_col": "subcat_VISUAL_OPTIC_FUNCTIONS", }, { "name": "Cerebellar Functions", "gt_col": "Cerebellum", "pred_col": "numeric_subcat_CEREBELLAR_FUNCTIONS", "pred_fallback_col": "subcat_CEREBELLAR_FUNCTIONS", }, { "name": "Brainstem Functions", "gt_col": "Hirnstamm", "pred_col": "numeric_subcat_BRAINSTEM_FUNCTIONS", "pred_fallback_col": "subcat_BRAINSTEM_FUNCTIONS", }, { "name": "Sensory Functions", "gt_col": "Sensibiliät", "pred_col": "numeric_subcat_SENSORY_FUNCTIONS", "pred_fallback_col": "subcat_SENSORY_FUNCTIONS", }, { "name": "Pyramidal Functions", "gt_col": "Pyramidalmotorik", "pred_col": "numeric_subcat_PYRAMIDAL_FUNCTIONS", "pred_fallback_col": "subcat_PYRAMIDAL_FUNCTIONS", }, { "name": "Ambulation", "gt_col": "Ambulation", "pred_col": "numeric_subcat_AMBULATION", "pred_fallback_col": "subcat_AMBULATION", }, { "name": "Cerebral Functions", "gt_col": "Cerebrale_Funktion", "pred_col": "numeric_subcat_CEREBRAL_FUNCTIONS", "pred_fallback_col": "subcat_CEREBRAL_FUNCTIONS", }, { "name": "Bowel And Bladder Functions", "gt_col": "Blasen-_und_Mastdarmfunktion", "pred_col": "numeric_subcat_BOWEL_AND_BLADDER_FUNCTIONS", "pred_fallback_col": "subcat_BOWEL_AND_BLADDER_FUNCTIONS", }, { "name": "EDSS", "gt_col": "EDSS", "pred_col": "EDSS_numeric", "pred_fallback_col": "EDSS", }, ] SYSTEM_ORDER = [ "Visual Optic Functions", "Cerebellar Functions", "Brainstem Functions", "Sensory Functions", "Pyramidal Functions", "Ambulation", "Cerebral Functions", "Bowel And Bladder Functions", "EDSS", ] CATEGORY_ORDER = [ "Exact", "≤0.5 error", "≤1 error", ">1 error", "Missing/invalid", ] # Blue-based palette for correct predictions COLORS = { "Exact": "#1F77B4", # blue "≤0.5 error": "#9ECAE1", # light blue "≤1 error": "#FDDC7A", # yellow ">1 error": "#F28E2B", # orange "Missing/invalid": "#D62728" # red } # ========================= # HELPERS # ========================= def safe_name(name): return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) def safe_parse_series(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def categorize_error(abs_error): if pd.isna(abs_error): return "Missing/invalid" if abs_error == 0: return "Exact" if abs_error <= 0.5: return "≤0.5 error" if abs_error <= 1.0: return "≤1 error" return ">1 error" def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def get_column_after_merge(df, base_col, side): """ After merge with suffixes=('_gt', '_pred'), duplicated columns become: EDSS_gt and EDSS_pred. For non-duplicated GT-only columns, the name remains unchanged. """ if base_col in df.columns: return base_col suffixed = f"{base_col}_{side}" if suffixed in df.columns: return suffixed return None def prepare_plot_data(gt, pred): rows = [] merged = gt.merge( pred, on="row_index", how="left", suffixes=("_gt", "_pred") ) for system in SYSTEMS_TO_PLOT: system_name = system["name"] gt_col = get_column_after_merge(merged, system["gt_col"], "gt") pred_col = None if system["pred_col"] in merged.columns: pred_col = system["pred_col"] elif system["pred_fallback_col"] in merged.columns: pred_col = system["pred_fallback_col"] elif f"{system['pred_fallback_col']}_pred" in merged.columns: pred_col = f"{system['pred_fallback_col']}_pred" if gt_col is None: print(f"Skipping {system_name}: GT column not found: {system['gt_col']}") continue if pred_col is None: print(f"Skipping {system_name}: prediction column not found") continue gt_values = safe_parse_series(merged[gt_col]) pred_values = safe_parse_series(merged[pred_col]) # Evaluate only cases where ground truth exists. gt_exists = gt_values.notna() for gt_value, pred_value in zip(gt_values[gt_exists], pred_values[gt_exists]): if pd.isna(pred_value): category = "Missing/invalid" else: abs_error = abs(pred_value - gt_value) category = categorize_error(abs_error) rows.append({ "system": system_name, "category": category, }) plot_df = pd.DataFrame(rows) if plot_df.empty: raise ValueError("No valid data available for plotting.") counts = ( plot_df .groupby(["system", "category"]) .size() .unstack(fill_value=0) ) counts = counts.reindex(index=SYSTEM_ORDER) counts = counts.reindex(columns=CATEGORY_ORDER, fill_value=0) counts = counts.fillna(0) # Remove systems with no available GT rows. counts = counts[counts.sum(axis=1) > 0] percentages = counts.div(counts.sum(axis=1), axis=0) * 100 percentages = percentages.fillna(0) return counts, percentages def plot_error_categories(counts, percentages, model_name, output_base): fig, ax = plt.subplots(figsize=(13, 7)) left = np.zeros(len(percentages)) for category in CATEGORY_ORDER: values = percentages[category].values ax.barh( percentages.index, values, left=left, color=COLORS[category], edgecolor="white", linewidth=0.8, label=category, ) for i, value in enumerate(values): if value >= 4: ax.text( left[i] + value / 2, i, f"{value:.1f}%", ha="center", va="center", fontsize=8, fontweight="bold", color="black", ) left += values for i, system in enumerate(percentages.index): total_n = int(counts.loc[system].sum()) if "Missing/invalid" in counts.columns: missing_n = int(counts.loc[system, "Missing/invalid"]) else: missing_n = 0 ax.text( 101, i, f"n={total_n}, missing={missing_n}", va="center", ha="left", fontsize=9, ) ax.set_xlim(0, 115) ax.set_xlabel("Percentage of Cases", fontsize=11, fontweight="bold") ax.set_ylabel("Functional System / EDSS", fontsize=11, fontweight="bold") ax.set_title( f"Prediction Error Categories by Functional System and EDSS\n{model_name}, Iteration {TARGET_ITERATION}", fontsize=13, fontweight="bold", pad=35, ) ax.set_xticks(np.arange(0, 101, 10)) ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) ax.xaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right", "left"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=5, frameon=False, ) plt.tight_layout(rect=[0, 0, 1, 0.90]) svg_path = output_base.with_suffix(".svg") png_path = output_base.with_suffix(".png") plt.savefig(svg_path, format="svg", bbox_inches="tight") plt.savefig(png_path, dpi=300, bbox_inches="tight") plt.show() return svg_path, png_path # ========================= # LOAD GT # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index print(f"GT rows: {len(gt)}") # ========================= # MAIN # ========================= summary_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred = pd.read_csv(result_file, sep=",") if "row_index" not in pred.columns: print("Skipping: no row_index column.") continue model_name = get_model_name(pred, model_dir) safe_model = safe_name(model_name) pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[ pred["success"] .astype(str) .str.lower() .isin(["true", "1", "yes", "ja"]) ].copy() pred = pred.drop_duplicates("row_index", keep="first").copy() counts, percentages = prepare_plot_data(gt, pred) output_base = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_categories_iter_{TARGET_ITERATION}" svg_path, png_path = plot_error_categories( counts=counts, percentages=percentages, model_name=model_name, output_base=output_base, ) counts_path = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_category_counts_iter_{TARGET_ITERATION}.csv" percentages_path = OUTPUT_DIR / f"{safe_model}_functional_systems_edss_error_category_percentages_iter_{TARGET_ITERATION}.csv" counts.to_csv(counts_path) percentages.to_csv(percentages_path) print("Saved:") print(svg_path) print(png_path) print(counts_path) print(percentages_path) summary_rows.append({ "model": model_name, "iteration": TARGET_ITERATION, "result_file": str(result_file), "svg_path": str(svg_path), "png_path": str(png_path), "counts_path": str(counts_path), "percentages_path": str(percentages_path), }) summary_df = pd.DataFrame(summary_rows) summary_path = OUTPUT_DIR / f"functional_systems_edss_error_category_plot_summary_iter_{TARGET_ITERATION}.csv" summary_df.to_csv(summary_path, index=False) print("\n" + "=" * 100) print("Done.") print(f"Summary saved to: {summary_path}") print("=" * 100) ## # %% EDSS error distribution per model from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"edss_error_distribution_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_PATH = OUTPUT_DIR / f"edss_error_distribution_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG_PATH = OUTPUT_DIR / f"edss_error_distribution_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def rate(n, d): if d == 0: return np.nan return n / d def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def classify_error(abs_error): if pd.isna(abs_error): return "missing_or_invalid" if abs_error == 0: return "exact_match" if abs_error == 0.5: return "error_equal_0_5" if 0.5 < abs_error <= 1.0: return "error_gt_0_5_le_1_0" if abs_error > 1.0: return "error_gt_1_0" return "other" # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() n_total_gt_rows = len(gt) n_gt_numeric = len(gt_numeric) print(f"GT rows: {n_total_gt_rows}") print(f"GT numeric EDSS rows: {n_gt_numeric}") # ========================= # MAIN ANALYSIS # ========================= summary_rows = [] long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print("Skipping: no row_index column.") continue model_name = get_model_name(pred_raw, model_dir) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) raw_prediction_rows = len(pred) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt_numeric.merge( pred, on="row_index", how="left", suffixes=("_gt", "_pred") ) merged["prediction_available"] = merged["PRED_EDSS_numeric"].notna() eval_df = merged[merged["prediction_available"]].copy() if eval_df.empty: print("No evaluable rows.") continue eval_df["error"] = eval_df["PRED_EDSS_numeric"] - eval_df["GT_EDSS_numeric"] eval_df["abs_error"] = eval_df["error"].abs() eval_df["error_category"] = eval_df["abs_error"].apply(classify_error) n_evaluable = len(eval_df) n_exact = int((eval_df["abs_error"] == 0).sum()) n_error_equal_05 = int((eval_df["abs_error"] == 0.5).sum()) n_error_gt_05_le_10 = int(((eval_df["abs_error"] > 0.5) & (eval_df["abs_error"] <= 1.0)).sum()) n_error_gt_10 = int((eval_df["abs_error"] > 1.0).sum()) n_error_gt_20 = int((eval_df["abs_error"] > 2.0).sum()) max_abs_error = eval_df["abs_error"].max() n_missing_or_invalid_against_gt_numeric = int((~merged["prediction_available"]).sum()) summary_rows.append({ "model": model_name, "iteration": TARGET_ITERATION, "result_file": str(result_file), "n_total_gt_rows": n_total_gt_rows, "n_gt_numeric": n_gt_numeric, "raw_prediction_rows": raw_prediction_rows, "n_evaluable": n_evaluable, "n_missing_or_invalid_against_gt_numeric": n_missing_or_invalid_against_gt_numeric, "exact_match_n": n_exact, "exact_match_percent_valid_only": rate(n_exact, n_evaluable) * 100, "exact_match_percent_all_gt_numeric": rate(n_exact, n_gt_numeric) * 100, "error_equal_0_5_n": n_error_equal_05, "error_equal_0_5_percent_valid_only": rate(n_error_equal_05, n_evaluable) * 100, "error_equal_0_5_percent_all_gt_numeric": rate(n_error_equal_05, n_gt_numeric) * 100, "error_gt_0_5_le_1_0_n": n_error_gt_05_le_10, "error_gt_0_5_le_1_0_percent_valid_only": rate(n_error_gt_05_le_10, n_evaluable) * 100, "error_gt_0_5_le_1_0_percent_all_gt_numeric": rate(n_error_gt_05_le_10, n_gt_numeric) * 100, "error_gt_1_0_n": n_error_gt_10, "error_gt_1_0_percent_valid_only": rate(n_error_gt_10, n_evaluable) * 100, "error_gt_1_0_percent_all_gt_numeric": rate(n_error_gt_10, n_gt_numeric) * 100, "error_gt_2_0_n": n_error_gt_20, "error_gt_2_0_percent_valid_only": rate(n_error_gt_20, n_evaluable) * 100, "error_gt_2_0_percent_all_gt_numeric": rate(n_error_gt_20, n_gt_numeric) * 100, "maximum_absolute_error": max_abs_error, }) keep_cols = [ "row_index", "unique_id_gt" if "unique_id_gt" in eval_df.columns else "unique_id", "MedDatum_gt" if "MedDatum_gt" in eval_df.columns else "MedDatum", "model", "iteration", "GT_EDSS_numeric", "PRED_EDSS_numeric", "error", "abs_error", "error_category", "raw_EDSS", "EDSS_numeric", "EDSS_in_valid_range", "klassifizierbar", "clinical_output_valid", "edss_logic_valid", "certainty_percent", "reason", "inference_time_sec", ] keep_cols = [c for c in keep_cols if c in eval_df.columns] for _, row in eval_df[keep_cols].iterrows(): row_dict = row.to_dict() row_dict["model_for_analysis"] = model_name long_rows.append(row_dict) print(f"Model: {model_name}") print(f"n_evaluable: {n_evaluable}") print(f"Exact match: {n_exact} ({rate(n_exact, n_evaluable) * 100:.1f}%)") print(f"Error = 0.5: {n_error_equal_05} ({rate(n_error_equal_05, n_evaluable) * 100:.1f}%)") print(f"Error >0.5 and ≤1.0: {n_error_gt_05_le_10} ({rate(n_error_gt_05_le_10, n_evaluable) * 100:.1f}%)") print(f"Error >1.0: {n_error_gt_10} ({rate(n_error_gt_10, n_evaluable) * 100:.1f}%)") print(f"Error >2.0: {n_error_gt_20} ({rate(n_error_gt_20, n_evaluable) * 100:.1f}%)") print(f"Maximum absolute error: {max_abs_error}") # ========================= # SAVE OUTPUT # ========================= summary_df = pd.DataFrame(summary_rows) long_df = pd.DataFrame(long_rows) summary_df.to_csv(OUTPUT_PATH, index=False) long_df.to_csv(OUTPUT_LONG_PATH, index=False) pd.set_option("display.max_columns", None) pd.set_option("display.width", 240) print("\n" + "=" * 100) print("EDSS error distribution summary:") print(summary_df) print("\nSaved:") print(OUTPUT_PATH) print(OUTPUT_LONG_PATH) ## # %% EDSS severity-group performance per model from pathlib import Path import pandas as pd import numpy as np from sklearn.metrics import confusion_matrix # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"edss_severity_group_metrics_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SUMMARY_PATH = OUTPUT_DIR / f"edss_severity_group_metrics_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG_PATH = OUTPUT_DIR / f"edss_severity_group_predictions_long_iter_{TARGET_ITERATION}.csv" OUTPUT_CONFUSION_PATH = OUTPUT_DIR / f"edss_severity_group_confusion_matrices_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" SEVERITY_GROUPS = [ "0.0-3.5", "4.0-5.5", "6.0-10.0", ] # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def rate(n, d): if d == 0: return np.nan return n / d def classify_edss_group(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0-3.5" if 4.0 <= value <= 5.5: return "4.0-5.5" if 6.0 <= value <= 10.0: return "6.0-10.0" return np.nan def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(classify_edss_group) gt_numeric = gt.dropna(subset=["GT_EDSS_numeric", "GT_EDSS_group"]).copy() n_total_gt_rows = len(gt) n_gt_numeric = len(gt_numeric) print(f"GT rows: {n_total_gt_rows}") print(f"GT numeric EDSS rows in severity groups: {n_gt_numeric}") print("\nGT group counts:") print(gt_numeric["GT_EDSS_group"].value_counts().reindex(SEVERITY_GROUPS, fill_value=0)) # ========================= # MAIN ANALYSIS # ========================= summary_rows = [] long_rows = [] confusion_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"\nNo iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print("Skipping: no row_index column.") continue model_name = get_model_name(pred_raw, model_dir) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) raw_prediction_rows = len(pred) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["PRED_EDSS_group"] = pred["PRED_EDSS_numeric"].apply(classify_edss_group) pred = pred.dropna(subset=["PRED_EDSS_numeric", "PRED_EDSS_group"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt_numeric.merge( pred, on="row_index", how="left", suffixes=("_gt", "_pred") ) merged["prediction_available"] = merged["PRED_EDSS_group"].notna() eval_df = merged[merged["prediction_available"]].copy() if eval_df.empty: print("No evaluable rows.") continue n_evaluable = len(eval_df) n_missing_or_invalid_against_gt_numeric = int((~merged["prediction_available"]).sum()) print(f"Model: {model_name}") print(f"Raw prediction rows: {raw_prediction_rows}") print(f"Evaluable rows: {n_evaluable}") print(f"Missing/invalid vs GT numeric: {n_missing_or_invalid_against_gt_numeric}") # Multiclass confusion matrix across 3 severity groups cm = confusion_matrix( eval_df["GT_EDSS_group"], eval_df["PRED_EDSS_group"], labels=SEVERITY_GROUPS ) cm_df = pd.DataFrame( cm, index=SEVERITY_GROUPS, columns=SEVERITY_GROUPS ) cm_df.index.name = "Ground truth severity group" cm_df.columns.name = "Predicted severity group" print("\nSeverity-group confusion matrix:") print(cm_df) for gt_group in SEVERITY_GROUPS: for pred_group in SEVERITY_GROUPS: confusion_rows.append({ "model": model_name, "iteration": TARGET_ITERATION, "gt_group": gt_group, "pred_group": pred_group, "count": int(cm_df.loc[gt_group, pred_group]), }) # One-vs-rest sensitivity/specificity for each severity group for group in SEVERITY_GROUPS: y_true = eval_df["GT_EDSS_group"] == group y_pred = eval_df["PRED_EDSS_group"] == group tn, fp, fn, tp = confusion_matrix( y_true, y_pred, labels=[False, True] ).ravel() sensitivity = rate(tp, tp + fn) specificity = rate(tn, tn + fp) ppv = rate(tp, tp + fp) npv = rate(tn, tn + fn) accuracy = rate(tp + tn, tp + tn + fp + fn) gt_positive_prevalence = rate(tp + fn, n_evaluable) predicted_positive_rate = rate(tp + fp, n_evaluable) summary_rows.append({ "model": model_name, "iteration": TARGET_ITERATION, "result_file": str(result_file), "severity_group": group, "n_total_gt_rows": n_total_gt_rows, "n_gt_numeric_in_groups": n_gt_numeric, "raw_prediction_rows": raw_prediction_rows, "n_evaluable": n_evaluable, "n_missing_or_invalid_against_gt_numeric": n_missing_or_invalid_against_gt_numeric, "true_positives": int(tp), "true_negatives": int(tn), "false_positives": int(fp), "false_negatives": int(fn), "sensitivity": sensitivity, "specificity": specificity, "positive_predictive_value": ppv, "negative_predictive_value": npv, "accuracy": accuracy, "sensitivity_percent": sensitivity * 100, "specificity_percent": specificity * 100, "positive_predictive_value_percent": ppv * 100, "negative_predictive_value_percent": npv * 100, "accuracy_percent": accuracy * 100, "gt_positive_prevalence": gt_positive_prevalence, "gt_positive_prevalence_percent": gt_positive_prevalence * 100, "predicted_positive_rate": predicted_positive_rate, "predicted_positive_rate_percent": predicted_positive_rate * 100, }) print( f"Group {group}: " f"TP={tp}, TN={tn}, FP={fp}, FN={fn}, " f"sensitivity={sensitivity * 100:.1f}%, " f"specificity={specificity * 100:.1f}%" ) # Long per-case output tmp = eval_df.copy() tmp["severity_match"] = tmp["GT_EDSS_group"] == tmp["PRED_EDSS_group"] keep_cols = [ "model", "iteration", "row_index", "unique_id_gt" if "unique_id_gt" in tmp.columns else "unique_id", "MedDatum_gt" if "MedDatum_gt" in tmp.columns else "MedDatum", "GT_EDSS_numeric", "PRED_EDSS_numeric", "GT_EDSS_group", "PRED_EDSS_group", "severity_match", "raw_EDSS", "EDSS_numeric", "EDSS_in_valid_range", "klassifizierbar", "clinical_output_valid", "edss_logic_valid", "certainty_percent", "reason", "inference_time_sec", ] keep_cols = [c for c in keep_cols if c in tmp.columns] for _, row in tmp[keep_cols].iterrows(): row_dict = row.to_dict() row_dict["model_for_analysis"] = model_name long_rows.append(row_dict) # ========================= # SAVE OUTPUT # ========================= summary_df = pd.DataFrame(summary_rows) long_df = pd.DataFrame(long_rows) confusion_df = pd.DataFrame(confusion_rows) summary_df.to_csv(OUTPUT_SUMMARY_PATH, index=False) long_df.to_csv(OUTPUT_LONG_PATH, index=False) confusion_df.to_csv(OUTPUT_CONFUSION_PATH, index=False) pd.set_option("display.max_columns", None) pd.set_option("display.width", 240) print("\n" + "=" * 100) print("EDSS severity-group performance summary:") print(summary_df) print("\nSaved:") print(OUTPUT_SUMMARY_PATH) print(OUTPUT_LONG_PATH) print(OUTPUT_CONFUSION_PATH) ## # %% Coverage table: model evaluable predictions vs ground truth from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_PATH = RUN_DIR / f"model_coverage_table_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) total_records = len(gt) numeric_gt_edss = gt["GT_EDSS_numeric"].notna().sum() gt_numeric = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"Total records: {total_records}") print(f"Numeric ground-truth EDSS: {numeric_gt_edss}") # ========================= # MODEL COVERAGE TABLE # ========================= rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column") continue model_name = get_model_name(pred_raw, model_dir) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred = pred.dropna(subset=["PRED_EDSS_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt_numeric.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) evaluable_predictions = len(merged) coverage_numeric_gt = ( evaluable_predictions / numeric_gt_edss * 100 if numeric_gt_edss > 0 else np.nan ) coverage_all_records = ( evaluable_predictions / total_records * 100 if total_records > 0 else np.nan ) rows.append({ "Model": model_name, "Total records": total_records, "Numeric ground-truth EDSS": numeric_gt_edss, "Evaluable predictions": evaluable_predictions, "Coverage of numeric ground truth (%)": coverage_numeric_gt, "Coverage of all records (%)": coverage_all_records, }) print( f"{model_name}: " f"evaluable={evaluable_predictions}, " f"coverage numeric GT={coverage_numeric_gt:.1f}%, " f"coverage all={coverage_all_records:.1f}%" ) coverage_df = pd.DataFrame(rows) coverage_df.to_csv(OUTPUT_PATH, index=False) print("\nCoverage table:") print(coverage_df) print(f"\nSaved to:") print(OUTPUT_PATH) ## # %% Dataset exploration table for EDSS project from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= DATA_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_TABLE_PATH = OUTPUT_DIR / "dataset_column_exploration_table.csv" OUTPUT_NUMERIC_SUMMARY_PATH = OUTPUT_DIR / "reference_numeric_summary.csv" OUTPUT_VALUE_COUNTS_PATH = OUTPUT_DIR / "reference_value_counts.csv" OUTPUT_TEXT_SUMMARY_PATH = OUTPUT_DIR / "model_input_text_summary.csv" # ========================= # COLUMN DEFINITIONS # ========================= COLUMNS_TO_EXPLORE = [ { "output_name": "unique_id", "source_col": "unique_id", "variable_type": "Identifier", "role": "Patient/record linkage", "description": "Pseudonymized unique identifier generated by hashing patient-related identifiers", "used_for_model_input": "No", "used_as_ground_truth": "No", }, { "output_name": "MedDatum", "source_col": "MedDatum", "variable_type": "Date", "role": "Visit metadata", "description": "Date of clinical visit or medical documentation", "used_for_model_input": "No", "used_as_ground_truth": "No", }, { "output_name": "T_Zusammenfassung", "source_col": "T_Zusammenfassung", "variable_type": "Text", "role": "Model input", "description": "Clinical summary section", "used_for_model_input": "Yes", "used_as_ground_truth": "No", }, { "output_name": "Diagnosen", "source_col": "Diagnosen", "variable_type": "Text", "role": "Model input", "description": "Diagnostic information and coded/free-text diagnoses", "used_for_model_input": "Yes", "used_as_ground_truth": "No", }, { "output_name": "T_KlinBef", "source_col": "T_KlinBef", "variable_type": "Text", "role": "Model input", "description": "Clinical examination findings", "used_for_model_input": "Yes", "used_as_ground_truth": "No", }, { "output_name": "T_Befunde", "source_col": "T_Befunde", "variable_type": "Text", "role": "Model input", "description": "Additional findings and reports", "used_for_model_input": "Yes", "used_as_ground_truth": "No", }, { "output_name": "EDSS_reference", "source_col": "EDSS", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted reference EDSS score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "VISUAL_OPTIC_FUNCTIONS_reference", "source_col": "Sehvermögen", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted visual/optic functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "BRAINSTEM_FUNCTIONS_reference", "source_col": "Hirnstamm", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted brainstem functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "PYRAMIDAL_FUNCTIONS_reference", "source_col": "Pyramidalmotorik", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted pyramidal functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "CEREBELLAR_FUNCTIONS_reference", "source_col": "Cerebellum", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted cerebellar functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "SENSORY_FUNCTIONS_reference", "source_col": "Sensibiliät", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted sensory functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "BOWEL_AND_BLADDER_FUNCTIONS_reference", "source_col": "Blasen-_und_Mastdarmfunktion", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted bowel and bladder functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "CEREBRAL_FUNCTIONS_reference", "source_col": "Cerebrale_Funktion", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted cerebral functional system score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, { "output_name": "AMBULATION_reference", "source_col": "Ambulation", "variable_type": "Numeric", "role": "Ground truth", "description": "Manually extracted ambulation score", "used_for_model_input": "No", "used_as_ground_truth": "Yes", }, ] TEXT_COLUMNS = [ "T_Zusammenfassung", "Diagnosen", "T_KlinBef", "T_Befunde", ] NUMERIC_REFERENCE_COLUMNS = [ item for item in COLUMNS_TO_EXPLORE if item["variable_type"] == "Numeric" ] # ========================= # HELPERS # ========================= def to_num(series): return pd.to_numeric( series.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def is_non_missing_value(series): """ Treat NaN, empty string, whitespace-only string, and literal 'nan'/'None' as missing. """ s = series.copy() missing = s.isna() s_str = s.astype(str).str.strip() missing = ( missing | (s_str == "") | (s_str.str.lower().isin(["nan", "none", "null", "na", "n/a"])) ) return ~missing def text_length_stats(series): non_missing = is_non_missing_value(series) lengths = series[non_missing].astype(str).str.len() if len(lengths) == 0: return { "mean_text_length_chars": np.nan, "median_text_length_chars": np.nan, "min_text_length_chars": np.nan, "max_text_length_chars": np.nan, } return { "mean_text_length_chars": lengths.mean(), "median_text_length_chars": lengths.median(), "min_text_length_chars": lengths.min(), "max_text_length_chars": lengths.max(), } # ========================= # LOAD DATA # ========================= df = pd.read_csv(DATA_PATH, sep=";") total_n = len(df) print(f"Loaded rows: {total_n}") print(f"Loaded columns: {len(df.columns)}") # ========================= # MAIN EXPLORATION TABLE # ========================= rows = [] for item in COLUMNS_TO_EXPLORE: source_col = item["source_col"] if source_col not in df.columns: non_missing_n = 0 non_missing_pct = 0.0 status = "missing_column" else: if item["variable_type"] == "Numeric": numeric_values = to_num(df[source_col]) non_missing_n = int(numeric_values.notna().sum()) else: non_missing_n = int(is_non_missing_value(df[source_col]).sum()) non_missing_pct = non_missing_n / total_n * 100 if total_n > 0 else np.nan status = "ok" rows.append({ "Variable / Column": item["output_name"], "Source column": source_col, "Variable type": item["variable_type"], "Role in study": item["role"], "Description": item["description"], "Non-missing n / total": f"{non_missing_n} / {total_n}", "Non-missing n": non_missing_n, "Total n": total_n, "Non-missing %": round(non_missing_pct, 1), "Used for model input": item["used_for_model_input"], "Used as ground truth": item["used_as_ground_truth"], "Status": status, }) exploration_df = pd.DataFrame(rows) exploration_df.to_csv(OUTPUT_TABLE_PATH, index=False) # ========================= # NUMERIC REFERENCE SUMMARY # ========================= numeric_rows = [] for item in NUMERIC_REFERENCE_COLUMNS: source_col = item["source_col"] if source_col not in df.columns: continue values = to_num(df[source_col]).dropna() if values.empty: numeric_rows.append({ "Variable / Column": item["output_name"], "Source column": source_col, "n": 0, "non_zero_n": 0, "non_zero_percent": np.nan, "mean": np.nan, "median": np.nan, "std": np.nan, "min": np.nan, "max": np.nan, }) continue non_zero_n = int((values != 0).sum()) numeric_rows.append({ "Variable / Column": item["output_name"], "Source column": source_col, "n": int(values.notna().sum()), "non_zero_n": non_zero_n, "non_zero_percent": non_zero_n / len(values) * 100, "mean": values.mean(), "median": values.median(), "std": values.std(), "min": values.min(), "max": values.max(), }) numeric_summary_df = pd.DataFrame(numeric_rows) numeric_summary_df.to_csv(OUTPUT_NUMERIC_SUMMARY_PATH, index=False) # ========================= # VALUE COUNTS FOR REFERENCE NUMERIC COLUMNS # ========================= value_count_rows = [] for item in NUMERIC_REFERENCE_COLUMNS: source_col = item["source_col"] if source_col not in df.columns: continue values = to_num(df[source_col]).dropna() counts = ( values .value_counts() .sort_index() ) for value, count in counts.items(): value_count_rows.append({ "Variable / Column": item["output_name"], "Source column": source_col, "value": value, "count": int(count), "percent_of_non_missing": count / len(values) * 100 if len(values) > 0 else np.nan, "percent_of_total": count / total_n * 100 if total_n > 0 else np.nan, }) value_counts_df = pd.DataFrame(value_count_rows) value_counts_df.to_csv(OUTPUT_VALUE_COUNTS_PATH, index=False) # ========================= # TEXT INPUT SUMMARY # ========================= text_rows = [] for col in TEXT_COLUMNS: if col not in df.columns: text_rows.append({ "column": col, "non_missing_n": 0, "non_missing_percent": 0.0, "mean_text_length_chars": np.nan, "median_text_length_chars": np.nan, "min_text_length_chars": np.nan, "max_text_length_chars": np.nan, }) continue non_missing = is_non_missing_value(df[col]) stats = text_length_stats(df[col]) text_rows.append({ "column": col, "non_missing_n": int(non_missing.sum()), "non_missing_percent": non_missing.sum() / total_n * 100 if total_n > 0 else np.nan, **stats, }) text_summary_df = pd.DataFrame(text_rows) text_summary_df.to_csv(OUTPUT_TEXT_SUMMARY_PATH, index=False) # ========================= # PRINT RESULTS # ========================= pd.set_option("display.max_columns", None) pd.set_option("display.width", 220) print("\n" + "=" * 100) print("Dataset column exploration table:") print(exploration_df) print("\n" + "=" * 100) print("Numeric reference summary:") print(numeric_summary_df) print("\n" + "=" * 100) print("Text input summary:") print(text_summary_df) print("\nSaved:") print(OUTPUT_TABLE_PATH) print(OUTPUT_NUMERIC_SUMMARY_PATH) print(OUTPUT_VALUE_COUNTS_PATH) print(OUTPUT_TEXT_SUMMARY_PATH) ## # %% Dataset characteristics table from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= DATA_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_CSV = OUTPUT_DIR / "dataset_characteristics_table.csv" OUTPUT_MD = OUTPUT_DIR / "dataset_characteristics_table.md" OUTPUT_PATIENT_COUNTS = OUTPUT_DIR / "patient_record_counts.csv" PATIENT_ID_COL = "unique_id" DATE_COL = "MedDatum" EDSS_COL = "EDSS" # ========================= # HELPERS # ========================= def to_num(series): return pd.to_numeric( series.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def parse_dates(series): """ Robust date parser. Handles common German/European and ISO-like dates. """ parsed = pd.to_datetime(series, errors="coerce", dayfirst=True) return parsed def fmt_int(x): if pd.isna(x): return "NA" return f"{int(x)}" def fmt_float(x, digits=1): if pd.isna(x): return "NA" return f"{float(x):.{digits}f}" def fmt_percent(n, d, digits=1): if d == 0: return "NA" return f"{(n / d * 100):.{digits}f}%" def fmt_n_total_percent(n, total): return f"{int(n)} / {int(total)}, {fmt_percent(n, total)}" def fmt_range(min_value, max_value, digits=1): if pd.isna(min_value) or pd.isna(max_value): return "NA" return f"{float(min_value):.{digits}f}–{float(max_value):.{digits}f}" def fmt_record_range(min_value, max_value): if pd.isna(min_value) or pd.isna(max_value): return "NA" return f"{int(min_value)}–{int(max_value)}" # ========================= # LOAD DATA # ========================= df = pd.read_csv(DATA_PATH, sep=";") total_records = len(df) if PATIENT_ID_COL not in df.columns: raise ValueError(f"Missing patient ID column: {PATIENT_ID_COL}") if DATE_COL not in df.columns: raise ValueError(f"Missing date column: {DATE_COL}") if EDSS_COL not in df.columns: raise ValueError(f"Missing EDSS column: {EDSS_COL}") # ========================= # BASIC COUNTS # ========================= unique_patients = df[PATIENT_ID_COL].nunique(dropna=True) dates = parse_dates(df[DATE_COL]) valid_dates = dates.dropna() if len(valid_dates) > 0: documentation_start_year = int(valid_dates.min().year) documentation_end_year = int(valid_dates.max().year) documentation_period = f"{documentation_start_year}–{documentation_end_year}" else: documentation_period = "NA" edss_numeric = to_num(df[EDSS_COL]) records_with_numeric_edss = int(edss_numeric.notna().sum()) records_without_numeric_edss = int(total_records - records_with_numeric_edss) # ========================= # PATIENT RECORD COUNTS # ========================= patient_counts = ( df.groupby(PATIENT_ID_COL, dropna=True) .size() .reset_index(name="record_count") .sort_values("record_count", ascending=False) ) patient_counts.to_csv(OUTPUT_PATIENT_COUNTS, index=False) median_records_per_patient = patient_counts["record_count"].median() min_records_per_patient = patient_counts["record_count"].min() max_records_per_patient = patient_counts["record_count"].max() patients_with_one_record = int((patient_counts["record_count"] == 1).sum()) patients_with_multiple_records = int((patient_counts["record_count"] > 1).sum()) # ========================= # DUPLICATE VISIT EXPLORATION # ========================= # This estimates duplicate visits using patient ID + documentation date. # If a patient has multiple rows on the same MedDatum, rows beyond the first are counted as duplicate records. duplicate_subset = df.copy() duplicate_subset["_parsed_MedDatum"] = dates duplicate_rows_mask = duplicate_subset.duplicated( subset=[PATIENT_ID_COL, "_parsed_MedDatum"], keep="first" ) records_excluded_as_duplicates = int(duplicate_rows_mask.sum()) duplicate_patients = duplicate_subset.loc[ duplicate_rows_mask, PATIENT_ID_COL ].nunique(dropna=True) if records_excluded_as_duplicates == 0: duplicate_text = "0" else: duplicate_text = ( f"{records_excluded_as_duplicates} visits from " f"{duplicate_patients} patients" ) # ========================= # EDSS SUMMARY # ========================= edss_valid = edss_numeric.dropna() if len(edss_valid) > 0: median_edss = edss_valid.median() q1_edss = edss_valid.quantile(0.25) q3_edss = edss_valid.quantile(0.75) min_edss = edss_valid.min() max_edss = edss_valid.max() else: median_edss = np.nan q1_edss = np.nan q3_edss = np.nan min_edss = np.nan max_edss = np.nan # ========================= # BUILD CHARACTERISTICS TABLE # ========================= rows = [ { "Characteristic": "Total clinical records", "Value": fmt_int(total_records), }, { "Characteristic": "Unique patients", "Value": fmt_int(unique_patients), }, { "Characteristic": "Documentation period", "Value": documentation_period, }, { "Characteristic": "Records excluded as duplicates", "Value": duplicate_text, }, { "Characteristic": "Records with numeric reference EDSS", "Value": fmt_n_total_percent(records_with_numeric_edss, total_records), }, { "Characteristic": "Records without numeric reference EDSS", "Value": fmt_n_total_percent(records_without_numeric_edss, total_records), }, { "Characteristic": "Median records per patient", "Value": fmt_float(median_records_per_patient, digits=1), }, { "Characteristic": "Range of records per patient", "Value": fmt_record_range(min_records_per_patient, max_records_per_patient), }, { "Characteristic": "Patients with one record", "Value": fmt_n_total_percent(patients_with_one_record, unique_patients), }, { "Characteristic": "Patients with multiple records", "Value": fmt_n_total_percent(patients_with_multiple_records, unique_patients), }, { "Characteristic": "Median reference EDSS", "Value": fmt_float(median_edss, digits=1), }, { "Characteristic": "IQR reference EDSS", "Value": fmt_range(q1_edss, q3_edss, digits=1), }, { "Characteristic": "Minimum–maximum reference EDSS", "Value": fmt_range(min_edss, max_edss, digits=1), }, ] characteristics_df = pd.DataFrame(rows) # ========================= # SAVE OUTPUT # ========================= characteristics_df.to_csv(OUTPUT_CSV, index=False) with open(OUTPUT_MD, "w", encoding="utf-8") as f: f.write(characteristics_df.to_markdown(index=False)) f.write("\n") # ========================= # PRINT OUTPUT # ========================= pd.set_option("display.max_colwidth", None) pd.set_option("display.width", 160) print("\nDataset characteristics table:") print(characteristics_df.to_markdown(index=False)) print("\nSaved:") print(OUTPUT_CSV) print(OUTPUT_MD) print(OUTPUT_PATIENT_COUNTS) print("\nDuplicate estimate note:") print( "Duplicates were estimated as repeated rows with the same unique_id and MedDatum. " "If you already removed duplicates before this file, this value may be 0." ) ## # %% Dataset characteristics table with visit-count distribution from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= DATA_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) OUTPUT_DIR = DATA_PATH.parent / "dataset_exploration" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_CSV = OUTPUT_DIR / "dataset_characteristics_table.csv" OUTPUT_MD = OUTPUT_DIR / "dataset_characteristics_table.md" OUTPUT_PATIENT_COUNTS = OUTPUT_DIR / "patient_record_counts.csv" OUTPUT_VISIT_DISTRIBUTION = OUTPUT_DIR / "patient_visit_count_distribution.csv" PATIENT_ID_COL = "unique_id" DATE_COL = "MedDatum" EDSS_COL = "EDSS" # ========================= # HELPERS # ========================= def to_num(series): return pd.to_numeric( series.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def parse_dates(series): """ Robust date parser. Handles common German/European and ISO-like dates. """ return pd.to_datetime(series, errors="coerce", dayfirst=True) def fmt_int(x): if pd.isna(x): return "NA" return f"{int(x)}" def fmt_float(x, digits=1): if pd.isna(x): return "NA" return f"{float(x):.{digits}f}" def fmt_percent(n, d, digits=1): if d == 0: return "NA" return f"{(n / d * 100):.{digits}f}%" def fmt_n_total_percent(n, total): return f"{int(n)} / {int(total)}, {fmt_percent(n, total)}" def fmt_range(min_value, max_value, digits=1): if pd.isna(min_value) or pd.isna(max_value): return "NA" return f"{float(min_value):.{digits}f}–{float(max_value):.{digits}f}" def fmt_record_range(min_value, max_value): if pd.isna(min_value) or pd.isna(max_value): return "NA" return f"{int(min_value)}–{int(max_value)}" # ========================= # LOAD DATA # ========================= df = pd.read_csv(DATA_PATH, sep=";") total_records = len(df) if PATIENT_ID_COL not in df.columns: raise ValueError(f"Missing patient ID column: {PATIENT_ID_COL}") if DATE_COL not in df.columns: raise ValueError(f"Missing date column: {DATE_COL}") if EDSS_COL not in df.columns: raise ValueError(f"Missing EDSS column: {EDSS_COL}") # ========================= # BASIC COUNTS # ========================= unique_patients = df[PATIENT_ID_COL].nunique(dropna=True) # In this dataset, each row is a clinical record / visit. total_visits = total_records dates = parse_dates(df[DATE_COL]) valid_dates = dates.dropna() if len(valid_dates) > 0: documentation_start_year = int(valid_dates.min().year) documentation_end_year = int(valid_dates.max().year) documentation_period = f"{documentation_start_year}–{documentation_end_year}" else: documentation_period = "NA" edss_numeric = to_num(df[EDSS_COL]) records_with_numeric_edss = int(edss_numeric.notna().sum()) records_without_numeric_edss = int(total_records - records_with_numeric_edss) # ========================= # PATIENT RECORD / VISIT COUNTS # ========================= patient_counts = ( df.groupby(PATIENT_ID_COL, dropna=True) .size() .reset_index(name="record_count") .sort_values("record_count", ascending=False) ) patient_counts.to_csv(OUTPUT_PATIENT_COUNTS, index=False) median_records_per_patient = patient_counts["record_count"].median() min_records_per_patient = patient_counts["record_count"].min() max_records_per_patient = patient_counts["record_count"].max() patients_with_one_record = int((patient_counts["record_count"] == 1).sum()) patients_with_multiple_records = int((patient_counts["record_count"] > 1).sum()) patients_with_n_records = { n: int((patient_counts["record_count"] == n).sum()) for n in range(2, 8) } patients_with_more_than_7_records = int((patient_counts["record_count"] > 7).sum()) visit_distribution_rows = [] visit_distribution_rows.append({ "records_per_patient": "1", "patients_n": patients_with_one_record, "total_patients": unique_patients, "patients_percent": patients_with_one_record / unique_patients * 100 if unique_patients else np.nan, }) for n in range(2, 8): visit_distribution_rows.append({ "records_per_patient": str(n), "patients_n": patients_with_n_records[n], "total_patients": unique_patients, "patients_percent": patients_with_n_records[n] / unique_patients * 100 if unique_patients else np.nan, }) visit_distribution_rows.append({ "records_per_patient": ">7", "patients_n": patients_with_more_than_7_records, "total_patients": unique_patients, "patients_percent": patients_with_more_than_7_records / unique_patients * 100 if unique_patients else np.nan, }) visit_distribution_df = pd.DataFrame(visit_distribution_rows) visit_distribution_df.to_csv(OUTPUT_VISIT_DISTRIBUTION, index=False) # ========================= # DUPLICATE VISIT EXPLORATION # ========================= # This estimates duplicate visits using patient ID + documentation date. # If a patient has multiple rows on the same MedDatum, rows beyond the first are counted as duplicate records. duplicate_subset = df.copy() duplicate_subset["_parsed_MedDatum"] = dates duplicate_rows_mask = duplicate_subset.duplicated( subset=[PATIENT_ID_COL, "_parsed_MedDatum"], keep="first" ) records_excluded_as_duplicates = int(duplicate_rows_mask.sum()) duplicate_patients = duplicate_subset.loc[ duplicate_rows_mask, PATIENT_ID_COL ].nunique(dropna=True) if records_excluded_as_duplicates == 0: duplicate_text = "0" else: duplicate_text = ( f"{records_excluded_as_duplicates} visits from " f"{duplicate_patients} patients" ) # ========================= # EDSS SUMMARY # ========================= edss_valid = edss_numeric.dropna() if len(edss_valid) > 0: median_edss = edss_valid.median() q1_edss = edss_valid.quantile(0.25) q3_edss = edss_valid.quantile(0.75) min_edss = edss_valid.min() max_edss = edss_valid.max() else: median_edss = np.nan q1_edss = np.nan q3_edss = np.nan min_edss = np.nan max_edss = np.nan # ========================= # BUILD CHARACTERISTICS TABLE # ========================= rows = [ { "Characteristic": "Total clinical records", "Value": fmt_int(total_records), }, { "Characteristic": "Total visits", "Value": fmt_int(total_visits), }, { "Characteristic": "Unique patients", "Value": fmt_int(unique_patients), }, { "Characteristic": "Documentation period", "Value": documentation_period, }, { "Characteristic": "Records excluded as duplicates", "Value": duplicate_text, }, { "Characteristic": "Records with numeric reference EDSS", "Value": fmt_n_total_percent(records_with_numeric_edss, total_records), }, { "Characteristic": "Records without numeric reference EDSS", "Value": fmt_n_total_percent(records_without_numeric_edss, total_records), }, { "Characteristic": "Median records per patient", "Value": fmt_float(median_records_per_patient, digits=1), }, { "Characteristic": "Range of records per patient", "Value": fmt_record_range(min_records_per_patient, max_records_per_patient), }, { "Characteristic": "Patients with one record", "Value": fmt_n_total_percent(patients_with_one_record, unique_patients), }, { "Characteristic": "Patients with multiple records", "Value": fmt_n_total_percent(patients_with_multiple_records, unique_patients), }, ] for n in range(2, 8): rows.append({ "Characteristic": f"Patients with {n} records", "Value": fmt_n_total_percent(patients_with_n_records[n], unique_patients), }) rows.append({ "Characteristic": "Patients with >7 records", "Value": fmt_n_total_percent(patients_with_more_than_7_records, unique_patients), }) rows.extend([ { "Characteristic": "Median reference EDSS", "Value": fmt_float(median_edss, digits=1), }, { "Characteristic": "IQR reference EDSS", "Value": fmt_range(q1_edss, q3_edss, digits=1), }, { "Characteristic": "Minimum–maximum reference EDSS", "Value": fmt_range(min_edss, max_edss, digits=1), }, ]) characteristics_df = pd.DataFrame(rows) # ========================= # SAVE OUTPUT # ========================= characteristics_df.to_csv(OUTPUT_CSV, index=False) with open(OUTPUT_MD, "w", encoding="utf-8") as f: f.write(characteristics_df.to_markdown(index=False)) f.write("\n") # ========================= # PRINT OUTPUT # ========================= pd.set_option("display.max_colwidth", None) pd.set_option("display.width", 180) print("\nDataset characteristics table:") print(characteristics_df.to_markdown(index=False)) print("\nPatient visit-count distribution:") print(visit_distribution_df.to_markdown(index=False)) print("\nSaved:") print(OUTPUT_CSV) print(OUTPUT_MD) print(OUTPUT_PATIENT_COUNTS) print(OUTPUT_VISIT_DISTRIBUTION) print("\nDuplicate estimate note:") print( "Duplicates were estimated as repeated rows with the same unique_id and MedDatum. " "If you already removed duplicates before this file, this value may be 0." ) ## # %% Structured-output validity bar chart grouped by metric from pathlib import Path import pandas as pd import matplotlib.pyplot as plt import numpy as np # ========================= # CONFIGURATION # ========================= RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) OUTPUT_DIR = RUN_DIR / "structured_output_validity_figure" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / "structured_output_validity_bar_chart_grouped_by_metric.svg" OUTPUT_PNG = OUTPUT_DIR / "structured_output_validity_bar_chart_grouped_by_metric.png" OUTPUT_CSV = OUTPUT_DIR / "structured_output_validity_table_grouped_by_metric.csv" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def find_summary_file(model_dir): files = sorted(model_dir.glob("*_summary_*.csv")) return files[0] if files else None def percent_from_rate(value): if pd.isna(value): return np.nan value = float(value) if value <= 1.0: return value * 100 return value def clean_model_name(name): name = str(name) replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(name, name) # ========================= # LOAD SUMMARY DATA # ========================= rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") ] for model_dir in model_dirs: summary_file = find_summary_file(model_dir) if summary_file is None: print(f"No summary file found in {model_dir}") continue df = pd.read_csv(summary_file) if df.empty: print(f"Empty summary file: {summary_file}") continue row = df.iloc[0] model = row.get("model", model_dir.name) model_display = clean_model_name(model) success_rate = percent_from_rate(row.get("success_rate", np.nan)) if "clinical_output_valid_rate" in df.columns: clinical_output_valid_rate = percent_from_rate( row.get("clinical_output_valid_rate", np.nan) ) else: clinical_output_valid_rate = percent_from_rate( row.get("clinical_range_valid_rate", np.nan) ) edss_valid_range_rate = percent_from_rate( row.get("EDSS_valid_range_rate", np.nan) ) rows.append({ "model": model, "model_display": model_display, "Success rate": success_rate, "Clinical-output validity": clinical_output_valid_rate, "EDSS valid-range rate": edss_valid_range_rate, "summary_file": str(summary_file), }) validity_df = pd.DataFrame(rows) if validity_df.empty: raise ValueError("No model summary data found.") # Optional model order model_order = ["GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it"] validity_df["model_display"] = pd.Categorical( validity_df["model_display"], categories=model_order, ordered=True ) validity_df = validity_df.sort_values("model_display").reset_index(drop=True) validity_df.to_csv(OUTPUT_CSV, index=False) print("\nStructured-output validity table:") print(validity_df) # ========================= # PLOT # ========================= metrics = [ "Success rate", "Clinical-output validity", "EDSS valid-range rate", ] models = validity_df["model_display"].astype(str).tolist() x = np.arange(len(metrics)) n_models = len(models) bar_width = 0.22 colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } fig, ax = plt.subplots(figsize=(10, 6)) for i, model in enumerate(models): values = [ validity_df.loc[validity_df["model_display"].astype(str) == model, metric].iloc[0] for metric in metrics ] offset = (i - (n_models - 1) / 2) * bar_width bars = ax.bar( x + offset, values, width=bar_width, label=model, color=colors.get(model, None), edgecolor="white", linewidth=0.8, ) for bar, value in zip(bars, values): if pd.notna(value): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f"{value:.1f}%", ha="center", va="bottom", fontsize=8, fontweight="bold", rotation=0, ) ax.set_xticks(x) ax.set_xticklabels(metrics, fontsize=10) ax.set_ylim(0, 110) ax.set_ylabel("Percentage of responses", fontsize=11, fontweight="bold") ax.set_xlabel("Structured-output metric", fontsize=11, fontweight="bold") #ax.set_title( # "Structured-output validity by metric and model", # fontsize=13, # fontweight="bold", # pad=15, #) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) plt.tight_layout(rect=[0, 0, 1, 0.92]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_CSV) print(OUTPUT_SVG) print(OUTPUT_PNG) ## # %% EDSS severity-group confusion heatmaps per model from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns # ========================= # CONFIGURATION # ========================= INPUT_LONG_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/edss_severity_group_metrics_iter_1/" "edss_severity_group_predictions_long_iter_1.csv" ) OUTPUT_DIR = INPUT_LONG_PATH.parent / "severity_group_heatmaps" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Options: # "count" -> cell values are raw counts # "row_percent" -> cell values are percentages within each ground-truth row PLOT_MODE = "row_percent" # PLOT_MODE = "count" GROUP_ORDER = [ "0.0-3.5", "4.0-5.5", "6.0-10.0", ] GROUP_LABELS = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def safe_model_name(name): return ( str(name) .replace("/", "_") .replace(" ", "_") .replace(":", "_") ) def make_confusion_table(df_model): cm = pd.crosstab( df_model["GT_EDSS_group"], df_model["PRED_EDSS_group"], dropna=False ) cm = cm.reindex(index=GROUP_ORDER, columns=GROUP_ORDER, fill_value=0) return cm def row_percent_table(cm): row_sums = cm.sum(axis=1).replace(0, np.nan) pct = cm.div(row_sums, axis=0) * 100 return pct.fillna(0) def plot_heatmap(cm_counts, model_name, plot_mode): if plot_mode == "count": plot_data = cm_counts.copy() annot = plot_data.astype(int).astype(str) fmt = "" cbar_label = "Number of cases" title_suffix = "Counts" vmax = None elif plot_mode == "row_percent": plot_data = row_percent_table(cm_counts) annot = plot_data.applymap(lambda x: f"{x:.1f}%") fmt = "" cbar_label = "Row percentage" title_suffix = "Row percentages" vmax = 100 else: raise ValueError(f"Unknown PLOT_MODE: {plot_mode}") fig, ax = plt.subplots(figsize=(7, 6)) sns.heatmap( plot_data, annot=annot, fmt=fmt, cmap="Blues", vmin=0, vmax=vmax, xticklabels=GROUP_LABELS, yticklabels=GROUP_LABELS, linewidths=0.8, linecolor="white", square=True, cbar_kws={"label": cbar_label}, ax=ax, ) ax.set_xlabel("Predicted EDSS severity group", fontsize=11, fontweight="bold") ax.set_ylabel("Ground-truth EDSS severity group", fontsize=11, fontweight="bold") ax.set_title( f"EDSS Severity-Group Confusion Matrix\n{model_name} | {title_suffix}", fontsize=13, fontweight="bold", pad=15, ) plt.xticks(rotation=0) plt.yticks(rotation=0) plt.tight_layout() safe_name = safe_model_name(model_name) svg_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_heatmap_{plot_mode}.svg" png_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_heatmap_{plot_mode}.png" plt.savefig(svg_path, format="svg", bbox_inches="tight") plt.savefig(png_path, dpi=300, bbox_inches="tight") plt.show() return svg_path, png_path # ========================= # LOAD DATA # ========================= df = pd.read_csv(INPUT_LONG_PATH) required_cols = [ "GT_EDSS_group", "PRED_EDSS_group", ] for col in required_cols: if col not in df.columns: raise ValueError(f"Missing required column: {col}") if "model_for_analysis" in df.columns: model_col = "model_for_analysis" elif "model" in df.columns: model_col = "model" else: raise ValueError("No model column found. Expected 'model_for_analysis' or 'model'.") df = df.dropna(subset=["GT_EDSS_group", "PRED_EDSS_group"]).copy() print(f"Loaded rows: {len(df)}") print(f"Models: {sorted(df[model_col].dropna().unique())}") # ========================= # CREATE HEATMAPS # ========================= summary_rows = [] for model_name, df_model in df.groupby(model_col): print("\n" + "=" * 80) print(f"Model: {model_name}") print(f"Rows: {len(df_model)}") cm_counts = make_confusion_table(df_model) cm_row_pct = row_percent_table(cm_counts) print("\nCount matrix:") print(cm_counts) print("\nRow percentage matrix:") print(cm_row_pct.round(1)) svg_path, png_path = plot_heatmap( cm_counts=cm_counts, model_name=model_name, plot_mode=PLOT_MODE, ) safe_name = safe_model_name(model_name) counts_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_counts.csv" row_pct_path = OUTPUT_DIR / f"{safe_name}_severity_group_confusion_row_percent.csv" cm_counts.to_csv(counts_path) cm_row_pct.to_csv(row_pct_path) summary_rows.append({ "model": model_name, "plot_mode": PLOT_MODE, "n_rows": len(df_model), "svg_path": str(svg_path), "png_path": str(png_path), "counts_path": str(counts_path), "row_percent_path": str(row_pct_path), }) print("\nSaved:") print(svg_path) print(png_path) print(counts_path) print(row_pct_path) # ========================= # SAVE SUMMARY # ========================= summary_df = pd.DataFrame(summary_rows) summary_path = OUTPUT_DIR / f"severity_group_heatmap_summary_{PLOT_MODE}.csv" summary_df.to_csv(summary_path, index=False) print("\n" + "=" * 80) print("Done.") print(f"Summary saved to: {summary_path}") ## # %% Grouped bar chart of patient-level EDSS range across 10 runs from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= INPUT_FILES = [ Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gemma-4-31B-it_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gpt-oss-120b_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "qwen3.6-27b_all_valid_predictions_long.csv" ), ] OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / "patient_level_edss_range_grouped_bar.svg" OUTPUT_PNG = OUTPUT_DIR / "patient_level_edss_range_grouped_bar.png" OUTPUT_PATIENT_RANGE_CSV = OUTPUT_DIR / "patient_level_edss_range_by_model.csv" OUTPUT_GROUPED_CSV = OUTPUT_DIR / "patient_level_edss_range_grouped_counts.csv" EDSS_COL = "EDSS_prediction" N_EXPECTED_RUNS = 10 # Choose whether to include all patients with at least one valid run, # or only patients with all 10 valid runs. USE_ONLY_COMPLETE_10_RUNS = False # USE_ONLY_COMPLETE_10_RUNS = True plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def clean_model_name(name): name = str(name) replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(name, name) def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def categorize_edss_range(value): """ Categorize patient-level EDSS range across repeated runs. """ if pd.isna(value): return np.nan if value == 0: return "0" if value <= 0.5: return "0.5" if value <= 1.0: return ">0.5–1.0" if value <= 2.0: return ">1.0–2.0" return ">2.0" # ========================= # LOAD AND COMBINE DATA # ========================= dfs = [] for path in INPUT_FILES: if not path.exists(): print(f"Skipping missing file: {path}") continue df = pd.read_csv(path) required_cols = ["model", "row_index", EDSS_COL] for col in required_cols: if col not in df.columns: raise ValueError(f"Missing column '{col}' in {path}") df = df.copy() df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() dfs.append(df) if not dfs: raise ValueError("No input data loaded.") all_df = pd.concat(dfs, ignore_index=True) all_df["model_display"] = all_df["model"].apply(clean_model_name) print(f"Loaded valid prediction rows: {len(all_df)}") print("\nRows per model:") print(all_df["model_display"].value_counts()) # ========================= # PATIENT-LEVEL RANGE # ========================= group_cols = ["model", "model_display", "row_index"] if "unique_id" in all_df.columns: group_cols.append("unique_id") patient_range_df = ( all_df .groupby(group_cols, dropna=False) .agg( n_valid_runs=("EDSS_prediction_numeric", "count"), EDSS_min=("EDSS_prediction_numeric", "min"), EDSS_max=("EDSS_prediction_numeric", "max"), EDSS_mean=("EDSS_prediction_numeric", "mean"), EDSS_median=("EDSS_prediction_numeric", "median"), EDSS_std=("EDSS_prediction_numeric", lambda x: x.std(ddof=0)), ) .reset_index() ) patient_range_df["EDSS_range"] = ( patient_range_df["EDSS_max"] - patient_range_df["EDSS_min"] ) patient_range_df["complete_10_valid_runs"] = ( patient_range_df["n_valid_runs"] == N_EXPECTED_RUNS ) patient_range_df["EDSS_range_category"] = patient_range_df["EDSS_range"].apply( categorize_edss_range ) patient_range_df.to_csv(OUTPUT_PATIENT_RANGE_CSV, index=False) # ========================= # OPTIONAL FILTER # ========================= plot_df = patient_range_df.copy() if USE_ONLY_COMPLETE_10_RUNS: plot_df = plot_df[plot_df["complete_10_valid_runs"]].copy() if plot_df.empty: raise ValueError("No patient-level data available after filtering.") # ========================= # GROUPED COUNTS AND PERCENTAGES # ========================= range_order = [ "0", "0.5", ">0.5–1.0", ">1.0–2.0", ">2.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] # Keep only models actually present model_order = [ m for m in model_order if m in plot_df["model_display"].unique() ] counts = ( plot_df .groupby(["EDSS_range_category", "model_display"]) .size() .unstack(fill_value=0) .reindex(index=range_order, columns=model_order, fill_value=0) ) percentages = counts.copy().astype(float) for model in model_order: total = counts[model].sum() if total > 0: percentages[model] = counts[model] / total * 100 else: percentages[model] = np.nan # Save combined counts and percentages combined_out = [] for range_cat in range_order: for model in model_order: combined_out.append({ "EDSS_range_category": range_cat, "model": model, "count": int(counts.loc[range_cat, model]), "percent_within_model": percentages.loc[range_cat, model], "total_patients_for_model": int(counts[model].sum()), "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, }) combined_df = pd.DataFrame(combined_out) combined_df.to_csv(OUTPUT_GROUPED_CSV, index=False) print("\nCounts:") print(counts) print("\nPercentages within model:") print(percentages.round(1)) # ========================= # PLOT # ========================= x = np.arange(len(range_order)) n_models = len(model_order) bar_width = 0.22 colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } fig, ax = plt.subplots(figsize=(10, 6)) for i, model in enumerate(model_order): values = percentages[model].values offset = (i - (n_models - 1) / 2) * bar_width bars = ax.bar( x + offset, values, width=bar_width, label=model, color=colors.get(model, None), edgecolor="white", linewidth=0.8, ) for bar, value in zip(bars, values): if pd.notna(value) and value >= 2: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f"{value:.1f}%", ha="center", va="bottom", fontsize=8, fontweight="bold", ) ax.set_xticks(x) ax.set_xticklabels(range_order, fontsize=10) ax.set_ylabel("Patients (%)", fontsize=11, fontweight="bold") ax.set_xlabel("Patient-level EDSS range across repeated runs", fontsize=11, fontweight="bold") title_suffix = "patients with all 10 valid runs" if USE_ONLY_COMPLETE_10_RUNS else "patients with available valid runs" #ax.set_title( # f"Repeated-run stability of EDSS predictions\n{title_suffix}", # fontsize=13, # fontweight="bold", # pad=15, #) ax.set_ylim(0, max(100, np.nanmax(percentages.values) + 10)) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) # Add model n below legend area as text n_text = " | ".join([ f"{model}: n={int(counts[model].sum())}" for model in model_order ]) ax.text( 0.5, 1.08, n_text, transform=ax.transAxes, ha="center", va="bottom", fontsize=9, ) plt.tight_layout(rect=[0, 0, 1, 0.90]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_PATIENT_RANGE_CSV) print(OUTPUT_GROUPED_CSV) ## # %% Simple stability figure: stable / minor variation / unstable from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= INPUT_FILES = [ Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gemma-4-31B-it_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gpt-oss-120b_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "qwen3.6-27b_all_valid_predictions_long.csv" ), ] OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / "simple_edss_stability_stacked_bar.svg" OUTPUT_PNG = OUTPUT_DIR / "simple_edss_stability_stacked_bar.png" OUTPUT_CSV = OUTPUT_DIR / "simple_edss_stability_table.csv" OUTPUT_PATIENT_LEVEL_CSV = OUTPUT_DIR / "simple_edss_stability_patient_level.csv" EDSS_COL = "EDSS_prediction" N_EXPECTED_RUNS = 10 # If True, only patients with all 10 valid predictions are included. # If False, patients with at least 2 valid predictions are included. USE_ONLY_COMPLETE_10_RUNS = False plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def classify_stability(edss_range): if pd.isna(edss_range): return np.nan if edss_range == 0: return "Identical across runs" if edss_range <= 0.5: return "Range ≤0.5" return "Range >0.5" # ========================= # LOAD DATA # ========================= dfs = [] for path in INPUT_FILES: df = pd.read_csv(path) df = df.copy() df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() df["model_display"] = df["model"].apply(clean_model_name) dfs.append(df) all_df = pd.concat(dfs, ignore_index=True) # ========================= # PATIENT-LEVEL RANGE # ========================= group_cols = ["model", "model_display", "row_index"] if "unique_id" in all_df.columns: group_cols.append("unique_id") patient_df = ( all_df .groupby(group_cols, dropna=False) .agg( n_valid_runs=("EDSS_prediction_numeric", "count"), edss_min=("EDSS_prediction_numeric", "min"), edss_max=("EDSS_prediction_numeric", "max"), ) .reset_index() ) patient_df["edss_range"] = patient_df["edss_max"] - patient_df["edss_min"] patient_df["complete_10_valid_runs"] = patient_df["n_valid_runs"] == N_EXPECTED_RUNS # Need at least 2 runs to measure variability. patient_df = patient_df[patient_df["n_valid_runs"] >= 2].copy() if USE_ONLY_COMPLETE_10_RUNS: patient_df = patient_df[patient_df["complete_10_valid_runs"]].copy() patient_df["stability_category"] = patient_df["edss_range"].apply(classify_stability) patient_df.to_csv(OUTPUT_PATIENT_LEVEL_CSV, index=False) # ========================= # SUMMARY TABLE # ========================= category_order = [ "Identical across runs", "Range ≤0.5", "Range >0.5", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in patient_df["model_display"].unique() ] counts = ( patient_df .groupby(["model_display", "stability_category"]) .size() .unstack(fill_value=0) .reindex(index=model_order, columns=category_order, fill_value=0) ) percentages = counts.div(counts.sum(axis=1), axis=0) * 100 percentages = percentages.fillna(0) summary_rows = [] for model in model_order: total = int(counts.loc[model].sum()) for category in category_order: summary_rows.append({ "model": model, "stability_category": category, "count": int(counts.loc[model, category]), "percent": percentages.loc[model, category], "total_patients": total, "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, }) summary_df = pd.DataFrame(summary_rows) summary_df.to_csv(OUTPUT_CSV, index=False) print("\nCounts:") print(counts) print("\nPercentages:") print(percentages.round(1)) # ========================= # PLOT # ========================= colors = { "Identical across runs": "#1F77B4", "Range ≤0.5": "#9ECAE1", "Range >0.5": "#F28E2B", } fig, ax = plt.subplots(figsize=(10, 5)) left = np.zeros(len(model_order)) for category in category_order: values = percentages[category].values bars = ax.barh( model_order, values, left=left, color=colors[category], edgecolor="white", linewidth=0.8, label=category, ) for i, value in enumerate(values): if value >= 5: ax.text( left[i] + value / 2, i, f"{value:.1f}%", ha="center", va="center", fontsize=9, fontweight="bold", ) left += values for i, model in enumerate(model_order): total = int(counts.loc[model].sum()) ax.text( 101, i, f"n={total}", va="center", ha="left", fontsize=9, ) ax.set_xlim(0, 110) ax.set_xlabel("Patients (%)", fontsize=11, fontweight="bold") ax.set_ylabel("Model", fontsize=11, fontweight="bold") title_suffix = ( "patients with all 10 valid runs" if USE_ONLY_COMPLETE_10_RUNS else "patients with at least 2 valid runs" ) #ax.set_title( # f"Repeated-run stability of EDSS predictions\n{title_suffix}", # fontsize=13, # fontweight="bold", # pad=15, #) ax.set_xticks(np.arange(0, 101, 10)) ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) ax.xaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right", "left"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) plt.tight_layout(rect=[0, 0, 1, 0.90]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_CSV) print(OUTPUT_PATIENT_LEVEL_CSV) ## # %% Fancy simple stability figure: rounded horizontal stacked bars from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import FancyBboxPatch # ========================= # CONFIGURATION # ========================= INPUT_FILES = [ Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gemma-4-31B-it_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "gpt-oss-120b_all_valid_predictions_long.csv" ), Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/" "run_20260528_103942/repeated_run_variability/" "qwen3.6-27b_all_valid_predictions_long.csv" ), ] OUTPUT_DIR = INPUT_FILES[0].parent / "stability_figures" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / "fancy_simple_edss_stability.svg" OUTPUT_PNG = OUTPUT_DIR / "fancy_simple_edss_stability.png" OUTPUT_CSV = OUTPUT_DIR / "fancy_simple_edss_stability_table.csv" OUTPUT_PATIENT_LEVEL_CSV = OUTPUT_DIR / "fancy_simple_edss_stability_patient_level.csv" EDSS_COL = "EDSS_prediction" N_EXPECTED_RUNS = 10 # False = include patients with at least 2 valid runs. # True = only patients with all 10 valid runs. USE_ONLY_COMPLETE_10_RUNS = False plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def classify_stability(edss_range): if pd.isna(edss_range): return np.nan if edss_range == 0: return "Identical" if edss_range <= 0.5: return "Minor variation" return "Unstable" def rounded_barh(ax, y, left, width, height, color, radius=0.16): """ Draw a rounded horizontal bar segment. """ patch = FancyBboxPatch( (left, y - height / 2), width, height, boxstyle=f"round,pad=0,rounding_size={radius}", linewidth=0, facecolor=color, ) ax.add_patch(patch) return patch # ========================= # LOAD DATA # ========================= dfs = [] for path in INPUT_FILES: df = pd.read_csv(path) df = df.copy() df["EDSS_prediction_numeric"] = to_num(df[EDSS_COL]) df = df.dropna(subset=["EDSS_prediction_numeric"]).copy() df["model_display"] = df["model"].apply(clean_model_name) dfs.append(df) all_df = pd.concat(dfs, ignore_index=True) # ========================= # PATIENT-LEVEL RANGE # ========================= group_cols = ["model", "model_display", "row_index"] if "unique_id" in all_df.columns: group_cols.append("unique_id") patient_df = ( all_df .groupby(group_cols, dropna=False) .agg( n_valid_runs=("EDSS_prediction_numeric", "count"), edss_min=("EDSS_prediction_numeric", "min"), edss_max=("EDSS_prediction_numeric", "max"), ) .reset_index() ) patient_df["edss_range"] = patient_df["edss_max"] - patient_df["edss_min"] patient_df["complete_10_valid_runs"] = patient_df["n_valid_runs"] == N_EXPECTED_RUNS # Need at least 2 repeated predictions to measure stability. patient_df = patient_df[patient_df["n_valid_runs"] >= 2].copy() if USE_ONLY_COMPLETE_10_RUNS: patient_df = patient_df[patient_df["complete_10_valid_runs"]].copy() patient_df["stability_category"] = patient_df["edss_range"].apply(classify_stability) patient_df.to_csv(OUTPUT_PATIENT_LEVEL_CSV, index=False) # ========================= # SUMMARY TABLE # ========================= category_order = [ "Identical", "Minor variation", "Unstable", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in patient_df["model_display"].unique() ] counts = ( patient_df .groupby(["model_display", "stability_category"]) .size() .unstack(fill_value=0) .reindex(index=model_order, columns=category_order, fill_value=0) ) percentages = counts.div(counts.sum(axis=1), axis=0) * 100 percentages = percentages.fillna(0) summary_rows = [] for model in model_order: total = int(counts.loc[model].sum()) for category in category_order: summary_rows.append({ "model": model, "stability_category": category, "count": int(counts.loc[model, category]), "percent": percentages.loc[model, category], "total_patients": total, "complete_10_runs_only": USE_ONLY_COMPLETE_10_RUNS, }) summary_df = pd.DataFrame(summary_rows) summary_df.to_csv(OUTPUT_CSV, index=False) print("\nPercentages:") print(percentages.round(1)) # ========================= # FANCY PLOT # ========================= colors = { "Identical": "#0B4F8A", "Minor variation": "#7DB9DE", "Unstable": "#F28E2B", } fig, ax = plt.subplots(figsize=(10.5, 5.3)) bar_height = 0.48 y_positions = np.arange(len(model_order)) for i, model in enumerate(model_order): left = 0 for category in category_order: value = percentages.loc[model, category] if value > 0: rounded_barh( ax=ax, y=i, left=left, width=value, height=bar_height, color=colors[category], radius=0.13, ) if value >= 6: ax.text( left + value / 2, i, f"{value:.1f}%", ha="center", va="center", fontsize=10, fontweight="bold", color="white" if category in ["Identical", "Unstable"] else "black", ) left += value total = int(counts.loc[model].sum()) ax.text( 103, i, f"n={total}", va="center", ha="left", fontsize=10, color="#333333", ) # Main stability label at left identical = percentages.loc[model, "Identical"] minor = percentages.loc[model, "Minor variation"] stable_or_minor = identical + minor ax.text( -3, i - 0.33, f"{stable_or_minor:.1f}% ≤0.5 range", va="center", ha="right", fontsize=9, color="#444444", ) # Y-axis model labels ax.set_yticks(y_positions) ax.set_yticklabels(model_order, fontsize=11, fontweight="bold") ax.set_xlim(-18, 112) ax.set_ylim(-0.8, len(model_order) - 0.2) ax.set_xlabel("Patients (%)", fontsize=11, fontweight="bold") ax.set_title( "Repeated-run stability of EDSS predictions", fontsize=15, fontweight="bold", pad=18, ) subtitle = ( "Patient-level EDSS range across repeated model runs " "(identical, minor variation, or unstable)" ) ax.text( 0.5, 1.02, subtitle, transform=ax.transAxes, ha="center", va="bottom", fontsize=10, color="#555555", ) # X-axis formatting ax.set_xticks(np.arange(0, 101, 20)) ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 20)]) ax.xaxis.grid(True, linestyle="--", alpha=0.25) ax.set_axisbelow(True) # Clean style for spine in ["top", "right", "left", "bottom"]: ax.spines[spine].set_visible(False) ax.tick_params(axis="y", length=0) ax.tick_params(axis="x", length=0) # Legend legend_handles = [ plt.Rectangle((0, 0), 1, 1, color=colors["Identical"]), plt.Rectangle((0, 0), 1, 1, color=colors["Minor variation"]), plt.Rectangle((0, 0), 1, 1, color=colors["Unstable"]), ] ax.legend( legend_handles, [ "Identical across runs", "Range ≤0.5", "Range >0.5", ], loc="lower center", bbox_to_anchor=(0.5, -0.18), ncol=3, frameon=False, fontsize=10, ) plt.tight_layout(rect=[0, 0.05, 1, 0.95]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_CSV) print(OUTPUT_PATIENT_LEVEL_CSV) ## # %% Functional system heatmap: MAE by functional system and model from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns # ========================= # CONFIGURATION # ========================= RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) INPUT_METRICS_PATH = ( RUN_DIR / "functional_system_metrics_iter_1" / "functional_system_metrics_short_iter_1.csv" ) OUTPUT_DIR = RUN_DIR / "functional_system_heatmaps" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / "functional_system_mae_heatmap.svg" OUTPUT_PNG = OUTPUT_DIR / "functional_system_mae_heatmap.png" OUTPUT_CSV = OUTPUT_DIR / "functional_system_mae_heatmap_table.csv" plt.rcParams["font.family"] = "Arial" # ========================= # SETTINGS # ========================= MODEL_ORDER = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] FUNCTIONAL_SYSTEM_ORDER = [ "Visual/optic functions", "Brainstem functions", "Pyramidal functions", "Cerebellar functions", "Sensory functions", "Bowel and bladder functions", "Cerebral functions", "Ambulation", ] MODEL_NAME_MAP = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } # ========================= # LOAD DATA # ========================= df = pd.read_csv(INPUT_METRICS_PATH) required_cols = ["model", "functional_system", "MAE"] for col in required_cols: if col not in df.columns: raise ValueError(f"Missing required column: {col}") df = df.copy() df["model_display"] = df["model"].map(MODEL_NAME_MAP).fillna(df["model"]) df["functional_system"] = df["functional_system"].replace({ "Visual/optic functions": "Visual/optic functions", "Brainstem functions": "Brainstem functions", "Pyramidal functions": "Pyramidal functions", "Cerebellar functions": "Cerebellar functions", "Sensory functions": "Sensory functions", "Bowel and bladder functions": "Bowel and bladder functions", "Cerebral functions": "Cerebral functions", "Ambulation": "Ambulation", }) df["MAE"] = pd.to_numeric(df["MAE"], errors="coerce") # ========================= # PIVOT TABLE # ========================= heatmap_df = ( df .pivot_table( index="functional_system", columns="model_display", values="MAE", aggfunc="mean" ) .reindex(index=FUNCTIONAL_SYSTEM_ORDER, columns=MODEL_ORDER) ) heatmap_df.to_csv(OUTPUT_CSV) print("\nMAE heatmap table:") print(heatmap_df) # ========================= # PLOT # ========================= fig, ax = plt.subplots(figsize=(8, 6.5)) sns.heatmap( heatmap_df, annot=True, fmt=".2f", cmap="Blues_r", # lower MAE appears darker/better linewidths=0.8, linecolor="white", cbar_kws={"label": "Mean absolute error"}, ax=ax, ) ax.set_xlabel("Model", fontsize=11, fontweight="bold") ax.set_ylabel("Functional system", fontsize=11, fontweight="bold") ax.set_title( "Functional system performance by model\nMean absolute error", fontsize=13, fontweight="bold", pad=15, ) plt.xticks(rotation=30, ha="right") plt.yticks(rotation=0) plt.tight_layout() plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_CSV) print(OUTPUT_SVG) print(OUTPUT_PNG) ## # %% Confidence bracket vs EDSS error grouped by model from pathlib import Path import re import pandas as pd import numpy as np import matplotlib.pyplot as plt from scipy.stats import pearsonr # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_error_analysis_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_bracket_mae_grouped_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_bracket_mae_grouped_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_mae_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_error_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" ADD_TREND_LINES = True plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def safe_name(name): return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(name)) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low (<70%)" if certainty < 80: return "Moderate (70–80%)" if certainty < 90: return "High (80–90%)" if certainty <= 100: return "Very High (90–100%)" return np.nan def confidence_midpoint(bracket): midpoint_map = { "Low (<70%)": 65, "Moderate (70–80%)": 75, "High (80–90%)": 85, "Very High (90–100%)": 95, } return midpoint_map.get(bracket, np.nan) def sem(series): values = pd.to_numeric(series, errors="coerce").dropna() if len(values) <= 1: return 0.0 return values.std(ddof=1) / np.sqrt(len(values)) # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # LOAD MODEL PREDICTIONS AND BUILD LONG ERROR DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence_error_analysis") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print("Skipping: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: print("No evaluable rows.") continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["confidence_midpoint"] = merged["confidence_bracket"].apply(confidence_midpoint) merged = merged.dropna(subset=["confidence_bracket"]).copy() print(f"Evaluable rows with confidence bracket: {len(merged)}") for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "confidence_midpoint": row["confidence_midpoint"], "error": row["error"], "abs_error": row["abs_error"], "inference_time_sec": row.get("inference_time_sec", np.nan), "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY BY MODEL AND CONFIDENCE BRACKET # ========================= bracket_order = [ "Low (<70%)", "Moderate (70–80%)", "High (80–90%)", "Very High (90–100%)", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "confidence_bracket"], observed=False) .agg( n=("abs_error", "count"), MAE=("abs_error", "mean"), median_abs_error=("abs_error", "median"), SEM=("abs_error", sem), mean_certainty=("certainty_percent", "mean"), ) .reset_index() ) # Ensure full model x bracket grid exists full_index = pd.MultiIndex.from_product( [model_order, bracket_order], names=["model_display", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["confidence_midpoint"] = summary["confidence_bracket"].apply(confidence_midpoint) summary.to_csv(OUTPUT_TABLE, index=False) print("\nConfidence-bracket MAE table:") print(summary) # ========================= # CORRELATION PER MODEL # ========================= corr_text = {} for model in model_order: df_m = long_df[long_df["model_display"] == model].copy() if len(df_m) >= 3 and df_m["certainty_percent"].nunique() > 1 and df_m["abs_error"].nunique() > 1: r, p = pearsonr(df_m["certainty_percent"], df_m["abs_error"]) corr_text[model] = f"r={r:.2f}, p={p:.2g}, n={len(df_m)}" else: corr_text[model] = f"r=NA, n={len(df_m)}" # ========================= # PLOT # ========================= colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } x = np.arange(len(bracket_order)) n_models = len(model_order) bar_width = 0.22 fig, ax = plt.subplots(figsize=(12, 7)) for i, model in enumerate(model_order): df_m = summary[summary["model_display"] == model].copy() df_m = df_m.set_index("confidence_bracket").reindex(bracket_order).reset_index() values = df_m["MAE"].values errors = df_m["SEM"].fillna(0).values ns = df_m["n"].fillna(0).astype(int).values offset = (i - (n_models - 1) / 2) * bar_width bars = ax.bar( x + offset, values, width=bar_width, yerr=errors, capsize=4, color=colors.get(model, None), edgecolor="black", linewidth=0.6, alpha=0.85, label=model, ) for bar, value, n in zip(bars, values, ns): if pd.notna(value) and n > 0: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.035, f"{value:.2f}\nn={n}", ha="center", va="bottom", fontsize=8, fontweight="bold", ) if ADD_TREND_LINES: valid = df_m.dropna(subset=["MAE", "confidence_midpoint"]) if len(valid) >= 2: ax.plot( x + offset, df_m["MAE"].values, linestyle="--", linewidth=1.5, color=colors.get(model, None), alpha=0.9, ) ax.set_xticks(x) ax.set_xticklabels(bracket_order, fontsize=10) ax.set_ylabel("Mean absolute EDSS error", fontsize=11, fontweight="bold") ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") ax.set_title( "EDSS prediction error across LLM confidence brackets", fontsize=14, fontweight="bold", pad=15, ) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="upper right", frameon=True, title="Model", ) # Add correlation text box corr_lines = ["Pearson correlation: confidence vs absolute error"] for model in model_order: corr_lines.append(f"{model}: {corr_text[model]}") ax.text( 0.02, 0.98, "\n".join(corr_lines), transform=ax.transAxes, ha="left", va="top", fontsize=9, bbox=dict( boxstyle="round,pad=0.4", facecolor="white", edgecolor="#999999", alpha=0.9, ), ) # Add metric explanation ax.text( 0.98, 0.02, "Bars: MAE\nError bars: SEM\nDashed lines: bracket trend", transform=ax.transAxes, ha="right", va="bottom", fontsize=9, bbox=dict( boxstyle="round,pad=0.4", facecolor="white", edgecolor="#CCCCCC", alpha=0.9, ), ) plt.tight_layout() plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Confidence bracket vs clinically acceptable EDSS accuracy grouped by model from pathlib import Path import re import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_analysis_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_accuracy_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["exact_match"] = merged["abs_error"] == 0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged = merged.dropna(subset=["confidence_bracket"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), mean_abs_error=("abs_error", "mean"), median_abs_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, bracket_order], names=["model_display", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nConfidence-bracket accuracy table:") print(summary) # ========================= # PLOT # ========================= x = np.arange(len(bracket_order)) n_models = len(model_order) bar_width = 0.22 colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } fig, ax = plt.subplots(figsize=(11, 6.5)) for i, model in enumerate(model_order): df_m = ( summary[summary["model_display"] == model] .set_index("confidence_bracket") .reindex(bracket_order) .reset_index() ) values = df_m["accuracy_within_0_5_percent"].values ns = df_m["n"].fillna(0).astype(int).values offset = (i - (n_models - 1) / 2) * bar_width bars = ax.bar( x + offset, values, width=bar_width, color=colors.get(model), edgecolor="white", linewidth=0.8, label=model, ) for bar, value, n in zip(bars, values, ns): if pd.notna(value) and n > 0: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.2, f"{value:.1f}%\nn={n}", ha="center", va="bottom", fontsize=8, fontweight="bold", ) ax.set_xticks(x) ax.set_xticklabels(bracket_order, fontsize=10) ax.set_ylim(0, 110) ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") #ax.set_title( # "Accuracy of EDSS predictions by confidence bracket", # fontsize=14, # fontweight="bold", # pad=15, #) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) ax.text( 0.5, -0.18, "Higher bars indicate better calibration: high-confidence predictions are more often clinically close to the reference EDSS.", transform=ax.transAxes, ha="center", va="top", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.05, 1, 0.92]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Confidence bracket accuracy + predicted EDSS range distribution from pathlib import Path import re import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_analysis_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_ACCURACY_SVG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.svg" OUTPUT_ACCURACY_PNG = OUTPUT_DIR / f"confidence_bracket_accuracy_within_0_5_iter_{TARGET_ITERATION}.png" OUTPUT_RANGE_SVG = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_iter_{TARGET_ITERATION}.svg" OUTPUT_RANGE_PNG = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_bracket_accuracy_table_iter_{TARGET_ITERATION}.csv" OUTPUT_RANGE_TABLE = OUTPUT_DIR / f"confidence_bracket_predicted_edss_range_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_group(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["exact_match"] = merged["abs_error"] == 0 merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "GT_EDSS_group": row["GT_EDSS_group"], "PRED_EDSS_group": row["PRED_EDSS_group"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY: ACCURACY BY CONFIDENCE # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), mean_abs_error=("abs_error", "mean"), median_abs_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, bracket_order], names=["model_display", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) # ========================= # FIGURE 1: ACCURACY WITHIN ±0.5 BY CONFIDENCE # ========================= x = np.arange(len(bracket_order)) n_models = len(model_order) bar_width = 0.22 model_colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } fig, ax = plt.subplots(figsize=(11, 6.5)) for i, model in enumerate(model_order): df_m = ( summary[summary["model_display"] == model] .set_index("confidence_bracket") .reindex(bracket_order) .reset_index() ) values = df_m["accuracy_within_0_5_percent"].values ns = df_m["n"].fillna(0).astype(int).values offset = (i - (n_models - 1) / 2) * bar_width bars = ax.bar( x + offset, values, width=bar_width, color=model_colors.get(model), edgecolor="white", linewidth=0.8, label=model, ) for bar, value, n in zip(bars, values, ns): if pd.notna(value) and n > 0: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 1.2, f"{value:.1f}%\nn={n}", ha="center", va="bottom", fontsize=8, fontweight="bold", ) ax.set_xticks(x) ax.set_xticklabels(bracket_order, fontsize=10) ax.set_ylim(0, 110) ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") ax.set_title( "Accuracy of EDSS predictions by confidence bracket", fontsize=14, fontweight="bold", pad=15, ) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) plt.tight_layout(rect=[0, 0.03, 1, 0.92]) plt.savefig(OUTPUT_ACCURACY_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_ACCURACY_PNG, dpi=300, bbox_inches="tight") plt.show() # ========================= # SUMMARY: PREDICTED EDSS RANGE BY CONFIDENCE # ========================= range_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] range_colors = { "0.0–3.5": "#9ECAE1", "4.0–5.5": "#FDDC7A", "6.0–10.0": "#F28E2B", } range_rows = [] for model in model_order: df_m = long_df[long_df["model_display"] == model].copy() for bracket in bracket_order: df_b = df_m[df_m["confidence_bracket"] == bracket].copy() total = len(df_b) for edss_range in range_order: count = int((df_b["PRED_EDSS_group"] == edss_range).sum()) percent = count / total * 100 if total > 0 else np.nan range_rows.append({ "model": model, "confidence_bracket": bracket, "predicted_EDSS_range": edss_range, "count": count, "total_in_confidence_bracket": total, "percent": percent, }) range_df = pd.DataFrame(range_rows) range_df.to_csv(OUTPUT_RANGE_TABLE, index=False) # ========================= # FIGURE 2: PREDICTED EDSS RANGE BY CONFIDENCE # ========================= fig, axes = plt.subplots( nrows=1, ncols=len(model_order), figsize=(5 * len(model_order), 5.5), sharey=True ) if len(model_order) == 1: axes = [axes] for ax, model in zip(axes, model_order): df_m = range_df[range_df["model"] == model].copy() left = np.zeros(len(bracket_order)) for edss_range in range_order: values = [] for bracket in bracket_order: value = df_m.loc[ (df_m["confidence_bracket"] == bracket) & (df_m["predicted_EDSS_range"] == edss_range), "percent" ] if len(value) == 0: values.append(0) else: values.append(value.iloc[0] if pd.notna(value.iloc[0]) else 0) bars = ax.bar( bracket_order, values, bottom=left, color=range_colors[edss_range], edgecolor="white", linewidth=0.8, label=edss_range, ) for i, value in enumerate(values): if value >= 8: ax.text( i, left[i] + value / 2, f"{value:.0f}%", ha="center", va="center", fontsize=8, fontweight="bold", ) left += np.array(values) # n labels above bars for i, bracket in enumerate(bracket_order): total = df_m.loc[ df_m["confidence_bracket"] == bracket, "total_in_confidence_bracket" ] total_n = int(total.iloc[0]) if len(total) > 0 and pd.notna(total.iloc[0]) else 0 ax.text( i, 102, f"n={total_n}", ha="center", va="bottom", fontsize=8, ) ax.set_title(model, fontsize=12, fontweight="bold") ax.set_xlabel("Confidence bracket", fontsize=10, fontweight="bold") ax.set_ylim(0, 110) ax.set_xticklabels(bracket_order, rotation=0, fontsize=8) ax.yaxis.grid(True, linestyle="--", alpha=0.25) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) axes[0].set_ylabel("Predicted EDSS range (%)", fontsize=11, fontweight="bold") handles, labels = axes[-1].get_legend_handles_labels() fig.legend( handles, labels, title="Predicted EDSS range", loc="lower center", bbox_to_anchor=(0.5, -0.02), ncol=3, frameon=False, ) fig.suptitle( "Predicted EDSS range within each confidence bracket", fontsize=14, fontweight="bold", y=1.03, ) plt.tight_layout(rect=[0, 0.07, 1, 0.96]) plt.savefig(OUTPUT_RANGE_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_RANGE_PNG, dpi=300, bbox_inches="tight") plt.show() # ========================= # DONE # ========================= print("\nSaved:") print(OUTPUT_ACCURACY_SVG) print(OUTPUT_ACCURACY_PNG) print(OUTPUT_RANGE_SVG) print(OUTPUT_RANGE_PNG) print(OUTPUT_TABLE) print(OUTPUT_RANGE_TABLE) print(OUTPUT_LONG) ## # %% Heatmap: confidence bracket x predicted EDSS range x accuracy, one panel per model from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_heatmap_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_group(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["abs_error"] = ( merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] ).abs() merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "GT_EDSS_group": row["GT_EDSS_group"], "PRED_EDSS_group": row["PRED_EDSS_group"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "abs_error": row["abs_error"], "within_0_5": row["within_0_5"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # AGGREGATE FOR HEATMAP # ========================= confidence_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] edss_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "PRED_EDSS_group", "confidence_bracket"]) .agg( n=("within_0_5", "count"), accuracy_within_0_5=("within_0_5", "mean"), mean_abs_error=("abs_error", "mean"), median_abs_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, edss_order, confidence_order], names=["model_display", "PRED_EDSS_group", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "PRED_EDSS_group", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["n"] = summary["n"].fillna(0).astype(int) summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nHeatmap summary table:") print(summary) # ========================= # PLOT # ========================= fig, axes = plt.subplots( nrows=1, ncols=len(model_order), figsize=(5.2 * len(model_order), 4.8), sharey=True ) if len(model_order) == 1: axes = [axes] for ax, model in zip(axes, model_order): df_m = summary[summary["model_display"] == model].copy() heatmap_values = ( df_m .pivot( index="PRED_EDSS_group", columns="confidence_bracket", values="accuracy_within_0_5_percent" ) .reindex(index=edss_order, columns=confidence_order) ) heatmap_n = ( df_m .pivot( index="PRED_EDSS_group", columns="confidence_bracket", values="n" ) .reindex(index=edss_order, columns=confidence_order) .fillna(0) .astype(int) ) annotations = heatmap_values.copy().astype(object) for r in edss_order: for c in confidence_order: value = heatmap_values.loc[r, c] n = heatmap_n.loc[r, c] if n == 0 or pd.isna(value): annotations.loc[r, c] = "" else: annotations.loc[r, c] = f"{value:.0f}%\nn={n}" sns.heatmap( heatmap_values, ax=ax, annot=annotations, fmt="", cmap="Blues", vmin=0, vmax=100, linewidths=1, linecolor="white", cbar=False, square=False, ) ax.set_title(model, fontsize=12, fontweight="bold") ax.set_xlabel("LLM confidence bracket", fontsize=10, fontweight="bold") ax.set_ylabel("Predicted EDSS range" if ax == axes[0] else "", fontsize=10, fontweight="bold") ax.set_xticklabels(confidence_order, rotation=0, fontsize=8) ax.set_yticklabels(edss_order, rotation=0, fontsize=9) # Shared colorbar mappable = axes[-1].collections[0] cbar = fig.colorbar( mappable, ax=axes, orientation="vertical", fraction=0.025, pad=0.02, ) cbar.set_label("Accuracy within ±0.5 EDSS (%)", fontsize=10, fontweight="bold") fig.suptitle( "Confidence-stratified EDSS accuracy by predicted severity range", fontsize=14, fontweight="bold", y=1.03, ) fig.text( 0.5, 0.01, "Cell color shows accuracy within ±0.5 EDSS; text shows accuracy and number of predictions.", ha="center", va="bottom", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.05, 0.97, 0.95]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Improved heatmap: confidence bracket x predicted EDSS range x accuracy, one model per row from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_heatmap_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_edss_range_accuracy_heatmap_vertical_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_group(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt["GT_EDSS_group"] = gt["GT_EDSS_numeric"].apply(edss_group) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["abs_error"] = ( merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] ).abs() merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["PRED_EDSS_group"] = merged["PRED_EDSS_numeric"].apply(edss_group) merged = merged.dropna(subset=["confidence_bracket", "PRED_EDSS_group"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "GT_EDSS_group": row["GT_EDSS_group"], "PRED_EDSS_group": row["PRED_EDSS_group"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "abs_error": row["abs_error"], "within_0_5": row["within_0_5"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # AGGREGATE FOR HEATMAP # ========================= confidence_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] edss_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "PRED_EDSS_group", "confidence_bracket"]) .agg( n=("within_0_5", "count"), accuracy_within_0_5=("within_0_5", "mean"), mean_abs_error=("abs_error", "mean"), median_abs_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, edss_order, confidence_order], names=["model_display", "PRED_EDSS_group", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "PRED_EDSS_group", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["n"] = summary["n"].fillna(0).astype(int) summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nHeatmap summary table:") print(summary) # ========================= # PLOT - ONE MODEL PER ROW # ========================= n_models = len(model_order) fig, axes = plt.subplots( nrows=n_models, ncols=1, figsize=(8.5, 3.1 * n_models), sharex=True, constrained_layout=False ) if n_models == 1: axes = [axes] cbar_ax = fig.add_axes([0.92, 0.18, 0.025, 0.65]) for i, (ax, model) in enumerate(zip(axes, model_order)): df_m = summary[summary["model_display"] == model].copy() heatmap_values = ( df_m .pivot( index="PRED_EDSS_group", columns="confidence_bracket", values="accuracy_within_0_5_percent" ) .reindex(index=edss_order, columns=confidence_order) ) heatmap_n = ( df_m .pivot( index="PRED_EDSS_group", columns="confidence_bracket", values="n" ) .reindex(index=edss_order, columns=confidence_order) .fillna(0) .astype(int) ) annotations = heatmap_values.copy().astype(object) for r in edss_order: for c in confidence_order: value = heatmap_values.loc[r, c] n = heatmap_n.loc[r, c] if n == 0 or pd.isna(value): annotations.loc[r, c] = "" else: annotations.loc[r, c] = f"{value:.0f}%\nn={n}" mask = heatmap_n == 0 sns.heatmap( heatmap_values, ax=ax, annot=annotations, fmt="", cmap="Blues", vmin=0, vmax=100, mask=mask, linewidths=1, linecolor="white", cbar=(i == 0), cbar_ax=cbar_ax if i == 0 else None, cbar_kws={"label": "Accuracy within ±0.5 EDSS (%)"}, ) # Grey background for empty cells ax.set_facecolor("#F2F2F2") ax.set_title(model, fontsize=12, fontweight="bold", loc="left", pad=8) ax.set_ylabel("Predicted EDSS range", fontsize=10, fontweight="bold") ax.set_xlabel("") ax.set_yticklabels(edss_order, rotation=0, fontsize=9) if i == n_models - 1: ax.set_xlabel("LLM confidence bracket", fontsize=10, fontweight="bold") ax.set_xticklabels(confidence_order, rotation=0, fontsize=9) else: ax.set_xticklabels([]) fig.suptitle( "Confidence-stratified EDSS accuracy by predicted severity range", fontsize=14, fontweight="bold", y=0.98, ) fig.text( 0.5, 0.03, "Cell color shows accuracy within ±0.5 EDSS; text shows accuracy and number of predictions. Empty grey cells indicate no predictions.", ha="center", va="center", fontsize=9, color="#555555", ) plt.subplots_adjust( left=0.17, right=0.89, top=0.92, bottom=0.09, hspace=0.38 ) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Line plot: confidence-stratified EDSS accuracy by model from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_lineplot_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_lineplot_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_lineplot_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_stratified_edss_accuracy_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: print("No evaluable rows.") continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["exact_match"] = merged["abs_error"] == 0 merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged = merged.dropna(subset=["confidence_bracket"]).copy() print(f"Evaluable rows with confidence bracket: {len(merged)}") for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "inference_time_sec": row.get("inference_time_sec", np.nan), "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY BY CONFIDENCE BRACKET # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), MAE=("abs_error", "mean"), median_absolute_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, bracket_order], names=["model_display", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nConfidence-stratified accuracy table:") print(summary) # ========================= # LINE PLOT # ========================= x = np.arange(len(bracket_order)) colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } markers = { "GPT-OSS-120B": "o", "Qwen3.6-27B": "s", "Gemma-4-31B-it": "^", } fig, ax = plt.subplots(figsize=(9.5, 6)) for model in model_order: df_m = ( summary[summary["model_display"] == model] .set_index("confidence_bracket") .reindex(bracket_order) .reset_index() ) y = df_m["accuracy_within_0_5_percent"].values n = df_m["n"].fillna(0).astype(int).values ax.plot( x, y, marker=markers.get(model, "o"), markersize=8, linewidth=2.2, color=colors.get(model), label=model, ) for xi, yi, ni in zip(x, y, n): if pd.notna(yi) and ni > 0: ax.text( xi, yi + 2.2, f"{yi:.1f}%\nn={ni}", ha="center", va="bottom", fontsize=8, color=colors.get(model), fontweight="bold", ) ax.set_xticks(x) ax.set_xticklabels(bracket_order, fontsize=10) ax.set_ylim(0, 110) ax.set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") ax.set_xlabel("LLM confidence bracket", fontsize=11, fontweight="bold") ax.set_title( "Confidence-stratified EDSS accuracy by model", fontsize=14, fontweight="bold", pad=15, ) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.yaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False, ) ax.text( 0.5, -0.18, "Higher values indicate a larger proportion of predictions within ±0.5 EDSS of the reference score.", transform=ax.transAxes, ha="center", va="top", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.05, 1, 0.92]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Line plot: confidence-stratified EDSS accuracy by predicted EDSS range from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_lineplot_by_edss_range_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_range(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue print("\n" + "=" * 100) print(f"Model folder: {model_dir.name}") print(f"Result file: {result_file}") pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: print("No evaluable rows.") continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["exact_match"] = merged["abs_error"] == 0 merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() print(f"Evaluable rows with confidence bracket and EDSS range: {len(merged)}") for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "unique_id": row.get("unique_id_gt", row.get("unique_id", None)), "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "predicted_EDSS_range": row["predicted_EDSS_range"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "inference_time_sec": row.get("inference_time_sec", np.nan), "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] range_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), MAE=("abs_error", "mean"), median_absolute_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, range_order, bracket_order], names=["model_display", "predicted_EDSS_range", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nConfidence-stratified accuracy by predicted EDSS range:") print(summary) # ========================= # PLOT: SMALL MULTIPLE LINE PLOT BY EDSS RANGE # ========================= x = np.arange(len(bracket_order)) colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } markers = { "GPT-OSS-120B": "o", "Qwen3.6-27B": "s", "Gemma-4-31B-it": "^", } fig, axes = plt.subplots( nrows=1, ncols=len(range_order), figsize=(5.1 * len(range_order), 5.8), sharey=True ) if len(range_order) == 1: axes = [axes] for ax, edss_r in zip(axes, range_order): for model in model_order: df_m = ( summary[ (summary["model_display"] == model) & (summary["predicted_EDSS_range"] == edss_r) ] .set_index("confidence_bracket") .reindex(bracket_order) .reset_index() ) y = df_m["accuracy_within_0_5_percent"].values n = df_m["n"].fillna(0).astype(int).values ax.plot( x, y, marker=markers.get(model, "o"), markersize=7, linewidth=2.0, color=colors.get(model), label=model, ) for xi, yi, ni in zip(x, y, n): if pd.notna(yi) and ni > 0: ax.text( xi, yi + 2.2, f"{yi:.0f}%\nn={ni}", ha="center", va="bottom", fontsize=7, color=colors.get(model), fontweight="bold", ) ax.set_title( f"Predicted EDSS {edss_r}", fontsize=12, fontweight="bold", pad=10, ) ax.set_xticks(x) ax.set_xticklabels(bracket_order, fontsize=8) ax.set_ylim(0, 112) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.grid(True, axis="y", linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") handles, labels = axes[0].get_legend_handles_labels() fig.legend( handles, labels, loc="lower center", bbox_to_anchor=(0.5, -0.01), ncol=3, frameon=False, ) fig.suptitle( "Confidence-stratified EDSS accuracy by predicted EDSS range", fontsize=14, fontweight="bold", y=1.02, ) fig.text( 0.5, 0.045, "Each panel shows predictions within a predicted EDSS severity range. Point labels show accuracy and number of predictions.", ha="center", va="center", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.08, 1, 0.94]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Dot plot: confidence-stratified EDSS accuracy by predicted EDSS range from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_dotplot_by_edss_range_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" # Hide very small cells from the plot but keep them in the CSV. MIN_N_TO_PLOT = 5 plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_range(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["exact_match"] = merged["abs_error"] == 0 merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "predicted_EDSS_range": row["predicted_EDSS_range"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] range_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), MAE=("abs_error", "mean"), median_absolute_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, range_order, bracket_order], names=["model_display", "predicted_EDSS_range", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary.to_csv(OUTPUT_TABLE, index=False) print("\nSummary:") print(summary) # ========================= # DOT PLOT # ========================= colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } markers = { "GPT-OSS-120B": "o", "Qwen3.6-27B": "s", "Gemma-4-31B-it": "^", } x_positions = { "Low\n<70%": 0, "Moderate\n70–80%": 1, "High\n80–90%": 2, "Very high\n90–100%": 3, } model_offsets = { "GPT-OSS-120B": -0.18, "Qwen3.6-27B": 0.00, "Gemma-4-31B-it": 0.18, } fig, axes = plt.subplots( nrows=1, ncols=len(range_order), figsize=(14, 5.5), sharey=True ) if len(range_order) == 1: axes = [axes] for ax, edss_r in zip(axes, range_order): df_r = summary[summary["predicted_EDSS_range"] == edss_r].copy() for model in model_order: df_m = df_r[df_r["model_display"] == model].copy() for _, row in df_m.iterrows(): n = row["n"] acc = row["accuracy_within_0_5_percent"] bracket = row["confidence_bracket"] if pd.isna(acc) or n < MIN_N_TO_PLOT: continue x = x_positions[bracket] + model_offsets.get(model, 0) ax.scatter( x, acc, s=45 + n * 2.0, color=colors[model], marker=markers[model], alpha=0.85, edgecolor="black", linewidth=0.6, label=model, ) ax.text( x, acc + 2.0, f"{acc:.0f}%\nn={int(n)}", ha="center", va="bottom", fontsize=7, color=colors[model], fontweight="bold", ) ax.set_title( f"Predicted EDSS {edss_r}", fontsize=12, fontweight="bold", pad=10, ) ax.set_xticks(list(x_positions.values())) ax.set_xticklabels(bracket_order, fontsize=8) ax.set_ylim(0, 112) ax.set_yticks(np.arange(0, 101, 10)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 10)]) ax.grid(True, axis="y", linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") handles = [ plt.Line2D( [0], [0], marker=markers[model], color="w", label=model, markerfacecolor=colors[model], markeredgecolor="black", markersize=8, ) for model in model_order ] fig.legend( handles=handles, loc="lower center", bbox_to_anchor=(0.5, -0.02), ncol=3, frameon=False, ) fig.suptitle( "Confidence-stratified EDSS accuracy by predicted EDSS range", fontsize=14, fontweight="bold", y=1.02, ) fig.text( 0.5, 0.045, f"Points show accuracy within ±0.5 EDSS. Point size reflects n. Cells with n < {MIN_N_TO_PLOT} are hidden.", ha="center", va="center", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.08, 1, 0.94]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Clean dot plot: confidence-stratified EDSS accuracy by predicted EDSS range from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_dotplot_by_edss_range_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_SVG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_iter_{TARGET_ITERATION}.svg" OUTPUT_PNG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_iter_{TARGET_ITERATION}.png" OUTPUT_TABLE = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_table_iter_{TARGET_ITERATION}.csv" OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_by_predicted_edss_range_dotplot_clean_long_iter_{TARGET_ITERATION}.csv" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" # Hide very small cells from the plot but keep them in the CSV. MIN_N_TO_PLOT = 5 plt.rcParams["font.family"] = "Arial" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low\n<70%" if certainty < 80: return "Moderate\n70–80%" if certainty < 90: return "High\n80–90%" if certainty <= 100: return "Very high\n90–100%" return np.nan def edss_range(value): if pd.isna(value): return np.nan if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return np.nan def size_from_n(n): """ Convert n to marker size. """ return 35 + (n * 4.0) # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() if "EDSS_is_numeric" in pred.columns: pred = pred[to_bool(pred["EDSS_is_numeric"])].copy() if "EDSS_in_valid_range" in pred.columns: pred = pred[to_bool(pred["EDSS_in_valid_range"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) pred = pred.dropna(subset=["PRED_EDSS_numeric", "certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["error"] = merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"] merged["abs_error"] = merged["error"].abs() merged["exact_match"] = merged["abs_error"] == 0 merged["within_0_5"] = merged["abs_error"] <= 0.5 merged["within_1_0"] = merged["abs_error"] <= 1.0 merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range) merged = merged.dropna(subset=["confidence_bracket", "predicted_EDSS_range"]).copy() for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "predicted_EDSS_range": row["predicted_EDSS_range"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "error": row["error"], "abs_error": row["abs_error"], "exact_match": row["exact_match"], "within_0_5": row["within_0_5"], "within_1_0": row["within_1_0"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY # ========================= bracket_order = [ "Low\n<70%", "Moderate\n70–80%", "High\n80–90%", "Very high\n90–100%", ] range_order = [ "0.0–3.5", "4.0–5.5", "6.0–10.0", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] summary = ( long_df .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .agg( n=("within_0_5", "count"), exact_accuracy=("exact_match", "mean"), accuracy_within_0_5=("within_0_5", "mean"), accuracy_within_1_0=("within_1_0", "mean"), MAE=("abs_error", "mean"), median_absolute_error=("abs_error", "median"), mean_confidence=("certainty_percent", "mean"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, range_order, bracket_order], names=["model_display", "predicted_EDSS_range", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["exact_accuracy_percent"] = summary["exact_accuracy"] * 100 summary["accuracy_within_0_5_percent"] = summary["accuracy_within_0_5"] * 100 summary["accuracy_within_1_0_percent"] = summary["accuracy_within_1_0"] * 100 summary["shown_in_plot"] = summary["n"].fillna(0) >= MIN_N_TO_PLOT summary.to_csv(OUTPUT_TABLE, index=False) print("\nSummary:") print(summary) # ========================= # CLEAN DOT PLOT # ========================= colors = { "GPT-OSS-120B": "#1F77B4", "Qwen3.6-27B": "#FF7F0E", "Gemma-4-31B-it": "#2CA02C", } markers = { "GPT-OSS-120B": "o", "Qwen3.6-27B": "s", "Gemma-4-31B-it": "^", } x_positions = { "Low\n<70%": 0, "Moderate\n70–80%": 1, "High\n80–90%": 2, "Very high\n90–100%": 3, } model_offsets = { "GPT-OSS-120B": -0.18, "Qwen3.6-27B": 0.00, "Gemma-4-31B-it": 0.18, } fig, axes = plt.subplots( nrows=1, ncols=len(range_order), figsize=(14, 5.3), sharey=True ) if len(range_order) == 1: axes = [axes] for ax, edss_r in zip(axes, range_order): df_r = summary[ (summary["predicted_EDSS_range"] == edss_r) & (summary["shown_in_plot"]) ].copy() for model in model_order: df_m = df_r[df_r["model_display"] == model].copy() for _, row in df_m.iterrows(): n = int(row["n"]) acc = row["accuracy_within_0_5_percent"] bracket = row["confidence_bracket"] if pd.isna(acc): continue x = x_positions[bracket] + model_offsets.get(model, 0) ax.scatter( x, acc, s=size_from_n(n), color=colors[model], marker=markers[model], alpha=0.85, edgecolor="black", linewidth=0.6, ) ax.set_title( f"Predicted EDSS {edss_r}", fontsize=12, fontweight="bold", pad=10, ) ax.set_xticks(list(x_positions.values())) ax.set_xticklabels(bracket_order, fontsize=8) ax.set_ylim(0, 105) ax.set_yticks(np.arange(0, 101, 20)) ax.set_yticklabels([f"{y}%" for y in np.arange(0, 101, 20)]) ax.grid(True, axis="y", linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right"]: ax.spines[spine].set_visible(False) ax.set_xlabel("LLM confidence bracket", fontsize=9, fontweight="bold") axes[0].set_ylabel("Predictions within ±0.5 EDSS (%)", fontsize=11, fontweight="bold") # ========================= # LEGENDS # ========================= model_handles = [ plt.Line2D( [0], [0], marker=markers[model], color="w", label=model, markerfacecolor=colors[model], markeredgecolor="black", markersize=8, ) for model in model_order ] fig.legend( handles=model_handles, loc="lower center", bbox_to_anchor=(0.43, -0.01), ncol=3, frameon=False, title="Model", ) size_values = [10, 50, 100, 200] max_n = int(summary["n"].fillna(0).max()) size_values = [n for n in size_values if n <= max_n] if size_values: size_handles = [ plt.scatter( [], [], s=size_from_n(n), color="lightgray", edgecolor="black", alpha=0.85, label=f"n={n}", ) for n in size_values ] fig.legend( handles=size_handles, loc="lower center", bbox_to_anchor=(0.78, -0.01), ncol=len(size_handles), frameon=False, title="Point size", ) fig.suptitle( "Confidence-stratified EDSS accuracy by predicted EDSS range", fontsize=14, fontweight="bold", y=1.02, ) fig.text( 0.5, 0.045, f"Points show accuracy within ±0.5 EDSS. Point size reflects n. Groups with n < {MIN_N_TO_PLOT} are omitted from the figure.", ha="center", va="center", fontsize=9, color="#555555", ) plt.tight_layout(rect=[0, 0.10, 1, 0.94]) plt.savefig(OUTPUT_SVG, format="svg", bbox_inches="tight") plt.savefig(OUTPUT_PNG, dpi=300, bbox_inches="tight") plt.show() print("\nSaved:") print(OUTPUT_SVG) print(OUTPUT_PNG) print(OUTPUT_TABLE) print(OUTPUT_LONG) ## # %% Confidence x predicted EDSS range table from pathlib import Path import pandas as pd import numpy as np # ========================= # CONFIGURATION # ========================= GT_PATH = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/data/processed/" "MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" ) RUN_DIR = Path( "/home/shahin/Lab/Doktorarbeit/Barcelona/results/benchmark_runs/run_20260528_103942" ) TARGET_ITERATION = 1 OUTPUT_DIR = RUN_DIR / f"confidence_accuracy_table_iter_{TARGET_ITERATION}" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) OUTPUT_LONG = OUTPUT_DIR / f"confidence_accuracy_long_iter_{TARGET_ITERATION}.csv" OUTPUT_SUMMARY = OUTPUT_DIR / f"confidence_accuracy_summary_iter_{TARGET_ITERATION}.csv" OUTPUT_WIDE_CSV = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.csv" OUTPUT_WIDE_MD = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.md" OUTPUT_WIDE_XLSX = OUTPUT_DIR / f"confidence_accuracy_wide_table_iter_{TARGET_ITERATION}.xlsx" GT_EDSS_COL = "EDSS" PRED_EDSS_COL = "EDSS_numeric" PRED_EDSS_FALLBACK_COL = "EDSS" CERTAINTY_COL = "certainty_percent" # ========================= # HELPERS # ========================= def to_num(s): return pd.to_numeric( s.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def to_bool(s): return s.astype(str).str.lower().isin(["true", "1", "yes", "ja"]) def clean_model_name(name): replacements = { "gpt-oss-120b": "GPT-OSS-120B", "qwen3.6-27b": "Qwen3.6-27B", "gemma-4-31B-it": "Gemma-4-31B-it", } return replacements.get(str(name), str(name)) def find_iter_file(model_dir, iteration): files = sorted(model_dir.glob(f"*results_iter_{iteration}_*.csv")) files = [ f for f in files if "incremental" not in f.name.lower() and "summary" not in f.name.lower() and "all_results" not in f.name.lower() ] return files[0] if files else None def get_model_name(df, model_dir): if "model" in df.columns and df["model"].notna().any(): return str(df["model"].dropna().iloc[0]) return model_dir.name def confidence_bracket(certainty): if pd.isna(certainty): return np.nan if certainty < 70: return "Low (<70%)" if certainty < 80: return "Moderate (70–80%)" if certainty < 90: return "High (80–90%)" if certainty <= 100: return "Very high (90–100%)" return np.nan def edss_range_with_missing(value): if pd.isna(value): return "Missing EDSS" if 0.0 <= value <= 3.5: return "0.0–3.5" if 4.0 <= value <= 5.5: return "4.0–5.5" if 6.0 <= value <= 10.0: return "6.0–10.0" return "Invalid EDSS" def format_cell(acc, n): if pd.isna(n) or int(n) == 0: return "—" if pd.isna(acc): return f"NA (n={int(n)})" return f"{acc:.1f}% (n={int(n)})" # ========================= # LOAD GROUND TRUTH # ========================= gt = pd.read_csv(GT_PATH, sep=";") gt["row_index"] = gt.index gt["GT_EDSS_numeric"] = to_num(gt[GT_EDSS_COL]) gt = gt.dropna(subset=["GT_EDSS_numeric"]).copy() print(f"GT rows with numeric EDSS: {len(gt)}") # ========================= # BUILD LONG DATA # ========================= long_rows = [] model_dirs = [ p for p in sorted(RUN_DIR.iterdir()) if p.is_dir() and not p.name.startswith("confusion") and not p.name.startswith("functional_system") and not p.name.startswith("repeated_run") and not p.name.startswith("edss_error_distribution") and not p.name.startswith("edss_threshold_metrics") and not p.name.startswith("edss_severity_group_metrics") and not p.name.startswith("structured_output_validity") and not p.name.startswith("confidence") ] for model_dir in model_dirs: result_file = find_iter_file(model_dir, TARGET_ITERATION) if result_file is None: print(f"No iteration {TARGET_ITERATION} result file found for {model_dir.name}") continue pred_raw = pd.read_csv(result_file, sep=",") if "row_index" not in pred_raw.columns: print(f"Skipping {model_dir.name}: no row_index column.") continue if CERTAINTY_COL not in pred_raw.columns: print(f"Skipping {model_dir.name}: no {CERTAINTY_COL} column.") continue model_name = get_model_name(pred_raw, model_dir) model_display = clean_model_name(model_name) pred = pred_raw.copy() pred["row_index"] = pd.to_numeric(pred["row_index"], errors="coerce") pred = pred.dropna(subset=["row_index"]).copy() pred["row_index"] = pred["row_index"].astype(int) if "success" in pred.columns: pred = pred[to_bool(pred["success"])].copy() pred_col = PRED_EDSS_COL if PRED_EDSS_COL in pred.columns else PRED_EDSS_FALLBACK_COL pred["PRED_EDSS_numeric"] = to_num(pred[pred_col]) pred["certainty_numeric"] = to_num(pred[CERTAINTY_COL]) # Keep missing EDSS predictions, but require confidence. pred = pred.dropna(subset=["certainty_numeric"]).copy() pred = pred.drop_duplicates("row_index", keep="first").copy() merged = gt.merge( pred, on="row_index", how="inner", suffixes=("_gt", "_pred") ) if merged.empty: continue merged["has_numeric_prediction"] = merged["PRED_EDSS_numeric"].notna() merged["predicted_EDSS_range"] = merged["PRED_EDSS_numeric"].apply(edss_range_with_missing) merged["confidence_bracket"] = merged["certainty_numeric"].apply(confidence_bracket) merged = merged.dropna(subset=["confidence_bracket"]).copy() merged["abs_error"] = np.where( merged["has_numeric_prediction"], (merged["PRED_EDSS_numeric"] - merged["GT_EDSS_numeric"]).abs(), np.nan ) # Missing EDSS counts as not within ±0.5. merged["within_0_5"] = np.where( merged["has_numeric_prediction"], merged["abs_error"] <= 0.5, False ) for _, row in merged.iterrows(): long_rows.append({ "model": model_name, "model_display": model_display, "iteration": TARGET_ITERATION, "row_index": row["row_index"], "GT_EDSS_numeric": row["GT_EDSS_numeric"], "PRED_EDSS_numeric": row["PRED_EDSS_numeric"], "has_numeric_prediction": row["has_numeric_prediction"], "predicted_EDSS_range": row["predicted_EDSS_range"], "certainty_percent": row["certainty_numeric"], "confidence_bracket": row["confidence_bracket"], "abs_error": row["abs_error"], "within_0_5": row["within_0_5"], "result_file": str(result_file), }) long_df = pd.DataFrame(long_rows) if long_df.empty: raise ValueError("No evaluable rows found.") long_df.to_csv(OUTPUT_LONG, index=False) # ========================= # SUMMARY TABLE # ========================= confidence_order = [ "Low (<70%)", "Moderate (70–80%)", "High (80–90%)", "Very high (90–100%)", ] range_order = [ "Missing EDSS", "0.0–3.5", "4.0–5.5", "6.0–10.0", "Invalid EDSS", ] model_order = [ "GPT-OSS-120B", "Qwen3.6-27B", "Gemma-4-31B-it", ] model_order = [ m for m in model_order if m in long_df["model_display"].unique() ] range_order = [ r for r in range_order if r in long_df["predicted_EDSS_range"].unique() ] summary = ( long_df .groupby(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .agg( n=("within_0_5", "count"), accuracy_within_0_5_percent=("within_0_5", lambda x: x.mean() * 100), n_numeric_predictions=("has_numeric_prediction", "sum"), mean_abs_error=("abs_error", "mean"), median_abs_error=("abs_error", "median"), ) .reset_index() ) full_index = pd.MultiIndex.from_product( [model_order, range_order, confidence_order], names=["model_display", "predicted_EDSS_range", "confidence_bracket"] ) summary = ( summary .set_index(["model_display", "predicted_EDSS_range", "confidence_bracket"]) .reindex(full_index) .reset_index() ) summary["n"] = summary["n"].fillna(0).astype(int) summary["n_numeric_predictions"] = summary["n_numeric_predictions"].fillna(0).astype(int) summary.to_csv(OUTPUT_SUMMARY, index=False) # ========================= # WIDE TABLE FOR PAPER # ========================= summary["cell"] = summary.apply( lambda row: format_cell( row["accuracy_within_0_5_percent"], row["n"] ), axis=1 ) wide = ( summary .pivot_table( index=["model_display", "predicted_EDSS_range"], columns="confidence_bracket", values="cell", aggfunc="first" ) .reindex(index=pd.MultiIndex.from_product( [model_order, range_order], names=["model_display", "predicted_EDSS_range"] )) .reindex(columns=confidence_order) .reset_index() ) wide = wide.rename(columns={ "model_display": "Model", "predicted_EDSS_range": "Predicted EDSS range", }) wide.to_csv(OUTPUT_WIDE_CSV, index=False) with open(OUTPUT_WIDE_MD, "w", encoding="utf-8") as f: f.write(wide.to_markdown(index=False)) f.write("\n") wide.to_excel(OUTPUT_WIDE_XLSX, index=False) # ========================= # PRINT OUTPUT # ========================= pd.set_option("display.max_columns", None) pd.set_option("display.width", 220) pd.set_option("display.max_colwidth", None) print("\nWide confidence accuracy table:") print(wide.to_markdown(index=False)) print("\nSaved:") print(OUTPUT_LONG) print(OUTPUT_SUMMARY) print(OUTPUT_WIDE_CSV) print(OUTPUT_WIDE_MD) print(OUTPUT_WIDE_XLSX) ## # %% name # ========================= # ALTERNATIVE WIDE TABLE FOR PAPER # Rows: Predicted EDSS range + Confidence bracket # Columns: Models # ========================= summary["cell"] = summary.apply( lambda row: format_cell( row["accuracy_within_0_5_percent"], row["n"] ), axis=1 ) model_as_columns = ( summary .pivot_table( index=["predicted_EDSS_range", "confidence_bracket"], columns="model_display", values="cell", aggfunc="first" ) .reindex( index=pd.MultiIndex.from_product( [range_order, confidence_order], names=["predicted_EDSS_range", "confidence_bracket"] ) ) .reindex(columns=model_order) .reset_index() ) model_as_columns = model_as_columns.rename(columns={ "predicted_EDSS_range": "Predicted EDSS range", "confidence_bracket": "Confidence bracket", }) OUTPUT_MODEL_COLUMNS_CSV = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.csv" OUTPUT_MODEL_COLUMNS_MD = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.md" OUTPUT_MODEL_COLUMNS_XLSX = OUTPUT_DIR / f"confidence_accuracy_model_columns_table_iter_{TARGET_ITERATION}.xlsx" model_as_columns.to_csv(OUTPUT_MODEL_COLUMNS_CSV, index=False) with open(OUTPUT_MODEL_COLUMNS_MD, "w", encoding="utf-8") as f: f.write(model_as_columns.to_markdown(index=False)) f.write("\n") model_as_columns.to_excel(OUTPUT_MODEL_COLUMNS_XLSX, index=False) print("\nModel-as-columns confidence accuracy table:") print(model_as_columns.to_markdown(index=False)) print("\nSaved alternative table:") print(OUTPUT_MODEL_COLUMNS_CSV) print(OUTPUT_MODEL_COLUMNS_MD) print(OUTPUT_MODEL_COLUMNS_XLSX) ##