140 lines
5.1 KiB
Python
140 lines
5.1 KiB
Python
from pathlib import Path
|
||
import numpy as np
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.manifold import TSNE
|
||
from sklearn.decomposition import PCA
|
||
|
||
def tsne_dir_shared_coords(
|
||
dir_path: str,
|
||
*,
|
||
metric: str = "euclidean", # 可试 "cosine";想保留尺度差异用 "euclidean"
|
||
perplexity: float = 50.0, # 30k~50k 样本建议 30~50
|
||
n_iter: int = 1000,
|
||
early_exaggeration: float = 12.0,
|
||
learning_rate = "auto",
|
||
standardize: bool = False,
|
||
pca_dim: int | None = None, # 先用 PCA 降到 pca_dim(如 20) 再跑 t-SNE,可提速
|
||
context: bool = True,
|
||
make_joint: bool = True,
|
||
init: str = "pca",
|
||
random_state: int = 42
|
||
) -> None:
|
||
p = Path(dir_path)
|
||
if not p.is_dir():
|
||
raise ValueError(f"{dir_path!r} 不是有效文件夹")
|
||
|
||
files = sorted(p.glob("*.npy"))
|
||
if not files:
|
||
print(f"目录 {p} 中未找到 .npy 文件")
|
||
return
|
||
|
||
X_list, paths, counts = [], [], []
|
||
for f in files:
|
||
try:
|
||
data = np.load(f)
|
||
if data.ndim != 2:
|
||
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
|
||
continue
|
||
|
||
# 统一到 (n_samples, 30)
|
||
if data.shape[1] == 30:
|
||
X = data
|
||
elif data.shape[0] == 30:
|
||
X = data.T
|
||
else:
|
||
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
|
||
continue
|
||
|
||
mask = np.isfinite(X).all(axis=1)
|
||
if not np.all(mask):
|
||
X = X[mask]
|
||
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
|
||
|
||
if X.shape[0] < 3:
|
||
print(f"[跳过] {f.name}: 样本数过少(n={X.shape[0]})")
|
||
continue
|
||
|
||
X_list.append(X)
|
||
paths.append(f)
|
||
counts.append(X.shape[0])
|
||
except Exception as e:
|
||
print(f"[错误] 读取 {f.name} 失败: {e}")
|
||
|
||
if not X_list:
|
||
print("未找到可用的数据文件")
|
||
return
|
||
|
||
X_all = np.vstack(X_list)
|
||
|
||
if standardize:
|
||
mean = X_all.mean(axis=0)
|
||
std = X_all.std(axis=0); std[std == 0] = 1.0
|
||
X_all = (X_all - mean) / std
|
||
|
||
if pca_dim is not None and pca_dim > 2:
|
||
X_all = PCA(n_components=pca_dim, random_state=random_state).fit_transform(X_all)
|
||
|
||
tsne = TSNE(
|
||
n_components=2,
|
||
metric=metric,
|
||
perplexity=float(perplexity),
|
||
early_exaggeration=float(early_exaggeration),
|
||
learning_rate=learning_rate,
|
||
init=init,
|
||
random_state=random_state,
|
||
method="barnes_hut", # 适合大样本
|
||
angle=0.5,
|
||
verbose=0,
|
||
)
|
||
Z_all = tsne.fit_transform(X_all)
|
||
|
||
# 统一坐标轴范围
|
||
x_min, x_max = float(Z_all[:, 0].min()), float(Z_all[:, 0].max())
|
||
y_min, y_max = float(Z_all[:, 1].min()), float(Z_all[:, 1].max())
|
||
pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1.0
|
||
pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1.0
|
||
|
||
colors = [
|
||
"#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
|
||
"#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"
|
||
]
|
||
|
||
# 分文件出图
|
||
start = 0
|
||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||
Zi = Z_all[start:start + n]; start += n
|
||
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
|
||
if context:
|
||
ax.scatter(Z_all[:, 0], Z_all[:, 1], s=5, c="#cccccc", alpha=0.35, edgecolors="none", label="All")
|
||
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.9, edgecolors="none", label=f.name)
|
||
ax.set_title(f"{f.name} • t-SNE(shared) (perp={perplexity}, metric={metric})", fontsize=9)
|
||
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
|
||
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||
if context: ax.legend(loc="best", fontsize=8, frameon=False)
|
||
fig.tight_layout()
|
||
out_png = f.with_suffix("").as_posix() + "_tsne_shared.png"
|
||
fig.savefig(out_png); plt.close(fig)
|
||
print(f"[完成] {f.name} -> {out_png}")
|
||
|
||
# 总览图
|
||
if make_joint:
|
||
start = 0
|
||
fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
|
||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||
Zi = Z_all[start:start + n]; start += n
|
||
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.85, edgecolors="none", label=f.name)
|
||
ax.set_title(f"t-SNE(shared) overview (perp={perplexity}, metric={metric})", fontsize=10)
|
||
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
|
||
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||
ax.legend(loc="best", fontsize=8, frameon=False)
|
||
fig.tight_layout()
|
||
out_png = Path(dir_path) / "tsne_shared_overview.png"
|
||
fig.savefig(out_png.as_posix()); plt.close(fig)
|
||
print(f"[完成] 总览 -> {out_png}")
|
||
|
||
if __name__ == "__main__":
|
||
tsne_dir_shared_coords("data") |