Skip to content

Commit

Permalink
syscalls: Implement validations automatically for all fat pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
sysheap committed Jan 4, 2025
1 parent 9c128ba commit e50068b
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 77 deletions.
1 change: 1 addition & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#![feature(macro_metavar_expr_concat)]
#![feature(auto_traits)]
#![feature(negative_impls)]
#![feature(str_from_raw_parts)]

pub mod array_vec;
pub mod big_endian;
Expand Down
40 changes: 21 additions & 19 deletions common/src/pointer.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
/// This trait both abstracts *const T and *mut T
/// It can be used if a method can receive both types of pointers
pub trait Pointer<T>: Clone + Copy {
pub trait Pointer: Clone + Copy {
type Pointee;

fn as_raw(&self) -> usize;
fn as_pointer(ptr: usize) -> Self;
}

impl<T> Pointer<T> for *const T {
impl<T> Pointer for *const T {
type Pointee = T;

fn as_raw(&self) -> usize {
*self as usize
}
Expand All @@ -15,7 +19,9 @@ impl<T> Pointer<T> for *const T {
}
}

impl<T> Pointer<T> for *mut T {
impl<T> Pointer for *mut T {
type Pointee = T;

fn as_raw(&self) -> usize {
*self as usize
}
Expand All @@ -30,50 +36,46 @@ pub struct FatPointer<Ptr> {
len: usize,
}

impl<Ptr: Clone + Copy> FatPointer<Ptr> {
fn new(ptr: Ptr, len: usize) -> Self {
impl<Ptr: Pointer> FatPointer<Ptr> {
pub fn new(ptr: Ptr, len: usize) -> Self {
Self { ptr, len }
}

pub fn ptr(&self) -> Ptr {
self.ptr
}

#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.len
}
}

pub trait AsFatPointer {
type T;
fn as_fat_pointer(&self) -> FatPointer<*const Self::T>;
}

pub trait AsFatPointerMut {
type T;
fn as_fat_pointer_mut(&mut self) -> FatPointer<*mut Self::T>;
type Ptr;
fn to_fat_pointer(self) -> FatPointer<Self::Ptr>;
}

impl AsFatPointer for &str {
type T = u8;
type Ptr = *const u8;

fn as_fat_pointer(&self) -> FatPointer<*const u8> {
fn to_fat_pointer(self) -> FatPointer<Self::Ptr> {
FatPointer::new(self.as_ptr(), self.len())
}
}

impl<T> AsFatPointer for &[T] {
type T = T;
type Ptr = *const T;

fn as_fat_pointer(&self) -> FatPointer<*const T> {
fn to_fat_pointer(self) -> FatPointer<Self::Ptr> {
FatPointer::new(self.as_ptr(), self.len())
}
}

impl<T> AsFatPointerMut for &mut [T] {
type T = T;
impl<T> AsFatPointer for &mut [T] {
type Ptr = *mut T;

fn as_fat_pointer_mut(&mut self) -> FatPointer<*mut T> {
fn to_fat_pointer(self) -> FatPointer<Self::Ptr> {
FatPointer::new(self.as_mut_ptr(), self.len())
}
}
36 changes: 31 additions & 5 deletions common/src/ref_conversion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pointer::{AsFatPointer, AsFatPointerMut, FatPointer};
use crate::pointer::{AsFatPointer, FatPointer};

auto trait IsValue {}

Expand All @@ -11,6 +11,8 @@ impl<T> !IsValue for &mut T {}
pub trait RefToPointer<T> {
type Out;
fn to_pointer_if_ref(self) -> Self::Out;
#[allow(clippy::missing_safety_doc)]
unsafe fn to_ref_if_pointer(input: Self::Out) -> Self;
}

impl<T: IsValue> RefToPointer<T> for T {
Expand All @@ -19,6 +21,10 @@ impl<T: IsValue> RefToPointer<T> for T {
fn to_pointer_if_ref(self) -> Self::Out {
self
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
input
}
}

impl<T> RefToPointer<T> for &T {
Expand All @@ -27,6 +33,10 @@ impl<T> RefToPointer<T> for &T {
fn to_pointer_if_ref(self) -> Self::Out {
self
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
&*input
}
}

impl<T> RefToPointer<T> for &mut T {
Expand All @@ -35,28 +45,44 @@ impl<T> RefToPointer<T> for &mut T {
fn to_pointer_if_ref(self) -> Self::Out {
self
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
&mut *input
}
}

impl RefToPointer<&str> for &str {
type Out = FatPointer<*const u8>;

fn to_pointer_if_ref(self) -> Self::Out {
self.as_fat_pointer()
self.to_fat_pointer()
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
core::str::from_raw_parts(input.ptr(), input.len())
}
}

impl RefToPointer<&[u8]> for &[u8] {
type Out = FatPointer<*const u8>;

fn to_pointer_if_ref(self) -> Self::Out {
self.as_fat_pointer()
self.to_fat_pointer()
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
core::slice::from_raw_parts(input.ptr(), input.len())
}
}

impl RefToPointer<&mut [u8]> for &mut [u8] {
type Out = FatPointer<*mut u8>;

fn to_pointer_if_ref(mut self) -> Self::Out {
self.as_fat_pointer_mut()
fn to_pointer_if_ref(self) -> Self::Out {
self.to_fat_pointer()
}

unsafe fn to_ref_if_pointer(input: Self::Out) -> Self {
core::slice::from_raw_parts_mut(input.ptr(), input.len())
}
}
2 changes: 1 addition & 1 deletion common/src/syscalls/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ macro_rules! syscalls {
$(fn $name$(<$lt>)?(&mut self, $($arg_name: Self::ArgWrapper<$arg_ty>),*) -> $ret;)*

/// Validate a pointer such that it is a valid userspace pointer
fn validate_and_translate_pointer<T, PTR: $crate::pointer::Pointer<T>>(&self, ptr: PTR) -> Option<PTR>;
fn validate_and_translate_pointer<PTR: $crate::pointer::Pointer>(&self, ptr: PTR) -> Option<PTR>;

fn dispatch(&mut self, nr: usize, arg: usize, ret: usize) -> $crate::syscalls::SyscallStatus {
use $crate::syscalls::SyscallStatus;
Expand Down
10 changes: 5 additions & 5 deletions kernel/src/memory/page_tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,14 @@ impl RootPageTableHolder {
.is_some_and(|entry| entry.get_validity() && entry.get_user_mode_accessible())
}

pub fn is_valid_userspace_fat_ptr<T>(
pub fn is_valid_userspace_fat_ptr<PTR: Pointer>(
&self,
ptr: impl Pointer<T>,
ptr: PTR,
len: usize,
writable: bool,
) -> bool {
let start = ptr.as_raw();
let end = start + (core::mem::size_of::<T>() * len);
let end = start + (core::mem::size_of::<PTR::Pointee>() * len);
// We only need to check for each PAGE_SIZE step if it is mapped
for addr in (start..end).step_by(PAGE_SIZE) {
let entry = unwrap_or_return!(self.get_page_table_entry_for_address(addr), false);
Expand All @@ -509,11 +509,11 @@ impl RootPageTableHolder {
true
}

pub fn is_valid_userspace_ptr<T>(&self, ptr: impl Pointer<T>, writable: bool) -> bool {
pub fn is_valid_userspace_ptr(&self, ptr: impl Pointer, writable: bool) -> bool {
self.is_valid_userspace_fat_ptr(ptr, 1, writable)
}

pub fn translate_userspace_address_to_physical_address<T, PTR: Pointer<T>>(
pub fn translate_userspace_address_to_physical_address<PTR: Pointer>(
&self,
ptr: PTR,
) -> Option<PTR> {
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/syscalls/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl KernelSyscalls for SyscallHandler {
}

#[doc = r" Validate a pointer such that it is a valid userspace pointer"]
fn validate_and_translate_pointer<T, PTR: Pointer<T>>(&self, ptr: PTR) -> Option<PTR> {
fn validate_and_translate_pointer<PTR: Pointer>(&self, ptr: PTR) -> Option<PTR> {
self.current_process.with_lock(|p| {
let pt = p.get_page_table();
if !pt.is_valid_userspace_ptr(ptr, true) {
Expand Down
56 changes: 10 additions & 46 deletions kernel/src/syscalls/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use core::ops::{Deref, DerefMut};
use common::{
constructable::Constructable,
net::UDPDescriptor,
pointer::{FatPointer, Pointer},
ref_conversion::RefToPointer,
syscalls::{SysSocketError, ValidationError},
unwrap_or_return,
Expand Down Expand Up @@ -46,53 +47,16 @@ impl Validatable<SharedAssignedSocket> for UserspaceArgument<UDPDescriptor> {
}
}

impl<'a> Validatable<&'a str> for UserspaceArgument<&'a str> {
/// I know this is really unreadable. However, I like to learn the power of
/// the type system and traits.
/// What this impl does is basically implement validation for all RefToPointer
/// where Out is an FatPointer.
impl<Ptr: Pointer, T: RefToPointer<T, Out = FatPointer<Ptr>>> Validatable<T>
for UserspaceArgument<T>
{
type Error = ValidationError;
fn validate(self, handler: &mut SyscallHandler) -> Result<&'a str, Self::Error> {
let start = self.inner.ptr();
let len = self.inner.len();
let ptr = handler.current_process().with_lock(|p| {
let pt = p.get_page_table();
if !pt.is_valid_userspace_fat_ptr(start, len, false) {
return None;
}
pt.translate_userspace_address_to_physical_address(start)
});

if let Some(ptr) = ptr {
// SAFETY: We validated the pointer above
unsafe { Ok(core::str::from_raw_parts(ptr, len)) }
} else {
Err(ValidationError::InvalidPtr)
}
}
}

impl<'a> Validatable<&'a [u8]> for UserspaceArgument<&'a [u8]> {
type Error = ValidationError;
fn validate(self, handler: &mut SyscallHandler) -> Result<&'a [u8], Self::Error> {
let start = self.inner.ptr();
let len = self.inner.len();
let ptr = handler.current_process().with_lock(|p| {
let pt = p.get_page_table();
if !pt.is_valid_userspace_fat_ptr(start, len, false) {
return None;
}
pt.translate_userspace_address_to_physical_address(start)
});

if let Some(ptr) = ptr {
// SAFETY: We validated the pointer above
unsafe { Ok(core::slice::from_raw_parts(ptr, len)) }
} else {
Err(ValidationError::InvalidPtr)
}
}
}

impl<'a> Validatable<&'a mut [u8]> for UserspaceArgument<&'a mut [u8]> {
type Error = ValidationError;
fn validate(self, handler: &mut SyscallHandler) -> Result<&'a mut [u8], Self::Error> {
fn validate(self, handler: &mut SyscallHandler) -> Result<T, Self::Error> {
let start = self.inner.ptr();
let len = self.inner.len();
let ptr = handler.current_process().with_lock(|p| {
Expand All @@ -105,7 +69,7 @@ impl<'a> Validatable<&'a mut [u8]> for UserspaceArgument<&'a mut [u8]> {

if let Some(ptr) = ptr {
// SAFETY: We validated the pointer above
unsafe { Ok(core::slice::from_raw_parts_mut(ptr, len)) }
unsafe { Ok(T::to_ref_if_pointer(FatPointer::new(ptr, len))) }
} else {
Err(ValidationError::InvalidPtr)
}
Expand Down

0 comments on commit e50068b

Please sign in to comment.