# %% Scatter import pandas as pd import matplotlib.pyplot as plt import numpy as np # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) # Create scatter plot plt.figure(figsize=(8, 6)) plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue') # Add labels and title plt.xlabel('GT_EDSS') plt.ylabel('LLM_Results') plt.title('Comparison of GT_EDSS vs LLM_Results') # Optional: Add a diagonal line for reference (perfect prediction) plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction') plt.legend() # Show plot plt.grid(True) plt.tight_layout() plt.show() ## # %% Bland0-altman import pandas as pd import matplotlib.pyplot as plt import numpy as np import statsmodels.api as sm # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.') df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce') df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results']) # Create Bland-Altman plot f, ax = plt.subplots(1, figsize=(8, 5)) sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax) # Add labels and title ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results') ax.set_xlabel('Mean of GT_EDSS and LLM_Results') ax.set_ylabel('Difference between GT_EDSS and LLM_Results') # Display Bland-Altman plot plt.tight_layout() plt.show() # Print some statistics mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results']) std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results']) print(f"Mean difference: {mean_diff:.3f}") print(f"Standard deviation of differences: {std_diff:.3f}") print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]") ## # %% Confusion matrix import pandas as pd import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # For confusion matrix, we need to categorize the values # Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) def categorize_edss(value): if pd.isna(value): return np.nan elif value <= 1.0: return '0-1' elif value <= 2.0: return '1-2' elif value <= 3.0: return '2-3' elif value <= 4.0: return '3-4' elif value <= 5.0: return '4-5' elif value <= 6.0: return '5-6' elif value <= 7.0: return '6-7' elif value <= 8.0: return '7-8' elif value <= 9.0: return '8-9' elif value <= 10.0: return '9-10' else: return '10+' # Create categorical versions df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) # Remove any NaN categories df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) # Create confusion matrix cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Plot confusion matrix plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) #plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)') plt.xlabel('LLM Generated EDSS') plt.ylabel('Ground Truth EDSS') plt.tight_layout() plt.show() # Print classification report print("Classification Report:") print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) # Print raw counts print("\nConfusion Matrix (Raw Counts):") print(cm) ## # %% Confusion matrix import pandas as pd import matplotlib.pyplot as plt import numpy as np from sklearn.metrics import confusion_matrix, classification_report import seaborn as sns # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # For confusion matrix, we need to categorize the values # Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10) def categorize_edss(value): if pd.isna(value): return np.nan elif value <= 1.0: return '0-1' elif value <= 2.0: return '1-2' elif value <= 3.0: return '2-3' elif value <= 4.0: return '3-4' elif value <= 5.0: return '4-5' elif value <= 6.0: return '5-6' elif value <= 7.0: return '6-7' elif value <= 8.0: return '7-8' elif value <= 9.0: return '8-9' elif value <= 10.0: return '9-10' else: return '10+' # Create categorical versions df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss) df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss) # Remove any NaN categories df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat']) # Create confusion matrix cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'], labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Plot confusion matrix plt.figure(figsize=(10, 8)) ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'], yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']) # Add legend text above the color bar # Get the colorbar object cbar = ax.collections[0].colorbar # Add text above the colorbar cbar.set_label('Number of Cases', rotation=270, labelpad=20) plt.xlabel('LLM Generated EDSS') plt.ylabel('Ground Truth EDSS') #plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)') plt.tight_layout() plt.show() # Print classification report print("Classification Report:") print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'])) # Print raw counts print("\nConfusion Matrix (Raw Counts):") print(cm) ## # %% Classification import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix import numpy as np # Load your data from TSV file file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Check data structure print("Data shape:", df.shape) print("First few rows:") print(df.head()) print("\nColumn names:") for col in df.columns: print(f" {col}") # Function to safely convert to boolean def safe_bool_convert(series): '''Safely convert series to boolean, handling various input formats''' # Convert to string first, then to boolean series_str = series.astype(str).str.strip().str.lower() # Handle different true/false representations bool_map = { 'true': True, '1': True, 'yes': True, 'y': True, 'false': False, '0': False, 'no': False, 'n': False } converted = series_str.map(bool_map) # Handle remaining NaN values converted = converted.fillna(False) # or True, depending on your preference return converted # Convert columns safely if 'result.klassifizierbar' in df.columns: print("\nresult.klassifizierbar column info:") print(df['result.klassifizierbar'].head(10)) print("Unique values:", df['result.klassifizierbar'].unique()) df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar']) print("After conversion:") print(df['result.klassifizierbar'].value_counts()) if 'GT.klassifizierbar' in df.columns: print("\nGT.klassifizierbar column info:") print(df['GT.klassifizierbar'].head(10)) print("Unique values:", df['GT.klassifizierbar'].unique()) df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar']) print("After conversion:") print(df['GT.klassifizierbar'].value_counts()) # Create bar chart showing only True values for klassifizierbar if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: # Get counts for True values only llm_true_count = df['result.klassifizierbar'].sum() gt_true_count = df['GT.klassifizierbar'].sum() # Plot using matplotlib directly fig, ax = plt.subplots(figsize=(8, 6)) x = np.arange(2) width = 0.35 bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8) bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8) # Add value labels on bars ax.annotate(f'{llm_true_count}', xy=(x[0], llm_true_count), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom') ax.annotate(f'{gt_true_count}', xy=(x[1], gt_true_count), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom') ax.set_xlabel('Classification Status (klassifizierbar)') ax.set_ylabel('Count') ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"') ax.set_xticks(x) ax.set_xticklabels(['LLM', 'GT']) ax.legend() plt.tight_layout() plt.show() # Create confusion matrix if both columns exist if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns: try: # Ensure both columns are boolean llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool) gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool) cm = confusion_matrix(gt_bool, llm_bool) # Plot confusion matrix fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['False ', 'True '], yticklabels=['False', 'True '], ax=ax) ax.set_xlabel('LLM Predictions ') ax.set_ylabel('GT Labels ') ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"') plt.tight_layout() plt.show() print("Confusion Matrix:") print(cm) except Exception as e: print(f"Error creating confusion matrix: {e}") # Show final data info print("\nFinal DataFrame info:") print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info()) ## # %% Boxplot import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np # Load your data from TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/results/join_results_unique.tsv' df = pd.read_csv(file_path, sep='\t') # Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.') df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.') # Convert to float (handle invalid entries gracefully) df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce') df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce') # Drop rows where either column is NaN df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS']) # 1. DEFINE CATEGORY ORDER # This ensures the X-axis is numerically logical (0-1 comes before 1-2) category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+'] # Convert the column to a Categorical type with the specific order df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss), categories=category_order, ordered=True) plt.figure(figsize=(14, 8)) # 2. ADD HUE FOR LEGEND # Assigning x to 'hue' allows Seaborn to generate a legend automatically box_plot = sns.boxplot( data=df_clean, x='GT.EDSS_cat', y='result.EDSS', hue='GT.EDSS_cat', # Added hue palette='viridis', linewidth=1.5, legend=True # Ensure legend is enabled ) # 3. CUSTOMIZE PLOT plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20) plt.xlabel('Ground Truth EDSS Category', fontsize=14) plt.ylabel('LLM Predicted EDSS', fontsize=14) # Move legend to the side or top plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left') plt.xticks(rotation=45, ha='right', fontsize=10) plt.grid(True, axis='y', alpha=0.3) plt.tight_layout() plt.show() ## # %% Postproccessing Column names import pandas as pd # Read the TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Create a mapping dictionary for German to English column names column_mapping = { 'EDSS':'GT.EDSS', 'klassifizierbar': 'GT.klassifizierbar', 'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS', 'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS', 'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS', 'Sensibiliät': 'GT.SENSORY_FUNCTIONS', 'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS', 'Ambulation': 'GT.AMBULATION', 'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS', 'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS' } # Rename columns df = df.rename(columns=column_mapping) # Save the modified dataframe back to TSV file df.to_csv(file_path, sep='\t', index=False) print("Columns have been successfully renamed!") print("Renamed columns:") for old_name, new_name in column_mapping.items(): if old_name in df.columns: print(f" {old_name} -> {new_name}") ## # %% Styled table import pandas as pd import numpy as np import seaborn as sns import matplotlib.pyplot as plt import dataframe_image as dfi # Load data df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t') # 1. Identify all GT and result columns gt_columns = [col for col in df.columns if col.startswith('GT.')] result_columns = [col for col in df.columns if col.startswith('result.')] print("GT Columns found:", gt_columns) print("Result Columns found:", result_columns) # 2. Create proper mapping between GT and result columns # Handle various naming conventions (spaces, underscores, etc.) column_mapping = {} for gt_col in gt_columns: base_name = gt_col.replace('GT.', '') # Clean the base name for matching - remove spaces, underscores, etc. # Try different matching approaches candidates = [ f'result.{base_name}', # Exact match f'result.{base_name.replace(" ", "_")}', # With underscores f'result.{base_name.replace("_", " ")}', # With spaces f'result.{base_name.replace(" ", "")}', # No spaces f'result.{base_name.replace("_", "")}' # No underscores ] # Also try case-insensitive matching candidates.append(f'result.{base_name.lower()}') candidates.append(f'result.{base_name.upper()}') # Try to find matching result column matched = False for candidate in candidates: if candidate in result_columns: column_mapping[gt_col] = candidate matched = True break # If no exact match found, try partial matching if not matched: # Try to match by removing special characters and comparing base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' ']) for result_col in result_columns: result_base = result_col.replace('result.', '') result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' ']) if base_clean.lower() == result_clean.lower(): column_mapping[gt_col] = result_col matched = True break print("Column mapping:", column_mapping) # 3. Faster, vectorized computation using the corrected mapping data_list = [] for gt_col, result_col in column_mapping.items(): print(f"Processing {gt_col} vs {result_col}") # Convert to numeric, forcing errors to NaN s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float) s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float) # Calculate matches (abs difference <= 0.5) diff = np.abs(s1 - s2) matches = (diff <= 0.5).sum() # Determine the denominator (total valid comparisons) valid_count = diff.notna().sum() if valid_count > 0: percentage = (matches / valid_count) * 100 else: percentage = 0 # Extract clean base name for display base_name = gt_col.replace('GT.', '') data_list.append({ 'GT': base_name, 'Match %': round(percentage, 1) }) # 4. Prepare Data match_df = pd.DataFrame(data_list) # Clean up labels: Replace underscores with spaces and capitalize match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title() match_df = match_df.sort_values('Match %', ascending=False) # 5. Create a "Beautiful" Table using Seaborn Heatmap def create_luxury_table(df, output_file="edss_agreement.png"): # Set the aesthetic style sns.set_theme(style="white", font="sans-serif") # Prepare data for heatmap plot_data = df.set_index('GT')[['Match %']] # Initialize the figure # Height is dynamic based on number of rows fig, ax = plt.subplots(figsize=(8, len(df) * 0.6)) # Create a custom diverging color map (Deep Red -> Mustard -> Emerald) # This looks more professional than standard 'RdYlGn' cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True) # Draw the heatmap sns.heatmap( plot_data, annot=True, fmt=".1f", cmap=cmap, center=85, # Centers the color transition vmin=50, vmax=100, # Range of the gradient linewidths=2, linecolor='white', cbar=False, # Remove color bar for a "table" look annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"} ) # Styling the Axes (Turning the heatmap into a table) ax.set_xlabel("") ax.set_ylabel("") ax.xaxis.tick_top() # Move "Match %" label to top ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50') ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0) # Add a thin border around the plot for _, spine in ax.spines.items(): spine.set_visible(True) spine.set_color('#ecf0f1') plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50') # Add a subtle footer plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points", wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic') # Save with high resolution plt.tight_layout() plt.savefig(output_file, dpi=300, bbox_inches='tight') print(f"Beautiful table saved as {output_file}") # Execute create_luxury_table(match_df) # Run the function save_styled_table(match_df) # 6. Save as SVG plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight') print("Successfully saved agreement_table.svg") # Show plot if running in a GUI environment plt.show() ## # %% Time Plot import numpy as np import matplotlib.pyplot as plt import pandas as pd from scipy import stats # Load the TSV file file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Extract the inference_time_sec column inference_times = df['inference_time_sec'].dropna() # Remove NaN values # Calculate statistics mean_time = inference_times.mean() std_time = inference_times.std() median_time = np.median(inference_times) # Create the histogram fig, ax = plt.subplots(figsize=(10, 6)) # Create histogram with bins of 1 second width min_time = int(inference_times.min()) max_time = int(inference_times.max()) + 1 bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width # Create histogram with counts (not probability density) n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5) # Generate Gaussian curve for fit x = np.linspace(inference_times.min(), inference_times.max(), 100) # Scale Gaussian to match histogram counts gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0]) # Plot Gaussian fit ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)') # Add vertical lines for mean and median ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s') ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s') # Add standard deviation as vertical lines ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s') ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s') ax.set_xlabel('Inference Time (seconds)') ax.set_ylabel('Frequency') ax.set_title('Inference Time Distribution with Gaussian Fit') ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.show() ## # %% Dashboard import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import numpy as np from matplotlib.gridspec import GridSpec def to_numeric_comma(s: pd.Series) -> pd.Series: # accepts 1.5 and 1,5 return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Rename columns to remove 'result.' prefix and replace spaces column_mapping = {} for col in df.columns: if col.startswith('result.'): new_name = col.replace('result.', '').replace(' ', '_') column_mapping[col] = new_name df = df.rename(columns=column_mapping) # Parse MedDatum safely df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') # Patient patient_id = 'd13e4aa3' patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() if patient_data.empty: raise ValueError(f"No data found for patient: {patient_id}") # Functional systems + EDSS edss_col, edss_title = ('GT.EDSS', 'EDSS') functional_systems = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), ('GT.SENSORY_FUNCTIONS', 'Sensory'), ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), ('GT.AMBULATION', 'Ambulation'), ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), ] # y-axis max rules ymax_by_col = { 'GT.PYRAMIDAL_FUNCTIONS': 6, 'GT.SENSORY_FUNCTIONS': 6, 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, 'GT.VISUAL_OPTIC_FUNCTIONS': 6, 'GT.CEREBELLAR_FUNCTIONS': 5, 'GT.CEREBRAL_FUNCTIONS': 5, 'GT.BRAINSTEM_FUNCTIONS': 5, 'GT.EDSS': 10, } default_ymax = 6 # ---------- Build shared "event dates" ticks ---------- cols_for_dates = [edss_col] + [c for c, _ in functional_systems] event_dates = [] for c in cols_for_dates: if c in patient_data.columns: y = to_numeric_comma(patient_data[c]) # <-- changed x = patient_data['MedDatum'] tmp = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]) event_dates.extend(tmp["x"].tolist()) event_dates = sorted(pd.Series(event_dates).drop_duplicates().tolist()) max_ticks = 8 if len(event_dates) > max_ticks: idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) event_dates = [event_dates[i] for i in idx] # ---------- A4 figure ---------- fig = plt.figure(figsize=(11.69, 8.27)) gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) def style_time_axis(ax, show_labels=True): ax.set_xticks(event_dates) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) if not show_labels: ax.tick_params(labelbottom=False) # ---------- EDSS main plot ---------- ax_main = fig.add_subplot(gs[0, :]) if edss_col in patient_data.columns: y = to_numeric_comma(patient_data[edss_col]) # <-- changed x = patient_data['MedDatum'] plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") ax_main.set_title(edss_title, fontsize=14, fontweight='bold') ax_main.set_ylabel("Score") ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) ax_main.grid(True, alpha=0.3) if not plot_df.empty: ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') else: ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') else: ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') ax_main.set_ylim(0, ymax_by_col.get(edss_col, 10)) ax_main.grid(True, alpha=0.3) style_time_axis(ax_main) # ---------- Small aligned plots ---------- small_axes = [] for k, (col, title) in enumerate(functional_systems): r = 1 + (k // 4) c = (k % 4) ax = fig.add_subplot(gs[r, c], sharex=ax_main) small_axes.append(ax) ymax = ymax_by_col.get(col, default_ymax) ax.set_title(title, fontsize=10) ax.set_ylabel("Score") ax.set_ylim(0, ymax) ax.grid(True, alpha=0.3) if col in patient_data.columns: y = to_numeric_comma(patient_data[col]) # <-- changed x = patient_data['MedDatum'] plot_df = pd.DataFrame({"x": x, "y": y}).dropna(subset=["x", "y"]).sort_values("x") if not plot_df.empty: ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') else: ax.set_title(f"{title} (no data)", fontsize=10) else: ax.set_title(f"{title} (missing)", fontsize=10) style_time_axis(ax) # Hide x tick labels on first row of small plots for ax in small_axes[:4]: ax.tick_params(labelbottom=False) plt.tight_layout() fig.subplots_adjust(hspace=0.7) plt.show() ## <<<<<<< Updated upstream ======= # %% Dashboard Angepasst import pandas as pd import matplotlib.pyplot as plt import matplotlib.dates as mdates import numpy as np from matplotlib.gridspec import GridSpec def to_numeric_comma(s: pd.Series) -> pd.Series: # accepts 1.5 and 1,5 return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce") # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Rename columns to remove 'result.' prefix and replace spaces column_mapping = {} for col in df.columns: if col.startswith('result.'): new_name = col.replace('result.', '').replace(' ', '_') column_mapping[col] = new_name df = df.rename(columns=column_mapping) # Parse MedDatum safely df['MedDatum'] = pd.to_datetime(df['MedDatum'], errors='coerce') # Patient patient_id = '3d942c60' patient_data = df[df['unique_id'] == patient_id].sort_values('MedDatum').copy() if patient_data.empty: raise ValueError(f"No data found for patient: {patient_id}") # Functional systems + EDSS edss_col, edss_title = ('GT.EDSS', 'EDSS') functional_systems = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'Visual / Optic'), ('GT.CEREBELLAR_FUNCTIONS', 'Cerebellar'), ('GT.BRAINSTEM_FUNCTIONS', 'Brainstem'), ('GT.SENSORY_FUNCTIONS', 'Sensory'), ('GT.PYRAMIDAL_FUNCTIONS', 'Pyramidal (Motor)'), ('GT.AMBULATION', 'Ambulation'), ('GT.CEREBRAL_FUNCTIONS', 'Cerebral'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'Bowel & Bladder'), ] # y-axis max rules ymax_by_col = { 'GT.PYRAMIDAL_FUNCTIONS': 6, 'GT.SENSORY_FUNCTIONS': 6, 'GT.BOWEL_AND_BLADDER_FUNCTIONS': 6, 'GT.VISUAL_OPTIC_FUNCTIONS': 6, 'GT.CEREBELLAR_FUNCTIONS': 5, 'GT.CEREBRAL_FUNCTIONS': 5, 'GT.BRAINSTEM_FUNCTIONS': 5, 'GT.EDSS': 10, } default_ymax = 6 # ---------- Build shared visit dates ticks ---------- # Use ALL patient visit dates, not only dates with valid numeric values event_dates = sorted(patient_data['MedDatum'].dropna().drop_duplicates().tolist()) max_ticks = 8 if len(event_dates) > max_ticks: idx = np.linspace(0, len(event_dates) - 1, max_ticks, dtype=int) event_dates = [event_dates[i] for i in idx] # Base timeline for plotting: one row per patient visit date timeline = ( patient_data[['MedDatum']] .dropna() .drop_duplicates() .sort_values('MedDatum') .rename(columns={'MedDatum': 'x'}) ) # ---------- A4 figure ---------- fig = plt.figure(figsize=(11.69, 8.27)) gs = GridSpec(nrows=3, ncols=4, figure=fig, height_ratios=[2.0, 1.0, 1.0], hspace=0.5, wspace=0.35) def style_time_axis(ax, show_labels=True): ax.set_xticks(event_dates) ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) ax.tick_params(axis='x', rotation=30, labelsize=8, pad=2) if not show_labels: ax.tick_params(labelbottom=False) def get_plot_df(patient_data, col): """ Keep all visit dates. Missing values stay NaN so matplotlib draws gaps instead of zeros. """ tmp = patient_data[['MedDatum', col]].copy() tmp = tmp.rename(columns={'MedDatum': 'x', col: 'raw_y'}) tmp['y'] = to_numeric_comma(tmp['raw_y']) # aggregate if multiple rows exist on same date tmp = tmp.groupby('x', as_index=False)['y'].max() # merge onto full timeline so all dates remain visible plot_df = timeline.merge(tmp, on='x', how='left').sort_values('x') return plot_df # ---------- EDSS main plot ---------- ax_main = fig.add_subplot(gs[0, :]) ax_main.set_title(edss_title, fontsize=14, fontweight='bold') ax_main.set_ylabel("Score") ax_main.set_ylim(0, ymax_by_col.get(edss_col, default_ymax)) ax_main.grid(True, alpha=0.3) if edss_col in patient_data.columns: plot_df = get_plot_df(patient_data, edss_col) if plot_df['y'].notna().any(): # NaNs create visible gaps in the line ax_main.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=3, color='tab:red') else: ax_main.set_title("EDSS (no numeric data)", fontsize=14, fontweight='bold') else: ax_main.set_title("EDSS (missing column GT.EDSS)", fontsize=14, fontweight='bold') style_time_axis(ax_main) # ---------- Small aligned plots ---------- small_axes = [] for k, (col, title) in enumerate(functional_systems): r = 1 + (k // 4) c = (k % 4) ax = fig.add_subplot(gs[r, c], sharex=ax_main) small_axes.append(ax) ymax = ymax_by_col.get(col, default_ymax) ax.set_title(title, fontsize=10) ax.set_ylabel("Score") ax.set_ylim(0, ymax) ax.grid(True, alpha=0.3) if col in patient_data.columns: plot_df = get_plot_df(patient_data, col) if plot_df['y'].notna().any(): # NaNs remain in y -> line breaks where data is missing ax.plot(plot_df["x"], plot_df["y"], marker='o', linewidth=2, color='tab:blue') else: ax.set_title(f"{title} (no numeric data)", fontsize=10) else: ax.set_title(f"{title} (missing)", fontsize=10) style_time_axis(ax) # Hide x tick labels on first row of small plots for ax in small_axes[:4]: ax.tick_params(labelbottom=False) plt.tight_layout() fig.subplots_adjust(hspace=0.7) plt.show() ## >>>>>>> Stashed changes # %% Table import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from datetime import datetime import numpy as np # Load the data file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv' df = pd.read_csv(file_path, sep='\t') # Convert MedDatum to datetime df['MedDatum'] = pd.to_datetime(df['MedDatum']) # Check what columns actually exist in the dataset print("Available columns:") print(df.columns.tolist()) print("\nFirst few rows:") print(df.head()) # Check data types print("\nData types:") print(df.dtypes) # Hardcode specific patient names patient_names = ['6ccda8c6'] # Define the functional systems (columns to plot) functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual'] # Create subplots num_plots = len(functional_systems) num_cols = 2 num_rows = (num_plots + num_cols - 1) // num_cols fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) if num_plots == 1: axes = [axes] elif num_rows == 1: axes = axes else: axes = axes.flatten() # Plot for the hardcoded patient for i, system in enumerate(functional_systems): # Filter data for this specific patient patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum') # Check if patient data exists if patient_data.empty: print(f"No data found for patient: {patient_names[0]}") axes[i].set_title(f'Functional System: {system} (No data)') axes[i].set_ylabel('Score') continue # Check if the system column exists if system in patient_data.columns: # Plot only valid data (non-null values) valid_data = patient_data.dropna(subset=[system]) if not valid_data.empty: # Ensure MedDatum is properly formatted for plotting axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system) axes[i].set_ylabel('Score') axes[i].set_title(f'Functional System: {system}') axes[i].grid(True, alpha=0.3) axes[i].legend() else: axes[i].set_title(f'Functional System: {system} (No valid data)') axes[i].set_ylabel('Score') else: # Try to find similar column names found_column = None for col in df.columns: if system.lower() in col.lower(): found_column = col break if found_column: valid_data = patient_data.dropna(subset=[found_column]) if not valid_data.empty: axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column) axes[i].set_ylabel('Score') axes[i].set_title(f'Functional System: {system} (found as: {found_column})') axes[i].grid(True, alpha=0.3) axes[i].legend() else: axes[i].set_title(f'Functional System: {system} (No valid data)') axes[i].set_ylabel('Score') else: axes[i].set_title(f'Functional System: {system} (Column not found)') axes[i].set_ylabel('Score') # Hide empty subplots for i in range(len(functional_systems), len(axes)): axes[i].set_visible(False) # Set x-axis label for the last row only for i in range(len(functional_systems)): if i >= len(axes) - num_cols: # Last row axes[i].set_xlabel('Date') # Format x-axis dates for ax in axes: if ax.get_lines(): # Only format if there are lines to plot ax.tick_params(axis='x', rotation=45) ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d')) # Automatically adjust layout plt.tight_layout() plt.show() ## # %% Histogram Fig1 import pandas as pd import matplotlib.pyplot as plt import matplotlib.font_manager as fm import json import os def create_visit_frequency_plot( file_path, output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data', output_filename='visit_frequency_distribution.svg', fontsize=10, color_scheme_path='colors.json' ): """ Creates a publication-ready bar chart of patient visit frequency. Args: file_path (str): Path to the input TSV file. output_dir (str): Directory to save the output SVG file. output_filename (str): Name of the output SVG file. fontsize (int): Font size for all text elements (labels, title). color_scheme_path (str): Path to the JSON file containing the color palette. """ # --- 1. Load Data and Color Scheme --- try: df = pd.read_csv(file_path, sep='\t') print("Data loaded successfully.") # Sort data for easier visual comparison df = df.sort_values(by='Visits Count') except FileNotFoundError: print(f"Error: The file was not found at {file_path}") return try: with open(color_scheme_path, 'r') as f: colors = json.load(f) # Select a blue from the sequential palette for the bars bar_color = colors['sequential']['blues'][-2] # A saturated blue except FileNotFoundError: print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.") bar_color = '#2171b5' # A common matplotlib blue # --- 2. Set up the Plot with Scientific Style --- plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height # Set the font to Arial arial_font = fm.FontProperties(family='Arial', size=fontsize) plt.rcParams['font.family'] = 'Arial' plt.rcParams['font.size'] = fontsize # --- 3. Create the Bar Chart --- ax = plt.gca() bars = plt.bar( x=df['Visits Count'], height=df['Unique Patients'], color=bar_color, edgecolor='black', linewidth=0.5, # Minimum line thickness width=0.7 ) # --- NEW: Explicitly set x-ticks and labels to ensure all are shown --- # Get the unique visit counts to use as tick labels visit_counts = df['Visits Count'].unique() # Set the x-ticks to be at the center of each bar ax.set_xticks(visit_counts) # Set the x-tick labels to be the visit counts, using the specified font ax.set_xticklabels(visit_counts, fontproperties=arial_font) # --- END OF NEW SECTION --- # --- 4. Customize Axes and Layout (Nature style) --- # Display only left and bottom axes ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Turn off axis ticks (the marks, not the labels) plt.tick_params(axis='both', which='both', length=0) # Remove grid lines plt.grid(False) # Set background to white (no shading) ax.set_facecolor('white') plt.gcf().set_facecolor('white') # --- 5. Add Labels and Title --- plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10) plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10) plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20) # --- 6. Add y-axis values on top of each bar --- # This adds the count of unique patients directly above each bar. ax.bar_label(bars, fmt='%d', padding=3) # --- 7. Export the Figure --- # Ensure the output directory exists os.makedirs(output_dir, exist_ok=True) full_output_path = os.path.join(output_dir, output_filename) plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight') print(f"\nFigure saved as '{full_output_path}'") # --- 8. (Optional) Display the Plot --- # plt.show() # --- Main execution --- if __name__ == '__main__': # Define the file path input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv' # Call the function to create and save the plot create_visit_frequency_plot( file_path=input_file, fontsize=10 # Using a 10 pt font size as per guidelines ) ## # %% Scatter Plot functional system import pandas as pd import matplotlib.pyplot as plt import json import os # --- Configuration --- # Set the font to Arial for all text in the plot, as per the guidelines plt.rcParams['font.family'] = 'Arial' # Define the path to your data file data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' # Define the path to save the color mapping JSON file color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' # Define the path to save the final figure figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' # --- 1. Load the Dataset --- try: # Load the TSV file df = pd.read_csv(data_path, sep='\t') print(f"Successfully loaded data from {data_path}") print(f"Data shape: {df.shape}") except FileNotFoundError: print(f"Error: The file at {data_path} was not found.") # Exit or handle the error appropriately raise # --- 2. Define Functional Systems and Create Color Mapping --- # List of tuples containing (ground_truth_column, result_column) functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # Extract system names for color mapping and legend system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] # Define a professional color palette (dark blue theme) # This is a qualitative palette with distinct, accessible colors colors = [ '#003366', # Dark Blue '#336699', # Medium Blue '#6699CC', # Light Blue '#99CCFF', # Very Light Blue '#FF9966', # Coral '#FF6666', # Light Red '#CC6699', # Magenta '#9966CC' # Purple ] # Create a dictionary mapping system names to colors color_map = dict(zip(system_names, colors)) # Ensure the directory for the JSON file exists os.makedirs(os.path.dirname(color_json_path), exist_ok=True) # Save the color map to a JSON file with open(color_json_path, 'w') as f: json.dump(color_map, f, indent=4) print(f"Color mapping saved to {color_json_path}") # --- 3. Calculate Agreement Percentages and Format Legend Labels --- agreement_percentages = {} legend_labels = {} for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Ensure we are comparing the same rows common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) gt_data = gt_numeric.loc[common_index] res_data = res_numeric.loc[common_index] # Calculate agreement percentage if len(gt_data) > 0: agreement = (gt_data == res_data).mean() * 100 else: agreement = 0 # Handle case with no valid data agreement_percentages[system_name] = agreement # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" # --- 4. Reshape Data for Plotting --- plot_data = [] for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Create a temporary DataFrame with the numeric data temp_df = pd.DataFrame({ 'system': system_name, 'ground_truth': gt_numeric, 'inference': res_numeric }) # Drop rows where either value is NaN, as they cannot be plotted temp_df = temp_df.dropna() plot_data.append(temp_df) # Concatenate all the temporary DataFrames into one plot_df = pd.concat(plot_data, ignore_index=True) if plot_df.empty: print("Warning: No valid numeric data to plot after conversion. The plot will be blank.") else: print(f"Prepared plot data with {len(plot_df)} data points.") # --- 5. Create the Scatter Plot --- plt.figure(figsize=(10, 8)) # Plot each functional system with its assigned color and formatted legend label for system, group in plot_df.groupby('system'): plt.scatter( group['ground_truth'], group['inference'], label=legend_labels[system], color=color_map[system], alpha=0.7, s=30 ) # Add a diagonal line representing perfect agreement (y = x) # This line helps visualize how close the predictions are to the ground truth if not plot_df.empty: plt.plot( [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], [plot_df['ground_truth'].min(), plot_df['ground_truth'].max()], color='black', linestyle='--', linewidth=0.8, alpha=0.7 ) # --- 6. Apply Styling and Labels --- plt.xlabel('Ground Truth', fontsize=12) plt.ylabel('LLM Inference', fontsize=12) plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14) # Apply scientific visualization styling rules ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.tick_params(axis='both', which='both', length=0) # Remove ticks ax.grid(False) # Remove grid lines plt.legend(title='Functional System', frameon=False, fontsize=10) # --- 7. Save and Display the Figure --- # Ensure the directory for the figure exists os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') print(f"Figure successfully saved to {figure_save_path}") # Display the plot plt.show() ## # %% Confusion Matrix functional systems import pandas as pd import matplotlib.pyplot as plt import json import os import numpy as np import matplotlib.colors as mcolors # --- Configuration --- plt.rcParams['font.family'] = 'Arial' data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg' # --- 1. Load the Dataset --- df = pd.read_csv(data_path, sep='\t') # --- 2. Define Functional Systems and Colors --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC'] color_map = dict(zip(system_names, colors)) # --- 3. Categorization Function --- categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'] category_to_index = {cat: i for i, cat in enumerate(categories)} n_categories = len(categories) def categorize_edss(value): if pd.isna(value): return np.nan # Ensure value is float to avoid TypeError val = float(value) idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0 return categories[min(idx, len(categories)-1)] # --- 4. Prepare Mixed Color Matrix with Saturation --- cell_system_counts = np.zeros((n_categories, n_categories, len(system_names))) for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot): # Fix: Ensure numeric conversion to avoid string comparison errors temp_df = df[[gt_col, res_col]].copy() temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce') temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce') valid_df = temp_df.dropna() for _, row in valid_df.iterrows(): gt_cat = categorize_edss(row[gt_col]) res_cat = categorize_edss(row[res_col]) if gt_cat in category_to_index and res_cat in category_to_index: cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1 # Create an RGBA image matrix (10x10x4) rgba_matrix = np.zeros((n_categories, n_categories, 4)) total_counts = np.sum(cell_system_counts, axis=2) max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1 for i in range(n_categories): for j in range(n_categories): count_sum = total_counts[i, j] if count_sum > 0: mixed_rgb = np.zeros(3) for s_idx, s_name in enumerate(system_names): weight = cell_system_counts[i, j, s_idx] / count_sum system_rgb = mcolors.to_rgb(color_map[s_name]) mixed_rgb += np.array(system_rgb) * weight # Set RGB channels rgba_matrix[i, j, :3] = mixed_rgb # Set Alpha channel (Saturation Effect) # Using a square root scale to make lower counts more visible but still "lighter" alpha = np.sqrt(count_sum / max_count) # Ensure alpha is at least 0.1 so it's not invisible rgba_matrix[i, j, 3] = max(0.1, alpha) else: # Empty cells are white rgba_matrix[i, j] = [1, 1, 1, 0] # --- 5. Plotting --- fig, ax = plt.subplots(figsize=(12, 10)) # Show the matrix # Note: we use origin='lower' if you want 0-1 at the bottom, # but confusion matrices usually have 0-1 at the top (origin='upper') im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper') # Add count labels for i in range(n_categories): for j in range(n_categories): if total_counts[i, j] > 0: # Background brightness for text contrast bg_color = rgba_matrix[i, j, :3] lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2] # If alpha is low, background is effectively white, so use black text text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black" ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=10, fontweight='bold') # --- 6. Styling --- ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10) ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10) ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20) ax.set_xticks(np.arange(n_categories)) ax.set_xticklabels(categories) ax.set_yticks(np.arange(n_categories)) ax.set_yticklabels(categories) # Remove the frame/spines for a cleaner look for spine in ax.spines.values(): spine.set_visible(False) # Custom Legend handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names] labels = [name.replace('_', ' ').capitalize() for name in system_names] ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False) plt.tight_layout() os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Difference Plot Functional system import pandas as pd import matplotlib.pyplot as plt import json import os import numpy as np # --- Configuration --- # Set the font to Arial for all text in the plot, as per the guidelines plt.rcParams['font.family'] = 'Arial' # Define the path to your data file data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv' # Define the path to save the color mapping JSON file color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json' # Define the path to save the final figure figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg' # --- 1. Load the Dataset --- try: # Load the TSV file df = pd.read_csv(data_path, sep='\t') print(f"Successfully loaded data from {data_path}") print(f"Data shape: {df.shape}") except FileNotFoundError: print(f"Error: The file at {data_path} was not found.") # Exit or handle the error appropriately raise # --- 2. Define Functional Systems and Create Color Mapping --- # List of tuples containing (ground_truth_column, result_column) functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # Extract system names for color mapping and legend system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] # Define a professional color palette (dark blue theme) # This is a qualitative palette with distinct, accessible colors colors = [ '#003366', # Dark Blue '#336699', # Medium Blue '#6699CC', # Light Blue '#99CCFF', # Very Light Blue '#FF9966', # Coral '#FF6666', # Light Red '#CC6699', # Magenta '#9966CC' # Purple ] # Create a dictionary mapping system names to colors color_map = dict(zip(system_names, colors)) # Ensure the directory for the JSON file exists os.makedirs(os.path.dirname(color_json_path), exist_ok=True) # Save the color map to a JSON file with open(color_json_path, 'w') as f: json.dump(color_map, f, indent=4) print(f"Color mapping saved to {color_json_path}") # --- 3. Calculate Agreement Percentages and Format Legend Labels --- agreement_percentages = {} legend_labels = {} for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Convert columns to numeric, setting errors to NaN gt_numeric = pd.to_numeric(df[gt_col], errors='coerce') res_numeric = pd.to_numeric(df[res_col], errors='coerce') # Ensure we are comparing the same rows common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index) gt_data = gt_numeric.loc[common_index] res_data = res_numeric.loc[common_index] # Calculate agreement percentage if len(gt_data) > 0: agreement = (gt_data == res_data).mean() * 100 else: agreement = 0 # Handle case with no valid data agreement_percentages[system_name] = agreement # Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions") formatted_name = " ".join(word.capitalize() for word in system_name.split('_')) legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)" # --- 4. Robustly Prepare Error Data for Boxplot --- def safe_parse(s): '''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)''' if pd.isna(s): return np.nan if isinstance(s, (int, float)): return float(s) # Replace comma with dot, then strip whitespace s_clean = str(s).replace(',', '.').strip() try: return float(s_clean) except ValueError: return np.nan plot_data = [] for gt_col, res_col in functional_systems_to_plot: system_name = gt_col.split('.')[1] # Parse both columns with robust comma handling gt_numeric = df[gt_col].apply(safe_parse) res_numeric = df[res_col].apply(safe_parse) # Compute error (only where both are finite) error = res_numeric - gt_numeric # Create temp DataFrame temp_df = pd.DataFrame({ 'system': system_name, 'error': error }).dropna() # drop rows where either was unparseable plot_data.append(temp_df) plot_df = pd.concat(plot_data, ignore_index=True) if plot_df.empty: print("⚠️ Warning: No valid numeric error data to plot after robust parsing.") else: print(f"✅ Prepared error data with {len(plot_df)} data points.") # Diagnostic: show a few samples print("\n📌 Sample errors by system:") for sys, grp in plot_df.groupby('system'): print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}") # Ensure categorical ordering plot_df['system'] = pd.Categorical( plot_df['system'], categories=[name.split('.')[1] for name, _ in functional_systems_to_plot], ordered=True ) # --- 5. Prepare Data for Diverging Stacked Bar Plot --- print("\n📊 Preparing diverging stacked bar plot data...") # Define bins for error direction def categorize_error(err): if pd.isna(err): return 'missing' elif err < 0: return 'underestimate' elif err > 0: return 'overestimate' else: return 'match' # Add category column (only on finite errors) plot_df_clean = plot_df[plot_df['error'].notna()].copy() plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error) # Count by system + category category_counts = ( plot_df_clean .groupby(['system', 'category']) .size() .unstack(fill_value=0) .reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0) ) # Reorder systems category_counts = category_counts.reindex(system_names) # Prepare for diverging plot: # - Underestimates: plotted to the *left* (negative x) # - Overestimates: plotted to the *right* (positive x) # - Matches: centered (no width needed, or as a bar of width 0.2) underestimate_counts = category_counts['underestimate'] match_counts = category_counts['match'] overestimate_counts = category_counts['overestimate'] # For diverging: left = -underestimate, right = overestimate left_counts = underestimate_counts right_counts = overestimate_counts # Compute max absolute bar height (for symmetric x-axis) max_bar = max(left_counts.max(), right_counts.max(), 1) plot_range = (-max_bar, max_bar) # X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ... n_systems = len(system_names) positions = np.arange(n_systems) left_positions = -positions - 0.5 # left-aligned underestimates right_positions = positions + 0.5 # right-aligned overestimates # --- 6. Create Diverging Stacked Bar Plot --- plt.figure(figsize=(12, 7)) # Colors: diverging palette colors = { 'underestimate': '#E74C3C', # Red (left) 'match': '#2ECC71', # Green (center) 'overestimate': '#F39C12' # Orange (right) } # Plot underestimates (left side) bars_left = plt.barh( left_positions, left_counts.values, height=0.8, left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values) color=colors['underestimate'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Underestimate' ) # Plot overestimates (right side) bars_right = plt.barh( right_positions, right_counts.values, height=0.8, left=0, color=colors['overestimate'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Overestimate' ) # Plot matches (center — narrow bar) # Use a very narrow width (0.2) centered at 0 plt.barh( positions, match_counts.values, height=0.2, left=0, # starts at 0, extends right color=colors['match'], edgecolor='black', linewidth=0.5, alpha=0.9, label='Exact Match' ) # ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match) # For perfect symmetry: for i, count in enumerate(match_counts.values): if count > 0: plt.barh( positions[i], width=count, left=-count/2, height=0.25, color=colors['match'], edgecolor='black', linewidth=0.5, alpha=0.95 ) # --- 7. Styling & Labels --- # Zero reference line plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8) # X-axis: symmetric around 0 plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1) plt.xticks(rotation=0, fontsize=10) plt.xlabel('Count', fontsize=12) # Y-axis: system names at original positions (centered) plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10) plt.ylabel('Functional System', fontsize=12) # Title & layout plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15) # Clean axes ax = plt.gca() ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['left'].set_visible(False) # We only need bottom axis ax.xaxis.set_ticks_position('bottom') ax.yaxis.set_ticks_position('none') # Grid only along x ax.xaxis.grid(True, linestyle=':', alpha=0.5) # Legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'), Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'), Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate') ] plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10) # Optional: Add counts on bars for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)): if left > 0: plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold') if right > 0: plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold') if match > 0: plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black') plt.tight_layout() # --- 8. Save & Show --- os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') print(f"✅ Diverging bar plot saved to {figure_save_path}") plt.show() ## # %% Difference Plot Gemini import pandas as pd import matplotlib.pyplot as plt import os import numpy as np # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_magnitude_focus.svg' # --- 1. Process Error Data with Magnitude Breakdown --- system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot] plot_list = [] for gt_col, res_col in functional_systems_to_plot: sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) error = res - gt # Granular Counts matches = (error == 0).sum() u_1 = (error == -1).sum() u_2plus = (error <= -2).sum() o_1 = (error == 1).sum() o_2plus = (error >= 2).sum() total = error.dropna().count() divisor = max(total, 1) plot_list.append({ 'System': sys_name.replace('_', ' ').title(), 'Matches': matches, 'MatchPct': (matches / divisor) * 100, 'U1': u_1, 'U2': u_2plus, 'UnderTotal': u_1 + u_2plus, 'UnderPct': ((u_1 + u_2plus) / divisor) * 100, 'O1': o_1, 'O2': o_2plus, 'OverTotal': o_1 + o_2plus, 'OverPct': ((o_1 + o_2plus) / divisor) * 100 }) stats_df = pd.DataFrame(plot_list) # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(13, 8)) # Define Magnitude Colors c_under_dark, c_under_light = '#C0392B', '#E74C3C' # Dark Red (-2+), Soft Red (-1) c_over_dark, c_over_light = '#2980B9', '#3498DB' # Dark Blue (+2+), Soft Blue (+1) bar_height = 0.6 y_pos = np.arange(len(stats_df)) # Plot Under-scored (Stacked: -2+ then -1) ax.barh(y_pos, -stats_df['U2'], bar_height, color=c_under_dark, label='Under -2+', edgecolor='white') ax.barh(y_pos, -stats_df['U1'], bar_height, left=-stats_df['U2'], color=c_under_light, label='Under -1', edgecolor='white') # Plot Over-scored (Stacked: +1 then +2+) ax.barh(y_pos, stats_df['O1'], bar_height, color=c_over_light, label='Over +1', edgecolor='white') ax.barh(y_pos, stats_df['O2'], bar_height, left=stats_df['O1'], color=c_over_dark, label='Over +2+', edgecolor='white') # --- 3. Aesthetics & Table Labels --- for i, row in stats_df.iterrows(): label_text = ( f"$\\mathbf{{{row['System']}}}$\n" f"Match: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n" f"Under: {int(row['UnderTotal'])} ({row['UnderPct']:.1f}%) | Over: {int(row['OverTotal'])} ({row['OverPct']:.1f}%)" ) # Position table text to the left ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.4) # Formatting ax.axvline(0, color='black', linewidth=1.2) ax.set_yticks([]) ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold') #ax.set_title('Directional Error Magnitude (Under vs. Over Scoring)', fontsize=14, pad=35) # Absolute X-axis labels ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()]) # Remove spines and add grid for spine in ['top', 'right', 'left']: ax.spines[spine].set_visible(False) ax.xaxis.grid(True, linestyle='--', alpha=0.3) # Legend with magnitude info ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1), ncol=2) plt.tight_layout() plt.show() ## # %% Functional System Error Boxplots import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_boxplot.svg' # --- 1. Build error data for boxplots --- boxplot_data = [] system_labels = [] sample_sizes = [] for gt_col, res_col in functional_systems_to_plot: sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Error = result - ground truth error = (res - gt).dropna() # Ignore all 0 errors error = error[error != 0] # Keep only systems that actually have non-zero data if len(error) > 0: clean_name = sys_name.replace('_', ' ').title() boxplot_data.append(error.values) system_labels.append(clean_name) sample_sizes.append(len(error)) # Safety check if not boxplot_data: raise ValueError("No valid non-zero error data available for any functional system.") # Put n into x-axis labels so it doesn't overlap the plot xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(14, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False ) # --- 3. Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) for flier in bp['fliers']: flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) # Reference line at zero error ax.axhline(0, color='black', linewidth=1.2, linestyle='--') # Labels and formatting ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') # Rotate x labels for readability plt.xticks(rotation=45, ha='right') # Grid and spines ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- 4. Legend above the plot, outside the axes --- legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False ) # Leave room at the top for the legend plt.tight_layout(rect=[0, 0, 1, 0.90]) # Optional save os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional System + EDSS Error Boxplots import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration & Theme --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_boxplot.svg' # ------------------------------------------------------------ # Expect functional_systems_to_plot like: # [ # ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL_OPTIC_FUNCTIONS'), # ... # ] # # Add EDSS here: # ------------------------------------------------------------ all_systems_to_plot = list(functional_systems_to_plot) + [ ('GT.EDSS', 'result.EDSS') ] # --- 1. Build error data for boxplots --- boxplot_data = [] system_labels = [] sample_sizes = [] for gt_col, res_col in all_systems_to_plot: # Skip safely if a column is missing if gt_col not in df.columns or res_col not in df.columns: print(f"Skipping missing columns: {gt_col}, {res_col}") continue sys_name = gt_col.split('.')[1] # Robust parsing gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Error = result - ground truth error = (res - gt).dropna() # Ignore all 0 errors error = error[error != 0] # Keep only systems that actually have non-zero data if len(error) > 0: if sys_name == 'EDSS': clean_name = 'EDSS' else: clean_name = sys_name.replace('_', ' ').title() boxplot_data.append(error.values) system_labels.append(clean_name) sample_sizes.append(len(error)) # Safety check if not boxplot_data: raise ValueError("No valid non-zero error data available for any functional system or EDSS.") # Put n into x-axis labels so it doesn't overlap the plot xtick_labels = [f"{label}\n(n={n})" for label, n in zip(system_labels, sample_sizes)] # --- 2. Plotting --- fig, ax = plt.subplots(figsize=(15, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False ) # --- 3. Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set(marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6) for flier in bp['fliers']: flier.set(marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4) # Reference line at zero error ax.axhline(0, color='black', linewidth=1.2, linestyle='--') # Labels and formatting ax.set_xlabel('Functional System / EDSS', fontsize=11, fontweight='bold') ax.set_ylabel('Error (Result - Ground Truth)', fontsize=11, fontweight='bold') # Rotate x labels for readability plt.xticks(rotation=45, ha='right') # Grid and spines ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- 4. Legend above the plot, outside the axes --- legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR (25th-75th percentile)'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Zero error reference') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.02), ncol=3, frameon=False ) # Leave room at the top for the legend plt.tight_layout(rect=[0, 0, 1, 0.90]) # Optional save os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% test # Diagnose: what are the actual differences? print("\n🔍 Raw differences (first 5 rows per system):") for gt_col, res_col in functional_systems_to_plot: gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) diff = res - gt non_zero = (diff != 0).sum() # Check if it's due to floating point noise abs_diff = diff.abs() tiny = (abs_diff > 0) & (abs_diff < 1e-10) print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}") ## # %% Functional System Continuous Accuracy Boxplot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_continuous_accuracy_boxplot.svg' # --- Functional systems using your actual column names --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Build accuracy data --- boxplot_data = [] system_labels = [] predicted_counts = [] missing_prediction_counts = [] total_gt_counts = [] mean_accuracies = [] for gt_col, res_col in functional_systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Only rows where ground truth exists gt_exists = gt.notna() total_gt = gt_exists.sum() if total_gt == 0: print(f"Skipping {system_name}: no ground-truth values") continue gt_valid = gt[gt_exists] res_valid = res[gt_exists] # GT exists, but LLM prediction is missing missing_count = res_valid.isna().sum() # For the boxplot, use rows where both GT and result exist both_exist = res_valid.notna() if both_exist.sum() == 0: print(f"Skipping {system_name}: no predicted values") continue gt_eval = gt_valid[both_exist] res_eval = res_valid[both_exist] # Functional system score range. # Adjust if your functional systems use another scale. score_range = 5 # Continuous accuracy: # exact match = 1.0 # off by 1 point = 0.8 # off by 2 points = 0.6 # etc. abs_error = (res_eval - gt_eval).abs() accuracy = 1 - (abs_error / score_range) accuracy = accuracy.clip(lower=0, upper=1) clean_name = system_name.replace('_', ' ').title() boxplot_data.append(accuracy.values) system_labels.append(clean_name) predicted_counts.append(len(gt_eval)) missing_prediction_counts.append(missing_count) total_gt_counts.append(total_gt) mean_accuracies.append(accuracy.mean()) print( f"{clean_name}: " f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, " f"mean accuracy={accuracy.mean():.1%}" ) if not boxplot_data: raise ValueError("No valid accuracy data available for plotting.") # X-axis labels xtick_labels = [ f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}" for label, gt_n, pred_n, miss_n in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts) ] # --- Plot --- fig, ax = plt.subplots(figsize=(16, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False, widths=0.55 ) # --- Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5) for whisker in bp['whiskers']: whisker.set(color=whisker_col, linewidth=1.2) for cap in bp['caps']: cap.set(color=whisker_col, linewidth=1.2) for median in bp['medians']: median.set(color=median_col, linewidth=2) for mean in bp['means']: mean.set( marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6 ) for flier in bp['fliers']: flier.set( marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4 ) # Mean accuracy label above each box for i, acc in enumerate(mean_accuracies, start=1): ax.text( i, 1.03, f"{acc:.1%}", ha='center', va='bottom', fontsize=9, fontweight='bold' ) # Perfect accuracy reference line ax.axhline(1, color='black', linewidth=1.2, linestyle='--', alpha=0.7) # Labels and formatting ax.set_xlabel('Functional System', fontsize=11, fontweight='bold') ax.set_ylabel('Continuous Accuracy', fontsize=11, fontweight='bold') ax.set_ylim(-0.05, 1.10) ax.set_yticks(np.arange(0, 1.01, 0.1)) ax.set_yticklabels([f"{int(y * 100)}%" for y in np.arange(0, 1.01, 0.1)]) plt.xticks(rotation=45, ha='right') ax.yaxis.grid(True, linestyle='--', alpha=0.3) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # Legend legend_handles = [ Patch(facecolor=box_face, edgecolor=box_edge, label='IQR of continuous accuracy'), Line2D([0], [0], color=median_col, lw=2, label='Median'), Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean'), Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'), Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Perfect accuracy') ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.06), ncol=5, frameon=False ) plt.tight_layout(rect=[0, 0, 1, 0.88]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional Systems + EDSS Continuous Accuracy Boxplot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch from matplotlib.lines import Line2D # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_continuous_accuracy_boxplot.svg' # --- Functional systems + EDSS using your actual column names --- functional_systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'), # EDSS ('GT.EDSS', 'result.EDSS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Build accuracy data --- boxplot_data = [] system_labels = [] predicted_counts = [] missing_prediction_counts = [] total_gt_counts = [] mean_accuracies = [] for gt_col, res_col in functional_systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Only rows where ground truth exists gt_exists = gt.notna() total_gt = gt_exists.sum() if total_gt == 0: print(f"Skipping {system_name}: no ground-truth values") continue gt_valid = gt[gt_exists] res_valid = res[gt_exists] # Count cases where GT exists but LLM prediction is missing missing_count = res_valid.isna().sum() # For the boxplot, use only rows where both GT and prediction exist both_exist = res_valid.notna() if both_exist.sum() == 0: print(f"Skipping {system_name}: no predicted values") continue gt_eval = gt_valid[both_exist] res_eval = res_valid[both_exist] # Functional systems are usually scored 0-5. # EDSS is usually scored 0-10. if system_name == "EDSS": score_range = 10 clean_name = "EDSS" else: score_range = 5 clean_name = system_name.replace('_', ' ').title() # Continuous accuracy: # exact match = 1.0 # off by 1 point in FS = 0.8 # off by 1 point in EDSS = 0.9 abs_error = (res_eval - gt_eval).abs() accuracy = 1 - (abs_error / score_range) # Keep values between 0 and 1 accuracy = accuracy.clip(lower=0, upper=1) boxplot_data.append(accuracy.values) system_labels.append(clean_name) predicted_counts.append(len(gt_eval)) missing_prediction_counts.append(missing_count) total_gt_counts.append(total_gt) mean_accuracies.append(accuracy.mean()) print( f"{clean_name}: " f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, " f"mean accuracy={accuracy.mean():.1%}" ) if not boxplot_data: raise ValueError("No valid accuracy data available for plotting.") # --- X-axis labels --- xtick_labels = [ f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}" for label, gt_n, pred_n, miss_n in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts) ] # --- Plot --- fig, ax = plt.subplots(figsize=(17, 8)) bp = ax.boxplot( boxplot_data, vert=True, patch_artist=True, labels=xtick_labels, showmeans=True, meanline=False, widths=0.55 ) # --- Styling --- box_face = '#D6EAF8' box_edge = '#2980B9' whisker_col = '#7F8C8D' median_col = '#C0392B' mean_col = '#1ABC9C' flier_face = '#95A5A6' flier_edge = '#7F8C8D' for box in bp['boxes']: box.set( facecolor=box_face, edgecolor=box_edge, linewidth=1.5 ) for whisker in bp['whiskers']: whisker.set( color=whisker_col, linewidth=1.2 ) for cap in bp['caps']: cap.set( color=whisker_col, linewidth=1.2 ) for median in bp['medians']: median.set( color=median_col, linewidth=2 ) for mean in bp['means']: mean.set( marker='o', markerfacecolor=mean_col, markeredgecolor='black', markersize=6 ) for flier in bp['fliers']: flier.set( marker='o', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.6, markersize=4 ) # --- Mean accuracy labels above each box --- for i, acc in enumerate(mean_accuracies, start=1): ax.text( i, 1.03, f"{acc:.1%}", ha='center', va='bottom', fontsize=9, fontweight='bold' ) # --- Perfect accuracy reference line --- ax.axhline( 1, color='black', linewidth=1.2, linestyle='--', alpha=0.7 ) # --- Labels and formatting --- ax.set_xlabel( 'Functional System / EDSS', fontsize=11, fontweight='bold' ) ax.set_ylabel( 'Continuous Accuracy', fontsize=11, fontweight='bold' ) #ax.set_title( # 'Continuous Accuracy of Functional Systems and EDSS', # fontsize=14, # fontweight='bold', # pad=35 #) ax.set_ylim(-0.05, 1.10) yticks = np.arange(0, 1.01, 0.1) ax.set_yticks(yticks) ax.set_yticklabels([f"{int(y * 100)}%" for y in yticks]) plt.xticks(rotation=45, ha='right') ax.yaxis.grid(True, linestyle='--', alpha=0.3) ax.set_axisbelow(True) for spine in ['top', 'right']: ax.spines[spine].set_visible(False) # --- Legend --- legend_handles = [ Patch( facecolor=box_face, edgecolor=box_edge, label='IQR of continuous accuracy' ), Line2D( [0], [0], color=median_col, lw=2, label='Median' ), Line2D( [0], [0], marker='o', color='w', markerfacecolor=mean_col, markeredgecolor='black', markersize=7, label='Mean' ), Line2D( [0], [0], marker='o', color='w', markerfacecolor=flier_face, markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier' ), Line2D( [0], [0], color='black', lw=1.2, linestyle='--', label='Perfect accuracy' ) ] ax.legend( handles=legend_handles, loc='lower center', bbox_to_anchor=(0.5, 1.08), ncol=5, frameon=False ) # --- Save and show --- plt.tight_layout(rect=[0, 0, 1, 0.86]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format='svg', bbox_inches='tight') plt.show() ## # %% Functional Systems + EDSS Error Category Stacked Bar Plot import pandas as pd import matplotlib.pyplot as plt import os import numpy as np from matplotlib.patches import Patch # --- Configuration --- plt.rcParams['font.family'] = 'Arial' figure_save_path = 'project/visuals/functional_systems_edss_error_categories.svg' # --- Functional systems + EDSS using your actual column names --- systems_to_plot = [ ('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'), ('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'), ('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'), ('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'), ('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'), ('GT.AMBULATION', 'result.AMBULATION'), ('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'), ('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'), ('GT.EDSS', 'result.EDSS') ] # --- Robust parser --- def safe_parse(s): """Convert to float, handling comma decimals like '3,5'.""" if pd.isna(s): return np.nan if isinstance(s, (int, float, np.integer, np.floating)): return float(s) s_clean = str(s).replace(',', '.').strip() if s_clean == "": return np.nan try: return float(s_clean) except ValueError: return np.nan # --- Categorize absolute error --- def categorize_error(abs_error): if abs_error == 0: return "Exact" elif abs_error <= 0.5: return "≤0.5 error" elif abs_error <= 1: return "≤1 error" else: return ">1 error" # --- Prepare data --- rows = [] for gt_col, res_col in systems_to_plot: if gt_col not in df.columns: print(f"Skipping {gt_col}: GT column not found") continue if res_col not in df.columns: print(f"Skipping {res_col}: result column not found") continue system_name = gt_col.split('.')[1] if system_name == "EDSS": clean_name = "EDSS" else: clean_name = system_name.replace("_", " ").title() gt = df[gt_col].apply(safe_parse) res = df[res_col].apply(safe_parse) # Evaluate only cases where ground truth exists gt_exists = gt.notna() gt_valid = gt[gt_exists] res_valid = res[gt_exists] if len(gt_valid) == 0: continue for gt_value, res_value in zip(gt_valid, res_valid): if pd.isna(res_value): category = "Missing" else: abs_error = abs(res_value - gt_value) category = categorize_error(abs_error) rows.append({ "system": clean_name, "category": category }) plot_df = pd.DataFrame(rows) if plot_df.empty: raise ValueError("No valid data available for plotting.") category_order = [ "Exact", "≤0.5 error", "≤1 error", ">1 error", "Missing" ] system_order = [ "Visual Optic Functions", "Cerebellar Functions", "Brainstem Functions", "Sensory Functions", "Pyramidal Functions", "Ambulation", "Cerebral Functions", "Bowel And Bladder Functions", "EDSS" ] counts = ( plot_df .groupby(["system", "category"]) .size() .unstack(fill_value=0) .reindex(index=system_order) .reindex(columns=category_order, fill_value=0) ) # Remove systems that were not available counts = counts.dropna(how="all") # Convert to percentages for easier comparison percentages = counts.div(counts.sum(axis=1), axis=0) * 100 # --- Plot --- fig, ax = plt.subplots(figsize=(13, 7)) colors = { "Exact": "#2ECC71", "≤0.5 error": "#A9DFBF", "≤1 error": "#F9E79F", ">1 error": "#E67E22", "Missing": "#E74C3C" } left = np.zeros(len(percentages)) for category in category_order: values = percentages[category].values ax.barh( percentages.index, values, left=left, color=colors[category], edgecolor="white", linewidth=0.8, label=category ) # Add labels only if segment is large enough for i, value in enumerate(values): if value >= 4: ax.text( left[i] + value / 2, i, f"{value:.1f}%", ha="center", va="center", fontsize=8, fontweight="bold" ) left += values # Add total n and missing count at the right side for i, system in enumerate(percentages.index): total_n = int(counts.loc[system].sum()) missing_n = int(counts.loc[system, "Missing"]) ax.text( 101, i, f"n={total_n}, missing={missing_n}", va="center", ha="left", fontsize=9 ) # --- Formatting --- ax.set_xlim(0, 115) ax.set_xlabel("Percentage of Cases", fontsize=11, fontweight="bold") ax.set_ylabel("Functional System / EDSS", fontsize=11, fontweight="bold") #ax.set_title( # "Prediction Error Categories by Functional System and EDSS", # fontsize=14, # fontweight="bold", # pad=20 #) ax.set_xticks(np.arange(0, 101, 10)) ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)]) ax.xaxis.grid(True, linestyle="--", alpha=0.3) ax.set_axisbelow(True) for spine in ["top", "right", "left"]: ax.spines[spine].set_visible(False) ax.legend( loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=5, frameon=False ) plt.tight_layout(rect=[0, 0, 1, 0.92]) os.makedirs(os.path.dirname(figure_save_path), exist_ok=True) plt.savefig(figure_save_path, format="svg", bbox_inches="tight") plt.show() ## # %% Confusion matrix for one EDSS benchmark result file import os from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import confusion_matrix, classification_report # ========================= # CONFIGURATION # ========================= REFERENCE_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/Data/MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_unique.csv" RESULT_PATH = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/endresults/qwen3.6-35b-a3b_iter_1_20260512_113358_incremental.csv" OUTPUT_DIR = "/home/shahin/Lab/Doktorarbeit/Barcelona/results_edss_benchmark/confusion_matrices" TARGET_ITERATION = 1 MERGE_KEY = "unique_id" # Ground truth EDSS column in the reference file GT_EDSS_COL = "EDSS" # Predicted EDSS column in the result file PRED_EDSS_COL = "EDSS" EDSS_LABELS = [ "0-1", "1-2", "2-3", "3-4", "4-5", "5-6", "6-7", "7-8", "8-9", "9-10" ] # ========================= # HELPERS # ========================= def safe_filename(name): return ( str(name) .replace("/", "_") .replace("\\", "_") .replace(" ", "_") .replace(":", "_") ) def parse_numeric_column(series): return pd.to_numeric( series.astype(str).str.replace(",", ".", regex=False), errors="coerce" ) def categorize_edss(value): if pd.isna(value): return np.nan elif value <= 1.0: return "0-1" elif value <= 2.0: return "1-2" elif value <= 3.0: return "2-3" elif value <= 4.0: return "3-4" elif value <= 5.0: return "4-5" elif value <= 6.0: return "5-6" elif value <= 7.0: return "6-7" elif value <= 8.0: return "7-8" elif value <= 9.0: return "8-9" elif value <= 10.0: return "9-10" else: return np.nan def load_reference(reference_path): df_ref = pd.read_csv(reference_path, sep=";") if MERGE_KEY not in df_ref.columns: raise ValueError(f"Reference file does not contain column: {MERGE_KEY}") if GT_EDSS_COL not in df_ref.columns: raise ValueError(f"Reference file does not contain column: {GT_EDSS_COL}") df_ref = df_ref.copy() df_ref[MERGE_KEY] = df_ref[MERGE_KEY].astype(str) df_ref["GT_EDSS_numeric"] = parse_numeric_column(df_ref[GT_EDSS_COL]) df_ref["GT_EDSS_cat"] = df_ref["GT_EDSS_numeric"].apply(categorize_edss) return df_ref def load_result(result_path): df_res = pd.read_csv(result_path, sep=",") if MERGE_KEY not in df_res.columns: raise ValueError(f"Result file does not contain column: {MERGE_KEY}") if PRED_EDSS_COL not in df_res.columns: raise ValueError(f"Result file does not contain column: {PRED_EDSS_COL}") df_res = df_res.copy() df_res[MERGE_KEY] = df_res[MERGE_KEY].astype(str) if "success" in df_res.columns: df_res = df_res[ df_res["success"].astype(str).str.lower().isin(["true", "1", "yes"]) ] if TARGET_ITERATION is not None and "iteration" in df_res.columns: df_res = df_res[df_res["iteration"] == TARGET_ITERATION] df_res["PRED_EDSS_numeric"] = parse_numeric_column(df_res[PRED_EDSS_COL]) df_res["PRED_EDSS_cat"] = df_res["PRED_EDSS_numeric"].apply(categorize_edss) return df_res def get_model_name(df_res, result_path): if "model" in df_res.columns and df_res["model"].notna().any(): return str(df_res["model"].dropna().iloc[0]) return Path(result_path).stem def plot_confusion_matrix(cm, model_name, output_path): plt.figure(figsize=(10, 8)) ax = sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=EDSS_LABELS, yticklabels=EDSS_LABELS ) cbar = ax.collections[0].colorbar cbar.set_label("Number of Cases", rotation=270, labelpad=20) plt.xlabel("LLM Generated EDSS") plt.ylabel("Ground Truth EDSS") plt.title(f"Confusion Matrix: {model_name} | Iteration {TARGET_ITERATION}") plt.tight_layout() plt.savefig(output_path, dpi=300, bbox_inches="tight") plt.show() # ========================= # MAIN # ========================= if __name__ == "__main__": output_dir = Path(OUTPUT_DIR) output_dir.mkdir(parents=True, exist_ok=True) print("Loading reference:") print(REFERENCE_PATH) df_ref = load_reference(REFERENCE_PATH) print(f"Reference rows: {len(df_ref)}") print(f"Reference rows with valid GT EDSS: {df_ref['GT_EDSS_numeric'].notna().sum()}") print("\nLoading result:") print(RESULT_PATH) df_res = load_result(RESULT_PATH) model_name = get_model_name(df_res, RESULT_PATH) safe_model = safe_filename(model_name) print(f"Model: {model_name}") print(f"Result rows after filtering: {len(df_res)}") before_dedup = len(df_res) df_res = df_res.sort_values(by=[MERGE_KEY]).drop_duplicates(subset=[MERGE_KEY], keep="first") after_dedup = len(df_res) if before_dedup != after_dedup: print(f"Deduplicated result rows by {MERGE_KEY}: {before_dedup} -> {after_dedup}") df_merged = df_ref.merge( df_res, on=MERGE_KEY, how="inner", suffixes=("_gt", "_pred") ) print(f"Merged rows: {len(df_merged)}") df_eval = df_merged.dropna(subset=["GT_EDSS_cat", "PRED_EDSS_cat"]).copy() print(f"Evaluable rows with valid GT and predicted EDSS: {len(df_eval)}") if df_eval.empty: raise ValueError("No evaluable rows after merging and EDSS filtering.") cm = confusion_matrix( df_eval["GT_EDSS_cat"], df_eval["PRED_EDSS_cat"], labels=EDSS_LABELS ) suffix = f"iter_{TARGET_ITERATION}" plot_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.png" cm_csv_path = output_dir / f"{safe_model}_confusion_matrix_{suffix}.csv" report_txt_path = output_dir / f"{safe_model}_classification_report_{suffix}.txt" merged_csv_path = output_dir / f"{safe_model}_merged_eval_rows_{suffix}.csv" plot_confusion_matrix(cm, model_name, plot_path) cm_df = pd.DataFrame(cm, index=EDSS_LABELS, columns=EDSS_LABELS) cm_df.index.name = "Ground Truth EDSS" cm_df.columns.name = "LLM Generated EDSS" cm_df.to_csv(cm_csv_path) report = classification_report( df_eval["GT_EDSS_cat"], df_eval["PRED_EDSS_cat"], labels=EDSS_LABELS, zero_division=0 ) with open(report_txt_path, "w", encoding="utf-8") as f: f.write(f"Model: {model_name}\n") f.write(f"Result file: {RESULT_PATH}\n") f.write(f"Target iteration: {TARGET_ITERATION}\n") f.write(f"Merged rows: {len(df_merged)}\n") f.write(f"Evaluable rows: {len(df_eval)}\n\n") f.write("Classification Report:\n") f.write(report) f.write("\n\nConfusion Matrix Raw Counts:\n") f.write(cm_df.to_string()) keep_cols = [ MERGE_KEY, "MedDatum_gt" if "MedDatum_gt" in df_eval.columns else "MedDatum", "GT_EDSS_numeric", "PRED_EDSS_numeric", "GT_EDSS_cat", "PRED_EDSS_cat", "model", "iteration", "success", "inference_time_sec", "certainty_percent", "reason", ] keep_cols = [col for col in keep_cols if col in df_eval.columns] df_eval[keep_cols].to_csv(merged_csv_path, index=False) print("\nClassification Report:") print(report) print("\nConfusion Matrix Raw Counts:") print(cm_df) print("\nSaved files:") print(f"Plot: {plot_path}") print(f"Confusion matrix: {cm_csv_path}") print(f"Report: {report_txt_path}") print(f"Merged rows: {merged_csv_path}") print("\nDone.") ##