Skip to content

Commit

Permalink
Added secure cleanup for most functions
Browse files Browse the repository at this point in the history
- Moved 'secure_cleanup' in 'wallet_client.py' to the DataManipulation class in 'cryptographic_util.py'
- Removed redundant 'secure_delete' function from 'wallet_client.py' because it was already in the DataManipulation of 'cryptographic_util.py'
- Most functions will now call the secure_cleanup method in the DataManipulation class whenever there is an error, return statment, or basically any exit point. This ensures that any (potential) sensitive variables are deleted but also purged from memory. This minimizes the potential risks associated with sensitive data lingering in memory. However proformance is slower as a result.
- Fixed the path to the denaro qr logo in the 'handle_new_encrypted_wallet'. It was causeing the script to error out when generating a new encrypted wallet with 2-Factor Authentication.
  • Loading branch information
The-Sycorax committed Oct 19, 2023
1 parent 7974c5f commit 5d03370
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 141 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ __pycache__/
*.py[cod]
env
denaro/wallet/wallets
wallets
wallets
testing.sh
80 changes: 63 additions & 17 deletions denaro/key_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from math import ceil
from datetime import datetime, timezone
from typing import Union
import os

# Importing third-party libraries
import base58
Expand All @@ -18,6 +19,12 @@
from icecream import ic
import binascii

# Get the absolute path of the directory containing the current script.
dir_path = os.path.dirname(os.path.realpath(__file__))
# Insert folder paths for modules
sys.path.insert(0, dir_path + "/wallet")
from denaro.wallet.cryptographic_util import DataManipulation

# Custom print function definition
_print = print # Saving the original print function for later use

Expand All @@ -36,7 +43,9 @@ def log(s):
Parameters:
s (str): Message to log
"""
logging.getLogger('denaro').info(s) # Logging the message under the 'denaro' namespace
# Logging the message under the 'denaro' namespace
logging.getLogger('denaro').info(s)
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None])

# Configure Icecream for custom logging
ic.configureOutput(outputFunction=log) # Redirecting icecream output to custom log function
Expand All @@ -52,7 +61,9 @@ def get_json(obj):
dict: Object as a dictionary.
"""
# Convert object to JSON and then deserialize it to dictionary
return json.loads(json.dumps(obj, default=lambda o: getattr(o, 'as_dict', getattr(o, '__dict__', str(o)))))
result = json.loads(json.dumps(obj, default=lambda o: getattr(o, 'as_dict', getattr(o, '__dict__', str(o)))))
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def timestamp():
"""
Expand All @@ -62,7 +73,9 @@ def timestamp():
int: Current timestamp in UTC timezone.
"""
# Getting current time, setting it to UTC and returning its timestamp
return int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())
result = int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def sha256(message: Union[str, bytes]):
"""
Expand All @@ -78,7 +91,9 @@ def sha256(message: Union[str, bytes]):
if isinstance(message, str):
message = bytes.fromhex(message)
# Calculate SHA-256 hash and return it as a hexadecimal string
return hashlib.sha256(message).hexdigest()
result = hashlib.sha256(message).hexdigest()
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def byte_length(i: int):
"""
Expand All @@ -91,7 +106,9 @@ def byte_length(i: int):
int: Byte length of the integer.
"""
# Calculate byte length using bit length and ceiling function
return ceil(i.bit_length() / 8.0)
result = ceil(i.bit_length() / 8.0)
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def normalize_block(block) -> dict:
"""
Expand All @@ -109,6 +126,7 @@ def normalize_block(block) -> dict:
block['address'] = block['address'].strip(' ')
# Convert and normalize the 'timestamp' field to UTC timestamp
block['timestamp'] = int(block['timestamp'].replace(tzinfo=timezone.utc).timestamp())
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not block])
return block

def x_to_y(x: int, is_odd: bool = False):
Expand All @@ -129,7 +147,9 @@ def x_to_y(x: int, is_odd: bool = False):
# Compute the square root of y^2 modulo p
y_res, y_mod = mod_sqrt(y2, p)
# Return either y_res or y_mod based on whether y should be odd
return y_res if y_res % 2 == is_odd else y_mod
result = y_res if y_res % 2 == is_odd else y_mod
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

class AddressFormat(Enum):
"""
Expand All @@ -151,13 +171,18 @@ def bytes_to_point(point_bytes: bytes) -> Point:
# If the byte length is 64, it's a full point (x and y coordinates)
if len(point_bytes) == 64:
x, y = int.from_bytes(point_bytes[:32], ENDIAN), int.from_bytes(point_bytes[32:], ENDIAN) # Extract x and y from bytes
return Point(x, y, CURVE) # Return as Point object
result = Point(x, y, CURVE) # Return as Point object
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
# If the byte length is 33, it's a compressed point
elif len(point_bytes) == 33:
specifier = point_bytes[0] # First byte is the specifier for odd/even y-coordinate
x = int.from_bytes(point_bytes[1:], ENDIAN) # Extract x from the bytes
return Point(x, x_to_y(x, specifier == 43)) # Compute y and return as Point object
result = Point(x, x_to_y(x, specifier == 43)) # Compute y and return as Point object
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
else:
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None])
# Unsupported byte length
raise NotImplementedError()

Expand All @@ -178,9 +203,12 @@ def bytes_to_string(point_bytes: bytes) -> str:
elif len(point_bytes) == 33:
address_format = AddressFormat.COMPRESSED # Compressed format
else:
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None])
# Unsupported byte length
raise NotImplementedError()
return point_to_string(point, address_format) # Convert point to string based on the determined format
result = point_to_string(point, address_format) # Convert point to string based on the determined format
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def point_to_bytes(point: Point, address_format: AddressFormat = AddressFormat.FULL_HEX) -> bytes:
"""
Expand All @@ -195,11 +223,16 @@ def point_to_bytes(point: Point, address_format: AddressFormat = AddressFormat.F
"""
# If full hexadecimal format is chosen
if address_format is AddressFormat.FULL_HEX:
return point.x.to_bytes(32, byteorder=ENDIAN) + point.y.to_bytes(32, byteorder=ENDIAN)
result = point.x.to_bytes(32, byteorder=ENDIAN) + point.y.to_bytes(32, byteorder=ENDIAN)
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
# If compressed format is chosen
elif address_format is AddressFormat.COMPRESSED:
return string_to_bytes(point_to_string(point, AddressFormat.COMPRESSED))
result = string_to_bytes(point_to_string(point, AddressFormat.COMPRESSED))
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
else:
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None])
# Raise an exception for unsupported formats
raise NotImplementedError()

Expand All @@ -218,13 +251,18 @@ def point_to_string(point: Point, address_format: AddressFormat = AddressFormat.
# For full hexadecimal format
if address_format is AddressFormat.FULL_HEX:
point_bytes = point_to_bytes(point) # Convert point to bytes
return point_bytes.hex() # Convert bytes to hexadecimal string
result = point_bytes.hex() # Convert bytes to hexadecimal string
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
# For compressed format
elif address_format is AddressFormat.COMPRESSED:
# Convert point to Base58 string
address = base58.b58encode((42 if y % 2 == 0 else 43).to_bytes(1, ENDIAN) + x.to_bytes(32, ENDIAN))
return address if isinstance(address, str) else address.decode('utf-8') # Ensure the result is a string
result = address if isinstance(address, str) else address.decode('utf-8') # Ensure the result is a string
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result
else:
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None])
# Unsupported format
raise NotImplementedError()

Expand All @@ -241,9 +279,11 @@ def string_to_bytes(string: str) -> bytes:
try:
# Try to convert from hexadecimal to bytes
point_bytes = bytes.fromhex(string)
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not point_bytes])
except ValueError:
# If not hexadecimal, assume it's Base58 and decode it
point_bytes = base58.b58decode(string)
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not point_bytes])
return point_bytes

def string_to_point(string: str):
Expand All @@ -257,7 +297,9 @@ def string_to_point(string: str):
Point: The converted ECDSA point.
"""
# Convert the string to bytes and then to an ECDSA point
return bytes_to_point(string_to_bytes(string))
result = bytes_to_point(string_to_bytes(string))
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def hex_to_point(x_hex: str, y_hex: str, curve_obj):
"""
Expand All @@ -273,7 +315,9 @@ def hex_to_point(x_hex: str, y_hex: str, curve_obj):
"""
x_int = int(x_hex, 16) # Convert x from hex to integer
y_int = int(y_hex, 16) # Convert y from hex to integer
return Point(x_int, y_int, curve_obj) # Create and return the ECDSA point
result = Point(x_int, y_int, curve_obj) # Create and return the ECDSA point
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def private_to_public_key_fastecdsa(private_key_hex):
"""
Expand All @@ -298,7 +342,9 @@ def private_to_public_key_fastecdsa(private_key_hex):
compressed_public_key = prefix + format(public_point.x, '064x')

# Return the public point and its compressed representation
return public_point, compressed_public_key
result = public_point, compressed_public_key
DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result

def generate(mnemonic_phrase=None, passphrase=None, index=0, deterministic=False, fields=None):
"""
Expand Down Expand Up @@ -367,5 +413,5 @@ def generate(mnemonic_phrase=None, passphrase=None, index=0, deterministic=False
result["public_key"] = public_key_hex
if "address" in fields:
result["address"] = address

DataManipulation.secure_cleanup([var for var in locals().values() if var is not None and var is not result])
return result # Return the generated information as a dictionary
Loading

0 comments on commit 5d03370

Please sign in to comment.