Added matrix test for softmax function and changed assert statements to doAssert ones
This commit is contained in:
parent
9c191adedd
commit
7737e47f11
|
@ -995,30 +995,34 @@ when isMainModule:
|
||||||
|
|
||||||
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
||||||
var k = m.transpose()
|
var k = m.transpose()
|
||||||
assert k[2, 1] == m[1, 2], "transpose mismatch"
|
doAssert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||||
assert all(m.transpose() == k), "transpose mismatch"
|
doAssert all(m.transpose() == k), "transpose mismatch"
|
||||||
assert k.sum() == m.sum(), "element sum mismatch"
|
doAssert k.sum() == m.sum(), "element sum mismatch"
|
||||||
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||||
assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||||
var y = newMatrix[int](@[1, 2, 3, 4])
|
var y = newMatrix[int](@[1, 2, 3, 4])
|
||||||
assert y.sum() == 10, "element sum mismatch"
|
doAssert y.sum() == 10, "element sum mismatch"
|
||||||
assert (y + y).sum() == 20, "matrix sum mismatch"
|
doAssert (y + y).sum() == 20, "matrix sum mismatch"
|
||||||
assert all(m + m == m * 2), "m + m != m * 2"
|
doAssert all(m + m == m * 2), "m + m != m * 2"
|
||||||
var z = newMatrix[int](@[1, 2, 3])
|
var z = newMatrix[int](@[1, 2, 3])
|
||||||
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
doAssert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||||
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||||
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||||
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||||
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||||
assert x.max() == 9, "max mismatch"
|
doAssert x.max() == 9, "max mismatch"
|
||||||
assert x.argmax() == 10, "argmax mismatch"
|
doAssert x.argmax() == 10, "argmax mismatch"
|
||||||
assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||||
assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||||
assert all(m.diag() == newMatrix[int](@[1, 5]))
|
doAssert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||||
assert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
doAssert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||||
assert all(m.diag(2) == newMatrix[int](@[3]))
|
doAssert all(m.diag(2) == newMatrix[int](@[3]))
|
||||||
assert m.diag(3).len() == 0
|
doAssert m.diag(3).len() == 0
|
||||||
var j = m.fliplr()
|
var j = m.fliplr()
|
||||||
assert all(j.diag() == newMatrix[int](@[3, 5]))
|
doAssert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||||
assert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||||
assert all(j.diag(2) == newMatrix[int](@[1]))
|
doAssert all(j.diag(2) == newMatrix[int](@[1]))
|
||||||
|
# A little test for the softmax function
|
||||||
|
var mat = newMatrix[float](@[123.0, 456.0, 789.0])
|
||||||
|
mat = mat - mat.max()
|
||||||
|
doAssert (mat.apply(math.exp, axis = -1) / sum(mat.apply(math.exp, axis = -1))).sum() == 1.0
|
Loading…
Reference in New Issue