Files
solidstate-tools/GPUMD/Umap/umap_make_2.py
2025-11-19 12:23:17 +08:00

161 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
try:
from umap import UMAP
except Exception:
from umap.umap_ import UMAP
def umap_dir_shared_coords(
dir_path: str,
*,
metric: str = "cosine",
n_neighbors: int = 15,
min_dist: float = 0.0,
spread: float = 1.2,
standardize: bool = False,
context: bool = True,
make_joint: bool = True,
init: str = "random", # 关键:禁用谱初始化,避免告警;也可用 "pca"
jitter: float = 0.0, # 可选:拟合前加微弱噪声,如 1e-6
random_state: int = 42
) -> None:
"""
在同一 UMAP 坐标系中为目录内每个 .npy 文件生成 2D 图。
- 每个 .npy 形状为 (n_samples, 30) 或 (30, n_samples)
- 统一坐标轴范围;各自输出 *_umap_shared.png另可输出总览图
"""
p = Path(dir_path)
if not p.is_dir():
raise ValueError(f"{dir_path!r} 不是有效文件夹")
files = sorted(p.glob("*.npy"))
if not files:
print(f"目录 {p} 中未找到 .npy 文件")
return
X_list, paths, counts = [], [], []
for f in files:
try:
data = np.load(f)
if data.ndim != 2:
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
continue
if data.shape[1] == 30:
X = data
elif data.shape[0] == 30:
X = data.T
else:
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
continue
mask = np.isfinite(X).all(axis=1)
if not np.all(mask):
X = X[mask]
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
if X.shape[0] < 3:
print(f"[跳过] {f.name}: 样本数过少(n={X.shape[0]})")
continue
X_list.append(X)
paths.append(f)
counts.append(X.shape[0])
except Exception as e:
print(f"[错误] 读取 {f.name} 失败: {e}")
if not X_list:
print("未找到可用的数据文件")
return
X_all = np.vstack(X_list)
if standardize:
mean = X_all.mean(axis=0)
std = X_all.std(axis=0)
std[std == 0] = 1.0
X_all = (X_all - mean) / std
if jitter and jitter > 0:
rng = np.random.default_rng(random_state)
X_all = X_all + rng.normal(scale=jitter, size=X_all.shape)
reducer = UMAP(
n_components=2,
n_neighbors=int(max(2, n_neighbors)),
min_dist=float(min_dist),
spread=float(spread),
metric=metric,
init=init, # 关键改动:避免谱初始化告警
random_state=random_state,
)
Z_all = reducer.fit_transform(X_all)
x_min, x_max = float(Z_all[:, 0].min()), float(Z_all[:, 0].max())
y_min, y_max = float(Z_all[:, 1].min()), float(Z_all[:, 1].max())
pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1.0
pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1.0
base_colors = [
"#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
"#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"
]
start = 0
for i, (f, n) in enumerate(zip(paths, counts)):
Zi = Z_all[start:start + n]
start += n
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
if context:
ax.scatter(Z_all[:, 0], Z_all[:, 1], s=5, c="#cccccc",
alpha=0.35, edgecolors="none", label="All")
ax.scatter(Zi[:, 0], Zi[:, 1], s=10,
c=base_colors[i % len(base_colors)],
alpha=0.9, edgecolors="none", label=f.name)
ax.set_title(
f"{f.name} • UMAP(shared) (nn={n_neighbors}, min={min_dist}, metric={metric}, init={init})",
fontsize=9
)
ax.set_xlabel("UMAP-1")
ax.set_ylabel("UMAP-2")
ax.set_xlim(x_min - pad_x, x_max + pad_x)
ax.set_ylim(y_min - pad_y, y_max + pad_y)
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
if context:
ax.legend(loc="best", fontsize=8, frameon=False)
fig.tight_layout()
out_png = f.with_suffix("").as_posix() + "_umap_shared.png"
fig.savefig(out_png)
plt.close(fig)
print(f"[完成] {f.name} -> {out_png}")
if make_joint:
start = 0
fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
for i, (f, n) in enumerate(zip(paths, counts)):
Zi = Z_all[start:start + n]; start += n
ax.scatter(Zi[:, 0], Zi[:, 1], s=8,
c=base_colors[i % len(base_colors)],
alpha=0.85, edgecolors="none", label=f.name)
ax.set_title(f"UMAP(shared) overview (metric={metric}, nn={n_neighbors}, min={min_dist}, init={init})",
fontsize=10)
ax.set_xlabel("UMAP-1"); ax.set_ylabel("UMAP-2")
ax.set_xlim(x_min - pad_x, x_max + pad_x)
ax.set_ylim(y_min - pad_y, y_max + pad_y)
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
ax.legend(loc="best", fontsize=8, frameon=False, ncol=1)
fig.tight_layout()
out_png = Path(dir_path) / "umap_shared_overview.png"
fig.savefig(out_png.as_posix())
plt.close(fig)
print(f"[完成] 总览 -> {out_png}")
if __name__=="__main__":
umap_dir_shared_coords("data")