From 7737e47f11d1234555aa261e6183cf85b75a1180 Mon Sep 17 00:00:00 2001 From: Mattia Giambirtone Date: Sat, 18 Mar 2023 21:20:28 +0100 Subject: [PATCH] Added matrix test for softmax function and changed assert statements to doAssert ones --- src/util/matrix.nim | 50 ++++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/util/matrix.nim b/src/util/matrix.nim index f5c0fd8..81555b3 100644 --- a/src/util/matrix.nim +++ b/src/util/matrix.nim @@ -995,30 +995,34 @@ when isMainModule: var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]]) var k = m.transpose() - assert k[2, 1] == m[1, 2], "transpose mismatch" - assert all(m.transpose() == k), "transpose mismatch" - assert k.sum() == m.sum(), "element sum mismatch" - assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch" - assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch" + doAssert k[2, 1] == m[1, 2], "transpose mismatch" + doAssert all(m.transpose() == k), "transpose mismatch" + doAssert k.sum() == m.sum(), "element sum mismatch" + doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch" + doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch" var y = newMatrix[int](@[1, 2, 3, 4]) - assert y.sum() == 10, "element sum mismatch" - assert (y + y).sum() == 20, "matrix sum mismatch" - assert all(m + m == m * 2), "m + m != m * 2" + doAssert y.sum() == 10, "element sum mismatch" + doAssert (y + y).sum() == 20, "matrix sum mismatch" + doAssert all(m + m == m * 2), "m + m != m * 2" var z = newMatrix[int](@[1, 2, 3]) - assert (m * z).sum() == 46, "matrix multiplication mismatch" - assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch" + doAssert (m * z).sum() == 46, "matrix multiplication mismatch" + doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch" var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch" - assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch" - assert x.max() == 9, "max mismatch" - assert x.argmax() == 10, "argmax mismatch" - assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276])) - assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20])) - assert all(m.diag() == newMatrix[int](@[1, 5])) - assert all(m.diag(1) == newMatrix[int](@[2, 6])) - assert all(m.diag(2) == newMatrix[int](@[3])) - assert m.diag(3).len() == 0 + doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch" + doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch" + doAssert x.max() == 9, "max mismatch" + doAssert x.argmax() == 10, "argmax mismatch" + doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276])) + doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20])) + doAssert all(m.diag() == newMatrix[int](@[1, 5])) + doAssert all(m.diag(1) == newMatrix[int](@[2, 6])) + doAssert all(m.diag(2) == newMatrix[int](@[3])) + doAssert m.diag(3).len() == 0 var j = m.fliplr() - assert all(j.diag() == newMatrix[int](@[3, 5])) - assert all(j.diag(1) == newMatrix[int](@[2, 4])) - assert all(j.diag(2) == newMatrix[int](@[1])) + doAssert all(j.diag() == newMatrix[int](@[3, 5])) + doAssert all(j.diag(1) == newMatrix[int](@[2, 4])) + doAssert all(j.diag(2) == newMatrix[int](@[1])) + # A little test for the softmax function + var mat = newMatrix[float](@[123.0, 456.0, 789.0]) + mat = mat - mat.max() + doAssert (mat.apply(math.exp, axis = -1) / sum(mat.apply(math.exp, axis = -1))).sum() == 1.0 \ No newline at end of file