一些小修改
This commit is contained in:
76
GPUMD/Umap/umap_make.py
Normal file
76
GPUMD/Umap/umap_make.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg") # 仅保存图片,不弹窗
|
||||
import matplotlib.pyplot as plt
|
||||
from umap import UMAP
|
||||
|
||||
def umap_dir_to_pngs(dir_path: str) -> None:
|
||||
"""
|
||||
对目录内每个 .npy 文件执行 UMAP(30D->2D) 并保存散点图。
|
||||
- 输入 .npy 期望形状为 (n_samples, 30) 或 (30, n_samples)
|
||||
- 输出图片保存在同目录,命名为 <原文件名>_umap.png
|
||||
"""
|
||||
p = Path(dir_path)
|
||||
if not p.is_dir():
|
||||
raise ValueError(f"{dir_path!r} 不是有效文件夹")
|
||||
|
||||
files = sorted(p.glob("*.npy"))
|
||||
if not files:
|
||||
print(f"目录 {p} 中未找到 .npy 文件")
|
||||
return
|
||||
|
||||
for f in files:
|
||||
try:
|
||||
data = np.load(f)
|
||||
if data.ndim == 2:
|
||||
if data.shape[1] == 30:
|
||||
X = data
|
||||
elif data.shape[0] == 30:
|
||||
X = data.T
|
||||
else:
|
||||
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
|
||||
continue
|
||||
else:
|
||||
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
|
||||
continue
|
||||
|
||||
# 清理非数值行
|
||||
mask = np.isfinite(X).all(axis=1)
|
||||
if not np.all(mask):
|
||||
X = X[mask]
|
||||
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
|
||||
|
||||
n_samples = X.shape[0]
|
||||
if n_samples < 3:
|
||||
print(f"[跳过] {f.name}: 样本数过少(n={n_samples}),无法稳定降维")
|
||||
continue
|
||||
|
||||
# 确保 n_neighbors 合法
|
||||
n_neighbors = min(15, max(2, n_samples - 1))
|
||||
reducer = UMAP(
|
||||
n_components=2,
|
||||
n_neighbors=n_neighbors,
|
||||
min_dist=0.1,
|
||||
metric="euclidean",
|
||||
random_state=42,
|
||||
)
|
||||
emb = reducer.fit_transform(X)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
|
||||
ax.scatter(emb[:, 0], emb[:, 1], s=6, c="#1f77b4", alpha=0.8, edgecolors="none")
|
||||
ax.set_title(f"{f.name} • UMAP (n={len(X)}, nn={n_neighbors})", fontsize=10)
|
||||
ax.set_xlabel("UMAP-1")
|
||||
ax.set_ylabel("UMAP-2")
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
fig.tight_layout()
|
||||
|
||||
out_png = f.with_suffix("").as_posix() + "_umap.png"
|
||||
fig.savefig(out_png)
|
||||
plt.close(fig)
|
||||
print(f"[完成] {f.name} -> {out_png}")
|
||||
except Exception as e:
|
||||
print(f"[错误] 处理 {f.name} 失败: {e}")
|
||||
|
||||
if __name__=="__main__":
|
||||
umap_dir_to_pngs("data")
|
||||
Reference in New Issue
Block a user