Files
screen/py/expansion.py
2025-12-07 13:56:33 +08:00

500 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from distutils.dir_util import remove_tree
from pymatgen.core import Structure, Lattice,Species,PeriodicSite
import numpy as np
from collections import defaultdict
import math
import spglib
from functools import reduce
from fractions import Fraction
import random
import re
import os
def mark_atoms_randomly(numerator,denominator):
"""
根据扩胞因子和占据数量生成随机占据字典
Args:
factors: 扩胞因子字典 {'x': int, 'y': int, 'z': int}
atom_number: 需要占据的副本数量
Returns:
字典 {0: 1或0, 1: 1或0, ..., total_copies-1: 1或0}
"""
if numerator > denominator:
raise ValueError(f"atom_number ({numerator}) 不能超过扩胞总数 (x*y*z = {denominator})")
# 生成所有副本索引 [0, 1, 2, ..., total_copies-1]
atom_dice = list(range(denominator))
# 随机选择 atom_number 个副本占据
selected_atoms = random.sample(atom_dice, numerator)
# 创建结果字典 {0: 1或0, 1: 1或0, ...}
result = {atom: 1 if atom in selected_atoms else 0 for atom in atom_dice}
return result
def extract_oxi_state(species_str,element):
"""
从物种字符串中提取指定元素的氧化态
参数:
species_str: 物种字符串,如 "Li+:0.689, Sc3+:0.311"
element: 要提取的元素符号,如 "Sc"
返回:
int: 氧化态数值(如 Sc3+ → 3Sc- → -1Sc3- → -3
如果未找到或没有氧化态则返回 0
"""
# 分割字符串获取各个物种部分
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
# 提取元素和电荷部分(冒号前的内容)
element_with_charge = part.split(":")[0].strip()
# 检查是否匹配目标元素
if element in element_with_charge:
# 提取电荷部分
charge_part = element_with_charge[len(element):]
# 处理无数字情况(如"Sc+"
if not any(c.isdigit() for c in charge_part):
if "+" in charge_part:
return 1
elif "-" in charge_part:
return -1
else:
return 0
# 处理有数字情况(如"Sc3+"
sign = 1
if "-" in charge_part:
sign = -1
# 提取数字部分
digits = ""
for c in charge_part:
if c.isdigit():
digits += c
if digits: # 确保有提取到数字
return sign * int(digits)
return 0 # 默认返回0
def factorize_to_three_factors(n,type_sym=None,keep_module=None):
factors = []
# 遍历可能的x值
if type_sym == None:
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
# 遍历可能的y值
for y in range(1, remaining_n + 1):
if remaining_n % y == 0:
z = remaining_n // y
factors.append({'x': x, 'y': y, 'z': z})
if type_sym == "xyz":
for x in range(1, n + 1):
if n % x == 0:
remaining_n = n // x
# 遍历可能的y值
for y in range(1, remaining_n + 1):
if remaining_n % y == 0 and y <= x:
z = remaining_n // y
if z <= y:
factors.append({'x': x, 'y': y, 'z': z})
if keep_module=='random':
import random
# 创建一个因子列表的副本,并随机打乱顺序
shuffled_factors = factors.copy()
random.shuffle(shuffled_factors)
return shuffled_factors
else:
def sort_key(item):
"""返回一个用于排序的元组"""
return (item['x'] + item['y'] + item['z'], item['z'], item['y'], item['x'])
# 使用 sorted() 函数(返回一个新的排序后的列表,不改变原列表)
sorted_factor = sorted(factors, key=sort_key)
return sorted_factor
def typejudge(number):
if number in [1, 2]:
return "Triclinic"
elif 3 <= number <= 15:
return "Monoclinic"
elif 16 <= number <= 74:
return "Orthorhombic"
elif 75 <= number <= 142:
return "Tetragonal"
elif 143 <= number <= 167:
return "Trigonal"
elif 168 <= number <= 194:
return "Hexagonal"
elif 195 <= number <= 230:
return "Cubic"
else:
return "Unknown"
def strategy_divide(struct,total,keep_module=None):
space_group_info = struct.get_space_group_info()
space_group_symbol = space_group_info[0]
all_spacegroup_symbols = [spglib.get_spacegroup_type(i) for i in range(1, 531)]
symbol = all_spacegroup_symbols[0]
for symbol_i in all_spacegroup_symbols:
if space_group_symbol == symbol_i.international_short:
symbol = symbol_i
space_type = typejudge(symbol.number)
print(f"当前空间群符号为{space_group_symbol},序号为{symbol.number},对应的晶体体系为{space_type}")
divides = []
if space_type == "Hexagonal":
print('当前为六方晶系,暂不处理')
if space_type == "Cubic":
print("当前为立方晶体,三个方向同步")
divides = factorize_to_three_factors(total, "xyz",keep_module=keep_module)
else:
print("为其他晶系,假设三个方向不同")
divides = factorize_to_three_factors(total,keep_module=keep_module)
return divides
def get_first_non_explicit_element(species_str, explict_element= ["Li","Li+"]):
"""
从物种字符串中获取第一个不在explict_element中的元素符号
参数:
species_str: 物种字符串,如 "Li+:0.689, Sc3+:0.311"
explict_element: 需要排除的元素列表,如 ["Li"]
返回:
str: 第一个符合条件的元素符号,如 "Sc"
如果没有找到则返回空字符串 ""
"""
if not species_str.strip():
return ""
# 分割字符串获取各个物种部分
species_parts = [part.strip() for part in species_str.split(",") if part.strip()]
for part in species_parts:
# 提取元素符号(去掉电荷和占据数部分)
element_with_charge = part.split(":")[0].strip()
# 提取纯元素符号(去掉数字和特殊符号)
pure_element = ''.join([c for c in element_with_charge if c.isalpha()])
if pure_element not in explict_element:
return pure_element
return ""
def calculate_expansion_factor(Occupation_list,calculate_type='high'):
"""
计算Occupation_list的扩大倍数支持不同精度模式
参数:
Occupation_list: List[Dict], 每个字典包含:
{
"occupation": float,
"atom_serial": List[int],
"numerator": None,
"denominator": None
}
calculate_type: str, 计算精度模式 ('high', 'normal', 'low')
- high: 精确分数(默认)
- normal: 分母≤100的最接近分数
- low: 分母≤10的最接近分数
返回:
int: 扩大倍数(所有分母的最小公倍数)
List[Dict]: 更新后的Occupation_list包含分子和分母
"""
if not Occupation_list:
return 1, []
# Step 1: 根据精度要求计算分数
for entry in Occupation_list:
occu = entry["occupation"]
if calculate_type == 'high':
# 高精度模式 - 使用精确分数
fraction = Fraction(occu).limit_denominator()
elif calculate_type == 'normal':
# 普通精度 - 分母≤100
fraction = Fraction(occu).limit_denominator(100)
elif calculate_type == 'low':
# 低精度 - 分母≤10
fraction = Fraction(occu).limit_denominator(10)
elif calculate_type == 'very low':
# 低精度 - 分母≤10
fraction = Fraction(occu).limit_denominator(5)
else:
raise ValueError("calculate_type必须是'high', 'normal''low'")
entry["numerator"] = fraction.numerator
entry["denominator"] = fraction.denominator
# Step 2: 计算所有分母的最小公倍数
denominators = [entry["denominator"] for entry in Occupation_list]
lcm = reduce(lambda a, b: a * b // math.gcd(a, b), denominators, 1)
# Step 3: 统一分母
for entry in Occupation_list:
denominator = entry["denominator"]
entry["numerator"] = entry["numerator"] * (lcm // denominator)
entry["denominator"] = lcm
return lcm, Occupation_list
def get_occu(s_str,explict_element):
'''
这里暂时不考虑无化合价的情况
Args:
s_str:
Returns:
'''
if not s_str.strip():
return {}
pattern = r'([A-Za-z0-9+-]+):([0-9.]+)'
matches = re.findall(pattern, s_str)
result = {}
for species, occu in matches:
try:
if species not in explict_element:
return occu
except ValueError:
continue # 忽略无效数字
return 1
def process_cif_file(struct, explict_element=["Li", "Li+"]):
"""
统计结构中各原子的occupation情况忽略occupation=1.0的原子)并分类
参数:
struct: Structure对象 (从CIF文件读取)
返回:
List[Dict]: Occupation_list每个字典格式为:
{
"occupation": list, # 占据值不为1.0
"atom_serial": List[int], # 原子序号列表
"numerator": None, # 预留分子
"denominator": None # 预留分母
"split":list[string]#对应的值
}
"""
if not isinstance(struct, Structure):
raise TypeError("输入必须为pymatgen的Structure对象")
occupation_dict = defaultdict(list)
# 用于记录每个occupation对应的元素列表
split_dict = {}
for i, site in enumerate(struct):
# 获取当前原子的occupation默认为1.0
occu = get_occu(site.species_string, explict_element)
# 忽略occupation=1.0的原子
if occu != 1.0:
if site.species.chemical_system not in explict_element:
occupation_dict[occu].append(i + 1) # 原子序号从1开始计数
# 提取元素名称列表
elements = []
if ':' in site.species_string:
# 格式如 'S:0.494, Cl:0.506' 或 'S2-:0.494, Cl-:0.506'
parts = site.species_string.split(',')
for part in parts:
# 提取冒号前的部分并去除前后空格
element_with_valence = part.strip().split(':')[0].strip()
# 从带有价态的元素符号中提取纯元素符号(只保留元素符号部分)
# 元素符号通常是一个大写字母,可能后跟一个小写字母
import re
element_match = re.match(r'([A-Z][a-z]?)', element_with_valence)
if element_match:
element = element_match.group(1)
elements.append(element)
else:
# 只有一个元素,也需要处理可能的价态
import re
element_match = re.match(r'([A-Z][a-z]?)', site.species_string)
if element_match:
elements = [element_match.group(1)]
# 存储该occupation对应的元素列表
split_dict[occu] = elements
# 转换为要求的输出格式
Occupation_list = [
{
"occupation": occu,
"atom_serial": serials,
"numerator": None,
"denominator": None,
"split": split_dict.get(occu, []) # 添加split字段
}
for occu, serials in occupation_dict.items()
]
return Occupation_list
def merge_structures(structure_list, merge_dict):
"""
按指定方向合并多个结构
参数:
structure_list: List[Structure], 待合并的结构列表(所有结构必须具有相同的晶格)
merge_dict: Dict[str, int], 指定各方向的合并次数(如 {"x":1, "y":1, "z":2}
返回:
Structure: 合并后的新结构
"""
if not structure_list:
raise ValueError("结构列表不能为空")
# 检查所有结构是否具有相同的晶格
ref_lattice = structure_list[0].lattice
for s in structure_list[1:]:
if not np.allclose(s.lattice.matrix, ref_lattice.matrix):
raise ValueError("所有结构的晶格必须相同")
# 计算总合并次数
total_merge = merge_dict.get("x", 1) * merge_dict.get("y", 1) * merge_dict.get("z", 1)
if len(structure_list) != total_merge:
raise ValueError(f"结构数量({len(structure_list)})与合并次数({total_merge})不匹配")
# 获取参考结构的晶格参数
a, b, c = ref_lattice.abc
alpha, beta, gamma = ref_lattice.angles
# 计算新晶格尺寸
new_a = a * merge_dict.get("x", 1)
new_b = b * merge_dict.get("y", 1)
new_c = c * merge_dict.get("z", 1)
new_lattice = Lattice.from_parameters(new_a, new_b, new_c, alpha, beta, gamma)
# 合并所有原子
all_sites = []
for i, structure in enumerate(structure_list):
# 计算当前结构在合并后的偏移量
x_offset = (i // (merge_dict.get("y", 1) * merge_dict.get("z", 1))) % merge_dict.get("x", 1)
y_offset = (i // merge_dict.get("z", 1)) % merge_dict.get("y", 1)
z_offset = i % merge_dict.get("z", 1)
# 对每个原子应用偏移
for site in structure:
coords = site.frac_coords.copy()
coords[0] = (coords[0] + x_offset) / merge_dict.get("x", 1)
coords[1] = (coords[1] + y_offset) / merge_dict.get("y", 1)
coords[2] = (coords[2] + z_offset) / merge_dict.get("z", 1)
all_sites.append({"species": site.species, "coords": coords})
# 创建新结构
return Structure(new_lattice, [site["species"] for site in all_sites], [site["coords"] for site in all_sites])
def generate_structure_list(base_structure,occupation_list,explict_element=["Li","Li+"]):
if not occupation_list:
return [base_structure.copy()]
lcm = occupation_list[0]["denominator"]
structure_list = [base_structure.copy() for _ in range(lcm)]
for entry in occupation_list:
numerator = entry["numerator"]
denominator = entry["denominator"]
atom_indices = entry["atom_serial"] # 注意原子序号从1开始
for atom_idx in atom_indices:
occupancy_dict = mark_atoms_randomly(numerator=numerator,denominator=denominator)
original_site = base_structure.sites[atom_idx - 1]
element = get_first_non_explicit_element(original_site.species_string,explict_element)
for copy_idx ,occupy in occupancy_dict.items():
structure_list[copy_idx].remove_sites([atom_idx-1])
oxi_state= extract_oxi_state(original_site.species_string,element)
if len(entry["split"])==1:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
else:
species_dict = {Species("Li", 1.0):0.0}
new_site = PeriodicSite(
species = species_dict,
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
else:
if occupy:
new_site = PeriodicSite(
species=Species(element, oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
else:
new_site = PeriodicSite(
species=Species(entry['split'][1], oxi_state),
coords=original_site.frac_coords,
lattice=structure_list[copy_idx].lattice,
to_unit_cell=True,
label=original_site.label
)
structure_list[copy_idx].sites.insert(atom_idx - 1, new_site)
return structure_list
def expansion(input_file,output_folder,keep_number,calculate_type='high',keep_module=None):
structure_origin = Structure.from_file(input_file)
lmp,oc_list = calculate_expansion_factor(process_cif_file(structure_origin),calculate_type=calculate_type)
strategy = strategy_divide(structure_origin,lmp,keep_module)
st_list = generate_structure_list(structure_origin,oc_list)
# 获取基础文件名(不含路径和扩展名)
base_name = os.path.splitext(os.path.basename(input_file))[0]
mergeds=[]
names=[]
if len(strategy)< keep_number:
keep_number = len(strategy)
for index in range(keep_number):
merged = merge_structures(st_list, strategy[index])
suffix = "x{}y{}z{}".format(
strategy[index]["x"],
strategy[index]["y"],
strategy[index]["z"]
)
output_filename=''
if keep_module=='classify':
print(f"{base_name}采用扩展方式为{suffix}")
output_filename=f"{base_name}.cif"
elif keep_module=='random':
print(f"{base_name}采用扩展方式为{suffix}")
output_filename=f"{base_name}-{suffix}.cif"
else:
output_filename = f"{base_name}-{suffix}.cif"
output_path = os.path.join(output_folder, output_filename)
merged.to(filename=output_path, fmt="cif")
print(f"Saved: {output_path}")
if keep_module=='classify':
return merged
if keep_module=='random':
mergeds.append(merged)
names.append(output_filename)
return mergeds,names
if __name__ == "__main__":
#expansion("../data/tmp/36.cif","../data/tmp",1,calculate_type='low')
expansion("../data/input_ClBr_set/36.cif", "../data/tmp", 3, calculate_type='low',keep_module='random')
#expansion("../data/input/1234.cif", "../data/input/output", 1, calculate_type='low',keep_module='classify')
# s1 = Structure.from_file("../data/input_pre/mp-6783.cif")
# s2 = Structure.from_file("../data/input_pre/ICSD_1234.cif")
# print(process_cif_file(s2))
# lmp,oc_list=calculate_expansion_factor(process_cif_file(s2))
# print(oc_list)
# strategy = strategy_divide(s2,lmp)
# print(strategy)
# st_list=generate_structure_list(s2,oc_list)
# merged = merge_structures(st_list,strategy[0])
# # merged = merge_structures([s1, s2], {"x": 1, "y": 1, "z": 2})
# merged.to("merged.cif", "cif") # 保存合并后的结构