Skip to content

Commit

Permalink
Get through attestation
Browse files Browse the repository at this point in the history
  • Loading branch information
tannewt committed Sep 25, 2024
1 parent 391fc0d commit 672c864
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 59 deletions.
146 changes: 123 additions & 23 deletions circuitmatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import hashlib
import pathlib
import json
import os
import struct
import time

import cryptography
import ecdsa

from typing import Optional

Expand All @@ -18,6 +18,18 @@
from . import session
from . import tlv

TEST_CERTS = pathlib.Path(
"/home/tannewt/repos/esp-matter/connectedhomeip/connectedhomeip/credentials/test/attestation/"
)
TEST_PAI_CERT_DER = TEST_CERTS / "Chip-Test-PAI-FFF1-8000-Cert.der"
TEST_PAI_CERT_PEM = TEST_CERTS / "Chip-Test-PAI-FFF1-8000-Cert.pem"
TEST_DAC_CERT_DER = TEST_CERTS / "Chip-Test-DAC-FFF1-8000-0000-Cert.der"
TEST_DAC_CERT_PEM = TEST_CERTS / "Chip-Test-DAC-FFF1-8000-0000-Cert.pem"
TEST_DAC_KEY_DER = TEST_CERTS / "Chip-Test-DAC-FFF1-8000-0000-Key.der"
TEST_DAC_KEY_PEM = TEST_CERTS / "Chip-Test-DAC-FFF1-8000-0000-Key.pem"

TEST_CD_CERT_DER = pathlib.Path("certification_declaration.der")

__version__ = "0.0.0"

# Section 4.11.2
Expand Down Expand Up @@ -192,9 +204,9 @@ def process_counter(self, counter) -> bool:


class MessageCounter:
def __init__(self, starting_value=None):
def __init__(self, starting_value=None, random_source=None):
if starting_value is None:
starting_value = os.urandom(4)
starting_value = random_source.urandom(4)
starting_value = struct.unpack("<I", starting_value)[0]
starting_value >>= 4
starting_value += 1
Expand Down Expand Up @@ -305,7 +317,7 @@ def send(self, message):


class SecureSessionContext:
def __init__(self, socket, local_session_id):
def __init__(self, random_source, socket, local_session_id):
self.session_type = None
"""Records whether the session was established using CASE or PASE."""
self.session_role_initiator = False
Expand All @@ -320,7 +332,7 @@ def __init__(self, socket, local_session_id):
"""Encrypts data in messages sent from the session establishment responder to the initiator."""
self.shared_secret = None
"""Computed during the CASE protocol execution and re-used when CASE session resumption is implemented."""
self.local_message_counter = MessageCounter()
self.local_message_counter = MessageCounter(random_source=random_source)
"""Secure Session Message Counter for outbound messages."""
self.message_reception_state = None
"""Provides tracking for the Secure Session Message Counter of the remote"""
Expand Down Expand Up @@ -736,7 +748,7 @@ def __str__(self):


class SessionManager:
def __init__(self, socket):
def __init__(self, random_source, socket):
persist_path = pathlib.Path("counters.json")
if persist_path.exists():
self.nonvolatile = json.loads(persist_path.read_text())
Expand All @@ -745,17 +757,22 @@ def __init__(self, socket):
self.nonvolatile["check_in_counter"] = None
self.nonvolatile["group_encrypted_data_message_counter"] = None
self.nonvolatile["group_encrypted_control_message_counter"] = None
self.unencrypted_message_counter = MessageCounter()
self.unencrypted_message_counter = MessageCounter(random_source=random_source)
self.group_encrypted_data_message_counter = MessageCounter(
self.nonvolatile["group_encrypted_data_message_counter"]
self.nonvolatile["group_encrypted_data_message_counter"],
random_source=random_source,
)
self.group_encrypted_control_message_counter = MessageCounter(
self.nonvolatile["group_encrypted_control_message_counter"]
self.nonvolatile["group_encrypted_control_message_counter"],
random_source=random_source,
)
self.check_in_counter = MessageCounter(
self.nonvolatile["check_in_counter"], random_source=random_source
)
self.check_in_counter = MessageCounter(self.nonvolatile["check_in_counter"])
self.unsecured_session_context = {}
self.secure_session_contexts = ["reserved"]
self.socket = socket
self.random = random_source

def _increment(self, value):
return (value + 1) % 0xFFFFFFFF
Expand Down Expand Up @@ -836,7 +853,7 @@ def new_context(self):
session_id = self.secure_session_contexts.index(None)

self.secure_session_contexts[session_id] = SecureSessionContext(
self.socket, session_id
self.random, self.socket, session_id
)
return self.secure_session_contexts[session_id]

Expand Down Expand Up @@ -887,41 +904,115 @@ def __init__(self):
self.basic_commissioning_info = basic_commissioning_info

def arm_fail_safe(
self, args: data_model.GeneralCommissioningCluster.ArmFailSafe
self, session, args: data_model.GeneralCommissioningCluster.ArmFailSafe
) -> data_model.GeneralCommissioningCluster.ArmFailSafeResponse:
response = data_model.GeneralCommissioningCluster.ArmFailSafeResponse()
response.ErrorCode = data_model.CommissioningErrorEnum.OK
return response

def set_regulatory_config(
self, args: data_model.GeneralCommissioningCluster.SetRegulatoryConfig
self, session, args: data_model.GeneralCommissioningCluster.SetRegulatoryConfig
) -> data_model.GeneralCommissioningCluster.SetRegulatoryConfigResponse:
response = data_model.GeneralCommissioningCluster.SetRegulatoryConfigResponse()
response.ErrorCode = data_model.CommissioningErrorEnum.OK
return response


class AttestationElements(tlv.Structure):
certification_declaration = tlv.OctetStringMember(0x01, max_length=400)
attestation_nonce = tlv.OctetStringMember(0x02, max_length=32)
timestamp = tlv.IntMember(0x03, signed=False, octets=4)
firmware_information = tlv.OctetStringMember(0x04, max_length=16, optional=True)
"""Used for secure boot. We don't support it."""


class NOCSRElements(tlv.Structure):
csr = tlv.OctetStringMember(0x01, max_length=1024)
CSRNonce = tlv.OctetStringMember(0x02, max_length=32)
# Skip vendor reserved


class NodeOperationalCredentialsCluster(data_model.NodeOperationalCredentialsCluster):
def __init__(self):
self.dac_key = ecdsa.keys.SigningKey.from_der(
TEST_DAC_KEY_DER.read_bytes(), hashfunc=hashlib.sha256
)

self.new_key_for_update = False

def certificate_chain_request(
self, args: data_model.NodeOperationalCredentialsCluster.CertificateChainRequest
self,
session,
args: data_model.NodeOperationalCredentialsCluster.CertificateChainRequest,
) -> data_model.NodeOperationalCredentialsCluster.CertificateChainResponse:
response = (
data_model.NodeOperationalCredentialsCluster.CertificateChainResponse()
)
if args.CertificateType == data_model.CertificateChainTypeEnum.PAI:
print("PAI")
response.Certificate = TEST_PAI_CERT_DER.read_bytes()
elif args.CertificateType == data_model.CertificateChainTypeEnum.DAC:
print("DAC")
response.Certificate = b""
response.Certificate = TEST_DAC_CERT_DER.read_bytes()
return response

def attestation_request(
self, args: data_model.NodeOperationalCredentialsCluster.AttestationRequest
self,
session,
args: data_model.NodeOperationalCredentialsCluster.AttestationRequest,
) -> data_model.NodeOperationalCredentialsCluster.AttestationResponse:
print("attestation")
elements = AttestationElements()
elements.certification_declaration = TEST_CD_CERT_DER.read_bytes()
elements.attestation_nonce = args.AttestationNonce
elements.timestamp = int(time.time())
elements = elements.encode()
print("elements", len(elements), elements[:3].hex(" "))
print(
"challeng",
len(session.attestation_challenge),
session.attestation_challenge[:3].hex(" "),
)
attestation_tbs = elements.tobytes() + session.attestation_challenge
response = data_model.NodeOperationalCredentialsCluster.AttestationResponse()
response.AttestationElements = b""
response.AttestationSignature = b""
response.AttestationElements = elements
response.AttestationSignature = self.dac_key.sign_deterministic(
attestation_tbs,
hashfunc=hashlib.sha256,
sigencode=ecdsa.util.sigencode_string,
)
return response

def csr_request(
self, session, args: data_model.NodeOperationalCredentialsCluster.CsrRequest
) -> data_model.NodeOperationalCredentialsCluster.CsrResponse:
# Section 6.4.6.1
# CSR stands for Certificate Signing Request. A NOCSR is a Node Operational Certificate Signing Request

self.new_key_for_update = args.IsForUpdateNOC

# class CSRRequest(tlv.Structure):
# CSRNonce = tlv.OctetStringMember(0, 32)
# IsForUpdateNOC = tlv.BoolMember(1, optional=True, default=False)

# Generate a new key pair.
new_key_csr = b"TODO"

# Create a CSR to reply back with. Sign it with the new private key.
elements = NOCSRElements()
elements.csr = new_key_csr
elements.CSRNonce = args.CSRNonce
elements = elements.encode()
nocsr_tbs = elements.tobytes() + session.attestation_challenge

# class CSRResponse(tlv.Structure):
# NOCSRElements = tlv.OctetStringMember(0, RESP_MAX)
# AttestationSignature = tlv.OctetStringMember(1, 64)
response = data_model.NodeOperationalCredentialsCluster.CsrResponse()
response.NOCSRElements = elements
response.AttestationSignature = self.dac_key.sign_deterministic(
nocsr_tbs, hashfunc=hashlib.sha256, sigencode=ecdsa.util.sigencode_string
)
return response


Expand All @@ -933,7 +1024,7 @@ def __init__(
random_source,
state_filename,
vendor_id=0xFFF1,
product_id=0,
product_id=0x8000,
):
self.socketpool = socketpool
self.mdns_server = mdns_server
Expand Down Expand Up @@ -963,7 +1054,7 @@ def __init__(
self.socket.bind((UDP_IP, self.UDP_PORT))
self.socket.setblocking(False)

self.manager = SessionManager(self.socket)
self.manager = SessionManager(self.random, self.socket)

print(f"Listening on UDP port {self.UDP_PORT}")

Expand Down Expand Up @@ -1039,10 +1130,10 @@ def get_report(self, cluster, path):
# report.AttributeStatus = astatus
return report

def invoke(self, cluster, path, fields, command_ref):
def invoke(self, session, cluster, path, fields, command_ref):
print("invoke", path)
response = interaction_model.InvokeResponseIB()
cdata = cluster.invoke(path, fields)
cdata = cluster.invoke(session, path, fields)
if cdata is None:
cstatus = interaction_model.CommandStatusIB()
cstatus.CommandPath = path
Expand Down Expand Up @@ -1216,6 +1307,9 @@ def process_packet(self, address, data):
elif protocol_opcode == SecureProtocolOpcode.ICD_CHECK_IN:
print("Received ICD Check-in")
elif message.protocol_id == ProtocolId.INTERACTION_MODEL:
secure_session_context = self.manager.secure_session_contexts[
message.session_id
]
if protocol_opcode == InteractionModelOpcode.READ_REQUEST:
print("Received Read Request")
read_request, _ = interaction_model.ReadRequestMessage.decode(
Expand Down Expand Up @@ -1272,7 +1366,12 @@ def process_packet(self, address, data):
cluster = self._endpoints[endpoint][path.Cluster]
path.Endpoint = endpoint
invoke_responses.append(
self.invoke(cluster, path, invoke.CommandFields)
self.invoke(
secure_session_context,
cluster,
path,
invoke.CommandFields,
)
)
else:
print(f"Cluster 0x{path.Cluster:02x} not found")
Expand All @@ -1281,6 +1380,7 @@ def process_packet(self, address, data):
cluster = self._endpoints[path.Endpoint][path.Cluster]
invoke_responses.append(
self.invoke(
secure_session_context,
cluster,
path,
invoke.CommandFields,
Expand Down
Loading

0 comments on commit 672c864

Please sign in to comment.