Rpc.hs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. module Gidl.Backend.Rpc (
  2. rpcBackend
  3. ) where
  4. import qualified Paths_gidl as P
  5. import Gidl.Backend.Cabal (cabalFileArtifact,defaultCabalFile,filePathToPackage)
  6. import Gidl.Backend.Haskell.Interface (interfaceModule,ifModuleName)
  7. import Gidl.Backend.Haskell.Types
  8. (typeModule,isUserDefined,typeModuleName,userTypeModuleName
  9. ,importType,importDecl, qualifiedImportDecl)
  10. import Gidl.Interface
  11. (Interface(..),MethodName,Method(..),Perm(..)
  12. ,interfaceMethods)
  13. import Gidl.Schema
  14. (Schema(..),producerSchema,consumerSchema,Message(..)
  15. ,consumerMessages,interfaceTypes,getResponseMessage)
  16. import Gidl.Types (Type(..))
  17. import Data.Char (isSpace)
  18. import Data.List (nub)
  19. import Ivory.Artifact
  20. (Artifact,artifactPath,artifactFileName,artifactPath,artifactText
  21. ,artifactCabalFile)
  22. import Ivory.Artifact.Template (artifactCabalFileTemplate)
  23. import Text.PrettyPrint.Mainland
  24. (Doc,prettyLazyText,text,empty,(<+>),(</>),(<>),char,line,parens
  25. ,punctuate,stack,tuple,dot,spread,cat,hang,nest,align,comma
  26. ,braces,brackets,dquotes)
  27. -- External Interface ----------------------------------------------------------
  28. rpcBackend :: [Interface] -> String -> String -> [Artifact]
  29. rpcBackend iis pkgName nsStr =
  30. cabalFileArtifact (defaultCabalFile pkgName modules buildDeps)
  31. : artifactCabalFile P.getDataDir "support/rpc/Makefile"
  32. : map (artifactPath "src") sourceMods
  33. where
  34. namespace = strToNs nsStr
  35. buildDeps = [ "cereal", "QuickCheck", "snap-core", "snap-server", "stm"
  36. , "aeson", "transformers", "containers" ]
  37. modules = [ filePathToPackage (artifactFileName m) | m <- sourceMods ]
  38. sourceMods = tmods ++ imods ++ [rpcBaseModule namespace]
  39. types = nub [ t | i <- iis, t <- interfaceTypes i]
  40. tmods = [ typeModule True (namespace ++ ["Types"]) t
  41. | t <- types
  42. , isUserDefined t
  43. ]
  44. imods = concat [ [ interfaceModule True (namespace ++ ["Interface"]) i
  45. , rpcModule namespace i ]
  46. | i <- iis
  47. ]
  48. rpcBaseModule :: [String] -> Artifact
  49. rpcBaseModule ns =
  50. artifactPath (foldr (\ p rest -> p ++ "/" ++ rest) "Rpc" ns) $
  51. artifactCabalFileTemplate P.getDataDir "support/rpc/Base.hs.template" env
  52. where
  53. env = [ ("module_path", foldr (\p rest -> p ++ "." ++ rest) "Rpc" ns) ]
  54. -- Utilities -------------------------------------------------------------------
  55. strToNs :: String -> [String]
  56. strToNs str =
  57. case break (== '.') (dropWhile isSpace str) of
  58. (a,'.' : b) | null a -> strToNs b
  59. | otherwise -> trim a : strToNs b
  60. (a,_) | null a -> []
  61. | otherwise -> [trim a]
  62. where
  63. trim = takeWhile (not . isSpace)
  64. allMethods :: Interface -> [(MethodName,Method)]
  65. allMethods (Interface _ ps ms) = concatMap allMethods ps ++ ms
  66. isEmptySchema :: Schema -> Bool
  67. isEmptySchema (Schema _ ms) = null ms
  68. -- Server Generation -----------------------------------------------------------
  69. rpcModule :: [String] -> Interface -> Artifact
  70. rpcModule ns iface =
  71. artifactPath (foldr (\ p rest -> p ++ "/" ++ rest) "Rpc" ns) $
  72. artifactText (ifaceMod ++ ".hs") $
  73. prettyLazyText 1000 $
  74. genServer ns iface ifaceMod
  75. where
  76. ifaceMod = ifModuleName iface
  77. genServer :: [String] -> Interface -> String -> Doc
  78. genServer ns iface ifaceMod = stack $
  79. [ text "{-# LANGUAGE RecordWildCards #-}" | useManager ] ++
  80. [ text "{-# LANGUAGE OverloadedStrings #-}"
  81. , text "{-# OPTIONS_GHC -fno-warn-unused-imports #-}"
  82. , moduleHeader ns ifaceMod
  83. , line
  84. , importTypes ns iface
  85. , importInterface ns ifaceMod
  86. , line
  87. , text "import" <+> (ppModName (ns ++ ["Rpc","Base"]))
  88. ] ++
  89. [ line
  90. , webServerImports hasConsumer
  91. , line
  92. , line
  93. , managerDefs
  94. , runServer hasConsumer useManager iface input output
  95. ]
  96. where
  97. hasConsumer = not (isEmptySchema (consumerSchema iface))
  98. (useManager,managerDefs) = managerDef hasConsumer iface input
  99. (input,output) = queueTypes iface
  100. moduleHeader :: [String] -> String -> Doc
  101. moduleHeader ns m =
  102. spread [ text "module"
  103. , dots (map text (ns ++ ["Rpc", m]))
  104. , tuple [ text "rpcServer", text "Config(..)" ]
  105. , text "where"
  106. ]
  107. -- | Import the type modules required by the interface. Import hiding
  108. -- everything, as we just need the ToJSON/FromJSON instances.
  109. importTypes :: [String] -> Interface -> Doc
  110. importTypes ns iface = stack
  111. $ map (streamImport . importType) streams
  112. ++ map (typeImport . importType) types
  113. where
  114. (streams,itypes) = partitionTypes iface
  115. types = itypes ++ interfaceTypes iface
  116. streamImport ty = importDecl addNs ty
  117. typeImport ty = qualifiedImportDecl addNs ty
  118. prefix = dots (map text (ns ++ ["Types"]))
  119. addNs m = prefix <> char '.' <> text m
  120. -- | Separate the types that are used from a stream method, from those used
  121. -- in attribute methods.
  122. partitionTypes :: Interface -> ([Type],[Type])
  123. partitionTypes iface = go [] [] (interfaceMethods iface)
  124. where
  125. go s a [] = (nub s, nub a)
  126. go s a ((_,StreamMethod _ ty):rest) = go (ty:s) a rest
  127. go s a ((_,AttrMethod _ ty):rest) = go s (ty:a) rest
  128. importInterface :: [String] -> String -> Doc
  129. importInterface ns ifaceName =
  130. text "import" <+> (dots (map text (ns ++ ["Interface", ifaceName])))
  131. webServerImports :: Bool -> Doc
  132. webServerImports hasConsumer = stack $
  133. [ text "import Control.Monad (msum)" | hasConsumer ] ++
  134. [ text "import Data.Aeson (decode)" | hasConsumer ] ++
  135. [ text "import qualified Snap.Core as Snap"
  136. , text "import Control.Concurrent (forkIO)"
  137. , text "import Control.Concurrent.STM"
  138. , text "import Control.Monad (forever)"
  139. , text "import Control.Monad.IO.Class (liftIO)"
  140. , text "import Data.Aeson (encode,Value(Null))"
  141. ]
  142. type InputQueue = Doc
  143. type OutputQueue = Doc
  144. queueTypes :: Interface -> (InputQueue,OutputQueue)
  145. queueTypes iface = (input,output)
  146. where
  147. Schema prodName _ = producerSchema iface
  148. Schema consName _ = consumerSchema iface
  149. prod = ifModuleName iface ++ prodName
  150. cons = ifModuleName iface ++ consName
  151. input = text "TQueue" <+> text prod
  152. output = text "TQueue" <+> text cons
  153. runServer :: Bool -> Bool -> Interface -> InputQueue -> OutputQueue -> Doc
  154. runServer hasConsumer useMgr iface input output =
  155. runServerSig hasConsumer input output </>
  156. runServerDef hasConsumer useMgr iface
  157. runServerSig :: Bool -> InputQueue -> OutputQueue -> Doc
  158. runServerSig hasConsumer input output =
  159. text "rpcServer ::" <+> hang 2 (arrow tys)
  160. where
  161. tys = [ input ] ++
  162. [ output | hasConsumer ] ++
  163. [ text "Config", text "IO ()" ]
  164. -- | Generate a definition for the server.
  165. runServerDef :: Bool -> Bool -> Interface -> Doc
  166. runServerDef hasConsumer useMgr iface =
  167. hang 2 (text "rpcServer" <+> body)
  168. where
  169. args = spread $
  170. [ text "input" ] ++
  171. [ text "output" | hasConsumer ] ++
  172. [ text "cfg" ]
  173. body = args <+> char '=' </> nest 2 (doStmts stmts)
  174. stmts = [ text "state <- mkState" | useMgr ]
  175. ++ [ defInput ]
  176. ++ [ spread $ [ text "_ <- forkIO (manager state input" ]
  177. ++ [ text "input'" | hasConsumer ]
  178. ++ [ text ")" ] | useMgr ]
  179. ++ [ text "conn <- newConn output" <+> input'
  180. <+> seqNumGetter | hasConsumer ]
  181. ++ [ text "runServer cfg $ Snap.route" </> routesDef ]
  182. (input',defInput)
  183. | hasConsumer && useMgr = (text "input'", text "input' <- newTQueueIO")
  184. | otherwise = (text "input", empty)
  185. routesDef = nest 2 (align (routes iface (text "state")))
  186. seqNumGetter = parens (text "SequenceNum.unSequenceNum ."
  187. <+> text "seqNumGetter" <> text (ifModuleName iface) <> text prodName)
  188. Schema prodName _ = producerSchema iface
  189. -- | Define one route for each interface member
  190. routes :: Interface -> Doc -> Doc
  191. routes iface state =
  192. align (char '[' <> nest 1 (stack (commas handlers)) <> char ']')
  193. where
  194. Interface pfx _ _ = iface
  195. Schema suffix _ = consumerSchema iface
  196. handlers = map (mkRoute pfx suffix state) (allMethods iface)
  197. mkRoute :: String -> String -> Doc -> (MethodName,Method) -> Doc
  198. mkRoute ifacePfx consSuffix state method@(name,mty) =
  199. parens (url <> comma </> guardMethods (handlersFor mty))
  200. where
  201. url = dquotes (text ifacePfx <> char '/' <> text name)
  202. guardMethods [h] = h
  203. guardMethods hs = nest 2 $ text "msum"
  204. </> brackets (stack (commas hs))
  205. handlersFor StreamMethod {} =
  206. [ readStream state name ]
  207. handlersFor (AttrMethod Read _) =
  208. [ readAttr method consSuffix m | m <- consumerMessages method ]
  209. handlersFor (AttrMethod Write _) =
  210. [ writeAttr consSuffix m | m <- consumerMessages method ]
  211. handlersFor (AttrMethod ReadWrite ty) =
  212. [ readAttr method consSuffix m
  213. | m <- consumerMessages (name,AttrMethod Read ty) ] ++
  214. [ writeAttr consSuffix m | m <- consumerMessages (name,AttrMethod Write ty) ]
  215. readStream :: Doc -> MethodName -> Doc
  216. readStream state name = nest 2 $ text "Snap.method Snap.GET $"
  217. </> doStmts
  218. [ text "x <- liftIO (atomically (readTSampleVar" <+> svar <> text "))"
  219. , text "let e = case x of Just v -> encode v; Nothing -> encode Null"
  220. , text "Snap.writeLBS e"
  221. ]
  222. where
  223. svar = parens (fieldName name <+> state)
  224. constrName :: String -> Message -> String
  225. constrName suffix (Message n _) = userTypeModuleName n ++ suffix
  226. readAttr :: (MethodName,Method) -> String -> Message -> Doc
  227. readAttr (attrname, (AttrMethod _ t)) suffix msg =
  228. text "Snap.method Snap.GET $" <+> doStmts
  229. [ parens (text responseConstructor
  230. <+> parens (responseSNumed <> dot <> responseSNumed
  231. <+> text "_ resp"))
  232. <+> text "<- liftIO $ sendRequest conn $"
  233. <+> text (constrName suffix msg)
  234. <+> dot <+> text "SequenceNum.SequenceNum"
  235. , text "Snap.modifyResponse $ Snap.setContentType \"application/json\""
  236. , text "Snap.writeLBS (encode resp)"
  237. ]
  238. where
  239. resp@(Message _ (StructType resp_tyname _)) = getResponseMessage attrname t
  240. responseConstructor = constrName "Producer" resp
  241. responseSNumed = text $ userTypeModuleName resp_tyname
  242. readAttr _ _ _ = error "impossible readAttr"
  243. writeAttr :: String -> Message -> Doc
  244. writeAttr suffix msg = text "Snap.method Snap.POST $" <+> doStmts
  245. [ text "bytes <- Snap.readRequestBody 32768"
  246. , text "case decode bytes of" </>
  247. text "Just req -> liftIO $" <+> doStmts
  248. [ text "_ <- sendRequest conn $ \\ snum ->"
  249. <+> text con
  250. <+> parens (text (userTypeModuleName sname)
  251. <> dot <> text (userTypeModuleName sname)
  252. <+> text "(SequenceNum.SequenceNum snum)" <+> text "req")
  253. , text "return ()"
  254. ] </>
  255. text "Nothing -> Snap.modifyResponse $ Snap.setResponseCode 400"
  256. ]
  257. where
  258. con = constrName suffix msg
  259. (Message _ (StructType sname _)) = msg
  260. -- The stream manager ----------------------------------------------------------
  261. -- | Define everything associated with the manager, but only if there are stream
  262. -- values to manage.
  263. managerDef :: Bool -> Interface -> InputQueue -> (Bool,Doc)
  264. managerDef hasConsumer iface input
  265. | null streams = (False,empty)
  266. | otherwise = (True,stack defs </> empty)
  267. where
  268. streams = [ (name,ty) | (name,StreamMethod _ ty) <- allMethods iface ]
  269. (stateType,stateDecl) = stateDef streams
  270. defs = [ stateDecl
  271. , empty
  272. , mkStateDef streams
  273. , empty
  274. , text "manager ::" <+> arrow ([ stateType, input ] ++
  275. [ input | hasConsumer ] ++
  276. [ text "IO ()" ])
  277. , nest 2 $ spread $
  278. [ text "manager state input" ] ++
  279. [ text "filtered" | hasConsumer ] ++
  280. [ text "= forever $" </> doStmts stmts ]
  281. ]
  282. stmts = [ text "msg <- atomically (readTQueue input)"
  283. , nest 2 (text "case msg of" </>
  284. stack (map mkCase streams ++ [ defCase | hasConsumer ])) ]
  285. -- name the producer constructor for a stream element
  286. Schema prodSuffix _ = producerSchema iface
  287. prodName ty = text (typeModuleName ty ++ prodSuffix)
  288. -- update the state for this stream element
  289. mkCase (n,ty) = prodName ty <+> text "x -> atomically (writeTSampleVar"
  290. <+> parens (fieldName n <+> text "state")
  291. <+> text "x)"
  292. defCase = text "notStream -> atomically (writeTQueue filtered notStream)"
  293. -- | Generate the data type used to hold the streaming values, or nothing if
  294. -- there aren't any present in the interface.
  295. stateDef :: [(MethodName,Type)] -> (Doc,Doc)
  296. stateDef streams = (text "State",def)
  297. where
  298. def = nest 2 (text "data State = State" <+> braces fields)
  299. fields = align (stack (punctuate comma (map mkField streams)))
  300. mkField (name,ty) =
  301. fieldName name
  302. <+> text "::"
  303. <+> text "TSampleVar"
  304. <+> text (typeModuleName ty)
  305. mkStateDef :: [(MethodName,Type)] -> Doc
  306. mkStateDef streams = stack
  307. [ text "mkState :: IO State"
  308. , nest 2 (text "mkState =" </> nest 3 (doStmts stmts))
  309. ]
  310. where
  311. stmts = [ fieldName n <+> text "<- newTSampleVarIO" | (n,_) <- streams ]
  312. ++ [ text "return State { .. }" ]
  313. -- | Given the name of a stream in the interface, produce the selector for the
  314. -- state data type.
  315. fieldName :: MethodName -> Doc
  316. fieldName name = text "stream_" <> text name
  317. -- Pretty-printing Helpers -----------------------------------------------------
  318. arrow :: [Doc] -> Doc
  319. arrow ts = spread (punctuate (text "->") ts)
  320. commas :: [Doc] -> [Doc]
  321. commas = punctuate comma
  322. dots :: [Doc] -> Doc
  323. dots = cat . punctuate dot
  324. ppModName :: [String] -> Doc
  325. ppModName = dots . map text
  326. doStmts :: [Doc] -> Doc
  327. doStmts [d] = nest 2 d
  328. doStmts ds = text "do" <+> align (stack (map (nest 2) ds))