show directional errors
Directional Errors of each functional system.
This commit is contained in:
@@ -1397,7 +1397,448 @@ os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||
plt.show()
|
||||
|
||||
#
|
||||
##
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# %% Difference Plot Functional system
|
||||
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
# --- Configuration ---
|
||||
# Set the font to Arial for all text in the plot, as per the guidelines
|
||||
plt.rcParams['font.family'] = 'Arial'
|
||||
|
||||
# Define the path to your data file
|
||||
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||
|
||||
# Define the path to save the color mapping JSON file
|
||||
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||||
|
||||
# Define the path to save the final figure
|
||||
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
|
||||
|
||||
# --- 1. Load the Dataset ---
|
||||
try:
|
||||
# Load the TSV file
|
||||
df = pd.read_csv(data_path, sep='\t')
|
||||
print(f"Successfully loaded data from {data_path}")
|
||||
print(f"Data shape: {df.shape}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file at {data_path} was not found.")
|
||||
# Exit or handle the error appropriately
|
||||
raise
|
||||
|
||||
# --- 2. Define Functional Systems and Create Color Mapping ---
|
||||
# List of tuples containing (ground_truth_column, result_column)
|
||||
functional_systems_to_plot = [
|
||||
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||
('GT.AMBULATION', 'result.AMBULATION'),
|
||||
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||||
]
|
||||
|
||||
# Extract system names for color mapping and legend
|
||||
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||
|
||||
# Define a professional color palette (dark blue theme)
|
||||
# This is a qualitative palette with distinct, accessible colors
|
||||
colors = [
|
||||
'#003366', # Dark Blue
|
||||
'#336699', # Medium Blue
|
||||
'#6699CC', # Light Blue
|
||||
'#99CCFF', # Very Light Blue
|
||||
'#FF9966', # Coral
|
||||
'#FF6666', # Light Red
|
||||
'#CC6699', # Magenta
|
||||
'#9966CC' # Purple
|
||||
]
|
||||
|
||||
# Create a dictionary mapping system names to colors
|
||||
color_map = dict(zip(system_names, colors))
|
||||
|
||||
# Ensure the directory for the JSON file exists
|
||||
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
|
||||
|
||||
# Save the color map to a JSON file
|
||||
with open(color_json_path, 'w') as f:
|
||||
json.dump(color_map, f, indent=4)
|
||||
|
||||
print(f"Color mapping saved to {color_json_path}")
|
||||
|
||||
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
|
||||
agreement_percentages = {}
|
||||
legend_labels = {}
|
||||
|
||||
for gt_col, res_col in functional_systems_to_plot:
|
||||
system_name = gt_col.split('.')[1]
|
||||
|
||||
# Convert columns to numeric, setting errors to NaN
|
||||
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||||
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||||
|
||||
# Ensure we are comparing the same rows
|
||||
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
|
||||
gt_data = gt_numeric.loc[common_index]
|
||||
res_data = res_numeric.loc[common_index]
|
||||
|
||||
# Calculate agreement percentage
|
||||
if len(gt_data) > 0:
|
||||
agreement = (gt_data == res_data).mean() * 100
|
||||
else:
|
||||
agreement = 0 # Handle case with no valid data
|
||||
|
||||
agreement_percentages[system_name] = agreement
|
||||
|
||||
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
|
||||
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
|
||||
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
|
||||
|
||||
|
||||
# --- 4. Robustly Prepare Error Data for Boxplot ---
|
||||
|
||||
def safe_parse(s):
|
||||
'''Convert to float, handling comma decimals (e.g., '3,5' → 3.5)'''
|
||||
if pd.isna(s):
|
||||
return np.nan
|
||||
if isinstance(s, (int, float)):
|
||||
return float(s)
|
||||
# Replace comma with dot, then strip whitespace
|
||||
s_clean = str(s).replace(',', '.').strip()
|
||||
try:
|
||||
return float(s_clean)
|
||||
except ValueError:
|
||||
return np.nan
|
||||
|
||||
plot_data = []
|
||||
for gt_col, res_col in functional_systems_to_plot:
|
||||
system_name = gt_col.split('.')[1]
|
||||
|
||||
# Parse both columns with robust comma handling
|
||||
gt_numeric = df[gt_col].apply(safe_parse)
|
||||
res_numeric = df[res_col].apply(safe_parse)
|
||||
|
||||
# Compute error (only where both are finite)
|
||||
error = res_numeric - gt_numeric
|
||||
|
||||
# Create temp DataFrame
|
||||
temp_df = pd.DataFrame({
|
||||
'system': system_name,
|
||||
'error': error
|
||||
}).dropna() # drop rows where either was unparseable
|
||||
|
||||
plot_data.append(temp_df)
|
||||
|
||||
plot_df = pd.concat(plot_data, ignore_index=True)
|
||||
|
||||
if plot_df.empty:
|
||||
print("⚠️ Warning: No valid numeric error data to plot after robust parsing.")
|
||||
else:
|
||||
print(f"✅ Prepared error data with {len(plot_df)} data points.")
|
||||
# Diagnostic: show a few samples
|
||||
print("\n📌 Sample errors by system:")
|
||||
for sys, grp in plot_df.groupby('system'):
|
||||
print(f" {sys:25s}: n={len(grp)}, mean err = {grp['error'].mean():+.2f}, min = {grp['error'].min():+.2f}, max = {grp['error'].max():+.2f}")
|
||||
|
||||
# Ensure categorical ordering
|
||||
plot_df['system'] = pd.Categorical(
|
||||
plot_df['system'],
|
||||
categories=[name.split('.')[1] for name, _ in functional_systems_to_plot],
|
||||
ordered=True
|
||||
)
|
||||
|
||||
|
||||
# --- 5. Prepare Data for Diverging Stacked Bar Plot ---
|
||||
print("\n📊 Preparing diverging stacked bar plot data...")
|
||||
|
||||
# Define bins for error direction
|
||||
def categorize_error(err):
|
||||
if pd.isna(err):
|
||||
return 'missing'
|
||||
elif err < 0:
|
||||
return 'underestimate'
|
||||
elif err > 0:
|
||||
return 'overestimate'
|
||||
else:
|
||||
return 'match'
|
||||
|
||||
# Add category column (only on finite errors)
|
||||
plot_df_clean = plot_df[plot_df['error'].notna()].copy()
|
||||
plot_df_clean['category'] = plot_df_clean['error'].apply(categorize_error)
|
||||
|
||||
# Count by system + category
|
||||
category_counts = (
|
||||
plot_df_clean
|
||||
.groupby(['system', 'category'])
|
||||
.size()
|
||||
.unstack(fill_value=0)
|
||||
.reindex(columns=['underestimate', 'match', 'overestimate'], fill_value=0)
|
||||
)
|
||||
# Reorder systems
|
||||
category_counts = category_counts.reindex(system_names)
|
||||
|
||||
# Prepare for diverging plot:
|
||||
# - Underestimates: plotted to the *left* (negative x)
|
||||
# - Overestimates: plotted to the *right* (positive x)
|
||||
# - Matches: centered (no width needed, or as a bar of width 0.2)
|
||||
underestimate_counts = category_counts['underestimate']
|
||||
match_counts = category_counts['match']
|
||||
overestimate_counts = category_counts['overestimate']
|
||||
|
||||
# For diverging: left = -underestimate, right = overestimate
|
||||
left_counts = underestimate_counts
|
||||
right_counts = overestimate_counts
|
||||
|
||||
# Compute max absolute bar height (for symmetric x-axis)
|
||||
max_bar = max(left_counts.max(), right_counts.max(), 1)
|
||||
plot_range = (-max_bar, max_bar)
|
||||
|
||||
# X-axis positions: 0 = center, left systems to -1, -2, ..., right systems to +1, +2, ...
|
||||
n_systems = len(system_names)
|
||||
positions = np.arange(n_systems)
|
||||
left_positions = -positions - 0.5 # left-aligned underestimates
|
||||
right_positions = positions + 0.5 # right-aligned overestimates
|
||||
|
||||
# --- 6. Create Diverging Stacked Bar Plot ---
|
||||
plt.figure(figsize=(12, 7))
|
||||
|
||||
# Colors: diverging palette
|
||||
colors = {
|
||||
'underestimate': '#E74C3C', # Red (left)
|
||||
'match': '#2ECC71', # Green (center)
|
||||
'overestimate': '#F39C12' # Orange (right)
|
||||
}
|
||||
|
||||
# Plot underestimates (left side)
|
||||
bars_left = plt.barh(
|
||||
left_positions,
|
||||
left_counts.values,
|
||||
height=0.8,
|
||||
left=0, # starts at 0, extends left (since bars are negative width would be wrong; instead use negative values)
|
||||
color=colors['underestimate'],
|
||||
edgecolor='black',
|
||||
linewidth=0.5,
|
||||
alpha=0.9,
|
||||
label='Underestimate'
|
||||
)
|
||||
|
||||
# Plot overestimates (right side)
|
||||
bars_right = plt.barh(
|
||||
right_positions,
|
||||
right_counts.values,
|
||||
height=0.8,
|
||||
left=0,
|
||||
color=colors['overestimate'],
|
||||
edgecolor='black',
|
||||
linewidth=0.5,
|
||||
alpha=0.9,
|
||||
label='Overestimate'
|
||||
)
|
||||
|
||||
# Plot matches (center — narrow bar)
|
||||
# Use a very narrow width (0.2) centered at 0
|
||||
plt.barh(
|
||||
positions,
|
||||
match_counts.values,
|
||||
height=0.2,
|
||||
left=0, # starts at 0, extends right
|
||||
color=colors['match'],
|
||||
edgecolor='black',
|
||||
linewidth=0.5,
|
||||
alpha=0.9,
|
||||
label='Exact Match'
|
||||
)
|
||||
|
||||
# ✨ Better: flip match to be centered symmetrically (left=-match/2, width=match)
|
||||
# For perfect symmetry:
|
||||
for i, count in enumerate(match_counts.values):
|
||||
if count > 0:
|
||||
plt.barh(
|
||||
positions[i],
|
||||
width=count,
|
||||
left=-count/2,
|
||||
height=0.25,
|
||||
color=colors['match'],
|
||||
edgecolor='black',
|
||||
linewidth=0.5,
|
||||
alpha=0.95
|
||||
)
|
||||
|
||||
# --- 7. Styling & Labels ---
|
||||
# Zero reference line
|
||||
plt.axvline(x=0, color='black', linestyle='-', linewidth=1.2, alpha=0.8)
|
||||
|
||||
# X-axis: symmetric around 0
|
||||
plt.xlim(plot_range[0] - max_bar*0.1, plot_range[1] + max_bar*0.1)
|
||||
plt.xticks(rotation=0, fontsize=10)
|
||||
plt.xlabel('Count', fontsize=12)
|
||||
|
||||
# Y-axis: system names at original positions (centered)
|
||||
plt.yticks(positions, [name.replace('_', '\n').replace('and', '&') for name in system_names], fontsize=10)
|
||||
plt.ylabel('Functional System', fontsize=12)
|
||||
|
||||
# Title & layout
|
||||
plt.title('Diverging Error Direction by Functional System\n(Red: Underestimation | Green: Exact | Orange: Overestimation)', fontsize=13, pad=15)
|
||||
|
||||
# Clean axes
|
||||
ax = plt.gca()
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.spines['left'].set_visible(False) # We only need bottom axis
|
||||
ax.xaxis.set_ticks_position('bottom')
|
||||
ax.yaxis.set_ticks_position('none')
|
||||
|
||||
# Grid only along x
|
||||
ax.xaxis.grid(True, linestyle=':', alpha=0.5)
|
||||
|
||||
# Legend
|
||||
from matplotlib.patches import Patch
|
||||
legend_elements = [
|
||||
Patch(facecolor=colors['underestimate'], edgecolor='black', label='Underestimate'),
|
||||
Patch(facecolor=colors['match'], edgecolor='black', label='Exact Match'),
|
||||
Patch(facecolor=colors['overestimate'], edgecolor='black', label='Overestimate')
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc='upper right', frameon=False, fontsize=10)
|
||||
|
||||
# Optional: Add counts on bars
|
||||
for i, (left, right, match) in enumerate(zip(left_counts, right_counts, match_counts)):
|
||||
if left > 0:
|
||||
plt.text(-left - max_bar*0.05, left_positions[i], str(left), va='center', ha='right', fontsize=9, color='white', fontweight='bold')
|
||||
if right > 0:
|
||||
plt.text(right + max_bar*0.05, right_positions[i], str(right), va='center', ha='left', fontsize=9, color='white', fontweight='bold')
|
||||
if match > 0:
|
||||
plt.text(match_counts[i]/2, positions[i], str(match), va='center', ha='center', fontsize=8, color='black')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# --- 8. Save & Show ---
|
||||
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||
print(f"✅ Diverging bar plot saved to {figure_save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
##
|
||||
|
||||
|
||||
|
||||
# %% Difference Gemini easy
|
||||
|
||||
|
||||
# --- 1. Process Error Data ---
|
||||
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||
plot_list = []
|
||||
|
||||
for gt_col, res_col in functional_systems_to_plot:
|
||||
sys_name = gt_col.split('.')[1]
|
||||
|
||||
# Robust parsing
|
||||
gt = df[gt_col].apply(safe_parse)
|
||||
res = df[res_col].apply(safe_parse)
|
||||
error = res - gt
|
||||
|
||||
# Calculate counts
|
||||
matches = (error == 0).sum()
|
||||
under = (error < 0).sum()
|
||||
over = (error > 0).sum()
|
||||
total = error.dropna().count()
|
||||
|
||||
# Calculate Percentages
|
||||
# Using max(total, 1) to avoid division by zero
|
||||
divisor = max(total, 1)
|
||||
match_pct = (matches / divisor) * 100
|
||||
under_pct = (under / divisor) * 100
|
||||
over_pct = (over / divisor) * 100
|
||||
|
||||
plot_list.append({
|
||||
'System': sys_name.replace('_', ' ').title(),
|
||||
'Matches': matches,
|
||||
'MatchPct': match_pct,
|
||||
'Under': under,
|
||||
'UnderPct': under_pct,
|
||||
'Over': over,
|
||||
'OverPct': over_pct
|
||||
})
|
||||
|
||||
stats_df = pd.DataFrame(plot_list)
|
||||
|
||||
# --- 2. Plotting ---
|
||||
fig, ax = plt.subplots(figsize=(12, 8)) # Slightly taller for multi-line labels
|
||||
|
||||
color_under = '#E74C3C'
|
||||
color_over = '#3498DB'
|
||||
bar_height = 0.6
|
||||
|
||||
y_pos = np.arange(len(stats_df))
|
||||
|
||||
ax.barh(y_pos, -stats_df['Under'], bar_height, label='Under-scored', color=color_under, edgecolor='white', alpha=0.8)
|
||||
ax.barh(y_pos, stats_df['Over'], bar_height, label='Over-scored', color=color_over, edgecolor='white', alpha=0.8)
|
||||
|
||||
# --- 3. Aesthetics & Labels ---
|
||||
|
||||
for i, row in stats_df.iterrows():
|
||||
# Constructing a detailed label for the left side
|
||||
# Matches (Bold) | Under % | Over %
|
||||
label_text = (
|
||||
f"$\mathbf{{{row['System']}}}$\n"
|
||||
f"Matches: {int(row['Matches'])} ({row['MatchPct']:.1f}%)\n"
|
||||
f"Under: {int(row['Under'])} ({row['UnderPct']:.1f}%) | Over: {int(row['Over'])} ({row['OverPct']:.1f}%)"
|
||||
)
|
||||
|
||||
# Position text to the left of the x=0 line
|
||||
ax.text(ax.get_xlim()[0] - 0.5, i, label_text, va='center', ha='right', fontsize=9, color='#333333', linespacing=1.3)
|
||||
|
||||
# Zero line
|
||||
ax.axvline(0, color='black', linewidth=1.2, alpha=0.7)
|
||||
|
||||
# Clean up axes
|
||||
ax.set_yticks([])
|
||||
ax.set_xlabel('Number of Patients with Error', fontsize=11, fontweight='bold', labelpad=10)
|
||||
#ax.set_title('Directional Error Analysis by Functional System', fontsize=14, pad=30)
|
||||
|
||||
# Make X-axis labels absolute
|
||||
ax.set_xticklabels([int(abs(tick)) for tick in ax.get_xticks()])
|
||||
|
||||
# Remove spines
|
||||
for spine in ['top', 'right', 'left']:
|
||||
ax.spines[spine].set_visible(False)
|
||||
|
||||
# Legend
|
||||
ax.legend(loc='upper right', frameon=False, bbox_to_anchor=(1, 1.1))
|
||||
|
||||
# Grid
|
||||
ax.xaxis.grid(True, linestyle='--', alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
##
|
||||
|
||||
# %% test
|
||||
# Diagnose: what are the actual differences?
|
||||
print("\n🔍 Raw differences (first 5 rows per system):")
|
||||
for gt_col, res_col in functional_systems_to_plot:
|
||||
gt = df[gt_col].apply(safe_parse)
|
||||
res = df[res_col].apply(safe_parse)
|
||||
diff = res - gt
|
||||
non_zero = (diff != 0).sum()
|
||||
# Check if it's due to floating point noise
|
||||
abs_diff = diff.abs()
|
||||
tiny = (abs_diff > 0) & (abs_diff < 1e-10)
|
||||
print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}")
|
||||
|
||||
##
|
||||
|
||||
Reference in New Issue
Block a user