Files
screen/py/CSM_reconstruct.py
2025-12-14 12:57:34 +08:00

224 lines
8.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import numpy as np
import argparse
from tqdm import tqdm
from scipy.spatial import ConvexHull
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element
from pymatgen.analysis.chemenv.coordination_environments.coordination_geometry_finder import LocalGeometryFinder
# ================= 配置区域 =================
# 建议使用绝对路径,避免找不到文件夹
INPUT_DIR = "../../solidstate-tools/corner-sharing/data/1209/input" # 请确保这里有你的 .cif 文件
OUTPUT_DIR = "../output/CSM"
TARGET_ELEMENT = 'Li'
ENV_TYPE = 'both'
# ===========================================
class HiddenPrints:
'''用于隐藏 pymatgen 繁杂的输出'''
def __enter__(self):
self._original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout.close()
sys.stdout = self._original_stdout
def non_elements(struct):
"""
【关键修复】保留卤素(F, Cl, Br, I) 和其他阴离子,防止氯化物结构被清空。
"""
# 这里加入了 F, Cl, Br, I, P, Se, Te 等
anions_to_keep = {"O", "S", "N", "F", "Cl", "Br", "I", "P", "Se", "Te", "As", "Sb", "C"}
stripped = struct.copy()
species_to_remove = [el.symbol for el in stripped.composition.elements
if el.symbol not in anions_to_keep]
if species_to_remove:
stripped.remove_species(species_to_remove)
return stripped
def site_env(coord, struct, sp="Li", envtype='both'):
stripped = non_elements(struct)
# 如果剥离后结构为空(例如纯金属锂),直接返回
if len(stripped) == 0:
return {'csm': np.nan, 'vol': np.nan, 'type': 'Error_NoAnions'}
with_li = stripped.copy()
# 插入一个探测用的 Li 原子
with_li.append(sp, coord, coords_are_cartesian=False, validate_proximity=False)
# 尝试排序,如果因为部分占据导致排序失败,则使用原始顺序
try:
with_li = with_li.get_sorted_structure()
except:
pass
tet_oct_competition = []
# ---------------- 四面体 (Tet) 检测 ----------------
if envtype == 'both' or envtype == 'tet':
for dist in np.linspace(1, 4, 601): # 扫描距离 1A 到 4A
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 4:
continue
elif len(neigh) > 4:
break
neigh_coords = [i.coords for i in neigh]
try:
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["T:4"])
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
site_volume = ConvexHull(neigh_coords).volume
# 获取 CSM
csm_val = lgf.get_coordination_symmetry_measures()['T:4']['csm']
tet_env = {'csm': csm_val, 'vol': site_volume, 'type': 'tet'}
tet_oct_competition.append(tet_env)
except Exception:
pass
if len(neigh) == 4: break
# ---------------- 八面体 (Oct) 检测 ----------------
if envtype == 'both' or envtype == 'oct':
for dist in np.linspace(1, 4, 601):
neigh = with_li.get_neighbors(with_li.sites[0], dist)
if len(neigh) < 6:
continue
elif len(neigh) > 6:
break
neigh_coords = [i.coords for i in neigh]
try:
with HiddenPrints():
lgf = LocalGeometryFinder(only_symbols=["O:6"], permutations_safe_override=False)
lgf.setup_structure(structure=with_li)
lgf.setup_local_geometry(isite=0, coords=neigh_coords)
site_volume = ConvexHull(neigh_coords).volume
csm_val = lgf.get_coordination_symmetry_measures()['O:6']['csm']
oct_env = {'csm': csm_val, 'vol': site_volume, 'type': 'oct'}
tet_oct_competition.append(oct_env)
except Exception:
pass
if len(neigh) == 6: break
# ---------------- 结果判定 ----------------
if len(tet_oct_competition) == 0:
return {'csm': np.nan, 'vol': np.nan, 'type': 'Non_' + envtype}
elif len(tet_oct_competition) == 1:
return tet_oct_competition[0]
elif len(tet_oct_competition) >= 2:
return min(tet_oct_competition, key=lambda x: x['csm'])
def extract_sites(struct, sp="Li", envtype='both'):
envlist = []
# 遍历所有位点寻找 Li
for i, site in enumerate(struct):
site_elements = [el.symbol for el in site.species.elements]
if sp in site_elements:
try:
# 传入结构副本以防修改原结构
singleenv = site_env(site.frac_coords, struct.copy(), sp, envtype)
envlist.append({
'site_index': i,
'frac_coords': site.frac_coords,
'type': singleenv.get('type', 'unknown'),
'csm': singleenv.get('csm', np.nan),
'volume': singleenv.get('vol', np.nan)
})
except Exception as e:
# 捕捉单个位点计算错误,不中断程序
# print(f" [Warn] Site {i} calculation failed: {e}")
pass
return envlist
def export_envs(envlist, sp, envtype, fname):
with open(fname, 'w') as f:
f.write('List of environment information\n')
f.write(f'Species : {sp}\n')
f.write(f'Envtype : {envtype}\n')
for item in envlist:
# 格式化输出,确保没有数据也能看懂
f.write(f"Site index {item['site_index']}: {item}\n")
# ================= 主程序 =================
def run_csm_analysis():
# 1. 检查目录
if not os.path.exists(INPUT_DIR):
print(f"错误: 输入目录不存在 -> {os.path.abspath(INPUT_DIR)}")
return
cif_files = []
for root, dirs, files in os.walk(INPUT_DIR):
for file in files:
if file.endswith(".cif"):
cif_files.append(os.path.join(root, file))
if not cif_files:
print(f"{INPUT_DIR} 中未找到 .cif 文件。")
return
print(f"开始分析 {len(cif_files)} 个文件 (目标元素: {TARGET_ELEMENT}, 包含阴离子: F,Cl,Br,I,O,S,N...)")
success_count = 0
for cif_path in tqdm(cif_files, desc="Calculating CSM"):
try:
# 准备路径
rel_path = os.path.relpath(cif_path, INPUT_DIR)
rel_dir = os.path.dirname(rel_path)
file_base = os.path.splitext(os.path.basename(cif_path))[0]
target_dir = os.path.join(OUTPUT_DIR, rel_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
target_dat_path = os.path.join(target_dir, f"{file_base}.dat")
# 如果文件已存在且不为空,可选择跳过
# if os.path.exists(target_dat_path) and os.path.getsize(target_dat_path) > 0:
# continue
# 读取结构
struct = Structure.from_file(cif_path)
# 检查是否含 Li
if Element(TARGET_ELEMENT) not in struct.composition.elements:
continue
# 计算环境
env_list = extract_sites(struct, sp=TARGET_ELEMENT, envtype=ENV_TYPE)
# 写入结果 (即使 env_list 为空也写入一个标记文件方便debug)
if env_list:
export_envs(env_list, sp=TARGET_ELEMENT, envtype=ENV_TYPE, fname=target_dat_path)
success_count += 1
else:
with open(target_dat_path, 'w') as f:
f.write(f"No {TARGET_ELEMENT} environments found (Check connectivity or anion types).")
except Exception as e:
print(f"\n[Error] File: {os.path.basename(cif_path)} -> {e}")
continue
print(f"\n分析完成!成功生成 {success_count} 个文件。")
print(f"输出目录: {os.path.abspath(OUTPUT_DIR)}")
if __name__ == "__main__":
run_csm_analysis()