adding some visualizations

This commit is contained in:
2026-01-26 02:02:19 +01:00
parent 2f1bd2bfd0
commit b2e9ccd2b6

View File

@@ -748,7 +748,7 @@ plt.show()
##
# %% name
# %% Table
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
@@ -860,4 +860,437 @@ for ax in axes:
plt.tight_layout()
plt.show()
##
# %% Histogram Fig1
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import json
import os
def create_visit_frequency_plot(
file_path,
output_dir='/home/shahin/Lab/Doktorarbeit/Barcelona/Data',
output_filename='visit_frequency_distribution.svg',
fontsize=10,
color_scheme_path='colors.json'
):
"""
Creates a publication-ready bar chart of patient visit frequency.
Args:
file_path (str): Path to the input TSV file.
output_dir (str): Directory to save the output SVG file.
output_filename (str): Name of the output SVG file.
fontsize (int): Font size for all text elements (labels, title).
color_scheme_path (str): Path to the JSON file containing the color palette.
"""
# --- 1. Load Data and Color Scheme ---
try:
df = pd.read_csv(file_path, sep='\t')
print("Data loaded successfully.")
# Sort data for easier visual comparison
df = df.sort_values(by='Visits Count')
except FileNotFoundError:
print(f"Error: The file was not found at {file_path}")
return
try:
with open(color_scheme_path, 'r') as f:
colors = json.load(f)
# Select a blue from the sequential palette for the bars
bar_color = colors['sequential']['blues'][-2] # A saturated blue
except FileNotFoundError:
print(f"Warning: Color scheme file not found at {color_scheme_path}. Using default blue.")
bar_color = '#2171b5' # A common matplotlib blue
# --- 2. Set up the Plot with Scientific Style ---
plt.figure(figsize=(7.94, 6)) # Single-column width (7.94 cm) with appropriate height
# Set the font to Arial
arial_font = fm.FontProperties(family='Arial', size=fontsize)
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = fontsize
# --- 3. Create the Bar Chart ---
ax = plt.gca()
bars = plt.bar(
x=df['Visits Count'],
height=df['Unique Patients'],
color=bar_color,
edgecolor='black',
linewidth=0.5, # Minimum line thickness
width=0.7
)
# --- NEW: Explicitly set x-ticks and labels to ensure all are shown ---
# Get the unique visit counts to use as tick labels
visit_counts = df['Visits Count'].unique()
# Set the x-ticks to be at the center of each bar
ax.set_xticks(visit_counts)
# Set the x-tick labels to be the visit counts, using the specified font
ax.set_xticklabels(visit_counts, fontproperties=arial_font)
# --- END OF NEW SECTION ---
# --- 4. Customize Axes and Layout (Nature style) ---
# Display only left and bottom axes
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Turn off axis ticks (the marks, not the labels)
plt.tick_params(axis='both', which='both', length=0)
# Remove grid lines
plt.grid(False)
# Set background to white (no shading)
ax.set_facecolor('white')
plt.gcf().set_facecolor('white')
# --- 5. Add Labels and Title ---
plt.xlabel('Number of Visits', fontproperties=arial_font, labelpad=10)
plt.ylabel('Number of Unique Patients', fontproperties=arial_font, labelpad=10)
plt.title('Distribution of Patient Visit Frequency', fontproperties=arial_font, pad=20)
# --- 6. Add y-axis values on top of each bar ---
# This adds the count of unique patients directly above each bar.
ax.bar_label(bars, fmt='%d', padding=3)
# --- 7. Export the Figure ---
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True)
full_output_path = os.path.join(output_dir, output_filename)
plt.savefig(full_output_path, format='svg', dpi=300, bbox_inches='tight')
print(f"\nFigure saved as '{full_output_path}'")
# --- 8. (Optional) Display the Plot ---
# plt.show()
# --- Main execution ---
if __name__ == '__main__':
# Define the file path
input_file = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/visit_freuency.tsv'
# Call the function to create and save the plot
create_visit_frequency_plot(
file_path=input_file,
fontsize=10 # Using a 10 pt font size as per guidelines
)
##
# %% Scatter Plot functional system
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
# --- Configuration ---
# Set the font to Arial for all text in the plot, as per the guidelines
plt.rcParams['font.family'] = 'Arial'
# Define the path to your data file
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
# Define the path to save the color mapping JSON file
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
# Define the path to save the final figure
figure_save_path = 'project/visuals/edss_functional_systems_comparison.svg'
# --- 1. Load the Dataset ---
try:
# Load the TSV file
df = pd.read_csv(data_path, sep='\t')
print(f"Successfully loaded data from {data_path}")
print(f"Data shape: {df.shape}")
except FileNotFoundError:
print(f"Error: The file at {data_path} was not found.")
# Exit or handle the error appropriately
raise
# --- 2. Define Functional Systems and Create Color Mapping ---
# List of tuples containing (ground_truth_column, result_column)
functional_systems_to_plot = [
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
('GT.AMBULATION', 'result.AMBULATION'),
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
]
# Extract system names for color mapping and legend
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
# Define a professional color palette (dark blue theme)
# This is a qualitative palette with distinct, accessible colors
colors = [
'#003366', # Dark Blue
'#336699', # Medium Blue
'#6699CC', # Light Blue
'#99CCFF', # Very Light Blue
'#FF9966', # Coral
'#FF6666', # Light Red
'#CC6699', # Magenta
'#9966CC' # Purple
]
# Create a dictionary mapping system names to colors
color_map = dict(zip(system_names, colors))
# Ensure the directory for the JSON file exists
os.makedirs(os.path.dirname(color_json_path), exist_ok=True)
# Save the color map to a JSON file
with open(color_json_path, 'w') as f:
json.dump(color_map, f, indent=4)
print(f"Color mapping saved to {color_json_path}")
# --- 3. Calculate Agreement Percentages and Format Legend Labels ---
agreement_percentages = {}
legend_labels = {}
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Convert columns to numeric, setting errors to NaN
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
# Ensure we are comparing the same rows
common_index = gt_numeric.dropna().index.intersection(res_numeric.dropna().index)
gt_data = gt_numeric.loc[common_index]
res_data = res_numeric.loc[common_index]
# Calculate agreement percentage
if len(gt_data) > 0:
agreement = (gt_data == res_data).mean() * 100
else:
agreement = 0 # Handle case with no valid data
agreement_percentages[system_name] = agreement
# Format the system name for the legend (e.g., "VISUAL_OPTIC_FUNCTIONS" -> "Visual Optic Functions")
formatted_name = " ".join(word.capitalize() for word in system_name.split('_'))
legend_labels[system_name] = f"{formatted_name} ({agreement:.1f}%)"
# --- 4. Reshape Data for Plotting ---
plot_data = []
for gt_col, res_col in functional_systems_to_plot:
system_name = gt_col.split('.')[1]
# Convert columns to numeric, setting errors to NaN
gt_numeric = pd.to_numeric(df[gt_col], errors='coerce')
res_numeric = pd.to_numeric(df[res_col], errors='coerce')
# Create a temporary DataFrame with the numeric data
temp_df = pd.DataFrame({
'system': system_name,
'ground_truth': gt_numeric,
'inference': res_numeric
})
# Drop rows where either value is NaN, as they cannot be plotted
temp_df = temp_df.dropna()
plot_data.append(temp_df)
# Concatenate all the temporary DataFrames into one
plot_df = pd.concat(plot_data, ignore_index=True)
if plot_df.empty:
print("Warning: No valid numeric data to plot after conversion. The plot will be blank.")
else:
print(f"Prepared plot data with {len(plot_df)} data points.")
# --- 5. Create the Scatter Plot ---
plt.figure(figsize=(10, 8))
# Plot each functional system with its assigned color and formatted legend label
for system, group in plot_df.groupby('system'):
plt.scatter(
group['ground_truth'],
group['inference'],
label=legend_labels[system],
color=color_map[system],
alpha=0.7,
s=30
)
# Add a diagonal line representing perfect agreement (y = x)
# This line helps visualize how close the predictions are to the ground truth
if not plot_df.empty:
plt.plot(
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
[plot_df['ground_truth'].min(), plot_df['ground_truth'].max()],
color='black',
linestyle='--',
linewidth=0.8,
alpha=0.7
)
# --- 6. Apply Styling and Labels ---
plt.xlabel('Ground Truth', fontsize=12)
plt.ylabel('LLM Inference', fontsize=12)
plt.title('Comparison of EDSS Functional Systems: Ground Truth vs. LLM Inference', fontsize=14)
# Apply scientific visualization styling rules
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(axis='both', which='both', length=0) # Remove ticks
ax.grid(False) # Remove grid lines
plt.legend(title='Functional System', frameon=False, fontsize=10)
# --- 7. Save and Display the Figure ---
# Ensure the directory for the figure exists
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
print(f"Figure successfully saved to {figure_save_path}")
# Display the plot
plt.show()
##
# %% Confusion Matrix functional systems
import pandas as pd
import matplotlib.pyplot as plt
import json
import os
import numpy as np
import matplotlib.colors as mcolors
# --- Configuration ---
plt.rcParams['font.family'] = 'Arial'
data_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/comparison.tsv'
color_json_path = '/home/shahin/Lab/Doktorarbeit/Barcelona/Data/functional_system_colors.json'
figure_save_path = 'project/visuals/edss_combined_confusion_matrix_mixed.svg'
# --- 1. Load the Dataset ---
df = pd.read_csv(data_path, sep='\t')
# --- 2. Define Functional Systems and Colors ---
functional_systems_to_plot = [
('GT.VISUAL_OPTIC_FUNCTIONS', 'result.VISUAL OPTIC FUNCTIONS'),
('GT.CEREBELLAR_FUNCTIONS', 'result.CEREBELLAR FUNCTIONS'),
('GT.BRAINSTEM_FUNCTIONS', 'result.BRAINSTEM FUNCTIONS'),
('GT.SENSORY_FUNCTIONS', 'result.SENSORY FUNCTIONS'),
('GT.PYRAMIDAL_FUNCTIONS', 'result.PYRAMIDAL FUNCTIONS'),
('GT.AMBULATION', 'result.AMBULATION'),
('GT.CEREBRAL_FUNCTIONS', 'result.CEREBRAL FUNCTIONS'),
('GT.BOWEL_AND_BLADDER_FUNCTIONS', 'result.BOWEL AND BLADDER FUNCTIONS')
]
system_names = [name.split('.')[1] for name, _ in functional_systems_to_plot]
colors = ['#003366', '#336699', '#6699CC', '#99CCFF', '#FF9966', '#FF6666', '#CC6699', '#9966CC']
color_map = dict(zip(system_names, colors))
# --- 3. Categorization Function ---
categories = ['0-1', '1-2', '2-3', '3-4', '4-5', '5-6', '6-7', '7-8', '8-9', '9-10']
category_to_index = {cat: i for i, cat in enumerate(categories)}
n_categories = len(categories)
def categorize_edss(value):
if pd.isna(value): return np.nan
idx = int(min(max(value, 0), 10) - 0.001) if value > 0 else 0
return categories[min(idx, len(categories)-1)]
# --- 4. Prepare Mixed Color Matrix ---
cell_system_counts = np.zeros((n_categories, n_categories, len(system_names)))
for s_idx, (gt_col, res_col) in enumerate(functional_systems_to_plot):
# CRITICAL FIX: Convert to numeric and drop NaNs in one go
# 'coerce' turns non-numeric strings into NaN so they don't crash the script
temp_df = df[[gt_col, res_col]].copy()
temp_df[gt_col] = pd.to_numeric(temp_df[gt_col], errors='coerce')
temp_df[res_col] = pd.to_numeric(temp_df[res_col], errors='coerce')
valid_df = temp_df.dropna()
for _, row in valid_df.iterrows():
gt_cat = categorize_edss(row[gt_col])
res_cat = categorize_edss(row[res_col])
if gt_cat in category_to_index and res_cat in category_to_index:
cell_system_counts[category_to_index[gt_cat], category_to_index[res_cat], s_idx] += 1
# Create an RGB image matrix (initially white/empty)
rgb_matrix = np.ones((n_categories, n_categories, 3))
# Create an Alpha matrix for the "Total Count" intensity
total_counts = np.sum(cell_system_counts, axis=2)
max_count = np.max(total_counts) if np.max(total_counts) > 0 else 1
for i in range(n_categories):
for j in range(n_categories):
count_sum = total_counts[i, j]
if count_sum > 0:
# Calculate weighted average color
mixed_rgb = np.zeros(3)
for s_idx, s_name in enumerate(system_names):
weight = cell_system_counts[i, j, s_idx] / count_sum
system_rgb = mcolors.to_rgb(color_map[s_name])
mixed_rgb += np.array(system_rgb) * weight
rgb_matrix[i, j] = mixed_rgb
# --- 5. Plotting ---
fig, ax = plt.subplots(figsize=(12, 10))
# Display the mixed color matrix
# We use alpha based on count to show density (optional, but recommended)
im = ax.imshow(rgb_matrix, interpolation='nearest', origin='upper')
# Add text labels for total counts in each cell
for i in range(n_categories):
for j in range(n_categories):
if total_counts[i, j] > 0:
# Determine text color based on brightness of background
lum = 0.2126 * rgb_matrix[i,j,0] + 0.7152 * rgb_matrix[i,j,1] + 0.0722 * rgb_matrix[i,j,2]
text_col = "white" if lum < 0.5 else "black"
ax.text(j, i, int(total_counts[i, j]), ha="center", va="center", color=text_col, fontsize=9)
# --- 6. Styling ---
ax.set_xlabel('LLM Inference (EDSS Category)', fontsize=12)
ax.set_ylabel('Ground Truth (EDSS Category)', fontsize=12)
ax.set_title('Blended Confusion Matrix (Color = Weighted System Mixture)', fontsize=14, pad=20)
ax.set_xticks(np.arange(n_categories))
ax.set_xticklabels(categories)
ax.set_yticks(np.arange(n_categories))
ax.set_yticklabels(categories)
# Custom Legend
handles = [plt.Rectangle((0,0),1,1, color=color_map[name]) for name in system_names]
labels = [name.replace('_', ' ').capitalize() for name in system_names]
ax.legend(handles, labels, title='Functional Systems', loc='upper left', bbox_to_anchor=(1.05, 1), frameon=False)
plt.tight_layout()
os.makedirs(os.path.dirname(figure_save_path), exist_ok=True)
plt.savefig(figure_save_path, format='svg', bbox_inches='tight')
plt.show()
#