一些小修改
This commit is contained in:
76
GPUMD/Umap/umap_make.py
Normal file
76
GPUMD/Umap/umap_make.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg") # 仅保存图片,不弹窗
|
||||
import matplotlib.pyplot as plt
|
||||
from umap import UMAP
|
||||
|
||||
def umap_dir_to_pngs(dir_path: str) -> None:
|
||||
"""
|
||||
对目录内每个 .npy 文件执行 UMAP(30D->2D) 并保存散点图。
|
||||
- 输入 .npy 期望形状为 (n_samples, 30) 或 (30, n_samples)
|
||||
- 输出图片保存在同目录,命名为 <原文件名>_umap.png
|
||||
"""
|
||||
p = Path(dir_path)
|
||||
if not p.is_dir():
|
||||
raise ValueError(f"{dir_path!r} 不是有效文件夹")
|
||||
|
||||
files = sorted(p.glob("*.npy"))
|
||||
if not files:
|
||||
print(f"目录 {p} 中未找到 .npy 文件")
|
||||
return
|
||||
|
||||
for f in files:
|
||||
try:
|
||||
data = np.load(f)
|
||||
if data.ndim == 2:
|
||||
if data.shape[1] == 30:
|
||||
X = data
|
||||
elif data.shape[0] == 30:
|
||||
X = data.T
|
||||
else:
|
||||
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
|
||||
continue
|
||||
else:
|
||||
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
|
||||
continue
|
||||
|
||||
# 清理非数值行
|
||||
mask = np.isfinite(X).all(axis=1)
|
||||
if not np.all(mask):
|
||||
X = X[mask]
|
||||
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
|
||||
|
||||
n_samples = X.shape[0]
|
||||
if n_samples < 3:
|
||||
print(f"[跳过] {f.name}: 样本数过少(n={n_samples}),无法稳定降维")
|
||||
continue
|
||||
|
||||
# 确保 n_neighbors 合法
|
||||
n_neighbors = min(15, max(2, n_samples - 1))
|
||||
reducer = UMAP(
|
||||
n_components=2,
|
||||
n_neighbors=n_neighbors,
|
||||
min_dist=0.1,
|
||||
metric="euclidean",
|
||||
random_state=42,
|
||||
)
|
||||
emb = reducer.fit_transform(X)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
|
||||
ax.scatter(emb[:, 0], emb[:, 1], s=6, c="#1f77b4", alpha=0.8, edgecolors="none")
|
||||
ax.set_title(f"{f.name} • UMAP (n={len(X)}, nn={n_neighbors})", fontsize=10)
|
||||
ax.set_xlabel("UMAP-1")
|
||||
ax.set_ylabel("UMAP-2")
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
fig.tight_layout()
|
||||
|
||||
out_png = f.with_suffix("").as_posix() + "_umap.png"
|
||||
fig.savefig(out_png)
|
||||
plt.close(fig)
|
||||
print(f"[完成] {f.name} -> {out_png}")
|
||||
except Exception as e:
|
||||
print(f"[错误] 处理 {f.name} 失败: {e}")
|
||||
|
||||
if __name__=="__main__":
|
||||
umap_dir_to_pngs("data")
|
||||
161
GPUMD/Umap/umap_make_2.py
Normal file
161
GPUMD/Umap/umap_make_2.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
try:
|
||||
from umap import UMAP
|
||||
except Exception:
|
||||
from umap.umap_ import UMAP
|
||||
|
||||
|
||||
def umap_dir_shared_coords(
|
||||
dir_path: str,
|
||||
*,
|
||||
metric: str = "cosine",
|
||||
n_neighbors: int = 15,
|
||||
min_dist: float = 0.0,
|
||||
spread: float = 1.2,
|
||||
standardize: bool = False,
|
||||
context: bool = True,
|
||||
make_joint: bool = True,
|
||||
init: str = "random", # 关键:禁用谱初始化,避免告警;也可用 "pca"
|
||||
jitter: float = 0.0, # 可选:拟合前加微弱噪声,如 1e-6
|
||||
random_state: int = 42
|
||||
) -> None:
|
||||
"""
|
||||
在同一 UMAP 坐标系中为目录内每个 .npy 文件生成 2D 图。
|
||||
- 每个 .npy 形状为 (n_samples, 30) 或 (30, n_samples)
|
||||
- 统一坐标轴范围;各自输出 *_umap_shared.png,另可输出总览图
|
||||
"""
|
||||
p = Path(dir_path)
|
||||
if not p.is_dir():
|
||||
raise ValueError(f"{dir_path!r} 不是有效文件夹")
|
||||
|
||||
files = sorted(p.glob("*.npy"))
|
||||
if not files:
|
||||
print(f"目录 {p} 中未找到 .npy 文件")
|
||||
return
|
||||
|
||||
X_list, paths, counts = [], [], []
|
||||
for f in files:
|
||||
try:
|
||||
data = np.load(f)
|
||||
if data.ndim != 2:
|
||||
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
|
||||
continue
|
||||
|
||||
if data.shape[1] == 30:
|
||||
X = data
|
||||
elif data.shape[0] == 30:
|
||||
X = data.T
|
||||
else:
|
||||
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
|
||||
continue
|
||||
|
||||
mask = np.isfinite(X).all(axis=1)
|
||||
if not np.all(mask):
|
||||
X = X[mask]
|
||||
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
|
||||
|
||||
if X.shape[0] < 3:
|
||||
print(f"[跳过] {f.name}: 样本数过少(n={X.shape[0]})")
|
||||
continue
|
||||
|
||||
X_list.append(X)
|
||||
paths.append(f)
|
||||
counts.append(X.shape[0])
|
||||
except Exception as e:
|
||||
print(f"[错误] 读取 {f.name} 失败: {e}")
|
||||
|
||||
if not X_list:
|
||||
print("未找到可用的数据文件")
|
||||
return
|
||||
|
||||
X_all = np.vstack(X_list)
|
||||
|
||||
if standardize:
|
||||
mean = X_all.mean(axis=0)
|
||||
std = X_all.std(axis=0)
|
||||
std[std == 0] = 1.0
|
||||
X_all = (X_all - mean) / std
|
||||
|
||||
if jitter and jitter > 0:
|
||||
rng = np.random.default_rng(random_state)
|
||||
X_all = X_all + rng.normal(scale=jitter, size=X_all.shape)
|
||||
|
||||
reducer = UMAP(
|
||||
n_components=2,
|
||||
n_neighbors=int(max(2, n_neighbors)),
|
||||
min_dist=float(min_dist),
|
||||
spread=float(spread),
|
||||
metric=metric,
|
||||
init=init, # 关键改动:避免谱初始化告警
|
||||
random_state=random_state,
|
||||
)
|
||||
Z_all = reducer.fit_transform(X_all)
|
||||
|
||||
x_min, x_max = float(Z_all[:, 0].min()), float(Z_all[:, 0].max())
|
||||
y_min, y_max = float(Z_all[:, 1].min()), float(Z_all[:, 1].max())
|
||||
pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1.0
|
||||
pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1.0
|
||||
|
||||
base_colors = [
|
||||
"#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
|
||||
"#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"
|
||||
]
|
||||
|
||||
start = 0
|
||||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||||
Zi = Z_all[start:start + n]
|
||||
start += n
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
|
||||
if context:
|
||||
ax.scatter(Z_all[:, 0], Z_all[:, 1], s=5, c="#cccccc",
|
||||
alpha=0.35, edgecolors="none", label="All")
|
||||
ax.scatter(Zi[:, 0], Zi[:, 1], s=10,
|
||||
c=base_colors[i % len(base_colors)],
|
||||
alpha=0.9, edgecolors="none", label=f.name)
|
||||
|
||||
ax.set_title(
|
||||
f"{f.name} • UMAP(shared) (nn={n_neighbors}, min={min_dist}, metric={metric}, init={init})",
|
||||
fontsize=9
|
||||
)
|
||||
ax.set_xlabel("UMAP-1")
|
||||
ax.set_ylabel("UMAP-2")
|
||||
ax.set_xlim(x_min - pad_x, x_max + pad_x)
|
||||
ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
if context:
|
||||
ax.legend(loc="best", fontsize=8, frameon=False)
|
||||
fig.tight_layout()
|
||||
|
||||
out_png = f.with_suffix("").as_posix() + "_umap_shared.png"
|
||||
fig.savefig(out_png)
|
||||
plt.close(fig)
|
||||
print(f"[完成] {f.name} -> {out_png}")
|
||||
|
||||
if make_joint:
|
||||
start = 0
|
||||
fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
|
||||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||||
Zi = Z_all[start:start + n]; start += n
|
||||
ax.scatter(Zi[:, 0], Zi[:, 1], s=8,
|
||||
c=base_colors[i % len(base_colors)],
|
||||
alpha=0.85, edgecolors="none", label=f.name)
|
||||
ax.set_title(f"UMAP(shared) overview (metric={metric}, nn={n_neighbors}, min={min_dist}, init={init})",
|
||||
fontsize=10)
|
||||
ax.set_xlabel("UMAP-1"); ax.set_ylabel("UMAP-2")
|
||||
ax.set_xlim(x_min - pad_x, x_max + pad_x)
|
||||
ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
ax.legend(loc="best", fontsize=8, frameon=False, ncol=1)
|
||||
fig.tight_layout()
|
||||
out_png = Path(dir_path) / "umap_shared_overview.png"
|
||||
fig.savefig(out_png.as_posix())
|
||||
plt.close(fig)
|
||||
print(f"[完成] 总览 -> {out_png}")
|
||||
if __name__=="__main__":
|
||||
umap_dir_shared_coords("data")
|
||||
88
GPUMD/data_POSCAR/origin/pnma.vasp
Normal file
88
GPUMD/data_POSCAR/origin/pnma.vasp
Normal file
@@ -0,0 +1,88 @@
|
||||
Li Y Cl
|
||||
1.0000000000000000
|
||||
12.1082364219999992 -0.0000000000000000 0.0000000000000000
|
||||
0.0000420925000000 12.6964871139000000 0.0000000000000000
|
||||
0.0000111360000000 0.0000097283000000 11.1520040839999997
|
||||
Li Y Cl
|
||||
24 8 48
|
||||
Cartesian
|
||||
3.0170113299999999 11.0208475999999997 6.5429541999999996
|
||||
9.0710813300000002 11.0208076100000003 6.5429413099999998
|
||||
3.0372732299999998 1.6755553700000001 0.9669378900000000
|
||||
9.0914532300000008 1.6755853700000001 0.9668849900000001
|
||||
5.9960454600000004 8.0228419300000002 4.6273539599999998
|
||||
12.0502254600000001 8.0228319300000006 4.6273410699999999
|
||||
6.0439837100000000 8.0239104000000001 0.9669839400000000
|
||||
12.0980237099999997 8.0238703999999998 0.9669410500000000
|
||||
2.9687930800000002 11.0219791300000001 10.2032142199999996
|
||||
9.0228930799999993 11.0219691300000004 10.2032413300000009
|
||||
3.0851749800000001 1.6745967300000000 4.6273578600000000
|
||||
9.1393549800000002 1.6745967399999999 4.6273049700000000
|
||||
0.0581325900000000 4.6736939399999997 10.2033481199999994
|
||||
6.1122525899999998 4.6736539400000003 10.2033652299999993
|
||||
0.0102008400000000 4.6725925699999999 6.5430081500000004
|
||||
6.0643308400000002 4.6725225699999999 6.5430152499999998
|
||||
6.0582017800000001 11.3105425600000000 6.4882826299999996
|
||||
0.0040517800000000 11.3105725600000007 6.4882655299999996
|
||||
3.0311532099999998 7.7341853599999997 0.9123054900000001
|
||||
9.0851632099999993 7.7341953600000002 0.9123126000000000
|
||||
3.0230812999999999 4.9623175699999997 6.4882665900000003
|
||||
9.0772212999999997 4.9622875700000000 6.4882637000000001
|
||||
12.1043127199999994 1.3859403699999999 0.9123036700000000
|
||||
6.0501727199999999 1.3859103699999999 0.9123265600000000
|
||||
0.0501691200000000 4.7065010400000000 2.8276881199999999
|
||||
6.1042791200000002 4.7064910400000004 2.8277152200000000
|
||||
6.0040072100000001 7.9899634099999997 8.4036839699999994
|
||||
12.0581272100000003 7.9900134100000004 8.4037010799999994
|
||||
3.0772167300000000 1.6417582399999999 8.4037178800000003
|
||||
9.1312967300000007 1.6417482400000001 8.4036949900000000
|
||||
2.9769896100000000 11.0547561999999999 2.8276842100000001
|
||||
9.0310896100000004 11.0547362000000007 2.8277013100000001
|
||||
4.5156189400000004 10.0925444199999994 4.7485038499999996
|
||||
10.5696489400000004 10.0924044199999994 4.7485109599999999
|
||||
1.5387092400000000 2.6040715400000001 10.3245482299999995
|
||||
7.5927392300000003 2.6040115400000001 10.3245253399999992
|
||||
1.3626281400000000 9.1096203100000004 6.4718857500000002
|
||||
7.4167081399999999 9.1095703100000005 6.4718328500000002
|
||||
1.4883697199999999 8.9523616199999996 10.3245457500000004
|
||||
7.5425797200000000 8.9522516299999992 10.3245128600000005
|
||||
4.3896168400000004 9.9351631099999995 0.8958039700000000
|
||||
10.4437368399999997 9.9351431100000003 0.8958210800000000
|
||||
1.6644776500000000 2.7613398100000000 6.4718681100000000
|
||||
7.7185976500000004 2.7613198099999998 6.4718352200000000
|
||||
4.6916063499999998 3.5868826100000000 0.8958463400000000
|
||||
10.7457363499999996 3.5868826100000000 0.8958134500000000
|
||||
4.5657084499999998 3.7442043400000000 4.7485363400000002
|
||||
10.6197684500000005 3.7442543399999999 4.7484934399999998
|
||||
1.4271435299999999 5.7840809399999999 8.3327970300000000
|
||||
7.4812435300000004 5.7840109399999999 8.3328141400000000
|
||||
4.6270127299999997 6.9124334500000000 2.7567950600000000
|
||||
10.6811427299999995 6.9124334500000000 2.7567721600000001
|
||||
4.4542422500000001 0.5641837300000000 2.7567976500000002
|
||||
10.5083422500000001 0.5641637400000000 2.7567847500000000
|
||||
1.5999540200000000 12.1323306500000001 8.3327844399999993
|
||||
7.6540440199999997 12.1323306500000001 8.3327815399999992
|
||||
1.4488319300000001 5.8551694799999998 4.7388569900000004
|
||||
7.5029319299999999 5.8551594800000002 4.7388941000000004
|
||||
4.6052662299999998 6.8413464700000004 10.3148350900000008
|
||||
10.6594162299999997 6.8413164799999997 10.3148222000000001
|
||||
4.4984539600000000 0.5241951300000000 6.4733476400000001
|
||||
10.5525539599999991 0.5241751300000000 6.4733547500000004
|
||||
4.4759357399999997 0.4930566900000000 10.3148076700000004
|
||||
10.5300357400000006 0.4930866900000000 10.3148047700000003
|
||||
1.4714200500000001 5.8240679200000001 0.8973369900000000
|
||||
7.5254900500000002 5.8240779299999996 0.8973441000000000
|
||||
4.5827044399999997 6.8724549899999996 6.4733550900000001
|
||||
10.6368744399999997 6.8724349900000004 6.4733521999999999
|
||||
1.5557505300000001 12.1723877900000002 0.8973444400000000
|
||||
7.6098205300000004 12.1722777900000008 0.8973515500000000
|
||||
1.5782524200000001 12.2033792699999992 4.7388244200000003
|
||||
7.6324124199999996 12.2033992700000002 4.7388615300000003
|
||||
1.5507855200000000 2.5414585299999999 2.7575782499999999
|
||||
7.6049855199999996 2.5414785300000000 2.7574953600000001
|
||||
4.5034907500000001 10.1550858599999998 8.3335738300000006
|
||||
10.5574507499999992 10.1550558599999992 8.3335109400000000
|
||||
4.5777702700000003 3.8068257399999998 8.3335863099999994
|
||||
10.6319002699999992 3.8067957400000001 8.3335634100000000
|
||||
1.4764060000000001 8.8896886500000001 2.7575657800000002
|
||||
7.5304859999999998 8.8896686500000008 2.7575728900000001
|
||||
149
GPUMD/raw2xyz.py
Normal file
149
GPUMD/raw2xyz.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def convert_raw_to_gpumd_xyz(input_folder: str, output_filename: str = "gpumd_nep_training_data.xyz"):
|
||||
"""
|
||||
将 DeePMD-kit 风格的 .raw 训练数据转换为 GPUMD NEP 训练所需的 extended XYZ 格式。
|
||||
调整为 GPUMD 期望的格式,包括在注释行中添加 Properties 字段,
|
||||
并将每个原子的力数据附加到原子坐标行。
|
||||
Args:
|
||||
input_folder (str): 包含 .raw 文件的文件夹路径 (例如 './set.000/').
|
||||
output_filename (str): 输出的 GPUMD extended XYZ 文件的名称。
|
||||
Raises:
|
||||
FileNotFoundError: 如果必需的 .raw 文件不存在。
|
||||
ValueError: 如果数据格式不符合预期。
|
||||
"""
|
||||
required_files = [
|
||||
'box.raw', 'coord.raw', 'energy.raw', 'force.raw',
|
||||
'type.raw', 'type_map.raw', 'virial.raw'
|
||||
]
|
||||
# 检查所有必需的文件是否存在
|
||||
for filename in required_files:
|
||||
filepath = os.path.join(input_folder, filename)
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(
|
||||
f"Missing required file: {filepath}. Please ensure all .raw files are in the specified folder.")
|
||||
print(f"Loading raw from folder: {input_folder}")
|
||||
|
||||
# --- 1. 读取数据 ---
|
||||
try:
|
||||
# 读取 type_map.raw
|
||||
with open(os.path.join(input_folder, 'type_map.raw'), 'r') as f:
|
||||
type_map_list = [line.strip() for line in f if line.strip()] # 移除空行
|
||||
|
||||
# 首次加载 coord.raw 来确定 num_atoms
|
||||
first_coord_line = np.loadtxt(os.path.join(input_folder, 'coord.raw'), max_rows=1)
|
||||
if first_coord_line.ndim == 0: # 如果只有1个数字
|
||||
num_atoms = 1
|
||||
else:
|
||||
num_atoms = first_coord_line.shape[0] // 3
|
||||
if num_atoms == 0:
|
||||
raise ValueError(f"Could not determine num_atoms from coord.raw. It seems empty or malformed.")
|
||||
|
||||
# 现在有了正确的 num_atoms,重新加载 type.raw 以获取原子类型列表
|
||||
with open(os.path.join(input_folder, 'type.raw'), 'r') as f:
|
||||
all_types_lines = f.readlines()
|
||||
if not all_types_lines:
|
||||
raise ValueError(f"{os.path.join(input_folder, 'type.raw')} is empty or malformed.")
|
||||
|
||||
# 假设所有构型的原子类型序列是相同的,我们只需要第一个构型的类型
|
||||
first_type_config = np.array([int(x) for x in all_types_lines[0].strip().split()])
|
||||
if len(first_type_config) != num_atoms:
|
||||
# 尝试另一种 DeePMD 常见的 type.raw 格式:一个长序列,表示所有原子类型
|
||||
# 如果 type.raw 的行数等于原子数,我们假设每行一个原子类型
|
||||
if len(all_types_lines) == num_atoms:
|
||||
atom_types_numeric = np.array([int(line.strip()) for line in all_types_lines])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Mismatch between num_atoms ({num_atoms}) derived from coord.raw and type.raw format. "
|
||||
f"First line of type.raw has {len(first_type_config)} types, total lines {len(all_types_lines)}. "
|
||||
f"Please check type.raw format and adjust script.")
|
||||
else:
|
||||
atom_types_numeric = first_type_config # 正常情况,第一行就是第一个构型的所有原子类型
|
||||
|
||||
atom_symbols = [type_map_list[t] for t in atom_types_numeric]
|
||||
|
||||
# 读取其他数据
|
||||
boxes = np.loadtxt(os.path.join(input_folder, 'box.raw')).reshape(-1, 3, 3)
|
||||
coords_flat = np.loadtxt(os.path.join(input_folder, 'coord.raw'))
|
||||
energies = np.loadtxt(os.path.join(input_folder, 'energy.raw'))
|
||||
forces_flat = np.loadtxt(os.path.join(input_folder, 'force.raw'))
|
||||
virials_flat = np.loadtxt(os.path.join(input_folder, 'virial.raw')) # 可能是 9 个分量
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading .raw files. Please check their format. Details: {e}")
|
||||
|
||||
# 验证数据维度
|
||||
num_configs = len(energies)
|
||||
expected_coord_cols = num_atoms * 3
|
||||
expected_virial_cols = 9 # DeepMD通常输出9个分量
|
||||
|
||||
if coords_flat.shape[1] != expected_coord_cols:
|
||||
raise ValueError(
|
||||
f"coord.raw has {coords_flat.shape[1]} columns, but expected {expected_coord_cols} (N_atoms * 3).")
|
||||
if boxes.shape[0] != num_configs:
|
||||
raise ValueError(f"box.raw has {boxes.shape[0]} configurations, but expected {num_configs}. Data mismatch.")
|
||||
if forces_flat.shape[1] != expected_coord_cols:
|
||||
raise ValueError(
|
||||
f"force.raw has {forces_flat.shape[1]} columns, but expected {expected_coord_cols} (N_atoms * 3). Check file format.")
|
||||
if virials_flat.shape[0] != num_configs or virials_flat.shape[1] != expected_virial_cols:
|
||||
raise ValueError(
|
||||
f"virial.raw has shape {virials_flat.shape}, but expected ({num_configs}, {expected_virial_cols}). Check file format.")
|
||||
|
||||
coords = coords_flat.reshape(num_configs, num_atoms, 3)
|
||||
forces = forces_flat.reshape(num_configs, num_atoms, 3)
|
||||
virials_matrix = virials_flat.reshape(num_configs, 3, 3)
|
||||
|
||||
print(f"Loaded {num_configs} configurations with {num_atoms} atoms each.")
|
||||
|
||||
# --- 2. 写入到 GPUMD NEP 的 extended XYZ 格式 ---
|
||||
# 确保输出路径的目录存在
|
||||
output_dir = os.path.dirname(output_filename)
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
output_filepath = output_filename # 直接使用传入的output_filename作为最终路径
|
||||
|
||||
with open(output_filepath, 'w') as f:
|
||||
for i in range(num_configs):
|
||||
# 第一行:原子数量
|
||||
f.write(f"{num_atoms}\n")
|
||||
|
||||
# 第二行:元数据
|
||||
box_matrix_flat = boxes[i].flatten()
|
||||
box_str = " ".join(f"{x:.10f}" for x in box_matrix_flat)
|
||||
energy_str = f"{energies[i]:.10f}"
|
||||
|
||||
virial_tensor = virials_matrix[i]
|
||||
# --- 关键修改处:输出 Virial 的九个分量 ---
|
||||
# 展平 3x3 矩阵以得到九个分量
|
||||
virial_gpumd_components = virial_tensor.flatten()
|
||||
virial_str = " ".join(f"{x:.10f}" for x in virial_gpumd_components)
|
||||
|
||||
# 构造 GPUMD 兼容的第二行
|
||||
config_type_str = f"Config_type=dpgen_iter{i:03d}" # 示例:迭代号,可以自定义
|
||||
weight_str = "Weight=1.0"
|
||||
properties_str = "Properties=species:S:1:pos:R:3:forces:R:3" # 关键修改
|
||||
|
||||
f.write(
|
||||
f'{config_type_str} {weight_str} Lattice="{box_str}" Energy={energy_str} Virial="{virial_str}" pbc="T T T" {properties_str}\n'
|
||||
)
|
||||
|
||||
# 后续行:原子符号、坐标和力
|
||||
for j in range(num_atoms):
|
||||
x, y, z = coords[i, j]
|
||||
fx, fy, fz = forces[i, j]
|
||||
f.write(f"{atom_symbols[j]} {x:.10f} {y:.10f} {z:.10f} {fx:.10f} {fy:.10f} {fz:.10f}\n")
|
||||
|
||||
print(f"Successfully converted {num_configs} configurations to {output_filepath}")
|
||||
print(f"Output file saved at: {output_filepath}")
|
||||
|
||||
|
||||
# --- 如何使用这个函数 ---
|
||||
if __name__ == "__main__":
|
||||
# 示例用法:
|
||||
input_folder_path = 'data/dpmd_data/lyc/training_data/p3m1_data/raw'
|
||||
output_file_path = 'data/dpmd_data/lyc/training_data/p3m1_data/p3m1_train.xyz'
|
||||
|
||||
convert_raw_to_gpumd_xyz(input_folder=input_folder_path, output_filename=output_file_path)
|
||||
180
GPUMD/swap_li.py
Normal file
180
GPUMD/swap_li.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: ascii -*-
|
||||
"""
|
||||
Randomly swap one Li-Y pair in a VASP5 POSCAR and write N new files.
|
||||
- Keeps coordinate mode (Direct/Cartesian), Selective Dynamics flags, and Velocities.
|
||||
- Requires VASP5+ POSCAR (with element symbols line).
|
||||
"""
|
||||
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
def _is_ints(tokens):
|
||||
try:
|
||||
_ = [int(t) for t in tokens]
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _find_species_index(species, target):
|
||||
t = target.lower()
|
||||
for i, s in enumerate(species):
|
||||
if s.lower() == t:
|
||||
return i
|
||||
raise ValueError("Element '%s' not found in species line: %s" % (target, " ".join(species)))
|
||||
|
||||
def parse_poscar(lines):
|
||||
if len(lines) < 8:
|
||||
raise ValueError("POSCAR too short")
|
||||
|
||||
comment = lines[0].rstrip("\n")
|
||||
scale = lines[1].rstrip("\n")
|
||||
lv = [lines[2].rstrip("\n"), lines[3].rstrip("\n"), lines[4].rstrip("\n")]
|
||||
|
||||
i = 5
|
||||
tokens = lines[i].split()
|
||||
if _is_ints(tokens):
|
||||
raise ValueError("VASP4 format (no element symbols line) is not supported.")
|
||||
species = tokens
|
||||
i += 1
|
||||
|
||||
counts_line = lines[i].rstrip("\n")
|
||||
counts = [int(x) for x in counts_line.split()]
|
||||
i += 1
|
||||
|
||||
selective = False
|
||||
sel_line = None
|
||||
if i < len(lines) and lines[i].strip().lower().startswith("s"):
|
||||
selective = True
|
||||
sel_line = lines[i].rstrip("\n")
|
||||
i += 1
|
||||
|
||||
coord_line = lines[i].rstrip("\n")
|
||||
i += 1
|
||||
|
||||
natoms = sum(counts)
|
||||
pos_start = i
|
||||
pos_end = i + natoms
|
||||
if pos_end > len(lines):
|
||||
raise ValueError("Atom count exceeds file length.")
|
||||
pos_lines = [lines[j].rstrip("\n") for j in range(pos_start, pos_end)]
|
||||
|
||||
# Optional Velocities section
|
||||
k = pos_end
|
||||
while k < len(lines) and lines[k].strip() == "":
|
||||
k += 1
|
||||
|
||||
vel_header = None
|
||||
vel_lines = None
|
||||
vel_end = k
|
||||
if k < len(lines) and lines[k].strip().lower().startswith("veloc"):
|
||||
vel_header = lines[k].rstrip("\n")
|
||||
vel_start = k + 1
|
||||
vel_end = vel_start + natoms
|
||||
if vel_end > len(lines):
|
||||
raise ValueError("Velocities section length inconsistent with atom count.")
|
||||
vel_lines = [lines[j].rstrip("\n") for j in range(vel_start, vel_end)]
|
||||
|
||||
tail_lines = [lines[j].rstrip("\n") for j in range(vel_end, len(lines))] if vel_end < len(lines) else []
|
||||
|
||||
# Species index ranges (by order in species list)
|
||||
starts = []
|
||||
acc = 0
|
||||
for c in counts:
|
||||
starts.append(acc)
|
||||
acc += c
|
||||
species_ranges = []
|
||||
for idx, sp in enumerate(species):
|
||||
s, e = starts[idx], starts[idx] + counts[idx]
|
||||
species_ranges.append((sp, s, e))
|
||||
|
||||
return {
|
||||
"comment": comment,
|
||||
"scale": scale,
|
||||
"lv": lv,
|
||||
"species": species,
|
||||
"counts": counts,
|
||||
"counts_line": counts_line,
|
||||
"selective": selective,
|
||||
"sel_line": sel_line,
|
||||
"coord_line": coord_line,
|
||||
"natoms": natoms,
|
||||
"pos_lines": pos_lines,
|
||||
"vel_header": vel_header,
|
||||
"vel_lines": vel_lines,
|
||||
"tail_lines": tail_lines,
|
||||
"species_ranges": species_ranges,
|
||||
}
|
||||
|
||||
def build_poscar(data, pos_lines, vel_lines=None):
|
||||
out = []
|
||||
out.append(data["comment"])
|
||||
out.append(data["scale"])
|
||||
out.extend(data["lv"])
|
||||
out.append(" ".join(data["species"]))
|
||||
out.append(data["counts_line"])
|
||||
if data["selective"]:
|
||||
out.append(data["sel_line"] if data["sel_line"] is not None else "Selective dynamics")
|
||||
out.append(data["coord_line"])
|
||||
out.extend(pos_lines)
|
||||
if data["vel_header"] is not None and vel_lines is not None:
|
||||
out.append(data["vel_header"])
|
||||
out.extend(vel_lines)
|
||||
if data["tail_lines"]:
|
||||
out.extend(data["tail_lines"])
|
||||
return "\n".join(out) + "\n"
|
||||
|
||||
def _swap_once(data, rng, li_label="Li", y_label="Y"):
|
||||
si_li = _find_species_index(data["species"], li_label)
|
||||
si_y = _find_species_index(data["species"], y_label)
|
||||
_, li_start, li_end = data["species_ranges"][si_li]
|
||||
_, y_start, y_end = data["species_ranges"][si_y]
|
||||
|
||||
li_pick = rng.randrange(li_start, li_end)
|
||||
y_pick = rng.randrange(y_start, y_end)
|
||||
|
||||
new_pos = list(data["pos_lines"])
|
||||
new_pos[li_pick], new_pos[y_pick] = new_pos[y_pick], new_pos[li_pick]
|
||||
|
||||
new_vel = None
|
||||
if data["vel_lines"] is not None:
|
||||
new_vel = list(data["vel_lines"])
|
||||
new_vel[li_pick], new_vel[y_pick] = new_vel[y_pick], new_vel[li_pick]
|
||||
|
||||
return new_pos, new_vel, (li_pick, y_pick)
|
||||
|
||||
def swap(n, input_file, output_dir):
|
||||
"""
|
||||
Generate n POSCAR files, each with one random Li-Y swap.
|
||||
|
||||
Returns: list of Path to written files.
|
||||
"""
|
||||
input_path = Path(input_file)
|
||||
out_dir = Path(output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
lines = input_path.read_text().splitlines()
|
||||
data = parse_poscar(lines)
|
||||
|
||||
rng = random.Random()
|
||||
base = input_path.name
|
||||
|
||||
out_paths = []
|
||||
for k in range(1, n + 1):
|
||||
new_pos, new_vel, picked = _swap_once(data, rng)
|
||||
txt = build_poscar(data, new_pos, new_vel)
|
||||
out_path = out_dir / f"swap_{k}_{base}"
|
||||
out_path.write_text(txt)
|
||||
out_paths.append(out_path)
|
||||
print(f"Wrote {out_path} (swapped Li idx {picked[0]} <-> Y idx {picked[1]})")
|
||||
return out_paths
|
||||
# --------- Editable defaults for direct run ---------
|
||||
INPUT_FILE = "data_POSCAR/origin/p3m1.vasp" # path to input POSCAR
|
||||
OUTPUT_DIR = "data_POSCAR/p3m1" # output directory
|
||||
N = 5 # number of files to generate
|
||||
# ----------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Direct-run entry: edit INPUT_FILE/OUTPUT_DIR/N above to change behavior.
|
||||
swap(n=N, input_file=INPUT_FILE, output_dir=OUTPUT_DIR)
|
||||
140
GPUMD/t-SNE/t-SNE.py
Normal file
140
GPUMD/t-SNE/t-SNE.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.manifold import TSNE
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
def tsne_dir_shared_coords(
|
||||
dir_path: str,
|
||||
*,
|
||||
metric: str = "euclidean", # 可试 "cosine";想保留尺度差异用 "euclidean"
|
||||
perplexity: float = 50.0, # 30k~50k 样本建议 30~50
|
||||
n_iter: int = 1000,
|
||||
early_exaggeration: float = 12.0,
|
||||
learning_rate = "auto",
|
||||
standardize: bool = False,
|
||||
pca_dim: int | None = None, # 先用 PCA 降到 pca_dim(如 20) 再跑 t-SNE,可提速
|
||||
context: bool = True,
|
||||
make_joint: bool = True,
|
||||
init: str = "pca",
|
||||
random_state: int = 42
|
||||
) -> None:
|
||||
p = Path(dir_path)
|
||||
if not p.is_dir():
|
||||
raise ValueError(f"{dir_path!r} 不是有效文件夹")
|
||||
|
||||
files = sorted(p.glob("*.npy"))
|
||||
if not files:
|
||||
print(f"目录 {p} 中未找到 .npy 文件")
|
||||
return
|
||||
|
||||
X_list, paths, counts = [], [], []
|
||||
for f in files:
|
||||
try:
|
||||
data = np.load(f)
|
||||
if data.ndim != 2:
|
||||
print(f"[跳过] {f.name}: 期望二维数组,实际 shape={data.shape}")
|
||||
continue
|
||||
|
||||
# 统一到 (n_samples, 30)
|
||||
if data.shape[1] == 30:
|
||||
X = data
|
||||
elif data.shape[0] == 30:
|
||||
X = data.T
|
||||
else:
|
||||
print(f"[跳过] {f.name}: shape={data.shape}, 未检测到 30 维特征")
|
||||
continue
|
||||
|
||||
mask = np.isfinite(X).all(axis=1)
|
||||
if not np.all(mask):
|
||||
X = X[mask]
|
||||
print(f"[提示] {f.name}: 移除了含 NaN/Inf 的样本行")
|
||||
|
||||
if X.shape[0] < 3:
|
||||
print(f"[跳过] {f.name}: 样本数过少(n={X.shape[0]})")
|
||||
continue
|
||||
|
||||
X_list.append(X)
|
||||
paths.append(f)
|
||||
counts.append(X.shape[0])
|
||||
except Exception as e:
|
||||
print(f"[错误] 读取 {f.name} 失败: {e}")
|
||||
|
||||
if not X_list:
|
||||
print("未找到可用的数据文件")
|
||||
return
|
||||
|
||||
X_all = np.vstack(X_list)
|
||||
|
||||
if standardize:
|
||||
mean = X_all.mean(axis=0)
|
||||
std = X_all.std(axis=0); std[std == 0] = 1.0
|
||||
X_all = (X_all - mean) / std
|
||||
|
||||
if pca_dim is not None and pca_dim > 2:
|
||||
X_all = PCA(n_components=pca_dim, random_state=random_state).fit_transform(X_all)
|
||||
|
||||
tsne = TSNE(
|
||||
n_components=2,
|
||||
metric=metric,
|
||||
perplexity=float(perplexity),
|
||||
early_exaggeration=float(early_exaggeration),
|
||||
learning_rate=learning_rate,
|
||||
init=init,
|
||||
random_state=random_state,
|
||||
method="barnes_hut", # 适合大样本
|
||||
angle=0.5,
|
||||
verbose=0,
|
||||
)
|
||||
Z_all = tsne.fit_transform(X_all)
|
||||
|
||||
# 统一坐标轴范围
|
||||
x_min, x_max = float(Z_all[:, 0].min()), float(Z_all[:, 0].max())
|
||||
y_min, y_max = float(Z_all[:, 1].min()), float(Z_all[:, 1].max())
|
||||
pad_x = 0.05 * (x_max - x_min) if x_max > x_min else 1.0
|
||||
pad_y = 0.05 * (y_max - y_min) if y_max > y_min else 1.0
|
||||
|
||||
colors = [
|
||||
"#1f77b4","#ff7f0e","#2ca02c","#d62728","#9467bd",
|
||||
"#8c564b","#e377c2","#7f7f7f","#bcbd22","#17becf"
|
||||
]
|
||||
|
||||
# 分文件出图
|
||||
start = 0
|
||||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||||
Zi = Z_all[start:start + n]; start += n
|
||||
fig, ax = plt.subplots(figsize=(6, 5), dpi=150)
|
||||
if context:
|
||||
ax.scatter(Z_all[:, 0], Z_all[:, 1], s=5, c="#cccccc", alpha=0.35, edgecolors="none", label="All")
|
||||
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.9, edgecolors="none", label=f.name)
|
||||
ax.set_title(f"{f.name} • t-SNE(shared) (perp={perplexity}, metric={metric})", fontsize=9)
|
||||
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
|
||||
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
if context: ax.legend(loc="best", fontsize=8, frameon=False)
|
||||
fig.tight_layout()
|
||||
out_png = f.with_suffix("").as_posix() + "_tsne_shared.png"
|
||||
fig.savefig(out_png); plt.close(fig)
|
||||
print(f"[完成] {f.name} -> {out_png}")
|
||||
|
||||
# 总览图
|
||||
if make_joint:
|
||||
start = 0
|
||||
fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
|
||||
for i, (f, n) in enumerate(zip(paths, counts)):
|
||||
Zi = Z_all[start:start + n]; start += n
|
||||
ax.scatter(Zi[:, 0], Zi[:, 1], s=8, c=colors[i % len(colors)], alpha=0.85, edgecolors="none", label=f.name)
|
||||
ax.set_title(f"t-SNE(shared) overview (perp={perplexity}, metric={metric})", fontsize=10)
|
||||
ax.set_xlabel("t-SNE-1"); ax.set_ylabel("t-SNE-2")
|
||||
ax.set_xlim(x_min - pad_x, x_max + pad_x); ax.set_ylim(y_min - pad_y, y_max + pad_y)
|
||||
ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.5)
|
||||
ax.legend(loc="best", fontsize=8, frameon=False)
|
||||
fig.tight_layout()
|
||||
out_png = Path(dir_path) / "tsne_shared_overview.png"
|
||||
fig.savefig(out_png.as_posix()); plt.close(fig)
|
||||
print(f"[完成] 总览 -> {out_png}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
tsne_dir_shared_coords("data")
|
||||
3
Li_Conductivity/data/conductivity_results.csv
Normal file
3
Li_Conductivity/data/conductivity_results.csv
Normal file
@@ -0,0 +1,3 @@
|
||||
Temperature(K),Conductivity(S/m)
|
||||
300,0.0148
|
||||
425,6.3173
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from utils.MSD import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
# file_path_li = 'data/msd_li.dat'
|
||||
# file_path_li = 'raw/msd_li.dat'
|
||||
# final_msd = plot_and_get_final_msd(file_path_li, ion_name='Li⁺')
|
||||
|
||||
num_li_ions = 144 # !! 请务必用您体系的真实值替换此示例值 !!
|
||||
|
||||
@@ -234,7 +234,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 检查文件是否存在
|
||||
msd_path = os.path.join(folder_path, "msd_li.dat")
|
||||
data_path = os.path.join(folder_path, "LYC.data")
|
||||
data_path = os.path.join(folder_path, "LYC.raw")
|
||||
if not os.path.exists(msd_path):
|
||||
print(f" 跳过:缺少{msd_path}")
|
||||
continue
|
||||
|
||||
@@ -133,14 +133,134 @@ def copy_cif_with_O_or_S_robust(source_dir: str, target_dir: str, dry_run: bool
|
||||
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
|
||||
else:
|
||||
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
|
||||
|
||||
|
||||
def copy_cif_without_Br_or_Cl(source_dir: str, target_dir: str, dry_run: bool = False):
|
||||
"""
|
||||
从源文件夹中筛选出内容不含'Br'或'Cl'元素的CIF文件,并复制到目标文件夹。
|
||||
(鲁棒版:能正确解析CIF中的元素符号列)
|
||||
|
||||
:param source_dir: 源文件夹路径,包含CIF文件。
|
||||
:param target_dir: 目标文件夹路径,用于存放筛选出的文件。
|
||||
:param dry_run: 如果为True,则只打印将要复制的文件,而不实际执行复制操作。
|
||||
"""
|
||||
# 1. 路径处理和验证 (与原函数相同)
|
||||
source_path = Path(source_dir)
|
||||
target_path = Path(target_dir)
|
||||
if not source_path.is_dir():
|
||||
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个文件夹。")
|
||||
return
|
||||
if not dry_run and not target_path.exists():
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"目标文件夹 '{target_dir}' 已创建。")
|
||||
|
||||
print(f"源文件夹: {source_path}")
|
||||
print(f"目标文件夹: {target_path}")
|
||||
if dry_run:
|
||||
print("\n--- *** 模拟运行模式 (Dry Run) *** ---")
|
||||
print("--- 不会执行任何实际的文件复制操作 ---")
|
||||
|
||||
# 2. 开始遍历和筛选
|
||||
print("\n开始扫描源文件夹,剔除含 Br 或 Cl 的CIF文件...")
|
||||
copied_count = 0
|
||||
checked_files = 0
|
||||
error_files = 0
|
||||
excluded_files = 0
|
||||
|
||||
# 遍历所有 .cif 文件
|
||||
for file_path in source_path.glob('*.cif'):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
checked_files += 1
|
||||
contains_br_or_cl = False # 标记文件是否包含 Br 或 Cl
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 步骤 A: 找到元素符号在哪一列
|
||||
element_col_idx = find_element_column_index(lines)
|
||||
|
||||
if element_col_idx != -1:
|
||||
# 优先使用结构数据进行精确判断
|
||||
for line in lines:
|
||||
line_stripped = line.strip()
|
||||
# 忽略空行、注释行和定义行
|
||||
if not line_stripped or line_stripped.startswith(('#', '_', 'loop_')):
|
||||
continue
|
||||
|
||||
parts = line_stripped.split()
|
||||
# 确保行中有足够的列
|
||||
if len(parts) > element_col_idx:
|
||||
atom_symbol = parts[element_col_idx].strip()
|
||||
# 检查元素是否为 Br 或 Cl(也考虑类似 Br- 的情况)
|
||||
if atom_symbol.upper().startswith('BR') or atom_symbol.upper().startswith('CL'):
|
||||
contains_br_or_cl = True
|
||||
break # 找到一个就足够,可以停止检查这个文件
|
||||
|
||||
# 步骤 B: 如果上述方法未找到,使用化学式作为备用检查
|
||||
if not contains_br_or_cl:
|
||||
# 使用 any() 来高效检查,找到一个匹配即停止
|
||||
is_in_formula = any(
|
||||
line.strip().startswith(('_chemical_formula_sum', '_chemical_formula_structural')) and
|
||||
(' Br' in line or ' Cl' in line)
|
||||
for line in lines
|
||||
)
|
||||
if is_in_formula:
|
||||
contains_br_or_cl = True
|
||||
|
||||
# 步骤 C: 根据检查结果决定是否复制
|
||||
if contains_br_or_cl:
|
||||
# 如果包含 Br 或 Cl,则打印信息并跳过
|
||||
print(f"排除文件: '{file_path.name}' (检测到 Br 或 Cl 元素)")
|
||||
excluded_files += 1
|
||||
else:
|
||||
# 如果不包含 Br 或 Cl,则执行复制
|
||||
target_file_path = target_path / file_path.name
|
||||
print(f"找到匹配: '{file_path.name}' (不含 Br 或 Cl)")
|
||||
if not dry_run:
|
||||
shutil.copy2(file_path, target_file_path)
|
||||
copied_count += 1
|
||||
|
||||
except Exception as e:
|
||||
error_files += 1
|
||||
print(f"!! 处理文件 '{file_path.name}' 时发生错误: {e}")
|
||||
|
||||
# 3. 打印最终报告 (与原函数类似,增加了排除计数)
|
||||
print("\n--- 操作总结 ---")
|
||||
print(f"共检查了 {checked_files} 个.cif文件。")
|
||||
print(f"排除了 {excluded_files} 个含有 Br 或 Cl 的文件。")
|
||||
if error_files > 0:
|
||||
print(f"处理过程中有 {error_files} 个文件发生错误。")
|
||||
if dry_run:
|
||||
print(f"模拟运行结束:如果实际运行,将会有 {copied_count} 个文件被复制。")
|
||||
else:
|
||||
print(f"成功复制了 {copied_count} 个文件到目标文件夹。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# !! 重要:请将下面的路径修改为您自己电脑上的实际路径
|
||||
source_folder = "D:/download/2025-10/data_all/input/input"
|
||||
target_folder = "D:/download/2025-10/data_all/output"
|
||||
# source_folder = "D:/download/2025-10/data_all/input/input"
|
||||
# target_folder = "D:/download/2025-10/data_all/output"
|
||||
#
|
||||
# # --- 第一次运行:使用模拟模式 (Dry Run) ---
|
||||
# print("================ 第一次运行: 模拟模式 ================")
|
||||
# copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
|
||||
#
|
||||
# print("\n\n=======================================================")
|
||||
# input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
|
||||
# print("=======================================================")
|
||||
#
|
||||
# # --- 第二次运行:实际执行复制 ---
|
||||
# print("\n================ 第二次运行: 实际复制模式 ================")
|
||||
# copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)
|
||||
|
||||
source_folder = "D:/download/2025-10/data_OS/input"
|
||||
target_folder = "D:/download/2025-10/data_withoutBrCl/input"
|
||||
|
||||
# --- 第一次运行:使用模拟模式 (Dry Run) ---
|
||||
print("================ 第一次运行: 模拟模式 ================")
|
||||
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=True)
|
||||
copy_cif_without_Br_or_Cl(source_folder, target_folder, dry_run=True)
|
||||
|
||||
print("\n\n=======================================================")
|
||||
input("检查上面的模拟运行结果。如果符合预期,按回车键继续执行实际复制操作...")
|
||||
@@ -148,4 +268,4 @@ if __name__ == '__main__':
|
||||
|
||||
# --- 第二次运行:实际执行复制 ---
|
||||
print("\n================ 第二次运行: 实际复制模式 ================")
|
||||
copy_cif_with_O_or_S_robust(source_folder, target_folder, dry_run=False)
|
||||
copy_cif_without_Br_or_Cl(source_folder, target_folder, dry_run=False)
|
||||
|
||||
111
contrast learning/split.py
Normal file
111
contrast learning/split.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def split_dataset(source_dir: str, output_dir: str, test_ratio: float = 0.2):
|
||||
"""
|
||||
将源文件夹中的文件按比例划分到输出文件夹下的 train 和 test 子目录中。
|
||||
|
||||
Args:
|
||||
source_dir (str): 包含所有数据文件的源文件夹路径。
|
||||
output_dir (str): 用于存放'train'和'test'文件夹的目标文件夹路径。
|
||||
test_ratio (float, optional): 测试集所占的比例。默认为 0.2。
|
||||
"""
|
||||
print("--- 开始执行数据集划分 ---")
|
||||
|
||||
# 1. 路径处理和验证
|
||||
source_path = Path(source_dir)
|
||||
output_path = Path(output_dir)
|
||||
|
||||
if not source_path.is_dir():
|
||||
print(f"错误:源文件夹 '{source_dir}' 不存在或不是一个目录。")
|
||||
return
|
||||
|
||||
# 2. 创建输出文件夹 (train 和 test)
|
||||
train_dir = output_path / 'train'
|
||||
test_dir = output_path / 'test'
|
||||
|
||||
try:
|
||||
os.makedirs(train_dir, exist_ok=True)
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
print(f"输出目录已准备好: \n 训练集 -> {train_dir}\n 测试集 -> {test_dir}")
|
||||
except OSError as e:
|
||||
print(f"错误:创建输出目录时发生错误: {e}")
|
||||
return
|
||||
|
||||
# 3. 获取所有文件并随机打乱
|
||||
all_files = [f for f in source_path.iterdir() if f.is_file()]
|
||||
|
||||
if not all_files:
|
||||
print(f"警告:源文件夹 '{source_dir}' 中没有文件可供划分。")
|
||||
return
|
||||
|
||||
random.shuffle(all_files)
|
||||
total_files = len(all_files)
|
||||
print(f"在源文件夹中找到 {total_files} 个文件。")
|
||||
|
||||
# 4. 计算分割数量
|
||||
num_test = int(total_files * test_ratio)
|
||||
num_train = total_files - num_test
|
||||
|
||||
print(f"划分计划 -> 训练集: {num_train} 个文件 | 测试集: {num_test} 个文件")
|
||||
|
||||
# 5. 分割文件列表
|
||||
test_files = all_files[:num_test]
|
||||
train_files = all_files[num_test:]
|
||||
|
||||
# 6. 定义一个复制/移动文件的辅助函数
|
||||
def copy_files(files_to_copy, destination_dir):
|
||||
copied_count = 0
|
||||
for file_path in files_to_copy:
|
||||
try:
|
||||
# 注意:这里使用的是复制(copy),更安全。
|
||||
# 如果你确认要移动(move)并且清空源文件夹,请将 shutil.copy 改为 shutil.move
|
||||
shutil.copy(file_path, destination_dir)
|
||||
copied_count += 1
|
||||
except Exception as e:
|
||||
print(f"处理文件 '{file_path.name}' 时出错: {e}")
|
||||
return copied_count
|
||||
|
||||
# 7. 复制文件到对应的文件夹
|
||||
print(f"\n正在复制文件到 'train' 文件夹...")
|
||||
copied_train = copy_files(train_files, train_dir)
|
||||
print(f"成功复制 {copied_train} 个文件到训练集。")
|
||||
|
||||
print(f"\n正在复制文件到 'test' 文件夹...")
|
||||
copied_test = copy_files(test_files, test_dir)
|
||||
print(f"成功复制 {copied_test} 个文件到测试集。")
|
||||
|
||||
print("\n--- 数据集划分完成! ---")
|
||||
|
||||
|
||||
# --- 如何使用这个函数 ---
|
||||
if __name__ == '__main__':
|
||||
# --- 请在这里配置你的文件夹路径 ---
|
||||
|
||||
# 你的原始数据集所在的文件夹
|
||||
# 例如: 'C:/Users/YourUser/Desktop/my_dataset' (Windows)
|
||||
# 或: '/home/user/project/raw/all_images' (Linux/macOS)
|
||||
SOURCE_DATA_DIR = 'D:/download/2025-10/data_OS/input/S'
|
||||
|
||||
# 你希望将'train'和'test'文件夹创建在哪里
|
||||
# 例如: 'C:/Users/YourUser/Desktop/split_output' (Windows)
|
||||
# 或: '/home/user/project/raw/processed' (Linux/macOS)
|
||||
# 如果使用 '.', 表示在当前脚本所在的目录下创建
|
||||
OUTPUT_DIR = 'D:/download/2025-10/data_OS/train/S'
|
||||
|
||||
# --- 配置完成,下面是调用函数 ---
|
||||
|
||||
# 检查示例路径是否存在,如果不存在则创建并填充一些假文件用于演示
|
||||
if not os.path.exists(SOURCE_DATA_DIR):
|
||||
print(f"演示目录 '{SOURCE_DATA_DIR}' 不存在,正在创建并生成100个示例文件...")
|
||||
os.makedirs(SOURCE_DATA_DIR)
|
||||
for i in range(100):
|
||||
with open(os.path.join(SOURCE_DATA_DIR, f'file_{i + 1:03d}.txt'), 'w') as f:
|
||||
f.write(f'This is file {i + 1}.')
|
||||
print("示例文件创建完毕。")
|
||||
|
||||
# 调用函数执行划分
|
||||
split_dataset(SOURCE_DATA_DIR, OUTPUT_DIR, test_ratio=0.2)
|
||||
@@ -82,7 +82,7 @@ def process_cif_folder(cif_folder_path: str, output_csv_path: str):
|
||||
if __name__ == "__main__":
|
||||
# ----- 参数配置 -----
|
||||
# 请将此路径修改为您存放CIF文件的文件夹的实际路径
|
||||
CIF_DIRECTORY = "data/0921"
|
||||
CIF_DIRECTORY = "raw/0921"
|
||||
|
||||
# 输出的CSV文件名
|
||||
OUTPUT_CSV = "corner_sharing_results.csv"
|
||||
|
||||
@@ -349,7 +349,7 @@ def check_only_corner_sharing(sharing_results: Dict[str, Dict[str, int]]) -> int
|
||||
else:
|
||||
return 0 # 没有任何共享关系,也返回 0
|
||||
|
||||
# structure = Structure.from_file("../data/0921/wjy_001.cif")
|
||||
# structure = Structure.from_file("../raw/0921/wjy_001.cif")
|
||||
# a = CS_catulate(structure,notice=True)
|
||||
# b = CS_count(structure,a)
|
||||
# print(f"{a}\n{b}")
|
||||
|
||||
@@ -205,6 +205,6 @@ def export_envs(envlist, sp='Li', envtype='both', fname=None):
|
||||
f.write("Site index " + str(index) + ": " + str(i) + '\n')
|
||||
|
||||
|
||||
# struct = Structure.from_file("../data/0921/wjy_475.cif")
|
||||
# struct = Structure.from_file("../raw/0921/wjy_475.cif")
|
||||
# site_info = extract_sites(struct, envtype="both")
|
||||
# export_envs(site_info, sp="Li", envtype="both")
|
||||
BIN
data_get/new_v1/data/input/最新数据库核查过gsj.xlsx
Normal file
BIN
data_get/new_v1/data/input/最新数据库核查过gsj.xlsx
Normal file
Binary file not shown.
127
data_get/new_v1/data_get.py
Normal file
127
data_get/new_v1/data_get.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import pandas as pd
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def extract_cif_from_xlsx(
|
||||
xlsx_path: str,
|
||||
output_dir: str,
|
||||
naming_mode: str = 'formula',
|
||||
name_col: int = 0,
|
||||
cif_col: int = 1,
|
||||
prefix: str = 'wjy'
|
||||
):
|
||||
"""
|
||||
从 XLSX 文件中提取 CIF 数据并保存为单独的 .cif 文件。
|
||||
|
||||
Args:
|
||||
xlsx_path (str): 输入的 XLSX 文件的路径。
|
||||
output_dir (str): 输出 .cif 文件的文件夹路径。
|
||||
naming_mode (str, optional): CIF 文件的命名模式。
|
||||
可选值为 'formula' (使用第一列的名字) 或
|
||||
'auto' (使用前缀+自动递增编号)。
|
||||
默认为 'formula'。
|
||||
name_col (int, optional): 包含文件名的列的索引(从0开始)。默认为 0。
|
||||
cif_col (int, optional): 包含 CIF 内容的列的索引(从0开始)。默认为 1。
|
||||
prefix (str, optional): 在 'auto' 命名模式下使用的文件名前缀。默认为 'wjy'。
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 如果指定的 xlsx_path 文件不存在。
|
||||
ValueError: 如果指定的 naming_mode 无效。
|
||||
Exception: 处理过程中发生的其他错误。
|
||||
"""
|
||||
# --- 1. 参数校验和准备 ---
|
||||
if not os.path.exists(xlsx_path):
|
||||
raise FileNotFoundError(f"错误: 输入文件未找到 -> {xlsx_path}")
|
||||
|
||||
if naming_mode not in ['formula', 'auto']:
|
||||
raise ValueError(f"错误: 'naming_mode' 参数必须是 'formula' 或 'auto',但收到了 '{naming_mode}'")
|
||||
|
||||
# 创建输出目录(如果不存在)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
print(f"CIF 文件将保存到: {output_dir}")
|
||||
|
||||
try:
|
||||
# --- 2. 读取 XLSX 文件 ---
|
||||
# header=None 表示第一行不是标题,将其作为数据读取
|
||||
df = pd.read_excel(xlsx_path, header=None)
|
||||
|
||||
# 跳过原始文件的表头行('formula', 'cif')
|
||||
if str(df.iloc[0, name_col]).strip().lower() == 'formula' and str(df.iloc[0, cif_col]).strip().lower() == 'cif':
|
||||
df = df.iloc[1:]
|
||||
print("检测到并跳过了表头行。")
|
||||
|
||||
# --- 3. 遍历数据并生成文件 ---
|
||||
success_count = 0
|
||||
for index, row in df.iterrows():
|
||||
# 获取文件名和 CIF 内容
|
||||
formula_name = str(row[name_col])
|
||||
cif_content = str(row[cif_col])
|
||||
|
||||
# 跳过内容为空的行
|
||||
if pd.isna(row[name_col]) or pd.isna(row[cif_col]) or not cif_content.strip():
|
||||
print(f"警告: 第 {index + 2} 行数据不完整,已跳过。")
|
||||
continue
|
||||
|
||||
# --- 4. 根据命名模式确定文件名 ---
|
||||
if naming_mode == 'formula':
|
||||
# 清理文件名,替换掉不适合做文件名的特殊字符
|
||||
# 例如:将 (PO4)3 替换为 _PO4_3,将 / 替换为 _
|
||||
safe_filename = re.sub(r'[\\/*?:"<>|()]', '_', formula_name)
|
||||
filename = f"{safe_filename}.cif"
|
||||
else: # naming_mode == 'auto'
|
||||
# 使用 format 方法来确保编号格式统一,例如 001, 002
|
||||
filename = f"{prefix}_{success_count + 1:03d}.cif"
|
||||
|
||||
# 构造完整的输出文件路径
|
||||
output_path = os.path.join(output_dir, filename)
|
||||
|
||||
# --- 5. 写入 CIF 文件 ---
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(cif_content)
|
||||
success_count += 1
|
||||
except IOError as e:
|
||||
print(f"错误: 无法写入文件 {output_path}。原因: {e}")
|
||||
|
||||
print(f"\n处理完成!成功提取并生成了 {success_count} 个 CIF 文件。")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理 XLSX 文件时发生错误: {e}")
|
||||
|
||||
|
||||
# --- 函数使用示例 ---
|
||||
if __name__ == '__main__':
|
||||
# 假设您的 XLSX 文件名为 'materials.xlsx',且与此脚本在同一目录下
|
||||
source_xlsx_file = 'input/cif_dataset.xlsx'
|
||||
|
||||
# 检查示例文件是否存在,如果不存在则创建一个
|
||||
if not os.path.exists(source_xlsx_file):
|
||||
print(f"未找到示例文件 '{source_xlsx_file}',正在创建一个...")
|
||||
example_data = {
|
||||
'formula': ['Li3Al0.3Ti1.7(PO4)3', 'Li6.5La3Zr1.75W0.25O12', 'Invalid/Name*Test'],
|
||||
'cif': ['# CIF Data for Li3Al0.3...\n_atom_site_type_symbol\n Li\n Al\n Ti\n P\n O',
|
||||
'# CIF Data for Li6.5La3...\n_symmetry_space_group_name_H-M \'I a -3 d\'',
|
||||
'# CIF Data for Invalid Name Test']
|
||||
}
|
||||
pd.DataFrame(example_data).to_excel(source_xlsx_file, index=False, header=True)
|
||||
print("示例文件创建成功。")
|
||||
|
||||
# --- 示例 1: 使用第一列的 'formula' 命名 ---
|
||||
# print("\n--- 示例 1: 使用 'formula' 命名模式 ---")
|
||||
# output_folder_1 = 'cif_by_formula'
|
||||
# extract_cif_from_xlsx(
|
||||
# xlsx_path=source_xlsx_file,
|
||||
# output_dir=output_folder_1,
|
||||
# naming_mode='formula'
|
||||
# )
|
||||
|
||||
# --- 示例 2: 使用 'wjy+编号' 自动命名 ---
|
||||
print("\n--- 示例 2: 使用 'auto' 命名模式 ---")
|
||||
output_folder_2 = 'cif_by_auto'
|
||||
extract_cif_from_xlsx(
|
||||
xlsx_path=source_xlsx_file,
|
||||
output_dir=output_folder_2,
|
||||
naming_mode='auto',
|
||||
prefix='wjy'
|
||||
)
|
||||
79
dpgen/create_supercell_poscar.py
Normal file
79
dpgen/create_supercell_poscar.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from pymatgen.core import Structure
|
||||
from pymatgen.io.vasp import Poscar
|
||||
|
||||
def create_supercell_poscar(cif_path, supercell_matrix, output_filename="POSCAR_supercell"):
|
||||
"""
|
||||
从CIF文件读取晶体结构,根据指定的矩阵进行扩胞,并生成VASP POSCAR文件。
|
||||
|
||||
Args:
|
||||
cif_path (str): 输入的CIF文件路径。
|
||||
supercell_matrix (list or tuple): 3x3的扩胞矩阵。
|
||||
- 对于简单的对角扩胞 (例如 2x2x4),使用: [[2, 0, 0], [0, 2, 0], [0, 0, 4]]
|
||||
- 对于非对角扩胞 (例如 a_s=3a, b_s=2a+4b, c_s=6c),使用: [[3, 0, 0], [2, 4, 0], [0, 0, 6]]
|
||||
output_filename (str): 输出的POSCAR文件名。默认为 "POSCAR_supercell"。
|
||||
|
||||
Returns:
|
||||
bool: 如果成功生成文件则返回 True,否则返回 False。
|
||||
"""
|
||||
try:
|
||||
# 1. 从CIF文件加载结构
|
||||
# 使用 from_file 静态方法直接读取
|
||||
# primitive=False 确保我们使用CIF中定义的晶胞,而不是其原胞
|
||||
original_structure = Structure.from_file(cif_path, primitive=False)
|
||||
|
||||
print("--- 原始晶胞信息 ---")
|
||||
print(f" 原子数: {original_structure.num_sites}")
|
||||
print(f" 化学式: {original_structure.composition.reduced_formula}")
|
||||
print(f" 晶格参数 (a, b, c, α, β, γ):")
|
||||
lat = original_structure.lattice
|
||||
print(f" {lat.a:.4f}, {lat.b:.4f}, {lat.c:.4f}, {lat.alpha:.2f}, {lat.beta:.2f}, {lat.gamma:.2f}")
|
||||
|
||||
# 2. 进行扩胞操作
|
||||
# 注意:pymatgen 会自动处理原子坐标的映射
|
||||
supercell_structure = original_structure * supercell_matrix
|
||||
|
||||
print("\n--- 扩胞后信息 ---")
|
||||
print(f" 扩胞矩阵: {supercell_matrix}")
|
||||
print(f" 新原子数: {supercell_structure.num_sites}")
|
||||
print(f" 新化学式: {supercell_structure.composition.reduced_formula}")
|
||||
print(f" 新晶格参数 (a, b, c, α, β, γ):")
|
||||
super_lat = supercell_structure.lattice
|
||||
print(f" {super_lat.a:.4f}, {super_lat.b:.4f}, {super_lat.c:.4f}, {super_lat.alpha:.2f}, {super_lat.beta:.2f}, {super_lat.gamma:.2f}")
|
||||
|
||||
# 3. 创建Poscar对象并写入文件
|
||||
# comment 参数可以设置POSCAR文件的第一行注释
|
||||
poscar = Poscar(supercell_structure, comment=f"Supercell from {cif_path}")
|
||||
poscar.write_file(output_filename)
|
||||
|
||||
print(f"\n成功!已将扩胞结构写入文件: {output_filename}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
return False
|
||||
|
||||
# --- 使用示例 ---
|
||||
|
||||
# 假设您的CIF文件名为 "origin.cif",并且与此脚本在同一目录下。
|
||||
# 如果您在复现 Wang Shuo 的工作,他们可能使用了不同的扩胞方案。
|
||||
# 例如,用于MD模拟的大超胞是 2x2x4 [source_id: 3]。
|
||||
# 而用于DP-GEN探索的小超胞是 1x1x2 [source_id: 3]。
|
||||
|
||||
# 示例1:生成用于DP-GEN探索的 1x1x2 小超胞 (60个原子)
|
||||
print("="*40)
|
||||
print("正在生成 1x1x2 超胞 (用于 DP-GEN 探索)...")
|
||||
matrix_1x1x2 = [[1, 0, 0], [0, 1, 0], [0, 0, 2]]
|
||||
create_supercell_poscar("data/P3ma/origin.cif", matrix_1x1x2, "data/P3ma/output/POSCAR_1x1x2_60atoms")
|
||||
|
||||
# 示例2:生成用于LAMMPS MD模拟的 2x2x4 大超胞 (480个原子)
|
||||
# print("\n" + "="*40)
|
||||
# print("正在生成 2x2x4 超胞 (用于 LAMMPS MD 模拟)...")
|
||||
# matrix_2x2x4 = [[2, 0, 0], [0, 2, 0], [0, 0, 4]]
|
||||
# create_supercell_poscar("origin.cif", matrix_2x2x4, "POSCAR_2x2x4_480atoms")
|
||||
#
|
||||
# # 示例3:生成 Geng 等人研究中使用的大超胞 (2160个原子) [source_id: 1]
|
||||
# # 这个扩胞矩阵 a_s = 3a_0, b_s = 2a_0 + 4b_0, c_s = 6c_0 [source_id: 1]
|
||||
# print("\n" + "="*40)
|
||||
# print("正在生成非对角扩胞超胞 (Geng et al.)...")
|
||||
# matrix_geng = [[3, 0, 0], [2, 4, 0], [0, 0, 6]]
|
||||
# create_supercell_poscar("origin.cif", matrix_geng, "POSCAR_Geng_2160atoms")
|
||||
@@ -147,5 +147,5 @@ def make_pnma_poscar_from_cif(cif_path: str,
|
||||
print(f"写出 {out_poscar};总原子数 = {len(s)};组成 = {comp}")
|
||||
|
||||
if __name__=="__main__":
|
||||
# make_model3_poscar_from_cif("data/P3ma/model3.cif","data/P3ma/supercell_model4.poscar")
|
||||
make_pnma_poscar_from_cif("data/Pnma/origin.cif","data/Pnma/supercell_pnma.poscar",seed=42)
|
||||
# make_model3_poscar_from_cif("raw/P3ma/model3.cif","raw/P3ma/supercell_model4.poscar")
|
||||
make_pnma_poscar_from_cif("data/Pnma/origin.cif","raw/Pnma/supercell_pnma.poscar",seed=42)
|
||||
@@ -0,0 +1,240 @@
|
||||
import pymatgen.core as mg
|
||||
from pymatgen.io.cif import CifParser
|
||||
from pymatgen.transformations.standard_transformations import SupercellTransformation
|
||||
import random
|
||||
import os
|
||||
|
||||
|
||||
def create_ordered_structure_from_disordered(disordered_structure):
|
||||
"""
|
||||
手动将包含部分占位的无序结构转换为有序结构,借鉴plus.py的思路。
|
||||
"""
|
||||
s = disordered_structure.copy()
|
||||
|
||||
# 识别需要处理的部分占位
|
||||
# 根据 model3.cif, Y2(z≈0.488, occ=0.75), Y3(z≈-0.065, occ=0.25), Li2(z≈0.5, occ=0.5) [model3.cif]
|
||||
y2_indices, y3_indices, li2_indices = [], [], []
|
||||
|
||||
for i, site in enumerate(s.sites):
|
||||
# 使用z坐标来识别特定的部分占位
|
||||
z = site.frac_coords[2]
|
||||
if site.species_string == "Y":
|
||||
if abs(z - 0.488) < 0.05:
|
||||
y2_indices.append(i)
|
||||
elif abs(z - (-0.065)) < 0.05 or abs(z - (1 - 0.065)) < 0.05:
|
||||
y3_indices.append(i)
|
||||
elif site.species_string == "Li":
|
||||
if abs(z - 0.5) < 0.05:
|
||||
li2_indices.append(i)
|
||||
|
||||
# 根据占位率随机选择要保留的原子
|
||||
def choose_keep(indices, keep_fraction):
|
||||
num_to_keep = int(round(len(indices) * keep_fraction))
|
||||
return set(random.sample(indices, num_to_keep))
|
||||
|
||||
keep_y2 = choose_keep(y2_indices, 0.75)
|
||||
keep_y3 = choose_keep(y3_indices, 0.25)
|
||||
keep_li2 = choose_keep(li2_indices, 0.50)
|
||||
|
||||
# 找出所有需要删除的原子索引
|
||||
to_remove_indices = [i for i in y2_indices if i not in keep_y2]
|
||||
to_remove_indices.extend([i for i in y3_indices if i not in keep_y3])
|
||||
to_remove_indices.extend([i for i in li2_indices if i not in keep_li2])
|
||||
|
||||
# 从后往前删除,避免索引错位
|
||||
s.remove_sites(sorted(to_remove_indices, reverse=True))
|
||||
|
||||
# --- 关键修复步骤 ---
|
||||
# 最终清理,确保所有位点都是有序的
|
||||
for i, site in enumerate(s.sites):
|
||||
if not site.is_ordered:
|
||||
# 将Composition对象转换为字典,然后找到占位率最高的元素 [plus.py]
|
||||
species_dict = site.species.as_dict()
|
||||
main_specie = max(species_dict.items(), key=lambda item: item[1])[0]
|
||||
s.replace(i, main_specie)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def create_supercells_from_file(cif_path, output_path="."):
|
||||
"""
|
||||
根据给定的CIF文件路径,生成三种不同尺寸和缺陷的超胞,并保存为POSCAR文件。
|
||||
"""
|
||||
if not os.path.exists(cif_path):
|
||||
print(f"错误: 文件 '{cif_path}' 不存在。")
|
||||
return
|
||||
|
||||
print(f"正在从 {cif_path} 读取结构...")
|
||||
parser = CifParser(cif_path)
|
||||
disordered_structure = parser.parse_structures(primitive=False)[0]
|
||||
|
||||
structure = create_ordered_structure_from_disordered(disordered_structure)
|
||||
print(f"成功将无序结构转换为一个包含 {len(structure)} 个原子的有序单胞。")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# 任务一:生成60原子超胞 (无缺陷)
|
||||
print("\n--- 正在生成 60原子无缺陷超胞 (1x1x2) ---")
|
||||
tf_60 = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 2]])
|
||||
sc_60_no_defect = tf_60.apply_transformation(structure)
|
||||
print(f"原子总数: {len(sc_60_no_defect)}, 化学式: {sc_60_no_defect.composition.reduced_formula}")
|
||||
sc_60_no_defect.to(fmt="poscar", filename=os.path.join(output_path, "POSCAR_60_no_defect"))
|
||||
print(f"已保存文件: {os.path.join(output_path, 'POSCAR_60_no_defect')}")
|
||||
|
||||
# 任务二:生成60原子超胞 (含一对反位缺陷)
|
||||
print("\n--- 正在生成 60原子含一对反位缺陷超胞 ---")
|
||||
sc_60_defect = sc_60_no_defect.copy()
|
||||
li_indices = [i for i, site in enumerate(sc_60_defect.sites) if site.species_string == 'Li']
|
||||
y_indices = [i for i, site in enumerate(sc_60_defect.sites) if site.species_string == 'Y']
|
||||
|
||||
if li_indices and y_indices:
|
||||
li_swap_idx, y_swap_idx = random.choice(li_indices), random.choice(y_indices)
|
||||
sc_60_defect.replace(li_swap_idx, "Y")
|
||||
sc_60_defect.replace(y_swap_idx, "Li")
|
||||
print(f"成功引入一对反位缺陷。浓度: {2 / (len(li_indices) + len(y_indices)) * 100:.2f}%")
|
||||
sc_60_defect.to(fmt="poscar", filename=os.path.join(output_path, "POSCAR_60_antisite_defect"))
|
||||
print(f"已保存文件: {os.path.join(output_path, 'POSCAR_60_antisite_defect')}")
|
||||
|
||||
# 任务三:生成90原子超胞 (含一对反位缺陷)
|
||||
print("\n--- 正在生成 90原子含一对反位缺陷超胞 ---")
|
||||
tf_90 = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 3]])
|
||||
sc_90_no_defect = tf_90.apply_transformation(structure)
|
||||
sc_90_defect = sc_90_no_defect.copy()
|
||||
li_indices = [i for i, site in enumerate(sc_90_defect.sites) if site.species_string == 'Li']
|
||||
y_indices = [i for i, site in enumerate(sc_90_defect.sites) if site.species_string == 'Y']
|
||||
|
||||
if li_indices and y_indices:
|
||||
li_swap_idx, y_swap_idx = random.choice(li_indices), random.choice(y_indices)
|
||||
sc_90_defect.replace(li_swap_idx, "Y")
|
||||
sc_90_defect.replace(y_swap_idx, "Li")
|
||||
print(f"原子总数: {len(sc_90_defect)}, 浓度: {2 / (len(li_indices) + len(y_indices)) * 100:.2f}%")
|
||||
sc_90_defect.to(fmt="poscar", filename=os.path.join(output_path, "POSCAR_90_antisite_defect"))
|
||||
print(f"已保存文件: {os.path.join(output_path, 'POSCAR_90_antisite_defect')}")
|
||||
|
||||
|
||||
def create_ordered_p3ma_structure(disordered_structure):
|
||||
"""
|
||||
手动将P3ma相的无序结构(包含Y2, Y3, Li2的部分占位)转换为有序结构。
|
||||
"""
|
||||
s = disordered_structure.copy()
|
||||
|
||||
# 根据 model3.cif, 识别Y2(z≈0.488, occ=0.75), Y3(z≈-0.065, occ=0.25), Li2(z≈0.5, occ=0.5) [model3.cif]
|
||||
y2_indices, y3_indices, li2_indices = [], [], []
|
||||
|
||||
for i, site in enumerate(s.sites):
|
||||
z = site.frac_coords[2]
|
||||
if site.species_string == "Y":
|
||||
if abs(z - 0.488) < 0.05:
|
||||
y2_indices.append(i)
|
||||
elif abs(z - (-0.065)) < 0.05 or abs(z - (1 - 0.065)) < 0.05:
|
||||
y3_indices.append(i)
|
||||
elif site.species_string == "Li":
|
||||
if abs(z - 0.5) < 0.05:
|
||||
li2_indices.append(i)
|
||||
|
||||
# 根据占位率随机选择要保留的原子
|
||||
def choose_keep(indices, keep_fraction):
|
||||
num_to_keep = int(round(len(indices) * keep_fraction))
|
||||
return set(random.sample(indices, num_to_keep))
|
||||
|
||||
keep_y2 = choose_keep(y2_indices, 0.75)
|
||||
keep_y3 = choose_keep(y3_indices, 0.25)
|
||||
keep_li2 = choose_keep(li2_indices, 0.50)
|
||||
|
||||
# 找出所有需要删除的原子索引
|
||||
to_remove_indices = [i for i in y2_indices if i not in keep_y2]
|
||||
to_remove_indices.extend([i for i in y3_indices if i not in keep_y3])
|
||||
to_remove_indices.extend([i for i in li2_indices if i not in keep_li2])
|
||||
|
||||
s.remove_sites(sorted(to_remove_indices, reverse=True))
|
||||
|
||||
# 最终清理,确保所有位点都是有序的
|
||||
for i, site in enumerate(s.sites):
|
||||
if not site.is_ordered:
|
||||
species_dict = site.species.as_dict()
|
||||
main_specie = max(species_dict.items(), key=lambda item: item[1])[0]
|
||||
s.replace(i, main_specie)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def create_multiple_p3ma_supercells(cif_path, num_configs=5, output_path="."):
|
||||
"""
|
||||
读取P3ma相CIF,为不同尺寸的超胞生成多个具有不同反位缺陷位置的构型。
|
||||
"""
|
||||
if not os.path.exists(cif_path):
|
||||
print(f"错误: 文件 '{cif_path}' 不存在。")
|
||||
return
|
||||
|
||||
print(f"正在从 {cif_path} 读取P3ma结构...")
|
||||
parser = CifParser(cif_path)
|
||||
disordered_structure = parser.parse_structures(primitive=False)[0]
|
||||
|
||||
structure = create_ordered_p3ma_structure(disordered_structure)
|
||||
print(f"成功将无序P3ma结构转换为一个包含 {len(structure)} 个原子的有序单胞。")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
target_sizes = [60, 90]
|
||||
for size in target_sizes:
|
||||
print(f"\n--- 正在为约 {size} 原子的版本生成 {num_configs} 个不同构型 ---")
|
||||
|
||||
# 1. 构建基准超胞
|
||||
if size == 60:
|
||||
tf = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 2]])
|
||||
filename_suffix = "60_approx"
|
||||
else: # size == 90
|
||||
tf = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 3]])
|
||||
filename_suffix = "90_approx"
|
||||
|
||||
base_supercell = tf.apply_transformation(structure)
|
||||
print(f"已生成基准超胞,实际原子数: {len(base_supercell)}")
|
||||
|
||||
li_indices = [i for i, site in enumerate(base_supercell.sites) if site.species_string == 'Li']
|
||||
y_indices = [i for i, site in enumerate(base_supercell.sites) if site.species_string == 'Y']
|
||||
|
||||
if not li_indices or not y_indices:
|
||||
print("错误:在超胞中未找到足够的Li或Y原子来引入缺陷。")
|
||||
continue
|
||||
|
||||
# 2. 循环生成多个独特的缺陷构型
|
||||
used_pairs = set()
|
||||
for i in range(num_configs):
|
||||
defect_supercell = base_supercell.copy()
|
||||
|
||||
# 确保随机选择的交换对是全新的
|
||||
# 增加一个尝试次数上限,防止在原子数很少时陷入死循环
|
||||
max_tries = len(li_indices) * len(y_indices)
|
||||
for _ in range(max_tries):
|
||||
li_swap_idx = random.choice(li_indices)
|
||||
y_swap_idx = random.choice(y_indices)
|
||||
pair = tuple(sorted((li_swap_idx, y_swap_idx)))
|
||||
if pair not in used_pairs:
|
||||
used_pairs.add(pair)
|
||||
break
|
||||
else:
|
||||
print(f" 警告: 未能找到更多独特的交换对,已停止在第 {i} 个构型。")
|
||||
break
|
||||
|
||||
# 引入缺陷
|
||||
defect_supercell.replace(li_swap_idx, "Y")
|
||||
defect_supercell.replace(y_swap_idx, "Li")
|
||||
|
||||
print(f" 配置 {i}: 成功引入一对反位缺陷 (Li at index {li_swap_idx} <-> Y at index {y_swap_idx})。")
|
||||
|
||||
# 3. 保存为带编号的POSCAR文件
|
||||
poscar_filename = f"POSCAR_P3ma_{filename_suffix}_antisite_defect_{i}"
|
||||
poscar_path = os.path.join(output_path, poscar_filename)
|
||||
defect_supercell.to(fmt="poscar", filename=poscar_path)
|
||||
print(f" 已保存文件: {poscar_path}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
# --- 使用方法 ---
|
||||
# 1. 将您的CIF文件保存,例如命名为 "Li3YCl6.cif"
|
||||
# 2. 将文件名作为参数传递给函数
|
||||
cif_file_path = "data/P3ma/model3.cif" # 修改为您的CIF文件名
|
||||
output_directory = "raw/P3ma/output" # 可以指定一个输出目录
|
||||
|
||||
# create_supercells_from_file(cif_file_path, output_directory)
|
||||
create_multiple_p3ma_supercells(cif_file_path,output_path=output_directory)
|
||||
print("所有任务完成!")
|
||||
115
dpgen/supercell_make_pnma.py
Normal file
115
dpgen/supercell_make_pnma.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import pymatgen.core as mg
|
||||
from pymatgen.io.cif import CifParser
|
||||
from pymatgen.transformations.standard_transformations import SupercellTransformation
|
||||
import random
|
||||
import os
|
||||
|
||||
|
||||
def create_ordered_pnma_structure(disordered_structure):
|
||||
"""
|
||||
手动将Pnma相的无序结构(主要为Li的部分占位)转换为有序结构。
|
||||
"""
|
||||
s = disordered_structure.copy()
|
||||
|
||||
# 根据origin.cif, Li位点的占位率为0.75 [5]
|
||||
partial_li_indices = [i for i, site in enumerate(s.sites) if "Li" in site.species and not site.is_ordered]
|
||||
|
||||
# 根据0.75的占位率随机选择要保留的Li原子
|
||||
num_to_keep = int(round(len(partial_li_indices) * 0.75))
|
||||
keep_indices = set(random.sample(partial_li_indices, num_to_keep))
|
||||
|
||||
# 找出需要删除的原子索引
|
||||
to_remove_indices = [i for i in partial_li_indices if i not in keep_indices]
|
||||
|
||||
s.remove_sites(sorted(to_remove_indices, reverse=True))
|
||||
|
||||
# 重新创建一个新的、完全有序的结构,避免任何副作用
|
||||
ordered_species = []
|
||||
ordered_coords = []
|
||||
for site in s.sites:
|
||||
# 只取每个位点的主要元素
|
||||
main_specie = site.species.elements[0]
|
||||
ordered_species.append(main_specie)
|
||||
ordered_coords.append(site.frac_coords)
|
||||
|
||||
final_structure = mg.Structure(s.lattice, ordered_species, ordered_coords)
|
||||
|
||||
return final_structure
|
||||
|
||||
|
||||
def create_multiple_pnma_supercells(cif_path, num_configs=3, output_path="."):
|
||||
"""
|
||||
读取Pnma相CIF,为不同尺寸的超胞生成多个具有不同反位缺陷位置的构型。
|
||||
"""
|
||||
if not os.path.exists(cif_path):
|
||||
print(f"错误: 文件 '{cif_path}' 不存在。")
|
||||
return
|
||||
|
||||
print(f"正在从 {cif_path} 读取Pnma结构...")
|
||||
parser = CifParser(cif_path)
|
||||
disordered_structure = parser.parse_structures(primitive=False)[0]
|
||||
|
||||
structure = create_ordered_pnma_structure(disordered_structure)
|
||||
print(f"成功将无序Pnma结构转换为一个包含 {len(structure)} 个原子的有序单胞。")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
target_sizes = [60, 90]
|
||||
for size in target_sizes:
|
||||
print(f"\n--- 正在为约 {size} 原子的版本生成 {num_configs} 个不同构型 ---")
|
||||
|
||||
# 1. 构建基准超胞
|
||||
if size == 60:
|
||||
tf = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 2]])
|
||||
filename_suffix = "60_approx"
|
||||
else: # size == 90
|
||||
tf = SupercellTransformation([[1, 0, 0], [0, 1, 0], [0, 0, 3]])
|
||||
filename_suffix = "90_approx"
|
||||
|
||||
base_supercell = tf.apply_transformation(structure)
|
||||
print(f"已生成基准超胞,实际原子数: {len(base_supercell)}")
|
||||
|
||||
li_indices = [i for i, site in enumerate(base_supercell.sites) if site.species_string == 'Li']
|
||||
y_indices = [i for i, site in enumerate(base_supercell.sites) if site.species_string == 'Y']
|
||||
|
||||
if not li_indices or not y_indices:
|
||||
print("错误:在超胞中未找到足够的Li或Y原子来引入缺陷。")
|
||||
continue
|
||||
|
||||
# 2. 循环生成多个独特的缺陷构型
|
||||
used_pairs = set()
|
||||
for i in range(num_configs):
|
||||
defect_supercell = base_supercell.copy()
|
||||
|
||||
# 确保随机选择的交换对是全新的
|
||||
while True:
|
||||
li_swap_idx = random.choice(li_indices)
|
||||
y_swap_idx = random.choice(y_indices)
|
||||
# 使用排序后的元组作为键,确保(a,b)和(b,a)被视为相同
|
||||
pair = tuple(sorted((li_swap_idx, y_swap_idx)))
|
||||
if pair not in used_pairs:
|
||||
used_pairs.add(pair)
|
||||
break
|
||||
|
||||
# 引入缺陷
|
||||
defect_supercell.replace(li_swap_idx, "Y")
|
||||
defect_supercell.replace(y_swap_idx, "Li")
|
||||
|
||||
print(f" 配置 {i}: 成功引入一对反位缺陷 (Li at index {li_swap_idx} <-> Y at index {y_swap_idx})。")
|
||||
|
||||
# 3. 保存为带编号的POSCAR文件
|
||||
poscar_filename = f"POSCAR_Pnma_{filename_suffix}_antisite_defect_{i}"
|
||||
poscar_path = os.path.join(output_path, poscar_filename)
|
||||
defect_supercell.to(fmt="poscar", filename=poscar_path)
|
||||
print(f" 已保存文件: {poscar_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 请将您的Pnma相CIF文件保存,并修改此路径
|
||||
# 这里我们使用您提供的参考文件名 'origin.cif'
|
||||
cif_file_path = "data/Pnma/origin.cif"
|
||||
output_directory = "raw/Pnma/output"
|
||||
|
||||
create_multiple_pnma_supercells(cif_file_path, num_configs=3, output_path=output_directory)
|
||||
|
||||
print("\nPnma相处理完成!")
|
||||
197
dpgen/supercell_make_wangshuo.py
Normal file
197
dpgen/supercell_make_wangshuo.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import random
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from pymatgen.core import Structure, Element
|
||||
from pymatgen.io.lammps.data import LammpsData
|
||||
|
||||
# ASE 兜底(可选)
|
||||
try:
|
||||
from ase.io import write as ase_write
|
||||
from ase import Atoms as ASEAtoms
|
||||
HAS_ASE = True
|
||||
except Exception:
|
||||
HAS_ASE = False
|
||||
|
||||
# ===== 用户参数 =====
|
||||
cif_filename = "data/P3ma/origin.cif" # 你的输入 CIF(含部分占位)[5]
|
||||
supercell_matrix = [[1, 0, 0], [0, 1, 0], [0, 0, 2]] # 2×2×4 超胞(总复制数=16)[3]
|
||||
out_lammps = "lyc_P3m1_1x1x2_from_cif_ordered.vasp" # 输出 LAMMPS raw
|
||||
seed = 2025
|
||||
strict_count = True # 严格配额:每个父位点在超胞内按占位概率分配整数个原子/空位
|
||||
# ====================
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
def species_to_probs(site):
|
||||
"""
|
||||
将站点的物种占位转换为 [(species or None(vac), prob)] 列表,prob 归一化为和=1。
|
||||
若总占位 < 1,补一个 vacancy(None)。
|
||||
去掉氧化态,仅保留元素。
|
||||
"""
|
||||
sp_items = site.species.items()
|
||||
total = 0.0
|
||||
pairs = []
|
||||
for spc, occ in sp_items:
|
||||
# 转成 Element(剔除氧化态)
|
||||
try:
|
||||
e = Element(spc.symbol) if hasattr(spc, "symbol") else Element(str(spc))
|
||||
except Exception:
|
||||
e = Element(str(spc))
|
||||
pairs.append((e, float(occ)))
|
||||
total += float(occ)
|
||||
if total < 1.0 - 1e-10:
|
||||
pairs.append((None, 1.0 - total)) # vacancy
|
||||
total = 1.0
|
||||
# 归一化
|
||||
if abs(total - 1.0) > 1e-10:
|
||||
pairs = [(e, p / total) for (e, p) in pairs]
|
||||
return pairs
|
||||
|
||||
def draw_counts_from_probs(n, probs):
|
||||
"""
|
||||
给定复制数 n 和概率 probs,返回 {species/None: count},使计数和为 n。
|
||||
先按四舍五入,再用残差修正到总和= n。
|
||||
"""
|
||||
# 初分配
|
||||
counts = {sp: int(round(p * n)) for sp, p in probs}
|
||||
s = sum(counts.values())
|
||||
if s == n:
|
||||
return counts
|
||||
|
||||
# 残差排序:需要增加则按概率大的优先加;需要减少则按概率小的优先减
|
||||
if n > s:
|
||||
need = n - s
|
||||
probs_sorted = sorted(probs, key=lambda x: x[1], reverse=True)
|
||||
for i in range(need):
|
||||
sp = probs_sorted[i % len(probs_sorted)][0]
|
||||
counts[sp] = counts.get(sp, 0) + 1
|
||||
else:
|
||||
need = s - n
|
||||
probs_sorted = sorted(probs, key=lambda x: x[1]) # 先减概率小的
|
||||
idx = 0
|
||||
while need > 0 and idx < 50 * len(probs_sorted):
|
||||
sp = probs_sorted[idx % len(probs_sorted)][0]
|
||||
if counts.get(sp, 0) > 0:
|
||||
counts[sp] -= 1
|
||||
need -= 1
|
||||
idx += 1
|
||||
return counts
|
||||
|
||||
def collapse_disorder_to_ordered_supercell(struct, M, strict=True):
|
||||
"""
|
||||
处理步骤:
|
||||
1) 给原胞每个位点打 parent_id
|
||||
2) 扩胞到 M
|
||||
3) 以父位点为组,在组内(复制数 n)按占位概率分配整数个 species/空位到每个复制位点
|
||||
- 有序位点:所有复制直接保留
|
||||
- 无序位点/部分占位:严格配额或独立抽样
|
||||
4) 返回完全有序的超胞 Structure
|
||||
"""
|
||||
s0 = struct.copy()
|
||||
s0.add_site_property("parent_id", list(range(s0.num_sites)))
|
||||
|
||||
sc = s0.copy()
|
||||
sc.make_supercell(M)
|
||||
|
||||
# 按父位点分组
|
||||
groups = defaultdict(list)
|
||||
for i, site in enumerate(sc.sites):
|
||||
pid = sc.site_properties["parent_id"][i]
|
||||
groups[pid].append(i)
|
||||
|
||||
new_species = []
|
||||
new_fracs = []
|
||||
new_lat = sc.lattice
|
||||
|
||||
for pid, idx_list in groups.items():
|
||||
# 用该组第一个复制的站点定义占位
|
||||
site0 = sc[idx_list[0]]
|
||||
# 有序站点:直接全部保留(只有一种元素,且占位为1)
|
||||
if site0.is_ordered:
|
||||
species_elem = list(site0.species.keys())[0]
|
||||
for i in idx_list:
|
||||
new_species.append(species_elem)
|
||||
new_fracs.append(sc[i].frac_coords)
|
||||
continue
|
||||
|
||||
# 无序/部分占位:概率分配
|
||||
probs = species_to_probs(site0)
|
||||
n = len(idx_list)
|
||||
|
||||
if strict:
|
||||
counts = draw_counts_from_probs(n, probs)
|
||||
# 构造分配池并打乱
|
||||
pool = []
|
||||
for sp, c in counts.items():
|
||||
pool += [sp] * c
|
||||
random.shuffle(pool)
|
||||
# 分配到每个复制位点
|
||||
for i, sp in zip(idx_list, pool):
|
||||
if sp is None:
|
||||
continue # vacancy -> 删除该位点
|
||||
new_species.append(sp)
|
||||
new_fracs.append(sc[i].frac_coords)
|
||||
else:
|
||||
# 独立抽样
|
||||
import bisect
|
||||
species_list = [sp for sp, p in probs]
|
||||
cum = []
|
||||
ssum = 0.0
|
||||
for _, p in probs:
|
||||
ssum += p
|
||||
cum.append(ssum)
|
||||
for i in idx_list:
|
||||
r = random.random()
|
||||
j = bisect.bisect_left(cum, r)
|
||||
sp = species_list[j]
|
||||
if sp is None:
|
||||
continue
|
||||
new_species.append(sp)
|
||||
new_fracs.append(sc[i].frac_coords)
|
||||
|
||||
ordered_sc = Structure(new_lat, new_species, new_fracs, to_unit_cell=True, coords_are_cartesian=False)
|
||||
# 去除可能残留的氧化态(LAMMPS atomic 不需要)
|
||||
try:
|
||||
ordered_sc.remove_oxidation_states()
|
||||
except Exception:
|
||||
pass
|
||||
return ordered_sc
|
||||
|
||||
# 1) 读取 CIF(含部分占位)
|
||||
s_in = Structure.from_file(cif_filename, primitive=False)
|
||||
print(f"读入: {cif_filename}, 原胞位点: {s_in.num_sites}, 有序?: {s_in.is_ordered}")
|
||||
|
||||
# 2) 在 2×2×4 超胞上固化部分占位 -> 完全有序超胞
|
||||
ordered_sc = collapse_disorder_to_ordered_supercell(s_in, supercell_matrix, strict=strict_count)
|
||||
print(f"生成有序超胞: 位点数={ordered_sc.num_sites}, 有序?: {ordered_sc.is_ordered}")
|
||||
|
||||
# 3) 打印元素计数,核对化学计量
|
||||
elem_count = Counter([sp.symbol for sp in ordered_sc.species])
|
||||
print("元素计数:", dict(elem_count))
|
||||
|
||||
# 4) 写 LAMMPS raw(pymatgen,失败则 ASE 兜底)
|
||||
wrote = False
|
||||
try:
|
||||
ldata = LammpsData.from_structure(ordered_sc, atom_style="atomic")
|
||||
ldata.write_file(out_lammps)
|
||||
wrote = True
|
||||
print(f"已写出 LAMMPS raw: {out_lammps} (pymatgen)")
|
||||
except Exception as e:
|
||||
print("pymatgen 写 LAMMPS raw 失败:", e)
|
||||
|
||||
if not wrote and HAS_ASE:
|
||||
try:
|
||||
ase_atoms = ASEAtoms(
|
||||
symbols=[sp.symbol for sp in ordered_sc.species],
|
||||
positions=ordered_sc.cart_coords,
|
||||
cell=ordered_sc.lattice.matrix,
|
||||
pbc=True
|
||||
)
|
||||
ase_write(out_lammps, ase_atoms, format="lammps-raw", atom_style="atomic")
|
||||
wrote = True
|
||||
print(f"已写出 LAMMPS raw: {out_lammps} (ASE)")
|
||||
except Exception as e:
|
||||
print("ASE 写 LAMMPS raw 也失败:", e)
|
||||
|
||||
if not wrote:
|
||||
print("写 LAMMPS raw 失败,请把错误信息发我。")
|
||||
14
mcp/main.py
14
mcp/main.py
@@ -11,15 +11,16 @@ from starlette.routing import Mount
|
||||
from system_tools import create_system_mcp
|
||||
from materialproject_mcp import create_materials_mcp
|
||||
from softBV_remake import create_softbv_mcp
|
||||
from paper_search_mcp import create_paper_search_mcp
|
||||
# from paper_search_mcp import create_paper_search_mcp
|
||||
from topological_analysis_models import create_topological_analysis_mcp
|
||||
|
||||
from vasp_mcp import create_vasp_mcp
|
||||
# 创建 MCP 实例
|
||||
system_mcp = create_system_mcp()
|
||||
materials_mcp = create_materials_mcp()
|
||||
softbv_mcp = create_softbv_mcp()
|
||||
paper_search_mcp = create_paper_search_mcp()
|
||||
# paper_search_mcp = create_paper_search_mcp()
|
||||
topological_analysis_mcp = create_topological_analysis_mcp()
|
||||
vasp_mcp = create_vasp_mcp()
|
||||
# 在 Starlette 的 lifespan 中启动 MCP 的 session manager
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette):
|
||||
@@ -27,8 +28,9 @@ async def lifespan(app: Starlette):
|
||||
await stack.enter_async_context(system_mcp.session_manager.run())
|
||||
await stack.enter_async_context(materials_mcp.session_manager.run())
|
||||
await stack.enter_async_context(softbv_mcp.session_manager.run())
|
||||
await stack.enter_async_context(paper_search_mcp.session_manager.run())
|
||||
# await stack.enter_async_context(paper_search_mcp.session_manager.run())
|
||||
await stack.enter_async_context(topological_analysis_mcp.session_manager.run())
|
||||
await stack.enter_async_context(vasp_mcp.session_manager.run())
|
||||
yield # 服务器运行期间
|
||||
# 退出时自动清理
|
||||
|
||||
@@ -39,8 +41,9 @@ app = Starlette(
|
||||
Mount("/system", app=system_mcp.streamable_http_app()),
|
||||
Mount("/materials", app=materials_mcp.streamable_http_app()),
|
||||
Mount("/softBV", app=softbv_mcp.streamable_http_app()),
|
||||
Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
|
||||
# Mount("/papersearch",app=paper_search_mcp.streamable_http_app()),
|
||||
Mount("/topologicalAnalysis",app=topological_analysis_mcp.streamable_http_app()),
|
||||
Mount("/vasp",app=vasp_mcp.streamable_http_app()),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -52,4 +55,5 @@ app = Starlette(
|
||||
# http://localhost:8000/softBV
|
||||
# http://localhost:8000/papersearch
|
||||
# http://localhost:8000/topologicalAnalysis
|
||||
# http://localhost:8000/vasp
|
||||
# 如果需要浏览器客户端访问(CORS 暴露 Mcp-Session-Id),请参考 README 中的 CORS 配置示例 [1]
|
||||
@@ -560,7 +560,7 @@ def _parse_print_cube_output(raw_text: str) -> PrintCubeResult:
|
||||
matrix.append(parts)
|
||||
return matrix
|
||||
|
||||
# Find key lines and parse data
|
||||
# Find key lines and parse raw
|
||||
name = re.search(r"CELL: name: (.*)", raw_text).group(1).strip()
|
||||
total_atoms = int(re.search(r"CELL: total atom: (\d+)", raw_text).group(1))
|
||||
|
||||
|
||||
362
mcp/vasp_mcp.py
Normal file
362
mcp/vasp_mcp.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# vasp_mcp.py
|
||||
import os
|
||||
import posixpath
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncssh
|
||||
from pymatgen.core import Structure
|
||||
from io import StringIO
|
||||
|
||||
from mcp.server.fastmcp import FastMCP, Context
|
||||
from mcp.server.session import ServerSession
|
||||
from pymatgen.io.cif import CifParser
|
||||
|
||||
# --- VASP 特定配置 ---
|
||||
REMOTE_HOST = os.getenv("REMOTE_HOST", '202.121.182.208')
|
||||
REMOTE_USER = os.getenv("REMOTE_USER", 'koko125')
|
||||
PRIVATE_KEY_PATH = os.getenv("PRIVATE_KEY_PATH", 'D:/tool/tool/id_rsa.txt')
|
||||
|
||||
# VASP 赝势库和沙箱路径
|
||||
POTCAR_BASE_PATH = "/cluster/home/koko125/tool/potcar_mcp"
|
||||
DEFAULT_SANDBOX = f"/cluster/home/{REMOTE_USER}/sandbox"
|
||||
VASP_ENV_SCRIPT = "/cluster/home/koko125/intel/oneapi/setvars.sh"
|
||||
VASP_MPI_RUN_CMD = "mpirun -np 4 /cluster/home/koko125/vasp/bin_cpu/vasp_std"
|
||||
def shell_quote(arg: str) -> str:
|
||||
"""安全地引用 shell 参数"""
|
||||
return "'" + str(arg).replace("'", "'\"'\"'") + "'"
|
||||
|
||||
|
||||
# --- 定义共享上下文 ---
|
||||
@dataclass
|
||||
class VaspContext:
|
||||
ssh_connection: asyncssh.SSHClientConnection
|
||||
potcar_base: str
|
||||
sandbox_path: str
|
||||
env_script: str
|
||||
mpi_run_cmd: str # <-- 添加这一行
|
||||
|
||||
# --- 定义生命周期管理器 ---
|
||||
@asynccontextmanager
|
||||
async def vasp_lifespan(_server: FastMCP) -> AsyncIterator[VaspContext]:
|
||||
"""建立 SSH 连接并注入 VASP 上下文。"""
|
||||
conn: asyncssh.SSHClientConnection | None = None
|
||||
try:
|
||||
conn = await asyncssh.connect(
|
||||
REMOTE_HOST, username=REMOTE_USER, client_keys=[PRIVATE_KEY_PATH], known_hosts=None
|
||||
)
|
||||
yield VaspContext(
|
||||
ssh_connection=conn,
|
||||
potcar_base=POTCAR_BASE_PATH,
|
||||
sandbox_path=DEFAULT_SANDBOX,
|
||||
env_script=VASP_ENV_SCRIPT,
|
||||
mpi_run_cmd=VASP_MPI_RUN_CMD
|
||||
)
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
|
||||
# --- VASP MCP 工厂函数 ---
|
||||
def create_vasp_mcp() -> FastMCP:
|
||||
"""创建包含 VASP 辅助工具的 MCP 实例。"""
|
||||
mcp = FastMCP(
|
||||
name="VASP POTCAR Tools",
|
||||
instructions="用于查询和准备 VASP POTCAR 文件的专用工具集。",
|
||||
lifespan=vasp_lifespan,
|
||||
streamable_http_path="/",
|
||||
)
|
||||
|
||||
# 沿用 system_tools.py 中的安全路径拼接函数,确保目标路径安全
|
||||
def _safe_join_sandbox(sandbox_root: str, relative_path: str) -> str:
|
||||
rel = (relative_path or ".").strip().lstrip("/")
|
||||
combined = posixpath.normpath(posixpath.join(sandbox_root, rel))
|
||||
root_norm = sandbox_root.rstrip("/")
|
||||
if combined != root_norm and not combined.startswith(root_norm + "/"):
|
||||
raise ValueError("路径越界:目标路径必须在沙箱目录内")
|
||||
if ".." in combined.split("/"):
|
||||
raise ValueError("非法路径:不允许使用 '..'")
|
||||
return combined
|
||||
|
||||
# --- 工具 1: 列出可用的赝势库类型 ---
|
||||
@mcp.tool()
|
||||
async def list_potcar_types(ctx: Context[ServerSession, VaspContext]) -> list[str]:
|
||||
"""
|
||||
列出中央赝势库中所有可用的赝势类型 (例如 'PBE_potpaw', 'PAW_GGA_PBE')。
|
||||
"""
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
try:
|
||||
# 使用 ls -F, 目录会带上 '/' 后缀
|
||||
result = await conn.run(f"ls -F {shell_quote(app_ctx.potcar_base)}", check=True)
|
||||
potcar_types = [
|
||||
name.strip('/') for name in result.stdout.strip().split() if name.endswith('/')
|
||||
]
|
||||
await ctx.info(f"发现可用赝势库: {potcar_types}")
|
||||
return potcar_types
|
||||
except Exception as e:
|
||||
await ctx.error(f"列出 POTCAR 类型失败: {e}")
|
||||
return []
|
||||
|
||||
# --- 工具 2 (新): 查询指定赝势库中的可用元素 ---
|
||||
@mcp.tool()
|
||||
async def query_potcar_elements(ctx: Context[ServerSession, VaspContext], potcar_type: str) -> list[str]:
|
||||
"""
|
||||
查询指定类型的赝势库中包含哪些元素的赝势。
|
||||
"""
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
# 安全地构建源目录路径
|
||||
source_dir = posixpath.join(app_ctx.potcar_base, potcar_type)
|
||||
if ".." in potcar_type or "/" in potcar_type:
|
||||
await ctx.error(f"非法的赝势库类型: {potcar_type}")
|
||||
return []
|
||||
|
||||
await ctx.info(f"查询目录 '{source_dir}' 中的可用元素...")
|
||||
try:
|
||||
result = await conn.run(f"ls -F {shell_quote(source_dir)}", check=True)
|
||||
elements = [
|
||||
name.strip('/') for name in result.stdout.strip().split() if name.endswith('/')
|
||||
]
|
||||
return elements
|
||||
except asyncssh.ProcessError as e:
|
||||
msg = f"查询元素失败: 赝势库 '{potcar_type}' 可能不存在。Stderr: {e.stderr}"
|
||||
await ctx.error(msg)
|
||||
return []
|
||||
except Exception as e:
|
||||
await ctx.error(f"查询 POTCAR 元素时发生未知错误: {e}")
|
||||
return []
|
||||
|
||||
# --- 工具 3 (新): 从中央库安全地复制 POTCAR 文件到沙箱 ---
|
||||
@mcp.tool()
|
||||
async def copy_potcar_file(
|
||||
ctx: Context[ServerSession, VaspContext],
|
||||
potcar_type: str,
|
||||
element: str,
|
||||
destination_path: str
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
从中央赝势库安全地复制一个指定元素的 POTCAR 文件到用户沙箱中的目标路径。
|
||||
例如, 将 'PBE_potpaw' 库中的 'Si' 赝势复制到 'sio2_relax/POTCAR_Si'。
|
||||
"""
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
try:
|
||||
# 1. 安全地构建源文件路径 (只允许访问 potcar_base 下的子目录)
|
||||
if ".." in potcar_type or "/" in potcar_type or ".." in element or "/" in element:
|
||||
raise ValueError("非法的赝势库类型或元素名称。")
|
||||
source_file = posixpath.join(app_ctx.potcar_base, potcar_type, element, "POTCAR")
|
||||
|
||||
# 2. 安全地构建目标文件路径 (必须在沙箱内)
|
||||
dest_file_abs = _safe_join_sandbox(app_ctx.sandbox_path, destination_path)
|
||||
|
||||
# 3. 执行 cp 命令
|
||||
cmd = f"cp {shell_quote(source_file)} {shell_quote(dest_file_abs)}"
|
||||
await ctx.info(f"执行安全复制: cp {source_file} -> {dest_file_abs}")
|
||||
await conn.run(cmd, check=True)
|
||||
|
||||
return {"status": "success", "source": source_file, "destination": destination_path}
|
||||
|
||||
except asyncssh.ProcessError as e:
|
||||
msg = f"复制 POTCAR 失败。请检查 potcar_type 和 element 是否正确。Stderr: {e.stderr}"
|
||||
await ctx.error(msg)
|
||||
return {"status": "error", "message": msg}
|
||||
except ValueError as e:
|
||||
await ctx.error(f"路径验证失败: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
except Exception as e:
|
||||
await ctx.error(f"复制 POTCAR 时发生未知错误: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
# (可选) 我们仍然可以保留 cif_to_poscar 工具,因为它对于确定元素顺序非常有用
|
||||
@mcp.tool()
|
||||
def cif_to_poscar(cif_content: str, sort_structure: bool = True) -> dict[str, str]:
|
||||
"""将 CIF 文件内容转换为 VASP 的 POSCAR 文件内容,并提供有序的元素列表。"""
|
||||
# ... (此工具代码与之前版本相同)
|
||||
try:
|
||||
structure = Structure.from_str(cif_content, fmt="cif")
|
||||
if sort_structure:
|
||||
structure = structure.get_sorted_structure()
|
||||
|
||||
elements = [site.specie.symbol for site in structure]
|
||||
unique_elements = sorted(set(elements), key=elements.index)
|
||||
|
||||
poscar_string_io = StringIO()
|
||||
structure.to(fmt="poscar", file_obj=poscar_string_io)
|
||||
poscar_content = poscar_string_io.getvalue()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"poscar_content": poscar_content,
|
||||
"elements": " ".join(unique_elements)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": f"CIF 转换失败: {e}"}
|
||||
|
||||
@mcp.tool()
|
||||
def cif_to_poscar(cif_content: str,ctx: Context[ServerSession, VaspContext], sort_structure: bool = True) -> dict[str, str]:
|
||||
"""
|
||||
将 CIF 文件内容稳健地转换为 VASP 的 POSCAR 文件内容。
|
||||
该工具会尝试多种解析策略来处理格式不规范的CIF文件。
|
||||
如果成功,返回 POSCAR 内容和用于生成 POTCAR 的有序元素列表。
|
||||
如果失败,返回详细的错误信息。
|
||||
|
||||
Args:
|
||||
cif_content (str): 包含晶体结构的 CIF 文件完整内容。
|
||||
sort_structure (bool): 是否对原子进行排序以匹配 Pymatgen 的 POTCAR 约定。默认为 True。
|
||||
"""
|
||||
|
||||
structure = None
|
||||
last_exception = None
|
||||
|
||||
# 忽略 Pymatgen 可能产生的警告,避免污染输出
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
# --- 策略 1: 标准解析 ---
|
||||
try:
|
||||
ctx.debug("尝试策略 1: 标准 CIF 解析...")
|
||||
structure = Structure.from_str(cif_content, fmt="cif")
|
||||
ctx.info("策略 1 成功: 标准解析完成。")
|
||||
except Exception as e:
|
||||
ctx.warning(f"策略 1 失败: {e}")
|
||||
last_exception = e
|
||||
|
||||
# --- 策略 2: 宽松解析 (忽略化合价检查) ---
|
||||
if structure is None:
|
||||
try:
|
||||
ctx.debug("尝试策略 2: 宽松解析 (不检查化合价)...")
|
||||
# 使用底层的 CifParser 并禁用化合价检查
|
||||
parser = CifParser.from_string(cif_content, check_valence=False)
|
||||
structure = parser.get_structures(primitive=True)[0]
|
||||
ctx.info("策略 2 成功: 宽松解析完成。")
|
||||
except Exception as e:
|
||||
ctx.warning(f"策略 2 失败: {e}")
|
||||
last_exception = e
|
||||
|
||||
# --- 策略 3: 使用原始坐标,不进行对称性处理 ---
|
||||
if structure is None:
|
||||
try:
|
||||
ctx.debug("尝试策略 3: 使用原始坐标 (primitive=False)...")
|
||||
parser = CifParser.from_string(cif_content)
|
||||
# 获取文件中的原始结构,而不是计算出的原胞
|
||||
structure = parser.get_structures(primitive=False)[0]
|
||||
ctx.info("策略 3 成功: 已使用原始坐标。")
|
||||
except Exception as e:
|
||||
ctx.warning(f"策略 3 失败: {e}")
|
||||
last_exception = e
|
||||
|
||||
# --- 如果所有策略都失败 ---
|
||||
if structure is None:
|
||||
error_message = (
|
||||
"无法从提供的 CIF 内容中解析出晶体结构。所有解析策略均已失败。\n"
|
||||
f"最后遇到的错误是: {last_exception}\n"
|
||||
"建议: 请检查 CIF 文件格式是否严重损坏。AI 可以尝试重新生成 CIF,或直接请求用户提供 POSCAR 文件内容。"
|
||||
)
|
||||
ctx.error(error_message)
|
||||
return {"status": "error", "message": error_message}
|
||||
|
||||
# --- 成功后处理 ---
|
||||
try:
|
||||
# 排序结构以确保元素顺序与 pymatgen 的 POTCAR 生成逻辑一致
|
||||
if sort_structure:
|
||||
structure = structure.get_sorted_structure()
|
||||
|
||||
# 从结构中获取有序的元素列表
|
||||
elements = [site.specie.symbol for site in structure.composition.elements]
|
||||
|
||||
# 生成 POSCAR 内容
|
||||
poscar_string_io = StringIO()
|
||||
# 使用 vasp5 格式,确保元素行存在
|
||||
structure.to(fmt="poscar", file_obj=poscar_string_io)
|
||||
poscar_content = poscar_string_io.getvalue()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"poscar_content": poscar_content,
|
||||
"elements": " ".join(elements) # 以空格分隔的字符串形式提供有序元素
|
||||
}
|
||||
except Exception as e:
|
||||
# 这种情况很少见,但可能在结构后处理时发生
|
||||
final_error = f"结构解析成功,但在生成POSCAR时出错: {e}"
|
||||
ctx.error(final_error)
|
||||
return {"status": "error", "message": final_error}
|
||||
|
||||
@mcp.tool()
|
||||
async def test_vasp_run(
|
||||
ctx: Context[ServerSession, VaspContext],
|
||||
job_directory: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
在指定目录中启动 VASP 的“精简模式”(--lite)以进行快速测试。
|
||||
此模式会完整解析所有输入文件(INCAR, POSCAR等)并检查参数,
|
||||
但不会开始实际的离子或电子步计算,通常在数秒内完成。
|
||||
|
||||
Args:
|
||||
job_directory (str): 包含所有 VASP 输入文件的远程沙箱子目录。
|
||||
"""
|
||||
app_ctx = ctx.request_context.lifespan_context
|
||||
conn = app_ctx.ssh_connection
|
||||
|
||||
try:
|
||||
# 安全地构建工作目录的绝对路径
|
||||
workdir_abs = _safe_join_sandbox(app_ctx.sandbox_path, job_directory)
|
||||
|
||||
await ctx.info(f"在 '{workdir_abs}' 中开始 VASP 输入文件验证 (精简模式)...")
|
||||
|
||||
# 关键:在 VASP 执行命令后附加 --lite 标志
|
||||
# 注意: mpirun [options] <executable> [args]
|
||||
# 我们需要将 --lite 附加到 vasp_std 后面
|
||||
# 假设 app_ctx.mpi_run_cmd 是 "mpirun -np 4 .../vasp_std"
|
||||
command_with_lite = f"{app_ctx.mpi_run_cmd} --lite"
|
||||
|
||||
# 构建完整的 shell 命令,以激活环境并执行
|
||||
# 这里的 f-string 已被修正,以避免多行和嵌套问题
|
||||
inner_command = (
|
||||
f"cd {shell_quote(workdir_abs)}; "
|
||||
f"source {shell_quote(app_ctx.env_script)}; "
|
||||
f"{command_with_lite}"
|
||||
)
|
||||
full_shell_command = f"bash -lc {shell_quote(inner_command)}"
|
||||
|
||||
# 使用简单的 conn.run 等待命令完成。check=False 因为我们想自己处理非零退出
|
||||
proc = await conn.run(full_shell_command, check=False)
|
||||
|
||||
# 分析结果
|
||||
stdout = proc.stdout or ""
|
||||
stderr = proc.stderr or ""
|
||||
|
||||
# VASP 成功完成初始化并正常退出的标志通常是这条信息
|
||||
success_indicator = "General timing and accounting informations for this job"
|
||||
|
||||
if proc.exit_status == 0 and success_indicator in stdout:
|
||||
test_passed = True
|
||||
conclusion = "测试成功:VASP 输入文件有效,所有参数均被正确解析。"
|
||||
await ctx.info(conclusion)
|
||||
else:
|
||||
test_passed = False
|
||||
conclusion = "测试失败:VASP 报告了错误或未能正常完成初始化。请检查下面的 stderr 输出。"
|
||||
await ctx.warning(conclusion)
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"test_passed": test_passed,
|
||||
"conclusion": conclusion,
|
||||
"exit_status": proc.exit_status,
|
||||
"stdout_preview": "\n".join(stdout.splitlines()[-20:]), # 只看最后20行,避免刷屏
|
||||
"stderr": stderr
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
msg = f"执行 test_vasp_run 工具时发生意外错误: {e}"
|
||||
await ctx.error(msg)
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
# 确保这个 _safe_join_sandbox 辅助函数存在于 create_vasp_mcp 函数内部或可被访问
|
||||
|
||||
return mcp
|
||||
122
rss/nature_filter_rss.py
Normal file
122
rss/nature_filter_rss.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import Bfeedparser
|
||||
import requests
|
||||
from feedgen.feed import FeedGenerator
|
||||
from datetime import datetime, timezone
|
||||
import time
|
||||
|
||||
# --- 1. 配置区 ---
|
||||
|
||||
# 你的关键词列表,不区分大小写
|
||||
KEYWORDS = ['solid-state battery', 'lithium metal', 'anode-free', 'electrolyte']
|
||||
|
||||
# 你想监控的 Nature 系列期刊的 RSS 源
|
||||
SOURCE_FEEDS = {
|
||||
'Nature': 'https://www.nature.com/nature/rss/current',
|
||||
'Nat Commun': 'https://www.nature.com/ncomms/rss/current',
|
||||
'Nat Energy': 'https://www.nature.com/nenergy/rss/current',
|
||||
'Nat Mater': 'https://www.nature.com/nmat/rss/current',
|
||||
'Nat Nanotechnol': 'https://www.nature.com/nnano/rss/current',
|
||||
'Nat Sustain': 'https://www.nature.com/natsustain/rss/current',
|
||||
'Nat Chem': 'https://www.nature.com/nchem/rss/current',
|
||||
'Nat Synth': 'https://www.nature.com/natsynth/rss/current',
|
||||
'Nat Catal': 'https://www.nature.com/natcatal/rss/current',
|
||||
'Nat Rev Mater': 'https://www.nature.com/natrevmat/rss/current',
|
||||
'Nat Rev Chem': 'https://www.nature.com/natrevchem/rss/current',
|
||||
'Nat Rev Earth Environ': 'https://www.nature.com/natrevearthenviron/rss/current',
|
||||
}
|
||||
|
||||
# 输出的 RSS 文件路径,确保 ttrss 能通过 web 服务器访问到它
|
||||
OUTPUT_FILE = '/var/www/html/rss/nature_filtered_feed.xml'
|
||||
|
||||
|
||||
# --- 2. 脚本核心逻辑 ---
|
||||
N
|
||||
def fetch_and_filter():
|
||||
"""获取所有源,过滤文章,返回一个匹配文章的列表"""
|
||||
|
||||
print(f"Starting feed fetch at {datetime.now()}")
|
||||
|
||||
matched_articles = []
|
||||
# 使用集合来存储已添加文章的链接,防止重复
|
||||
seen_links = set()
|
||||
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
|
||||
for name, url in SOURCE_FEEDS.items():
|
||||
print(f" -> Fetching from {name}...")
|
||||
try:
|
||||
# 使用 requests 获取内容,可以更好地处理网络问题和伪装 User-Agent
|
||||
response = requests.get(url, headers=headers, timeout=15)
|
||||
response.raise_for_status() # 确保请求成功
|
||||
|
||||
# 使用 feedparser 解析获取到的内容
|
||||
feed = feedparser.parse(response.content)
|
||||
|
||||
for entry in feed.entries:
|
||||
# 检查文章链接是否已处理过
|
||||
if entry.link in seen_links:
|
||||
continue
|
||||
|
||||
# 将标题和摘要拼接在一起,方便搜索
|
||||
content_to_check = (entry.title + ' ' + entry.get('summary', '')).lower()
|
||||
|
||||
# 检查是否有任何一个关键词出现在内容中
|
||||
if any(keyword.lower() in content_to_check for keyword in KEYWORDS):
|
||||
print(f" [MATCH FOUND] in {name}: {entry.title}")
|
||||
|
||||
# 为了在 RSS 阅读器中更好地展示,我们在标题前加上来源期刊
|
||||
entry.title = f"[{name}] {entry.title}"
|
||||
matched_articles.append(entry)
|
||||
seen_links.add(entry.link)
|
||||
|
||||
# 友好请求,避免过于频繁
|
||||
time.sleep(1)
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f" [ERROR] Could not fetch {name}: {e}")
|
||||
except Exception as e:
|
||||
print(f" [ERROR] An unexpected error occurred for {name}: {e}")
|
||||
|
||||
print(f"\nFound {len(matched_articles)} matching articles in total.")
|
||||
return matched_articles
|
||||
|
||||
|
||||
def generate_filtered_feed(articles):
|
||||
"""根据过滤后的文章列表生成新的 RSS 文件"""
|
||||
|
||||
fg = FeedGenerator()
|
||||
fg.title('My Filtered Nature Research Feed')
|
||||
fg.link(href='https://www.nature.com', rel='alternate')
|
||||
fg.description(f"Custom RSS feed for Nature journals, filtered by keywords: {', '.join(KEYWORDS)}")
|
||||
|
||||
# 按发布日期对文章进行排序(从新到旧)
|
||||
articles.sort(key=lambda x: x.get('published_parsed') or x.get('updated_parsed'), reverse=True)
|
||||
|
||||
for entry in articles:
|
||||
fe = fg.add_entry()
|
||||
fe.id(entry.link) # 使用文章链接作为唯一ID
|
||||
fe.title(entry.title)
|
||||
fe.link(href=entry.link)
|
||||
# feedparser 已经帮我们解析好了摘要
|
||||
fe.description(entry.get('summary', 'No summary available.'))
|
||||
|
||||
# 处理发布日期
|
||||
pub_date = entry.get('published_parsed')
|
||||
if pub_date:
|
||||
# 转换为带时区的 datetime 对象
|
||||
fe.published(datetime.fromtimestamp(time.mktime(pub_date)).replace(tzinfo=timezone.utc))
|
||||
|
||||
# 写入文件
|
||||
fg.rss_file(OUTPUT_FILE, pretty=True)
|
||||
print(f"Successfully generated new RSS feed at {OUTPUT_FILE}")
|
||||
|
||||
|
||||
# --- 3. 主程序入口 ---
|
||||
if __name__ == "__main__":
|
||||
filtered_articles = fetch_and_filter()
|
||||
if filtered_articles:
|
||||
generate_filtered_feed(filtered_articles)
|
||||
else:
|
||||
print("No new matching articles found. RSS file not updated.")
|
||||
Reference in New Issue
Block a user