Skip to content

Commit

Permalink
XspectraCrystalWorkChain: Enable Symmetry Data Inputs (aiidateam#1028)
Browse files Browse the repository at this point in the history
Adds an input namespace for the `XspectraCrystalWorkChain` which
allows the user to define the spacegroup and equivalent sites data
for the incoming structure, thus instructing the WorkChain to generate
structures and run calculations for only the sites specified.

Changes:
* Adds the `symmetry_data` input namespace to `XspectraCrystalWorkChain`,
  which the `WorkChain` will use to generate structures and set the list
  of polarisation vectors to calculate.
* Adds input validation steps for the symmetry data to check for
  required information and for entries which may cause a crash, though
  does not check for issues beyond this in order to maximise flexibility
  of use.
* Fixes an oversight in `get_xspectra_structures` where the `supercell`
  entry was not returned to the outputs when external symmetry data were
  provided by the user.
  • Loading branch information
PNOGillespie authored and bastonero committed Jan 6, 2025
1 parent 2e9480d commit e3641b2
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st
new_supercell = get_supercell_result['new_supercell']
output_params['supercell_factors'] = multiples

result['supercell'] = new_supercell
output_params['supercell_num_sites'] = len(new_supercell.sites)
output_params['supercell_cell_matrix'] = new_supercell.cell
output_params['supercell_cell_lengths'] = new_supercell.cell_lengths
Expand Down
105 changes: 81 additions & 24 deletions src/aiida_quantumespresso/workflows/xspectra/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Uses QuantumESPRESSO pw.x and xspectra.x.
"""
from aiida import orm
from aiida.common import AttributeDict, ValidationError
from aiida.common import AttributeDict
from aiida.engine import ToContext, WorkChain, if_
from aiida.orm import UpfData as aiida_core_upf
from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory
Expand Down Expand Up @@ -173,6 +173,19 @@ def define(cls, spec):
help=('Input namespace to provide core wavefunction inputs for each element. Must follow the format: '
'``core_wfc_data__{symbol} = {node}``')
)
spec.input_namespace(
'symmetry_data',
valid_type=(orm.Dict, orm.Int),
dynamic=True,
required=False,
help=(
'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will '
'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known '
'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be '
'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_<site_index>". '
'See docstring of `get_xspectra_structures` for more information about inputs.'
)
)
spec.inputs.validator = cls.validate_inputs
spec.outline(
cls.setup,
Expand Down Expand Up @@ -370,7 +383,7 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements


@staticmethod
def validate_inputs(inputs, _):
def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements
"""Validate the inputs before launching the WorkChain."""
structure = inputs['structure']
kinds_present = [kind.name for kind in structure.kinds]
Expand All @@ -382,54 +395,92 @@ def validate_inputs(inputs, _):
if element not in elements_present:
extra_elements.append(element)
if len(extra_elements) > 0:
raise ValidationError(
return (
f'Some elements in ``elements_list`` {extra_elements} do not exist in the'
f' structure provided {elements_present}.'
)

abs_atom_marker = inputs['abs_atom_marker'].value
if abs_atom_marker in kinds_present:
raise ValidationError(
return (
f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the '
f'input structure ({kinds_present}).'
)

if not inputs['core']['get_powder_spectrum'].value:
raise ValidationError(
return (
'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.'
)

if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs:
raise ValidationError(
return (
'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.'
)

if 'core_wfc_data' in inputs:
core_wfc_data_list = sorted(inputs['core_wfc_data'].keys())
if core_wfc_data_list != absorbing_elements_list:
raise ValidationError(
return (
f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of'
f' absorbing elements ({absorbing_elements_list})'
)
else:
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except Exception as exc:
raise ValidationError(
'The core wavefunction data file is not of the correct format'
) from exc
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
raise ValidationError(
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)
empty_core_wfc_data = []
for key, value in inputs['core_wfc_data'].items():
header_line = value.get_content()[:40]
try:
num_core_states = int(header_line.split(' ')[5])
except: # pylint: disable=bare-except
return (
'The core wavefunction data file is not of the correct format'
) # pylint: enable=bare-except
if num_core_states == 0:
empty_core_wfc_data.append(key)
if len(empty_core_wfc_data) > 0:
return (
f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain '
'any wavefunction data.'
)

if 'symmetry_data' in inputs:
spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value
equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict()
if spacegroup_number <= 0 or spacegroup_number >= 231:
return (
f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).'
)

input_elements = []
required_keys = sorted(['symbol', 'multiplicity', 'kind_name', 'site_index'])
invalid_entries = []
# We check three things here: (1) are there any site indices which are outside of the possible
# range of site indices (2) do we have all the required keys for each entry,
# and (3) is there a mismatch between `absorbing_elements_list` and the elements specified
# in the entries of `equivalent_sites_data`. These checks are intended only to avoid a crash.
# We assume otherwise that the user knows what they're doing and has set everything else
# to their preferences correctly.
for site_label, value in equivalent_sites_data.items():
if not set(required_keys).issubset(set(value.keys())) :
invalid_entries.append(site_label)
elif value['symbol'] not in input_elements:
input_elements.append(value['symbol'])
if value['site_index'] < 0 or value['site_index'] >= len(structure.sites):
return (
f'The site index for {site_label} ({value["site_index"]}) is outside the range of '
+ f'sites within the structure (0-{len(structure.sites) -1}).'
)

if len(invalid_entries) != 0:
return (
f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}'
)

sorted_input_elements = sorted(input_elements)
if sorted_input_elements != absorbing_elements_list:
return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) '
f'do not match the list of absorbing elements ({absorbing_elements_list})')


# pylint: enable=too-many-return-statements
def setup(self):
"""Set required context variables."""
if 'core_wfc_data' in self.inputs.keys():
Expand Down Expand Up @@ -489,6 +540,12 @@ def get_xspectra_structures(self):
if 'spglib_settings' in self.inputs:
inputs['spglib_settings'] = self.inputs.spglib_settings

if 'symmetry_data' in self.inputs:
inputs['parse_symmetry'] = orm.Bool(False)
input_sym_data = self.inputs.symmetry_data
inputs['equivalent_sites_data'] = input_sym_data['equivalent_sites_data']
inputs['spacegroup_number'] = input_sym_data['spacegroup_number']

if 'relax' in self.inputs:
result = get_xspectra_structures(self.ctx.optimized_structure, **inputs)
else:
Expand Down

0 comments on commit e3641b2

Please sign in to comment.