Pārlūkot izejas kodu

Check for recursive types in toEnv function

Getty Ritter 9 gadi atpakaļ
vecāks
revīzija
69f330ec3c
2 mainītis faili ar 28 papildinājumiem un 16 dzēšanām
  1. 2 1
      gidl.cabal
  2. 26 15
      src/Gidl/Parse.hs

+ 2 - 1
gidl.cabal

@@ -51,7 +51,8 @@ library
                        transformers,
                        ivory-artifact,
                        s-cargot,
-                       text
+                       text,
+                       mtl
   hs-source-dirs:      src
   default-language:    Haskell2010
   ghc-options:         -Wall

+ 26 - 15
src/Gidl/Parse.hs

@@ -5,6 +5,7 @@ module Gidl.Parse (parseDecls) where
 
 import           Control.Applicative ((<$>), (<*>))
 import           Control.Monad ((>=>))
+import           Control.Monad.Reader (ask, lift, local, runReaderT)
 import           Data.List (partition, group, intercalate)
 import           Data.SCargot.Comments (withHaskellComments)
 import           Data.SCargot.General ( SExprSpec
@@ -52,7 +53,7 @@ toEnv decls = do
   unlessEmpty (duplicated interfaceNames)
       (\n -> "Interface named '" ++ n ++ "' declared multiple times")
 
-  typs <- mapM (getTypePair . getName) typDs
+  typs <- mapM (flip runReaderT [] . getTypePair . getName) typDs
   ifcs <- mapM getIfacePair interfaceNames
   return (TypeEnv typs, InterfaceEnv ifcs)
   where (typDs, ifcDs) = partition isTypeDecl decls
@@ -67,47 +68,51 @@ toEnv decls = do
         ifcMap = [(getName i, toInterface i) | i <- ifcDs]
 
         -- this is gross because I'm trying to make sure declarations
-        -- can happen in any order. XXX: prevent recursion!
+        -- can happen in any order.
         getType n = snd `fmap` getTypePair n
-        getTypePair n = case lookup n typMap of
-            Just (Right t) -> return t
-            Just (Left l)  -> Left l
-            Nothing        -> throw ("Unknown primitive type: " ++ n)
+        getTypePair n = do
+          env <- ask
+          if n `elem` env
+            then lift $ throw ("Types cannot be recursive.\n" ++
+                               showCycle env)
+            else case lookup n typMap of
+              Just rs -> rs
+              Nothing -> lift $ throw ("Unknown primitive type: " ++ n)
 
         getIface n = snd `fmap` getIfacePair n
         getIfacePair n = case lookup n ifcMap of
           Just (Right i) -> return i
           Just (Left l)  -> Left l
-          Nothing        -> throw ("Unknown interface: " ++ n)
+          Nothing        -> Left ("Unknown interface: " ++ n)
 
         getPrimType n = do
           t <- getType n
           case t of
             PrimType t' -> return t'
-            _ -> throw ("Expected primitive type but got " ++ show t)
+            _ -> lift $ throw ("Expected primitive type but got " ++ show t)
 
         -- converts a decl to an actual type
-        toType (NewtypeDecl n t) = do
+        toType (NewtypeDecl n t) = local (n:) $ do
           t' <- getPrimType t
           return (n, PrimType (Newtype n t'))
-        toType (EnumDecl (n, s) ts) = do
-          unlessEmpty (duplicated (map fst ts))
+        toType (EnumDecl (n, s) ts) = local (n:) $ do
+          lift $ unlessEmpty (duplicated (map fst ts))
               (\i -> "Enum identifier '" ++ i
                   ++ "' repeated in declaration of 'Enum " ++ n ++ "'")
-          unlessEmpty (duplicated (map snd ts))
+          lift $ unlessEmpty (duplicated (map snd ts))
               (\i -> "Enum value '" ++ (show i)
                   ++ "' repeated in declaration of 'Enum " ++ n ++ "'")
           return (n, PrimType (EnumType n s ts))
-        toType (StructDecl n ss) = do
+        toType (StructDecl n ss) = local (n:) $ do
           ps <- mapM (getPrimType . snd) ss
           return (n, StructType n (zip (map fst ss) ps))
         toType _ = error "[unreachable]"
 
         toMethod (n, AttrDecl perm t) = do
-          t' <- getType t
+          t' <- runReaderT (getType t) []
           return (n, AttrMethod perm t')
         toMethod (n, StreamDecl rate t) = do
-          t' <- getType t
+          t' <- runReaderT (getType t) []
           return (n, StreamMethod rate t')
 
         toInterface (InterfaceDecl n is ms) = do
@@ -124,6 +129,12 @@ toEnv decls = do
         isTypeDecl InterfaceDecl {} = False
         isTypeDecl _                = True
 
+        showCycle []        = error "[unreachable]"
+        showCycle [x]       = "  In recursive type `" ++ x ++ "`"
+        showCycle ls@(x:_)  = "  In mutually recursive cycle " ++ go ls
+          where go (y:ys) = "`" ++ y ++ "` => " ++ go ys
+                go []     = "`" ++ x ++ "`"
+
 parseDecls :: String -> Either String (TypeEnv, InterfaceEnv)
 parseDecls = return . pack >=> decode gidlSpec >=> toEnv