update gitignore

This commit is contained in:
2026-02-04 15:29:56 +01:00
parent b2e9ccd2b6
commit c2ccb8cd11
2 changed files with 45 additions and 26 deletions

View File

@@ -662,7 +662,7 @@ print("\nFirst few rows:")
print(df.head())
# Hardcode specific patient names
patient_names = ['113c1470']
patient_names = ['6b56865d']
# Define the functional systems (columns to plot) - adjust based on actual column names
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
@@ -1183,7 +1183,6 @@ import matplotlib.colors as mcolors
# --- Configuration ---
plt.rcParams['font.family'] = 'Arial'
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
# --- 1. Load the Dataset ---
@@ -1212,19 +1211,19 @@ n_categories = len(categories)
def categorize_edss(value):
if pd.isna(value): return np.nan
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0
# Ensure value is float to avoid TypeError
val = float(value)
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
return categories[min(idx, len(categories)-1)]
# --- 4. Prepare Mixed Color Matrix ---
# --- 4. Prepare Mixed Color Matrix with Saturation ---
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
# CRITICAL FIX: Convert to numeric and drop NaNs in one go
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
# Fix: Ensure numeric conversion to avoid string comparison errors
temp_df = df[[gt_col, res_col]].copy()
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
valid_df = temp_df.dropna()
for _, row in valid_df.iterrows():
@@ -1233,10 +1232,9 @@ for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
if gt_cat in category_to_index and res_cat in category_to_index:
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
# Create an RGB image matrix (initially white/empty)
rgb_matrix = np.ones((n_categories, n_categories, 3))
# Create an RGBA image matrix (10x10x4)
rgba_matrix = np.zeros((n_categories, n_categories, 4))
# Create an Alpha matrix for the "Total Count" intensity
total_counts = np.sum(cell_system_counts, axis=2)
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
@@ -1244,52 +1242,69 @@ for i in range(n_categories):
for j in range(n_categories):
count_sum = total_counts[i, j]
if count_sum > 0:
# Calculate weighted average color
mixed_rgb = np.zeros(3)
for s_idx, s_name in enumerate(system_names):
weight = cell_system_counts[i, j, s_idx] / count_sum
system_rgb = mcolors.to_rgb(color_map[s_name])
mixed_rgb += np.array(system_rgb) * weight
rgb_matrix[i, j] = mixed_rgb
# Set RGB channels
rgba_matrix[i, j, :3] = mixed_rgb
# Set Alpha channel (Saturation Effect)
# Using a square root scale to make lower counts more visible but still "lighter"
alpha = np.sqrt(count_sum / max_count)
# Ensure alpha is at least 0.1 so it's not invisible
rgba_matrix[i, j, 3] = max(0.1, alpha)
else:
# Empty cells are white
rgba_matrix[i, j] = [1, 1, 1, 0]
# --- 5. Plotting ---
fig, ax = plt.subplots(figsize=(12, 10))
# Display the mixed color matrix
# We use alpha based on count to show density (optional, but recommended)
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper')
# Show the matrix
# Note: we use origin='lower' if you want 0-1 at the bottom,
# but confusion matrices usually have 0-1 at the top (origin='upper')
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
# Add text labels for total counts in each cell
# Add count labels
for i in range(n_categories):
for j in range(n_categories):
if total_counts[i, j] > 0:
# Determine text color based on brightness of background
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2]
text_col = "white" if lum < 0.5 else "black"
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9)
# Background brightness for text contrast
bg_color = rgba_matrix[i, j, :3]
lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
# If alpha is low, background is effectively white, so use black text
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
color=text_col, fontsize=10, fontweight='bold')
# --- 6. Styling ---
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12)
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12)
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20)
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
ax.set_xticks(np.arange(n_categories))
ax.set_xticklabels(categories)
ax.set_yticks(np.arange(n_categories))
ax.set_yticklabels(categories)
# Remove the frame/spines for a cleaner look
for spine in ax.spines.values():
spine.set_visible(False)
# Custom Legend
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
labels = [name.replace('_', ' ').capitalize() for name in system_names]
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False)
ax.legend(handles, labels, title='Functional Systems', loc='upper left',
bbox_to_anchor=(1.05, 1), frameon=False)
plt.tight_layout()
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
plt.show()
#