Files
EDSS-calc/Data/show_plots.py
Shahin Ramezanzadeh f4bf37f71c show directional errors
Directional Errors of each functional system.
2026-02-08 01:27:48 +01:00

1845 lines
60 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# %% Scatter
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Load your data from TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv'
df = pd.read_csv(file_path, sep='\t')
# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results
df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.')
df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.')
# Convert to float (handle invalid entries gracefully)
df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce')
df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce')
# Drop rows where either column is NaN
df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results'])
# Create scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(df_clean['GT_EDSS'], df_clean['LLM_Results'], alpha=0.7, color='blue')
# Add labels and title
plt.xlabel('GT_EDSS')
plt.ylabel('LLM_Results')
plt.title('Comparison of GT_EDSS vs LLM_Results')
# Optional: Add a diagonal line for reference (perfect prediction)
plt.plot([0, max(df_clean['GT_EDSS'])], [0, max(df_clean['GT_EDSS'])], color='red', linestyle='--', label='Perfect Prediction')
plt.legend()
# Show plot
plt.grid(True)
plt.tight_layout()
plt.show()
##
# %% Bland0-altman
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import statsmodels.api as sm
# Load your data from TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_MS_Briefe_400_with_unique_id_SHA3_explore_cleaned_results+MS_Briefe_400_with_unique_id_SHA3_explore_cleaned.tsv'
df = pd.read_csv(file_path, sep='\t')
# Replace comma with dot for numeric conversion in GT_EDSS and LLM_Results
df['GT_EDSS'] = df['GT_EDSS'].astype(str).str.replace(',', '.')
df['LLM_Results'] = df['LLM_Results'].astype(str).str.replace(',', '.')
# Convert to float (handle invalid entries gracefully)
df['GT_EDSS'] = pd.to_numeric(df['GT_EDSS'], errors='coerce')
df['LLM_Results'] = pd.to_numeric(df['LLM_Results'], errors='coerce')
# Drop rows where either column is NaN
df_clean = df.dropna(subset=['GT_EDSS', 'LLM_Results'])
# Create Bland-Altman plot
f, ax = plt.subplots(1, figsize=(8, 5))
sm.graphics.mean_diff_plot(df_clean['GT_EDSS'], df_clean['LLM_Results'], ax=ax)
# Add labels and title
ax.set_title('Bland-Altman Plot: GT_EDSS vs LLM_Results')
ax.set_xlabel('Mean of GT_EDSS and LLM_Results')
ax.set_ylabel('Difference between GT_EDSS and LLM_Results')
# Display Bland-Altman plot
plt.tight_layout()
plt.show()
# Print some statistics
mean_diff = np.mean(df_clean['GT_EDSS'] - df_clean['LLM_Results'])
std_diff = np.std(df_clean['GT_EDSS'] - df_clean['LLM_Results'])
print(f"Mean difference: {mean_diff:.3f}")
print(f"Standard deviation of differences: {std_diff:.3f}")
print(f"95% Limits of Agreement: [{mean_diff - 1.96*std_diff:.3f}, {mean_diff + 1.96*std_diff:.3f}]")
##
# %% Confusion matrix
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
# Load your data from TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
# Convert to float (handle invalid entries gracefully)
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
# Drop rows where either column is NaN
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
# For confusion matrix, we need to categorize the values
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
def categorize_edss(value):
if pd.isna(value):
return np.nan
elif value <= 1.0:
return '0-1'
elif value <= 2.0:
return '1-2'
elif value <= 3.0:
return '2-3'
elif value <= 4.0:
return '3-4'
elif value <= 5.0:
return '4-5'
elif value <= 6.0:
return '5-6'
elif value <= 7.0:
return '6-7'
elif value <= 8.0:
return '7-8'
elif value <= 9.0:
return '8-9'
elif value <= 10.0:
return '9-10'
else:
return '10+'
# Create categorical versions
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
# Remove any NaN categories
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
# Create confusion matrix
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
#plt.title('Confusion Matrix: Ground truth EDSS vs interferred EDSS (Categorized 0-10)')
plt.xlabel('LLM Generated EDSS')
plt.ylabel('Ground Truth EDSS')
plt.tight_layout()
plt.show()
# Print classification report
print("Classification Report:")
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
# Print raw counts
print("\nConfusion Matrix (Raw Counts):")
print(cm)
##
# %% Confusion matrix
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
# Load your data from TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
# Convert to float (handle invalid entries gracefully)
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
# Drop rows where either column is NaN
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
# For confusion matrix, we need to categorize the values
# Let's create categories up to 10 (0-1, 1-2, 2-3, ..., 9-10)
def categorize_edss(value):
if pd.isna(value):
return np.nan
elif value <= 1.0:
return '0-1'
elif value <= 2.0:
return '1-2'
elif value <= 3.0:
return '2-3'
elif value <= 4.0:
return '3-4'
elif value <= 5.0:
return '4-5'
elif value <= 6.0:
return '5-6'
elif value <= 7.0:
return '6-7'
elif value <= 8.0:
return '7-8'
elif value <= 9.0:
return '8-9'
elif value <= 10.0:
return '9-10'
else:
return '10+'
# Create categorical versions
df_clean['GT.EDSS_cat'] = df_clean['GT.EDSS'].apply(categorize_edss)
df_clean['result.EDSS_cat'] = df_clean['result.EDSS'].apply(categorize_edss)
# Remove any NaN categories
df_clean = df_clean.dropna(subset=['GT.EDSS_cat', 'result.EDSS_cat'])
# Create confusion matrix
cm = confusion_matrix(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat'],
labels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
# Plot confusion matrix
plt.figure(figsize=(10, 8))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'],
yticklabels=['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10'])
# Add legend text above the color bar
# Get the colorbar object
cbar = ax.collections[0].colorbar
# Add text above the colorbar
cbar.set_label('Number of Cases', rotation=270, labelpad=20)
plt.xlabel('LLM Generated EDSS')
plt.ylabel('Ground Truth EDSS')
#plt.title('Confusion Matrix: Ground truth EDSS vs inferred EDSS (Categorized 0-10)')
plt.tight_layout()
plt.show()
# Print classification report
print("Classification Report:")
print(classification_report(df_clean['GT.EDSS_cat'], df_clean['result.EDSS_cat']))
# Print raw counts
print("\nConfusion Matrix (Raw Counts):")
print(cm)
##
# %% Classification
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
# Load your data from TSV file
file_path ='/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Check data structure
print("Data shape:", df.shape)
print("First few rows:")
print(df.head())
print("\nColumn names:")
for col in df.columns:
print(f" {col}")
# Function to safely convert to boolean
def safe_bool_convert(series):
'''Safely convert series to boolean, handling various input formats'''
# Convert to string first, then to boolean
series_str = series.astype(str).str.strip().str.lower()
# Handle different true/false representations
bool_map = {
'true': True, '1': True, 'yes': True, 'y': True,
'false': False, '0': False, 'no': False, 'n': False
}
converted = series_str.map(bool_map)
# Handle remaining NaN values
converted = converted.fillna(False) # or True, depending on your preference
return converted
# Convert columns safely
if 'result.klassifizierbar' in df.columns:
print("\nresult.klassifizierbar column info:")
print(df['result.klassifizierbar'].head(10))
print("Unique values:", df['result.klassifizierbar'].unique())
df['result.klassifizierbar'] = safe_bool_convert(df['result.klassifizierbar'])
print("After conversion:")
print(df['result.klassifizierbar'].value_counts())
if 'GT.klassifizierbar' in df.columns:
print("\nGT.klassifizierbar column info:")
print(df['GT.klassifizierbar'].head(10))
print("Unique values:", df['GT.klassifizierbar'].unique())
df['GT.klassifizierbar'] = safe_bool_convert(df['GT.klassifizierbar'])
print("After conversion:")
print(df['GT.klassifizierbar'].value_counts())
# Create bar chart showing only True values for klassifizierbar
if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns:
# Get counts for True values only
llm_true_count = df['result.klassifizierbar'].sum()
gt_true_count = df['GT.klassifizierbar'].sum()
# Plot using matplotlib directly
fig, ax = plt.subplots(figsize=(8, 6))
x = np.arange(2)
width = 0.35
bars1 = ax.bar(x[0] - width/2, llm_true_count, width, label='LLM', color='skyblue', alpha=0.8)
bars2 = ax.bar(x[1] + width/2, gt_true_count, width, label='GT', color='lightcoral', alpha=0.8)
# Add value labels on bars
ax.annotate(f'{llm_true_count}',
xy=(x[0], llm_true_count),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom')
ax.annotate(f'{gt_true_count}',
xy=(x[1], gt_true_count),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom')
ax.set_xlabel('Classification Status (klassifizierbar)')
ax.set_ylabel('Count')
ax.set_title('True Values Comparison: LLM vs GT for "klassifizierbar"')
ax.set_xticks(x)
ax.set_xticklabels(['LLM', 'GT'])
ax.legend()
plt.tight_layout()
plt.show()
# Create confusion matrix if both columns exist
if 'result.klassifizierbar' in df.columns and 'GT.klassifizierbar' in df.columns:
try:
# Ensure both columns are boolean
llm_bool = df['result.klassifizierbar'].fillna(False).astype(bool)
gt_bool = df['GT.klassifizierbar'].fillna(False).astype(bool)
cm = confusion_matrix(gt_bool, llm_bool)
# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['False ', 'True '],
yticklabels=['False', 'True '],
ax=ax)
ax.set_xlabel('LLM Predictions ')
ax.set_ylabel('GT Labels ')
ax.set_title('Confusion Matrix: LLM vs GT for "klassifizierbar"')
plt.tight_layout()
plt.show()
print("Confusion Matrix:")
print(cm)
except Exception as e:
print(f"Error creating confusion matrix: {e}")
# Show final data info
print("\nFinal DataFrame info:")
print(df[['result.klassifizierbar', 'GT.klassifizierbar']].info())
##
# %% Boxplot
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Load your data from TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/join_results_unique.tsv'
df = pd.read_csv(file_path, sep='\t')
# Replace comma with dot for numeric conversion in GT.EDSS and result.EDSS
df['GT.EDSS'] = df['GT.EDSS'].astype(str).str.replace(',', '.')
df['result.EDSS'] = df['result.EDSS'].astype(str).str.replace(',', '.')
# Convert to float (handle invalid entries gracefully)
df['GT.EDSS'] = pd.to_numeric(df['GT.EDSS'], errors='coerce')
df['result.EDSS'] = pd.to_numeric(df['result.EDSS'], errors='coerce')
# Drop rows where either column is NaN
df_clean = df.dropna(subset=['GT.EDSS', 'result.EDSS'])
# 1. DEFINE CATEGORY ORDER
# This ensures the X-axis is numerically logical (0-1 comes before 1-2)
category_order = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10', '10+']
# Convert the column to a Categorical type with the specific order
df_clean['GT.EDSS_cat'] = pd.Categorical(df_clean['GT.EDSS'].apply(categorize_edss),
categories=category_order,
ordered=True)
plt.figure(figsize=(14, 8))
# 2. ADD HUE FOR LEGEND
# Assigning x to 'hue' allows Seaborn to generate a legend automatically
box_plot = sns.boxplot(
data=df_clean,
x='GT.EDSS_cat',
y='result.EDSS',
hue='GT.EDSS_cat', # Added hue
palette='viridis',
linewidth=1.5,
legend=True # Ensure legend is enabled
)
# 3. CUSTOMIZE PLOT
plt.title('Distribution of result.EDSS by GT.EDSS Category', fontsize=18, pad=20)
plt.xlabel('Ground Truth EDSS Category', fontsize=14)
plt.ylabel('LLM Predicted EDSS', fontsize=14)
# Move legend to the side or top
plt.legend(title="EDSS Categories", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.xticks(rotation=45, ha='right', fontsize=10)
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
##
# %% Postproccessing Column names
import pandas as pd
# Read the TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Create a mapping dictionary for German to English column names
column_mapping = {
'EDSS':'GT.EDSS',
'klassifizierbar': 'GT.klassifizierbar',
'Sehvermögen': 'GT.VISUAL_OPTIC_FUNCTIONS',
'Cerebellum': 'GT.CEREBELLAR_FUNCTIONS',
'Hirnstamm': 'GT.BRAINSTEM_FUNCTIONS',
'Sensibiliät': 'GT.SENSORY_FUNCTIONS',
'Pyramidalmotorik': 'GT.PYRAMIDAL_FUNCTIONS',
'Ambulation': 'GT.AMBULATION',
'Cerebrale_Funktion': 'GT.CEREBRAL_FUNCTIONS',
'Blasen-_und_Mastdarmfunktion': 'GT.BOWEL_AND_BLADDER_FUNCTIONS'
}
# Rename columns
df = df.rename(columns=column_mapping)
# Save the modified dataframe back to TSV file
df.to_csv(file_path, sep='\t', index=False)
print("Columns have been successfully renamed!")
print("Renamed columns:")
for old_name, new_name in column_mapping.items():
if old_name in df.columns:
print(f" {old_name} -> {new_name}")
##
# %% Styled table
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import dataframe_image as dfi
# Load data
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
# 1. Identify all GT and result columns
gt_columns = [col for col in df.columns if col.startswith('GT.')]
result_columns = [col for col in df.columns if col.startswith('result.')]
print("GT Columns found:", gt_columns)
print("Result Columns found:", result_columns)
# 2. Create proper mapping between GT and result columns
# Handle various naming conventions (spaces, underscores, etc.)
column_mapping = {}
for gt_col in gt_columns:
base_name = gt_col.replace('GT.', '')
# Clean the base name for matching - remove spaces, underscores, etc.
# Try different matching approaches
candidates = [
f'result.{base_name}', # Exact match
f'result.{base_name.replace(" ", "_")}', # With underscores
f'result.{base_name.replace("_", " ")}', # With spaces
f'result.{base_name.replace(" ", "")}', # No spaces
f'result.{base_name.replace("_", "")}' # No underscores
]
# Also try case-insensitive matching
candidates.append(f'result.{base_name.lower()}')
candidates.append(f'result.{base_name.upper()}')
# Try to find matching result column
matched = False
for candidate in candidates:
if candidate in result_columns:
column_mapping[gt_col] = candidate
matched = True
break
# If no exact match found, try partial matching
if not matched:
# Try to match by removing special characters and comparing
base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' '])
for result_col in result_columns:
result_base = result_col.replace('result.', '')
result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' '])
if base_clean.lower() == result_clean.lower():
column_mapping[gt_col] = result_col
matched = True
break
print("Column mapping:", column_mapping)
# 3. Faster, vectorized computation using the corrected mapping
data_list = []
for gt_col, result_col in column_mapping.items():
print(f"Processing {gt_col} vs {result_col}")
# Convert to numeric, forcing errors to NaN
s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float)
s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float)
# Calculate matches (abs difference <= 0.5)
diff = np.abs(s1 - s2)
matches = (diff <= 0.5).sum()
# Determine the denominator (total valid comparisons)
valid_count = diff.notna().sum()
if valid_count > 0:
percentage = (matches / valid_count) * 100
else:
percentage = 0
# Extract clean base name for display
base_name = gt_col.replace('GT.', '')
data_list.append({
'GT': base_name,
'Match %': round(percentage, 1)
})
# 4. Prepare Data
match_df = pd.DataFrame(data_list)
# Clean up labels: Replace underscores with spaces and capitalize
match_df['GT'] = match_df['GT'].str.replace('_', ' ').str.title()
match_df = match_df.sort_values('Match %', ascending=False)
# 5. Create a "Beautiful" Table using Seaborn Heatmap
def create_luxury_table(df, output_file="edss_agreement.png"):
# Set the aesthetic style
sns.set_theme(style="white", font="sans-serif")
# Prepare data for heatmap
plot_data = df.set_index('GT')[['Match %']]
# Initialize the figure
# Height is dynamic based on number of rows
fig, ax = plt.subplots(figsize=(8, len(df) * 0.6))
# Create a custom diverging color map (Deep Red -> Mustard -> Emerald)
# This looks more professional than standard 'RdYlGn'
cmap = sns.diverging_palette(15, 135, s=80, l=55, as_cmap=True)
# Draw the heatmap
sns.heatmap(
plot_data,
annot=True,
fmt=".1f",
cmap=cmap,
center=85, # Centers the color transition
vmin=50, vmax=100, # Range of the gradient
linewidths=2,
linecolor='white',
cbar=False, # Remove color bar for a "table" look
annot_kws={"size": 14, "weight": "bold", "family": "sans-serif"}
)
# Styling the Axes (Turning the heatmap into a table)
ax.set_xlabel("")
ax.set_ylabel("")
ax.xaxis.tick_top() # Move "Match %" label to top
ax.set_xticklabels(['Agreement (%)'], fontsize=14, fontweight='bold', color='#2c3e50')
ax.tick_params(axis='y', labelsize=12, labelcolor='#2c3e50', length=0)
# Add a thin border around the plot
for _, spine in ax.spines.items():
spine.set_visible(True)
spine.set_color('#ecf0f1')
plt.title('EDSS Subcategory Consistency Analysis', fontsize=16, pad=40, fontweight='bold', color='#2c3e50')
# Add a subtle footer
plt.figtext(0.5, 0.0, "Tolerance: ±0.5 points",
wrap=True, horizontalalignment='center', fontsize=10, color='gray', style='italic')
# Save with high resolution
plt.tight_layout()
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"Beautiful table saved as {output_file}")
# Execute
create_luxury_table(match_df)
# Run the function
save_styled_table(match_df)
# 6. Save as SVG
plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight')
print("Successfully saved agreement_table.svg")
# Show plot if running in a GUI environment
plt.show()
##
# %% Time Plot
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
# Load the TSV file
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Extract the inference_time_sec column
inference_times = df['inference_time_sec'].dropna() # Remove NaN values
# Calculate statistics
mean_time = inference_times.mean()
std_time = inference_times.std()
median_time = np.median(inference_times)
# Create the histogram
fig, ax = plt.subplots(figsize=(10, 6))
# Create histogram with bins of 1 second width
min_time = int(inference_times.min())
max_time = int(inference_times.max()) + 1
bins = np.arange(min_time, max_time + 1, 1) # Bins of 1 second width
# Create histogram with counts (not probability density)
n, bins, patches = ax.hist(inference_times, bins=bins, color='lightblue', alpha=0.7, edgecolor='black', linewidth=0.5)
# Generate Gaussian curve for fit
x = np.linspace(inference_times.min(), inference_times.max(), 100)
# Scale Gaussian to match histogram counts
gaussian_counts = stats.norm.pdf(x, mean_time, std_time) * len(inference_times) * (bins[1] - bins[0])
# Plot Gaussian fit
ax.plot(x, gaussian_counts, color='red', linewidth=2, label=f'Gaussian Fit (μ={mean_time:.1f}s, σ={std_time:.1f}s)')
# Add vertical lines for mean and median
ax.axvline(mean_time, color='blue', linestyle='--', linewidth=2, label=f'Mean = {mean_time:.1f}s')
ax.axvline(median_time, color='green', linestyle='--', linewidth=2, label=f'Median = {median_time:.1f}s')
# Add standard deviation as vertical lines
ax.axvline(mean_time + std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'+1σ = {mean_time + std_time:.1f}s')
ax.axvline(mean_time - std_time, color='saddlebrown', linestyle=':', linewidth=1, alpha=0.7, label=f'-1σ = {mean_time - std_time:.1f}s')
ax.set_xlabel('Inference Time (seconds)')
ax.set_ylabel('Frequency')
ax.set_title('Inference Time Distribution with Gaussian Fit')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
##
# %% Dashboard
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import numpy as np
# Load the data
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Rename columns to remove 'result.' prefix and handle spaces
column_mapping = {}
for col in df.columns:
if col.startswith('result.'):
new_name = col.replace('result.', '')
# Handle spaces in column names (replace with underscores if needed)
new_name = new_name.replace(' ', '_')
column_mapping[col] = new_name
df = df.rename(columns=column_mapping)
# Convert MedDatum to datetime
df['MedDatum'] = pd.to_datetime(df['MedDatum'])
# Check what columns actually exist in the dataset
print("Available columns:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())
# Hardcode specific patient names
patient_names = ['6b56865d']
# Define the functional systems (columns to plot) - adjust based on actual column names
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
# Create subplots horizontally (2 columns, adjust rows as needed)
num_plots = len(functional_systems)
num_cols = 2
num_rows = (num_plots + num_cols - 1) // num_cols # Ceiling division
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False) # Changed sharex=False
if num_plots == 1:
axes = [axes]
elif num_rows == 1:
axes = axes
else:
axes = axes.flatten()
# Plot for the hardcoded patient
for i, system in enumerate(functional_systems):
# Filter data for this specific patient
patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum')
# Check if patient data exists
if patient_data.empty:
print(f"No data found for patient: {patient_names[0]}")
continue
# Check if the system column exists in the data
if system in patient_data.columns:
# Plot the specific functional system
if not patient_data[system].isna().all():
axes[i].plot(patient_data['MedDatum'], patient_data[system], marker='o', linewidth=2, label=system)
axes[i].set_ylabel('Score')
axes[i].set_title(f'Functional System: {system}')
axes[i].grid(True, alpha=0.3)
axes[i].legend()
else:
axes[i].set_title(f'Functional System: {system} (No data)')
axes[i].set_ylabel('Score')
axes[i].grid(True, alpha=0.3)
else:
# Try to find column with similar name (case insensitive)
found_column = None
for col in df.columns:
if system.lower() in col.lower():
found_column = col
break
if found_column:
print(f"Found similar column: {found_column}")
if not patient_data[found_column].isna().all():
axes[i].plot(patient_data['MedDatum'], patient_data[found_column], marker='o', linewidth=2, label=found_column)
axes[i].set_ylabel('Score')
axes[i].set_title(f'Functional System: {system} (found as: {found_column})')
axes[i].grid(True, alpha=0.3)
axes[i].legend()
else:
axes[i].set_title(f'Functional System: {system} (Column not found)')
axes[i].set_ylabel('Score')
axes[i].grid(True, alpha=0.3)
# Hide empty subplots
for i in range(len(functional_systems), len(axes)):
axes[i].set_visible(False)
# Set x-axis label for the last row only
for i in range(len(functional_systems)):
if i >= len(axes) - num_cols: # Last row
axes[i].set_xlabel('Date')
# Force date formatting on all axes
for ax in axes:
ax.tick_params(axis='x', rotation=45)
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(plt.matplotlib.dates.MonthLocator())
# Automatically format x-axis dates
plt.gcf().autofmt_xdate()
plt.tight_layout()
plt.show()
##
# %% Table
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import numpy as np
# Load the data
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
df = pd.read_csv(file_path, sep='\t')
# Convert MedDatum to datetime
df['MedDatum'] = pd.to_datetime(df['MedDatum'])
# Check what columns actually exist in the dataset
print("Available columns:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head())
# Check data types
print("\nData types:")
print(df.dtypes)
# Hardcode specific patient names
patient_names = ['6ccda8c6']
# Define the functional systems (columns to plot)
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
# Create subplots
num_plots = len(functional_systems)
num_cols = 2
num_rows = (num_plots + num_cols - 1) // num_cols
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False)
if num_plots == 1:
axes = [axes]
elif num_rows == 1:
axes = axes
else:
axes = axes.flatten()
# Plot for the hardcoded patient
for i, system in enumerate(functional_systems):
# Filter data for this specific patient
patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum')
# Check if patient data exists
if patient_data.empty:
print(f"No data found for patient: {patient_names[0]}")
axes[i].set_title(f'Functional System: {system} (No data)')
axes[i].set_ylabel('Score')
continue
# Check if the system column exists
if system in patient_data.columns:
# Plot only valid data (non-null values)
valid_data = patient_data.dropna(subset=[system])
if not valid_data.empty:
# Ensure MedDatum is properly formatted for plotting
axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system)
axes[i].set_ylabel('Score')
axes[i].set_title(f'Functional System: {system}')
axes[i].grid(True, alpha=0.3)
axes[i].legend()
else:
axes[i].set_title(f'Functional System: {system} (No valid data)')
axes[i].set_ylabel('Score')
else:
# Try to find similar column names
found_column = None
for col in df.columns:
if system.lower() in col.lower():
found_column = col
break
if found_column:
valid_data = patient_data.dropna(subset=[found_column])
if not valid_data.empty:
axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column)
axes[i].set_ylabel('Score')
axes[i].set_title(f'Functional System: {system} (found as: {found_column})')
axes[i].grid(True, alpha=0.3)
axes[i].legend()
else:
axes[i].set_title(f'Functional System: {system} (No valid data)')
axes[i].set_ylabel('Score')
else:
axes[i].set_title(f'Functional System: {system} (Column not found)')
axes[i].set_ylabel('Score')
# Hide empty subplots
for i in range(len(functional_systems), len(axes)):
axes[i].set_visible(False)
# Set x-axis label for the last row only
for i in range(len(functional_systems)):
if i >= len(axes) - num_cols: # Last row
axes[i].set_xlabel('Date')
# Format x-axis dates
for ax in axes:
if ax.get_lines(): # Only format if there are lines to plot
ax.tick_params(axis='x', rotation=45)
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
# Automatically adjust layout
plt.tight_layout()
plt.show()
##
# %% Histogram Fig1
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import json
import os
def create_visit_frequency_plot(
file_path,
output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data',
output_filename='visit_frequency_distribution.svg',
fontsize=10,
color_scheme_path='colors.json'
):
"""
Creates a publication-ready bar chart of patient visit frequency.
Args:
file_path (str): Path to the input TSV file.
output_dir (str): Directory to save the output SVG file.
output_filename (str): Name of the output SVG file.
fontsize (int): Font size for all text elements (labels, title).
color_scheme_path (str): Path to the JSON file containing the color palette.
"""
# --- 1. Load Data and Color Scheme ---
try:
df = pd.read_csv(file_path, sep='\t')
print("Data loaded successfully.")
# Sort data for easier visual comparison
df = df.sort_values(by='Visits Count')
except FileNotFoundError:
print(f"Error: The file was not found at {file_path}")
return
try:
with open(color_scheme_path, 'r') as f:
colors = json.load(f)
# Select a blue from the sequential palette for the bars
bar_color = colors['sequential']['blues'][-2] # A saturated blue
except FileNotFoundError:
print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.")
bar_color = '#2171b5' # A common matplotlib blue
# --- 2. Set up the Plot with Scientific Style ---
plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height
# Set the font to Arial
arial_font = fm.FontProperties(family='Arial', size=fontsize)
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = fontsize
# --- 3. Create the Bar Chart ---
ax = plt.gca()
bars = plt.bar(
x=df['Visits Count'],
height=df['Unique Patients'],
color=bar_color,
edgecolor='black',
linewidth=0.5, # Minimum line thickness
width=0.7
)
# --- NEW: Explicitly set x-ticks and labels to ensure all are shown ---
# Get the unique visit counts to use as tick labels
visit_counts = df['Visits Count'].unique()
# Set the x-ticks to be at the center of each bar
ax.set_xticks(visit_counts)
# Set the x-tick labels to be the visit counts, using the specified font
ax.set_xticklabels(visit_counts, fontproperties=arial_font)
# --- END OF NEW SECTION ---
# --- 4. Customize Axes and Layout (Nature style) ---
# Display only left and bottom axes
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Turn off axis ticks (the marks, not the labels)
plt.tick_params(axis='both', which='both', length=0)
# Remove grid lines
plt.grid(False)
# Set background to white (no shading)
ax.set_facecolor('white')
plt.gcf().set_facecolor('white')
# --- 5. Add Labels and Title ---
plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10)
plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10)
plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20)
# --- 6. Add y-axis values on top of each bar ---
# This adds the count of unique patients directly above each bar.
ax.bar_label(bars, fmt='%d', padding=3)
# --- 7. Export the Figure ---
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
full_output_path = os.path.join(output_dir, output_filename)
plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight')
print(f"\nFigure saved as '{full_output_path}'")
# --- 8. (Optional) Display the Plot ---
# plt.show()
# --- Main execution ---
if __name__ == '__main__':
# Define the file path
input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv'
# Call the function to create and save the plot
create_visit_frequency_plot(
file_path=input_file,
fontsize=10 # Using a 10 pt font size as per guidelines
)
##
# %% Scatter Plot functional system
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
# --- Configuration ---
# Set the font to Arial for all text in the plot, as per the guidelines
plt.rcParams['font.family'] = 'Arial'
# Define the path to your data file
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
# Define the path to save the color mapping JSON file
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
# Define the path to save the final figure
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
# --- 1. Load the Dataset ---
try:
# Load the TSV file
df = pd.read_csv(data_path, sep='\t')
print(f"Successfully loaded data from {data_path}")
print(f"Data shape: {df.shape}")
except FileNotFoundError:
print(f"Error: The file at {data_path} was not found.")
# Exit or handle the error appropriately
raise
# --- 2. Define Functional Systems and Create Color Mapping ---
# List of tuples containing (ground_truth_column, result_column)
functional_systems_to_plot = [
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
('GT.AMBULATION', 'result.AMBULATION'),
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
]
# Extract system names for color mapping and legend
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
# Define a professional color palette (dark blue theme)
# This is a qualitative palette with distinct, accessible colors
colors = [
'#003366', # Dark Blue
'#336699', # Medium Blue
'#6699CC', # Light Blue
'#99CCFF', # Very Light Blue
'#FF9966', # Coral
'#FF6666', # Light Red
'#CC6699', # Magenta
'#9966CC' # Purple
]
# Create a dictionary mapping system names to colors
color_map = dict(zip(system_names, colors))
# Ensure the directory for the JSON file exists
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
# Save the color map to a JSON file
with open(color_json_path, 'w') as f:
json.dump(color_map, f, indent=4)
print(f"Color mapping saved to {color_json_path}")
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
agreement_percentages = {}
legend_labels = {}
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Convert columns to numeric, setting errors to NaN
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
# Ensure we are comparing the same rows
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
gt_data = gt_numeric.loc[common_index]
res_data = res_numeric.loc[common_index]
# Calculate agreement percentage
if len(gt_data) > 0:
agreement = (gt_data == res_data).mean() * 100
else:
agreement = 0 # Handle case with no valid data
agreement_percentages[system_name] = agreement
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
# --- 4. Reshape Data for Plotting ---
plot_data = []
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Convert columns to numeric, setting errors to NaN
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
# Create a temporary DataFrame with the numeric data
temp_df = pd.DataFrame({
'system': system_name,
'ground_truth': gt_numeric,
'inference': res_numeric
})
# Drop rows where either value is NaN, as they cannot be plotted
temp_df = temp_df.dropna()
plot_data.append(temp_df)
# Concatenate all the temporary DataFrames into one
plot_df = pd.concat(plot_data, ignore_index=True)
if plot_df.empty:
print("Warning: No valid numeric data to plot after conversion. The plot will be blank.")
else:
print(f"Prepared plot data with {len(plot_df)} data points.")
# --- 5. Create the Scatter Plot ---
plt.figure(figsize=(10, 8))
# Plot each functional system with its assigned color and formatted legend label
for system, group in plot_df.groupby('system'):
plt.scatter(
group['ground_truth'],
group['inference'],
label=legend_labels[system],
color=color_map[system],
alpha=0.7,
s=30
)
# Add a diagonal line representing perfect agreement (y = x)
# This line helps visualize how close the predictions are to the ground truth
if not plot_df.empty:
plt.plot(
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
color='black',
linestyle='--',
linewidth=0.8,
alpha=0.7
)
# --- 6. Apply Styling and Labels ---
plt.xlabel('Ground Truth', fontsize=12)
plt.ylabel('LLM Inference', fontsize=12)
plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14)
# Apply scientific visualization styling rules
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(axis='both', which='both', length=0) # Remove ticks
ax.grid(False) # Remove grid lines
plt.legend(title='Functional System', frameon=False, fontsize=10)
# --- 7. Save and Display the Figure ---
# Ensure the directory for the figure exists
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
print(f"Figure successfully saved to {figure_save_path}")
# Display the plot
plt.show()
##
# %% Confusion Matrix functional systems
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import numpy as np
import matplotlib.colors as mcolors
# --- Configuration ---
plt.rcParams['font.family'] = 'Arial'
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
# --- 1. Load the Dataset ---
df = pd.read_csv(data_path, sep='\t')
# --- 2. Define Functional Systems and Colors ---
functional_systems_to_plot = [
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
('GT.AMBULATION', 'result.AMBULATION'),
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
]
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC']
color_map = dict(zip(system_names, colors))
# --- 3. Categorization Function ---
categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']
category_to_index = {cat: i for i, cat in enumerate(categories)}
n_categories = len(categories)
def categorize_edss(value):
if pd.isna(value): return np.nan
# Ensure value is float to avoid TypeError
val = float(value)
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
return categories[min(idx, len(categories)-1)]
# --- 4. Prepare Mixed Color Matrix with Saturation ---
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
# Fix: Ensure numeric conversion to avoid string comparison errors
temp_df = df[[gt_col, res_col]].copy()
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
valid_df = temp_df.dropna()
for _, row in valid_df.iterrows():
gt_cat = categorize_edss(row[gt_col])
res_cat = categorize_edss(row[res_col])
if gt_cat in category_to_index and res_cat in category_to_index:
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
# Create an RGBA image matrix (10x10x4)
rgba_matrix = np.zeros((n_categories, n_categories, 4))
total_counts = np.sum(cell_system_counts, axis=2)
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
for i in range(n_categories):
for j in range(n_categories):
count_sum = total_counts[i, j]
if count_sum > 0:
mixed_rgb = np.zeros(3)
for s_idx, s_name in enumerate(system_names):
weight = cell_system_counts[i, j, s_idx] / count_sum
system_rgb = mcolors.to_rgb(color_map[s_name])
mixed_rgb += np.array(system_rgb) * weight
# Set RGB channels
rgba_matrix[i, j, :3] = mixed_rgb
# Set Alpha channel (Saturation Effect)
# Using a square root scale to make lower counts more visible but still "lighter"
alpha = np.sqrt(count_sum / max_count)
# Ensure alpha is at least 0.1 so it's not invisible
rgba_matrix[i, j, 3] = max(0.1, alpha)
else:
# Empty cells are white
rgba_matrix[i, j] = [1, 1, 1, 0]
# --- 5. Plotting ---
fig, ax = plt.subplots(figsize=(12, 10))
# Show the matrix
# Note: we use origin='lower' if you want 0-1 at the bottom,
# but confusion matrices usually have 0-1 at the top (origin='upper')
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
# Add count labels
for i in range(n_categories):
for j in range(n_categories):
if total_counts[i, j] > 0:
# Background brightness for text contrast
bg_color = rgba_matrix[i, j, :3]
lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
# If alpha is low, background is effectively white, so use black text
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
color=text_col, fontsize=10, fontweight='bold')
# --- 6. Styling ---
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
ax.set_xticks(np.arange(n_categories))
ax.set_xticklabels(categories)
ax.set_yticks(np.arange(n_categories))
ax.set_yticklabels(categories)
# Remove the frame/spines for a cleaner look
for spine in ax.spines.values():
spine.set_visible(False)
# Custom Legend
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
labels = [name.replace('_', ' ').capitalize() for name in system_names]
ax.legend(handles, labels, title='Functional Systems', loc='upper left',
bbox_to_anchor=(1.05, 1), frameon=False)
plt.tight_layout()
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
plt.show()
##
# %% Difference Plot Functional system
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import numpy as np
# --- Configuration ---
# Set the font to Arial for all text in the plot, as per the guidelines
plt.rcParams['font.family'] = 'Arial'
# Define the path to your data file
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
# Define the path to save the color mapping JSON file
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
# Define the path to save the final figure
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
# --- 1. Load the Dataset ---
try:
# Load the TSV file
df = pd.read_csv(data_path, sep='\t')
print(f"Successfully loaded data from {data_path}")
print(f"Data shape: {df.shape}")
except FileNotFoundError:
print(f"Error: The file at {data_path} was not found.")
# Exit or handle the error appropriately
raise
# --- 2. Define Functional Systems and Create Color Mapping ---
# List of tuples containing (ground_truth_column, result_column)
functional_systems_to_plot = [
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
('GT.AMBULATION', 'result.AMBULATION'),
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
]
# Extract system names for color mapping and legend
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
# Define a professional color palette (dark blue theme)
# This is a qualitative palette with distinct, accessible colors
colors = [
'#003366', # Dark Blue
'#336699', # Medium Blue
'#6699CC', # Light Blue
'#99CCFF', # Very Light Blue
'#FF9966', # Coral
'#FF6666', # Light Red
'#CC6699', # Magenta
'#9966CC' # Purple
]
# Create a dictionary mapping system names to colors
color_map = dict(zip(system_names, colors))
# Ensure the directory for the JSON file exists
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
# Save the color map to a JSON file
with open(color_json_path, 'w') as f:
json.dump(color_map, f, indent=4)
print(f"Color mapping saved to {color_json_path}")
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
agreement_percentages = {}
legend_labels = {}
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Convert columns to numeric, setting errors to NaN
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
# Ensure we are comparing the same rows
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
gt_data = gt_numeric.loc[common_index]
res_data = res_numeric.loc[common_index]
# Calculate agreement percentage
if len(gt_data) > 0:
agreement = (gt_data == res_data).mean() * 100
else:
agreement = 0 # Handle case with no valid data
agreement_percentages[system_name] = agreement
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
# --- 4. Robustly Prepare Error Data for Boxplot ---
def safe_parse(s):
'''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)'''
if pd.isna(s):
return np.nan
if isinstance(s, (int, float)):
return float(s)
# Replace comma with dot, then strip whitespace
s_clean = str(s).replace(',', '.').strip()
try:
return float(s_clean)
except ValueError:
return np.nan
plot_data = []
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Parse both columns with robust comma handling
gt_numeric = df[gt_col].apply(safe_parse)
res_numeric = df[res_col].apply(safe_parse)
# Compute error (only where both are finite)
error = res_numeric - gt_numeric
# Create temp DataFrame
temp_df = pd.DataFrame({
'system': system_name,
'error': error
}).dropna() # drop rows where either was unparseable
plot_data.append(temp_df)
plot_df = pd.concat(plot_data, ignore_index=True)
if plot_df.empty:
print("⚠️ Warning: No valid numeric error data to plot after robust parsing.")
else:
print(f"✅ Prepared error data with {len(plot_df)} data points.")
# Diagnostic: show a few samples
print("\n📌 Sample errors by system:")
for sys, grp in plot_df.groupby('system'):
print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}")
# Ensure categorical ordering
plot_df['system'] = pd.Categorical(
plot_df['system'],
categories=[name.split('.')[1] for name, _ in functional_systems_to_plot],
ordered=True
)
# --- 5. Prepare Data for Diverging Stacked Bar Plot ---
print("\n📊 Preparing diverging stacked bar plot data...")
# Define bins for error direction
def categorize_error(err):
if pd.isna(err):
return 'missing'
elif err < 0:
return 'underestimate'
elif err > 0:
return 'overestimate'
else:
return 'match'
# Add category column (only on finite errors)
plot_df_clean = plot_df[plot_df['error'].notna()].copy()
plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error)
# Count by system + category
category_counts = (
plot_df_clean
.groupby(['system', 'category'])
.size()
.unstack(fill_value=0)
.reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0)
)
# Reorder systems
category_counts = category_counts.reindex(system_names)
# Prepare for diverging plot:
# - Underestimates: plotted to the *left* (negative x)
# - Overestimates: plotted to the *right* (positive x)
# - Matches: centered (no width needed, or as a bar of width 0.2)
underestimate_counts = category_counts['underestimate']
match_counts = category_counts['match']
overestimate_counts = category_counts['overestimate']
# For diverging: left = -underestimate, right = overestimate
left_counts = underestimate_counts
right_counts = overestimate_counts
# Compute max absolute bar height (for symmetric x-axis)
max_bar = max(left_counts.max(), right_counts.max(), 1)
plot_range = (-max_bar, max_bar)
# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ...
n_systems = len(system_names)
positions = np.arange(n_systems)
left_positions = -positions - 0.5 # left-aligned underestimates
right_positions = positions + 0.5 # right-aligned overestimates
# --- 6. Create Diverging Stacked Bar Plot ---
plt.figure(figsize=(12, 7))
# Colors: diverging palette
colors = {
'underestimate': '#E74C3C', # Red (left)
'match': '#2ECC71', # Green (center)
'overestimate': '#F39C12' # Orange (right)
}
# Plot underestimates (left side)
bars_left = plt.barh(
left_positions,
left_counts.values,
height=0.8,
left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values)
color=colors['underestimate'],
edgecolor='black',
linewidth=0.5,
alpha=0.9,
label='Underestimate'
)
# Plot overestimates (right side)
bars_right = plt.barh(
right_positions,
right_counts.values,
height=0.8,
left=0,
color=colors['overestimate'],
edgecolor='black',
linewidth=0.5,
alpha=0.9,
label='Overestimate'
)
# Plot matches (center — narrow bar)
# Use a very narrow width (0.2) centered at 0
plt.barh(
positions,
match_counts.values,
height=0.2,
left=0, # starts at 0, extends right
color=colors['match'],
edgecolor='black',
linewidth=0.5,
alpha=0.9,
label='Exact Match'
)
# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match)
# For perfect symmetry:
for i, count in enumerate(match_counts.values):
if count > 0:
plt.barh(
positions[i],
width=count,
left=-count/2,
height=0.25,
color=colors['match'],
edgecolor='black',
linewidth=0.5,
alpha=0.95
)
# --- 7. Styling & Labels ---
# Zero reference line
plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8)
# X-axis: symmetric around 0
plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1)
plt.xticks(rotation=0, fontsize=10)
plt.xlabel('Count', fontsize=12)
# Y-axis: system names at original positions (centered)
plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10)
plt.ylabel('Functional System', fontsize=12)
# Title & layout
plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15)
# Clean axes
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False) # We only need bottom axis
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('none')
# Grid only along x
ax.xaxis.grid(True, linestyle=':', alpha=0.5)
# Legend
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'),
Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'),
Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate')
]
plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10)
# Optional: Add counts on bars
for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)):
if left > 0:
plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold')
if right > 0:
plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold')
if match > 0:
plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black')
plt.tight_layout()
# --- 8. Save & Show ---
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
print(f"✅ Diverging bar plot saved to {figure_save_path}")
plt.show()
##
# %% Difference Gemini easy
# --- 1. Process Error Data ---
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
plot_list = []
for gt_col, res_col in functional_systems_to_plot:
sys_name = gt_col.split('.')[1]
# Robust parsing
gt = df[gt_col].apply(safe_parse)
res = df[res_col].apply(safe_parse)
error = res - gt
# Calculate counts
matches = (error == 0).sum()
under = (error < 0).sum()
over = (error > 0).sum()
total = error.dropna().count()
# Calculate Percentages
# Using max(total, 1) to avoid division by zero
divisor = max(total, 1)
match_pct = (matches / divisor) * 100
under_pct = (under / divisor) * 100
over_pct = (over / divisor) * 100
plot_list.append({
'System': sys_name.replace('_', ' ').title(),
'Matches': matches,
'MatchPct': match_pct,
'Under': under,
'UnderPct': under_pct,
'Over': over,
'OverPct': over_pct
})
stats_df = pd.DataFrame(plot_list)
# --- 2. Plotting ---
fig, ax = plt.subplots(figsize=(12, 8)) # Slightly taller for multi-line labels
color_under = '#E74C3C'
color_over = '#3498DB'
bar_height = 0.6
y_pos = np.arange(len(stats_df))
ax.barh(y_pos, -stats_df['Under'], bar_height, label='Under-scored', color=color_under, edgecolor='white', alpha=0.8)
ax.barh(y_pos, stats_df['Over'], bar_height, label='Over-scored', color=color_over, edgecolor='white', alpha=0.8)
# --- 3. Aesthetics & Labels ---
for i, row in stats_df.iterrows():
# Constructing a detailed label for the left side
# Matches (Bold) | Under % | Over %
label_text = (
f"$\mathbf{{{row['System']}}}$\n"
f"Matches: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n"
f"Under: {int(row['Under'])} ({row['UnderPct']:.1f}%) | Over: {int(row['Over'])} ({row['OverPct']:.1f}%)"
)
# Position text to the left of the x=0 line
ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.3)
# Zero line
ax.axvline(0, color='black', linewidth=1.2, alpha=0.7)
# Clean up axes
ax.set_yticks([])
ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold', labelpad=10)
#ax.set_title('Directional Error Analysis by Functional System', fontsize=14, pad=30)
# Make X-axis labels absolute
ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()])
# Remove spines
for spine in ['top', 'right', 'left']:
ax.spines[spine].set_visible(False)
# Legend
ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1))
# Grid
ax.xaxis.grid(True, linestyle='--', alpha=0.3)
plt.tight_layout()
plt.show()
##
# %% test
# Diagnose: what are the actual differences?
print("\n🔍 Raw differences (first 5 rows per system):")
for gt_col, res_col in functional_systems_to_plot:
gt = df[gt_col].apply(safe_parse)
res = df[res_col].apply(safe_parse)
diff = res - gt
non_zero = (diff != 0).sum()
# Check if it's due to floating point noise
abs_diff = diff.abs()
tiny = (abs_diff > 0) & (abs_diff < 1e-10)
print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}")
##