Compare commits
4 Commits
Dashboard_
...
clean
| Author | SHA1 | Date | |
|---|---|---|---|
| a29d9fcba5 | |||
| c986ab92c5 | |||
| b2e9ccd2b6 | |||
| 2f1bd2bfd0 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,7 +6,7 @@
|
|||||||
.env
|
.env
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
/reference/
|
||||||
# 2. Ignore virtual environments COMPLETELY
|
# 2. Ignore virtual environments COMPLETELY
|
||||||
# This must come BEFORE the unignore rule
|
# This must come BEFORE the unignore rule
|
||||||
env*/
|
env*/
|
||||||
|
|||||||
@@ -662,7 +662,7 @@ print("\nFirst few rows:")
|
|||||||
print(df.head())
|
print(df.head())
|
||||||
|
|
||||||
# Hardcode specific patient names
|
# Hardcode specific patient names
|
||||||
patient_names = ['bc55b1b2']
|
patient_names = ['113c1470']
|
||||||
|
|
||||||
# Define the functional systems (columns to plot) - adjust based on actual column names
|
# Define the functional systems (columns to plot) - adjust based on actual column names
|
||||||
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||||||
@@ -746,3 +746,551 @@ plt.tight_layout()
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
##
|
##
|
||||||
|
|
||||||
|
|
||||||
|
# %% Table
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
file_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv'
|
||||||
|
df = pd.read_csv(file_path, sep='\t')
|
||||||
|
|
||||||
|
# Convert MedDatum to datetime
|
||||||
|
df['MedDatum'] = pd.to_datetime(df['MedDatum'])
|
||||||
|
|
||||||
|
# Check what columns actually exist in the dataset
|
||||||
|
print("Available columns:")
|
||||||
|
print(df.columns.tolist())
|
||||||
|
print("\nFirst few rows:")
|
||||||
|
print(df.head())
|
||||||
|
|
||||||
|
# Check data types
|
||||||
|
print("\nData types:")
|
||||||
|
print(df.dtypes)
|
||||||
|
|
||||||
|
# Hardcode specific patient names
|
||||||
|
patient_names = ['6ccda8c6']
|
||||||
|
|
||||||
|
# Define the functional systems (columns to plot)
|
||||||
|
functional_systems = ['EDSS', 'Visual', 'Sensory', 'Motor', 'Brainstem', 'Cerebellar', 'Autonomic', 'Bladder', 'Intellectual']
|
||||||
|
|
||||||
|
# Create subplots
|
||||||
|
num_plots = len(functional_systems)
|
||||||
|
num_cols = 2
|
||||||
|
num_rows = (num_plots + num_cols - 1) // num_cols
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 4*num_rows), sharex=False)
|
||||||
|
if num_plots == 1:
|
||||||
|
axes = [axes]
|
||||||
|
elif num_rows == 1:
|
||||||
|
axes = axes
|
||||||
|
else:
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
# Plot for the hardcoded patient
|
||||||
|
for i, system in enumerate(functional_systems):
|
||||||
|
# Filter data for this specific patient
|
||||||
|
patient_data = df[df['unique_id'] == patient_names[0]].sort_values('MedDatum')
|
||||||
|
|
||||||
|
# Check if patient data exists
|
||||||
|
if patient_data.empty:
|
||||||
|
print(f"No data found for patient: {patient_names[0]}")
|
||||||
|
axes[i].set_title(f'Functional System: {system} (No data)')
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if the system column exists
|
||||||
|
if system in patient_data.columns:
|
||||||
|
# Plot only valid data (non-null values)
|
||||||
|
valid_data = patient_data.dropna(subset=[system])
|
||||||
|
|
||||||
|
if not valid_data.empty:
|
||||||
|
# Ensure MedDatum is properly formatted for plotting
|
||||||
|
axes[i].plot(valid_data['MedDatum'], valid_data[system], marker='o', linewidth=2, label=system)
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
axes[i].set_title(f'Functional System: {system}')
|
||||||
|
axes[i].grid(True, alpha=0.3)
|
||||||
|
axes[i].legend()
|
||||||
|
else:
|
||||||
|
axes[i].set_title(f'Functional System: {system} (No valid data)')
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
else:
|
||||||
|
# Try to find similar column names
|
||||||
|
found_column = None
|
||||||
|
for col in df.columns:
|
||||||
|
if system.lower() in col.lower():
|
||||||
|
found_column = col
|
||||||
|
break
|
||||||
|
|
||||||
|
if found_column:
|
||||||
|
valid_data = patient_data.dropna(subset=[found_column])
|
||||||
|
if not valid_data.empty:
|
||||||
|
axes[i].plot(valid_data['MedDatum'], valid_data[found_column], marker='o', linewidth=2, label=found_column)
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
axes[i].set_title(f'Functional System: {system} (found as: {found_column})')
|
||||||
|
axes[i].grid(True, alpha=0.3)
|
||||||
|
axes[i].legend()
|
||||||
|
else:
|
||||||
|
axes[i].set_title(f'Functional System: {system} (No valid data)')
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
else:
|
||||||
|
axes[i].set_title(f'Functional System: {system} (Column not found)')
|
||||||
|
axes[i].set_ylabel('Score')
|
||||||
|
|
||||||
|
# Hide empty subplots
|
||||||
|
for i in range(len(functional_systems), len(axes)):
|
||||||
|
axes[i].set_visible(False)
|
||||||
|
|
||||||
|
# Set x-axis label for the last row only
|
||||||
|
for i in range(len(functional_systems)):
|
||||||
|
if i >= len(axes) - num_cols: # Last row
|
||||||
|
axes[i].set_xlabel('Date')
|
||||||
|
|
||||||
|
# Format x-axis dates
|
||||||
|
for ax in axes:
|
||||||
|
if ax.get_lines(): # Only format if there are lines to plot
|
||||||
|
ax.tick_params(axis='x', rotation=45)
|
||||||
|
ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
|
||||||
|
|
||||||
|
# Automatically adjust layout
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Histogram Fig1
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.font_manager as fm
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
def create_visit_frequency_plot(
|
||||||
|
file_path,
|
||||||
|
output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data',
|
||||||
|
output_filename='visit_frequency_distribution.svg',
|
||||||
|
fontsize=10,
|
||||||
|
color_scheme_path='colors.json'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a publication-ready bar chart of patient visit frequency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Path to the input TSV file.
|
||||||
|
output_dir (str): Directory to save the output SVG file.
|
||||||
|
output_filename (str): Name of the output SVG file.
|
||||||
|
fontsize (int): Font size for all text elements (labels, title).
|
||||||
|
color_scheme_path (str): Path to the JSON file containing the color palette.
|
||||||
|
"""
|
||||||
|
# --- 1. Load Data and Color Scheme ---
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(file_path, sep='\t')
|
||||||
|
print("Data loaded successfully.")
|
||||||
|
# Sort data for easier visual comparison
|
||||||
|
df = df.sort_values(by='Visits Count')
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file was not found at {file_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(color_scheme_path, 'r') as f:
|
||||||
|
colors = json.load(f)
|
||||||
|
# Select a blue from the sequential palette for the bars
|
||||||
|
bar_color = colors['sequential']['blues'][-2] # A saturated blue
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.")
|
||||||
|
bar_color = '#2171b5' # A common matplotlib blue
|
||||||
|
|
||||||
|
# --- 2. Set up the Plot with Scientific Style ---
|
||||||
|
plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height
|
||||||
|
|
||||||
|
# Set the font to Arial
|
||||||
|
arial_font = fm.FontProperties(family='Arial', size=fontsize)
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
plt.rcParams['font.size'] = fontsize
|
||||||
|
|
||||||
|
# --- 3. Create the Bar Chart ---
|
||||||
|
ax = plt.gca()
|
||||||
|
bars = plt.bar(
|
||||||
|
x=df['Visits Count'],
|
||||||
|
height=df['Unique Patients'],
|
||||||
|
color=bar_color,
|
||||||
|
edgecolor='black',
|
||||||
|
linewidth=0.5, # Minimum line thickness
|
||||||
|
width=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- NEW: Explicitly set x-ticks and labels to ensure all are shown ---
|
||||||
|
# Get the unique visit counts to use as tick labels
|
||||||
|
visit_counts = df['Visits Count'].unique()
|
||||||
|
# Set the x-ticks to be at the center of each bar
|
||||||
|
ax.set_xticks(visit_counts)
|
||||||
|
# Set the x-tick labels to be the visit counts, using the specified font
|
||||||
|
ax.set_xticklabels(visit_counts, fontproperties=arial_font)
|
||||||
|
# --- END OF NEW SECTION ---
|
||||||
|
|
||||||
|
# --- 4. Customize Axes and Layout (Nature style) ---
|
||||||
|
# Display only left and bottom axes
|
||||||
|
ax.spines['top'].set_visible(False)
|
||||||
|
ax.spines['right'].set_visible(False)
|
||||||
|
|
||||||
|
# Turn off axis ticks (the marks, not the labels)
|
||||||
|
plt.tick_params(axis='both', which='both', length=0)
|
||||||
|
|
||||||
|
# Remove grid lines
|
||||||
|
plt.grid(False)
|
||||||
|
|
||||||
|
# Set background to white (no shading)
|
||||||
|
ax.set_facecolor('white')
|
||||||
|
plt.gcf().set_facecolor('white')
|
||||||
|
|
||||||
|
# --- 5. Add Labels and Title ---
|
||||||
|
plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10)
|
||||||
|
plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10)
|
||||||
|
plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20)
|
||||||
|
|
||||||
|
# --- 6. Add y-axis values on top of each bar ---
|
||||||
|
# This adds the count of unique patients directly above each bar.
|
||||||
|
ax.bar_label(bars, fmt='%d', padding=3)
|
||||||
|
|
||||||
|
# --- 7. Export the Figure ---
|
||||||
|
# Ensure the output directory exists
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
full_output_path = os.path.join(output_dir, output_filename)
|
||||||
|
plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight')
|
||||||
|
print(f"\nFigure saved as '{full_output_path}'")
|
||||||
|
|
||||||
|
# --- 8. (Optional) Display the Plot ---
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
# --- Main execution ---
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Define the file path
|
||||||
|
input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv'
|
||||||
|
|
||||||
|
# Call the function to create and save the plot
|
||||||
|
create_visit_frequency_plot(
|
||||||
|
file_path=input_file,
|
||||||
|
fontsize=10 # Using a 10 pt font size as per guidelines
|
||||||
|
)
|
||||||
|
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Scatter Plot functional system
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
# Set the font to Arial for all text in the plot, as per the guidelines
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
|
||||||
|
# Define the path to your data file
|
||||||
|
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||||
|
|
||||||
|
# Define the path to save the color mapping JSON file
|
||||||
|
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||||||
|
|
||||||
|
# Define the path to save the final figure
|
||||||
|
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
|
||||||
|
|
||||||
|
# --- 1. Load the Dataset ---
|
||||||
|
try:
|
||||||
|
# Load the TSV file
|
||||||
|
df = pd.read_csv(data_path, sep='\t')
|
||||||
|
print(f"Successfully loaded data from {data_path}")
|
||||||
|
print(f"Data shape: {df.shape}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: The file at {data_path} was not found.")
|
||||||
|
# Exit or handle the error appropriately
|
||||||
|
raise
|
||||||
|
|
||||||
|
# --- 2. Define Functional Systems and Create Color Mapping ---
|
||||||
|
# List of tuples containing (ground_truth_column, result_column)
|
||||||
|
functional_systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract system names for color mapping and legend
|
||||||
|
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||||
|
|
||||||
|
# Define a professional color palette (dark blue theme)
|
||||||
|
# This is a qualitative palette with distinct, accessible colors
|
||||||
|
colors = [
|
||||||
|
'#003366', # Dark Blue
|
||||||
|
'#336699', # Medium Blue
|
||||||
|
'#6699CC', # Light Blue
|
||||||
|
'#99CCFF', # Very Light Blue
|
||||||
|
'#FF9966', # Coral
|
||||||
|
'#FF6666', # Light Red
|
||||||
|
'#CC6699', # Magenta
|
||||||
|
'#9966CC' # Purple
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a dictionary mapping system names to colors
|
||||||
|
color_map = dict(zip(system_names, colors))
|
||||||
|
|
||||||
|
# Ensure the directory for the JSON file exists
|
||||||
|
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Save the color map to a JSON file
|
||||||
|
with open(color_json_path, 'w') as f:
|
||||||
|
json.dump(color_map, f, indent=4)
|
||||||
|
|
||||||
|
print(f"Color mapping saved to {color_json_path}")
|
||||||
|
|
||||||
|
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
|
||||||
|
agreement_percentages = {}
|
||||||
|
legend_labels = {}
|
||||||
|
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Convert columns to numeric, setting errors to NaN
|
||||||
|
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||||||
|
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||||||
|
|
||||||
|
# Ensure we are comparing the same rows
|
||||||
|
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
|
||||||
|
gt_data = gt_numeric.loc[common_index]
|
||||||
|
res_data = res_numeric.loc[common_index]
|
||||||
|
|
||||||
|
# Calculate agreement percentage
|
||||||
|
if len(gt_data) > 0:
|
||||||
|
agreement = (gt_data == res_data).mean() * 100
|
||||||
|
else:
|
||||||
|
agreement = 0 # Handle case with no valid data
|
||||||
|
|
||||||
|
agreement_percentages[system_name] = agreement
|
||||||
|
|
||||||
|
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
|
||||||
|
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
|
||||||
|
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
|
||||||
|
|
||||||
|
# --- 4. Reshape Data for Plotting ---
|
||||||
|
plot_data = []
|
||||||
|
for gt_col, res_col in functional_systems_to_plot:
|
||||||
|
system_name = gt_col.split('.')[1]
|
||||||
|
|
||||||
|
# Convert columns to numeric, setting errors to NaN
|
||||||
|
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
|
||||||
|
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
|
||||||
|
|
||||||
|
# Create a temporary DataFrame with the numeric data
|
||||||
|
temp_df = pd.DataFrame({
|
||||||
|
'system': system_name,
|
||||||
|
'ground_truth': gt_numeric,
|
||||||
|
'inference': res_numeric
|
||||||
|
})
|
||||||
|
|
||||||
|
# Drop rows where either value is NaN, as they cannot be plotted
|
||||||
|
temp_df = temp_df.dropna()
|
||||||
|
|
||||||
|
plot_data.append(temp_df)
|
||||||
|
|
||||||
|
# Concatenate all the temporary DataFrames into one
|
||||||
|
plot_df = pd.concat(plot_data, ignore_index=True)
|
||||||
|
|
||||||
|
if plot_df.empty:
|
||||||
|
print("Warning: No valid numeric data to plot after conversion. The plot will be blank.")
|
||||||
|
else:
|
||||||
|
print(f"Prepared plot data with {len(plot_df)} data points.")
|
||||||
|
|
||||||
|
# --- 5. Create the Scatter Plot ---
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
|
||||||
|
# Plot each functional system with its assigned color and formatted legend label
|
||||||
|
for system, group in plot_df.groupby('system'):
|
||||||
|
plt.scatter(
|
||||||
|
group['ground_truth'],
|
||||||
|
group['inference'],
|
||||||
|
label=legend_labels[system],
|
||||||
|
color=color_map[system],
|
||||||
|
alpha=0.7,
|
||||||
|
s=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a diagonal line representing perfect agreement (y = x)
|
||||||
|
# This line helps visualize how close the predictions are to the ground truth
|
||||||
|
if not plot_df.empty:
|
||||||
|
plt.plot(
|
||||||
|
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
|
||||||
|
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
|
||||||
|
color='black',
|
||||||
|
linestyle='--',
|
||||||
|
linewidth=0.8,
|
||||||
|
alpha=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- 6. Apply Styling and Labels ---
|
||||||
|
plt.xlabel('Ground Truth', fontsize=12)
|
||||||
|
plt.ylabel('LLM Inference', fontsize=12)
|
||||||
|
plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14)
|
||||||
|
|
||||||
|
# Apply scientific visualization styling rules
|
||||||
|
ax = plt.gca()
|
||||||
|
ax.spines['top'].set_visible(False)
|
||||||
|
ax.spines['right'].set_visible(False)
|
||||||
|
ax.tick_params(axis='both', which='both', length=0) # Remove ticks
|
||||||
|
ax.grid(False) # Remove grid lines
|
||||||
|
plt.legend(title='Functional System', frameon=False, fontsize=10)
|
||||||
|
|
||||||
|
# --- 7. Save and Display the Figure ---
|
||||||
|
# Ensure the directory for the figure exists
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
|
||||||
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
|
print(f"Figure successfully saved to {figure_save_path}")
|
||||||
|
|
||||||
|
# Display the plot
|
||||||
|
plt.show()
|
||||||
|
##
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# %% Confusion Matrix functional systems
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.colors as mcolors
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
plt.rcParams['font.family'] = 'Arial'
|
||||||
|
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
|
||||||
|
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
|
||||||
|
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
|
||||||
|
|
||||||
|
# --- 1. Load the Dataset ---
|
||||||
|
df = pd.read_csv(data_path, sep='\t')
|
||||||
|
|
||||||
|
# --- 2. Define Functional Systems and Colors ---
|
||||||
|
functional_systems_to_plot = [
|
||||||
|
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
|
||||||
|
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
|
||||||
|
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
|
||||||
|
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
|
||||||
|
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
|
||||||
|
('GT.AMBULATION', 'result.AMBULATION'),
|
||||||
|
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
|
||||||
|
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
|
||||||
|
]
|
||||||
|
|
||||||
|
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
|
||||||
|
colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC']
|
||||||
|
color_map = dict(zip(system_names, colors))
|
||||||
|
|
||||||
|
# --- 3. Categorization Function ---
|
||||||
|
categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']
|
||||||
|
category_to_index = {cat: i for i, cat in enumerate(categories)}
|
||||||
|
n_categories = len(categories)
|
||||||
|
|
||||||
|
def categorize_edss(value):
|
||||||
|
if pd.isna(value): return np.nan
|
||||||
|
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0
|
||||||
|
return categories[min(idx, len(categories)-1)]
|
||||||
|
|
||||||
|
# --- 4. Prepare Mixed Color Matrix ---
|
||||||
|
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
|
||||||
|
|
||||||
|
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
|
||||||
|
# CRITICAL FIX: Convert to numeric and drop NaNs in one go
|
||||||
|
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
|
||||||
|
temp_df = df[[gt_col, res_col]].copy()
|
||||||
|
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
|
||||||
|
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
|
||||||
|
|
||||||
|
valid_df = temp_df.dropna()
|
||||||
|
|
||||||
|
for _, row in valid_df.iterrows():
|
||||||
|
gt_cat = categorize_edss(row[gt_col])
|
||||||
|
res_cat = categorize_edss(row[res_col])
|
||||||
|
if gt_cat in category_to_index and res_cat in category_to_index:
|
||||||
|
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
|
||||||
|
|
||||||
|
# Create an RGB image matrix (initially white/empty)
|
||||||
|
rgb_matrix = np.ones((n_categories, n_categories, 3))
|
||||||
|
|
||||||
|
# Create an Alpha matrix for the "Total Count" intensity
|
||||||
|
total_counts = np.sum(cell_system_counts, axis=2)
|
||||||
|
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
|
||||||
|
|
||||||
|
for i in range(n_categories):
|
||||||
|
for j in range(n_categories):
|
||||||
|
count_sum = total_counts[i, j]
|
||||||
|
if count_sum > 0:
|
||||||
|
# Calculate weighted average color
|
||||||
|
mixed_rgb = np.zeros(3)
|
||||||
|
for s_idx, s_name in enumerate(system_names):
|
||||||
|
weight = cell_system_counts[i, j, s_idx] / count_sum
|
||||||
|
system_rgb = mcolors.to_rgb(color_map[s_name])
|
||||||
|
mixed_rgb += np.array(system_rgb) * weight
|
||||||
|
rgb_matrix[i, j] = mixed_rgb
|
||||||
|
|
||||||
|
# --- 5. Plotting ---
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 10))
|
||||||
|
|
||||||
|
# Display the mixed color matrix
|
||||||
|
# We use alpha based on count to show density (optional, but recommended)
|
||||||
|
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper')
|
||||||
|
|
||||||
|
# Add text labels for total counts in each cell
|
||||||
|
for i in range(n_categories):
|
||||||
|
for j in range(n_categories):
|
||||||
|
if total_counts[i, j] > 0:
|
||||||
|
# Determine text color based on brightness of background
|
||||||
|
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2]
|
||||||
|
text_col = "white" if lum < 0.5 else "black"
|
||||||
|
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9)
|
||||||
|
|
||||||
|
# --- 6. Styling ---
|
||||||
|
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12)
|
||||||
|
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12)
|
||||||
|
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20)
|
||||||
|
|
||||||
|
ax.set_xticks(np.arange(n_categories))
|
||||||
|
ax.set_xticklabels(categories)
|
||||||
|
ax.set_yticks(np.arange(n_categories))
|
||||||
|
ax.set_yticklabels(categories)
|
||||||
|
|
||||||
|
# Custom Legend
|
||||||
|
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
|
||||||
|
labels = [name.replace('_', ' ').capitalize() for name in system_names]
|
||||||
|
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
|
||||||
|
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
135
Data/style2.py
135
Data/style2.py
@@ -1,135 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import dataframe_image as dfi
|
|
||||||
# Load data
|
|
||||||
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
|
|
||||||
|
|
||||||
# 1. Identify all GT and result columns
|
|
||||||
gt_columns = [col for col in df.columns if col.startswith('GT.')]
|
|
||||||
result_columns = [col for col in df.columns if col.startswith('result.')]
|
|
||||||
|
|
||||||
print("GT Columns found:", gt_columns)
|
|
||||||
print("Result Columns found:", result_columns)
|
|
||||||
|
|
||||||
# 2. Create proper mapping between GT and result columns
|
|
||||||
# Handle various naming conventions (spaces, underscores, etc.)
|
|
||||||
column_mapping = {}
|
|
||||||
|
|
||||||
for gt_col in gt_columns:
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
|
|
||||||
# Clean the base name for matching - remove spaces, underscores, etc.
|
|
||||||
# Try different matching approaches
|
|
||||||
candidates = [
|
|
||||||
f'result.{base_name}', # Exact match
|
|
||||||
f'result.{base_name.replace(" ", "_")}', # With underscores
|
|
||||||
f'result.{base_name.replace("_", " ")}', # With spaces
|
|
||||||
f'result.{base_name.replace(" ", "")}', # No spaces
|
|
||||||
f'result.{base_name.replace("_", "")}' # No underscores
|
|
||||||
]
|
|
||||||
|
|
||||||
# Also try case-insensitive matching
|
|
||||||
candidates.append(f'result.{base_name.lower()}')
|
|
||||||
candidates.append(f'result.{base_name.upper()}')
|
|
||||||
|
|
||||||
# Try to find matching result column
|
|
||||||
matched = False
|
|
||||||
for candidate in candidates:
|
|
||||||
if candidate in result_columns:
|
|
||||||
column_mapping[gt_col] = candidate
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# If no exact match found, try partial matching
|
|
||||||
if not matched:
|
|
||||||
# Try to match by removing special characters and comparing
|
|
||||||
base_clean = ''.join(e for e in base_name if e.isalnum() or e in ['_', ' '])
|
|
||||||
for result_col in result_columns:
|
|
||||||
result_base = result_col.replace('result.', '')
|
|
||||||
result_clean = ''.join(e for e in result_base if e.isalnum() or e in ['_', ' '])
|
|
||||||
if base_clean.lower() == result_clean.lower():
|
|
||||||
column_mapping[gt_col] = result_col
|
|
||||||
matched = True
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Column mapping:", column_mapping)
|
|
||||||
|
|
||||||
# 3. Faster, vectorized computation using the corrected mapping
|
|
||||||
data_list = []
|
|
||||||
|
|
||||||
for gt_col, result_col in column_mapping.items():
|
|
||||||
print(f"Processing {gt_col} vs {result_col}")
|
|
||||||
|
|
||||||
# Convert to numeric, forcing errors to NaN
|
|
||||||
s1 = pd.to_numeric(df[gt_col], errors='coerce').astype(float)
|
|
||||||
s2 = pd.to_numeric(df[result_col], errors='coerce').astype(float)
|
|
||||||
|
|
||||||
# Calculate matches (abs difference <= 0.5)
|
|
||||||
diff = np.abs(s1 - s2)
|
|
||||||
matches = (diff <= 0.5).sum()
|
|
||||||
|
|
||||||
# Determine the denominator (total valid comparisons)
|
|
||||||
valid_count = diff.notna().sum()
|
|
||||||
|
|
||||||
if valid_count > 0:
|
|
||||||
percentage = (matches / valid_count) * 100
|
|
||||||
else:
|
|
||||||
percentage = 0
|
|
||||||
|
|
||||||
# Extract clean base name for display
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
|
|
||||||
data_list.append({
|
|
||||||
'GT': base_name,
|
|
||||||
'Match %': round(percentage, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 4. Prepare Data for Plotting
|
|
||||||
match_df = pd.DataFrame(data_list)
|
|
||||||
match_df = match_df.sort_values('Match %', ascending=False) # Sort for better visual flow
|
|
||||||
|
|
||||||
# 5. Create the Styled Gradient Table
|
|
||||||
def style_agreement_table(df):
|
|
||||||
return (df.style
|
|
||||||
.format({'Match %': '{:.1f}%'}) # Add % sign
|
|
||||||
.background_gradient(cmap='RdYlGn', subset=['Match %'], vmin=50, vmax=100) # Red to Green gradient
|
|
||||||
.set_properties(**{
|
|
||||||
'text-align': 'center',
|
|
||||||
'font-size': '12pt',
|
|
||||||
'border-collapse': 'collapse',
|
|
||||||
'border': '1px solid #D3D3D3'
|
|
||||||
})
|
|
||||||
.set_table_styles([
|
|
||||||
# Style the header
|
|
||||||
{'selector': 'th', 'props': [
|
|
||||||
('background-color', '#404040'),
|
|
||||||
('color', 'white'),
|
|
||||||
('font-weight', 'bold'),
|
|
||||||
('text-transform', 'uppercase'),
|
|
||||||
('padding', '10px')
|
|
||||||
]},
|
|
||||||
# Add hover effect
|
|
||||||
{'selector': 'tr:hover', 'props': [('background-color', '#f5f5f5')]}
|
|
||||||
])
|
|
||||||
.set_caption("EDSS Agreement Analysis: Ground Truth vs. Results (Tolerance ±0.5)")
|
|
||||||
)
|
|
||||||
|
|
||||||
# To display in a Jupyter Notebook:
|
|
||||||
styled_table = style_agreement_table(match_df)
|
|
||||||
styled_table
|
|
||||||
|
|
||||||
dfi.export(styled_table, "styled_table.png")
|
|
||||||
#styled_table.to_html("agreement_report.html")
|
|
||||||
# 6. Save as SVG
|
|
||||||
|
|
||||||
#plt.savefig("agreement_table.svg", format='svg', dpi=300, bbox_inches='tight')
|
|
||||||
#print("Successfully saved agreement_table.svg")
|
|
||||||
|
|
||||||
# Show plot if running in a GUI environment
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
|
|
||||||
# Sample data (replace with your actual df)
|
|
||||||
df = pd.read_csv("/home/shahin/Lab/Doktorarbeit/Barcelona/Data/Join_edssandsub.tsv", sep='\t')
|
|
||||||
|
|
||||||
# Identify GT and Result columns
|
|
||||||
gt_columns = [col for col in df.columns if col.startswith('GT.')]
|
|
||||||
result_columns = [col for col in df.columns if col.startswith('result.')]
|
|
||||||
|
|
||||||
# Create mapping
|
|
||||||
column_mapping = {}
|
|
||||||
for gt_col in gt_columns:
|
|
||||||
base_name = gt_col.replace('GT.', '')
|
|
||||||
result_col = f'result.{base_name}'
|
|
||||||
if result_col in result_columns:
|
|
||||||
column_mapping[gt_col] = result_col
|
|
||||||
|
|
||||||
# Function to compute match percentage for each GT-Result pair
|
|
||||||
def compute_match_percentages(df, column_mapping):
|
|
||||||
percentages = []
|
|
||||||
for gt_col, result_col in column_mapping.items():
|
|
||||||
count = 0
|
|
||||||
total = len(df)
|
|
||||||
|
|
||||||
for _, row in df.iterrows():
|
|
||||||
gt_val = row[gt_col]
|
|
||||||
result_val = row[result_col]
|
|
||||||
|
|
||||||
# Handle NaN values
|
|
||||||
if pd.isna(gt_val) or pd.isna(result_val):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle non-numeric values
|
|
||||||
try:
|
|
||||||
gt_float = float(gt_val)
|
|
||||||
result_float = float(result_val)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
# Skip rows with non-numeric values
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if values are within 0.5 tolerance
|
|
||||||
if abs(gt_float - result_float) <= 0.5:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
percentage = (count / total) * 100
|
|
||||||
percentages.append({
|
|
||||||
'GT_Column': gt_col,
|
|
||||||
'Result_Column': result_col,
|
|
||||||
'Match_Percentage': round(percentage, 1)
|
|
||||||
})
|
|
||||||
|
|
||||||
return pd.DataFrame(percentages)
|
|
||||||
|
|
||||||
# Compute match percentages
|
|
||||||
match_df = compute_match_percentages(df, column_mapping)
|
|
||||||
|
|
||||||
# Create a pivot table for gradient display (optional but helpful)
|
|
||||||
pivot_table = match_df.set_index(['GT_Column', 'Result_Column'])['Match_Percentage'].unstack(fill_value=0)
|
|
||||||
|
|
||||||
# Apply gradient background
|
|
||||||
cm = sns.light_palette("green", as_cmap=True)
|
|
||||||
styled_table = pivot_table.style.background_gradient(cmap=cm, axis=None)
|
|
||||||
|
|
||||||
# Display result
|
|
||||||
print("Agreement Percentage Table (with gradient):")
|
|
||||||
styled_table
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Save the styled table to a file
|
|
||||||
styled_table.to_html("agreement_report.html")
|
|
||||||
print("Report saved to agreement_report.html")
|
|
||||||
Reference in New Issue
Block a user