isabella box and Error disagreement plot
This commit is contained in:
+803
-3
@@ -2152,8 +2152,6 @@ plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
|||||||
plt.show()
|
plt.show()
|
||||||
##
|
##
|
||||||
|
|
||||||
<<<<<<< Updated upstream
|
|
||||||
=======
|
|
||||||
# %% Functional System + EDSS Error Boxplots
|
# %% Functional System + EDSS Error Boxplots
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -2302,7 +2300,6 @@ plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
|||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
##
|
##
|
||||||
>>>>>>> Stashed changes
|
|
||||||
|
|
||||||
# %% test
|
# %% test
|
||||||
# Diagnose: what are the actual differences?
|
# Diagnose: what are the actual differences?
|
||||||
@@ -2318,3 +2315,806 @@ for gt_col, res_col in functional_systems_to_plot:
|
|||||||
print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}")
|
print(f"{gt_col.split('.')[1]:25s}: non-zero = {non_zero:3d}, tiny = {tiny.sum():3d}, max abs diff = {abs_diff.max():.12f}")
|
||||||
|
|
||||||
##
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Functional System Continuous Accuracy Boxplot
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.patches import Patch
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
|
||||||
|
figure_save_path = 'project/visuals/functional_systems_continuous_accuracy_boxplot.svg'
|
||||||
|
|
||||||
|
# --- Functional systems using your actual column names ---
|
||||||
|
functional_systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Robust parser ---
|
||||||
|
def safe_parse(s):
|
||||||
|
"""Convert to float, handling comma decimals like '3,5'."""
|
||||||
|
if pd.isna(s):
|
||||||
|
return np.nan
|
||||||
|
if isinstance(s, (int, float, np.integer, np.floating)):
|
||||||
|
return float(s)
|
||||||
|
|
||||||
|
s_clean = str(s).replace(',', '.').strip()
|
||||||
|
|
||||||
|
if s_clean == "":
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
try:
|
||||||
|
return float(s_clean)
|
||||||
|
except ValueError:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
# --- Build accuracy data ---
|
||||||
|
boxplot_data = []
|
||||||
|
system_labels = []
|
||||||
|
predicted_counts = []
|
||||||
|
missing_prediction_counts = []
|
||||||
|
total_gt_counts = []
|
||||||
|
mean_accuracies = []
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
|
||||||
|
if gt_col not in df.columns:
|
||||||
|
print(f"Skipping {gt_col}: GT column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if res_col not in df.columns:
|
||||||
|
print(f"Skipping {res_col}: result column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
|
||||||
|
# Only rows where ground truth exists
|
||||||
|
gt_exists = gt.notna()
|
||||||
|
|
||||||
|
total_gt = gt_exists.sum()
|
||||||
|
|
||||||
|
if total_gt == 0:
|
||||||
|
print(f"Skipping {system_name}: no ground-truth values")
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_valid = gt[gt_exists]
|
||||||
|
res_valid = res[gt_exists]
|
||||||
|
|
||||||
|
# GT exists, but LLM prediction is missing
|
||||||
|
missing_count = res_valid.isna().sum()
|
||||||
|
|
||||||
|
# For the boxplot, use rows where both GT and result exist
|
||||||
|
both_exist = res_valid.notna()
|
||||||
|
|
||||||
|
if both_exist.sum() == 0:
|
||||||
|
print(f"Skipping {system_name}: no predicted values")
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_eval = gt_valid[both_exist]
|
||||||
|
res_eval = res_valid[both_exist]
|
||||||
|
|
||||||
|
# Functional system score range.
|
||||||
|
# Adjust if your functional systems use another scale.
|
||||||
|
score_range = 5
|
||||||
|
|
||||||
|
# Continuous accuracy:
|
||||||
|
# exact match = 1.0
|
||||||
|
# off by 1 point = 0.8
|
||||||
|
# off by 2 points = 0.6
|
||||||
|
# etc.
|
||||||
|
abs_error = (res_eval - gt_eval).abs()
|
||||||
|
accuracy = 1 - (abs_error / score_range)
|
||||||
|
accuracy = accuracy.clip(lower=0, upper=1)
|
||||||
|
|
||||||
|
clean_name = system_name.replace('_', ' ').title()
|
||||||
|
|
||||||
|
boxplot_data.append(accuracy.values)
|
||||||
|
system_labels.append(clean_name)
|
||||||
|
predicted_counts.append(len(gt_eval))
|
||||||
|
missing_prediction_counts.append(missing_count)
|
||||||
|
total_gt_counts.append(total_gt)
|
||||||
|
mean_accuracies.append(accuracy.mean())
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{clean_name}: "
|
||||||
|
f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, "
|
||||||
|
f"mean accuracy={accuracy.mean():.1%}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not boxplot_data:
|
||||||
|
raise ValueError("No valid accuracy data available for plotting.")
|
||||||
|
|
||||||
|
# X-axis labels
|
||||||
|
xtick_labels = [
|
||||||
|
f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}"
|
||||||
|
for label, gt_n, pred_n, miss_n
|
||||||
|
in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Plot ---
|
||||||
|
fig, ax = plt.subplots(figsize=(16, 8))
|
||||||
|
|
||||||
|
bp = ax.boxplot(
|
||||||
|
boxplot_data,
|
||||||
|
vert=True,
|
||||||
|
patch_artist=True,
|
||||||
|
labels=xtick_labels,
|
||||||
|
showmeans=True,
|
||||||
|
meanline=False,
|
||||||
|
widths=0.55
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Styling ---
|
||||||
|
box_face = '#D6EAF8'
|
||||||
|
box_edge = '#2980B9'
|
||||||
|
whisker_col = '#7F8C8D'
|
||||||
|
median_col = '#C0392B'
|
||||||
|
mean_col = '#1ABC9C'
|
||||||
|
flier_face = '#95A5A6'
|
||||||
|
flier_edge = '#7F8C8D'
|
||||||
|
|
||||||
|
for box in bp['boxes']:
|
||||||
|
box.set(facecolor=box_face, edgecolor=box_edge, linewidth=1.5)
|
||||||
|
|
||||||
|
for whisker in bp['whiskers']:
|
||||||
|
whisker.set(color=whisker_col, linewidth=1.2)
|
||||||
|
|
||||||
|
for cap in bp['caps']:
|
||||||
|
cap.set(color=whisker_col, linewidth=1.2)
|
||||||
|
|
||||||
|
for median in bp['medians']:
|
||||||
|
median.set(color=median_col, linewidth=2)
|
||||||
|
|
||||||
|
for mean in bp['means']:
|
||||||
|
mean.set(
|
||||||
|
marker='o',
|
||||||
|
markerfacecolor=mean_col,
|
||||||
|
markeredgecolor='black',
|
||||||
|
markersize=6
|
||||||
|
)
|
||||||
|
|
||||||
|
for flier in bp['fliers']:
|
||||||
|
flier.set(
|
||||||
|
marker='o',
|
||||||
|
markerfacecolor=flier_face,
|
||||||
|
markeredgecolor=flier_edge,
|
||||||
|
alpha=0.6,
|
||||||
|
markersize=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mean accuracy label above each box
|
||||||
|
for i, acc in enumerate(mean_accuracies, start=1):
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
1.03,
|
||||||
|
f"{acc:.1%}",
|
||||||
|
ha='center',
|
||||||
|
va='bottom',
|
||||||
|
fontsize=9,
|
||||||
|
fontweight='bold'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perfect accuracy reference line
|
||||||
|
ax.axhline(1, color='black', linewidth=1.2, linestyle='--', alpha=0.7)
|
||||||
|
|
||||||
|
# Labels and formatting
|
||||||
|
ax.set_xlabel('Functional System', fontsize=11, fontweight='bold')
|
||||||
|
ax.set_ylabel('Continuous Accuracy', fontsize=11, fontweight='bold')
|
||||||
|
|
||||||
|
ax.set_ylim(-0.05, 1.10)
|
||||||
|
ax.set_yticks(np.arange(0, 1.01, 0.1))
|
||||||
|
ax.set_yticklabels([f"{int(y * 100)}%" for y in np.arange(0, 1.01, 0.1)])
|
||||||
|
|
||||||
|
plt.xticks(rotation=45, ha='right')
|
||||||
|
|
||||||
|
ax.yaxis.grid(True, linestyle='--', alpha=0.3)
|
||||||
|
for spine in ['top', 'right']:
|
||||||
|
ax.spines[spine].set_visible(False)
|
||||||
|
|
||||||
|
# Legend
|
||||||
|
legend_handles = [
|
||||||
|
Patch(facecolor=box_face, edgecolor=box_edge, label='IQR of continuous accuracy'),
|
||||||
|
Line2D([0], [0], color=median_col, lw=2, label='Median'),
|
||||||
|
Line2D([0], [0], marker='o', color='w', markerfacecolor=mean_col,
|
||||||
|
markeredgecolor='black', markersize=7, label='Mean'),
|
||||||
|
Line2D([0], [0], marker='o', color='w', markerfacecolor=flier_face,
|
||||||
|
markeredgecolor=flier_edge, alpha=0.8, markersize=6, label='Outlier'),
|
||||||
|
Line2D([0], [0], color='black', lw=1.2, linestyle='--', label='Perfect accuracy')
|
||||||
|
]
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
handles=legend_handles,
|
||||||
|
loc='lower center',
|
||||||
|
bbox_to_anchor=(0.5, 1.06),
|
||||||
|
ncol=5,
|
||||||
|
frameon=False
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.88])
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Functional Systems + EDSS Continuous Accuracy Boxplot
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.patches import Patch
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
|
||||||
|
figure_save_path = 'project/visuals/functional_systems_edss_continuous_accuracy_boxplot.svg'
|
||||||
|
|
||||||
|
# --- Functional systems + EDSS using your actual column names ---
|
||||||
|
functional_systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'),
|
||||||
|
|
||||||
|
# EDSS
|
||||||
|
('GT.EDSS', 'result.EDSS')
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Robust parser ---
|
||||||
|
def safe_parse(s):
|
||||||
|
"""Convert to float, handling comma decimals like '3,5'."""
|
||||||
|
if pd.isna(s):
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
if isinstance(s, (int, float, np.integer, np.floating)):
|
||||||
|
return float(s)
|
||||||
|
|
||||||
|
s_clean = str(s).replace(',', '.').strip()
|
||||||
|
|
||||||
|
if s_clean == "":
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
try:
|
||||||
|
return float(s_clean)
|
||||||
|
except ValueError:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
|
||||||
|
# --- Build accuracy data ---
|
||||||
|
boxplot_data = []
|
||||||
|
system_labels = []
|
||||||
|
predicted_counts = []
|
||||||
|
missing_prediction_counts = []
|
||||||
|
total_gt_counts = []
|
||||||
|
mean_accuracies = []
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
|
||||||
|
if gt_col not in df.columns:
|
||||||
|
print(f"Skipping {gt_col}: GT column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if res_col not in df.columns:
|
||||||
|
print(f"Skipping {res_col}: result column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
|
||||||
|
# Only rows where ground truth exists
|
||||||
|
gt_exists = gt.notna()
|
||||||
|
total_gt = gt_exists.sum()
|
||||||
|
|
||||||
|
if total_gt == 0:
|
||||||
|
print(f"Skipping {system_name}: no ground-truth values")
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_valid = gt[gt_exists]
|
||||||
|
res_valid = res[gt_exists]
|
||||||
|
|
||||||
|
# Count cases where GT exists but LLM prediction is missing
|
||||||
|
missing_count = res_valid.isna().sum()
|
||||||
|
|
||||||
|
# For the boxplot, use only rows where both GT and prediction exist
|
||||||
|
both_exist = res_valid.notna()
|
||||||
|
|
||||||
|
if both_exist.sum() == 0:
|
||||||
|
print(f"Skipping {system_name}: no predicted values")
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_eval = gt_valid[both_exist]
|
||||||
|
res_eval = res_valid[both_exist]
|
||||||
|
|
||||||
|
# Functional systems are usually scored 0-5.
|
||||||
|
# EDSS is usually scored 0-10.
|
||||||
|
if system_name == "EDSS":
|
||||||
|
score_range = 10
|
||||||
|
clean_name = "EDSS"
|
||||||
|
else:
|
||||||
|
score_range = 5
|
||||||
|
clean_name = system_name.replace('_', ' ').title()
|
||||||
|
|
||||||
|
# Continuous accuracy:
|
||||||
|
# exact match = 1.0
|
||||||
|
# off by 1 point in FS = 0.8
|
||||||
|
# off by 1 point in EDSS = 0.9
|
||||||
|
abs_error = (res_eval - gt_eval).abs()
|
||||||
|
accuracy = 1 - (abs_error / score_range)
|
||||||
|
|
||||||
|
# Keep values between 0 and 1
|
||||||
|
accuracy = accuracy.clip(lower=0, upper=1)
|
||||||
|
|
||||||
|
boxplot_data.append(accuracy.values)
|
||||||
|
system_labels.append(clean_name)
|
||||||
|
predicted_counts.append(len(gt_eval))
|
||||||
|
missing_prediction_counts.append(missing_count)
|
||||||
|
total_gt_counts.append(total_gt)
|
||||||
|
mean_accuracies.append(accuracy.mean())
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{clean_name}: "
|
||||||
|
f"GT={total_gt}, predicted={len(gt_eval)}, missing={missing_count}, "
|
||||||
|
f"mean accuracy={accuracy.mean():.1%}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not boxplot_data:
|
||||||
|
raise ValueError("No valid accuracy data available for plotting.")
|
||||||
|
|
||||||
|
|
||||||
|
# --- X-axis labels ---
|
||||||
|
xtick_labels = [
|
||||||
|
f"{label}\nGT={gt_n}, predicted={pred_n}, missing={miss_n}"
|
||||||
|
for label, gt_n, pred_n, miss_n
|
||||||
|
in zip(system_labels, total_gt_counts, predicted_counts, missing_prediction_counts)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# --- Plot ---
|
||||||
|
fig, ax = plt.subplots(figsize=(17, 8))
|
||||||
|
|
||||||
|
bp = ax.boxplot(
|
||||||
|
boxplot_data,
|
||||||
|
vert=True,
|
||||||
|
patch_artist=True,
|
||||||
|
labels=xtick_labels,
|
||||||
|
showmeans=True,
|
||||||
|
meanline=False,
|
||||||
|
widths=0.55
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Styling ---
|
||||||
|
box_face = '#D6EAF8'
|
||||||
|
box_edge = '#2980B9'
|
||||||
|
whisker_col = '#7F8C8D'
|
||||||
|
median_col = '#C0392B'
|
||||||
|
mean_col = '#1ABC9C'
|
||||||
|
flier_face = '#95A5A6'
|
||||||
|
flier_edge = '#7F8C8D'
|
||||||
|
|
||||||
|
for box in bp['boxes']:
|
||||||
|
box.set(
|
||||||
|
facecolor=box_face,
|
||||||
|
edgecolor=box_edge,
|
||||||
|
linewidth=1.5
|
||||||
|
)
|
||||||
|
|
||||||
|
for whisker in bp['whiskers']:
|
||||||
|
whisker.set(
|
||||||
|
color=whisker_col,
|
||||||
|
linewidth=1.2
|
||||||
|
)
|
||||||
|
|
||||||
|
for cap in bp['caps']:
|
||||||
|
cap.set(
|
||||||
|
color=whisker_col,
|
||||||
|
linewidth=1.2
|
||||||
|
)
|
||||||
|
|
||||||
|
for median in bp['medians']:
|
||||||
|
median.set(
|
||||||
|
color=median_col,
|
||||||
|
linewidth=2
|
||||||
|
)
|
||||||
|
|
||||||
|
for mean in bp['means']:
|
||||||
|
mean.set(
|
||||||
|
marker='o',
|
||||||
|
markerfacecolor=mean_col,
|
||||||
|
markeredgecolor='black',
|
||||||
|
markersize=6
|
||||||
|
)
|
||||||
|
|
||||||
|
for flier in bp['fliers']:
|
||||||
|
flier.set(
|
||||||
|
marker='o',
|
||||||
|
markerfacecolor=flier_face,
|
||||||
|
markeredgecolor=flier_edge,
|
||||||
|
alpha=0.6,
|
||||||
|
markersize=4
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Mean accuracy labels above each box ---
|
||||||
|
for i, acc in enumerate(mean_accuracies, start=1):
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
1.03,
|
||||||
|
f"{acc:.1%}",
|
||||||
|
ha='center',
|
||||||
|
va='bottom',
|
||||||
|
fontsize=9,
|
||||||
|
fontweight='bold'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Perfect accuracy reference line ---
|
||||||
|
ax.axhline(
|
||||||
|
1,
|
||||||
|
color='black',
|
||||||
|
linewidth=1.2,
|
||||||
|
linestyle='--',
|
||||||
|
alpha=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Labels and formatting ---
|
||||||
|
ax.set_xlabel(
|
||||||
|
'Functional System / EDSS',
|
||||||
|
fontsize=11,
|
||||||
|
fontweight='bold'
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_ylabel(
|
||||||
|
'Continuous Accuracy',
|
||||||
|
fontsize=11,
|
||||||
|
fontweight='bold'
|
||||||
|
)
|
||||||
|
|
||||||
|
#ax.set_title(
|
||||||
|
# 'Continuous Accuracy of Functional Systems and EDSS',
|
||||||
|
# fontsize=14,
|
||||||
|
# fontweight='bold',
|
||||||
|
# pad=35
|
||||||
|
#)
|
||||||
|
|
||||||
|
ax.set_ylim(-0.05, 1.10)
|
||||||
|
|
||||||
|
yticks = np.arange(0, 1.01, 0.1)
|
||||||
|
ax.set_yticks(yticks)
|
||||||
|
ax.set_yticklabels([f"{int(y * 100)}%" for y in yticks])
|
||||||
|
|
||||||
|
plt.xticks(rotation=45, ha='right')
|
||||||
|
|
||||||
|
ax.yaxis.grid(True, linestyle='--', alpha=0.3)
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
|
||||||
|
for spine in ['top', 'right']:
|
||||||
|
ax.spines[spine].set_visible(False)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Legend ---
|
||||||
|
legend_handles = [
|
||||||
|
Patch(
|
||||||
|
facecolor=box_face,
|
||||||
|
edgecolor=box_edge,
|
||||||
|
label='IQR of continuous accuracy'
|
||||||
|
),
|
||||||
|
Line2D(
|
||||||
|
[0], [0],
|
||||||
|
color=median_col,
|
||||||
|
lw=2,
|
||||||
|
label='Median'
|
||||||
|
),
|
||||||
|
Line2D(
|
||||||
|
[0], [0],
|
||||||
|
marker='o',
|
||||||
|
color='w',
|
||||||
|
markerfacecolor=mean_col,
|
||||||
|
markeredgecolor='black',
|
||||||
|
markersize=7,
|
||||||
|
label='Mean'
|
||||||
|
),
|
||||||
|
Line2D(
|
||||||
|
[0], [0],
|
||||||
|
marker='o',
|
||||||
|
color='w',
|
||||||
|
markerfacecolor=flier_face,
|
||||||
|
markeredgecolor=flier_edge,
|
||||||
|
alpha=0.8,
|
||||||
|
markersize=6,
|
||||||
|
label='Outlier'
|
||||||
|
),
|
||||||
|
Line2D(
|
||||||
|
[0], [0],
|
||||||
|
color='black',
|
||||||
|
lw=1.2,
|
||||||
|
linestyle='--',
|
||||||
|
label='Perfect accuracy'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
handles=legend_handles,
|
||||||
|
loc='lower center',
|
||||||
|
bbox_to_anchor=(0.5, 1.08),
|
||||||
|
ncol=5,
|
||||||
|
frameon=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Save and show ---
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.86])
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Functional Systems + EDSS Error Category Stacked Bar Plot
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib.patches import Patch
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
figure_save_path = 'project/visuals/functional_systems_edss_error_categories.svg'
|
||||||
|
|
||||||
|
# --- Functional systems + EDSS using your actual column names ---
|
||||||
|
systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS'),
|
||||||
|
('GT.EDSS', 'result.EDSS')
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Robust parser ---
|
||||||
|
def safe_parse(s):
|
||||||
|
"""Convert to float, handling comma decimals like '3,5'."""
|
||||||
|
if pd.isna(s):
|
||||||
|
return np.nan
|
||||||
|
if isinstance(s, (int, float, np.integer, np.floating)):
|
||||||
|
return float(s)
|
||||||
|
|
||||||
|
s_clean = str(s).replace(',', '.').strip()
|
||||||
|
|
||||||
|
if s_clean == "":
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
try:
|
||||||
|
return float(s_clean)
|
||||||
|
except ValueError:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
|
|
||||||
|
# --- Categorize absolute error ---
|
||||||
|
def categorize_error(abs_error):
|
||||||
|
if abs_error == 0:
|
||||||
|
return "Exact"
|
||||||
|
elif abs_error <= 0.5:
|
||||||
|
return "≤0.5 error"
|
||||||
|
elif abs_error <= 1:
|
||||||
|
return "≤1 error"
|
||||||
|
else:
|
||||||
|
return ">1 error"
|
||||||
|
|
||||||
|
|
||||||
|
# --- Prepare data ---
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for gt_col, res_col in systems_to_plot:
|
||||||
|
|
||||||
|
if gt_col not in df.columns:
|
||||||
|
print(f"Skipping {gt_col}: GT column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if res_col not in df.columns:
|
||||||
|
print(f"Skipping {res_col}: result column not found")
|
||||||
|
continue
|
||||||
|
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
if system_name == "EDSS":
|
||||||
|
clean_name = "EDSS"
|
||||||
|
else:
|
||||||
|
clean_name = system_name.replace("_", " ").title()
|
||||||
|
|
||||||
|
gt = df[gt_col].apply(safe_parse)
|
||||||
|
res = df[res_col].apply(safe_parse)
|
||||||
|
|
||||||
|
# Evaluate only cases where ground truth exists
|
||||||
|
gt_exists = gt.notna()
|
||||||
|
gt_valid = gt[gt_exists]
|
||||||
|
res_valid = res[gt_exists]
|
||||||
|
|
||||||
|
if len(gt_valid) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for gt_value, res_value in zip(gt_valid, res_valid):
|
||||||
|
if pd.isna(res_value):
|
||||||
|
category = "Missing"
|
||||||
|
else:
|
||||||
|
abs_error = abs(res_value - gt_value)
|
||||||
|
category = categorize_error(abs_error)
|
||||||
|
|
||||||
|
rows.append({
|
||||||
|
"system": clean_name,
|
||||||
|
"category": category
|
||||||
|
})
|
||||||
|
|
||||||
|
plot_df = pd.DataFrame(rows)
|
||||||
|
|
||||||
|
if plot_df.empty:
|
||||||
|
raise ValueError("No valid data available for plotting.")
|
||||||
|
|
||||||
|
category_order = [
|
||||||
|
"Exact",
|
||||||
|
"≤0.5 error",
|
||||||
|
"≤1 error",
|
||||||
|
">1 error",
|
||||||
|
"Missing"
|
||||||
|
]
|
||||||
|
|
||||||
|
system_order = [
|
||||||
|
"Visual Optic Functions",
|
||||||
|
"Cerebellar Functions",
|
||||||
|
"Brainstem Functions",
|
||||||
|
"Sensory Functions",
|
||||||
|
"Pyramidal Functions",
|
||||||
|
"Ambulation",
|
||||||
|
"Cerebral Functions",
|
||||||
|
"Bowel And Bladder Functions",
|
||||||
|
"EDSS"
|
||||||
|
]
|
||||||
|
|
||||||
|
counts = (
|
||||||
|
plot_df
|
||||||
|
.groupby(["system", "category"])
|
||||||
|
.size()
|
||||||
|
.unstack(fill_value=0)
|
||||||
|
.reindex(index=system_order)
|
||||||
|
.reindex(columns=category_order, fill_value=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove systems that were not available
|
||||||
|
counts = counts.dropna(how="all")
|
||||||
|
|
||||||
|
# Convert to percentages for easier comparison
|
||||||
|
percentages = counts.div(counts.sum(axis=1), axis=0) * 100
|
||||||
|
|
||||||
|
|
||||||
|
# --- Plot ---
|
||||||
|
fig, ax = plt.subplots(figsize=(13, 7))
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
"Exact": "#2ECC71",
|
||||||
|
"≤0.5 error": "#A9DFBF",
|
||||||
|
"≤1 error": "#F9E79F",
|
||||||
|
">1 error": "#E67E22",
|
||||||
|
"Missing": "#E74C3C"
|
||||||
|
}
|
||||||
|
|
||||||
|
left = np.zeros(len(percentages))
|
||||||
|
|
||||||
|
for category in category_order:
|
||||||
|
values = percentages[category].values
|
||||||
|
|
||||||
|
ax.barh(
|
||||||
|
percentages.index,
|
||||||
|
values,
|
||||||
|
left=left,
|
||||||
|
color=colors[category],
|
||||||
|
edgecolor="white",
|
||||||
|
linewidth=0.8,
|
||||||
|
label=category
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add labels only if segment is large enough
|
||||||
|
for i, value in enumerate(values):
|
||||||
|
if value >= 4:
|
||||||
|
ax.text(
|
||||||
|
left[i] + value / 2,
|
||||||
|
i,
|
||||||
|
f"{value:.1f}%",
|
||||||
|
ha="center",
|
||||||
|
va="center",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold"
|
||||||
|
)
|
||||||
|
|
||||||
|
left += values
|
||||||
|
|
||||||
|
|
||||||
|
# Add total n and missing count at the right side
|
||||||
|
for i, system in enumerate(percentages.index):
|
||||||
|
total_n = int(counts.loc[system].sum())
|
||||||
|
missing_n = int(counts.loc[system, "Missing"])
|
||||||
|
|
||||||
|
ax.text(
|
||||||
|
101,
|
||||||
|
i,
|
||||||
|
f"n={total_n}, missing={missing_n}",
|
||||||
|
va="center",
|
||||||
|
ha="left",
|
||||||
|
fontsize=9
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Formatting ---
|
||||||
|
ax.set_xlim(0, 115)
|
||||||
|
ax.set_xlabel("Percentage of Cases", fontsize=11, fontweight="bold")
|
||||||
|
ax.set_ylabel("Functional System / EDSS", fontsize=11, fontweight="bold")
|
||||||
|
|
||||||
|
#ax.set_title(
|
||||||
|
# "Prediction Error Categories by Functional System and EDSS",
|
||||||
|
# fontsize=14,
|
||||||
|
# fontweight="bold",
|
||||||
|
# pad=20
|
||||||
|
#)
|
||||||
|
|
||||||
|
ax.set_xticks(np.arange(0, 101, 10))
|
||||||
|
ax.set_xticklabels([f"{x}%" for x in np.arange(0, 101, 10)])
|
||||||
|
|
||||||
|
ax.xaxis.grid(True, linestyle="--", alpha=0.3)
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
|
||||||
|
for spine in ["top", "right", "left"]:
|
||||||
|
ax.spines[spine].set_visible(False)
|
||||||
|
|
||||||
|
ax.legend(
|
||||||
|
loc="lower center",
|
||||||
|
bbox_to_anchor=(0.5, 1.02),
|
||||||
|
ncol=5,
|
||||||
|
frameon=False
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.92])
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
plt.savefig(figure_save_path, format="svg", bbox_inches="tight")
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user