500 lines
20 KiB
Python
500 lines
20 KiB
Python
from distutils.dir_util import remove_tree
|
||
|
||
from pymatgen.core import Structure, Lattice,Species,PeriodicSite
|
||
import numpy as np
|
||
from collections import defaultdict
|
||
import math
|
||
import spglib
|
||
from functools import reduce
|
||
from fractions import Fraction
|
||
import random
|
||
import re
|
||
import os
|
||
|
||
def mark_atoms_randomly(numerator,denominator):
|
||
"""
|
||
根据扩胞因子和占据数量生成随机占据字典
|
||
|
||
Args:
|
||
factors: 扩胞因子字典 {'x': int, 'y': int, 'z': int}
|
||
atom_number: 需要占据的副本数量
|
||
|
||
Returns:
|
||
字典 {0: 1或0, 1: 1或0, ..., total_copies-1: 1或0}
|
||
"""
|
||
|
||
|
||
if numerator > denominator:
|
||
raise ValueError(f"atom_number ({numerator}) 不能超过扩胞总数 (x*y*z = {denominator})")
|
||
|
||
# 生成所有副本索引 [0, 1, 2, ..., total_copies-1]
|
||
atom_dice = list(range(denominator))
|
||
|
||
# 随机选择 atom_number 个副本占据
|
||
selected_atoms = random.sample(atom_dice, numerator)
|
||
|
||
# 创建结果字典 {0: 1或0, 1: 1或0, ...}
|
||
result = {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
|
||
|
||
return result
|
||
def extract_oxi_state(species_str,element):
|
||
"""
|
||
从物种字符串中提取指定元素的氧化态
|
||
|
||
参数:
|
||
species_str: 物种字符串,如 "Li+:0.689, Sc3+:0.311"
|
||
element: 要提取的元素符号,如 "Sc"
|
||
|
||
返回:
|
||
int: 氧化态数值(如 Sc3+ → 3,Sc- → -1,Sc3- → -3)
|
||
如果未找到或没有氧化态则返回 0
|
||
"""
|
||
# 分割字符串获取各个物种部分
|
||
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
|
||
|
||
for part in species_parts:
|
||
# 提取元素和电荷部分(冒号前的内容)
|
||
element_with_charge = part.split(":")[0].strip()
|
||
|
||
# 检查是否匹配目标元素
|
||
if element in element_with_charge:
|
||
# 提取电荷部分
|
||
charge_part = element_with_charge[len(element):]
|
||
|
||
# 处理无数字情况(如"Sc+")
|
||
if not any(c.isdigit() for c in charge_part):
|
||
if "+" in charge_part:
|
||
return 1
|
||
elif "-" in charge_part:
|
||
return -1
|
||
else:
|
||
return 0
|
||
|
||
# 处理有数字情况(如"Sc3+")
|
||
sign = 1
|
||
if "-" in charge_part:
|
||
sign = -1
|
||
|
||
# 提取数字部分
|
||
digits = ""
|
||
for c in charge_part:
|
||
if c.isdigit():
|
||
digits += c
|
||
|
||
if digits: # 确保有提取到数字
|
||
return sign * int(digits)
|
||
|
||
return 0 # 默认返回0
|
||
def factorize_to_three_factors(n,type_sym=None,keep_module=None):
|
||
factors = []
|
||
|
||
# 遍历可能的x值
|
||
if type_sym == None:
|
||
for x in range(1, n + 1):
|
||
if n % x == 0:
|
||
remaining_n = n // x
|
||
# 遍历可能的y值
|
||
for y in range(1, remaining_n + 1):
|
||
if remaining_n % y == 0:
|
||
z = remaining_n // y
|
||
factors.append({'x': x, 'y': y, 'z': z})
|
||
if type_sym == "xyz":
|
||
for x in range(1, n + 1):
|
||
if n % x == 0:
|
||
remaining_n = n // x
|
||
# 遍历可能的y值
|
||
for y in range(1, remaining_n + 1):
|
||
if remaining_n % y == 0 and y <= x:
|
||
z = remaining_n // y
|
||
if z <= y:
|
||
factors.append({'x': x, 'y': y, 'z': z})
|
||
if keep_module=='random':
|
||
import random
|
||
# 创建一个因子列表的副本,并随机打乱顺序
|
||
shuffled_factors = factors.copy()
|
||
random.shuffle(shuffled_factors)
|
||
return shuffled_factors
|
||
else:
|
||
def sort_key(item):
|
||
"""返回一个用于排序的元组"""
|
||
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
|
||
|
||
# 使用 sorted() 函数(返回一个新的排序后的列表,不改变原列表)
|
||
sorted_factor = sorted(factors, key=sort_key)
|
||
return sorted_factor
|
||
def typejudge(number):
|
||
if number in [1, 2]:
|
||
return "Triclinic"
|
||
elif 3 <= number <= 15:
|
||
return "Monoclinic"
|
||
elif 16 <= number <= 74:
|
||
return "Orthorhombic"
|
||
elif 75 <= number <= 142:
|
||
return "Tetragonal"
|
||
elif 143 <= number <= 167:
|
||
return "Trigonal"
|
||
elif 168 <= number <= 194:
|
||
return "Hexagonal"
|
||
elif 195 <= number <= 230:
|
||
return "Cubic"
|
||
else:
|
||
return "Unknown"
|
||
def strategy_divide(struct,total,keep_module=None):
|
||
space_group_info = struct.get_space_group_info()
|
||
space_group_symbol = space_group_info[0]
|
||
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
|
||
symbol = all_spacegroup_symbols[0]
|
||
for symbol_i in all_spacegroup_symbols:
|
||
if space_group_symbol == symbol_i.international_short:
|
||
symbol = symbol_i
|
||
space_type = typejudge(symbol.number)
|
||
print(f"当前空间群符号为{space_group_symbol},序号为{symbol.number},对应的晶体体系为{space_type}")
|
||
divides = []
|
||
if space_type == "Hexagonal":
|
||
print('当前为六方晶系,暂不处理')
|
||
if space_type == "Cubic":
|
||
print("当前为立方晶体,三个方向同步")
|
||
divides = factorize_to_three_factors(total, "xyz",keep_module=keep_module)
|
||
else:
|
||
print("为其他晶系,假设三个方向不同")
|
||
divides = factorize_to_three_factors(total,keep_module=keep_module)
|
||
return divides
|
||
def get_first_non_explicit_element(species_str, explict_element= ["Li","Li+"]):
|
||
"""
|
||
从物种字符串中获取第一个不在explict_element中的元素符号
|
||
|
||
参数:
|
||
species_str: 物种字符串,如 "Li+:0.689, Sc3+:0.311"
|
||
explict_element: 需要排除的元素列表,如 ["Li"]
|
||
|
||
返回:
|
||
str: 第一个符合条件的元素符号,如 "Sc"
|
||
如果没有找到则返回空字符串 ""
|
||
"""
|
||
if not species_str.strip():
|
||
return ""
|
||
|
||
# 分割字符串获取各个物种部分
|
||
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
|
||
|
||
for part in species_parts:
|
||
# 提取元素符号(去掉电荷和占据数部分)
|
||
element_with_charge = part.split(":")[0].strip()
|
||
# 提取纯元素符号(去掉数字和特殊符号)
|
||
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
|
||
|
||
if pure_element not in explict_element:
|
||
return pure_element
|
||
|
||
return ""
|
||
def calculate_expansion_factor(Occupation_list,calculate_type='high'):
|
||
"""
|
||
计算Occupation_list的扩大倍数,支持不同精度模式
|
||
|
||
参数:
|
||
Occupation_list: List[Dict], 每个字典包含:
|
||
{
|
||
"occupation": float,
|
||
"atom_serial": List[int],
|
||
"numerator": None,
|
||
"denominator": None
|
||
}
|
||
calculate_type: str, 计算精度模式 ('high', 'normal', 'low')
|
||
- high: 精确分数(默认)
|
||
- normal: 分母≤100的最接近分数
|
||
- low: 分母≤10的最接近分数
|
||
|
||
返回:
|
||
int: 扩大倍数(所有分母的最小公倍数)
|
||
List[Dict]: 更新后的Occupation_list(包含分子和分母)
|
||
"""
|
||
if not Occupation_list:
|
||
return 1, []
|
||
|
||
# Step 1: 根据精度要求计算分数
|
||
for entry in Occupation_list:
|
||
occu = entry["occupation"]
|
||
|
||
if calculate_type == 'high':
|
||
# 高精度模式 - 使用精确分数
|
||
fraction = Fraction(occu).limit_denominator()
|
||
elif calculate_type == 'normal':
|
||
# 普通精度 - 分母≤100
|
||
fraction = Fraction(occu).limit_denominator(100)
|
||
elif calculate_type == 'low':
|
||
# 低精度 - 分母≤10
|
||
fraction = Fraction(occu).limit_denominator(10)
|
||
elif calculate_type == 'very low':
|
||
# 低精度 - 分母≤10
|
||
fraction = Fraction(occu).limit_denominator(5)
|
||
else:
|
||
raise ValueError("calculate_type必须是'high', 'normal'或'low'")
|
||
|
||
entry["numerator"] = fraction.numerator
|
||
entry["denominator"] = fraction.denominator
|
||
|
||
# Step 2: 计算所有分母的最小公倍数
|
||
denominators = [entry["denominator"] for entry in Occupation_list]
|
||
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
|
||
|
||
# Step 3: 统一分母
|
||
for entry in Occupation_list:
|
||
denominator = entry["denominator"]
|
||
entry["numerator"] = entry["numerator"] * (lcm // denominator)
|
||
entry["denominator"] = lcm
|
||
|
||
return lcm, Occupation_list
|
||
def get_occu(s_str,explict_element):
|
||
'''
|
||
这里暂时不考虑无化合价的情况
|
||
Args:
|
||
s_str:
|
||
|
||
Returns:
|
||
|
||
'''
|
||
if not s_str.strip():
|
||
return {}
|
||
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
|
||
matches = re.findall(pattern, s_str)
|
||
result = {}
|
||
for species, occu in matches:
|
||
try:
|
||
if species not in explict_element:
|
||
return occu
|
||
except ValueError:
|
||
continue # 忽略无效数字
|
||
|
||
return 1
|
||
def process_cif_file(struct, explict_element=["Li", "Li+"]):
|
||
"""
|
||
统计结构中各原子的occupation情况(忽略occupation=1.0的原子)并分类
|
||
参数:
|
||
struct: Structure对象 (从CIF文件读取)
|
||
返回:
|
||
List[Dict]: Occupation_list,每个字典格式为:
|
||
{
|
||
"occupation": list, # 占据值(不为1.0)
|
||
"atom_serial": List[int], # 原子序号列表
|
||
"numerator": None, # 预留分子
|
||
"denominator": None # 预留分母
|
||
"split":list[string]#对应的值
|
||
}
|
||
"""
|
||
if not isinstance(struct, Structure):
|
||
raise TypeError("输入必须为pymatgen的Structure对象")
|
||
|
||
occupation_dict = defaultdict(list)
|
||
# 用于记录每个occupation对应的元素列表
|
||
split_dict = {}
|
||
for i, site in enumerate(struct):
|
||
# 获取当前原子的occupation(默认为1.0)
|
||
occu = get_occu(site.species_string, explict_element)
|
||
# 忽略occupation=1.0的原子
|
||
if occu != 1.0:
|
||
if site.species.chemical_system not in explict_element:
|
||
occupation_dict[occu].append(i + 1) # 原子序号从1开始计数
|
||
# 提取元素名称列表
|
||
elements = []
|
||
if ':' in site.species_string:
|
||
# 格式如 'S:0.494, Cl:0.506' 或 'S2-:0.494, Cl-:0.506'
|
||
parts = site.species_string.split(',')
|
||
for part in parts:
|
||
# 提取冒号前的部分并去除前后空格
|
||
element_with_valence = part.strip().split(':')[0].strip()
|
||
# 从带有价态的元素符号中提取纯元素符号(只保留元素符号部分)
|
||
# 元素符号通常是一个大写字母,可能后跟一个小写字母
|
||
import re
|
||
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
|
||
if element_match:
|
||
element = element_match.group(1)
|
||
elements.append(element)
|
||
else:
|
||
# 只有一个元素,也需要处理可能的价态
|
||
import re
|
||
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
|
||
if element_match:
|
||
elements = [element_match.group(1)]
|
||
# 存储该occupation对应的元素列表
|
||
split_dict[occu] = elements
|
||
|
||
# 转换为要求的输出格式
|
||
Occupation_list = [
|
||
{
|
||
"occupation": occu,
|
||
"atom_serial": serials,
|
||
"numerator": None,
|
||
"denominator": None,
|
||
"split": split_dict.get(occu, []) # 添加split字段
|
||
}
|
||
for occu, serials in occupation_dict.items()
|
||
]
|
||
|
||
return Occupation_list
|
||
def merge_structures(structure_list, merge_dict):
|
||
"""
|
||
按指定方向合并多个结构
|
||
|
||
参数:
|
||
structure_list: List[Structure], 待合并的结构列表(所有结构必须具有相同的晶格)
|
||
merge_dict: Dict[str, int], 指定各方向的合并次数(如 {"x":1, "y":1, "z":2})
|
||
|
||
返回:
|
||
Structure: 合并后的新结构
|
||
"""
|
||
if not structure_list:
|
||
raise ValueError("结构列表不能为空")
|
||
|
||
# 检查所有结构是否具有相同的晶格
|
||
ref_lattice = structure_list[0].lattice
|
||
for s in structure_list[1:]:
|
||
if not np.allclose(s.lattice.matrix, ref_lattice.matrix):
|
||
raise ValueError("所有结构的晶格必须相同")
|
||
|
||
# 计算总合并次数
|
||
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
|
||
if len(structure_list) != total_merge:
|
||
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
|
||
|
||
# 获取参考结构的晶格参数
|
||
a, b, c = ref_lattice.abc
|
||
alpha, beta, gamma = ref_lattice.angles
|
||
|
||
# 计算新晶格尺寸
|
||
new_a = a * merge_dict.get("x", 1)
|
||
new_b = b * merge_dict.get("y", 1)
|
||
new_c = c * merge_dict.get("z", 1)
|
||
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
|
||
|
||
# 合并所有原子
|
||
all_sites = []
|
||
for i, structure in enumerate(structure_list):
|
||
# 计算当前结构在合并后的偏移量
|
||
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
|
||
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
|
||
z_offset = i % merge_dict.get("z", 1)
|
||
|
||
# 对每个原子应用偏移
|
||
for site in structure:
|
||
coords = site.frac_coords.copy()
|
||
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
|
||
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
|
||
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
|
||
all_sites.append({"species": site.species, "coords": coords})
|
||
|
||
# 创建新结构
|
||
return Structure(new_lattice, [site["species"] for site in all_sites], [site["coords"] for site in all_sites])
|
||
def generate_structure_list(base_structure,occupation_list,explict_element=["Li","Li+"]):
|
||
if not occupation_list:
|
||
return [base_structure.copy()]
|
||
lcm = occupation_list[0]["denominator"]
|
||
structure_list = [base_structure.copy() for _ in range(lcm)]
|
||
for entry in occupation_list:
|
||
numerator = entry["numerator"]
|
||
denominator = entry["denominator"]
|
||
atom_indices = entry["atom_serial"] # 注意:原子序号从1开始
|
||
for atom_idx in atom_indices:
|
||
occupancy_dict = mark_atoms_randomly(numerator=numerator,denominator=denominator)
|
||
original_site = base_structure.sites[atom_idx - 1]
|
||
element = get_first_non_explicit_element(original_site.species_string,explict_element)
|
||
for copy_idx ,occupy in occupancy_dict.items():
|
||
structure_list[copy_idx].remove_sites([atom_idx-1])
|
||
oxi_state= extract_oxi_state(original_site.species_string,element)
|
||
if len(entry["split"])==1:
|
||
if occupy:
|
||
new_site = PeriodicSite(
|
||
species=Species(element, oxi_state),
|
||
coords=original_site.frac_coords,
|
||
lattice=structure_list[copy_idx].lattice,
|
||
to_unit_cell=True,
|
||
label=original_site.label
|
||
)
|
||
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
|
||
else:
|
||
species_dict = {Species("Li", 1.0):0.0}
|
||
new_site = PeriodicSite(
|
||
species = species_dict,
|
||
coords=original_site.frac_coords,
|
||
lattice=structure_list[copy_idx].lattice,
|
||
to_unit_cell=True,
|
||
label=original_site.label
|
||
)
|
||
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
|
||
else:
|
||
if occupy:
|
||
new_site = PeriodicSite(
|
||
species=Species(element, oxi_state),
|
||
coords=original_site.frac_coords,
|
||
lattice=structure_list[copy_idx].lattice,
|
||
to_unit_cell=True,
|
||
label=original_site.label
|
||
)
|
||
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
|
||
else:
|
||
new_site = PeriodicSite(
|
||
species=Species(entry['split'][1], oxi_state),
|
||
coords=original_site.frac_coords,
|
||
lattice=structure_list[copy_idx].lattice,
|
||
to_unit_cell=True,
|
||
label=original_site.label
|
||
)
|
||
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
|
||
return structure_list
|
||
def expansion(input_file,output_folder,keep_number,calculate_type='high',keep_module=None):
|
||
structure_origin = Structure.from_file(input_file)
|
||
lmp,oc_list = calculate_expansion_factor(process_cif_file(structure_origin),calculate_type=calculate_type)
|
||
strategy = strategy_divide(structure_origin,lmp,keep_module)
|
||
st_list = generate_structure_list(structure_origin,oc_list)
|
||
# 获取基础文件名(不含路径和扩展名)
|
||
base_name = os.path.splitext(os.path.basename(input_file))[0]
|
||
mergeds=[]
|
||
names=[]
|
||
if len(strategy)< keep_number:
|
||
keep_number = len(strategy)
|
||
for index in range(keep_number):
|
||
merged = merge_structures(st_list, strategy[index])
|
||
|
||
suffix = "x{}y{}z{}".format(
|
||
strategy[index]["x"],
|
||
strategy[index]["y"],
|
||
strategy[index]["z"]
|
||
)
|
||
output_filename=''
|
||
if keep_module=='classify':
|
||
print(f"{base_name}采用扩展方式为{suffix}")
|
||
output_filename=f"{base_name}.cif"
|
||
elif keep_module=='random':
|
||
print(f"{base_name}采用扩展方式为{suffix}")
|
||
output_filename=f"{base_name}-{suffix}.cif"
|
||
else:
|
||
output_filename = f"{base_name}-{suffix}.cif"
|
||
output_path = os.path.join(output_folder, output_filename)
|
||
|
||
merged.to(filename=output_path, fmt="cif")
|
||
|
||
print(f"Saved: {output_path}")
|
||
if keep_module=='classify':
|
||
|
||
return merged
|
||
if keep_module=='random':
|
||
mergeds.append(merged)
|
||
names.append(output_filename)
|
||
return mergeds,names
|
||
|
||
|
||
if __name__ == "__main__":
|
||
#expansion("../data/tmp/36.cif","../data/tmp",1,calculate_type='low')
|
||
expansion("../data/input_ClBr_set/36.cif", "../data/tmp", 3, calculate_type='low',keep_module='random')
|
||
#expansion("../data/input/1234.cif", "../data/input/output", 1, calculate_type='low',keep_module='classify')
|
||
# s1 = Structure.from_file("../data/input_pre/mp-6783.cif")
|
||
# s2 = Structure.from_file("../data/input_pre/ICSD_1234.cif")
|
||
# print(process_cif_file(s2))
|
||
# lmp,oc_list=calculate_expansion_factor(process_cif_file(s2))
|
||
# print(oc_list)
|
||
# strategy = strategy_divide(s2,lmp)
|
||
# print(strategy)
|
||
# st_list=generate_structure_list(s2,oc_list)
|
||
# merged = merge_structures(st_list,strategy[0])
|
||
# # merged = merge_structures([s1, s2], {"x": 1, "y": 1, "z": 2})
|
||
# merged.to("merged.cif", "cif") # 保存合并后的结构
|