update gitignore
This commit is contained in:
@@ -662,7 +662,7 @@ print("\nFirst few rows:")
|
||||
print(df.head())
|
||||
|
||||
# Hardcode specific patient names
|
||||
patient_names = ['113c1470']
|
||||
patient_names = ['6b56865d']
|
||||
|
||||
# Define the functional systems (columns to plot) - adjust based on actual column names
|
||||
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||||
@@ -1183,7 +1183,6 @@ import matplotlib.colors as mcolors
|
||||
# --- Configuration ---
|
||||
plt.rcParams['font.family'] = 'Arial'
|
||||
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||||
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
|
||||
|
||||
# --- 1. Load the Dataset ---
|
||||
@@ -1212,19 +1211,19 @@ n_categories = len(categories)
|
||||
|
||||
def categorize_edss(value):
|
||||
if pd.isna(value): return np.nan
|
||||
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0
|
||||
# Ensure value is float to avoid TypeError
|
||||
val = float(value)
|
||||
idx = int(min(max(val, 0), 10) - 0.001) if val > 0 else 0
|
||||
return categories[min(idx, len(categories)-1)]
|
||||
|
||||
# --- 4. Prepare Mixed Color Matrix ---
|
||||
# --- 4. Prepare Mixed Color Matrix with Saturation ---
|
||||
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
|
||||
|
||||
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
||||
# CRITICAL FIX: Convert to numeric and drop NaNs in one go
|
||||
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
|
||||
# Fix: Ensure numeric conversion to avoid string comparison errors
|
||||
temp_df = df[[gt_col, res_col]].copy()
|
||||
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
|
||||
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
|
||||
|
||||
valid_df = temp_df.dropna()
|
||||
|
||||
for _, row in valid_df.iterrows():
|
||||
@@ -1233,10 +1232,9 @@ for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
||||
if gt_cat in category_to_index and res_cat in category_to_index:
|
||||
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
|
||||
|
||||
# Create an RGB image matrix (initially white/empty)
|
||||
rgb_matrix = np.ones((n_categories, n_categories, 3))
|
||||
# Create an RGBA image matrix (10x10x4)
|
||||
rgba_matrix = np.zeros((n_categories, n_categories, 4))
|
||||
|
||||
# Create an Alpha matrix for the "Total Count" intensity
|
||||
total_counts = np.sum(cell_system_counts, axis=2)
|
||||
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
|
||||
|
||||
@@ -1244,52 +1242,69 @@ for i in range(n_categories):
|
||||
for j in range(n_categories):
|
||||
count_sum = total_counts[i, j]
|
||||
if count_sum > 0:
|
||||
# Calculate weighted average color
|
||||
mixed_rgb = np.zeros(3)
|
||||
for s_idx, s_name in enumerate(system_names):
|
||||
weight = cell_system_counts[i, j, s_idx] / count_sum
|
||||
system_rgb = mcolors.to_rgb(color_map[s_name])
|
||||
mixed_rgb += np.array(system_rgb) * weight
|
||||
rgb_matrix[i, j] = mixed_rgb
|
||||
|
||||
# Set RGB channels
|
||||
rgba_matrix[i, j, :3] = mixed_rgb
|
||||
|
||||
# Set Alpha channel (Saturation Effect)
|
||||
# Using a square root scale to make lower counts more visible but still "lighter"
|
||||
alpha = np.sqrt(count_sum / max_count)
|
||||
# Ensure alpha is at least 0.1 so it's not invisible
|
||||
rgba_matrix[i, j, 3] = max(0.1, alpha)
|
||||
else:
|
||||
# Empty cells are white
|
||||
rgba_matrix[i, j] = [1, 1, 1, 0]
|
||||
|
||||
# --- 5. Plotting ---
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
|
||||
# Display the mixed color matrix
|
||||
# We use alpha based on count to show density (optional, but recommended)
|
||||
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper')
|
||||
# Show the matrix
|
||||
# Note: we use origin='lower' if you want 0-1 at the bottom,
|
||||
# but confusion matrices usually have 0-1 at the top (origin='upper')
|
||||
im = ax.imshow(rgba_matrix, interpolation='nearest', origin='upper')
|
||||
|
||||
# Add text labels for total counts in each cell
|
||||
# Add count labels
|
||||
for i in range(n_categories):
|
||||
for j in range(n_categories):
|
||||
if total_counts[i, j] > 0:
|
||||
# Determine text color based on brightness of background
|
||||
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2]
|
||||
text_col = "white" if lum < 0.5 else "black"
|
||||
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9)
|
||||
# Background brightness for text contrast
|
||||
bg_color = rgba_matrix[i, j, :3]
|
||||
lum = 0.2126 * bg_color[0] + 0.7152 * bg_color[1] + 0.0722 * bg_color[2]
|
||||
# If alpha is low, background is effectively white, so use black text
|
||||
text_col = "white" if (lum < 0.5 and rgba_matrix[i,j,3] > 0.5) else "black"
|
||||
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center",
|
||||
color=text_col, fontsize=10, fontweight='bold')
|
||||
|
||||
# --- 6. Styling ---
|
||||
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12)
|
||||
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12)
|
||||
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20)
|
||||
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12, labelpad=10)
|
||||
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12, labelpad=10)
|
||||
ax.set_title('Saturated Confusion Matrix\nColor = System Mixture | Opacity = Density', fontsize=14, pad=20)
|
||||
|
||||
ax.set_xticks(np.arange(n_categories))
|
||||
ax.set_xticklabels(categories)
|
||||
ax.set_yticks(np.arange(n_categories))
|
||||
ax.set_yticklabels(categories)
|
||||
|
||||
# Remove the frame/spines for a cleaner look
|
||||
for spine in ax.spines.values():
|
||||
spine.set_visible(False)
|
||||
|
||||
# Custom Legend
|
||||
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
|
||||
labels = [name.replace('_', ' ').capitalize() for name in system_names]
|
||||
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False)
|
||||
ax.legend(handles, labels, title='Functional Systems', loc='upper left',
|
||||
bbox_to_anchor=(1.05, 1), frameon=False)
|
||||
|
||||
plt.tight_layout()
|
||||
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user