Derivatives are now vectorized

This commit is contained in:
Mattia Giambirtone 2023-03-20 10:31:09 +01:00
parent e91869e5ab
commit dca0efd1f6
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
2 changed files with 6 additions and 6 deletions

View File

@ -10,9 +10,9 @@ proc mse(a, b: Matrix[float]): float =
result = (b - a).apply(proc (x: float): float = pow(x, 2), axis = -1).sum() / len(a).float
# Derivative of MSE
func dxMSE*(x, y: float): float = 2 * (x - y)
func dxMSE*(x, y: Matrix[float]): Matrix[float] = 2.0 * (x - y)
func dx*(x, y: float): float = 0.0
func dx*(x, y: Matrix[float]): Matrix[float] = zeros[float](x.shape)
# A bunch of vectorized activation functions
func sigmoid*(input: Matrix[float]): Matrix[float] =

View File

@ -41,11 +41,11 @@ type
Loss* = ref object
## A loss function and its derivative
function: proc (a, b: Matrix[float]): float
derivative: proc (x, y: float): float {.noSideEffect.}
derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}
Activation* = ref object
## An activation function
function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}
derivative: proc (x, y: float): float {.noSideEffect.}
derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}
Layer* = ref object
## A generic neural network
## layer
@ -70,14 +70,14 @@ proc `$`*(self: NeuralNetwork): string =
result = &"NeuralNetwork(learnRate={self.learnRate}, layers={self.layers})"
proc newLoss*(function: proc (a, b: Matrix[float]): float, derivative: proc (x, y: float): float {.noSideEffect.}): Loss =
proc newLoss*(function: proc (a, b: Matrix[float]): float, derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}): Loss =
## Creates a new Loss object
new(result)
result.function = function
result.derivative = derivative
proc newActivation*(function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}, derivative: proc (x, y: float): float {.noSideEffect.}): Activation =
proc newActivation*(function: proc (input: Matrix[float]): Matrix[float] {.noSideEffect.}, derivative: proc (x, y: Matrix[float]): Matrix[float] {.noSideEffect.}): Activation =
## Creates a new Activation object
new(result)
result.function = function