Added matrix test for softmax function and changed assert statements to doAssert ones

This commit is contained in:
Mattia Giambirtone 2023-03-18 21:20:28 +01:00
parent 9c191adedd
commit 7737e47f11
Signed by: nocturn9x
GPG Key ID: 8270F9F467971E59
1 changed files with 27 additions and 23 deletions

View File

@ -995,30 +995,34 @@ when isMainModule:
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
var k = m.transpose()
assert k[2, 1] == m[1, 2], "transpose mismatch"
assert all(m.transpose() == k), "transpose mismatch"
assert k.sum() == m.sum(), "element sum mismatch"
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
doAssert k[2, 1] == m[1, 2], "transpose mismatch"
doAssert all(m.transpose() == k), "transpose mismatch"
doAssert k.sum() == m.sum(), "element sum mismatch"
doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
var y = newMatrix[int](@[1, 2, 3, 4])
assert y.sum() == 10, "element sum mismatch"
assert (y + y).sum() == 20, "matrix sum mismatch"
assert all(m + m == m * 2), "m + m != m * 2"
doAssert y.sum() == 10, "element sum mismatch"
doAssert (y + y).sum() == 20, "matrix sum mismatch"
doAssert all(m + m == m * 2), "m + m != m * 2"
var z = newMatrix[int](@[1, 2, 3])
assert (m * z).sum() == 46, "matrix multiplication mismatch"
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
doAssert (m * z).sum() == 46, "matrix multiplication mismatch"
doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
assert x.max() == 9, "max mismatch"
assert x.argmax() == 10, "argmax mismatch"
assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
assert all(m.diag() == newMatrix[int](@[1, 5]))
assert all(m.diag(1) == newMatrix[int](@[2, 6]))
assert all(m.diag(2) == newMatrix[int](@[3]))
assert m.diag(3).len() == 0
doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
doAssert x.max() == 9, "max mismatch"
doAssert x.argmax() == 10, "argmax mismatch"
doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
doAssert all(m.diag() == newMatrix[int](@[1, 5]))
doAssert all(m.diag(1) == newMatrix[int](@[2, 6]))
doAssert all(m.diag(2) == newMatrix[int](@[3]))
doAssert m.diag(3).len() == 0
var j = m.fliplr()
assert all(j.diag() == newMatrix[int](@[3, 5]))
assert all(j.diag(1) == newMatrix[int](@[2, 4]))
assert all(j.diag(2) == newMatrix[int](@[1]))
doAssert all(j.diag() == newMatrix[int](@[3, 5]))
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
doAssert all(j.diag(2) == newMatrix[int](@[1]))
# A little test for the softmax function
var mat = newMatrix[float](@[123.0, 456.0, 789.0])
mat = mat - mat.max()
doAssert (mat.apply(math.exp, axis = -1) / sum(mat.apply(math.exp, axis = -1))).sum() == 1.0