Extended diag() method and updated matrix tests
This commit is contained in:
parent
dca0efd1f6
commit
d83205e09d
19
src/main.nim
19
src/main.nim
|
@ -12,16 +12,20 @@ proc mse(a, b: Matrix[float]): float =
|
|||
# Derivative of MSE
|
||||
func dxMSE*(x, y: Matrix[float]): Matrix[float] = 2.0 * (x - y)
|
||||
|
||||
func dx*(x, y: Matrix[float]): Matrix[float] = zeros[float](x.shape)
|
||||
|
||||
# A bunch of vectorized activation functions
|
||||
func sigmoid*(input: Matrix[float]): Matrix[float] =
|
||||
result = input.apply(proc (x: float): float = 1 / (1 + exp(-x)) , axis = -1)
|
||||
|
||||
func sigmoidDerivative*(input: Matrix[float]): Matrix[float] = sigmoid(input) * (1.0 - sigmoid(input))
|
||||
|
||||
|
||||
func softmax*(input: Matrix[float]): Matrix[float] =
|
||||
var input = input - input.max()
|
||||
result = input.apply(math.exp, axis = -1) / input.apply(math.exp, axis = -1).sum()
|
||||
|
||||
func softmaxDerivative*(input: Matrix[float]): Matrix[float] = zeros[float](input.shape)
|
||||
|
||||
|
||||
func step*(input: Matrix[float]): Matrix[float] = input.apply(proc (x: float): float = (if x < 0.0: 0.0 else: x), axis = -1)
|
||||
func silu*(input: Matrix[float]): Matrix[float] = input.apply(proc (x: float): float = 1 / (1 + exp(-x)), axis= -1)
|
||||
func relu*(input: Matrix[float]): Matrix[float] = input.apply(proc (x: float): float = max(0.0, x), axis = -1)
|
||||
|
@ -33,9 +37,10 @@ func htan*(input: Matrix[float]): Matrix[float] =
|
|||
input.apply(f, axis = -1)
|
||||
|
||||
|
||||
var mlp = newNeuralNetwork(@[newDenseLayer(2, 3, newActivation(sigmoid, dx)), newDenseLayer(3, 2, newActivation(sigmoid, dx)),
|
||||
newDenseLayer(2, 3, newActivation(softmax, dx))],
|
||||
lossFunc=newLoss(mse, dxMSE),
|
||||
learnRate=0.05, weightRange=(start: -1.0, stop: 1.0), biasRange=(start: -10.0, stop: 10.0),
|
||||
momentum=0.55)
|
||||
var mlp = newNeuralNetwork(@[newDenseLayer(2, 3, newActivation(sigmoid, sigmoidDerivative)),
|
||||
newDenseLayer(3, 2, newActivation(sigmoid, sigmoidDerivative)),
|
||||
newDenseLayer(2, 3, newActivation(softmax, softmaxDerivative))],
|
||||
lossFunc=newLoss(mse, dxMSE), learnRate=0.05, momentum=0.55,
|
||||
weightRange=(start: -1.0, stop: 1.0), biasRange=(start: -10.0, stop: 10.0))
|
||||
echo mlp.feedforward(newMatrix[float](@[1.0, 2.0]))
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ type
|
|||
Activation* = ref object
|
||||
## An activation function
|
||||
function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}
|
||||
derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}
|
||||
derivative: proc (x: Matrix[float]): Matrix[float] {.noSideEffect.}
|
||||
Layer* = ref object
|
||||
## A generic neural network
|
||||
## layer
|
||||
|
@ -77,7 +77,7 @@ proc newLoss*(function: proc (a, b: Matrix[float]): float, derivative: proc (x,
|
|||
result.derivative = derivative
|
||||
|
||||
|
||||
proc newActivation*(function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}, derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}): Activation =
|
||||
proc newActivation*(function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}, derivative: proc (x: Matrix[float]): Matrix[float] {.noSideEffect.}): Activation =
|
||||
## Creates a new Activation object
|
||||
new(result)
|
||||
result.function = function
|
||||
|
|
|
@ -661,18 +661,29 @@ proc `==`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[bool] =
|
|||
return a[0] == b
|
||||
|
||||
|
||||
proc diag*[T](a: Matrix[T], offset: int = 0): Matrix[T] =
|
||||
## Returns the diagonal of the given
|
||||
## matrix starting at the given offset
|
||||
if offset >= a.shape.cols:
|
||||
return newMatrix[T](@[])
|
||||
var current = offset.ind2sub(a.shape)
|
||||
var res = newSeqOfCap[T](a.shape.getSize())
|
||||
while current.row < a.shape.rows and current.col < a.shape.cols:
|
||||
res.add(a.data[a.getIndex(current.row, current.col)])
|
||||
inc(current.row)
|
||||
inc(current.col)
|
||||
result = newMatrix(res)
|
||||
proc diag*[T](a: Matrix[T], k: int = 0): Matrix[T] =
|
||||
## Returns the kth diagonal of
|
||||
## the given matrix if a is 2-D
|
||||
## or a 2-D matrix with a on its
|
||||
## kth diagonal if it is 1-D
|
||||
if a.shape.rows > 0:
|
||||
if k >= a.shape.cols:
|
||||
return newMatrix[T](@[])
|
||||
var current = k.ind2sub(a.shape)
|
||||
var res = newSeqOfCap[T](a.shape.getSize())
|
||||
while current.row < a.shape.rows and current.col < a.shape.cols:
|
||||
res.add(a.data[a.getIndex(current.row, current.col)])
|
||||
inc(current.row)
|
||||
inc(current.col)
|
||||
result = newMatrix(res)
|
||||
else:
|
||||
let size = len(a) + k
|
||||
result = zeros[T]((size, size))
|
||||
var current = k.ind2sub(a.shape)
|
||||
for e in a[0]:
|
||||
result[current.row, current.col] = e
|
||||
inc(current.row)
|
||||
inc(current.col)
|
||||
|
||||
|
||||
proc fliplr*[T](self: Matrix[T]): Matrix[T] =
|
||||
|
@ -1033,30 +1044,32 @@ when isMainModule:
|
|||
|
||||
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
||||
var k = m.transpose()
|
||||
assert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
assert all(m.transpose() == k), "transpose mismatch"
|
||||
assert k.sum() == m.sum(), "element sum mismatch"
|
||||
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
doAssert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
doAssert all(m.transpose() == k), "transpose mismatch"
|
||||
doAssert k.sum() == m.sum(), "element sum mismatch"
|
||||
doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
var y = newMatrix[int](@[1, 2, 3, 4])
|
||||
assert y.sum() == 10, "element sum mismatch"
|
||||
assert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
assert all(m + m == m * 2), "m + m != m * 2"
|
||||
doAssert y.sum() == 10, "element sum mismatch"
|
||||
doAssert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
doAssert all(m + m == m * 2), "m + m != m * 2"
|
||||
var z = newMatrix[int](@[1, 2, 3])
|
||||
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
doAssert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
assert x.max() == 9, "max mismatch"
|
||||
assert x.argmax() == 10, "argmax mismatch"
|
||||
assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
assert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
assert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
assert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
assert m.diag(3).len() == 0
|
||||
doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
doAssert x.max() == 9, "max mismatch"
|
||||
doAssert x.argmax() == 10, "argmax mismatch"
|
||||
doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
doAssert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
doAssert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
doAssert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
doAssert m.diag(3).len() == 0
|
||||
var j = m.fliplr()
|
||||
assert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
assert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
assert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
doAssert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
doAssert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
var o = newMatrix[int](@[1, 2, 3])
|
||||
doAssert all(o.diag() == newMatrix[int](@[@[1, 0, 0], @[0, 2, 0], @[0, 0, 3]]))
|
Loading…
Reference in New Issue