Files
solidstate-tools/GPUMD/t-SNE/t-SNE.py
2025-11-19 12:23:17 +08:00

140 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
def tsne_dir_shared_coords(
dir_path: str,
*,
metric: str = "euclidean", # 可试 "cosine";想保留尺度差异用 "euclidean"
perplexity: float = 50.0, # 30k~50k 样本建议 30~50
n_iter: int = 1000,
early_exaggeration: float = 12.0,
learning_rate = "auto",
standardize: bool = False,
pca_dim: int | None = None, # 先用 PCA 降到 pca_dim(如 20) 再跑 t-SNE可提速
context: bool = True,
make_joint: bool = True,
init: str = "pca",
random_state: int = 42
) -> None:
p = Path(dir_path)
if not p.is_dir():
raise ValueError(f"{dir_path!r} 不是有效文件夹")
files = sorted(p.glob("*.npy"))
if not files:
print(f"目录 {p} 中未找到 .npy 文件")
return
X_list, paths, counts = [], [], []
for f in files:
try:
data = np.load(f)
if data.ndim != 2:
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
continue
# 统一到 (n_samples, 30)
if data.shape[1] == 30:
X = data
elif data.shape[0] == 30:
X = data.T
else:
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
continue
mask = np.isfinite(X).all(axis=1)
if not np.all(mask):
X = X[mask]
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
if X.shape[0] < 3:
print(f"[跳过] {f.name}: 样本数过少(n={X.shape[0]})")
continue
X_list.append(X)
paths.append(f)
counts.append(X.shape[0])
except Exception as e:
print(f"[错误] 读取 {f.name} 失败: {e}")
if not X_list:
print("未找到可用的数据文件")
return
X_all = np.vstack(X_list)
if standardize:
mean = X_all.mean(axis=0)
std = X_all.std(axis=0); std[std == 0] = 1.0
X_all = (X_all - mean) / std
if pca_dim is not None and pca_dim > 2:
X_all = PCA(n_components=pca_dim, random_state=random_state).fit_transform(X_all)
tsne = TSNE(
n_components=2,
metric=metric,
perplexity=float(perplexity),
early_exaggeration=float(early_exaggeration),
learning_rate=learning_rate,
init=init,
random_state=random_state,
method="barnes_hut", # 适合大样本
angle=0.5,
verbose=0,
)
Z_all = tsne.fit_transform(X_all)
# 统一坐标轴范围
x_min, x_max = float(Z_all[:, 0].min()), float(Z_all[:, 0].max())
y_min, y_max = float(Z_all[:, 1].min()), float(Z_all[:, 1].max())
pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1.0
pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1.0
colors = [
"#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
"#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"
]
# 分文件出图
start = 0
for i, (f, n) in enumerate(zip(paths, counts)):
Zi = Z_all[start:start + n]; start += n
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
if context:
ax.scatter(Z_all[:, 0], Z_all[:, 1], s=5, c="#cccccc", alpha=0.35, edgecolors="none", label="All")
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.9, edgecolors="none", label=f.name)
ax.set_title(f"{f.name} • t-SNE(shared) (perp={perplexity}, metric={metric})", fontsize=9)
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
if context: ax.legend(loc="best", fontsize=8, frameon=False)
fig.tight_layout()
out_png = f.with_suffix("").as_posix() + "_tsne_shared.png"
fig.savefig(out_png); plt.close(fig)
print(f"[完成] {f.name} -> {out_png}")
# 总览图
if make_joint:
start = 0
fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
for i, (f, n) in enumerate(zip(paths, counts)):
Zi = Z_all[start:start + n]; start += n
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.85, edgecolors="none", label=f.name)
ax.set_title(f"t-SNE(shared) overview (perp={perplexity}, metric={metric})", fontsize=10)
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
ax.legend(loc="best", fontsize=8, frameon=False)
fig.tight_layout()
out_png = Path(dir_path) / "tsne_shared_overview.png"
fig.savefig(out_png.as_posix()); plt.close(fig)
print(f"[完成] 总览 -> {out_png}")
if __name__ == "__main__":
tsne_dir_shared_coords("data")