diff --git a/src/saml2/config.py b/src/saml2/config.py index 1f423789a..8cb38617a 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -7,6 +7,7 @@ import os import re import sys +from functools import partial import six @@ -21,6 +22,7 @@ from saml2.mdstore import MetadataStore from saml2.saml import NAME_FORMAT_URI from saml2.virtual_org import VirtualOrg +from saml2.utility.config import RuleValidator, should_warning, must_error logger = logging.getLogger(__name__) @@ -542,6 +544,9 @@ def service_per_endpoint(self, context=None): res[endp] = (service, binding) return res + def validate(self): + pass + class SPConfig(Config): def_context = "sp" @@ -571,6 +576,47 @@ def ecp_endpoint(self, ipaddress): return None +class eIDASConfig(Config): + @classmethod + def assert_not_declared(cls, error_signal): + return (lambda x: x is None, + partial(error_signal, message="not be declared")) + + @classmethod + def assert_declared(cls, error_signal): + return (lambda x: x is not None, + partial(error_signal, message="be declared")) + + +class eIDASSPConfig(SPConfig, eIDASConfig): + def validate(self): + validators = [ + RuleValidator( + "single_logout_service", + self._sp_endpoints.get("single_logout_service"), + *self.assert_not_declared(should_warning) + ), + RuleValidator( + "artifact_resolution_service", + self._sp_endpoints.get("artifact_resolution_service"), + *self.assert_not_declared(should_warning) + ), + RuleValidator( + "manage_name_id_service", + self._sp_endpoints.get("manage_name_id_service"), + *self.assert_not_declared(should_warning) + ), + RuleValidator( + "KeyDescriptor", + self.cert_file or self.encryption_keypairs, + *self.assert_declared(must_error) + ) + ] + + for validator in validators: + validator.validate() + + class IdPConfig(Config): def_context = "idp" @@ -578,6 +624,10 @@ def __init__(self): Config.__init__(self) +class eIDASIdPConfig(IdPConfig): + pass + + def config_factory(_type, config): """ diff --git a/src/saml2/utility/__init__.py b/src/saml2/utility/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/saml2/utility/config.py b/src/saml2/utility/config.py new file mode 100644 index 000000000..5b9d4bb4e --- /dev/null +++ b/src/saml2/utility/config.py @@ -0,0 +1,43 @@ +import logging + + +logger = logging.getLogger(__name__) + + +class ConfigValidationError(Exception): + pass + + +class RuleValidator(object): + def __init__(self, element_name, element_value, validator, error_signal): + """ + :param element_name: the name of the element that will be + validated + :param element_value: function to be called + with config as parameter to fetch an element value + :param validator: function to be called + with a config element value as a parameter + :param error_signal: function to be called + with an element name and value to signal an error (can be a log + function, raise an error etc) + """ + self.element_name = element_name + self.element_value = element_value + self.validator = validator + self.error_signal = error_signal + + def validate(self): + if not self.validator(self.element_value): + self.error_signal(self.element_name) + + +def should_warning(element_name, message): + logger.warning("{element} SHOULD {message}".format( + element=element_name, message=message)) + + +def must_error(element_name, message): + error = "{element} MUST {message}".format( + element=element_name, message=message) + logger.error(error) + raise ConfigValidationError(error) diff --git a/tests/eidas/test_sp.py b/tests/eidas/test_sp.py index 38201508d..005c486bf 100644 --- a/tests/eidas/test_sp.py +++ b/tests/eidas/test_sp.py @@ -1,16 +1,20 @@ +import pytest +import copy +from saml2 import BINDING_HTTP_POST from saml2 import metadata from saml2 import samlp from saml2.client import Saml2Client from saml2.server import Server -from saml2.config import SPConfig +from saml2.config import eIDASSPConfig from eidas.sp_conf import CONFIG +from saml2.utility.config import ConfigValidationError class TestSP: def setup_class(self): self.server = Server("idp_conf") - self.conf = SPConfig() + self.conf = eIDASSPConfig() self.conf.load_file("sp_conf") self.client = Saml2Client(self.conf) @@ -18,6 +22,10 @@ def setup_class(self): def teardown_class(self): self.server.close() + @pytest.fixture(scope="function") + def config(self): + return copy.deepcopy(CONFIG) + def test_authn_request_force_authn(self): req_str = "{0}".format(self.client.create_authn_request( "http://www.example.com/sso", message_id="id1")[-1]) @@ -35,10 +43,10 @@ def test_sp_type_only_in_request(self): assert not any(filter(lambda x: x.tag == "SPType", entd.extensions.extension_elements)) - def test_sp_type_in_metadata(self): - CONFIG["service"]["sp"]["sp_type_in_metadata"] = True - sconf = SPConfig() - sconf.load(CONFIG) + def test_sp_type_in_metadata(self, config): + config["service"]["sp"]["sp_type_in_metadata"] = True + sconf = eIDASSPConfig() + sconf.load(config) custom_client = Saml2Client(sconf) req_str = "{0}".format(custom_client.create_authn_request( @@ -58,5 +66,49 @@ def test_node_country_in_metadata(self): entd.extensions.extension_elements)) -if __name__ == '__main__': - TestSP() +class TestSPConfig: + @pytest.fixture(scope="function") + def raise_error_on_warning(self, monkeypatch): + def r(*args, **kwargs): + raise ConfigValidationError() + monkeypatch.setattr("saml2.utility.config.logger.warning", r) + + @pytest.fixture(scope="function") + def config(self): + return copy.deepcopy(CONFIG) + + def test_singlelogout_declared(self, config, raise_error_on_warning): + config["service"]["sp"]["endpoints"]["single_logout_service"] = \ + [("https://example.com", BINDING_HTTP_POST)] + conf = eIDASSPConfig() + conf.load(config) + + with pytest.raises(ConfigValidationError): + conf.validate() + + def test_artifact_resolution_declared(self, config, raise_error_on_warning): + config["service"]["sp"]["endpoints"]["artifact_resolution_service"] = \ + [("https://example.com", BINDING_HTTP_POST)] + conf = eIDASSPConfig() + conf.load(config) + + with pytest.raises(ConfigValidationError): + conf.validate() + + def test_manage_nameid_service_declared(self, config, raise_error_on_warning): + config["service"]["sp"]["endpoints"]["manage_name_id_service"] = \ + [("https://example.com", BINDING_HTTP_POST)] + conf = eIDASSPConfig() + conf.load(config) + + with pytest.raises(ConfigValidationError): + conf.validate() + + def test_no_keydescriptor(self, config): + del config["cert_file"] + del config["encryption_keypairs"] + conf = eIDASSPConfig() + conf.load(config) + + with pytest.raises(ConfigValidationError): + conf.validate()