Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implements a macro for easy union construction from expression #6

Merged
merged 1 commit into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,26 @@ proc search[T, U](x: T, needle: U): auto =

let idx = find(x, needle)
if idx >= 0:
result <- x[idx] # sugar for assigning without converting via `as`
result <- x[idx] # sugar for assignment without conversion

assert [1, 2, 42, 20, 1000].search(10) of None
assert [1, 2, 42, 20, 1000].search(42) as int == 42
# For `==`, no explicit conversion is necessary
assert [1, 2, 42, 20, 1000].search(42) == 42
# Types that are not active at the moment will simply be treated as not equal
assert [1, 2, 42, 20, 1000].search(1) != None()

proc `{}`[T](x: seq[T], idx: Natural): auto =
## An array accessor for seq[T] but doesn't raise if the index is not there
# Using makeUnion, an expression may return more than one type
makeUnion:
if idx in 0 ..< x.len:
x[idx]
else:
None()

assert @[1]{2} of None
assert @[42]{0} == 42
```

See the [documentation][0] for more information on features and limitations of
Expand Down
95 changes: 95 additions & 0 deletions tests/tmaker.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pkg/balls

import union

suite "makeUnion() tests":
test "unionTail is present in makeUnion":
let x = makeUnion:
if true:
unionTail(10)
else:
[1, 2, 3]

check x == 10

test "if expression":
let x = makeUnion:
if true:
10
elif false:
4.2
else:
"string"

check x is union(int | float | string)
check x == 10

test "block expression":
let x = makeUnion:
block:
if true:
10
elif false:
4.2
else:
"string"

check x is union(int | float | string)
check x == 10

test "pragma block expression":
let x = makeUnion:
{.line: instantiationInfo(0).}:
if true:
10
elif false:
4.2
else:
"string"

check x is union(int | float | string)
check x == 10

test "when expression":
let x = makeUnion:
when false:
[1, 2, 3, 4]
else:
if true:
@[1, 2, 3, 4]
else:
true

check x is union(seq[int] | bool)
check x == @[1, 2, 3, 4]

test "case expression":
let x = makeUnion:
case 10
of 1, 3, 6, 8:
RootObj()
of 2, 7, 4, 5:
[4, 2]
elif false:
RootRef()
elif true:
42
else:
@["string"]

check x is union(RootObj | array[2, int] | RootRef | int | seq[string])
check x == 42

test "try expression":
let x = makeUnion:
try:
RootObj()
except ValueError:
[4, 2]
except KeyError:
RootRef()
except:
42

check x is union(RootObj | array[2, int] | RootRef | int)
check x == RootObj()
213 changes: 212 additions & 1 deletion union.nim
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,26 @@ runnableExamples:

let idx = find(x, needle)
if idx >= 0:
result <- x[idx]
result <- x[idx] # sugar for assignment without conversion

assert [1, 2, 42, 20, 1000].search(10) of None
assert [1, 2, 42, 20, 1000].search(42) as int == 42
# For `==`, no explicit conversion is necessary
assert [1, 2, 42, 20, 1000].search(42) == 42
# Types that are not active at the moment will simply be treated as not equal
assert [1, 2, 42, 20, 1000].search(1) != None()

proc `{}`[T](x: seq[T], idx: Natural): auto =
## An array accessor for seq[T] but doesn't raise if the index is not there
# Using makeUnion, an expression may return more than one type
makeUnion:
if idx in 0 ..< x.len:
x[idx]
else:
None()

assert @[1]{2} of None
assert @[42]{0} == 42

import std/[
algorithm, macros, macrocache, sequtils, typetraits, options
Expand Down Expand Up @@ -401,3 +417,198 @@ template `==`*[T; U: Union](x: T, u: U): untyped =
##
## Returns false if `u` current type is not `T`.
u == x

proc exprFilter(n: NimNode, fn: proc(n: NimNode): NimNode): NimNode =
## Produce a new tree from `n` by running `fn` on all things that looks like
## an expression tail.
##
## This is because we are working on untyped AST, thus we have little details
## on whether something is an expression.
proc branchFilter(n: NimNode, fn: proc(n: NimNode): NimNode): NimNode =
## Shared logic for filtering elif/of/else/except/finally branches
case n.kind
of nnkElifBranch, nnkElifExpr:
# Copy the node and condition
result = copyNimNode(n).add(copy n[0]):
# Rewrite body
exprFilter(n.last, fn)
of nnkOfBranch, nnkExceptBranch:
# Copy the node
result = copyNimNode(n)
# Copy matching constraints (all node but last)
for idx in 0 ..< n.len - 1:
result.add copy(n[idx])

# Rewrite body
result.add exprFilter(n.last, fn)
of nnkElse, nnkElseExpr:
# Copy the node and rewrite body
result = copyNimNode(n).add:
exprFilter(n.last, fn)
of nnkFinally:
# Copy the node, it can't have expression
result = copy(n)
else:
raise newException(Defect):
"unknown node kind passed to branchFilter: " & $n.kind

case n.kind
of nnkStmtList, nnkStmtListExpr:
result = copyNimNode(n)

for idx in 0 ..< n.len - 1: # copy everything but the last node
result.add copy(n[idx])

# run the filter on the last node
result.add exprFilter(n.last, fn)
of nnkBlockStmt, nnkBlockExpr, nnkPragmaBlock:
# Copy the node and the label/pragma list
result = copyNimNode(n).add(copy n[0]):
# Run filter on block body
exprFilter(n.last, fn)
of nnkIfStmt, nnkIfExpr, nnkWhenStmt:
# Copy the node
result = copyNimNode(n)

# Rewrite children
for child in n.items:
result.add branchFilter(child, fn)
of nnkCaseStmt:
# Copy the node
result = copyNimNode(n)

# Rewrite children
for idx, child in n.pairs:
if idx == 0:
# This is the matching constraint, copy as is
result.add copy(child)
else:
result.add branchFilter(child, fn)
of nnkTryStmt:
# Copy the node
result = copyNimNode(n)

for idx, child in n.pairs:
if idx == 0:
# Rewrite the try body
result.add exprFilter(child, fn)
else:
# Process branches
result.add branchFilter(child, fn)
else:
# If it's not a known expression block type, it's an expression
result = fn(n)
if result.isNil:
result = copy(n)

proc filter(n: NimNode, fn: proc(n: NimNode): NimNode): NimNode =
## Produce a new tree by running `fn` on all nodes.
##
## If `fn` returns non-nil, filter will not recurse into that node.
## Otherwise, the `n` will be copied and filter will apply `fn`
## on all of `n` children.
result = fn(n)
if result.isNil:
result = copyNimNode(n)
for c in n.items:
result.add filter(c, fn)

template unionExpr(T, expr: typed) {.pragma.}
## Tag the expression `expr` with a type to be collected by
## `collectUnion`.

macro unionTail(n: typed): untyped =
## Analyze `n` and produce `unionExpr` tag for `collectUnion`.
# If `n` has a type
if n.typeKind notin {ntyNone, ntyVoid}:
# Produce a `{.unionExpr(typeof(n), n).}: <nothing>` tag
result = newStmtList:
# Obtain the type from `n`, and copy `n` lineinfo into it
let exprTyp = getTypeInst(n)
exprTyp.copyLineInfo(n)
# We have to use a block or the compiler will complain with:
#
# Error: cannot attach a custom pragma to <module>
nnkPragmaBlock.newTree(
nnkPragma.newTree(newCall(bindSym"unionExpr", exprTyp, n)),
newStmtList()
)
else:
# If n doesn't have a type, do nothing
result = n

proc getUnionExpr(n: NimNode): Option[tuple[typ, expr: NimNode]] =
## Returns the data within `unionExpr` tag, if `n` is one.
if (
n.kind == nnkPragmaBlock and n[0].kind == nnkPragma and
n[0].last[0] == bindSym"unionExpr"
):
result = some((n[0].last[1], n[0].last[2]))

macro collectUnion(expr: typed): untyped =
## Collect annotated data from makeUnion() and friends and
## turn expr into an actual expression.
var types: NimNode = nil
# Collect all types into a typeclass
discard expr.filter do (n: NimNode) -> NimNode:
let unionExpr = getUnionExpr(n)
# If this is an unionExpr annotation
if unionExpr.isSome:
# Obtain the tagged type
let taggedType = copy unionExpr.get.typ
types =
if types.isNil:
taggedType
else:
types.infix(bindSym"|", taggedType)

# Build an union typedesc from the typeclass
let unionType = newTypedesc:
newCall(bindSym"union", types)

# Run another filter pass, this time replace all tags
# with conversions of the body to the union type
result = expr.filter do (n: NimNode) -> NimNode:
let unionExpr = getUnionExpr(n)
if unionExpr.isSome:
infix(copy(unionExpr.get.expr), bindSym"as"):
newCall(bindSym"union", copy(types))
else:
nil

macro makeUnion*(expr: untyped): untyped =
## Produce an union from expression `expr`. The expression may return
## multiple different types, of which will be combined into one union type.
##
## The expression must return more than one type. A compile-time error will
## be raised if the expression returns only one type.
##
## Due to compiler limitations, this macro cannot evaluate macros within
## `expr` and might miss a few expressions. In those cases, the expressions
## need to be analyzed can be tagged by making a call to `unionTail`, which
## is introduced into `expr` scope.
runnableExamples:
let x = makeUnion:
if true:
10
else:
"string"

assert x is union(int | string)

template introduceUnionTail(expr: untyped): untyped =
## A small helper to introduce `unionTail` to expr's scope
bind unionTail
template unionTail(x: untyped) {.used.} = unionTail(x)
expr

result = newStmtList:
# Run collectUnion on the tagged tree to finalize it
newCall(bindSym"collectUnion"):
newStmtList:
newCall(bindSym"introduceUnionTail"):
# Add the tagged tree
expr.exprFilter do (n: NimNode) -> NimNode:
# For each "expression tail", call unionTail to process it
newCall(bindSym"unionTail"):
copy(n)