From 01525da8896563971a571bff4ef219fc21819074 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Mon, 20 Mar 2023 12:02:00 +0100 Subject: [PATCH] Added matrix.diagflat() and fast softmax derivative --- src/main.nim | 5 ++++- src/nn/util/matrix.nim | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/main.nim b/src/main.nim index 7346739..a7cf3d7 100644 --- a/src/main.nim +++ b/src/main.nim @@ -23,7 +23,10 @@ func softmax*(input: Matrix[float]): Matrix[float] = var input = input - input.max() result = input.apply(math.exp, axis = -1) / input.apply(math.exp, axis = -1).sum() -func softmaxDerivative*(input: Matrix[float]): Matrix[float] = zeros[float](input.shape) + +func softmaxDerivative*(input: Matrix[float]): Matrix[float] = + var input = input.reshape(input.shape.cols, 1) + result = input.diagflat() - input.dot(input.transpose()) func step*(input: Matrix[float]): Matrix[float] = input.apply(proc (x: float): float = (if x < 0.0: 0.0 else: x), axis = -1) diff --git a/src/nn/util/matrix.nim b/src/nn/util/matrix.nim index f84fbed..c7c0edc 100644 --- a/src/nn/util/matrix.nim +++ b/src/nn/util/matrix.nim @@ -662,8 +662,8 @@ proc `==`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[bool] = proc diag*[T](a: Matrix[T], k: int = 0): Matrix[T] = - ## Returns the kth diagonal of - ## the given matrix if a is 2-D + ## Returns the kth diagonal of + ## the given matrix if a is 2-D ## or a 2-D matrix with a on its ## kth diagonal if it is 1-D if a.shape.rows > 0: @@ -686,6 +686,12 @@ proc diag*[T](a: Matrix[T], k: int = 0): Matrix[T] = inc(current.col) +proc diagflat*[T](a: Matrix[T], k: int = 0): Matrix[T] = + ## Create a 2-D array with the flattened + ## input as a diagonal + result = a.flatten().diag(k) + + proc fliplr*[T](self: Matrix[T]): Matrix[T] = ## Flips each row in the matrix left ## to right. A copy is returned @@ -937,7 +943,9 @@ proc dot*[T](self, other: Matrix[T]): Matrix[T] = proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] = - ## Behaves like numpy.where() + ## Return elements chosen from x or y depending on cond + ## Where cond is true, take elements from x, otherwise + ## take elements from y when not defined(release): if not (x.shape == y.shape and y.shape == cond.shape): raise newException(ValueError, &"all inputs must be of equal shape for where()") @@ -960,7 +968,9 @@ proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] = proc where*[T](cond: Matrix[bool], x: Matrix[T], y: T): Matrix[T] = - ## Behaves like numpy.where, but with a constant + ## Behaves like where but with a constant instead of + ## an array. When cond is true, take elements from x, + ## otherwise take y when not defined(release): if not (x.shape == cond.shape): raise newException(ValueError, &"all inputs must be of equal shape for where()") @@ -1072,4 +1082,6 @@ when isMainModule: doAssert all(j.diag(1) == newMatrix[int](@[2, 4])) doAssert all(j.diag(2) == newMatrix[int](@[1])) var o = newMatrix[int](@[1, 2, 3]) - doAssert all(o.diag() == newMatrix[int](@[@[1, 0, 0], @[0, 2, 0], @[0, 0, 3]])) \ No newline at end of file + doAssert all(o.diag() == newMatrix[int](@[@[1, 0, 0], @[0, 2, 0], @[0, 0, 3]])) + var n = newMatrix[int](@[@[1, 2], @[3, 4]]) + doAssert all(n.diagflat() == newMatrix[int](@[@[1, 0, 0, 0], @[0, 2, 0, 0], @[0, 0, 3, 0], @[0, 0, 0, 4]])) \ No newline at end of file