fennel.lua 82 KB


  1. --[[
  2. Copyright (c) 2016-2018 Calvin Rose and contributors
  3. Permission is hereby granted, free of charge, to any person obtaining a copy of
  4. this software and associated documentation files (the "Software"), to deal in
  5. the Software without restriction, including without limitation the rights to
  6. use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
  7. the Software, and to permit persons to whom the Software is furnished to do so,
  8. subject to the following conditions:
  9. The above copyright notice and this permission notice shall be included in all
  10. copies or substantial portions of the Software.
  11. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  12. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  13. FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
  14. COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  15. IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  16. CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  17. ]]
  18. -- Make global variables local.
  19. local setmetatable = setmetatable
  20. local getmetatable = getmetatable
  21. local type = type
  22. local assert = assert
  23. local pairs = pairs
  24. local ipairs = ipairs
  25. local tostring = tostring
  26. local unpack = unpack or table.unpack
  27. --
  28. -- Main Types and support functions
  29. --
  30. local function deref(self) return self[1] end
  31. local SYMBOL_MT = { 'SYMBOL', __tostring = deref }
  32. local EXPR_MT = { 'EXPR', __tostring = deref }
  33. local VARARG = setmetatable({ '...' }, { 'VARARG', __tostring = deref })
  34. local LIST_MT = { 'LIST',
  35. __tostring = function (self)
  36. local strs = {}
  37. for _, s in ipairs(self) do
  38. table.insert(strs, tostring(s))
  39. end
  40. return '(' .. table.concat(strs, ', ', 1, #self) .. ')'
  41. end
  42. }
  43. local SEQUENCE_MT = { 'SEQUENCE' }
  44. -- Load code with an environment in all recent Lua versions
  45. local function loadCode(code, environment, filename)
  46. environment = environment or _ENV or _G
  47. if setfenv and loadstring then
  48. local f = assert(loadstring(code, filename))
  49. setfenv(f, environment)
  50. return f
  51. else
  52. return assert(load(code, filename, "t", environment))
  53. end
  54. end
  55. -- Create a new list
  56. local function list(...)
  57. return setmetatable({...}, LIST_MT)
  58. end
  59. -- Create a new symbol
  60. local function sym(str, scope, meta)
  61. local s = {str, scope = scope}
  62. if meta then
  63. for k, v in pairs(meta) do
  64. if type(k) == 'string' then s[k] = v end
  65. end
  66. end
  67. return setmetatable(s, SYMBOL_MT)
  68. end
  69. -- Create a new sequence
  70. local function sequence(...)
  71. return setmetatable({...}, SEQUENCE_MT)
  72. end
  73. -- Create a new expr
  74. -- etype should be one of
  75. -- "literal", -- literals like numbers, strings, nil, true, false
  76. -- "expression", -- Complex strings of Lua code, may have side effects, etc, but is an expression
  77. -- "statement", -- Same as expression, but is also a valid statement (function calls).
  78. -- "vargs", -- varargs symbol
  79. -- "sym", -- symbol reference
  80. local function expr(strcode, etype)
  81. return setmetatable({ strcode, type = etype }, EXPR_MT)
  82. end
  83. local function varg()
  84. return VARARG
  85. end
  86. local function isVarg(x)
  87. return x == VARARG and x
  88. end
  89. -- Checks if an object is a List. Returns the object if is a List.
  90. local function isList(x)
  91. return type(x) == 'table' and getmetatable(x) == LIST_MT and x
  92. end
  93. -- Checks if an object is a symbol. Returns the object if it is a symbol.
  94. local function isSym(x)
  95. return type(x) == 'table' and getmetatable(x) == SYMBOL_MT and x
  96. end
  97. -- Checks if an object any kind of table, EXCEPT list or symbol
  98. local function isTable(x)
  99. return type(x) == 'table' and
  100. x ~= VARARG and
  101. getmetatable(x) ~= LIST_MT and getmetatable(x) ~= SYMBOL_MT and x
  102. end
  103. -- Checks if an object is a sequence (created with a [] literal)
  104. local function isSequence(x)
  105. return type(x) == 'table' and getmetatable(x) == SEQUENCE_MT and x
  106. end
  107. --
  108. -- Parser
  109. --
  110. -- Convert a stream of chunks to a stream of bytes.
  111. -- Also returns a second function to clear the buffer in the byte stream
  112. local function granulate(getchunk)
  113. local c = ''
  114. local index = 1
  115. local done = false
  116. return function (parserState)
  117. if done then return nil end
  118. if index <= #c then
  119. local b = c:byte(index)
  120. index = index + 1
  121. return b
  122. else
  123. c = getchunk(parserState)
  124. if not c or c == '' then
  125. done = true
  126. return nil
  127. end
  128. index = 2
  129. return c:byte(1)
  130. end
  131. end, function ()
  132. c = ''
  133. end
  134. end
  135. -- Convert a string into a stream of bytes
  136. local function stringStream(str)
  137. local index = 1
  138. return function()
  139. local r = str:byte(index)
  140. index = index + 1
  141. return r
  142. end
  143. end
  144. -- Table of delimiter bytes - (, ), [, ], {, }
  145. -- Opener keys have closer as the value, and closers keys
  146. -- have true as their value.
  147. local delims = {
  148. [40] = 41, -- (
  149. [41] = true, -- )
  150. [91] = 93, -- [
  151. [93] = true, -- ]
  152. [123] = 125, -- {
  153. [125] = true -- }
  154. }
  155. local function iswhitespace(b)
  156. return b == 32 or (b >= 9 and b <= 13) or b == 44
  157. end
  158. local function issymbolchar(b)
  159. return b > 32 and
  160. not delims[b] and
  161. b ~= 127 and -- "<BS>"
  162. b ~= 34 and -- "\""
  163. b ~= 39 and -- "'"
  164. b ~= 59 and -- ";"
  165. b ~= 44 and -- ","
  166. b ~= 96 -- "`"
  167. end
  168. local prefixes = { -- prefix chars substituted while reading
  169. [96] = 'quote', -- `
  170. [64] = 'unquote' -- @
  171. }
  172. -- Parse one value given a function that
  173. -- returns sequential bytes. Will throw an error as soon
  174. -- as possible without getting more bytes on bad input. Returns
  175. -- if a value was read, and then the value read. Will return nil
  176. -- when input stream is finished.
  177. local function parser(getbyte, filename)
  178. -- Stack of unfinished values
  179. local stack = {}
  180. -- Provide one character buffer and keep
  181. -- track of current line and byte index
  182. local line = 1
  183. local byteindex = 0
  184. local lastb
  185. local function ungetb(ub)
  186. if ub == 10 then line = line - 1 end
  187. byteindex = byteindex - 1
  188. lastb = ub
  189. end
  190. local function getb()
  191. local r
  192. if lastb then
  193. r, lastb = lastb, nil
  194. else
  195. r = getbyte({ stackSize = #stack })
  196. end
  197. byteindex = byteindex + 1
  198. if r == 10 then line = line + 1 end
  199. return r
  200. end
  201. local function parseError(msg)
  202. return error(msg .. ' in ' .. (filename or 'unknown') .. ':' .. line, 0)
  203. end
  204. -- Parse stream
  205. return function()
  206. -- Dispatch when we complete a value
  207. local done, retval
  208. local function dispatch(v)
  209. if #stack == 0 then
  210. retval = v
  211. done = true
  212. elseif stack[#stack].prefix then
  213. local stacktop = stack[#stack]
  214. stack[#stack] = nil
  215. return dispatch(list(sym(stacktop.prefix), v))
  216. else
  217. table.insert(stack[#stack], v)
  218. end
  219. end
  220. -- Throw nice error when we expect more characters
  221. -- but reach end of stream.
  222. local function badend()
  223. local accum = {}
  224. for _, item in ipairs(stack) do
  225. accum[#accum + 1] = item.closer
  226. end
  227. parseError(('expected closing delimiter%s %s'):format(
  228. #stack == 1 and "" or "s",
  229. string.char(unpack(accum))))
  230. end
  231. -- The main parse loop
  232. repeat
  233. local b
  234. -- Skip whitespace
  235. repeat
  236. b = getb()
  237. until not b or not iswhitespace(b)
  238. if not b then
  239. if #stack > 0 then badend() end
  240. return nil
  241. end
  242. if b == 59 then -- ; Comment
  243. repeat
  244. b = getb()
  245. until not b or b == 10 -- newline
  246. elseif type(delims[b]) == 'number' then -- Opening delimiter
  247. table.insert(stack, setmetatable({
  248. closer = delims[b],
  249. line = line,
  250. filename = filename,
  251. bytestart = byteindex
  252. }, LIST_MT))
  253. elseif delims[b] then -- Closing delimiter
  254. if #stack == 0 then parseError 'unexpected closing delimiter' end
  255. local last = stack[#stack]
  256. local val
  257. if last.closer ~= b then
  258. parseError('unexpected delimiter ' .. string.char(b) ..
  259. ', expected ' .. string.char(last.closer))
  260. end
  261. last.byteend = byteindex -- Set closing byte index
  262. if b == 41 then -- ; )
  263. val = last
  264. elseif b == 93 then -- ; ]
  265. val = sequence()
  266. for i = 1, #last do
  267. val[i] = last[i]
  268. end
  269. else -- ; }
  270. if #last % 2 ~= 0 then
  271. parseError('expected even number of values in table literal')
  272. end
  273. val = {}
  274. for i = 1, #last, 2 do
  275. val[last[i]] = last[i + 1]
  276. end
  277. end
  278. stack[#stack] = nil
  279. dispatch(val)
  280. elseif b == 34 or b == 39 then -- Quoted string
  281. local start = b
  282. local state = "base"
  283. local chars = {start}
  284. stack[#stack + 1] = {closer = b}
  285. repeat
  286. b = getb()
  287. chars[#chars + 1] = b
  288. if state == "base" then
  289. if b == 92 then
  290. state = "backslash"
  291. elseif b == start then
  292. state = "done"
  293. end
  294. else
  295. -- state == "backslash"
  296. state = "base"
  297. end
  298. until not b or (state == "done")
  299. if not b then badend() end
  300. stack[#stack] = nil
  301. local raw = string.char(unpack(chars))
  302. local formatted = raw:gsub("[\1-\31]", function (c) return '\\' .. c:byte() end)
  303. local loadFn = loadCode(('return %s'):format(formatted), nil, filename)
  304. dispatch(loadFn())
  305. elseif prefixes[b] then -- expand prefix byte into wrapping form eg. '`a' into '(quote a)'
  306. table.insert(stack, {
  307. prefix = prefixes[b]
  308. })
  309. else -- Try symbol
  310. local chars = {}
  311. local bytestart = byteindex
  312. repeat
  313. chars[#chars + 1] = b
  314. b = getb()
  315. until not b or not issymbolchar(b)
  316. if b then ungetb(b) end
  317. local rawstr = string.char(unpack(chars))
  318. if rawstr == 'true' then dispatch(true)
  319. elseif rawstr == 'false' then dispatch(false)
  320. elseif rawstr == '...' then dispatch(VARARG)
  321. elseif rawstr:match('^:.+$') then -- keyword style strings
  322. dispatch(rawstr:sub(2))
  323. else
  324. local forceNumber = rawstr:match('^%d')
  325. local numberWithStrippedUnderscores = rawstr:gsub("_", "")
  326. local x
  327. if forceNumber then
  328. x = tonumber(numberWithStrippedUnderscores) or
  329. parseError('could not read token "' .. rawstr .. '"')
  330. else
  331. x = tonumber(numberWithStrippedUnderscores) or
  332. sym(rawstr, nil, { line = line,
  333. filename = filename,
  334. bytestart = bytestart,
  335. byteend = byteindex, })
  336. end
  337. dispatch(x)
  338. end
  339. end
  340. until done
  341. return true, retval
  342. end, function ()
  343. stack = {}
  344. end
  345. end
  346. --
  347. -- Compilation
  348. --
  349. -- Create a new Scope, optionally under a parent scope. Scopes are compile time constructs
  350. -- that are responsible for keeping track of local variables, name mangling, and macros.
  351. -- They are accessible to user code via the '*compiler' special form (may change). They
  352. -- use metatables to implement nesting via inheritance.
  353. local function makeScope(parent)
  354. return {
  355. unmanglings = setmetatable({}, {
  356. __index = parent and parent.unmanglings
  357. }),
  358. manglings = setmetatable({}, {
  359. __index = parent and parent.manglings
  360. }),
  361. specials = setmetatable({}, {
  362. __index = parent and parent.specials
  363. }),
  364. symmeta = setmetatable({}, {
  365. __index = parent and parent.symmeta
  366. }),
  367. parent = parent,
  368. vararg = parent and parent.vararg,
  369. depth = parent and ((parent.depth or 0) + 1) or 0
  370. }
  371. end
  372. -- Assert a condition and raise a compile error with line numbers. The ast arg
  373. -- should be unmodified so that its first element is the form being called.
  374. local function assertCompile(condition, msg, ast)
  375. -- if we use regular `assert' we can't provide the `level' argument of zero
  376. if not condition then
  377. error(string.format("Compile error in '%s' %s:%s: %s",
  378. isSym(ast[1]) and ast[1][1] or ast[1] or '()',
  379. ast.filename or "unknown", ast.line or '?', msg), 0)
  380. end
  381. return condition
  382. end
  383. local GLOBAL_SCOPE = makeScope()
  384. GLOBAL_SCOPE.vararg = true
  385. local SPECIALS = GLOBAL_SCOPE.specials
  386. local COMPILER_SCOPE = makeScope(GLOBAL_SCOPE)
  387. local luaKeywords = {
  388. 'and', 'break', 'do', 'else', 'elseif', 'end', 'false', 'for', 'function',
  389. 'if', 'in', 'local', 'nil', 'not', 'or', 'repeat', 'return', 'then', 'true',
  390. 'until', 'while'
  391. }
  392. for i, v in ipairs(luaKeywords) do
  393. luaKeywords[v] = i
  394. end
  395. local function isValidLuaIdentifier(str)
  396. return (str:match('^[%a_][%w_]*$') and not luaKeywords[str])
  397. end
  398. -- Allow printing a string to Lua, also keep as 1 line.
  399. local serializeSubst = {
  400. ['\a'] = '\\a',
  401. ['\b'] = '\\b',
  402. ['\f'] = '\\f',
  403. ['\n'] = 'n',
  404. ['\t'] = '\\t',
  405. ['\v'] = '\\v'
  406. }
  407. local function serializeString(str)
  408. local s = ("%q"):format(str)
  409. s = s:gsub('.', serializeSubst):gsub("[\128-\255]", function(c)
  410. return "\\" .. c:byte()
  411. end)
  412. return s
  413. end
  414. -- A multi symbol is a symbol that is actually composed of
  415. -- two or more symbols using the dot syntax. The main differences
  416. -- from normal symbols is that they cannot be declared local, and
  417. -- they may have side effects on invocation (metatables)
  418. local function isMultiSym(str)
  419. if type(str) ~= 'string' then return end
  420. local parts = {}
  421. for part in str:gmatch('[^%.]+') do
  422. parts[#parts + 1] = part
  423. end
  424. return #parts > 0 and
  425. str:match('%.') and
  426. (not str:match('%.%.')) and
  427. str:byte() ~= string.byte '.' and
  428. str:byte(-1) ~= string.byte '.' and
  429. parts
  430. end
  431. -- Mangler for global symbols. Does not protect against collisions,
  432. -- but makes them unlikely. This is the mangling that is exposed to
  433. -- to the world.
  434. local function globalMangling(str)
  435. if isValidLuaIdentifier(str) then
  436. return str
  437. end
  438. -- Use underscore as escape character
  439. return '__fnl_global__' .. str:gsub('[^%w]', function (c)
  440. return ('_%02x'):format(c:byte())
  441. end)
  442. end
  443. -- Reverse a global mangling. Takes a Lua identifier and
  444. -- returns the fennel symbol string that created it.
  445. local function globalUnmangling(identifier)
  446. local rest = identifier:match('^__fnl_global__(.*)$')
  447. if rest then
  448. return rest:gsub('_[%da-f][%da-f]', function (code)
  449. return string.char(tonumber(code:sub(2), 16))
  450. end)
  451. else
  452. return identifier
  453. end
  454. end
  455. -- Creates a symbol from a string by mangling it.
  456. -- ensures that the generated symbol is unique
  457. -- if the input string is unique in the scope.
  458. local function localMangling(str, scope, ast)
  459. if scope.manglings[str] then
  460. return scope.manglings[str]
  461. end
  462. local append = 0
  463. local mangling = str
  464. assertCompile(not isMultiSym(str), 'did not expect multi symbol ' .. str, ast)
  465. -- Mapping mangling to a valid Lua identifier
  466. if luaKeywords[mangling] or mangling:match('^%d') then
  467. mangling = '_' .. mangling
  468. end
  469. mangling = mangling:gsub('-', '_')
  470. mangling = mangling:gsub('[^%w_]', function (c)
  471. return ('_%02x'):format(c:byte())
  472. end)
  473. local raw = mangling
  474. while scope.unmanglings[mangling] do
  475. mangling = raw .. append
  476. append = append + 1
  477. end
  478. scope.unmanglings[mangling] = str
  479. scope.manglings[str] = mangling
  480. return mangling
  481. end
  482. -- Combine parts of a symbol
  483. local function combineParts(parts, scope)
  484. local ret = scope.manglings[parts[1]] or globalMangling(parts[1])
  485. for i = 2, #parts do
  486. if isValidLuaIdentifier(parts[i]) then
  487. ret = ret .. '.' .. parts[i]
  488. else
  489. ret = ret .. '[' .. serializeString(parts[i]) .. ']'
  490. end
  491. end
  492. return ret
  493. end
  494. -- Generates a unique symbol in the scope.
  495. local function gensym(scope)
  496. local mangling
  497. local append = 0
  498. repeat
  499. mangling = '_' .. append .. '_'
  500. append = append + 1
  501. until not scope.unmanglings[mangling]
  502. scope.unmanglings[mangling] = true
  503. return mangling
  504. end
  505. -- Check if a binding is valid
  506. local function checkBindingValid(symbol, scope, ast)
  507. -- Check if symbol will be over shadowed by special
  508. local name = symbol[1]
  509. assertCompile(not scope.specials[name],
  510. ("symbol %s may be overshadowed by a special form or macro"):format(name), ast)
  511. end
  512. -- Declare a local symbol
  513. local function declareLocal(symbol, meta, scope, ast)
  514. checkBindingValid(symbol, scope, ast)
  515. local name = symbol[1]
  516. assertCompile(not isMultiSym(name), "did not expect mutltisym", ast)
  517. local mangling = localMangling(name, scope, ast)
  518. scope.symmeta[name] = meta
  519. return mangling
  520. end
  521. -- If there's a provided list of allowed globals, don't let references
  522. -- thru that aren't on the list. This list is set at the compiler
  523. -- entry points of compile and compileStream.
  524. local allowedGlobals
  525. local function globalAllowed(name)
  526. if not allowedGlobals then return true end
  527. for _, g in ipairs(allowedGlobals) do
  528. if g == name then return true end
  529. end
  530. end
  531. -- Convert symbol to Lua code. Will only work for local symbols
  532. -- if they have already been declared via declareLocal
  533. local function symbolToExpression(symbol, scope, isReference)
  534. local name = symbol[1]
  535. local parts = isMultiSym(name) or {name}
  536. local etype = (#parts > 1) and "expression" or "sym"
  537. local isLocal = scope.manglings[parts[1]]
  538. -- if it's a reference and not a symbol which introduces a new binding
  539. -- then we need to check for allowed globals
  540. assertCompile(not isReference or isLocal or globalAllowed(parts[1]),
  541. 'unknown global in strict mode: ' .. parts[1], symbol)
  542. return expr(combineParts(parts, scope), etype)
  543. end
  544. -- Emit Lua code
  545. local function emit(chunk, out, ast)
  546. if type(out) == 'table' then
  547. table.insert(chunk, out)
  548. else
  549. table.insert(chunk, {leaf = out, ast = ast})
  550. end
  551. end
  552. -- Do some peephole optimization.
  553. local function peephole(chunk)
  554. if chunk.leaf then return chunk end
  555. -- Optimize do ... end in some cases.
  556. if #chunk == 3 and
  557. chunk[1].leaf == 'do' and
  558. not chunk[2].leaf and
  559. chunk[3].leaf == 'end' then
  560. return peephole(chunk[2])
  561. end
  562. -- Recurse
  563. for i, v in ipairs(chunk) do
  564. chunk[i] = peephole(v)
  565. end
  566. return chunk
  567. end
  568. -- correlate line numbers in input with line numbers in output
  569. local function flattenChunkCorrelated(mainChunk)
  570. local function flatten(chunk, out, lastLine, file)
  571. if chunk.leaf then
  572. out[lastLine] = (out[lastLine] or "") .. " " .. chunk.leaf
  573. else
  574. for _, subchunk in ipairs(chunk) do
  575. -- Ignore empty chunks
  576. if subchunk.leaf or #subchunk > 0 then
  577. -- don't increase line unless it's from the same file
  578. if subchunk.ast and file == subchunk.ast.file then
  579. lastLine = math.max(lastLine, subchunk.ast.line or 0)
  580. end
  581. lastLine = flatten(subchunk, out, lastLine, file)
  582. end
  583. end
  584. end
  585. return lastLine
  586. end
  587. local out = {}
  588. local last = flatten(mainChunk, out, 1, mainChunk.file)
  589. for i = 1, last do
  590. if out[i] == nil then out[i] = "" end
  591. end
  592. return table.concat(out, "\n")
  593. end
  594. -- Flatten a tree of indented Lua source code lines.
  595. -- Tab is what is used to indent a block.
  596. local function flattenChunk(sm, chunk, tab, depth)
  597. if type(tab) == 'boolean' then tab = tab and ' ' or '' end
  598. if chunk.leaf then
  599. local code = chunk.leaf
  600. local info = chunk.ast
  601. -- Just do line info for now to save memory
  602. if sm then sm[#sm + 1] = info and info.line or -1 end
  603. return code
  604. else
  605. local parts = {}
  606. for i = 1, #chunk do
  607. -- Ignore empty chunks
  608. if chunk[i].leaf or #(chunk[i]) > 0 then
  609. local sub = flattenChunk(sm, chunk[i], tab, depth + 1)
  610. if depth > 0 then sub = tab .. sub:gsub('\n', '\n' .. tab) end
  611. table.insert(parts, sub)
  612. end
  613. end
  614. return table.concat(parts, '\n')
  615. end
  616. end
  617. -- Some global state for all fennel sourcemaps. For the time being,
  618. -- this seems the easiest way to store the source maps.
  619. -- Sourcemaps are stored with source being mapped as the key, prepended
  620. -- with '@' if it is a filename (like debug.getinfo returns for source).
  621. -- The value is an array of mappings for each line.
  622. local fennelSourcemap = {}
  623. -- TODO: loading, unloading, and saving sourcemaps?
  624. local function makeShortSrc(source)
  625. source = source:gsub('\n', ' ')
  626. if #source <= 49 then
  627. return '[fennel "' .. source .. '"]'
  628. else
  629. return '[fennel "' .. source:sub(1, 46) .. '..."]'
  630. end
  631. end
  632. -- Return Lua source and source map table
  633. local function flatten(chunk, options)
  634. local sm = options.sourcemap and {}
  635. chunk = peephole(chunk)
  636. if(options.correlate) then
  637. return flattenChunkCorrelated(chunk), {}
  638. else
  639. local ret = flattenChunk(sm, chunk, options.indent, 0)
  640. if sm then
  641. local key, short_src
  642. if options.filename then
  643. short_src = options.filename
  644. key = '@' .. short_src
  645. else
  646. key = ret
  647. short_src = makeShortSrc(options.source or ret)
  648. end
  649. sm.short_src = short_src
  650. sm.key = key
  651. fennelSourcemap[key] = sm
  652. end
  653. return ret, sm
  654. end
  655. end
  656. -- Convert expressions to Lua string
  657. local function exprs1(exprs)
  658. local t = {}
  659. for _, e in ipairs(exprs) do
  660. t[#t + 1] = e[1]
  661. end
  662. return table.concat(t, ', ')
  663. end
  664. -- Compile side effects for a chunk
  665. local function keepSideEffects(exprs, chunk, start, ast)
  666. start = start or 1
  667. for j = start, #exprs do
  668. local se = exprs[j]
  669. -- Avoid the rogue 'nil' expression (nil is usually a literal,
  670. -- but becomes an expression if a special form
  671. -- returns 'nil'.)
  672. if se.type == 'expression' and se[1] ~= 'nil' then
  673. emit(chunk, ('do local _ = %s end'):format(tostring(se)), ast)
  674. elseif se.type == 'statement' then
  675. emit(chunk, tostring(se), ast)
  676. end
  677. end
  678. end
  679. -- Does some common handling of returns and register
  680. -- targets for special forms. Also ensures a list expression
  681. -- has an acceptable number of expressions if opts contains the
  682. -- "nval" option.
  683. local function handleCompileOpts(exprs, parent, opts, ast)
  684. if opts.nval then
  685. local n = opts.nval
  686. if n ~= #exprs then
  687. local len = #exprs
  688. if len > n then
  689. -- Drop extra
  690. keepSideEffects(exprs, parent, n + 1, ast)
  691. for i = n + 1, len do
  692. exprs[i] = nil
  693. end
  694. else
  695. -- Pad with nils
  696. for i = #exprs + 1, n do
  697. exprs[i] = expr('nil', 'literal')
  698. end
  699. end
  700. end
  701. end
  702. if opts.tail then
  703. emit(parent, ('return %s'):format(exprs1(exprs)), ast)
  704. end
  705. if opts.target then
  706. emit(parent, ('%s = %s'):format(opts.target, exprs1(exprs)), ast)
  707. end
  708. if opts.tail or opts.target then
  709. -- Prevent statements and expression from being used twice if they
  710. -- have side-effects. Since if the target or tail options are set,
  711. -- the expressions are already emitted, we should not return them. This
  712. -- is fine, as when these options are set, the caller doesn't need the result
  713. -- anyways.
  714. exprs = {}
  715. end
  716. return exprs
  717. end
  718. -- Compile an AST expression in the scope into parent, a tree
  719. -- of lines that is eventually compiled into Lua code. Also
  720. -- returns some information about the evaluation of the compiled expression,
  721. -- which can be used by the calling function. Macros
  722. -- are resolved here, as well as special forms in that order.
  723. -- the 'ast' param is the root AST to compile
  724. -- the 'scope' param is the scope in which we are compiling
  725. -- the 'parent' param is the table of lines that we are compiling into.
  726. -- add lines to parent by appending strings. Add indented blocks by appending
  727. -- tables of more lines.
  728. -- the 'opts' param contains info about where the form is being compiled.
  729. -- Options include:
  730. -- 'target' - mangled name of symbol(s) being compiled to.
  731. -- Could be one variable, 'a', or a list, like 'a, b, _0_'.
  732. -- 'tail' - boolean indicating tail position if set. If set, form will generate a return
  733. -- instruction.
  734. -- 'nval' - The number of values to compile to if it is known to be a fixed value.
  735. local function compile1(ast, scope, parent, opts)
  736. opts = opts or {}
  737. local exprs = {}
  738. -- Compile the form
  739. if isList(ast) then
  740. -- Function call or special form
  741. local len = #ast
  742. assertCompile(len > 0, "expected a function to call", ast)
  743. -- Test for special form
  744. local first = ast[1]
  745. if isSym(first) then -- Resolve symbol
  746. first = first[1]
  747. end
  748. local special = scope.specials[first]
  749. if special and isSym(ast[1]) then
  750. -- Special form
  751. exprs = special(ast, scope, parent, opts) or expr('nil', 'literal')
  752. -- Be very accepting of strings or expression
  753. -- as well as lists or expressions
  754. if type(exprs) == 'string' then exprs = expr(exprs, 'expression') end
  755. if getmetatable(exprs) == EXPR_MT then exprs = {exprs} end
  756. -- Unless the special form explicitly handles the target, tail, and nval properties,
  757. -- (indicated via the 'returned' flag), handle these options.
  758. if not exprs.returned then
  759. exprs = handleCompileOpts(exprs, parent, opts, ast)
  760. elseif opts.tail or opts.target then
  761. exprs = {}
  762. end
  763. exprs.returned = true
  764. return exprs
  765. else
  766. -- Function call
  767. local fargs = {}
  768. local fcallee = compile1(ast[1], scope, parent, {
  769. nval = 1
  770. })[1]
  771. assertCompile(fcallee.type ~= 'literal',
  772. 'cannot call literal value', ast)
  773. fcallee = tostring(fcallee)
  774. for i = 2, len do
  775. local subexprs = compile1(ast[i], scope, parent, {
  776. nval = i ~= len and 1 or nil
  777. })
  778. fargs[#fargs + 1] = subexprs[1] or expr('nil', 'literal')
  779. if i == len then
  780. -- Add sub expressions to function args
  781. for j = 2, #subexprs do
  782. fargs[#fargs + 1] = subexprs[j]
  783. end
  784. else
  785. -- Emit sub expression only for side effects
  786. keepSideEffects(subexprs, parent, 2, ast[i])
  787. end
  788. end
  789. local call = ('%s(%s)'):format(tostring(fcallee), exprs1(fargs))
  790. exprs = handleCompileOpts({expr(call, 'statement')}, parent, opts, ast)
  791. end
  792. elseif isVarg(ast) then
  793. assertCompile(scope.vararg, "unexpected vararg", ast)
  794. exprs = handleCompileOpts({expr('...', 'varg')}, parent, opts, ast)
  795. elseif isSym(ast) then
  796. local e
  797. -- Handle nil as special symbol - it resolves to the nil literal rather than
  798. -- being unmangled. Alternatively, we could remove it from the lua keywords table.
  799. if ast[1] == 'nil' then
  800. e = expr('nil', 'literal')
  801. else
  802. e = symbolToExpression(ast, scope, true)
  803. end
  804. exprs = handleCompileOpts({e}, parent, opts, ast)
  805. elseif type(ast) == 'nil' or type(ast) == 'boolean' then
  806. exprs = handleCompileOpts({expr(tostring(ast), 'literal')}, parent, opts)
  807. elseif type(ast) == 'number' then
  808. local n = ('%.17g'):format(ast)
  809. exprs = handleCompileOpts({expr(n, 'literal')}, parent, opts)
  810. elseif type(ast) == 'string' then
  811. local s = serializeString(ast)
  812. exprs = handleCompileOpts({expr(s, 'literal')}, parent, opts)
  813. elseif type(ast) == 'table' then
  814. local buffer = {}
  815. for i = 1, #ast do -- Write numeric keyed values.
  816. local nval = i ~= #ast and 1
  817. buffer[#buffer + 1] = exprs1(compile1(ast[i], scope, parent, {nval = nval}))
  818. end
  819. local keys = {}
  820. for k, _ in pairs(ast) do -- Write other keys.
  821. if type(k) ~= 'number' or math.floor(k) ~= k or k < 1 or k > #ast then
  822. local kstr
  823. if type(k) == 'string' and isValidLuaIdentifier(k) then
  824. kstr = k
  825. else
  826. kstr = '[' .. tostring(compile1(k, scope, parent, {nval = 1})[1]) .. ']'
  827. end
  828. table.insert(keys, { kstr, k })
  829. end
  830. end
  831. table.sort(keys, function (a, b) return a[1] < b[1] end)
  832. for _, k in ipairs(keys) do
  833. local v = ast[k[2]]
  834. buffer[#buffer + 1] = ('%s = %s'):format(
  835. k[1], tostring(compile1(v, scope, parent, {nval = 1})[1]))
  836. end
  837. local tbl = '{' .. table.concat(buffer, ', ') ..'}'
  838. exprs = handleCompileOpts({expr(tbl, 'expression')}, parent, opts, ast)
  839. else
  840. assertCompile(false, 'could not compile value of type ' .. type(ast), ast)
  841. end
  842. exprs.returned = true
  843. return exprs
  844. end
  845. -- SPECIALS --
  846. -- For statements and expressions, put the value in a local to avoid
  847. -- double-evaluating it.
  848. local function once(val, ast, scope, parent)
  849. if val.type == 'statement' or val.type == 'expression' then
  850. local s = gensym(scope)
  851. emit(parent, ('local %s = %s'):format(s, tostring(val)), ast)
  852. return expr(s, 'sym')
  853. else
  854. return val
  855. end
  856. end
  857. -- Implements destructuring for forms like let, bindings, etc.
  858. -- Takes a number of options to control behavior.
  859. -- var: Whether or not to mark symbols as mutable
  860. -- declaration: begin each assignment with 'local' in output
  861. -- nomulti: disallow multisyms in the destructuring. Used for (local) and (global).
  862. -- noundef: Don't set undefined bindings. (set)
  863. -- forceglobal: Don't allow local bindings
  864. local function destructure(to, from, ast, scope, parent, opts)
  865. opts = opts or {}
  866. local isvar = opts.isvar
  867. local declaration = opts.declaration
  868. local nomulti = opts.nomulti
  869. local noundef = opts.noundef
  870. local forceglobal = opts.forceglobal
  871. local forceset = opts.forceset
  872. local setter = declaration and "local %s = %s" or "%s = %s"
  873. -- Get Lua source for symbol, and check for errors
  874. local function getname(symbol, up1)
  875. local raw = symbol[1]
  876. assertCompile(not (nomulti and isMultiSym(raw)),
  877. 'did not expect multisym', up1)
  878. if declaration then
  879. return declareLocal(symbol, {var = isvar}, scope, symbol)
  880. else
  881. local parts = isMultiSym(raw) or {raw}
  882. local meta = scope.symmeta[parts[1]]
  883. if #parts == 1 and not forceset then
  884. assertCompile(not(forceglobal and meta),
  885. 'expected global, found var', up1)
  886. assertCompile(meta or not noundef,
  887. 'expected local var ' .. parts[1], up1)
  888. assertCompile(not (meta and not meta.var),
  889. 'expected local var', up1)
  890. end
  891. return symbolToExpression(symbol, scope)[1]
  892. end
  893. end
  894. -- Compile the outer most form. We can generate better Lua in this case.
  895. local function compileTopTarget(lvalues)
  896. -- Calculate initial rvalue
  897. local inits = {}
  898. for _, x in ipairs(lvalues) do
  899. table.insert(inits, scope.manglings[x] and x or 'nil')
  900. end
  901. local init = table.concat(inits, ', ')
  902. local lvalue = table.concat(lvalues, ', ')
  903. local plen = #parent
  904. local ret = compile1(from, scope, parent, {target = lvalue})
  905. if declaration then
  906. if #parent == plen + 1 and parent[#parent].leaf then
  907. -- A single leaf emitted means an simple assignment a = x was emitted
  908. parent[#parent].leaf = 'local ' .. parent[#parent].leaf
  909. else
  910. table.insert(parent, plen + 1, { leaf = 'local ' .. lvalue .. ' = ' .. init, ast = ast})
  911. end
  912. end
  913. return ret
  914. end
  915. -- Recursive auxiliary function
  916. local function destructure1(left, rightexprs, up1, top)
  917. if isSym(left) and left[1] ~= "nil" then
  918. checkBindingValid(left, scope, left)
  919. local lname = getname(left, up1)
  920. if top then
  921. compileTopTarget({lname})
  922. else
  923. emit(parent, setter:format(lname, exprs1(rightexprs)), left)
  924. end
  925. elseif isTable(left) then -- table destructuring
  926. if top then rightexprs = compile1(from, scope, parent) end
  927. local s = gensym(scope)
  928. emit(parent, ("local %s = %s"):format(s, exprs1(rightexprs)), left)
  929. for k, v in pairs(left) do
  930. if isSym(left[k]) and left[k][1] == "&" then
  931. assertCompile(type(k) == "number" and not left[k+2],
  932. "expected rest argument in final position", left)
  933. local subexpr = expr(('{(table.unpack or unpack)(%s, %s)}'):format(s, k),
  934. 'expression')
  935. destructure1(left[k+1], {subexpr}, left)
  936. return
  937. else
  938. if type(k) ~= "number" then k = serializeString(k) end
  939. local subexpr = expr(('%s[%s]'):format(s, k), 'expression')
  940. destructure1(v, {subexpr}, left)
  941. end
  942. end
  943. elseif isList(left) then -- values destructuring
  944. local leftNames, tables = {}, {}
  945. for i, name in ipairs(left) do
  946. local symname
  947. if isSym(name) then -- binding directly to a name
  948. symname = getname(name, up1)
  949. else -- further destructuring of tables inside values
  950. symname = gensym(scope)
  951. tables[i] = {name, expr(symname, 'sym')}
  952. end
  953. table.insert(leftNames, symname)
  954. end
  955. if top then
  956. compileTopTarget(leftNames)
  957. else
  958. local lvalue = table.concat(leftNames, ', ')
  959. emit(parent, setter:format(lvalue, exprs1(rightexprs)), left)
  960. end
  961. for _, pair in pairs(tables) do -- recurse if left-side tables found
  962. destructure1(pair[1], {pair[2]}, left)
  963. end
  964. else
  965. assertCompile(false, 'unable to destructure ' .. tostring(left), up1)
  966. end
  967. if top then return {returned = true} end
  968. end
  969. return destructure1(to, nil, ast, true)
  970. end
  971. -- Unlike most expressions and specials, 'values' resolves with multiple
  972. -- values, one for each argument, allowing multiple return values. The last
  973. -- expression can return multiple arguments as well, allowing for more than the number
  974. -- of expected arguments.
  975. local function values(ast, scope, parent)
  976. local len = #ast
  977. local exprs = {}
  978. for i = 2, len do
  979. local subexprs = compile1(ast[i], scope, parent, {
  980. nval = (i ~= len) and 1
  981. })
  982. exprs[#exprs + 1] = subexprs[1]
  983. if i == len then
  984. for j = 2, #subexprs do
  985. exprs[#exprs + 1] = subexprs[j]
  986. end
  987. end
  988. end
  989. return exprs
  990. end
  991. -- Compile a list of forms for side effects
  992. local function compileDo(ast, scope, parent, start)
  993. start = start or 2
  994. local len = #ast
  995. local subScope = makeScope(scope)
  996. for i = start, len do
  997. compile1(ast[i], subScope, parent, {
  998. nval = 0
  999. })
  1000. end
  1001. end
  1002. -- Implements a do statement, starting at the 'start' element. By default, start is 2.
  1003. local function doImpl(ast, scope, parent, opts, start, chunk, subScope)
  1004. start = start or 2
  1005. subScope = subScope or makeScope(scope)
  1006. chunk = chunk or {}
  1007. local len = #ast
  1008. local outerTarget = opts.target
  1009. local outerTail = opts.tail
  1010. local retexprs = {returned = true}
  1011. -- See if we need special handling to get the return values
  1012. -- of the do block
  1013. if not outerTarget and opts.nval ~= 0 and not outerTail then
  1014. if opts.nval then
  1015. -- Generate a local target
  1016. local syms = {}
  1017. for i = 1, opts.nval do
  1018. local s = gensym(scope)
  1019. syms[i] = s
  1020. retexprs[i] = expr(s, 'sym')
  1021. end
  1022. outerTarget = table.concat(syms, ', ')
  1023. emit(parent, ('local %s'):format(outerTarget), ast)
  1024. emit(parent, 'do', ast)
  1025. else
  1026. -- We will use an IIFE for the do
  1027. local fname = gensym(scope)
  1028. local fargs = scope.vararg and '...' or ''
  1029. emit(parent, ('local function %s(%s)'):format(fname, fargs), ast)
  1030. retexprs = expr(fname .. '(' .. fargs .. ')', 'statement')
  1031. outerTail = true
  1032. outerTarget = nil
  1033. end
  1034. else
  1035. emit(parent, 'do', ast)
  1036. end
  1037. -- Compile the body
  1038. if start > len then
  1039. -- In the unlikely case we do a do with no arguments.
  1040. compile1(nil, subScope, chunk, {
  1041. tail = outerTail,
  1042. target = outerTarget
  1043. })
  1044. -- There will be no side effects
  1045. else
  1046. for i = start, len do
  1047. local subopts = {
  1048. nval = i ~= len and 0 or opts.nval,
  1049. tail = i == len and outerTail or nil,
  1050. target = i == len and outerTarget or nil
  1051. }
  1052. local subexprs = compile1(ast[i], subScope, chunk, subopts)
  1053. if i ~= len then
  1054. keepSideEffects(subexprs, parent, nil, ast[i])
  1055. end
  1056. end
  1057. end
  1058. emit(parent, chunk, ast)
  1059. emit(parent, 'end', ast)
  1060. return retexprs
  1061. end
  1062. SPECIALS['do'] = doImpl
  1063. SPECIALS['values'] = values
  1064. -- The fn special declares a function. Syntax is similar to other lisps;
  1065. -- (fn optional-name [arg ...] (body))
  1066. -- Further decoration such as docstrings, meta info, and multibody functions a possibility.
  1067. SPECIALS['fn'] = function(ast, scope, parent)
  1068. local fScope = makeScope(scope)
  1069. local fChunk = {}
  1070. local index = 2
  1071. local fnName = isSym(ast[index])
  1072. local isLocalFn
  1073. fScope.vararg = false
  1074. if fnName and fnName[1] ~= 'nil' then
  1075. isLocalFn = not isMultiSym(fnName[1])
  1076. if isLocalFn then
  1077. fnName = declareLocal(fnName, {}, scope, ast)
  1078. else
  1079. fnName = symbolToExpression(fnName, scope)[1]
  1080. end
  1081. index = index + 1
  1082. else
  1083. isLocalFn = true
  1084. fnName = gensym(scope)
  1085. end
  1086. local argList = assertCompile(isTable(ast[index]),
  1087. 'expected vector arg list [a b ...]', ast)
  1088. local argNameList = {}
  1089. for i = 1, #argList do
  1090. if isVarg(argList[i]) then
  1091. assertCompile(i == #argList, "expected vararg in last parameter position", ast)
  1092. argNameList[i] = '...'
  1093. fScope.vararg = true
  1094. elseif(isSym(argList[i]) and argList[i][1] ~= "nil"
  1095. and not isMultiSym(argList[i][1])) then
  1096. argNameList[i] = declareLocal(argList[i], {}, fScope, ast)
  1097. elseif isTable(argList[i]) then
  1098. local raw = sym(gensym(scope))
  1099. argNameList[i] = declareLocal(raw, {}, fScope, ast)
  1100. destructure(argList[i], raw, ast, fScope, fChunk,
  1101. { declaration = true, nomulti = true })
  1102. else
  1103. assertCompile(false, 'expected symbol for function parameter', ast)
  1104. end
  1105. end
  1106. for i = index + 1, #ast do
  1107. compile1(ast[i], fScope, fChunk, {
  1108. tail = i == #ast,
  1109. nval = i ~= #ast and 0 or nil
  1110. })
  1111. end
  1112. if isLocalFn then
  1113. emit(parent, ('local function %s(%s)')
  1114. :format(fnName, table.concat(argNameList, ', ')), ast)
  1115. else
  1116. emit(parent, ('%s = function(%s)')
  1117. :format(fnName, table.concat(argNameList, ', ')), ast)
  1118. end
  1119. emit(parent, fChunk, ast)
  1120. emit(parent, 'end', ast)
  1121. return expr(fnName, 'sym')
  1122. end
  1123. -- (lua "print('hello!')") -> prints hello, evaluates to nil
  1124. -- (lua "print 'hello!'" "10") -> prints hello, evaluates to the number 10
  1125. -- (lua nil "{1,2,3}") -> Evaluates to a table literal
  1126. SPECIALS['lua'] = function(ast, _, parent)
  1127. assertCompile(#ast == 2 or #ast == 3,
  1128. "expected 2 or 3 arguments in 'lua' special form", ast)
  1129. if ast[2] ~= nil then
  1130. table.insert(parent, {leaf = tostring(ast[2]), ast = ast})
  1131. end
  1132. if #ast == 3 then
  1133. return tostring(ast[3])
  1134. end
  1135. end
  1136. -- Wrapper for table access
  1137. SPECIALS['.'] = function(ast, scope, parent)
  1138. local len = #ast
  1139. assertCompile(len > 1, "expected table argument", ast)
  1140. local lhs = compile1(ast[2], scope, parent, {nval = 1})
  1141. if len == 2 then
  1142. return tostring(lhs[1])
  1143. else
  1144. local indices = {}
  1145. for i = 3, len do
  1146. local index = ast[i]
  1147. if type(index) == 'string' and isValidLuaIdentifier(index) then
  1148. table.insert(indices, '.' .. index)
  1149. else
  1150. index = compile1(index, scope, parent, {nval = 1})[1]
  1151. table.insert(indices, '[' .. tostring(index) .. ']')
  1152. end
  1153. end
  1154. -- extra parens are needed for table literals
  1155. if isTable(ast[2]) then
  1156. return '(' .. tostring(lhs[1]) .. ')' .. table.concat(indices)
  1157. else
  1158. return tostring(lhs[1]) .. table.concat(indices)
  1159. end
  1160. end
  1161. end
  1162. SPECIALS['global'] = function(ast, scope, parent)
  1163. assertCompile(#ast == 3, "expected name and value", ast)
  1164. if allowedGlobals then table.insert(allowedGlobals, ast[2][1]) end
  1165. destructure(ast[2], ast[3], ast, scope, parent, {
  1166. nomulti = true,
  1167. forceglobal = true
  1168. })
  1169. end
  1170. SPECIALS['set'] = function(ast, scope, parent)
  1171. assertCompile(#ast == 3, "expected name and value", ast)
  1172. destructure(ast[2], ast[3], ast, scope, parent, {
  1173. noundef = true
  1174. })
  1175. end
  1176. SPECIALS['set-forcibly!'] = function(ast, scope, parent)
  1177. assertCompile(#ast == 3, "expected name and value", ast)
  1178. destructure(ast[2], ast[3], ast, scope, parent, {
  1179. forceset = true
  1180. })
  1181. end
  1182. SPECIALS['local'] = function(ast, scope, parent)
  1183. assertCompile(#ast == 3, "expected name and value", ast)
  1184. destructure(ast[2], ast[3], ast, scope, parent, {
  1185. declaration = true,
  1186. nomulti = true
  1187. })
  1188. end
  1189. SPECIALS['var'] = function(ast, scope, parent)
  1190. assertCompile(#ast == 3, "expected name and value", ast)
  1191. destructure(ast[2], ast[3], ast, scope, parent, {
  1192. declaration = true,
  1193. nomulti = true,
  1194. isvar = true
  1195. })
  1196. end
  1197. SPECIALS['let'] = function(ast, scope, parent, opts)
  1198. local bindings = ast[2]
  1199. assertCompile(isList(bindings) or isTable(bindings),
  1200. 'expected table for destructuring', ast)
  1201. assertCompile(#bindings % 2 == 0,
  1202. 'expected even number of name/value bindings', ast)
  1203. assertCompile(#ast >= 3, 'missing body expression', ast)
  1204. local subScope = makeScope(scope)
  1205. local subChunk = {}
  1206. for i = 1, #bindings, 2 do
  1207. destructure(bindings[i], bindings[i + 1], ast, subScope, subChunk, {
  1208. declaration = true,
  1209. nomulti = true
  1210. })
  1211. end
  1212. return doImpl(ast, scope, parent, opts, 3, subChunk, subScope)
  1213. end
  1214. -- For setting items in a table
  1215. SPECIALS['tset'] = function(ast, scope, parent)
  1216. assertCompile(#ast > 3, ('tset form needs table, key, and value'), ast)
  1217. local root = compile1(ast[2], scope, parent, {nval = 1})[1]
  1218. local keys = {}
  1219. for i = 3, #ast - 1 do
  1220. local key = compile1(ast[i], scope, parent, {nval = 1})[1]
  1221. keys[#keys + 1] = tostring(key)
  1222. end
  1223. local value = compile1(ast[#ast], scope, parent, {nval = 1})[1]
  1224. local rootstr = tostring(root)
  1225. local fmtstr = (rootstr:match('^{')) and '(%s)[%s] = %s' or '%s[%s] = %s'
  1226. emit(parent, fmtstr:format(tostring(root),
  1227. table.concat(keys, ']['),
  1228. tostring(value)), ast)
  1229. end
  1230. -- The if special form behaves like the cond form in
  1231. -- many languages
  1232. SPECIALS['if'] = function(ast, scope, parent, opts)
  1233. local doScope = makeScope(scope)
  1234. local branches = {}
  1235. local elseBranch = nil
  1236. -- Calculate some external stuff. Optimizes for tail calls and what not
  1237. local wrapper, innerTail, innerTarget, targetExprs
  1238. if opts.tail or opts.target or opts.nval then
  1239. if opts.nval and opts.nval ~= 0 and not opts.target then
  1240. -- We need to create a target
  1241. targetExprs = {}
  1242. local accum = {}
  1243. for i = 1, opts.nval do
  1244. local s = gensym(scope)
  1245. accum[i] = s
  1246. targetExprs[i] = expr(s, 'sym')
  1247. end
  1248. wrapper = 'target'
  1249. innerTail = opts.tail
  1250. innerTarget = table.concat(accum, ', ')
  1251. else
  1252. wrapper = 'none'
  1253. innerTail = opts.tail
  1254. innerTarget = opts.target
  1255. end
  1256. else
  1257. wrapper = 'iife'
  1258. innerTail = true
  1259. innerTarget = nil
  1260. end
  1261. -- Compile bodies and conditions
  1262. local bodyOpts = {
  1263. tail = innerTail,
  1264. target = innerTarget,
  1265. nval = opts.nval
  1266. }
  1267. local function compileBody(i)
  1268. local chunk = {}
  1269. local cscope = makeScope(doScope)
  1270. keepSideEffects(compile1(ast[i], cscope, chunk, bodyOpts),
  1271. chunk, nil, ast[i])
  1272. return {
  1273. chunk = chunk,
  1274. scope = cscope
  1275. }
  1276. end
  1277. for i = 2, #ast - 1, 2 do
  1278. local condchunk = {}
  1279. local res = compile1(ast[i], doScope, condchunk, {nval = 1})
  1280. local cond = res[1]
  1281. --print(ast[i], res, cond)
  1282. local branch = compileBody(i + 1)
  1283. branch.cond = cond
  1284. branch.condchunk = condchunk
  1285. branch.nested = i ~= 2 and next(condchunk, nil) == nil
  1286. table.insert(branches, branch)
  1287. end
  1288. local hasElse = #ast > 3 and #ast % 2 == 0
  1289. if hasElse then elseBranch = compileBody(#ast) end
  1290. -- Emit code
  1291. local s = gensym(scope)
  1292. local buffer = {}
  1293. local lastBuffer = buffer
  1294. for i = 1, #branches do
  1295. local branch = branches[i]
  1296. local fstr = not branch.nested and 'if %s then' or 'elseif %s then'
  1297. local condLine = fstr:format(tostring(branch.cond))
  1298. if branch.nested then
  1299. emit(lastBuffer, branch.condchunk, ast)
  1300. else
  1301. for _, v in ipairs(branch.condchunk) do emit(lastBuffer, v, ast) end
  1302. end
  1303. emit(lastBuffer, condLine, ast)
  1304. emit(lastBuffer, branch.chunk, ast)
  1305. if i == #branches then
  1306. if hasElse then
  1307. emit(lastBuffer, 'else', ast)
  1308. emit(lastBuffer, elseBranch.chunk, ast)
  1309. end
  1310. emit(lastBuffer, 'end', ast)
  1311. elseif not branches[i + 1].nested then
  1312. emit(lastBuffer, 'else', ast)
  1313. local nextBuffer = {}
  1314. emit(lastBuffer, nextBuffer, ast)
  1315. emit(lastBuffer, 'end', ast)
  1316. lastBuffer = nextBuffer
  1317. end
  1318. end
  1319. if wrapper == 'iife' then
  1320. local iifeargs = scope.vararg and '...' or ''
  1321. emit(parent, ('local function %s(%s)'):format(tostring(s), iifeargs), ast)
  1322. emit(parent, buffer, ast)
  1323. emit(parent, 'end', ast)
  1324. return expr(('%s(%s)'):format(tostring(s), iifeargs), 'statement')
  1325. elseif wrapper == 'none' then
  1326. -- Splice result right into code
  1327. for i = 1, #buffer do
  1328. emit(parent, buffer[i], ast)
  1329. end
  1330. return {returned = true}
  1331. else -- wrapper == 'target'
  1332. emit(parent, ('local %s'):format(innerTarget), ast)
  1333. for i = 1, #buffer do
  1334. emit(parent, buffer[i], ast)
  1335. end
  1336. return targetExprs
  1337. end
  1338. end
  1339. -- (each [k v (pairs t)] body...) => []
  1340. SPECIALS['each'] = function(ast, scope, parent)
  1341. local binding = assertCompile(isTable(ast[2]), 'expected binding table', ast)
  1342. local iter = table.remove(binding, #binding) -- last item is iterator call
  1343. local bindVars = {}
  1344. local destructures = {}
  1345. for _, v in ipairs(binding) do
  1346. assertCompile(isSym(v) or isTable(v),
  1347. 'expected iterator symbol or table', ast)
  1348. if(isSym(v)) then
  1349. table.insert(bindVars, declareLocal(v, {}, scope, ast))
  1350. else
  1351. local raw = sym(gensym(scope))
  1352. destructures[raw] = v
  1353. table.insert(bindVars, declareLocal(raw, {}, scope, ast))
  1354. end
  1355. end
  1356. emit(parent, ('for %s in %s do'):format(
  1357. table.concat(bindVars, ', '),
  1358. tostring(compile1(iter, scope, parent, {nval = 1})[1])), ast)
  1359. local chunk = {}
  1360. for raw, args in pairs(destructures) do
  1361. destructure(args, raw, ast, scope, chunk,
  1362. { declaration = true, nomulti = true })
  1363. end
  1364. compileDo(ast, scope, chunk, 3)
  1365. emit(parent, chunk, ast)
  1366. emit(parent, 'end', ast)
  1367. end
  1368. -- (while condition body...) => []
  1369. SPECIALS['while'] = function(ast, scope, parent)
  1370. local len1 = #parent
  1371. local condition = compile1(ast[2], scope, parent, {nval = 1})[1]
  1372. local len2 = #parent
  1373. local subChunk = {}
  1374. if len1 ~= len2 then
  1375. -- Compound condition
  1376. emit(parent, 'while true do', ast)
  1377. -- Move new compilation to subchunk
  1378. for i = len1 + 1, len2 do
  1379. subChunk[#subChunk + 1] = parent[i]
  1380. parent[i] = nil
  1381. end
  1382. emit(parent, ('if %s then break end'):format(condition[1]), ast)
  1383. else
  1384. -- Simple condition
  1385. emit(parent, 'while ' .. tostring(condition) .. ' do', ast)
  1386. end
  1387. compileDo(ast, makeScope(scope), subChunk, 3)
  1388. emit(parent, subChunk, ast)
  1389. emit(parent, 'end', ast)
  1390. end
  1391. SPECIALS['for'] = function(ast, scope, parent)
  1392. local ranges = assertCompile(isTable(ast[2]), 'expected binding table', ast)
  1393. local bindingSym = assertCompile(isSym(table.remove(ast[2], 1)),
  1394. 'expected iterator symbol', ast)
  1395. local rangeArgs = {}
  1396. for i = 1, math.min(#ranges, 3) do
  1397. rangeArgs[i] = tostring(compile1(ranges[i], scope, parent, {nval = 1})[1])
  1398. end
  1399. emit(parent, ('for %s = %s do'):format(
  1400. declareLocal(bindingSym, {}, scope, ast),
  1401. table.concat(rangeArgs, ', ')), ast)
  1402. local chunk = {}
  1403. compileDo(ast, scope, chunk, 3)
  1404. emit(parent, chunk, ast)
  1405. emit(parent, 'end', ast)
  1406. end
  1407. SPECIALS[':'] = function(ast, scope, parent)
  1408. assertCompile(#ast >= 3, 'expected at least 3 arguments', ast)
  1409. -- Compile object
  1410. local objectexpr = compile1(ast[2], scope, parent, {nval = 1})[1]
  1411. -- Compile method selector
  1412. local methodstring
  1413. local methodident = false
  1414. if type(ast[3]) == 'string' and isValidLuaIdentifier(ast[3]) then
  1415. methodident = true
  1416. methodstring = ast[3]
  1417. else
  1418. methodstring = tostring(compile1(ast[3], scope, parent, {nval = 1})[1])
  1419. objectexpr = once(objectexpr, ast[2], scope, parent)
  1420. end
  1421. -- Compile arguments
  1422. local args = {}
  1423. for i = 4, #ast do
  1424. local subexprs = compile1(ast[i], scope, parent, {
  1425. nval = i ~= #ast and 1 or nil
  1426. })
  1427. for j = 1, #subexprs do
  1428. args[#args + 1] = tostring(subexprs[j])
  1429. end
  1430. end
  1431. local fstring
  1432. if methodident then
  1433. fstring = objectexpr.type == 'literal'
  1434. and '(%s):%s(%s)'
  1435. or '%s:%s(%s)'
  1436. else
  1437. -- Make object first argument
  1438. table.insert(args, 1, tostring(objectexpr))
  1439. fstring = objectexpr.type == 'sym'
  1440. and '%s[%s](%s)'
  1441. or '(%s)[%s](%s)'
  1442. end
  1443. return expr(fstring:format(
  1444. tostring(objectexpr),
  1445. methodstring,
  1446. table.concat(args, ', ')), 'statement')
  1447. end
  1448. local function defineArithmeticSpecial(name, zeroArity, unaryPrefix)
  1449. local paddedOp = ' ' .. name .. ' '
  1450. SPECIALS[name] = function(ast, scope, parent)
  1451. local len = #ast
  1452. if len == 1 then
  1453. assertCompile(zeroArity ~= nil, 'Expected more than 0 arguments', ast)
  1454. return expr(zeroArity, 'literal')
  1455. else
  1456. local operands = {}
  1457. for i = 2, len do
  1458. local subexprs = compile1(ast[i], scope, parent, {
  1459. nval = (i == 1 and 1 or nil)
  1460. })
  1461. for j = 1, #subexprs do
  1462. operands[#operands + 1] = tostring(subexprs[j])
  1463. end
  1464. end
  1465. if #operands == 1 then
  1466. if unaryPrefix then
  1467. return '(' .. unaryPrefix .. paddedOp .. operands[1] .. ')'
  1468. else
  1469. return operands[1]
  1470. end
  1471. else
  1472. return '(' .. table.concat(operands, paddedOp) .. ')'
  1473. end
  1474. end
  1475. end
  1476. end
  1477. defineArithmeticSpecial('+', '0')
  1478. defineArithmeticSpecial('..', "''")
  1479. defineArithmeticSpecial('^')
  1480. defineArithmeticSpecial('-', nil, '')
  1481. defineArithmeticSpecial('*', '1')
  1482. defineArithmeticSpecial('%')
  1483. defineArithmeticSpecial('/', nil, '1')
  1484. defineArithmeticSpecial('//', nil, '1')
  1485. defineArithmeticSpecial('or', 'false')
  1486. defineArithmeticSpecial('and', 'true')
  1487. local function defineComparatorSpecial(name, realop, chainOp)
  1488. local op = realop or name
  1489. SPECIALS[name] = function(ast, scope, parent)
  1490. local len = #ast
  1491. assertCompile(len > 2, 'expected at least two arguments', ast)
  1492. local lhs = compile1(ast[2], scope, parent, {nval = 1})[1]
  1493. local lastval = compile1(ast[3], scope, parent, {nval = 1})[1]
  1494. -- avoid double-eval by introducing locals for possible side-effects
  1495. if len > 3 then lastval = once(lastval, ast[3], scope, parent) end
  1496. local out = ('(%s %s %s)'):
  1497. format(tostring(lhs), op, tostring(lastval))
  1498. if len > 3 then
  1499. for i = 4, len do -- variadic comparison
  1500. local nextval = once(compile1(ast[i], scope, parent, {nval = 1})[1],
  1501. ast[i], scope, parent)
  1502. out = (out .. " %s (%s %s %s)"):
  1503. format(chainOp or 'and', tostring(lastval), op, tostring(nextval))
  1504. lastval = nextval
  1505. end
  1506. out = '(' .. out .. ')'
  1507. end
  1508. return out
  1509. end
  1510. end
  1511. defineComparatorSpecial('>')
  1512. defineComparatorSpecial('<')
  1513. defineComparatorSpecial('>=')
  1514. defineComparatorSpecial('<=')
  1515. defineComparatorSpecial('=', '==')
  1516. defineComparatorSpecial('~=', '~=', 'or')
  1517. defineComparatorSpecial('not=', '~=', 'or')
  1518. local function defineUnarySpecial(op, realop)
  1519. SPECIALS[op] = function(ast, scope, parent)
  1520. assertCompile(#ast == 2, 'expected one argument', ast)
  1521. local tail = compile1(ast[2], scope, parent, {nval = 1})
  1522. return (realop or op) .. tostring(tail[1])
  1523. end
  1524. end
  1525. defineUnarySpecial('not', 'not ')
  1526. defineUnarySpecial('#')
  1527. -- Save current macro scope
  1528. local macroCurrentScope = GLOBAL_SCOPE
  1529. -- Covert a macro function to a special form
  1530. local function macroToSpecial(mac)
  1531. return function(ast, scope, parent, opts)
  1532. local oldScope = macroCurrentScope
  1533. macroCurrentScope = scope
  1534. local ok, transformed = pcall(mac, unpack(ast, 2))
  1535. macroCurrentScope = oldScope
  1536. assertCompile(ok, transformed, ast)
  1537. local result = compile1(transformed, scope, parent, opts)
  1538. return result
  1539. end
  1540. end
  1541. local function compile(ast, options)
  1542. options = options or {}
  1543. local oldGlobals = allowedGlobals
  1544. allowedGlobals = options.allowedGlobals
  1545. if options.indent == nil then options.indent = ' ' end
  1546. local chunk = {}
  1547. local scope = options.scope or makeScope(GLOBAL_SCOPE)
  1548. local exprs = compile1(ast, scope, chunk, {tail = true})
  1549. keepSideEffects(exprs, chunk, nil, ast)
  1550. allowedGlobals = oldGlobals
  1551. return flatten(chunk, options)
  1552. end
  1553. -- map a function across all pairs in a table
  1554. local function quoteTmap(f, t)
  1555. local res = {}
  1556. for k,v in pairs(t) do
  1557. local nk, nv = f(k, v)
  1558. if nk then
  1559. res[nk] = nv
  1560. end
  1561. end
  1562. return res
  1563. end
  1564. -- make a transformer for key / value table pairs, preserving all numeric keys
  1565. local function entryTransform(fk,fv)
  1566. return function(k, v)
  1567. if type(k) == 'number' then
  1568. return k,fv(v)
  1569. else
  1570. return fk(k),fv(v)
  1571. end
  1572. end
  1573. end
  1574. -- consume everything return nothing
  1575. local function no() end
  1576. local function mixedConcat(t, joiner)
  1577. local ret = ""
  1578. local s = ""
  1579. local seen = {}
  1580. for k,v in ipairs(t) do
  1581. table.insert(seen, k)
  1582. ret = ret .. s .. v
  1583. s = joiner
  1584. end
  1585. for k,v in pairs(t) do
  1586. if not(seen[k]) then
  1587. ret = ret .. s .. '[' .. k .. ']' .. '=' .. v
  1588. s = joiner
  1589. end
  1590. end
  1591. return ret
  1592. end
  1593. -- expand a quoted form into a data literal, evaluating unquote
  1594. local function doQuote (form, scope, parent, runtime)
  1595. local q = function (x) return doQuote(x, scope, parent, runtime) end
  1596. -- vararg
  1597. if isVarg(form) then
  1598. assertCompile(not runtime, "quoted ... may only be used at compile time", form)
  1599. return "_VARARG"
  1600. -- symbol
  1601. elseif isSym(form) then
  1602. assertCompile(not runtime, "symbols may only be used at compile time", form)
  1603. return ("sym('%s')"):format(deref(form))
  1604. -- unquote
  1605. elseif isList(form) and isSym(form[1]) and (deref(form[1]) == 'unquote') then
  1606. local payload = form[2]
  1607. local res = unpack(compile1(payload, scope, parent))
  1608. return res[1]
  1609. -- list
  1610. elseif isList(form) then
  1611. assertCompile(not runtime, "lists may only be used at compile time", form)
  1612. local mapped = quoteTmap(entryTransform(no, q), form)
  1613. return 'list(' .. mixedConcat(mapped, ", ") .. ')'
  1614. -- table
  1615. elseif type(form) == 'table' then
  1616. local mapped = quoteTmap(entryTransform(q, q), form)
  1617. return '{' .. mixedConcat(mapped, ", ") .. '}'
  1618. -- string
  1619. elseif type(form) == 'string' then
  1620. return serializeString(form)
  1621. else
  1622. return tostring(form)
  1623. end
  1624. end
  1625. SPECIALS['quote'] = function(ast, scope, parent)
  1626. assertCompile(#ast == 2, "quote only takes a single form")
  1627. local runtime, thisScope = true, scope
  1628. while thisScope do
  1629. thisScope = thisScope.parent
  1630. if thisScope == COMPILER_SCOPE then runtime = false end
  1631. end
  1632. return doQuote(ast[2], scope, parent, runtime)
  1633. end
  1634. local function compileStream(strm, options)
  1635. options = options or {}
  1636. local oldGlobals = allowedGlobals
  1637. allowedGlobals = options.allowedGlobals
  1638. if options.indent == nil then options.indent = ' ' end
  1639. local scope = options.scope or makeScope(GLOBAL_SCOPE)
  1640. local vals = {}
  1641. for ok, val in parser(strm, options.filename) do
  1642. if not ok then break end
  1643. vals[#vals + 1] = val
  1644. end
  1645. local chunk = {}
  1646. for i = 1, #vals do
  1647. local exprs = compile1(vals[i], scope, chunk, {
  1648. tail = i == #vals
  1649. })
  1650. keepSideEffects(exprs, chunk, nil, vals[i])
  1651. end
  1652. allowedGlobals = oldGlobals
  1653. return flatten(chunk, options)
  1654. end
  1655. local function compileString(str, options)
  1656. local strm = stringStream(str)
  1657. return compileStream(strm, options)
  1658. end
  1659. ---
  1660. --- Evaluation
  1661. ---
  1662. -- Convert a fennel environment table to a Lua environment table.
  1663. -- This means automatically unmangling globals when getting a value,
  1664. -- and mangling values when setting a value. This means the original
  1665. -- env will see its values updated as expected, regardless of mangling rules.
  1666. local function wrapEnv(env)
  1667. return setmetatable({}, {
  1668. __index = function(_, key)
  1669. if type(key) == 'string' then
  1670. key = globalUnmangling(key)
  1671. end
  1672. return env[key]
  1673. end,
  1674. __newindex = function(_, key, value)
  1675. if type(key) == 'string' then
  1676. key = globalMangling(key)
  1677. end
  1678. env[key] = value
  1679. end,
  1680. -- checking the __pairs metamethod won't work automatically in Lua 5.1
  1681. -- sadly, but it's important for 5.2+ and can be done manually in 5.1
  1682. __pairs = function()
  1683. local pt = {}
  1684. for key, value in pairs(env) do
  1685. if type(key) == 'string' then
  1686. pt[globalUnmangling(key)] = value
  1687. else
  1688. pt[key] = value
  1689. end
  1690. end
  1691. return next, pt, nil
  1692. end,
  1693. })
  1694. end
  1695. -- A custom traceback function for Fennel that looks similar to
  1696. -- the Lua's debug.traceback.
  1697. -- Use with xpcall to produce fennel specific stacktraces.
  1698. local function traceback(msg, start)
  1699. local level = start or 2 -- Can be used to skip some frames
  1700. local lines = {}
  1701. if msg then
  1702. table.insert(lines, msg)
  1703. end
  1704. table.insert(lines, 'stack traceback:')
  1705. while true do
  1706. local info = debug.getinfo(level, "Sln")
  1707. if not info then break end
  1708. local line
  1709. if info.what == "C" then
  1710. if info.name then
  1711. line = (' [C]: in function \'%s\''):format(info.name)
  1712. else
  1713. line = ' [C]: in ?'
  1714. end
  1715. else
  1716. local remap = fennelSourcemap[info.source]
  1717. if remap and remap[info.currentline] then
  1718. -- And some global info
  1719. info.short_src = remap.short_src
  1720. local mapping = remap[info.currentline]
  1721. -- Overwrite info with values from the mapping (mapping is now
  1722. -- just integer, but may eventually be a table)
  1723. info.currentline = mapping
  1724. end
  1725. if info.what == 'Lua' then
  1726. local n = info.name and ("'" .. info.name .. "'") or '?'
  1727. line = (' %s:%d: in function %s'):format(info.short_src, info.currentline, n)
  1728. elseif info.short_src == '(tail call)' then
  1729. line = ' (tail call)'
  1730. else
  1731. line = (' %s:%d: in main chunk'):format(info.short_src, info.currentline)
  1732. end
  1733. end
  1734. table.insert(lines, line)
  1735. level = level + 1
  1736. end
  1737. return table.concat(lines, '\n')
  1738. end
  1739. local function currentGlobalNames(env)
  1740. local names = {}
  1741. for k in pairs(env or _G) do
  1742. k = globalUnmangling(k)
  1743. table.insert(names, k)
  1744. end
  1745. return names
  1746. end
  1747. local function eval(str, options, ...)
  1748. options = options or {}
  1749. -- eval and dofile are considered "live" entry points, so we can assume
  1750. -- that the globals available at compile time are a reasonable allowed list
  1751. -- UNLESS there's a metatable on env, in which case we can't assume that
  1752. -- pairs will return all the effective globals; for instance openresty
  1753. -- sets up _G in such a way that all the globals are available thru
  1754. -- the __index meta method, but as far as pairs is concerned it's empty.
  1755. if options.allowedGlobals == nil and not getmetatable(options.env) then
  1756. options.allowedGlobals = currentGlobalNames(options.env)
  1757. end
  1758. local env = options.env and wrapEnv(options.env)
  1759. local luaSource = compileString(str, options)
  1760. local loader = loadCode(luaSource, env,
  1761. options.filename and ('@' .. options.filename) or str)
  1762. return loader(...)
  1763. end
  1764. local function dofileFennel(filename, options, ...)
  1765. options = options or {sourcemap = true}
  1766. if options.allowedGlobals == nil then
  1767. options.allowedGlobals = currentGlobalNames(options.env)
  1768. end
  1769. local f = assert(io.open(filename, "rb"))
  1770. local source = f:read("*all"):gsub("^#![^\n]*\n", "")
  1771. f:close()
  1772. options.filename = options.filename or filename
  1773. return eval(source, options, ...)
  1774. end
  1775. -- Implements a configurable repl
  1776. local function repl(options)
  1777. local opts = options or {}
  1778. -- This would get set for us when calling eval, but we want to seed it
  1779. -- with a value that is persistent so it doesn't get reset on each eval.
  1780. if opts.allowedGlobals == nil then
  1781. options.allowedGlobals = currentGlobalNames(opts.env)
  1782. end
  1783. local env = opts.env and wrapEnv(opts.env) or setmetatable({}, {
  1784. __index = _ENV or _G
  1785. })
  1786. local function defaultReadChunk(parserState)
  1787. io.write(parserState.stackSize > 0 and '.. ' or '>> ')
  1788. io.flush()
  1789. local input = io.read()
  1790. return input and input .. '\n'
  1791. end
  1792. local function defaultOnValues(xs)
  1793. io.write(table.concat(xs, '\t'))
  1794. io.write('\n')
  1795. end
  1796. local function defaultOnError(errtype, err, luaSource)
  1797. if (errtype == 'Lua Compile') then
  1798. io.write('Bad code generated - likely a bug with the compiler:\n')
  1799. io.write('--- Generated Lua Start ---\n')
  1800. io.write(luaSource .. '\n')
  1801. io.write('--- Generated Lua End ---\n')
  1802. end
  1803. if (errtype == 'Runtime') then
  1804. io.write(traceback(err, 4))
  1805. io.write('\n')
  1806. else
  1807. io.write(('%s error: %s\n'):format(errtype, tostring(err)))
  1808. end
  1809. end
  1810. -- Read options
  1811. local readChunk = opts.readChunk or defaultReadChunk
  1812. local onValues = opts.onValues or defaultOnValues
  1813. local onError = opts.onError or defaultOnError
  1814. local pp = opts.pp or tostring
  1815. -- Make parser
  1816. local bytestream, clearstream = granulate(readChunk)
  1817. local chars = {}
  1818. local read, reset = parser(function (parserState)
  1819. local c = bytestream(parserState)
  1820. chars[#chars + 1] = c
  1821. return c
  1822. end)
  1823. local envdbg = (opts.env or _G)["debug"]
  1824. -- if the environment doesn't support debug.getlocal you can't save locals
  1825. local saveLocals = opts.saveLocals ~= false and envdbg and envdbg.getlocal
  1826. local saveSource = table.
  1827. concat({"local ___i___ = 1",
  1828. "while true do",
  1829. " local name, value = debug.getlocal(1, ___i___)",
  1830. " if(name and name ~= \"___i___\") then",
  1831. " ___replLocals___[name] = value",
  1832. " ___i___ = ___i___ + 1",
  1833. " else break end end"}, "\n")
  1834. local spliceSaveLocals = function(luaSource)
  1835. -- we do some source munging in order to save off locals from each chunk
  1836. -- and reintroduce them to the beginning of the next chunk, allowing
  1837. -- locals to work in the repl the way you'd expect them to.
  1838. env.___replLocals___ = env.___replLocals___ or {}
  1839. local splicedSource = {}
  1840. for line in luaSource:gmatch("([^\n]+)\n?") do
  1841. table.insert(splicedSource, line)
  1842. end
  1843. -- reintroduce locals from the previous time around
  1844. local bind = "local %s = ___replLocals___['%s']"
  1845. for name in pairs(env.___replLocals___) do
  1846. table.insert(splicedSource, 1, bind:format(name, name))
  1847. end
  1848. -- save off new locals at the end - if safe to do so (i.e. last line is a return)
  1849. if (string.match(splicedSource[#splicedSource], "^ *return .*$")) then
  1850. if (#splicedSource > 1) then
  1851. table.insert(splicedSource, #splicedSource, saveSource)
  1852. end
  1853. end
  1854. return table.concat(splicedSource, "\n")
  1855. end
  1856. local scope = makeScope(GLOBAL_SCOPE)
  1857. local replCompleter = function(text, from, to)
  1858. local matches = {}
  1859. local inputFragment = string.lower(text):sub(from, to):gsub("[%s)(]*(.+)", "%1")
  1860. -- adds any matching keys from the provided generator/iterator to matches
  1861. local function addMatchesFromGen(next, param, state)
  1862. for k in next, param, state do
  1863. if #matches >= 40 then break -- cap completions at 40 to avoid overwhelming output
  1864. elseif inputFragment == k:sub(0, #inputFragment):lower() then table.insert(matches, k) end
  1865. end
  1866. end
  1867. addMatchesFromGen(pairs(env._ENV or env._G or {}))
  1868. addMatchesFromGen(pairs(env.___replLocals___ or {}))
  1869. addMatchesFromGen(pairs(SPECIALS or {}))
  1870. addMatchesFromGen(pairs(scope.specials or {}))
  1871. return matches
  1872. end
  1873. if options.registerCompleter then options.registerCompleter(replCompleter) end
  1874. -- REPL loop
  1875. while true do
  1876. chars = {}
  1877. local ok, parseok, x = pcall(read)
  1878. local srcstring = string.char(unpack(chars))
  1879. if not ok then
  1880. onError('Parse', parseok)
  1881. clearstream()
  1882. reset()
  1883. else
  1884. if not parseok then break end -- eof
  1885. local compileOk, luaSource = pcall(compile, x, {
  1886. sourcemap = opts.sourcemap,
  1887. source = srcstring,
  1888. scope = scope,
  1889. })
  1890. if not compileOk then
  1891. clearstream()
  1892. onError('Compile', luaSource) -- luaSource is error message in this case
  1893. else
  1894. if saveLocals then
  1895. luaSource = spliceSaveLocals(luaSource)
  1896. end
  1897. local luacompileok, loader = pcall(loadCode, luaSource, env)
  1898. if not luacompileok then
  1899. clearstream()
  1900. onError('Lua Compile', loader, luaSource)
  1901. else
  1902. local loadok, ret = xpcall(function () return {loader()} end,
  1903. function (runtimeErr)
  1904. onError('Runtime', runtimeErr)
  1905. end)
  1906. if loadok then
  1907. env._ = ret[1]
  1908. env.__ = ret
  1909. for i = 1, #ret do ret[i] = pp(ret[i]) end
  1910. onValues(ret)
  1911. end
  1912. end
  1913. end
  1914. end
  1915. end
  1916. end
  1917. local macroLoaded = {}
  1918. local module = {
  1919. parser = parser,
  1920. granulate = granulate,
  1921. stringStream = stringStream,
  1922. compile = compile,
  1923. compileString = compileString,
  1924. compileStream = compileStream,
  1925. compile1 = compile1,
  1926. mangle = globalMangling,
  1927. unmangle = globalUnmangling,
  1928. list = list,
  1929. sym = sym,
  1930. varg = varg,
  1931. scope = makeScope,
  1932. gensym = gensym,
  1933. eval = eval,
  1934. repl = repl,
  1935. dofile = dofileFennel,
  1936. macroLoaded = macroLoaded,
  1937. path = "./?.fnl;./?/init.fnl",
  1938. traceback = traceback,
  1939. version = "0.2.1",
  1940. }
  1941. local function searchModule(modulename)
  1942. modulename = modulename:gsub("%.", "/")
  1943. for path in string.gmatch(module.path..";", "([^;]*);") do
  1944. local filename = path:gsub("%?", modulename)
  1945. local file = io.open(filename, "rb")
  1946. if(file) then
  1947. file:close()
  1948. return filename
  1949. end
  1950. end
  1951. end
  1952. module.makeSearcher = function(options)
  1953. return function(modulename)
  1954. local opts = {}
  1955. for k,v in pairs(options or {}) do opts[k] = v end
  1956. local filename = searchModule(modulename)
  1957. if filename then
  1958. return function(modname)
  1959. return dofileFennel(filename, opts, modname)
  1960. end
  1961. end
  1962. end
  1963. end
  1964. -- This will allow regular `require` to work with Fennel:
  1965. -- table.insert(package.loaders, fennel.searcher)
  1966. module.searcher = module.makeSearcher()
  1967. module.make_searcher = module.makeSearcher -- oops backwards compatibility
  1968. local function makeCompilerEnv(ast, scope, parent)
  1969. return setmetatable({
  1970. -- State of compiler if needed
  1971. _SCOPE = scope,
  1972. _CHUNK = parent,
  1973. _AST = ast,
  1974. _IS_COMPILER = true,
  1975. _SPECIALS = SPECIALS,
  1976. _VARARG = VARARG,
  1977. -- Expose the module in the compiler
  1978. fennel = module,
  1979. -- Useful for macros and meta programming. All of Fennel can be accessed
  1980. -- via fennel.myfun, for example (fennel.eval "(print 1)").
  1981. list = list,
  1982. sym = sym,
  1983. unpack = unpack,
  1984. gensym = function() return sym(gensym(macroCurrentScope)) end,
  1985. ["list?"] = isList,
  1986. ["multi-sym?"] = isMultiSym,
  1987. ["sym?"] = isSym,
  1988. ["table?"] = isTable,
  1989. ["sequence?"] = isSequence,
  1990. ["varg?"] = isVarg,
  1991. ["get-scope"] = function() return macroCurrentScope end,
  1992. ["in-scope?"] = function(symbol)
  1993. return macroCurrentScope.manglings[tostring(symbol)]
  1994. end
  1995. }, { __index = _ENV or _G })
  1996. end
  1997. local function macroGlobals(env, globals)
  1998. local allowed = {}
  1999. for k in pairs(env) do
  2000. local g = globalUnmangling(k)
  2001. table.insert(allowed, g)
  2002. end
  2003. if globals then
  2004. for _, k in pairs(globals) do
  2005. table.insert(allowed, k)
  2006. end
  2007. end
  2008. return allowed
  2009. end
  2010. local function addMacros(macros, ast, scope)
  2011. assertCompile(isTable(macros), 'expected macros to be table', ast)
  2012. for k, v in pairs(macros) do
  2013. scope.specials[k] = macroToSpecial(v)
  2014. end
  2015. end
  2016. local function loadMacros(modname, ast, scope, parent)
  2017. local filename = assertCompile(searchModule(modname),
  2018. modname .. " not found.", ast)
  2019. local env = makeCompilerEnv(ast, scope, parent)
  2020. local globals = macroGlobals(env, currentGlobalNames())
  2021. return dofileFennel(filename, { env = env, allowedGlobals = globals,
  2022. scope = COMPILER_SCOPE })
  2023. end
  2024. SPECIALS['require-macros'] = function(ast, scope, parent)
  2025. assertCompile(#ast == 2, "Expected one module name argument", ast)
  2026. local modname = ast[2]
  2027. if not macroLoaded[modname] then
  2028. macroLoaded[modname] = loadMacros(modname, ast, scope, parent)
  2029. end
  2030. addMacros(macroLoaded[modname], ast, scope, parent)
  2031. end
  2032. local function evalCompiler(ast, scope, parent)
  2033. local luaSource = compile(ast, { scope = makeScope(COMPILER_SCOPE) })
  2034. local loader = loadCode(luaSource, wrapEnv(makeCompilerEnv(ast, scope, parent)))
  2035. return loader()
  2036. end
  2037. SPECIALS['macros'] = function(ast, scope, parent)
  2038. assertCompile(#ast == 2, "Expected one table argument", ast)
  2039. local macros = evalCompiler(ast[2], scope, parent)
  2040. addMacros(macros, ast, scope, parent)
  2041. end
  2042. SPECIALS['eval-compiler'] = function(ast, scope, parent)
  2043. local oldFirst = ast[1]
  2044. ast[1] = sym('do')
  2045. local val = evalCompiler(ast, scope, parent)
  2046. ast[1] = oldFirst
  2047. return val
  2048. end
  2049. -- Load standard macros
  2050. local stdmacros = [===[
  2051. {"->" (fn [val ...]
  2052. (var x val)
  2053. (each [_ e (ipairs [...])]
  2054. (let [elt (if (list? e) e (list e))]
  2055. (table.insert elt 2 x)
  2056. (set x elt)))
  2057. x)
  2058. "->>" (fn [val ...]
  2059. (var x val)
  2060. (each [_ e (pairs [...])]
  2061. (let [elt (if (list? e) e (list e))]
  2062. (table.insert elt x)
  2063. (set x elt)))
  2064. x)
  2065. "-?>" (fn [val ...]
  2066. (if (= 0 (# [...]))
  2067. val
  2068. (let [els [...]
  2069. e (table.remove els 1)
  2070. el (if (list? e) e (list e))
  2071. tmp (gensym)]
  2072. (table.insert el 2 tmp)
  2073. `(let [@tmp @val]
  2074. (if @tmp
  2075. (-?> @el @(unpack els))
  2076. @tmp)))))
  2077. "-?>>" (fn [val ...]
  2078. (if (= 0 (# [...]))
  2079. val
  2080. (let [els [...]
  2081. e (table.remove els 1)
  2082. el (if (list? e) e (list e))
  2083. tmp (gensym)]
  2084. (table.insert el tmp)
  2085. `(let [@tmp @val]
  2086. (if @tmp
  2087. (-?>> @el @(unpack els))
  2088. @tmp)))))
  2089. :doto (fn [val ...]
  2090. (let [name (gensym)
  2091. form `(let [@name @val])]
  2092. (each [_ elt (pairs [...])]
  2093. (table.insert elt 2 name)
  2094. (table.insert form elt))
  2095. (table.insert form name)
  2096. form))
  2097. :when (fn [condition body1 ...]
  2098. (assert body1 "expected body")
  2099. `(if @condition
  2100. (do @body1 @...)))
  2101. :partial (fn [f ...]
  2102. (let [body (list f ...)]
  2103. (table.insert body _VARARG)
  2104. `(fn [@_VARARG] @body)))
  2105. :lambda (fn [...]
  2106. (let [args [...]
  2107. has-internal-name? (sym? (. args 1))
  2108. arglist (if has-internal-name? (. args 2) (. args 1))
  2109. arity-check-position (if has-internal-name? 3 2)]
  2110. (assert (> (# args) 1) "missing body expression")
  2111. (each [i a (ipairs arglist)]
  2112. (if (and (not (: (tostring a) :match "^?"))
  2113. (~= (tostring a) "..."))
  2114. (table.insert args arity-check-position
  2115. `(assert (~= nil @a)
  2116. (: "Missing argument %s on %s:%s"
  2117. :format @(tostring a)
  2118. @(or a.filename "unknown")
  2119. @(or a.line "?"))))))
  2120. `(fn @(unpack args))))
  2121. :match
  2122. (fn match [val ...]
  2123. ;; this function takes the AST of values and a single pattern and returns a
  2124. ;; condition to determine if it matches as well as a list of bindings to
  2125. ;; introduce for the duration of the body if it does match.
  2126. (fn match-pattern [vals pattern unifications]
  2127. ;; we have to assume we're matching against multiple values here until we
  2128. ;; know we're either in a multi-valued clause (in which case we know the #
  2129. ;; of vals) or we're not, in which case we only care about the first one.
  2130. (let [[val] vals]
  2131. (if (and (sym? pattern) ; unification with outer locals (or nil)
  2132. (or (in-scope? pattern)
  2133. (= :nil (tostring pattern))))
  2134. (values `(= @val @pattern) [])
  2135. ;; unify a local we've seen already
  2136. (and (sym? pattern)
  2137. (. unifications (tostring pattern)))
  2138. (values `(= @(. unifications (tostring pattern)) @val) [])
  2139. ;; bind a fresh local
  2140. (sym? pattern)
  2141. (do (if (~= (tostring pattern) "_")
  2142. (tset unifications (tostring pattern) val))
  2143. (values (if (: (tostring pattern) :find "^?")
  2144. true `(~= @(sym :nil) @val))
  2145. [pattern val]))
  2146. ;; multi-valued patterns (represented as lists)
  2147. (list? pattern)
  2148. (let [condition `(and)
  2149. bindings []]
  2150. (each [i pat (ipairs pattern)]
  2151. (let [(subcondition subbindings) (match-pattern [(. vals i)] pat
  2152. unifications)]
  2153. (table.insert condition subcondition)
  2154. (each [_ b (ipairs subbindings)]
  2155. (table.insert bindings b))))
  2156. (values condition bindings))
  2157. ;; table patterns)
  2158. (= (type pattern) :table)
  2159. (let [condition `(and (= (type @val) :table))
  2160. bindings []]
  2161. (each [k pat (pairs pattern)]
  2162. (if (and (sym? pat) (= "&" (tostring pat)))
  2163. (do (assert (not (. pattern (+ k 2)))
  2164. "expected rest argument in final position")
  2165. (table.insert bindings (. pattern (+ k 1)))
  2166. (table.insert bindings [`(select @k ((or unpack table.unpack)
  2167. @val))]))
  2168. (and (= :number (type k))
  2169. (= "&" (tostring (. pattern (- k 1)))))
  2170. nil ; don't process the pattern right after &; already got it
  2171. (let [subval `(. @val @k)
  2172. (subcondition subbindings) (match-pattern [subval] pat
  2173. unifications)]
  2174. (table.insert condition subcondition)
  2175. (each [_ b (ipairs subbindings)]
  2176. (table.insert bindings b)))))
  2177. (values condition bindings))
  2178. ;; literal value
  2179. (values `(= @val @pattern) []))))
  2180. (fn match-condition [vals clauses]
  2181. (let [out `(if)]
  2182. (for [i 1 (# clauses) 2]
  2183. (let [pattern (. clauses i)
  2184. body (. clauses (+ i 1))
  2185. (condition bindings) (match-pattern vals pattern {})]
  2186. (table.insert out condition)
  2187. (table.insert out `(let @bindings @body))))
  2188. out))
  2189. ;; how many multi-valued clauses are there? return a list of that many gensyms
  2190. (fn val-syms [clauses]
  2191. (let [syms (list (gensym))]
  2192. (for [i 1 (# clauses) 2]
  2193. (if (list? (. clauses i))
  2194. (each [valnum (ipairs (. clauses i))]
  2195. (if (not (. syms valnum))
  2196. (tset syms valnum (gensym))))))
  2197. syms))
  2198. ;; wrap it in a way that prevents double-evaluation of the matched value
  2199. (let [clauses [...]
  2200. vals (val-syms clauses)]
  2201. (if (~= 0 (% (# clauses) 2)) ; treat odd final clause as default
  2202. (table.insert clauses (# clauses) (sym :_)))
  2203. ;; protect against multiple evaluation of the value, bind against as
  2204. ;; many values as we ever match against in the clauses.
  2205. (list (sym :let) [vals val]
  2206. (match-condition vals clauses))))
  2207. }
  2208. ]===]
  2209. do
  2210. local env = makeCompilerEnv(nil, COMPILER_SCOPE, {})
  2211. for name, fn in pairs(eval(stdmacros, {
  2212. env = env,
  2213. scope = makeScope(COMPILER_SCOPE),
  2214. -- assume the code to load globals doesn't have any mistaken globals,
  2215. -- otherwise this can be problematic when loading fennel in contexts
  2216. -- where _G is an empty table with an __index metamethod. (openresty)
  2217. allowedGlobals = false,
  2218. })) do
  2219. SPECIALS[name] = macroToSpecial(fn)
  2220. end
  2221. end
  2222. SPECIALS['λ'] = SPECIALS['lambda']
  2223. return module