Skip to content

Commit

Permalink
add inline return
Browse files Browse the repository at this point in the history
  • Loading branch information
antonkesy committed Mar 14, 2024
1 parent aed94ba commit 3a41eea
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 35 deletions.
14 changes: 12 additions & 2 deletions examples/fibonacci.mmm
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
void main() {
int fib = fibbonacci_single_return(15);
println(str(fib));
int fib_v1 = fibbonacci_inline(15);
println(str(fib_v1));

int fib_v2 = fibbonacci_single_return(15);
println(str(fib_v2));
}

int fibbonacci_single_return(int n) {
Expand All @@ -10,3 +13,10 @@ int fibbonacci_single_return(int n) {
}
return ret;
}

int fibbonacci_inline(int n) {
if n < 2 {
return n;
}
return fibbonacci_inline(n - 1) + fibbonacci_inline(n - 2);
}
2 changes: 1 addition & 1 deletion src/Interpreter/BuiltIn.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ toString =
[IntValue i] -> show i
[FloatValue f] -> show f
[BoolValue b] -> show b
_ -> error "No matching type for str"
_ -> error ("No matching type for str: " ++ show val)
92 changes: 61 additions & 31 deletions src/Interpreter/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ module Interpreter.Interpreter (module Interpreter.Interpreter) where

import AST
import Control.Monad (foldM)
import qualified Data.Functor
import Data.Map.Strict as Map
import qualified Debug.Trace as Debug
import Interpreter.BuiltIn
import Interpreter.Manipulator
import Interpreter.ProgramState
Expand All @@ -16,6 +18,7 @@ interpret (Program statements) = do
-- putStrLn "Valid program"
let correctedStatments = ensureEntryPoint statements
let functionMap = getFunctionMap correctedStatments
-- print correctedStatments
_ <- foldM interpretStatement (InterpretState (ProgramState empty functionMap) Nothing) correctedStatments
-- putStrLn $ "End state: " ++ show endState
return ()
Expand All @@ -24,78 +27,104 @@ interpret (Program statements) = do

interpretStatement :: InterpretState -> Statement -> IO InterpretState
interpretStatement (InterpretState state _) (VariableStatement (Variable (VariableDeclaration name _) expression)) = do
value <- interpretExpression state expression
return (InterpretState (updateState state name value) Nothing)
(ScopeResult innerVars ret) <- interpretExpression state expression
let newState = updateOuterStateV state innerVars
return (InterpretState (updateState newState name ret) Nothing)
interpretStatement (InterpretState state _) (AssignmentStatement (Assignment name expression)) = do
value <- interpretExpression state expression
return (InterpretState (updateState state name value) Nothing)
(ScopeResult innerVars ret) <- interpretExpression state expression
let newState = updateOuterStateV state innerVars
return (InterpretState (updateState newState name ret) Nothing)
interpretStatement (InterpretState state _) (ExpressionStatement expression) = do
_ <- interpretExpression state expression
return (InterpretState state Nothing)
interpretStatement (InterpretState state _) (FunctionDefinitionStatement _) = do
return (InterpretState state Nothing)
interpretStatement (InterpretState state _) (ReturnStatement expression) = do
-- TODO: return should cancel the current function
ret <- interpretExpression state expression
return (InterpretState state (Just ret))
(ScopeResult innerVars ret) <- interpretExpression state expression
let newState = updateOuterStateV state innerVars
return (InterpretState newState ret)
interpretStatement (InterpretState state _) (ControlStatement control) = do
interpretControl state control

updateState :: ProgramState -> Name -> Value -> ProgramState
updateState (ProgramState vars funs) name value = ProgramState (Map.insert name value vars) funs
updateState :: ProgramState -> Name -> Maybe Value -> ProgramState
updateState (ProgramState vars funs) name value = do
case value of
Just v -> ProgramState (Map.insert name v vars) funs
Nothing -> ProgramState vars funs

-- Update variable in outer scope
updateOuterState :: ProgramState -> ProgramState -> ProgramState
updateOuterState (ProgramState outerVars funs) (ProgramState innerVars _) =
ProgramState (Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars) funs

-- if Map.member name vars then updateState (ProgramState vars funs) name value else ProgramState vars funs
updateOuterStateV :: ProgramState -> Map Name Value -> ProgramState
updateOuterStateV (ProgramState outerVars funs) innerVars =
ProgramState
(Map.unionWithKey (\_ inner _outer -> inner) innerVars outerVars)
funs

interpretExpression :: ProgramState -> Expression -> IO Value
interpretExpression :: ProgramState -> Expression -> IO ScopeResult
interpretExpression state (AtomicExpression atomic) = do
interpretAtomic state atomic
interpretExpression state (OperationExpression left operator right) = do
leftValue <- interpretExpression state left
rightValue <- interpretExpression state right
(ScopeResult _ (Just leftValue)) <- interpretExpression state left
(ScopeResult _ (Just rightValue)) <- interpretExpression state right
let value = interpretOperation operator leftValue rightValue
return value
return (ScopeResult (variables state) (Just value))

interpretAtomic :: ProgramState -> Atomic -> IO Value
interpretAtomic _ (LiteralAtomic literal) = do
interpretLiteral literal
interpretAtomic :: ProgramState -> Atomic -> IO ScopeResult
interpretAtomic (ProgramState vars _) (LiteralAtomic literal) = do
ret <- interpretLiteral literal
return (ScopeResult vars (Just ret))
interpretAtomic (ProgramState vars _) (VariableAtomic name) = do
let varValue = Map.lookup name vars
return $ case varValue of
Just value -> value
Just value -> ScopeResult vars (Just value)
Nothing -> error $ "Variable not found: " ++ name
interpretAtomic (ProgramState vars funs) (FunctionCallAtomic name args) = do
let isBuiltIn = Map.lookup name getAllBuiltIns
case isBuiltIn of
Just (BuiltIn _ _ fn) -> do
argValues <- getArgValues args
fn argValues
ret <- fn argValues
return (ScopeResult vars (Just ret))
Nothing -> do
let fun = Map.lookup name funs
case fun of
Just (FunctionDefinitionStatement (Function _ argDef _ body)) -> do
params <- mapExpressionToParam argDef args
let fnScope = ProgramState (Map.union params vars) funs
(InterpretState _ ret) <- foldM interpretStatement (InterpretState fnScope Nothing) body
case ret of
Just value -> return value
Nothing -> error $ "Function did not return a value: " ++ name
Nothing -> error $ "Function not found: " ++ name
(ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState fnScope Nothing) body True
let (ProgramState newVars _) = updateOuterStateV (ProgramState vars funs) innerVars
return (ScopeResult newVars ret)
_ -> error $ "Function not found: " ++ name
where
getArgValues :: [Expression] -> IO [Value]
getArgValues = mapM (interpretExpression (ProgramState vars funs))
getArgValues exprs =
mapM
(interpretExpression (ProgramState vars funs))
exprs
Data.Functor.<&> Prelude.map (\(ScopeResult _ (Just v)) -> v)

mapExpressionToParam :: [VariableDeclaration] -> [Expression] -> IO (Map Name Value)
mapExpressionToParam [] [] = pure Map.empty
mapExpressionToParam (VariableDeclaration n _ : rest) (expression : restExp) = do
val <- interpretExpression (ProgramState vars funs) expression
(ScopeResult _ (Just val)) <- interpretExpression (ProgramState vars funs) expression
restMap <- mapExpressionToParam rest restExp
return (Map.insert n val restMap)
mapExpressionToParam _ _ = error "Invalid number of arguments"

returnSkipWrapper :: InterpretState -> [Statement] -> Bool -> IO ScopeResult
returnSkipWrapper state (stmt : rest) inFunction = do
(InterpretState s ret) <- interpretStatement state stmt
case ret of
Just value -> return (ScopeResult (variables s) (Just value))
Nothing -> returnSkipWrapper (InterpretState s Nothing) rest inFunction
returnSkipWrapper state [] inFunction =
if inFunction
then error "missing return"
else return (ScopeResult (variables (programState state)) Nothing)

interpretLiteral :: Literal -> IO Value
interpretLiteral (IntLiteral value) = do
return $ IntValue value
Expand Down Expand Up @@ -128,13 +157,14 @@ interpretControl (ProgramState vars funs) (IfControl test body elseBody) = do
(BoolValue testValue) <- isTestValue (ProgramState vars funs) test
if testValue
then do
(InterpretState innerVars ret) <- foldM interpretStatement (InterpretState (ProgramState vars funs) Nothing) body
return $ InterpretState (updateOuterState (ProgramState vars funs) innerVars) ret
(ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs) Nothing) body False
return $ InterpretState (updateOuterStateV (ProgramState vars funs) innerVars) ret
else do
case elseBody of
Just elseStatements -> do
(InterpretState innerVars ret) <- foldM interpretStatement (InterpretState (ProgramState vars funs) Nothing) elseStatements
return $ InterpretState (updateOuterState (ProgramState vars funs) innerVars) ret
-- TODO: extract cancellable statements function
(ScopeResult innerVars ret) <- returnSkipWrapper (InterpretState (ProgramState vars funs) Nothing) elseStatements False
return $ InterpretState (updateOuterStateV (ProgramState vars funs) innerVars) ret
Nothing -> return $ InterpretState (ProgramState vars funs) Nothing
interpretControl (ProgramState vars funs) (WhileControl test body) = do
(BoolValue testValue) <- isTestValue (ProgramState vars funs) test
Expand All @@ -148,7 +178,7 @@ interpretControl (ProgramState vars funs) (WhileControl test body) = do

isTestValue :: ProgramState -> Expression -> IO Value
isTestValue (ProgramState vars funs) test = do
testValue <- interpretExpression (ProgramState vars funs) test
(ScopeResult _ (Just testValue)) <- interpretExpression (ProgramState vars funs) test
if not (isBoolValue testValue)
then do error "Control statement test must be a boolean value."
else return testValue
Expand Down
5 changes: 4 additions & 1 deletion src/Interpreter/ProgramState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Interpreter.ProgramState (module Interpreter.ProgramState) where
import AST
import Data.Map.Strict as Map

data Value = IntValue Int | FloatValue Float | BoolValue Bool | UnitValue | StringValue String
data Value = IntValue Int | FloatValue Float | BoolValue Bool | UnitValue | StringValue String | InterpreterErrorValue String
deriving (Show, Eq)

data ProgramState where
Expand All @@ -15,3 +15,6 @@ data ProgramState where
data InterpretState where
InterpretState :: {programState :: ProgramState, returnValue :: Maybe Value} -> InterpretState
deriving (Show, Eq)

data ScopeResult = ScopeResult (Map Name Value) (Maybe Value)
deriving (Show, Eq)

0 comments on commit 3a41eea

Please sign in to comment.