NNExperiments/src/util/preprocessing.nim

76 lines
2.4 KiB
Nim

# Copyright 2023 Mattia Giambirtone
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## Various data preprocessing tools
import matrix
import strformat
import sets
type
LabelEncoder* = ref object
## An encoder to assign a numerical value in the
## range from 0 to n_labels - 1 to the labels
## of some categorical data, reversibly
isFit: bool
labels: Matrix[string]
proc newLabelEncoder*: LabelEncoder =
## Initializes a new LabelEncoder object
new(result)
proc toOrderedSet[T](m: Matrix[T]): OrderedSet[T] =
result = initOrderedSet[T]()
for row in m:
for element in row:
result.incl(element)
proc fit*(self: LabelEncoder, labels: Matrix[string]) =
# Fits the encoder to the given labels
var lbl: seq[string] = @[]
for label in toOrderedSet(labels):
lbl.add(label)
self.labels = newMatrix(lbl)
self.is_fit = true
proc transform*(self: LabelEncoder, labels: Matrix[string]): Matrix[int] =
## Transforms a vector of labels into a vector of encoded
## integers. Duplicate labels are assigned the same integer
assert self.isFit, "The estimator must be fit!"
var res: seq[int] = @[]
for row in labels:
for label in row:
if label notin self.labels:
raise newException(ValueError, &"Unknown label '{label}'")
res.add(self.labels.raw[].find(label))
result = newMatrix(res)
proc reverseTransform*(self: LabelEncoder, labels: Matrix[int]): Matrix[string] =
## Reverses the transformation of the integer labels back to a string
assert self.is_fit, "The estimator must be fit!"
var res: seq[string] = @[]
for row in labels:
for label in row:
if label notin 0..<self.labels.len():
raise newException(ValueError, &"Unknown encoded label '{label}'")
res.add(self.labels[0, label])
result = newMatrix(res)