Added matrix test for softmax function and changed assert statements to doAssert ones
parent
9c191adedd
commit
7737e47f11
|
@ -995,30 +995,34 @@ when isMainModule:
|
|||
|
||||
var m = newMatrix[int](@[@[1, 2, 3], @[4, 5, 6]])
|
||||
var k = m.transpose()
|
||||
assert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
assert all(m.transpose() == k), "transpose mismatch"
|
||||
assert k.sum() == m.sum(), "element sum mismatch"
|
||||
assert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
assert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
doAssert k[2, 1] == m[1, 2], "transpose mismatch"
|
||||
doAssert all(m.transpose() == k), "transpose mismatch"
|
||||
doAssert k.sum() == m.sum(), "element sum mismatch"
|
||||
doAssert all(k.sum(axis=1) == m.sum(axis=0)), "sum over axis mismatch"
|
||||
doAssert all(k.sum(axis=0) == m.sum(axis=1)), "sum over axis mismatch"
|
||||
var y = newMatrix[int](@[1, 2, 3, 4])
|
||||
assert y.sum() == 10, "element sum mismatch"
|
||||
assert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
assert all(m + m == m * 2), "m + m != m * 2"
|
||||
doAssert y.sum() == 10, "element sum mismatch"
|
||||
doAssert (y + y).sum() == 20, "matrix sum mismatch"
|
||||
doAssert all(m + m == m * 2), "m + m != m * 2"
|
||||
var z = newMatrix[int](@[1, 2, 3])
|
||||
assert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
assert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
doAssert (m * z).sum() == 46, "matrix multiplication mismatch"
|
||||
doAssert all(z * z == z.apply(pow, 2, axis = -1, copy=true)), "matrix multiplication mismatch"
|
||||
var x = newMatrix[int](@[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
assert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
assert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
assert x.max() == 9, "max mismatch"
|
||||
assert x.argmax() == 10, "argmax mismatch"
|
||||
assert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
assert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
assert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
assert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
assert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
assert m.diag(3).len() == 0
|
||||
doAssert (x < 5).where(x, x * 10).sum() == 360, "where mismatch"
|
||||
doAssert all((x < 5).where(x, x * 10) == x.where(x < 5, x * 10)), "where mismatch"
|
||||
doAssert x.max() == 9, "max mismatch"
|
||||
doAssert x.argmax() == 10, "argmax mismatch"
|
||||
doAssert all(newMatrix[int](@[12, 23]).dot(newMatrix[int](@[@[11, 22], @[33, 44]])) == newMatrix[int](@[891, 1276]))
|
||||
doAssert all(newMatrix[int](@[@[1, 2, 3], @[2, 3, 4]]).dot(newMatrix[int](@[1, 2, 3])) == newMatrix[int](@[14, 20]))
|
||||
doAssert all(m.diag() == newMatrix[int](@[1, 5]))
|
||||
doAssert all(m.diag(1) == newMatrix[int](@[2, 6]))
|
||||
doAssert all(m.diag(2) == newMatrix[int](@[3]))
|
||||
doAssert m.diag(3).len() == 0
|
||||
var j = m.fliplr()
|
||||
assert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
assert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
assert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
doAssert all(j.diag() == newMatrix[int](@[3, 5]))
|
||||
doAssert all(j.diag(1) == newMatrix[int](@[2, 4]))
|
||||
doAssert all(j.diag(2) == newMatrix[int](@[1]))
|
||||
# A little test for the softmax function
|
||||
var mat = newMatrix[float](@[123.0, 456.0, 789.0])
|
||||
mat = mat - mat.max()
|
||||
doAssert (mat.apply(math.exp, axis = -1) / sum(mat.apply(math.exp, axis = -1))).sum() == 1.0
|
Loading…
Reference in New Issue