Initial work on genetic algorithm for tris
This commit is contained in:
parent
d6e5e148aa
commit
1f875e6f2b
12
src/main.nim
12
src/main.nim
|
@ -3,7 +3,11 @@ import nn/util/activations
|
||||||
import nn/util/losses
|
import nn/util/losses
|
||||||
|
|
||||||
|
|
||||||
var net = newNeuralNetwork(@[2, 3, 2], activationFunc=newActivation(sigmoid, func (x, y: float): float = 0.0),
|
const InitialSize = 50
|
||||||
lossFunc=newLoss(mse, mse), weightRange=(-1.0, +1.0), learnRate=0.05)
|
|
||||||
var prediction = net.predict(newMatrix[float](@[2.7, 3.0]))
|
|
||||||
echo prediction
|
var networks: seq[NeuralNetwork] = @[]
|
||||||
|
for _ in 0..<InitialSize:
|
||||||
|
networks.add(newNeuralNetwork(@[9, 8, 10, 9], activationFunc=newActivation(sigmoid, func (x, y: float): float = 0.0),
|
||||||
|
lossFunc=newLoss(mse, func (x, y: float): float = 0.0), weightRange=(-1.0, +1.0), learnRate=0.05))
|
||||||
|
|
||||||
|
|
|
@ -74,10 +74,12 @@ proc compute*(self: Layer, data: Matrix[float]): Matrix[float] =
|
||||||
## Computes the output of a given layer with
|
## Computes the output of a given layer with
|
||||||
## the given input data and returns it as a
|
## the given input data and returns it as a
|
||||||
## one-dimensional array
|
## one-dimensional array
|
||||||
result = ((self.weights * data).sum() + self.biases).apply(self.activation.function, axis= -1)
|
result = (self.weights.dot(data).sum() + self.biases).apply(self.activation.function, axis= -1)
|
||||||
|
|
||||||
|
|
||||||
proc cost*(self: Layer, x: Matrix[float], Y: Matrix[float]): float =
|
proc cost*(self: Layer, x, y: Matrix[float]): float =
|
||||||
## Returns the total cost of this layer
|
## Returns the total cost of this layer
|
||||||
|
for i in 0..x.shape.cols:
|
||||||
|
result += self.loss.function(x[0, i], y[0, i])
|
||||||
|
result /= float(x.shape.cols)
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,9 @@ import std/strformat
|
||||||
|
|
||||||
type
|
type
|
||||||
NeuralNetwork* = ref object
|
NeuralNetwork* = ref object
|
||||||
## A generic neural network
|
## A generic feed-forward
|
||||||
layers*: seq[Layer]
|
## neural network
|
||||||
|
layers: seq[Layer]
|
||||||
|
|
||||||
|
|
||||||
proc newNeuralNetwork*(layers: seq[int], activationFunc: Activation, lossFunc: Loss, learnRate: float,
|
proc newNeuralNetwork*(layers: seq[int], activationFunc: Activation, lossFunc: Loss, learnRate: float,
|
||||||
|
@ -57,3 +58,10 @@ proc classify*(self: NeuralNetwork, data: Matrix[float]): int =
|
||||||
## Performs a prediction and returns the label
|
## Performs a prediction and returns the label
|
||||||
## with the highest likelyhood
|
## with the highest likelyhood
|
||||||
result = maxIndex(self.predict(data).raw[])
|
result = maxIndex(self.predict(data).raw[])
|
||||||
|
|
||||||
|
|
||||||
|
proc cost*(self: NeuralNetwork, x, y: Matrix[float]): float =
|
||||||
|
## Returns the total average cost of the network
|
||||||
|
for layer in self.layers:
|
||||||
|
result += layer.cost(x, y)
|
||||||
|
result /= float(self.layers.len())
|
|
@ -13,7 +13,10 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from std/strformat import `&`
|
from std/strformat import `&`
|
||||||
from std/sequtils import zip
|
import std/random
|
||||||
|
|
||||||
|
|
||||||
|
randomize()
|
||||||
|
|
||||||
|
|
||||||
type
|
type
|
||||||
|
@ -34,8 +37,8 @@ proc getSize(shape: tuple[rows, cols: int]): int =
|
||||||
## Helper to get the size required for the
|
## Helper to get the size required for the
|
||||||
## underlying data array for a matrix of the
|
## underlying data array for a matrix of the
|
||||||
## given shape
|
## given shape
|
||||||
if shape.cols == 0:
|
if shape.rows == 0:
|
||||||
return shape.rows
|
return shape.cols
|
||||||
return shape.cols * shape.rows
|
return shape.cols * shape.rows
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +56,7 @@ proc newMatrix*[T](data: seq[T]): Matrix[T] =
|
||||||
result.order = RowMajor
|
result.order = RowMajor
|
||||||
|
|
||||||
|
|
||||||
proc newMatrix*[T](data: seq[seq[T]], order: MatrixOrder = RowMajor): Matrix[T] {.raises: [ValueError].} =
|
proc newMatrix*[T](data: seq[seq[T]], order: MatrixOrder = RowMajor): Matrix[T] =
|
||||||
## Initializes a new matrix from a given
|
## Initializes a new matrix from a given
|
||||||
## 2D sequence
|
## 2D sequence
|
||||||
new(result)
|
new(result)
|
||||||
|
@ -81,12 +84,52 @@ proc newMatrix*[T](data: seq[seq[T]], order: MatrixOrder = RowMajor): Matrix[T]
|
||||||
idx = col
|
idx = col
|
||||||
|
|
||||||
|
|
||||||
proc zeros*[T](shape: tuple[rows, cols: int], order: MatrixOrder = RowMajor): Matrix[T] =
|
proc zeros*[T: int | float](shape: tuple[rows, cols: int], order: MatrixOrder = RowMajor): Matrix[T] =
|
||||||
## Creates a new matrix of the given shape
|
## Creates a new matrix of the given shape
|
||||||
## filled with zeros
|
## filled with zeros
|
||||||
new(result)
|
new(result)
|
||||||
new(result.data)
|
new(result.data)
|
||||||
result.data[] = @[]
|
result.data[] = @[]
|
||||||
|
let size = shape.getSize()
|
||||||
|
result.shape = shape
|
||||||
|
when T is int:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(0)
|
||||||
|
when T is float:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
proc ones*[T: int | float](shape: tuple[rows, cols: int], order: MatrixOrder = RowMajor): Matrix[T] =
|
||||||
|
## Creates a new matrix of the given shape
|
||||||
|
## filled with ones
|
||||||
|
new(result)
|
||||||
|
new(result.data)
|
||||||
|
result.data[] = @[]
|
||||||
|
let size = shape.getSize()
|
||||||
|
result.shape = shape
|
||||||
|
when T is int:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(1)
|
||||||
|
when T is float:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(1.0)
|
||||||
|
|
||||||
|
proc rand*[T: int | float](shape: tuple[rows, cols: int], order: MatrixOrder = RowMajor): Matrix[T] =
|
||||||
|
## Creates a new matrix of the given shape
|
||||||
|
## filled with random values between 0 and
|
||||||
|
## 1
|
||||||
|
new(result)
|
||||||
|
new(result.data)
|
||||||
|
result.data[] = @[]
|
||||||
|
let size = shape.getSize()
|
||||||
|
result.shape = shape
|
||||||
|
when T is int:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(rand(0..1))
|
||||||
|
when T is float:
|
||||||
|
for _ in 0..<size:
|
||||||
|
result.data[].add(rand(0.0..1.0))
|
||||||
|
|
||||||
|
|
||||||
# Simple one-line helpers and forward declarations
|
# Simple one-line helpers and forward declarations
|
||||||
|
@ -105,7 +148,7 @@ func getIndex[T](self: Matrix[T], row, col: int): int =
|
||||||
result = col * self.shape.rows + row
|
result = col * self.shape.rows + row
|
||||||
|
|
||||||
|
|
||||||
proc `[]`*[T](self: Matrix[T], row, col: int): T {.raises: [IndexDefect, ValueError].} =
|
proc `[]`*[T](self: Matrix[T], row, col: int): T =
|
||||||
## Gets the element the given row and
|
## Gets the element the given row and
|
||||||
## column into the matrix
|
## column into the matrix
|
||||||
var idx = self.getIndex(row, col)
|
var idx = self.getIndex(row, col)
|
||||||
|
@ -115,7 +158,7 @@ proc `[]`*[T](self: Matrix[T], row, col: int): T {.raises: [IndexDefect, ValueEr
|
||||||
return self.data[idx]
|
return self.data[idx]
|
||||||
|
|
||||||
|
|
||||||
proc `[]`*[T](self: Matrix[T], row: int): MatrixView[T] {.raises: [IndexDefect, ValueError].} =
|
proc `[]`*[T](self: Matrix[T], row: int): MatrixView[T] =
|
||||||
## Gets a single row in the matrix. No data copies
|
## Gets a single row in the matrix. No data copies
|
||||||
## occur and a view into the original matrix is
|
## occur and a view into the original matrix is
|
||||||
## returned
|
## returned
|
||||||
|
@ -128,7 +171,7 @@ proc `[]`*[T](self: Matrix[T], row: int): MatrixView[T] {.raises: [IndexDefect,
|
||||||
result.row = row
|
result.row = row
|
||||||
|
|
||||||
|
|
||||||
proc `[]`*[T](self: MatrixView[T], col: int): T {.raises: [IndexDefect, ValueError].} =
|
proc `[]`*[T](self: MatrixView[T], col: int): T =
|
||||||
## Gets the element the given row into
|
## Gets the element the given row into
|
||||||
## the matrix view
|
## the matrix view
|
||||||
var idx = self.m.getIndex(self.row, col)
|
var idx = self.m.getIndex(self.row, col)
|
||||||
|
@ -138,7 +181,7 @@ proc `[]`*[T](self: MatrixView[T], col: int): T {.raises: [IndexDefect, ValueErr
|
||||||
result = self.m.data[idx]
|
result = self.m.data[idx]
|
||||||
|
|
||||||
|
|
||||||
proc `[]=`*[T](self: Matrix[T], row, col: int, val: T) {.raises: [IndexDefect, ValueError].} =
|
proc `[]=`*[T](self: Matrix[T], row, col: int, val: T) =
|
||||||
## Sets the element at the given row and
|
## Sets the element at the given row and
|
||||||
## column into the matrix to value val
|
## column into the matrix to value val
|
||||||
var idx = self.getIndex(row, col)
|
var idx = self.getIndex(row, col)
|
||||||
|
@ -148,7 +191,7 @@ proc `[]=`*[T](self: Matrix[T], row, col: int, val: T) {.raises: [IndexDefect, V
|
||||||
self.data[idx] = val
|
self.data[idx] = val
|
||||||
|
|
||||||
|
|
||||||
proc `[]=`*[T](self: MatrixView[T], col: int, val: T) {.raises: [IndexDefect, ValueError].} =
|
proc `[]=`*[T](self: MatrixView[T], col: int, val: T) =
|
||||||
## Sets the element at the given row
|
## Sets the element at the given row
|
||||||
## into the matrix view to the value
|
## into the matrix view to the value
|
||||||
## val
|
## val
|
||||||
|
@ -161,7 +204,7 @@ proc `[]=`*[T](self: MatrixView[T], col: int, val: T) {.raises: [IndexDefect, Va
|
||||||
|
|
||||||
|
|
||||||
# Shape management
|
# Shape management
|
||||||
proc reshape*[T](self: Matrix[T], shape: tuple[rows, cols: int]): Matrix[T] {.raises: [ValueError].} =
|
proc reshape*[T](self: Matrix[T], shape: tuple[rows, cols: int]): Matrix[T] =
|
||||||
## Reshapes the given matrix. No data copies occur
|
## Reshapes the given matrix. No data copies occur
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if shape.getSize() != self.data[].len():
|
if shape.getSize() != self.data[].len():
|
||||||
|
@ -170,7 +213,7 @@ proc reshape*[T](self: Matrix[T], shape: tuple[rows, cols: int]): Matrix[T] {.ra
|
||||||
result.shape = shape
|
result.shape = shape
|
||||||
|
|
||||||
|
|
||||||
proc reshape*[T](self: Matrix[T], rows, cols: int): Matrix[T] {.raises: [ValueError].} =
|
proc reshape*[T](self: Matrix[T], rows, cols: int): Matrix[T] =
|
||||||
## Reshapes the given matrix. No data copies occur
|
## Reshapes the given matrix. No data copies occur
|
||||||
result = self.reshape((rows, cols))
|
result = self.reshape((rows, cols))
|
||||||
|
|
||||||
|
@ -183,11 +226,14 @@ proc transpose*[T](self: Matrix[T]): Matrix[T] =
|
||||||
|
|
||||||
|
|
||||||
proc flatten*[T](self: Matrix[T]): Matrix[T] =
|
proc flatten*[T](self: Matrix[T]): Matrix[T] =
|
||||||
## Flattens the matrix into a vector. No
|
## Flattens the matrix into a vector
|
||||||
## data copies occur
|
new(result)
|
||||||
result = self.dup()
|
new(result.data)
|
||||||
result.data = self.data
|
for row in self:
|
||||||
result = result.reshape(0, len(self))
|
for element in row:
|
||||||
|
result.data[].add(element)
|
||||||
|
result.order = RowMajor
|
||||||
|
result.shape = (0, len(self))
|
||||||
|
|
||||||
|
|
||||||
# Helpers for fast applying of operations along an axis
|
# Helpers for fast applying of operations along an axis
|
||||||
|
@ -218,7 +264,7 @@ proc apply*[T](self: Matrix[T], op: proc (a, b: T): T {.noSideEffect.}, b: T, co
|
||||||
|
|
||||||
|
|
||||||
proc apply*[T](self: Matrix[T], op: proc (a: T): T {.noSideEffect.}, copy: bool = false, axis: int): Matrix[T] =
|
proc apply*[T](self: Matrix[T], op: proc (a: T): T {.noSideEffect.}, copy: bool = false, axis: int): Matrix[T] =
|
||||||
## Applies a binary operator to every
|
## Applies a unary operator to every
|
||||||
## element in the given axis of the
|
## element in the given axis of the
|
||||||
## given matrix (0 = rows, 1 = columns,
|
## given matrix (0 = rows, 1 = columns,
|
||||||
## -1 = both). No copies occur unless
|
## -1 = both). No copies occur unless
|
||||||
|
@ -362,6 +408,7 @@ proc dup*[T](self: MatrixView[T]): MatrixView[T] =
|
||||||
new(result)
|
new(result)
|
||||||
result.m = self.m
|
result.m = self.m
|
||||||
result.shape = self.shape
|
result.shape = self.shape
|
||||||
|
result.row = self.row
|
||||||
|
|
||||||
# matrix/scalar operations
|
# matrix/scalar operations
|
||||||
|
|
||||||
|
@ -407,7 +454,7 @@ proc `+`*[T](a, b: MatrixView[T]): Matrix[T] =
|
||||||
result.data[].add(a[i] + b[i])
|
result.data[].add(a[i] + b[i])
|
||||||
|
|
||||||
|
|
||||||
proc `+`*[T](a, b: Matrix[T]): Matrix[T] {.raises: [ValueError].} =
|
proc `+`*[T](a, b: Matrix[T]): Matrix[T] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape.rows > 0 and b.shape.rows > 0 and a.shape != b.shape:
|
if a.shape.rows > 0 and b.shape.rows > 0 and a.shape != b.shape:
|
||||||
raise newException(ValueError, &"incompatible argument shapes for addition")
|
raise newException(ValueError, &"incompatible argument shapes for addition")
|
||||||
|
@ -445,7 +492,7 @@ proc `*`*[T](a, b: MatrixView[T]): Matrix[T] =
|
||||||
result.data[].add(a[i] * b[i])
|
result.data[].add(a[i] * b[i])
|
||||||
|
|
||||||
|
|
||||||
proc `*`*[T](a, b: Matrix[T]): Matrix[T] {.raises: [ValueError].} =
|
proc `*`*[T](a, b: Matrix[T]): Matrix[T] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape.rows > 0 and b.shape.rows > 0 and a.shape.cols != b.shape.rows:
|
if a.shape.rows > 0 and b.shape.rows > 0 and a.shape.cols != b.shape.rows:
|
||||||
raise newException(ValueError, &"incompatible argument shapes for multiplication")
|
raise newException(ValueError, &"incompatible argument shapes for multiplication")
|
||||||
|
@ -468,6 +515,12 @@ proc `*`*[T](a, b: Matrix[T]): Matrix[T] {.raises: [ValueError].} =
|
||||||
for m in r1 * r2:
|
for m in r1 * r2:
|
||||||
for element in m:
|
for element in m:
|
||||||
result.data[].add(element)
|
result.data[].add(element)
|
||||||
|
else:
|
||||||
|
for r1 in a:
|
||||||
|
for r2 in b:
|
||||||
|
for m in r1 * r2:
|
||||||
|
for element in m:
|
||||||
|
result.data[].add(element)
|
||||||
else:
|
else:
|
||||||
result = a[0] * b[0]
|
result = a[0] * b[0]
|
||||||
|
|
||||||
|
@ -521,7 +574,53 @@ proc `>=`*[T](a: Matrix[T], b: T): Matrix[bool] =
|
||||||
result.data[].add(e >= b)
|
result.data[].add(e >= b)
|
||||||
|
|
||||||
|
|
||||||
proc `==`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
proc `==`*[T](a: MatrixView[T], b: MatrixView[T]): Matrix[bool] =
|
||||||
|
when not defined(release):
|
||||||
|
if a.len() != b.len():
|
||||||
|
raise newException(ValueError, "invalid shapes for comparison")
|
||||||
|
new(result)
|
||||||
|
new(result.data)
|
||||||
|
result.shape = a.shape
|
||||||
|
result.order = RowMajor
|
||||||
|
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
||||||
|
var col = 0
|
||||||
|
while col < result.shape.cols:
|
||||||
|
result.data[].add(a[col] == b[col])
|
||||||
|
inc(col)
|
||||||
|
|
||||||
|
|
||||||
|
proc `==`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[bool] =
|
||||||
|
when not defined(release):
|
||||||
|
if a.shape.cols != b.len() or a.shape.rows > 0:
|
||||||
|
raise newException(ValueError, "invalid shapes for comparison")
|
||||||
|
return a[0] == b
|
||||||
|
|
||||||
|
|
||||||
|
proc diag*[T](a: Matrix[T], diagonal: int): Matrix[T] =
|
||||||
|
## Returns the chosen diagonal of the given
|
||||||
|
## matrix as a linear array. Diagonal 0 means left,
|
||||||
|
## 1 means right
|
||||||
|
when not defined(release):
|
||||||
|
if a.shape.rows != a.shape.cols:
|
||||||
|
raise newException(ValueError, "only square matrices have diagonals")
|
||||||
|
var res = newSeqOfCap[T](a.shape.getSize())
|
||||||
|
case diagonal:
|
||||||
|
of 0:
|
||||||
|
for i in 0..<a.shape.rows:
|
||||||
|
res.add(a[i, i])
|
||||||
|
of 1:
|
||||||
|
for i in 0..<a.shape.rows:
|
||||||
|
res.add(a[i, a.shape.rows - i])
|
||||||
|
else:
|
||||||
|
when not defined(release):
|
||||||
|
raise newException(ValueError, &"invalid diagonal {diagonal} for matrix")
|
||||||
|
else:
|
||||||
|
discard
|
||||||
|
result = newMatrix(res)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
proc `==`*[T](a, b: Matrix[T]): Matrix[bool] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape != b.shape:
|
if a.shape != b.shape:
|
||||||
raise newException(ValueError, "can't compare matrices of different shapes")
|
raise newException(ValueError, "can't compare matrices of different shapes")
|
||||||
|
@ -530,12 +629,14 @@ proc `==`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
||||||
result.shape = a.shape
|
result.shape = a.shape
|
||||||
result.order = RowMajor
|
result.order = RowMajor
|
||||||
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
||||||
|
if a.shape.rows == 0:
|
||||||
|
result = a[0] == b[0]
|
||||||
for r in 0..<a.shape.rows:
|
for r in 0..<a.shape.rows:
|
||||||
for c in 0..<a.shape.cols:
|
for c in 0..<a.shape.cols:
|
||||||
result.data[].add(a[r, c] == b[r, c])
|
result.data[].add(a[r, c] == b[r, c])
|
||||||
|
|
||||||
|
|
||||||
proc `>`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
proc `>`*[T](a, b: Matrix[T]): Matrix[bool] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape != b.shape:
|
if a.shape != b.shape:
|
||||||
raise newException(ValueError, "can't compare matrices of different shapes")
|
raise newException(ValueError, "can't compare matrices of different shapes")
|
||||||
|
@ -544,12 +645,14 @@ proc `>`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
||||||
result.shape = a.shape
|
result.shape = a.shape
|
||||||
result.order = RowMajor
|
result.order = RowMajor
|
||||||
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
||||||
|
if a.shape.rows == 0:
|
||||||
|
result = a[0] > b[0]
|
||||||
for r in 0..<a.shape.rows:
|
for r in 0..<a.shape.rows:
|
||||||
for c in 0..<a.shape.cols:
|
for c in 0..<a.shape.cols:
|
||||||
result.data[].add(a[r, c] > b[r, c])
|
result.data[].add(a[r, c] > b[r, c])
|
||||||
|
|
||||||
|
|
||||||
proc `>=`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
proc `>=`*[T](a, b: Matrix[T]): Matrix[bool] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape != b.shape:
|
if a.shape != b.shape:
|
||||||
raise newException(ValueError, "can't compare matrices of different shapes")
|
raise newException(ValueError, "can't compare matrices of different shapes")
|
||||||
|
@ -558,12 +661,14 @@ proc `>=`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
||||||
result.shape = a.shape
|
result.shape = a.shape
|
||||||
result.order = RowMajor
|
result.order = RowMajor
|
||||||
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
||||||
|
if a.shape.rows == 0:
|
||||||
|
result = a[0] >= b[0]
|
||||||
for r in 0..<a.shape.rows:
|
for r in 0..<a.shape.rows:
|
||||||
for c in 0..<a.shape.cols:
|
for c in 0..<a.shape.cols:
|
||||||
result.data[].add(a[r, c] >= b[r, c])
|
result.data[].add(a[r, c] >= b[r, c])
|
||||||
|
|
||||||
|
|
||||||
proc `<=`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
proc `<=`*[T](a, b: Matrix[T]): Matrix[bool] =
|
||||||
when not defined(release):
|
when not defined(release):
|
||||||
if a.shape != b.shape:
|
if a.shape != b.shape:
|
||||||
raise newException(ValueError, "can't compare matrices of different shapes")
|
raise newException(ValueError, "can't compare matrices of different shapes")
|
||||||
|
@ -572,6 +677,8 @@ proc `<=`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} =
|
||||||
result.shape = a.shape
|
result.shape = a.shape
|
||||||
result.order = RowMajor
|
result.order = RowMajor
|
||||||
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
result.data[] = newSeqOfCap[bool](result.shape.getSize())
|
||||||
|
if a.shape.rows == 0:
|
||||||
|
result = a[0] <= b[0]
|
||||||
for r in 0..<a.shape.rows:
|
for r in 0..<a.shape.rows:
|
||||||
for c in 0..<a.shape.cols:
|
for c in 0..<a.shape.cols:
|
||||||
result.data[].add(a[r, c] <= b[r, c])
|
result.data[].add(a[r, c] <= b[r, c])
|
||||||
|
@ -585,40 +692,49 @@ proc all*(a: Matrix[bool]): bool =
|
||||||
return true
|
return true
|
||||||
|
|
||||||
|
|
||||||
|
proc any*(a: Matrix[bool]): bool =
|
||||||
|
# Helper for boolean comparisons
|
||||||
|
for e in a.data[]:
|
||||||
|
if e:
|
||||||
|
return true
|
||||||
|
return false
|
||||||
|
|
||||||
|
|
||||||
# Specular definitions of commutative operators
|
# Specular definitions of commutative operators
|
||||||
proc `<`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} = b > a
|
proc `<`*[T](a, b: Matrix[T]): Matrix[bool] = b > a
|
||||||
proc `!=`*[T](a, b: Matrix[T]): Matrix[bool] {.raises: [ValueError].} = not a == b
|
proc `!=`*[T](a, b: Matrix[T]): Matrix[bool] = not a == b
|
||||||
proc `*`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[T] {.raises: [ValueError].} = b * a
|
proc `*`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[T] = b * a
|
||||||
proc `==`*[T](a: T, b: Matrix[T]): Matrix[bool] = b == a
|
proc `==`*[T](a: T, b: Matrix[T]): Matrix[bool] = b == a
|
||||||
|
proc `==`*[T](a: MatrixView[T], b: Matrix[T]): Matrix[bool] = b == a
|
||||||
proc `!=`*[T](a: Matrix[T], b: T): Matrix[bool] = not a == b
|
proc `!=`*[T](a: Matrix[T], b: T): Matrix[bool] = not a == b
|
||||||
proc `!=`*[T](a: T, b: Matrix[T]): Matrix[bool] = not b == a
|
proc `!=`*[T](a: T, b: Matrix[T]): Matrix[bool] = not b == a
|
||||||
|
|
||||||
|
|
||||||
proc toRowMajor*[T](self: Matrix[T]): Matrix[T] =
|
proc toRowMajor*[T](self: Matrix[T], copy: bool = true): Matrix[T] =
|
||||||
## Converts a column-major matrix to a
|
## Converts a column-major matrix to a
|
||||||
## row-major one
|
## row-major one. Returns a copy unless
|
||||||
|
## copy equals false
|
||||||
if self.order == RowMajor:
|
if self.order == RowMajor:
|
||||||
return
|
return self
|
||||||
self.order = RowMajor
|
if copy:
|
||||||
let orig = self.data[]
|
result = self.copy()
|
||||||
self.data[] = @[]
|
else:
|
||||||
var idx = 0
|
result = self
|
||||||
var col = 0
|
result.order = RowMajor
|
||||||
while col < self.shape.cols:
|
for row in self:
|
||||||
self.data[].add(orig[idx])
|
for element in row:
|
||||||
idx += self.shape.cols
|
self.data[].add(element)
|
||||||
if idx > orig.high():
|
|
||||||
inc(col)
|
|
||||||
idx = col
|
|
||||||
result = self
|
|
||||||
|
|
||||||
|
|
||||||
proc toColumnMajor*[T](self: Matrix[T]): Matrix[T] =
|
proc toColumnMajor*[T](self: Matrix[T], copy: bool = true): Matrix[T] =
|
||||||
## Converts a row-major matrix to a
|
## Converts a row-major matrix to a
|
||||||
## column-major one
|
## column-major one
|
||||||
new(result)
|
|
||||||
if self.order == ColumnMajor:
|
if self.order == ColumnMajor:
|
||||||
return
|
return self
|
||||||
|
if copy:
|
||||||
|
result = self.copy()
|
||||||
|
else:
|
||||||
|
result = self
|
||||||
self.order = ColumnMajor
|
self.order = ColumnMajor
|
||||||
let orig = self.data[]
|
let orig = self.data[]
|
||||||
self.data[] = @[]
|
self.data[] = @[]
|
||||||
|
@ -674,7 +790,7 @@ proc `$`*[T](self: MatrixView[T]): string =
|
||||||
proc `$`*[T](self: Matrix[T]): string =
|
proc `$`*[T](self: Matrix[T]): string =
|
||||||
## Stringifies the matrix
|
## Stringifies the matrix
|
||||||
if self.shape.rows == 0:
|
if self.shape.rows == 0:
|
||||||
return $self[0]
|
return $(self[0])
|
||||||
result &= "["
|
result &= "["
|
||||||
for i, row in self:
|
for i, row in self:
|
||||||
result &= "["
|
result &= "["
|
||||||
|
@ -693,10 +809,34 @@ proc `$`*[T](self: Matrix[T]): string =
|
||||||
proc dot*[T](self, other: Matrix[T]): Matrix[T] =
|
proc dot*[T](self, other: Matrix[T]): Matrix[T] =
|
||||||
## Computes the dot product of the two
|
## Computes the dot product of the two
|
||||||
## input matrices
|
## input matrices
|
||||||
when not defined(release):
|
if self.shape.rows > 1 and other.shape.rows > 1:
|
||||||
if a.shape.cols != b.shape.rows:
|
when not defined(release):
|
||||||
raise newException(ValueError, &"incompatible argument shapes for dot product")
|
if self.shape.rows != other.shape.cols:
|
||||||
# TODO
|
raise newException(ValueError, &"incompatible argument shapes for dot product")
|
||||||
|
result = zeros[T]((self.shape.rows, other.shape.cols))
|
||||||
|
echo self
|
||||||
|
var other = other.transpose()
|
||||||
|
echo other
|
||||||
|
for i in 0..<result.shape.rows:
|
||||||
|
for j in 0..<result.shape.cols:
|
||||||
|
result[i, j] = (self[i] * other[j]).sum()
|
||||||
|
elif self.shape.rows > 1:
|
||||||
|
when not defined(release):
|
||||||
|
if self.shape.cols != other.shape.cols:
|
||||||
|
raise newException(ValueError, &"incompatible argument shapes for dot product")
|
||||||
|
result = zeros[T]((0, self.shape.rows))
|
||||||
|
for i in 0..<result.shape.cols:
|
||||||
|
result[0, i] = (self[i] * other[0]).sum()
|
||||||
|
elif other.shape.rows > 1:
|
||||||
|
when not defined(release):
|
||||||
|
if self.shape.cols != other.shape.cols:
|
||||||
|
raise newException(ValueError, &"incompatible argument shapes for dot product")
|
||||||
|
result = zeros[T]((0, self.shape.cols))
|
||||||
|
var other = other.transpose()
|
||||||
|
for i in 0..<result.shape.cols:
|
||||||
|
result[0, i] = (self[0] * other[i]).sum()
|
||||||
|
else:
|
||||||
|
return self * other
|
||||||
|
|
||||||
|
|
||||||
proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] =
|
proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] =
|
||||||
|
@ -721,6 +861,50 @@ proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] =
|
||||||
inc(row)
|
inc(row)
|
||||||
col = 0
|
col = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Just a helper to avoid mistakes and so that x.where(x > 10, y) works as expected
|
||||||
|
proc where*[T](self: Matrix[T], cond: Matrix[bool], other: Matrix[T]): Matrix[T] = cond.where(self, other)
|
||||||
|
|
||||||
|
|
||||||
|
proc max*[T](self: Matrix[T]): T =
|
||||||
|
## Returns the largest element
|
||||||
|
## into the matrix
|
||||||
|
var m: T = self[0, 0]
|
||||||
|
for row in self:
|
||||||
|
for element in row:
|
||||||
|
if m < element:
|
||||||
|
m = element
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
proc argmax*[T](self: Matrix[T]): T =
|
||||||
|
## Returns the index largest element
|
||||||
|
## into the matrix
|
||||||
|
var m: T = self[0, 0]
|
||||||
|
var
|
||||||
|
row = 0
|
||||||
|
col = 0
|
||||||
|
while row < self.shape.rows:
|
||||||
|
while col < self.shape.cols:
|
||||||
|
if self[row, col] > m:
|
||||||
|
m = self[row, col]
|
||||||
|
if self.shape.rows == 0:
|
||||||
|
while col < self.shape.cols:
|
||||||
|
if self[0, col] > m:
|
||||||
|
m = self[0, col]
|
||||||
|
inc(col)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
proc contains*[T](self: Matrix[T], e: T): bool =
|
||||||
|
## Returns wherher the matrix contains
|
||||||
|
## the element e
|
||||||
|
for row in self:
|
||||||
|
for element in row:
|
||||||
|
if element == e:
|
||||||
|
return true
|
||||||
|
return false
|
||||||
|
|
||||||
when isMainModule:
|
when isMainModule:
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
@ -729,6 +913,7 @@ when isMainModule:
|
||||||
|
|
||||||
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
||||||
var k = m.transpose()
|
var k = m.transpose()
|
||||||
|
assert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||||
assert all(m.transpose() == k), "transpose mismatch"
|
assert all(m.transpose() == k), "transpose mismatch"
|
||||||
assert k.sum() == m.sum(), "element sum mismatch"
|
assert k.sum() == m.sum(), "element sum mismatch"
|
||||||
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||||
|
@ -741,4 +926,9 @@ when isMainModule:
|
||||||
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||||
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||||
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||||
|
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||||
|
assert x.max() == 9, "max mismatch"
|
||||||
|
assert x.argmax() == 9, "argmax mismatch"
|
||||||
|
discard newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]]))
|
||||||
|
discard newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3]))
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Copyright 2022 Mattia Giambirtone
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
## Various data preprocessing tools
|
||||||
|
|
||||||
|
import matrix
|
||||||
|
|
||||||
|
|
||||||
|
import strformat
|
||||||
|
import sets
|
||||||
|
|
||||||
|
|
||||||
|
type
|
||||||
|
LabelEncoder* = ref object
|
||||||
|
## An encoder to assign a numerical value in the
|
||||||
|
## range from 0 to n_labels - 1 to the labels
|
||||||
|
# of some categorical data, reversibly
|
||||||
|
isFit: bool
|
||||||
|
labels: Matrix[string]
|
||||||
|
|
||||||
|
|
||||||
|
proc newLabelEncoder*: LabelEncoder =
|
||||||
|
## Initializes a new LabelEncoder object
|
||||||
|
new(result)
|
||||||
|
|
||||||
|
|
||||||
|
proc toOrderedSet[T](m: Matrix[T]): OrderedSet[T] =
|
||||||
|
result = initOrderedSet[T]()
|
||||||
|
for row in m:
|
||||||
|
for element in row:
|
||||||
|
result.incl(element)
|
||||||
|
|
||||||
|
|
||||||
|
proc fit*(self: LabelEncoder, labels: Matrix[string]) =
|
||||||
|
# Fits the encoder to the given labels
|
||||||
|
var lbl: seq[string] = @[]
|
||||||
|
for label in toOrderedSet(labels):
|
||||||
|
lbl.add(label)
|
||||||
|
self.labels = newMatrix(lbl)
|
||||||
|
self.is_fit = true
|
||||||
|
|
||||||
|
|
||||||
|
proc transform*(self: LabelEncoder, labels: Matrix[string]): Matrix[int] =
|
||||||
|
## Transforms a vector of labels into a vector of encoded
|
||||||
|
## integers. Duplicate labels are assigned the same integer
|
||||||
|
assert self.isFit, "The estimator must be fit!"
|
||||||
|
var res: seq[int] = @[]
|
||||||
|
for row in labels:
|
||||||
|
for label in row:
|
||||||
|
if label notin self.labels:
|
||||||
|
raise newException(ValueError, &"Unknown label '{label}'")
|
||||||
|
res.add(self.labels.raw[].find(label))
|
||||||
|
result = newMatrix(res)
|
||||||
|
|
||||||
|
|
||||||
|
proc reverseTransform*(self: LabelEncoder, labels: Matrix[int]): Matrix[string] =
|
||||||
|
## Reverses the transformation of the integer labels back to a string
|
||||||
|
assert self.is_fit, "The estimator must be fit!"
|
||||||
|
var res: seq[string] = @[]
|
||||||
|
for row in labels:
|
||||||
|
for label in row:
|
||||||
|
if label notin 0..<self.labels.len():
|
||||||
|
raise newException(ValueError, &"Unknown encoded label '{label}'")
|
||||||
|
res.add(self.labels[0, label])
|
||||||
|
result = newMatrix(res)
|
|
@ -0,0 +1,78 @@
|
||||||
|
import matrix
|
||||||
|
|
||||||
|
|
||||||
|
type
|
||||||
|
TileKind* = enum
|
||||||
|
## A tile enumeration kind
|
||||||
|
Empty = 0,
|
||||||
|
Self,
|
||||||
|
Enemy
|
||||||
|
GameStatus* = enum
|
||||||
|
## A game status enumeration
|
||||||
|
Playing,
|
||||||
|
Win,
|
||||||
|
Lose,
|
||||||
|
Draw
|
||||||
|
TrisGame* = ref object
|
||||||
|
map*: Matrix[int]
|
||||||
|
|
||||||
|
|
||||||
|
proc newTrisGame*: TrisGame =
|
||||||
|
## Creates a new TrisGame object
|
||||||
|
new(result)
|
||||||
|
result.map = zeros[int]((3, 3))
|
||||||
|
|
||||||
|
|
||||||
|
proc get*(self: TrisGame): GameStatus =
|
||||||
|
## Returns the game status
|
||||||
|
# Checks for rows
|
||||||
|
for _, row in self.map:
|
||||||
|
if all(row == newMatrix[int](@[1, 1, 1])):
|
||||||
|
return Win
|
||||||
|
elif all(row == newMatrix[int](@[2, 2, 2])):
|
||||||
|
return Lose
|
||||||
|
# Checks for columns
|
||||||
|
for _, col in self.map.transpose:
|
||||||
|
if all(col == newMatrix[int](@[1, 1, 1])):
|
||||||
|
return Win
|
||||||
|
elif all(col == newMatrix[int](@[2, 2, 2])):
|
||||||
|
return Lose
|
||||||
|
# Checks for diagonals
|
||||||
|
for i in 0..<2:
|
||||||
|
if all(self.map.diag(i) == newMatrix[int](@[1, 1, 1])):
|
||||||
|
return Win
|
||||||
|
elif all(self.map.diag(i) == newMatrix[int](@[2, 2, 2])):
|
||||||
|
return Lose
|
||||||
|
# No check was successful and there's no empty slots: draw!
|
||||||
|
if not any(self.map == 0):
|
||||||
|
return Draw
|
||||||
|
# There are empty slots and no one won yet, we're still in game!
|
||||||
|
return Playing
|
||||||
|
|
||||||
|
|
||||||
|
proc `$`*(self: TrisGame): string =
|
||||||
|
## Stringifies self
|
||||||
|
return $self.map
|
||||||
|
|
||||||
|
|
||||||
|
proc place*(self: TrisGame, tile: TileKind, x, y: int) =
|
||||||
|
## Places a tile onto the playing board
|
||||||
|
self.map[x, y] = int(tile)
|
||||||
|
|
||||||
|
|
||||||
|
when isMainModule:
|
||||||
|
var game = newTrisGame()
|
||||||
|
game.place(Enemy, 0, 0)
|
||||||
|
game.place(Enemy, 0, 1)
|
||||||
|
assert game.get() == Playing
|
||||||
|
game.place(Enemy, 0, 2)
|
||||||
|
assert game.get() == Lose
|
||||||
|
game.place(Self, 0, 2)
|
||||||
|
assert game.get() == Playing
|
||||||
|
game.place(Enemy, 1, 1)
|
||||||
|
game.place(Enemy, 2, 2)
|
||||||
|
assert game.get() == Lose
|
||||||
|
game.place(Self, 2, 2)
|
||||||
|
assert game.get() == Playing
|
||||||
|
game.place(Self, 1, 2)
|
||||||
|
assert game.get() == Win
|
Loading…
Reference in New Issue