pycorrana.core.visualizer 源代码

"""
可视化核心模块
==============
提供相关性热力图、散点图矩阵、箱线图等可视化功能。
"""

import warnings
from typing import Optional, List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


[文档] class CorrVisualizer: """ 相关性可视化器 提供多种相关性可视化图表。 """
[文档] def __init__(self, style: str = 'seaborn-v0_8-whitegrid'): """ Parameters ---------- style : str, default='seaborn-v0_8-whitegrid' matplotlib样式 """ try: plt.style.use(style) except: plt.style.use('seaborn-v0_8-whitegrid' if 'seaborn-v0_8-whitegrid' in plt.style.available else 'default') self.default_figsize = (10, 8) self.default_dpi = 100
[文档] def plot_heatmap(self, corr_matrix: pd.DataFrame, figsize: Tuple[int, int] = (10, 8), annot: bool = True, fmt: str = '.2f', cmap: str = 'RdBu_r', center: float = 0, vmin: float = -1, vmax: float = 1, cluster: bool = False, mask_upper: bool = False, title: Optional[str] = None, savefig: Optional[str] = None, dpi: int = 100, **kwargs) -> plt.Figure: """ 绘制相关性热力图。 Parameters ---------- corr_matrix : pd.DataFrame 相关性矩阵 figsize : tuple, default=(10, 8) 图表大小 annot : bool, default=True 是否显示数值标注 fmt : str, default='.2f' 数值格式 cmap : str, default='RdBu_r' 颜色映射 center : float, default=0 颜色中心值 vmin, vmax : float, default=-1, 1 颜色范围 cluster : bool, default=False 是否进行层次聚类 mask_upper : bool, default=False 是否只显示下三角 title : str, optional 图表标题 savefig : str, optional 保存路径 dpi : int, default=100 分辨率 **kwargs 其他参数传递给seaborn.heatmap Returns ------- plt.Figure matplotlib图表对象 """ fig, ax = plt.subplots(figsize=figsize, dpi=dpi) # 准备数据 plot_data = corr_matrix.copy() # 层次聚类 if cluster: try: from scipy.cluster.hierarchy import linkage, dendrogram from scipy.spatial.distance import squareform # 计算距离矩阵 dist_matrix = 1 - np.abs(plot_data.fillna(0)) linkage_matrix = linkage(squareform(dist_matrix), method='average') # 获取聚类顺序 dendro = dendrogram(linkage_matrix, no_plot=True) order = dendro['leaves'] plot_data = plot_data.iloc[order, order] except Exception as e: warnings.warn(f"聚类失败: {e},使用原始顺序") # 创建掩码(只显示下三角) mask = None if mask_upper: mask = np.triu(np.ones_like(plot_data, dtype=bool), k=1) # 绘制热力图 sns.heatmap( plot_data, mask=mask, annot=annot, fmt=fmt, cmap=cmap, center=center, vmin=vmin, vmax=vmax, square=True, linewidths=0.5, cbar_kws={"shrink": 0.8}, ax=ax, **kwargs ) # 设置标题 if title: ax.set_title(title, fontsize=14, fontweight='bold') else: ax.set_title('相关性热力图', fontsize=14, fontweight='bold') # 旋转标签 plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() # 保存 if savefig: plt.savefig(savefig, dpi=dpi, bbox_inches='tight', format=savefig.split('.')[-1] if '.' in savefig else 'png') print(f" 图表已保存: {savefig}") plt.show() return fig
[文档] def plot_pairplot(self, data: pd.DataFrame, columns: Optional[List[str]] = None, hue: Optional[str] = None, diag_kind: str = 'kde', kind: str = 'scatter', corner: bool = False, height: float = 2.5, aspect: float = 1, savefig: Optional[str] = None, **kwargs) -> sns.PairGrid: """ 绘制散点图矩阵。 Parameters ---------- data : pd.DataFrame 输入数据 columns : list, optional 要绘制的列 hue : str, optional 用于颜色区分的列 diag_kind : str, default='kde' 对角线图表类型:'kde'、'hist' kind : str, default='scatter' 非对角线图表类型:'scatter'、'reg' corner : bool, default=False 是否只绘制下三角 height : float, default=2.5 每个子图的高度 aspect : float, default=1 宽高比 savefig : str, optional 保存路径 **kwargs 其他参数传递给seaborn.pairplot Returns ------- sns.PairGrid seaborn PairGrid对象 """ if columns: plot_data = data[columns + ([hue] if hue else [])] else: plot_data = data g = sns.pairplot( plot_data, hue=hue, diag_kind=diag_kind, kind=kind, corner=corner, height=height, aspect=aspect, plot_kws={'alpha': 0.6, 's': 30}, diag_kws={'fill': True}, **kwargs ) g.fig.suptitle('散点图矩阵', y=1.02, fontsize=14, fontweight='bold') # 保存 if savefig: plt.savefig(savefig, dpi=100, bbox_inches='tight', format=savefig.split('.')[-1] if '.' in savefig else 'png') print(f" 图表已保存: {savefig}") plt.show() return g
[文档] def plot_boxplot(self, data: pd.DataFrame, numeric_col: str, categorical_col: str, kind: str = 'box', figsize: Tuple[int, int] = (10, 6), palette: str = 'Set2', show_points: bool = False, savefig: Optional[str] = None, **kwargs) -> plt.Figure: """ 绘制数值变量按分类变量分组的箱线图/小提琴图。 Parameters ---------- data : pd.DataFrame 输入数据 numeric_col : str 数值列名 categorical_col : str 分类列名 kind : str, default='box' 图表类型:'box'、'violin'、'boxen'、'strip'、'swarm' figsize : tuple, default=(10, 6) 图表大小 palette : str, default='Set2' 颜色调色板 show_points : bool, default=False 是否显示原始数据点 savefig : str, optional 保存路径 **kwargs 其他参数 Returns ------- plt.Figure matplotlib图表对象 """ fig, ax = plt.subplots(figsize=figsize) # 根据类型选择绘图函数 if kind == 'box': sns.boxplot( data=data, x=categorical_col, y=numeric_col, palette=palette, ax=ax, **kwargs ) elif kind == 'violin': sns.violinplot( data=data, x=categorical_col, y=numeric_col, palette=palette, ax=ax, **kwargs ) elif kind == 'boxen': sns.boxenplot( data=data, x=categorical_col, y=numeric_col, palette=palette, ax=ax, **kwargs ) elif kind == 'strip': sns.stripplot( data=data, x=categorical_col, y=numeric_col, palette=palette, ax=ax, **kwargs ) elif kind == 'swarm': sns.swarmplot( data=data, x=categorical_col, y=numeric_col, palette=palette, ax=ax, **kwargs ) else: raise ValueError(f"未知的图表类型: {kind}") # 叠加数据点 if show_points and kind in ['box', 'violin', 'boxen']: sns.stripplot( data=data, x=categorical_col, y=numeric_col, color='black', alpha=0.3, size=3, ax=ax ) ax.set_title(f'{numeric_col} by {categorical_col}', fontsize=14, fontweight='bold') ax.set_xlabel(categorical_col, fontsize=12) ax.set_ylabel(numeric_col, fontsize=12) # 旋转x轴标签 plt.xticks(rotation=45, ha='right') plt.tight_layout() # 保存 if savefig: plt.savefig(savefig, dpi=100, bbox_inches='tight', format=savefig.split('.')[-1] if '.' in savefig else 'png') print(f" 图表已保存: {savefig}") plt.show() return fig
[文档] def plot_correlation_network(self, corr_matrix: pd.DataFrame, threshold: float = 0.5, figsize: Tuple[int, int] = (12, 12), savefig: Optional[str] = None, **kwargs) -> plt.Figure: """ 绘制相关性网络图。 Parameters ---------- corr_matrix : pd.DataFrame 相关性矩阵 threshold : float, default=0.5 显示连接的阈值 figsize : tuple, default=(12, 12) 图表大小 savefig : str, optional 保存路径 **kwargs 其他参数 Returns ------- plt.Figure matplotlib图表对象 """ try: import networkx as nx except ImportError: warnings.warn("networkx未安装,无法绘制网络图") return None # 创建图 G = nx.Graph() # 添加节点 for col in corr_matrix.columns: G.add_node(col) # 添加边(只添加强相关) for i, col1 in enumerate(corr_matrix.columns): for j, col2 in enumerate(corr_matrix.columns): if i < j: # 避免重复 corr_val = corr_matrix.loc[col1, col2] if abs(corr_val) >= threshold: G.add_edge(col1, col2, weight=abs(corr_val)) # 绘制 fig, ax = plt.subplots(figsize=figsize) pos = nx.spring_layout(G, k=3, iterations=50) # 节点大小基于连接数 node_sizes = [G.degree(node) * 500 + 300 for node in G.nodes()] # 边宽度基于相关性强度 edges = G.edges() weights = [G[u][v]['weight'] * 3 for u, v in edges] nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color='lightblue', alpha=0.9, ax=ax) nx.draw_networkx_edges(G, pos, width=weights, alpha=0.5, ax=ax) nx.draw_networkx_labels(G, pos, font_size=10, ax=ax) ax.set_title(f'相关性网络图 (|r| >= {threshold})', fontsize=14, fontweight='bold') ax.axis('off') plt.tight_layout() # 保存 if savefig: plt.savefig(savefig, dpi=100, bbox_inches='tight', format=savefig.split('.')[-1] if '.' in savefig else 'png') print(f" 图表已保存: {savefig}") plt.show() return fig
[文档] def plot_significant_pairs(self, significant_pairs: List[dict], top_n: int = 20, figsize: Tuple[int, int] = (10, 8), savefig: Optional[str] = None, **kwargs) -> plt.Figure: """ 绘制显著相关对条形图。 Parameters ---------- significant_pairs : list 显著相关对列表 top_n : int, default=20 显示前N个 figsize : tuple, default=(10, 8) 图表大小 savefig : str, optional 保存路径 **kwargs 其他参数 Returns ------- plt.Figure matplotlib图表对象 """ if not significant_pairs: warnings.warn("没有显著相关对可绘制") return None # 取前N个 pairs = significant_pairs[:top_n] # 准备数据 labels = [f"{p['var1']}\nvs\n{p['var2']}" for p in pairs] values = [abs(p['correlation']) for p in pairs] colors = ['green' if p['correlation'] > 0 else 'red' for p in pairs] fig, ax = plt.subplots(figsize=figsize) bars = ax.barh(range(len(labels)), values, color=colors, alpha=0.7) ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels) ax.set_xlabel('|Correlation|', fontsize=12) ax.set_title(f'Top {top_n} Significant Correlations', fontsize=14, fontweight='bold') ax.invert_yaxis() # 添加数值标签 for i, (bar, val) in enumerate(zip(bars, values)): ax.text(val + 0.01, i, f'{val:.3f}', va='center', fontsize=9) plt.tight_layout() # 保存 if savefig: plt.savefig(savefig, dpi=100, bbox_inches='tight', format=savefig.split('.')[-1] if '.' in savefig else 'png') print(f" 图表已保存: {savefig}") plt.show() return fig