76 lines
2.4 KiB
Nim
76 lines
2.4 KiB
Nim
|
# Copyright 2022 Mattia Giambirtone
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
## Various data preprocessing tools
|
||
|
|
||
|
import matrix
|
||
|
|
||
|
|
||
|
import strformat
|
||
|
import sets
|
||
|
|
||
|
|
||
|
type
|
||
|
LabelEncoder* = ref object
|
||
|
## An encoder to assign a numerical value in the
|
||
|
## range from 0 to n_labels - 1 to the labels
|
||
|
# of some categorical data, reversibly
|
||
|
isFit: bool
|
||
|
labels: Matrix[string]
|
||
|
|
||
|
|
||
|
proc newLabelEncoder*: LabelEncoder =
|
||
|
## Initializes a new LabelEncoder object
|
||
|
new(result)
|
||
|
|
||
|
|
||
|
proc toOrderedSet[T](m: Matrix[T]): OrderedSet[T] =
|
||
|
result = initOrderedSet[T]()
|
||
|
for row in m:
|
||
|
for element in row:
|
||
|
result.incl(element)
|
||
|
|
||
|
|
||
|
proc fit*(self: LabelEncoder, labels: Matrix[string]) =
|
||
|
# Fits the encoder to the given labels
|
||
|
var lbl: seq[string] = @[]
|
||
|
for label in toOrderedSet(labels):
|
||
|
lbl.add(label)
|
||
|
self.labels = newMatrix(lbl)
|
||
|
self.is_fit = true
|
||
|
|
||
|
|
||
|
proc transform*(self: LabelEncoder, labels: Matrix[string]): Matrix[int] =
|
||
|
## Transforms a vector of labels into a vector of encoded
|
||
|
## integers. Duplicate labels are assigned the same integer
|
||
|
assert self.isFit, "The estimator must be fit!"
|
||
|
var res: seq[int] = @[]
|
||
|
for row in labels:
|
||
|
for label in row:
|
||
|
if label notin self.labels:
|
||
|
raise newException(ValueError, &"Unknown label '{label}'")
|
||
|
res.add(self.labels.raw[].find(label))
|
||
|
result = newMatrix(res)
|
||
|
|
||
|
|
||
|
proc reverseTransform*(self: LabelEncoder, labels: Matrix[int]): Matrix[string] =
|
||
|
## Reverses the transformation of the integer labels back to a string
|
||
|
assert self.is_fit, "The estimator must be fit!"
|
||
|
var res: seq[string] = @[]
|
||
|
for row in labels:
|
||
|
for label in row:
|
||
|
if label notin 0..<self.labels.len():
|
||
|
raise newException(ValueError, &"Unknown encoded label '{label}'")
|
||
|
res.add(self.labels[0, label])
|
||
|
result = newMatrix(res)
|