403 lines
16 KiB
Nim
403 lines
16 KiB
Nim
# Copyright 2022 Mattia Giambirtone & All Contributors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import meta/ast
|
|
import meta/token
|
|
|
|
import parseutils
|
|
import strformat
|
|
import strutils
|
|
import math
|
|
|
|
|
|
type
|
|
WarningKind* = enum
|
|
unreachableCode,
|
|
nameShadowing,
|
|
isWithALiteral,
|
|
equalityWithSingleton,
|
|
valueOverflow,
|
|
implicitConversion,
|
|
invalidOperation
|
|
|
|
Warning* = ref object
|
|
kind*: WarningKind
|
|
node*: ASTNode
|
|
|
|
Optimizer* = ref object
|
|
warnings: seq[Warning]
|
|
foldConstants*: bool
|
|
|
|
|
|
proc initOptimizer*(foldConstants: bool = true): Optimizer =
|
|
## Initializes a new optimizer object
|
|
new(result)
|
|
result.foldConstants = foldConstants
|
|
result.warnings = @[]
|
|
|
|
|
|
proc newWarning(self: Optimizer, kind: WarningKind, node: ASTNode) =
|
|
self.warnings.add(Warning(kind: kind, node: node))
|
|
|
|
|
|
proc `$`*(self: Warning): string = &"Warning(kind={self.kind}, node={self.node})"
|
|
|
|
|
|
# Forward declaration
|
|
proc optimizeNode(self: Optimizer, node: ASTNode): ASTNode
|
|
|
|
|
|
proc optimizeConstant(self: Optimizer, node: ASTNode): ASTNode =
|
|
## Performs some checks on constant AST nodes such as
|
|
## integers. This method converts all of the different
|
|
## integer forms (binary, octal and hexadecimal) to
|
|
## decimal integers. Overflows are checked here too
|
|
if not self.foldConstants:
|
|
return node
|
|
case node.kind:
|
|
of intExpr:
|
|
var x: int
|
|
var y = IntExpr(node)
|
|
try:
|
|
assert parseInt(y.literal.lexeme, x) == len(y.literal.lexeme)
|
|
except ValueError:
|
|
self.newWarning(valueOverflow, node)
|
|
result = node
|
|
of hexExpr:
|
|
var x: int
|
|
var y = HexExpr(node)
|
|
try:
|
|
assert parseHex(y.literal.lexeme, x) == len(y.literal.lexeme)
|
|
except ValueError:
|
|
self.newWarning(valueOverflow, node)
|
|
return node
|
|
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
|
|
of binExpr:
|
|
var x: int
|
|
var y = BinExpr(node)
|
|
try:
|
|
assert parseBin(y.literal.lexeme, x) == len(y.literal.lexeme)
|
|
except ValueError:
|
|
self.newWarning(valueOverflow, node)
|
|
return node
|
|
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
|
|
of octExpr:
|
|
var x: int
|
|
var y = OctExpr(node)
|
|
try:
|
|
assert parseOct(y.literal.lexeme, x) == len(y.literal.lexeme)
|
|
except ValueError:
|
|
self.newWarning(valueOverflow, node)
|
|
return node
|
|
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
|
|
of floatExpr:
|
|
var x: float
|
|
var y = FloatExpr(node)
|
|
try:
|
|
discard parseFloat(y.literal.lexeme, x)
|
|
except ValueError:
|
|
self.newWarning(valueOverflow, node)
|
|
return node
|
|
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $x, line: y.literal.line, pos: (start: -1, stop: -1)))
|
|
else:
|
|
result = node
|
|
|
|
|
|
proc optimizeUnary(self: Optimizer, node: UnaryExpr): ASTNode =
|
|
## Attempts to optimize unary expressions
|
|
var a = self.optimizeNode(node.a)
|
|
if self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == a:
|
|
# We can't optimize further, the overflow will be caught in the compiler
|
|
return UnaryExpr(kind: unaryExpr, a: a, operator: node.operator)
|
|
case a.kind:
|
|
of intExpr:
|
|
var x: int
|
|
assert parseInt(IntExpr(a).literal.lexeme, x) == len(IntExpr(a).literal.lexeme)
|
|
case node.operator.kind:
|
|
of Tilde:
|
|
x = not x
|
|
of Minus:
|
|
x = -x
|
|
else:
|
|
discard # Unreachable
|
|
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $x, line: node.operator.line, pos: (start: -1, stop: -1)))
|
|
of floatExpr:
|
|
var x: float
|
|
discard parseFloat(FloatExpr(a).literal.lexeme, x)
|
|
case node.operator.kind:
|
|
of Minus:
|
|
x = -x
|
|
of Tilde:
|
|
self.newWarning(invalidOperation, node)
|
|
return node
|
|
else:
|
|
discard
|
|
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $x, line: node.operator.line, pos: (start: -1, stop: -1)))
|
|
else:
|
|
result = node
|
|
|
|
|
|
proc optimizeBinary(self: Optimizer, node: BinaryExpr): ASTNode =
|
|
## Attempts to optimize binary expressions
|
|
var a, b: ASTNode
|
|
a = self.optimizeNode(node.a)
|
|
b = self.optimizeNode(node.b)
|
|
if self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and (self.warnings[^1].node == a or self.warnings[^1].node == b):
|
|
# We can't optimize further, the overflow will be caught in the compiler. We don't return the same node
|
|
# because optimizeNode might've been able to optimize one of the two operands and we don't know which
|
|
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
if node.operator.kind == DoubleEqual:
|
|
if a.kind in {trueExpr, falseExpr, nilExpr, nanExpr, infExpr}:
|
|
self.newWarning(equalityWithSingleton, a)
|
|
elif b.kind in {trueExpr, falseExpr, nilExpr, nanExpr, infExpr}:
|
|
self.newWarning(equalityWithSingleton, b)
|
|
elif node.operator.kind == Is:
|
|
if a.kind in {strExpr, intExpr, tupleExpr, dictExpr, listExpr, setExpr}:
|
|
self.newWarning(isWithALiteral, a)
|
|
elif b.kind in {strExpr, intExpr, tupleExpr, dictExpr, listExpr, setExpr}:
|
|
self.newWarning(isWithALiteral, b)
|
|
if a.kind == intExpr and b.kind == intExpr:
|
|
# Optimizes integer operations
|
|
var x, y, z: int
|
|
assert parseInt(IntExpr(a).literal.lexeme, x) == IntExpr(a).literal.lexeme.len()
|
|
assert parseInt(IntExpr(b).literal.lexeme, y) == IntExpr(b).literal.lexeme.len()
|
|
try:
|
|
case node.operator.kind:
|
|
of Plus:
|
|
z = x + y
|
|
of Minus:
|
|
z = x - y
|
|
of Asterisk:
|
|
z = x * y
|
|
of FloorDiv:
|
|
z = int(x / y)
|
|
of DoubleAsterisk:
|
|
if y >= 0:
|
|
z = x ^ y
|
|
else:
|
|
# Nim's builtin pow operator can't handle
|
|
# negative exponents, so we use math's
|
|
# pow and convert from/to floats instead
|
|
z = pow(x.float, y.float).int
|
|
of Percentage:
|
|
z = x mod y
|
|
of Caret:
|
|
z = x xor y
|
|
of Ampersand:
|
|
z = x and y
|
|
of Pipe:
|
|
z = x or y
|
|
of Slash:
|
|
# Special case, yields a float
|
|
return FloatExpr(kind: intExpr, literal: Token(kind: Float, lexeme: $(x / y), line: IntExpr(a).literal.line, pos: (start: -1, stop: -1)))
|
|
else:
|
|
result = BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
except OverflowDefect:
|
|
self.newWarning(valueOverflow, node)
|
|
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
except RangeDefect:
|
|
# TODO: What warning do we raise here?
|
|
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
result = IntExpr(kind: intExpr, literal: Token(kind: Integer, lexeme: $z, line: IntExpr(a).literal.line, pos: (start: -1, stop: -1)))
|
|
elif a.kind == floatExpr or b.kind == floatExpr:
|
|
var x, y, z: float
|
|
if a.kind == intExpr:
|
|
var temp: int
|
|
assert parseInt(IntExpr(a).literal.lexeme, temp) == IntExpr(a).literal.lexeme.len()
|
|
x = float(temp)
|
|
self.newWarning(implicitConversion, a)
|
|
else:
|
|
discard parseFloat(FloatExpr(a).literal.lexeme, x)
|
|
if b.kind == intExpr:
|
|
var temp: int
|
|
assert parseInt(IntExpr(b).literal.lexeme, temp) == IntExpr(b).literal.lexeme.len()
|
|
y = float(temp)
|
|
self.newWarning(implicitConversion, b)
|
|
else:
|
|
discard parseFloat(FloatExpr(b).literal.lexeme, y)
|
|
# Optimizes float operations
|
|
try:
|
|
case node.operator.kind:
|
|
of Plus:
|
|
z = x + y
|
|
of Minus:
|
|
z = x - y
|
|
of Asterisk:
|
|
z = x * y
|
|
of FloorDiv, Slash:
|
|
z = x / y
|
|
of DoubleAsterisk:
|
|
z = pow(x, y)
|
|
of Percentage:
|
|
z = x mod y
|
|
else:
|
|
result = BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
except OverflowDefect:
|
|
self.newWarning(valueOverflow, node)
|
|
return BinaryExpr(kind: binaryExpr, a: a, b: b, operator: node.operator)
|
|
result = FloatExpr(kind: floatExpr, literal: Token(kind: Float, lexeme: $z, line: LiteralExpr(a).literal.line, pos: (start: -1, stop: -1)))
|
|
elif a.kind == strExpr and b.kind == strExpr:
|
|
var a = StrExpr(a)
|
|
var b = StrExpr(b)
|
|
case node.operator.kind:
|
|
of Plus:
|
|
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & a.literal.lexeme[1..<(^1)] & b.literal.lexeme[1..<(^1)] & "'", pos: (start: -1, stop: -1)))
|
|
else:
|
|
result = node
|
|
elif a.kind == strExpr and self.optimizeNode(b).kind == intExpr and not (self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == b):
|
|
var a = StrExpr(a)
|
|
var b = IntExpr(b)
|
|
var bb: int
|
|
assert parseInt(b.literal.lexeme, bb) == b.literal.lexeme.len()
|
|
case node.operator.kind:
|
|
of Asterisk:
|
|
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & a.literal.lexeme[1..<(^1)].repeat(bb) & "'"))
|
|
else:
|
|
result = node
|
|
elif b.kind == strExpr and self.optimizeNode(a).kind == intExpr and not (self.warnings.len() > 0 and self.warnings[^1].kind == valueOverflow and self.warnings[^1].node == a):
|
|
var b = StrExpr(b)
|
|
var a = IntExpr(a)
|
|
var aa: int
|
|
assert parseInt(a.literal.lexeme, aa) == a.literal.lexeme.len()
|
|
case node.operator.kind:
|
|
of Asterisk:
|
|
result = StrExpr(kind: strExpr, literal: Token(kind: String, lexeme: "'" & b.literal.lexeme[1..<(^1)].repeat(aa) & "'"))
|
|
else:
|
|
result = node
|
|
else:
|
|
# There's no constant folding we can do!
|
|
result = node
|
|
|
|
|
|
proc detectClosures(self: Optimizer, node: FunDecl) =
|
|
## Goes trough a function's code and detects
|
|
## references to variables in enclosing local
|
|
## scopes
|
|
var names: seq[Declaration] = @[]
|
|
for line in BlockStmt(node.body).code:
|
|
case line.kind:
|
|
of varDecl:
|
|
names.add(VarDecl(line))
|
|
of funDecl:
|
|
names.add(FunDecl(line))
|
|
of classDecl:
|
|
names.add(ClassDecl(line))
|
|
else:
|
|
discard
|
|
for name in names:
|
|
|
|
|
|
proc optimizeNode(self: Optimizer, node: ASTNode): ASTNode =
|
|
## Analyzes an AST node and attempts to perform
|
|
## optimizations on it. If no optimizations can be
|
|
## applied or self.foldConstants is set to false,
|
|
## then the same node is returned
|
|
if not self.foldConstants:
|
|
return node
|
|
case node.kind:
|
|
of exprStmt:
|
|
result = newExprStmt(self.optimizeNode(ExprStmt(node).expression), ExprStmt(node).token)
|
|
of intExpr, hexExpr, octExpr, binExpr, floatExpr, strExpr:
|
|
result = self.optimizeConstant(node)
|
|
of unaryExpr:
|
|
result = self.optimizeUnary(UnaryExpr(node))
|
|
of binaryExpr:
|
|
result = self.optimizeBinary(BinaryExpr(node))
|
|
of groupingExpr:
|
|
# Recursively unnests groups
|
|
result = self.optimizeNode(GroupingExpr(node).expression)
|
|
of callExpr:
|
|
var node = CallExpr(node)
|
|
for i, positional in node.arguments.positionals:
|
|
node.arguments.positionals[i] = self.optimizeNode(positional)
|
|
for i, (key, value) in node.arguments.keyword:
|
|
node.arguments.keyword[i].value = self.optimizeNode(value)
|
|
result = node
|
|
of sliceExpr:
|
|
var node = SliceExpr(node)
|
|
for i, e in node.ends:
|
|
node.ends[i] = self.optimizeNode(e)
|
|
node.slicee = self.optimizeNode(node.slicee)
|
|
result = node
|
|
of tryStmt:
|
|
var node = TryStmt(node)
|
|
node.body = self.optimizeNode(node.body)
|
|
if node.finallyClause != nil:
|
|
node.finallyClause = self.optimizeNode(node.finallyClause)
|
|
if node.elseClause != nil:
|
|
node.elseClause = self.optimizeNode(node.elseClause)
|
|
for i, handler in node.handlers:
|
|
node.handlers[i].body = self.optimizeNode(node.handlers[i].body)
|
|
result = node
|
|
of funDecl:
|
|
var decl = FunDecl(node)
|
|
for i, node in decl.defaults:
|
|
decl.defaults[i] = self.optimizeNode(node)
|
|
decl.body = self.optimizeNode(decl.body)
|
|
result = decl
|
|
of blockStmt:
|
|
var node = BlockStmt(node)
|
|
for i, n in node.code:
|
|
node.code[i] = self.optimizeNode(n)
|
|
result = node
|
|
of varDecl:
|
|
var decl = VarDecl(node)
|
|
decl.value = self.optimizeNode(decl.value)
|
|
result = decl
|
|
of assignExpr:
|
|
var asgn = AssignExpr(node)
|
|
asgn.value = self.optimizeNode(asgn.value)
|
|
result = asgn
|
|
of listExpr:
|
|
var l = ListExpr(node)
|
|
for i, e in l.members:
|
|
l.members[i] = self.optimizeNode(e)
|
|
result = node
|
|
of setExpr:
|
|
var s = SetExpr(node)
|
|
for i, e in s.members:
|
|
s.members[i] = self.optimizeNode(e)
|
|
result = node
|
|
of tupleExpr:
|
|
var t = TupleExpr(node)
|
|
for i, e in t.members:
|
|
t.members[i] = self.optimizeNode(e)
|
|
result = node
|
|
of dictExpr:
|
|
var d = DictExpr(node)
|
|
for i, e in d.keys:
|
|
d.keys[i] = self.optimizeNode(e)
|
|
for i, e in d.values:
|
|
d.values[i] = self.optimizeNode(e)
|
|
result = node
|
|
else:
|
|
result = node
|
|
|
|
|
|
proc optimize*(self: Optimizer, tree: seq[ASTNode]): tuple[tree: seq[ASTNode], warnings: seq[Warning]] =
|
|
## Runs the optimizer on the given source
|
|
## tree and returns a new optimized tree
|
|
## as well as a list of warnings that may
|
|
## be of interest. The input tree may be
|
|
## identical to the output tree if no optimization
|
|
## could be performed. Constant folding can be
|
|
## turned off by setting foldConstants to false
|
|
## when initializing the optimizer object. This
|
|
## optimization step also takes care of detecting
|
|
## closed-over variables so that the compiler can
|
|
## emit appropriate instructions for them later on
|
|
var newTree: seq[ASTNode] = @[]
|
|
for node in tree:
|
|
newTree.add(self.optimizeNode(node))
|
|
result = (tree: newTree, warnings: self.warnings)
|