peon/src/frontend/optimizer.nim

403 lines
16 KiB
Nim

# Copyright 2022 Mattia Giambirtone & All Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import meta/ast
import meta/token
import parseutils
import strformat
import strutils
import math
type
WarningKind* = enum
unreachableCode,
nameShadowing,
isWithALiteral,
equalityWithSingleton,
valueOverflow,
implicitConversion,
invalidOperation
Warning* = ref object
kind*: WarningKind
node*: ASTNode
Optimizer* = ref object
warnings: seq[Warning]
foldConstants*: bool
proc initOptimizer*(foldConstants: bool = true): Optimizer =
## Initializes a new optimizer object
new(result)
result.foldConstants = foldConstants
result.warnings = @[]
proc newWarning(self: Optimizer, kind: WarningKind, node: ASTNode) =
self.warnings.add(Warning(kind: kind, node: node))
proc `$`*(self: Warning): string = &"Warning(kind={self.kind}, node={self.node})"
# Forward declaration
proc optimizeNode(self: Optimizer, node: ASTNode): ASTNode
proc optimizeConstant(self: Optimizer, node: ASTNode): ASTNode =
## Performs some checks on constant AST nodes such as
## integers. This method converts all of the different
## integer forms (binary, octal and hexadecimal) to
## decimal integers. Overflows are checked here too
if not self.foldConstants:
return node
case node.kind:
of intExpr:
var x: int
var y = IntExpr(node)
try:
assert parseInt(y.literal.lexeme, x) == len(y.literal.lexeme)
except ValueError:
self.newWarning(valueOverflow, node)
result = node
of hexExpr:
var x: int
var y = HexExpr(node)
try:
assert parseHex(y.literal.lexeme, x) == len(y.literal.lexeme)
except ValueError:
self.newWarning(valueOverflow, node)
return node
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
of binExpr:
var x: int
var y = BinExpr(node)
try:
assert parseBin(y.literal.lexeme, x) == len(y.literal.lexeme)
except ValueError:
self.newWarning(valueOverflow, node)
return node
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
of octExpr:
var x: int
var y = OctExpr(node)
try:
assert parseOct(y.literal.lexeme, x) == len(y.literal.lexeme)
except ValueError:
self.newWarning(valueOverflow, node)
return node
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
of floatExpr:
var x: float
var y = FloatExpr(node)
try:
discard parseFloat(y.literal.lexeme, x)
except ValueError:
self.newWarning(valueOverflow, node)
return node
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
else:
result = node
proc optimizeUnary(self: Optimizer, node: UnaryExpr): ASTNode =
## Attempts to optimize unary expressions
var a = self.optimizeNode(node.a)
if self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == a:
# We can't optimize further, the overflow will be caught in the compiler
return UnaryExpr(kind: unaryExpr, a: a, operator: node.operator)
case a.kind:
of intExpr:
var x: int
assert parseInt(IntExpr(a).literal.lexeme, x) == len(IntExpr(a).literal.lexeme)
case node.operator.kind:
of Tilde:
x = not x
of Minus:
x = -x
else:
discard # Unreachable
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: node.operator.line, pos: (start: -1, stop: -1)))
of floatExpr:
var x: float
discard parseFloat(FloatExpr(a).literal.lexeme, x)
case node.operator.kind:
of Minus:
x = -x
of Tilde:
self.newWarning(invalidOperation, node)
return node
else:
discard
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $x, line: node.operator.line, pos: (start: -1, stop: -1)))
else:
result = node
proc optimizeBinary(self: Optimizer, node: BinaryExpr): ASTNode =
## Attempts to optimize binary expressions
var a, b: ASTNode
a = self.optimizeNode(node.a)
b = self.optimizeNode(node.b)
if self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and (self.warnings[^1].node == a or self.warnings[^1].node == b):
# We can't optimize further, the overflow will be caught in the compiler. We don't return the same node
# because optimizeNode might've been able to optimize one of the two operands and we don't know which
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
if node.operator.kind == DoubleEqual:
if a.kind in {trueExpr, falseExpr, nilExpr, nanExpr, infExpr}:
self.newWarning(equalityWithSingleton, a)
elif b.kind in {trueExpr, falseExpr, nilExpr, nanExpr, infExpr}:
self.newWarning(equalityWithSingleton, b)
elif node.operator.kind == Is:
if a.kind in {strExpr, intExpr, tupleExpr, dictExpr, listExpr, setExpr}:
self.newWarning(isWithALiteral, a)
elif b.kind in {strExpr, intExpr, tupleExpr, dictExpr, listExpr, setExpr}:
self.newWarning(isWithALiteral, b)
if a.kind == intExpr and b.kind == intExpr:
# Optimizes integer operations
var x, y, z: int
assert parseInt(IntExpr(a).literal.lexeme, x) == IntExpr(a).literal.lexeme.len()
assert parseInt(IntExpr(b).literal.lexeme, y) == IntExpr(b).literal.lexeme.len()
try:
case node.operator.kind:
of Plus:
z = x + y
of Minus:
z = x - y
of Asterisk:
z = x * y
of FloorDiv:
z = int(x / y)
of DoubleAsterisk:
if y >= 0:
z = x ^ y
else:
# Nim's builtin pow operator can't handle
# negative exponents, so we use math's
# pow and convert from/to floats instead
z = pow(x.float, y.float).int
of Percentage:
z = x mod y
of Caret:
z = x xor y
of Ampersand:
z = x and y
of Pipe:
z = x or y
of Slash:
# Special case, yields a float
return FloatExpr(kind: intExpr, literal: Token(kind: Float, lexeme: $(x / y), line: IntExpr(a).literal.line, pos: (start: -1, stop: -1)))
else:
result = BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
except OverflowDefect:
self.newWarning(valueOverflow, node)
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
except RangeDefect:
# TODO: What warning do we raise here?
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $z, line: IntExpr(a).literal.line, pos: (start: -1, stop: -1)))
elif a.kind == floatExpr or b.kind == floatExpr:
var x, y, z: float
if a.kind == intExpr:
var temp: int
assert parseInt(IntExpr(a).literal.lexeme, temp) == IntExpr(a).literal.lexeme.len()
x = float(temp)
self.newWarning(implicitConversion, a)
else:
discard parseFloat(FloatExpr(a).literal.lexeme, x)
if b.kind == intExpr:
var temp: int
assert parseInt(IntExpr(b).literal.lexeme, temp) == IntExpr(b).literal.lexeme.len()
y = float(temp)
self.newWarning(implicitConversion, b)
else:
discard parseFloat(FloatExpr(b).literal.lexeme, y)
# Optimizes float operations
try:
case node.operator.kind:
of Plus:
z = x + y
of Minus:
z = x - y
of Asterisk:
z = x * y
of FloorDiv, Slash:
z = x / y
of DoubleAsterisk:
z = pow(x, y)
of Percentage:
z = x mod y
else:
result = BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
except OverflowDefect:
self.newWarning(valueOverflow, node)
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $z, line: LiteralExpr(a).literal.line, pos: (start: -1, stop: -1)))
elif a.kind == strExpr and b.kind == strExpr:
var a = StrExpr(a)
var b = StrExpr(b)
case node.operator.kind:
of Plus:
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & a.literal.lexeme[1..<(^1)] & b.literal.lexeme[1..<(^1)] & "'", pos: (start: -1, stop: -1)))
else:
result = node
elif a.kind == strExpr and self.optimizeNode(b).kind == intExpr and not (self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == b):
var a = StrExpr(a)
var b = IntExpr(b)
var bb: int
assert parseInt(b.literal.lexeme, bb) == b.literal.lexeme.len()
case node.operator.kind:
of Asterisk:
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & a.literal.lexeme[1..<(^1)].repeat(bb) & "'"))
else:
result = node
elif b.kind == strExpr and self.optimizeNode(a).kind == intExpr and not (self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == a):
var b = StrExpr(b)
var a = IntExpr(a)
var aa: int
assert parseInt(a.literal.lexeme, aa) == a.literal.lexeme.len()
case node.operator.kind:
of Asterisk:
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & b.literal.lexeme[1..<(^1)].repeat(aa) & "'"))
else:
result = node
else:
# There's no constant folding we can do!
result = node
proc detectClosures(self: Optimizer, node: FunDecl) =
## Goes trough a function's code and detects
## references to variables in enclosing local
## scopes
var names: seq[Declaration] = @[]
for line in BlockStmt(node.body).code:
case line.kind:
of varDecl:
names.add(VarDecl(line))
of funDecl:
names.add(FunDecl(line))
of classDecl:
names.add(ClassDecl(line))
else:
discard
for name in names:
proc optimizeNode(self: Optimizer, node: ASTNode): ASTNode =
## Analyzes an AST node and attempts to perform
## optimizations on it. If no optimizations can be
## applied or self.foldConstants is set to false,
## then the same node is returned
if not self.foldConstants:
return node
case node.kind:
of exprStmt:
result = newExprStmt(self.optimizeNode(ExprStmt(node).expression), ExprStmt(node).token)
of intExpr, hexExpr, octExpr, binExpr, floatExpr, strExpr:
result = self.optimizeConstant(node)
of unaryExpr:
result = self.optimizeUnary(UnaryExpr(node))
of binaryExpr:
result = self.optimizeBinary(BinaryExpr(node))
of groupingExpr:
# Recursively unnests groups
result = self.optimizeNode(GroupingExpr(node).expression)
of callExpr:
var node = CallExpr(node)
for i, positional in node.arguments.positionals:
node.arguments.positionals[i] = self.optimizeNode(positional)
for i, (key, value) in node.arguments.keyword:
node.arguments.keyword[i].value = self.optimizeNode(value)
result = node
of sliceExpr:
var node = SliceExpr(node)
for i, e in node.ends:
node.ends[i] = self.optimizeNode(e)
node.slicee = self.optimizeNode(node.slicee)
result = node
of tryStmt:
var node = TryStmt(node)
node.body = self.optimizeNode(node.body)
if node.finallyClause != nil:
node.finallyClause = self.optimizeNode(node.finallyClause)
if node.elseClause != nil:
node.elseClause = self.optimizeNode(node.elseClause)
for i, handler in node.handlers:
node.handlers[i].body = self.optimizeNode(node.handlers[i].body)
result = node
of funDecl:
var decl = FunDecl(node)
for i, node in decl.defaults:
decl.defaults[i] = self.optimizeNode(node)
decl.body = self.optimizeNode(decl.body)
result = decl
of blockStmt:
var node = BlockStmt(node)
for i, n in node.code:
node.code[i] = self.optimizeNode(n)
result = node
of varDecl:
var decl = VarDecl(node)
decl.value = self.optimizeNode(decl.value)
result = decl
of assignExpr:
var asgn = AssignExpr(node)
asgn.value = self.optimizeNode(asgn.value)
result = asgn
of listExpr:
var l = ListExpr(node)
for i, e in l.members:
l.members[i] = self.optimizeNode(e)
result = node
of setExpr:
var s = SetExpr(node)
for i, e in s.members:
s.members[i] = self.optimizeNode(e)
result = node
of tupleExpr:
var t = TupleExpr(node)
for i, e in t.members:
t.members[i] = self.optimizeNode(e)
result = node
of dictExpr:
var d = DictExpr(node)
for i, e in d.keys:
d.keys[i] = self.optimizeNode(e)
for i, e in d.values:
d.values[i] = self.optimizeNode(e)
result = node
else:
result = node
proc optimize*(self: Optimizer, tree: seq[ASTNode]): tuple[tree: seq[ASTNode], warnings: seq[Warning]] =
## Runs the optimizer on the given source
## tree and returns a new optimized tree
## as well as a list of warnings that may
## be of interest. The input tree may be
## identical to the output tree if no optimization
## could be performed. Constant folding can be
## turned off by setting foldConstants to false
## when initializing the optimizer object. This
## optimization step also takes care of detecting
## closed-over variables so that the compiler can
## emit appropriate instructions for them later on
var newTree: seq[ASTNode] = @[]
for node in tree:
newTree.add(self.optimizeNode(node))
result = (tree: newTree, warnings: self.warnings)