Added matrix.diagflat() and fast softmax derivative

This commit is contained in:
Mattia Giambirtone 2023-03-20 12:02:00 +01:00
parent d83205e09d
commit 01525da889
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
2 changed files with 21 additions and 6 deletions

View File

@ -23,7 +23,10 @@ func softmax*(input: Matrix[float]): Matrix[float] =
var input = input - input.max()
result = input.apply(math.exp, axis = -1) / input.apply(math.exp, axis = -1).sum()
func softmaxDerivative*(input: Matrix[float]): Matrix[float] = zeros[float](input.shape)
func softmaxDerivative*(input: Matrix[float]): Matrix[float] =
var input = input.reshape(input.shape.cols, 1)
result = input.diagflat() - input.dot(input.transpose())
func step*(input: Matrix[float]): Matrix[float] = input.apply(proc (x: float): float = (if x < 0.0: 0.0 else: x), axis = -1)

View File

@ -662,8 +662,8 @@ proc `==`*[T](a: Matrix[T], b: MatrixView[T]): Matrix[bool] =
proc diag*[T](a: Matrix[T], k: int = 0): Matrix[T] =
## Returns the kth diagonal of
## the given matrix if a is 2-D
## Returns the kth diagonal of
## the given matrix if a is 2-D
## or a 2-D matrix with a on its
## kth diagonal if it is 1-D
if a.shape.rows > 0:
@ -686,6 +686,12 @@ proc diag*[T](a: Matrix[T], k: int = 0): Matrix[T] =
inc(current.col)
proc diagflat*[T](a: Matrix[T], k: int = 0): Matrix[T] =
## Create a 2-D array with the flattened
## input as a diagonal
result = a.flatten().diag(k)
proc fliplr*[T](self: Matrix[T]): Matrix[T] =
## Flips each row in the matrix left
## to right. A copy is returned
@ -937,7 +943,9 @@ proc dot*[T](self, other: Matrix[T]): Matrix[T] =
proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] =
## Behaves like numpy.where()
## Return elements chosen from x or y depending on cond
## Where cond is true, take elements from x, otherwise
## take elements from y
when not defined(release):
if not (x.shape == y.shape and y.shape == cond.shape):
raise newException(ValueError, &"all inputs must be of equal shape for where()")
@ -960,7 +968,9 @@ proc where*[T](cond: Matrix[bool], x, y: Matrix[T]): Matrix[T] =
proc where*[T](cond: Matrix[bool], x: Matrix[T], y: T): Matrix[T] =
## Behaves like numpy.where, but with a constant
## Behaves like where but with a constant instead of
## an array. When cond is true, take elements from x,
## otherwise take y
when not defined(release):
if not (x.shape == cond.shape):
raise newException(ValueError, &"all inputs must be of equal shape for where()")
@ -1072,4 +1082,6 @@ when isMainModule:
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
doAssert all(j.diag(2) == newMatrix[int](@[1]))
var o = newMatrix[int](@[1, 2, 3])
doAssert all(o.diag() == newMatrix[int](@[@[1, 0, 0], @[0, 2, 0], @[0, 0, 3]]))
doAssert all(o.diag() == newMatrix[int](@[@[1, 0, 0], @[0, 2, 0], @[0, 0, 3]]))
var n = newMatrix[int](@[@[1, 2], @[3, 4]])
doAssert all(n.diagflat() == newMatrix[int](@[@[1, 0, 0, 0], @[0, 2, 0, 0], @[0, 0, 3, 0], @[0, 0, 0, 4]]))