Chess FEN dataloading
Expects bullet-style text input FEN | EVAL | WDL
This commit is contained in:
1
makefile
1
makefile
@@ -34,6 +34,7 @@ EXE ?= Ember$(EXE_EXT)
|
||||
|
||||
# Source and object files
|
||||
SRCS := $(wildcard ./src/*.cpp)
|
||||
SRCS += $(wildcard ./src/*/*.cpp)
|
||||
SRCS += ./external/fmt/format.cpp
|
||||
OBJS := $(SRCS:.cpp=.o)
|
||||
DEPS := $(OBJS:.o=.d)
|
||||
|
||||
@@ -4,31 +4,26 @@
|
||||
|
||||
int main() {
|
||||
Ember::Network net(
|
||||
Ember::layers::Input(28, 28, 1),
|
||||
Ember::layers::Convolution(64, 3),
|
||||
Ember::layers::Input(2 * 6 * 64),
|
||||
Ember::layers::Linear(128),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::MaxPool(),
|
||||
Ember::layers::Flatten(),
|
||||
Ember::layers::Linear(512),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::Linear(64),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::Linear(10),
|
||||
Ember::activations::Softmax()
|
||||
Ember::layers::Linear(1)
|
||||
);
|
||||
|
||||
Ember::dataloaders::ImageDataLoader dataloader("../datasets/FashionMNIST/", 64, 0.9, 6, 28, 28);
|
||||
constexpr Ember::usize evalScale = 400;
|
||||
|
||||
Ember::dataloaders::chess::BulletTextDataLoader dataloader("../datasets/preludeData.txt", 1024 * 16, evalScale, 1);
|
||||
Ember::optimizers::Adam optimizer(net);
|
||||
|
||||
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::CrossEntropyLoss());
|
||||
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::SigmoidMSE(evalScale));
|
||||
|
||||
std::cout << net << std::endl;
|
||||
|
||||
learner.addCallbacks(
|
||||
Ember::callbacks::DropLROnPlateau(3, 0.3),
|
||||
Ember::callbacks::StopWhenNoProgress(5),
|
||||
Ember::callbacks::DropLROnPlateau(3, 0.3, Ember::Metric::TRAIN_LOSS),
|
||||
Ember::callbacks::StopWhenNoProgress(5, Ember::Metric::TRAIN_LOSS),
|
||||
Ember::callbacks::AutosaveBest("../net.bin")
|
||||
);
|
||||
|
||||
learner.learn(0.001, 20, 1);
|
||||
}
|
||||
learner.learn(0.005, 40, 1);
|
||||
}
|
||||
@@ -1,15 +1,23 @@
|
||||
#include "activation.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace Ember {
|
||||
namespace internal::activations {
|
||||
float ReLU(const float x) {
|
||||
return std::max(x, 0.0f);
|
||||
}
|
||||
float CReLU(const float x) {
|
||||
return std::clamp(x, 0.0f, 1.0f);
|
||||
}
|
||||
|
||||
namespace derivatives {
|
||||
float ReLU(const float x) {
|
||||
return x > 0 ? 1 : 0;
|
||||
}
|
||||
float CReLU(const float x) {
|
||||
return x > 0 && x < 1 ? 1 : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +35,19 @@ namespace Ember {
|
||||
}
|
||||
|
||||
|
||||
void CReLU::forward(const Layer& previous) {
|
||||
for (usize prev = 0; prev < previous.values.size(); prev++)
|
||||
values.data[prev] = internal::activations::CReLU(previous.values.data[prev]);
|
||||
}
|
||||
Tensor CReLU::backward(const Layer& previous, const Tensor& gradOutput) const {
|
||||
Tensor result(gradOutput.dims());
|
||||
for (usize prev = 0; prev < gradOutput.size(); prev++)
|
||||
result.data[prev] = gradOutput.data[prev] * internal::activations::derivatives::CReLU(previous.values.data[prev]);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
void Softmax::forward(const Layer& previous) {
|
||||
const usize batchSize = previous.values.dim(0);
|
||||
const usize numClasses = previous.values.dim(1);
|
||||
@@ -72,4 +93,4 @@ namespace Ember {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,20 @@ namespace Ember {
|
||||
}
|
||||
};
|
||||
|
||||
struct CReLU : internal::NonComputeLayer {
|
||||
void forward(const Layer& previous) override;
|
||||
|
||||
Tensor backward(const Layer& previous, const Tensor& gradOutput) const override;
|
||||
|
||||
std::unique_ptr<Layer> clone() override {
|
||||
return std::make_unique<CReLU>(*this);
|
||||
}
|
||||
|
||||
std::string str() const override {
|
||||
return fmt::format("Clipped ReLU - {}", dims());
|
||||
}
|
||||
};
|
||||
|
||||
struct Softmax : internal::NonComputeLayer {
|
||||
void forward(const Layer& previous) override;
|
||||
|
||||
|
||||
415
src/chess/board.cpp
Normal file
415
src/chess/board.cpp
Normal file
@@ -0,0 +1,415 @@
|
||||
#include "board.h"
|
||||
|
||||
#include "../util.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <random>
|
||||
#include <bit>
|
||||
|
||||
using std::popcount;
|
||||
|
||||
#define ctzll(x) std::countr_zero(x)
|
||||
|
||||
namespace Ember::chess {
|
||||
// Helpers and util functions
|
||||
|
||||
constexpr Square toSquare(const Rank rank, const File file) { return static_cast<Square>((static_cast<int>(rank) << 3) | file); }
|
||||
|
||||
// Takes square (h8) and converts it into a bitboard index (64)
|
||||
constexpr Square parseSquare(const std::string_view square) { return static_cast<Square>((square.at(1) - '1') * 8 + (square.at(0) - 'a')); }
|
||||
|
||||
bool readBit(const u64 bb, const i8 sq) { return (1ULL << sq) & bb; }
|
||||
|
||||
template<bool value>
|
||||
void setBit(auto& bitboard, const usize index) {
|
||||
assert(index <= sizeof(bitboard) * 8);
|
||||
if constexpr (value)
|
||||
bitboard |= (1ULL << index);
|
||||
else
|
||||
bitboard &= ~(1ULL << index);
|
||||
}
|
||||
|
||||
Square getLSB(auto bb) {
|
||||
assert(bb > 0);
|
||||
return static_cast<Square>(ctzll(bb));
|
||||
}
|
||||
|
||||
Square popLSB(auto& bb) {
|
||||
assert(bb > 0);
|
||||
const Square sq = getLSB(bb);
|
||||
bb &= bb - 1;
|
||||
return sq;
|
||||
}
|
||||
|
||||
template<int dir>
|
||||
u64 shift(const u64 bb) {
|
||||
return dir > 0 ? bb << dir : bb >> -dir;
|
||||
}
|
||||
|
||||
u64 shift(const int dir, const u64 bb) { return dir > 0 ? bb << dir : bb >> -dir; }
|
||||
|
||||
constexpr u8 castleIndex(const Color c, const bool kingside) { return c == WHITE ? (kingside ? 3 : 2) : (kingside ? 1 : 0); }
|
||||
|
||||
constexpr Square flipRank(Square s) { return Square(s ^ 0b111000); }
|
||||
constexpr Square flipFile(Square s) { return Square(s ^ 0b000111); }
|
||||
|
||||
|
||||
|
||||
// Encodes a chess move
|
||||
class Move {
|
||||
u16 move;
|
||||
|
||||
public:
|
||||
constexpr Move() = default;
|
||||
constexpr ~Move() = default;
|
||||
|
||||
constexpr Move(const u8 startSquare, const u8 endSquare, const MoveType flags = STANDARD_MOVE) {
|
||||
move = startSquare | flags;
|
||||
move |= endSquare << 6;
|
||||
}
|
||||
|
||||
constexpr Move(const u8 startSquare, const u8 endSquare, const PieceType promo) {
|
||||
move = startSquare | PROMOTION;
|
||||
move |= endSquare << 6;
|
||||
move |= (promo - 1) << 12;
|
||||
}
|
||||
|
||||
Move(std::string strIn, Board& board);
|
||||
|
||||
constexpr static Move null() { return Move(a1, a1); }
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
Square from() const { return static_cast<Square>(move & 0b111111); }
|
||||
Square to() const { return static_cast<Square>((move >> 6) & 0b111111); }
|
||||
|
||||
MoveType typeOf() const { return static_cast<MoveType>(move & 0xC000); }
|
||||
|
||||
PieceType promo() const {
|
||||
assert(typeOf() == PROMOTION);
|
||||
return static_cast<PieceType>(((move >> 12) & 0b11) + 1);
|
||||
}
|
||||
|
||||
bool isNull() const { return *this == null(); }
|
||||
|
||||
bool operator==(const Move other) const { return move == other.move; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Move& m) {
|
||||
os << m.toString();
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Returns the piece on a square as a character
|
||||
char Board::getPieceAt(const i8 sq) const {
|
||||
assert(sq >= 0);
|
||||
assert(sq < 64);
|
||||
if (getPiece(sq) == NO_PIECE_TYPE)
|
||||
return ' ';
|
||||
constexpr char whiteSymbols[] = { 'P', 'N', 'B', 'R', 'Q', 'K' };
|
||||
constexpr char blackSymbols[] = { 'p', 'n', 'b', 'r', 'q', 'k' };
|
||||
if (((1ULL << sq) & byColor[WHITE]) != 0)
|
||||
return whiteSymbols[getPiece(sq)];
|
||||
return blackSymbols[getPiece(sq)];
|
||||
}
|
||||
|
||||
void Board::placePiece(const Color c, const PieceType pt, const int sq) {
|
||||
assert(sq >= 0);
|
||||
assert(sq < 64);
|
||||
|
||||
auto& BB = byPieces[pt];
|
||||
|
||||
assert(!readBit(BB, sq));
|
||||
|
||||
BB ^= 1ULL << sq;
|
||||
byColor[c] ^= 1ULL << sq;
|
||||
|
||||
mailbox[sq] = pt;
|
||||
}
|
||||
|
||||
void Board::removePiece(Color c, PieceType pt, int sq) {
|
||||
assert(sq >= 0);
|
||||
assert(sq < 64);
|
||||
|
||||
auto& BB = byPieces[pt];
|
||||
|
||||
assert(readBit(BB, sq));
|
||||
|
||||
BB ^= 1ULL << sq;
|
||||
byColor[c] ^= 1ULL << sq;
|
||||
|
||||
mailbox[sq] = NO_PIECE_TYPE;
|
||||
}
|
||||
|
||||
void Board::removePiece(Color c, int sq) {
|
||||
assert(sq >= 0);
|
||||
assert(sq < 64);
|
||||
|
||||
auto& BB = byPieces[getPiece(sq)];
|
||||
|
||||
assert(readBit(BB, sq));
|
||||
|
||||
BB ^= 1ULL << sq;
|
||||
byColor[c] ^= 1ULL << sq;
|
||||
|
||||
mailbox[sq] = NO_PIECE_TYPE;
|
||||
}
|
||||
|
||||
void Board::resetMailbox() {
|
||||
mailbox.fill(NO_PIECE_TYPE);
|
||||
for (u8 i = 0; i < 64; i++) {
|
||||
PieceType& sq = mailbox[i];
|
||||
const u64 mask = 1ULL << i;
|
||||
if (mask & pieces(PAWN))
|
||||
sq = PAWN;
|
||||
else if (mask & pieces(KNIGHT))
|
||||
sq = KNIGHT;
|
||||
else if (mask & pieces(BISHOP))
|
||||
sq = BISHOP;
|
||||
else if (mask & pieces(ROOK))
|
||||
sq = ROOK;
|
||||
else if (mask & pieces(QUEEN))
|
||||
sq = QUEEN;
|
||||
else if (mask & pieces(KING))
|
||||
sq = KING;
|
||||
}
|
||||
}
|
||||
|
||||
void Board::setCastlingRights(const Color c, const Square sq, const bool value) { castling[castleIndex(c, ctzll(pieces(c, KING)) < sq)] = (value == false ? NO_SQUARE : sq); }
|
||||
|
||||
void Board::unsetCastlingRights(const Color c) { castling[castleIndex(c, true)] = castling[castleIndex(c, false)] = NO_SQUARE; }
|
||||
|
||||
Square Board::castleSq(const Color c, const bool kingside) const { return castling[castleIndex(c, kingside)]; }
|
||||
|
||||
u8 Board::count(const PieceType pt) const { return popcount(pieces(pt)); }
|
||||
|
||||
u64 Board::pieces() const { return byColor[WHITE] | byColor[BLACK]; }
|
||||
u64 Board::pieces(const Color c) const { return byColor[c]; }
|
||||
u64 Board::pieces(const PieceType pt) const { return byPieces[pt]; }
|
||||
u64 Board::pieces(const Color c, const PieceType pt) const { return byPieces[pt] & byColor[c]; }
|
||||
u64 Board::pieces(const PieceType pt1, const PieceType pt2) const { return byPieces[pt1] | byPieces[pt2]; }
|
||||
u64 Board::pieces(const Color c, const PieceType pt1, const PieceType pt2) const { return (byPieces[pt1] | byPieces[pt2]) & byColor[c]; }
|
||||
|
||||
// Load a board from the FEN
|
||||
void Board::loadFromFEN(const std::string fen) {
|
||||
// Clear all squares
|
||||
byPieces.fill(0);
|
||||
byColor.fill(0);
|
||||
|
||||
const std::vector<std::string> tokens = split(fen, ' ');
|
||||
|
||||
const std::vector<std::string> rankTokens = split(tokens[0], '/');
|
||||
|
||||
int currIdx = 56;
|
||||
|
||||
constexpr char whitePieces[6] = { 'P', 'N', 'B', 'R', 'Q', 'K' };
|
||||
constexpr char blackPieces[6] = { 'p', 'n', 'b', 'r', 'q', 'k' };
|
||||
|
||||
for (const std::string& rank : rankTokens) {
|
||||
for (const char c : rank) {
|
||||
if (isdigit(c)) { // Empty squares
|
||||
currIdx += c - '0';
|
||||
continue;
|
||||
}
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (c == whitePieces[i]) {
|
||||
setBit<1>(byPieces[i], currIdx);
|
||||
setBit<1>(byColor[WHITE], currIdx);
|
||||
break;
|
||||
}
|
||||
if (c == blackPieces[i]) {
|
||||
setBit<1>(byPieces[i], currIdx);
|
||||
setBit<1>(byColor[BLACK], currIdx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
currIdx++;
|
||||
}
|
||||
currIdx -= 16;
|
||||
}
|
||||
|
||||
if (tokens[1] == "w")
|
||||
stm = WHITE;
|
||||
else
|
||||
stm = BLACK;
|
||||
|
||||
castling.fill(NO_SQUARE);
|
||||
if (tokens[2].find('-') == std::string::npos) {
|
||||
// Standard FEN and maybe XFEN later
|
||||
if (tokens[2].find('K') != std::string::npos)
|
||||
castling[castleIndex(WHITE, true)] = h1;
|
||||
if (tokens[2].find('Q') != std::string::npos)
|
||||
castling[castleIndex(WHITE, false)] = a1;
|
||||
if (tokens[2].find('k') != std::string::npos)
|
||||
castling[castleIndex(BLACK, true)] = h8;
|
||||
if (tokens[2].find('q') != std::string::npos)
|
||||
castling[castleIndex(BLACK, false)] = a8;
|
||||
|
||||
// FRC FEN
|
||||
if (std::tolower(tokens[2][0]) >= 'a' && std::tolower(tokens[2][0]) <= 'h') {
|
||||
for (char token : tokens[2]) {
|
||||
const auto file = static_cast<File>(std::tolower(token) - 'a');
|
||||
|
||||
if (std::isupper(token))
|
||||
setCastlingRights(WHITE, toSquare(RANK1, file), true);
|
||||
else
|
||||
setCastlingRights(BLACK, toSquare(RANK8, file), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens[3] != "-")
|
||||
epSquare = parseSquare(tokens[3]);
|
||||
else
|
||||
epSquare = NO_SQUARE;
|
||||
|
||||
halfMoveClock = tokens.size() > 4 ? (stoi(tokens[4])) : 0;
|
||||
fullMoveClock = tokens.size() > 5 ? (stoi(tokens[5])) : 1;
|
||||
|
||||
resetMailbox();
|
||||
}
|
||||
|
||||
// Return the type of the piece on the square
|
||||
PieceType Board::getPiece(const i8 sq) const {
|
||||
assert(sq >= 0);
|
||||
assert(sq < 64);
|
||||
return mailbox[sq];
|
||||
}
|
||||
|
||||
bool Board::isCapture(const Move m) const { return ((1ULL << m.to() & pieces(~stm)) || m.typeOf() == EN_PASSANT); }
|
||||
|
||||
std::vector<float> Board::asInputLayer() const {
|
||||
const auto getFeature = [this](const Color pieceColor, const Square square) {
|
||||
const bool enemy = stm != pieceColor;
|
||||
const int squareIndex = (stm == BLACK) ? flipRank(square) : static_cast<int>(square);
|
||||
|
||||
return enemy * 64 * 6 + getPiece(square) * 64 + squareIndex;
|
||||
};
|
||||
|
||||
std::vector<float> res(2 * 6 * 64);
|
||||
|
||||
u64 whitePieces = pieces(WHITE);
|
||||
u64 blackPieces = pieces(BLACK);
|
||||
|
||||
while (whitePieces) {
|
||||
const Square sq = popLSB(whitePieces);
|
||||
|
||||
const usize feature = getFeature(WHITE, sq);
|
||||
|
||||
res[feature] = true;
|
||||
}
|
||||
|
||||
while (blackPieces) {
|
||||
const Square sq = popLSB(blackPieces);
|
||||
|
||||
const usize feature = getFeature(BLACK, sq);
|
||||
|
||||
res[feature] = true;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void Board::move(const Move m) {
|
||||
epSquare = NO_SQUARE;
|
||||
Square from = m.from();
|
||||
Square to = m.to();
|
||||
MoveType mt = m.typeOf();
|
||||
PieceType pt = getPiece(from);
|
||||
PieceType toPT = NO_PIECE_TYPE;
|
||||
|
||||
removePiece(stm, pt, from);
|
||||
if (isCapture(m)) {
|
||||
toPT = getPiece(to);
|
||||
halfMoveClock = 0;
|
||||
if (mt != EN_PASSANT) {
|
||||
removePiece(~stm, toPT, to);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (pt == PAWN)
|
||||
halfMoveClock = 0;
|
||||
else
|
||||
halfMoveClock++;
|
||||
}
|
||||
|
||||
switch (mt) {
|
||||
case STANDARD_MOVE:
|
||||
placePiece(stm, pt, to);
|
||||
if (pt == PAWN && (to + 16 == from || to - 16 == from)
|
||||
&& (pieces(~stm, PAWN) & (shift<EAST>((1ULL << to) & ~MASK_FILE[FILE_H]) | shift<WEST>((1ULL << to) & ~MASK_FILE[FILE_A])))) // Only set EP square if it could be taken
|
||||
epSquare = static_cast<Square>(stm == WHITE ? from + NORTH : from + SOUTH);
|
||||
break;
|
||||
case EN_PASSANT:
|
||||
removePiece(~stm, PAWN, to + (stm == WHITE ? SOUTH : NORTH));
|
||||
placePiece(stm, pt, to);
|
||||
break;
|
||||
case CASTLE:
|
||||
assert(getPiece(to) == ROOK);
|
||||
removePiece(stm, ROOK, to);
|
||||
if (stm == WHITE) {
|
||||
if (from < to) {
|
||||
placePiece(stm, KING, g1);
|
||||
placePiece(stm, ROOK, f1);
|
||||
}
|
||||
else {
|
||||
placePiece(stm, KING, c1);
|
||||
placePiece(stm, ROOK, d1);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (from < to) {
|
||||
placePiece(stm, KING, g8);
|
||||
placePiece(stm, ROOK, f8);
|
||||
}
|
||||
else {
|
||||
placePiece(stm, KING, c8);
|
||||
placePiece(stm, ROOK, d8);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case PROMOTION:
|
||||
placePiece(stm, m.promo(), to);
|
||||
break;
|
||||
}
|
||||
|
||||
assert(popcount(pieces(WHITE, KING)) == 1);
|
||||
assert(popcount(pieces(BLACK, KING)) == 1);
|
||||
|
||||
if (pt == ROOK) {
|
||||
const Square sq = castleSq(stm, from > ctzll(pieces(stm, KING)));
|
||||
if (from == sq)
|
||||
setCastlingRights(stm, from, false);
|
||||
}
|
||||
else if (pt == KING)
|
||||
unsetCastlingRights(stm);
|
||||
if (toPT == ROOK) {
|
||||
const Square sq = castleSq(~stm, to > ctzll(pieces(~stm, KING)));
|
||||
if (to == sq)
|
||||
setCastlingRights(~stm, to, false);
|
||||
}
|
||||
|
||||
stm = ~stm;
|
||||
|
||||
fullMoveClock += stm == WHITE;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Board& board) {
|
||||
os << "\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n";
|
||||
|
||||
usize line = 1;
|
||||
for (i32 rank = (board.stm == WHITE) * 7; (board.stm == WHITE) ? rank >= 0 : rank < 8; (board.stm == WHITE) ? rank-- : rank++) {
|
||||
os << "\u2502 ";
|
||||
for (usize file = 0; file < 8; file++) {
|
||||
const auto sq = static_cast<Square>(rank * 8 + file);
|
||||
const auto color = ((1ULL << sq) & board.pieces(WHITE)) ? "\033[33m" : "\033[34m";
|
||||
os << color << board.getPieceAt(sq) << "\033[0m ";
|
||||
}
|
||||
os << "\u2502 " << rank + 1 << "\n";
|
||||
}
|
||||
os << "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n";
|
||||
os << " a b c d e f g h\n";
|
||||
return os;
|
||||
}
|
||||
}
|
||||
136
src/chess/board.h
Normal file
136
src/chess/board.h
Normal file
@@ -0,0 +1,136 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "../types.h"
|
||||
#include "../tensor.h"
|
||||
|
||||
namespace Ember::chess {
|
||||
enum Color {
|
||||
BLACK,
|
||||
WHITE
|
||||
};
|
||||
|
||||
enum PieceType : i8 {
|
||||
PAWN,
|
||||
KNIGHT,
|
||||
BISHOP,
|
||||
ROOK,
|
||||
QUEEN,
|
||||
KING,
|
||||
NO_PIECE_TYPE
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
enum Square : i8 {
|
||||
a1, b1, c1, d1, e1, f1, g1, h1,
|
||||
a2, b2, c2, d2, e2, f2, g2, h2,
|
||||
a3, b3, c3, d3, e3, f3, g3, h3,
|
||||
a4, b4, c4, d4, e4, f4, g4, h4,
|
||||
a5, b5, c5, d5, e5, f5, g5, h5,
|
||||
a6, b6, c6, d6, e6, f6, g6, h6,
|
||||
a7, b7, c7, d7, e7, f7, g7, h7,
|
||||
a8, b8, c8, d8, e8, f8, g8, h8,
|
||||
NO_SQUARE
|
||||
};
|
||||
|
||||
enum Direction : int {
|
||||
NORTH = 8,
|
||||
NORTH_EAST = 9,
|
||||
EAST = 1,
|
||||
SOUTH_EAST = -7,
|
||||
SOUTH = -8,
|
||||
SOUTH_WEST = -9,
|
||||
WEST = -1,
|
||||
NORTH_WEST = 7,
|
||||
NORTH_NORTH = 16,
|
||||
SOUTH_SOUTH = -16
|
||||
};
|
||||
|
||||
enum File : int {
|
||||
FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H
|
||||
};
|
||||
enum Rank : int {
|
||||
RANK1, RANK2, RANK3, RANK4, RANK5, RANK6, RANK7, RANK8
|
||||
};
|
||||
|
||||
constexpr u64 MASK_FILE[8] = {
|
||||
0x101010101010101, 0x202020202020202, 0x404040404040404, 0x808080808080808, 0x1010101010101010, 0x2020202020202020, 0x4040404040404040, 0x8080808080808080,
|
||||
};
|
||||
constexpr u64 MASK_RANK[8] = {
|
||||
0xff, 0xff00, 0xff0000, 0xff000000, 0xff00000000, 0xff0000000000, 0xff000000000000, 0xff00000000000000
|
||||
};
|
||||
|
||||
//Inverts the color (WHITE -> BLACK) and (BLACK -> WHITE)
|
||||
constexpr Color operator~(const Color c) { return static_cast<Color>(c ^ 1); }
|
||||
|
||||
inline Square& operator++(Square& s) { return s = static_cast<Square>(static_cast<i8>(s) + 1); }
|
||||
inline Square& operator--(Square& s) { return s = static_cast<Square>(static_cast<i8>(s) - 1); }
|
||||
constexpr Square operator+(const Square s, const Direction d) { return static_cast<Square>(static_cast<i8>(s) + static_cast<i8>(d)); }
|
||||
constexpr Square operator-(const Square s, const Direction d) { return static_cast<Square>(static_cast<i8>(s) - static_cast<i8>(d)); }
|
||||
inline Square& operator+=(Square& s, const Direction d) { return s = s + d; }
|
||||
inline Square& operator-=(Square& s, const Direction d) { return s = s - d; }
|
||||
//clang-format on
|
||||
|
||||
enum MoveType {
|
||||
STANDARD_MOVE = 0, EN_PASSANT = 0x4000, CASTLE = 0x8000, PROMOTION = 0xC000
|
||||
};
|
||||
|
||||
constexpr std::array<Square, 4> ROOK_CASTLE_END_SQ = { d8, f8, d1, f1 };
|
||||
constexpr std::array<Square, 4> KING_CASTLE_END_SQ = { c8, g8, c1, g1 };
|
||||
|
||||
class Move;
|
||||
|
||||
struct Board {
|
||||
// Index is based on square, returns the piece type
|
||||
std::array<PieceType, 64> mailbox;
|
||||
// Indexed pawns, knights, bishops, rooks, queens, king
|
||||
std::array<u64, 6> byPieces;
|
||||
// Index is based on color, so black is colors[0]
|
||||
std::array<u64, 2> byColor;
|
||||
|
||||
// En passant square
|
||||
Square epSquare;
|
||||
// Index KQkq
|
||||
std::array<Square, 4> castling;
|
||||
|
||||
Color stm;
|
||||
|
||||
usize halfMoveClock;
|
||||
usize fullMoveClock;
|
||||
|
||||
private:
|
||||
void placePiece(Color c, PieceType pt, int sq);
|
||||
void removePiece(Color c, PieceType pt, int sq);
|
||||
void removePiece(Color c, int sq);
|
||||
void resetMailbox();
|
||||
|
||||
void setCastlingRights(Color c, Square sq, bool value);
|
||||
void unsetCastlingRights(Color c);
|
||||
|
||||
Square castleSq(Color c, bool kingside) const;
|
||||
|
||||
public:
|
||||
u8 count(PieceType pt) const;
|
||||
|
||||
u64 pieces() const;
|
||||
u64 pieces(Color c) const;
|
||||
u64 pieces(PieceType pt) const;
|
||||
u64 pieces(Color c, PieceType pt) const;
|
||||
u64 pieces(PieceType pt1, PieceType pt2) const;
|
||||
u64 pieces(Color c, PieceType pt1, PieceType pt2) const;
|
||||
|
||||
void loadFromFEN(std::string fen);
|
||||
|
||||
char getPieceAt(i8 i) const;
|
||||
|
||||
PieceType getPiece(i8 sq) const;
|
||||
bool isCapture(Move m) const;
|
||||
|
||||
std::vector<float> asInputLayer() const;
|
||||
|
||||
void move(Move m);
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const Board& board);
|
||||
};
|
||||
}
|
||||
174
src/chess/dataloader.cpp
Normal file
174
src/chess/dataloader.cpp
Normal file
@@ -0,0 +1,174 @@
|
||||
#include "../dataloader.h"
|
||||
|
||||
#include "board.h"
|
||||
|
||||
#include "../../external/fmt/format.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <random>
|
||||
#include <chrono>
|
||||
#include <omp.h>
|
||||
|
||||
#include "../util.h"
|
||||
|
||||
namespace Ember::dataloaders::chess {
|
||||
BulletTextDataLoader::BulletTextDataLoader(const std::string& filePath, const u64 batchSize, const usize evalScale, const u64 threads) : DataLoader(batchSize, threads), filePath(filePath), evalScale(evalScale) {
|
||||
fmt::println("Attempting to open file '{}'", filePath);
|
||||
if (!std::filesystem::exists(filePath) || std::filesystem::is_directory(filePath))
|
||||
exitWithMsg("Data file does not exist or is a directory: " + filePath, 1);
|
||||
|
||||
|
||||
std::string l;
|
||||
file = std::ifstream(filePath);
|
||||
while (std::getline(file, l))
|
||||
numSamples++;
|
||||
|
||||
fmt::println("Found {} positions", formatNum(numSamples));
|
||||
}
|
||||
|
||||
void BulletTextDataLoader::loadBatch(const usize batchIdx) {
|
||||
data[batchIdx].input.resize(batchSize, static_cast<usize>(2 * 6 * 64));
|
||||
data[batchIdx].target.resize(batchSize, static_cast<usize>(1));
|
||||
|
||||
data[batchIdx].target.fill(0);
|
||||
data[batchIdx].input.fill(0);
|
||||
|
||||
std::vector<std::vector<internal::DataPoint>> localData(threads);
|
||||
|
||||
std::string l;
|
||||
std::vector<std::string> lines;
|
||||
u64 linesRead = 0;
|
||||
|
||||
// Load lines into a buffer
|
||||
while (linesRead < batchSize) {
|
||||
if (!std::getline(file, l)) {
|
||||
file = std::ifstream(filePath);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Discard null chars because windows uses encoding that
|
||||
// places a \0 after every character
|
||||
std::erase_if(l, [](const char c) { return c == '\0'; });
|
||||
|
||||
if (std::ranges::all_of(l.begin(), l.end(), [](const char c) { return std::isspace(c); }))
|
||||
continue;
|
||||
|
||||
lines.emplace_back(l);
|
||||
linesRead++;
|
||||
}
|
||||
|
||||
assert(lines.size() == batchSize);
|
||||
|
||||
std::vector<u64> shuffledIndexes;
|
||||
shuffledIndexes.reserve(batchSize);
|
||||
|
||||
// This is for batch shuffling
|
||||
for (usize i = 0; i < batchSize; i++)
|
||||
shuffledIndexes.push_back(i);
|
||||
|
||||
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
std::default_random_engine rng(seed);
|
||||
|
||||
// Shuffle the vector
|
||||
std::ranges::shuffle(shuffledIndexes, rng);
|
||||
|
||||
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
|
||||
for (usize i = 0; i < batchSize; i++) {
|
||||
std::string& line = lines[shuffledIndexes[i]];
|
||||
|
||||
// Strip UTF-16 BOM if present
|
||||
if (line.size() >= 2 && static_cast<unsigned char>(line[0]) == 0xFF && static_cast<unsigned char>(line[1]) == 0xFE)
|
||||
line = line.substr(2);
|
||||
|
||||
const auto tokens = split(line, '|');
|
||||
|
||||
if (tokens.size() != 3)
|
||||
exitWithMsg(fmt::format("Expected 3 tokens, got {}. Failed to parse line: {}", tokens.size(), line), 1);
|
||||
assert(tokens.size() == 3);
|
||||
|
||||
const std::string& fen = tokens[0];
|
||||
const float eval = std::stof(tokens[1]);
|
||||
// Token 2 is discarded b/c it's the WDL which is not
|
||||
// used yet
|
||||
|
||||
Ember::chess::Board board{};
|
||||
board.loadFromFEN(fen);
|
||||
|
||||
std::vector<float> input = board.asInputLayer();
|
||||
|
||||
std::memcpy(&data[batchIdx].input[i, 0], input.data(), sizeof(float) * input.size());
|
||||
data[batchIdx].target[i, 0] = eval * evalScale;
|
||||
}
|
||||
}
|
||||
|
||||
void BulletTextDataLoader::loadTestSet() {
|
||||
// Local file object, always reads the
|
||||
// first batchSize samples for
|
||||
// more consistent results
|
||||
std::ifstream file(filePath);
|
||||
|
||||
data[currBatch].input.resize(batchSize, static_cast<usize>(2 * 6 * 64));
|
||||
data[currBatch].target.resize(batchSize, static_cast<usize>(1));
|
||||
|
||||
data[currBatch].target.fill(0);
|
||||
data[currBatch].input.fill(0);
|
||||
|
||||
std::vector<std::vector<internal::DataPoint>> localData(threads);
|
||||
|
||||
std::string l;
|
||||
std::vector<std::string> lines;
|
||||
u64 linesRead = 0;
|
||||
|
||||
// Load lines into a buffer
|
||||
while (linesRead < batchSize) {
|
||||
if (!std::getline(file, l))
|
||||
exitWithMsg(fmt::format("Failed to load test batch. Is the data corrupted? Are there at least {} data points?", batchSize), 1);
|
||||
// Discard null chars because windows uses encoding that
|
||||
// places a \0 after every character
|
||||
std::erase_if(l, [](const char c) { return c == '\0'; });
|
||||
|
||||
if (std::ranges::all_of(l.begin(), l.end(), [](const char c) { return std::isspace(c); }))
|
||||
continue;
|
||||
|
||||
lines.emplace_back(l);
|
||||
linesRead++;
|
||||
}
|
||||
|
||||
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
|
||||
for (usize i = 0; i < batchSize; i++) {
|
||||
std::string& line = lines[i];
|
||||
|
||||
// Strip UTF-16 BOM if present
|
||||
if (line.size() >= 2 && static_cast<unsigned char>(line[0]) == 0xFF && static_cast<unsigned char>(line[1]) == 0xFE)
|
||||
line = line.substr(2);
|
||||
|
||||
const auto tokens = split(line, '|');
|
||||
|
||||
if (tokens.size() != 3)
|
||||
exitWithMsg(fmt::format("Expected 3 tokens, got {}. Failed to parse line: {}", tokens.size(), line), 1);
|
||||
assert(tokens.size() == 3);
|
||||
|
||||
const std::string& fen = tokens[0];
|
||||
const float eval = std::stof(tokens[1]);
|
||||
|
||||
Ember::chess::Board board{};
|
||||
board.loadFromFEN(fen);
|
||||
|
||||
std::vector<float> input = board.asInputLayer();
|
||||
|
||||
std::memcpy(&data[currBatch].input[i, 0], input.data(), sizeof(float) * input.size());
|
||||
data[currBatch].target[i, 0] = eval * evalScale;
|
||||
}
|
||||
}
|
||||
|
||||
u64 BulletTextDataLoader::countCorrect(const Tensor& output, const Tensor& target) {
|
||||
u64 numCorrect = 0;
|
||||
|
||||
for (usize i = 0; i < target.dim(0); i++)
|
||||
numCorrect += (std::round(output[i, 0] / evalScale) == std::round(target[i, 0] / evalScale));
|
||||
|
||||
return numCorrect;
|
||||
}
|
||||
}
|
||||
@@ -38,17 +38,12 @@ std::vector<float> loadGreyscaleImage(const std::string& path, const Ember::usiz
|
||||
}
|
||||
|
||||
namespace Ember::dataloaders {
|
||||
ImageDataLoader::ImageDataLoader(const std::string& dataDir, const u64 batchSize, const float trainSplit, const u64 threads, const usize width, const usize height)
|
||||
: DataLoader(batchSize, trainSplit, threads) {
|
||||
this->width = width;
|
||||
this->height = height;
|
||||
|
||||
ImageDataLoader::ImageDataLoader(const std::string& dataDir, const u64 batchSize, const u64 threads, const float trainSplit, const usize width, const usize height)
|
||||
: DataLoader(batchSize, threads), dataDir(dataDir), trainSplit(trainSplit), width(width), height(height) {
|
||||
fmt::println("Attempting to open data dir '{}'", dataDir);
|
||||
if (!std::filesystem::exists(dataDir) || !std::filesystem::is_directory(dataDir))
|
||||
exitWithMsg("Data directory does not exist or is not a directory: " + dataDir, 1);
|
||||
|
||||
this->dataDir = dataDir;
|
||||
|
||||
for (const auto &entry: std::filesystem::directory_iterator(this->dataDir)) {
|
||||
if (entry.is_directory())
|
||||
types.push_back(entry.path().string());
|
||||
@@ -95,7 +90,7 @@ namespace Ember::dataloaders {
|
||||
|
||||
std::vector<std::vector<internal::DataPoint>> localData(threads);
|
||||
|
||||
#pragma omp parallel for num_threads(threads)
|
||||
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
|
||||
for (usize i = 0; i < batchSize; i++) {
|
||||
std::mt19937 rng{ std::random_device{}() + omp_get_thread_num()};
|
||||
|
||||
@@ -142,4 +137,22 @@ namespace Ember::dataloaders {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
u64 ImageDataLoader::countCorrect(const Tensor& output, const Tensor& target) {
|
||||
u64 numCorrect = 0;
|
||||
|
||||
for (usize i = 0; i < target.dim(0); i++) {
|
||||
usize guess = 0;
|
||||
usize goal = 0;
|
||||
for (usize j = 0; j < target.dim(1); j++) {
|
||||
if (output[i, j] > output[i, guess])
|
||||
guess = j;
|
||||
if (target[i, j] > target[i, goal])
|
||||
goal = j;
|
||||
}
|
||||
numCorrect += (guess == goal);
|
||||
}
|
||||
|
||||
return numCorrect;
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "types.h"
|
||||
#include "tensor.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <future>
|
||||
#include <random>
|
||||
@@ -21,22 +22,17 @@ namespace Ember {
|
||||
struct DataLoader {
|
||||
u64 threads;
|
||||
u64 batchSize;
|
||||
float trainSplit;
|
||||
|
||||
u64 numSamples;
|
||||
u64 numSamples = 0;
|
||||
|
||||
usize currBatch = 0;
|
||||
|
||||
usize currBatch;
|
||||
std::future<void> dataFuture;
|
||||
std::array<DataPoint, 2> data;
|
||||
|
||||
DataLoader(const u64 batchSize, const float trainSplit, const u64 threads) {
|
||||
DataLoader(const u64 batchSize, const u64 threads) {
|
||||
this->threads = threads;
|
||||
this->batchSize = batchSize;
|
||||
this->trainSplit = trainSplit;
|
||||
|
||||
this->numSamples = 0;
|
||||
|
||||
this->currBatch = 0;
|
||||
}
|
||||
|
||||
// Loads batch into other buffer
|
||||
@@ -57,10 +53,14 @@ namespace Ember {
|
||||
return data[currBatch];
|
||||
}
|
||||
|
||||
void swapBuffers() {
|
||||
virtual void swapBuffers() {
|
||||
currBatch ^= 1;
|
||||
}
|
||||
|
||||
// Returns the number of "correct" outputs
|
||||
// from the network
|
||||
virtual u64 countCorrect(const Tensor& output, const Tensor& target) = 0;
|
||||
|
||||
virtual ~DataLoader() = default;
|
||||
};
|
||||
}
|
||||
@@ -72,17 +72,46 @@ namespace Ember {
|
||||
std::vector<u64> samplesPerType;
|
||||
std::vector<std::vector<std::string>> allImages;
|
||||
|
||||
std::vector<usize> trainSamplesPerType;
|
||||
usize numTrainSamples;
|
||||
usize numTestSamples;
|
||||
std::vector<u64> trainSamplesPerType;
|
||||
u64 numTrainSamples;
|
||||
u64 numTestSamples;
|
||||
|
||||
float trainSplit;
|
||||
|
||||
usize width;
|
||||
usize height;
|
||||
|
||||
ImageDataLoader(const std::string& dataDir, const u64 batchSize, const float trainSplit, const u64 threads = 0, const usize width = 0, const usize height = 0);
|
||||
ImageDataLoader(const std::string& dataDir, const u64 batchSize, const u64 threads, const float trainSplit, const usize width = 0, const usize height = 0);
|
||||
|
||||
void loadBatch(const usize batchIdx) override;
|
||||
void loadTestSet() override;
|
||||
|
||||
u64 countCorrect(const Tensor& output, const Tensor& target) override;
|
||||
};
|
||||
|
||||
// Defined in ./chess/*
|
||||
namespace chess {
|
||||
struct BulletTextDataLoader : internal::DataLoader {
|
||||
std::string filePath;
|
||||
|
||||
u64 batchNumber = 0;
|
||||
usize evalScale = 0;
|
||||
|
||||
std::ifstream file;
|
||||
|
||||
BulletTextDataLoader(const std::string& filePath, const u64 batchSize, const usize evalScale, const u64 threads = 0);
|
||||
|
||||
void loadBatch(const usize batchIdx) override;
|
||||
void loadTestSet() override;
|
||||
|
||||
u64 countCorrect(const Tensor& output, const Tensor& target) override;
|
||||
|
||||
|
||||
void swapBuffers() override {
|
||||
batchNumber++;
|
||||
currBatch ^= 1;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,6 @@ namespace Ember {
|
||||
|
||||
// Returns { test loss, test accuracy }
|
||||
const auto getTestLossAcc = [&]() {
|
||||
usize numCorrect = 0;
|
||||
dataLoader.loadTestSet();
|
||||
const internal::DataPoint& data = dataLoader.batchData();
|
||||
const usize testSize = data.input.dim(0);
|
||||
@@ -62,17 +61,7 @@ namespace Ember {
|
||||
|
||||
const float loss = lossFunc->forward(net.output(), data.target);
|
||||
|
||||
for (usize i = 0; i < data.target.dim(0); i++) {
|
||||
usize guess = 0;
|
||||
usize goal = 0;
|
||||
for (usize j = 0; j < data.target.dim(1); j++) {
|
||||
if (net.output()[i, j] > net.output()[i, guess])
|
||||
guess = j;
|
||||
if (data.target[i, j] > data.target[i, goal])
|
||||
goal = j;
|
||||
}
|
||||
numCorrect += (guess == goal);
|
||||
}
|
||||
const u64 numCorrect = dataLoader.countCorrect(net.output(), data.target);
|
||||
|
||||
return std::pair<float, float>{ loss, numCorrect / static_cast<float>(testSize ? testSize : 1) };
|
||||
};
|
||||
@@ -160,7 +149,7 @@ namespace Ember {
|
||||
internal::cursor::up();
|
||||
internal::cursor::up();
|
||||
internal::cursor::begin();
|
||||
fmt::println("{:>5L}{:>14.5f}{:>13}{:>17}{:>12}", epoch, trainLoss / currentBatch, "Pending", "Pending", formatTime(stopwatch.elapsed()));
|
||||
fmt::print("{:>5L}{:>14.5f}{:>13}{:>17}{:>12}\n", epoch, trainLoss / currentBatch, "Pending", "Pending", formatTime(stopwatch.elapsed()));
|
||||
std::cout << progressBar.report(currentBatch + 1, batchesPerEpoch, 63) << " " << std::endl;
|
||||
|
||||
afterBatch:
|
||||
|
||||
44
src/loss.cpp
44
src/loss.cpp
@@ -7,20 +7,56 @@ namespace Ember::loss {
|
||||
float loss = 0;
|
||||
for (usize i = 0; i < output.size(); i++)
|
||||
loss += std::pow(output.data[i] - target.data[i], 2);
|
||||
return loss / output.dim(1);
|
||||
return loss / output.size();
|
||||
}
|
||||
|
||||
Tensor MeanSquaredError::backward(const Tensor& output, const Tensor& target) {
|
||||
Tensor gradient;
|
||||
gradient.resize(output.dims());
|
||||
|
||||
const float scalar = 2.0f / output.dim(1);
|
||||
const float scalar = 2.0f / output.size();
|
||||
|
||||
for (usize i = 0; i < output.size(); i++)
|
||||
gradient[i] = (output.data[i] - target.data[i]) * scalar;
|
||||
return gradient;
|
||||
}
|
||||
|
||||
|
||||
float SigmoidMSE::forward(const Tensor& output, const Tensor& target) {
|
||||
assert(output.size() == target.size());
|
||||
|
||||
float loss = 0;
|
||||
for (usize i = 0; i < output.size(); i++) {
|
||||
const float imprecision = std::abs(sigmoid(output.data[i]) - sigmoid(target.data[i]));
|
||||
loss += std::pow(imprecision, 2);
|
||||
}
|
||||
return loss / output.size() - offset;
|
||||
}
|
||||
|
||||
Tensor SigmoidMSE::backward(const Tensor& output, const Tensor& target) {
|
||||
assert(output.size() == target.size());
|
||||
|
||||
Tensor gradient;
|
||||
gradient.resize(output.dims());
|
||||
|
||||
const float scalar = 2.0f / output.size();
|
||||
|
||||
for (usize i = 0; i < output.size(); i++) {
|
||||
const float expOutput = std::exp(a + b * output.data[i]);
|
||||
const float expTarget = std::exp(a + b * target.data[i]);
|
||||
|
||||
const float fOutput = k / (1.0f + expOutput);
|
||||
const float fTarget = k / (1.0f + expTarget);
|
||||
|
||||
const float fprime_out = -k * b * expOutput / ((1.0f + expOutput) * (1.0f + expOutput));
|
||||
|
||||
gradient.data[i] = scalar * (fOutput - fTarget) * fprime_out;
|
||||
}
|
||||
|
||||
return gradient;
|
||||
}
|
||||
|
||||
|
||||
float CrossEntropyLoss::forward(const Tensor& output, const Tensor& target) {
|
||||
assert(output.size() == target.size());
|
||||
|
||||
@@ -31,7 +67,7 @@ namespace Ember::loss {
|
||||
prob = std::min(prob, 1.0f);
|
||||
loss -= target.data[i] * std::log(prob);
|
||||
}
|
||||
return loss / output.dim(0);
|
||||
return loss / output.size();
|
||||
}
|
||||
|
||||
Tensor CrossEntropyLoss::backward(const Tensor& output, const Tensor& target) {
|
||||
@@ -40,7 +76,7 @@ namespace Ember::loss {
|
||||
Tensor gradient;
|
||||
gradient.resize(output.dims());
|
||||
|
||||
const float scalar = 1.0f / output.dim(0);
|
||||
const float scalar = 1.0f / output.size();
|
||||
for (usize i = 0; i < output.size(); i++) {
|
||||
const float prob = std::max(output.data[i], 1e-10f);
|
||||
gradient.data[i] = -target.data[i] / prob * scalar;
|
||||
|
||||
25
src/loss.h
25
src/loss.h
@@ -19,6 +19,31 @@ namespace Ember {
|
||||
Tensor backward(const Tensor& output, const Tensor& target) override;
|
||||
};
|
||||
|
||||
// Apply sigmoid to the input and target before
|
||||
// calculating the loss using MSE
|
||||
// Modeled by the below function, paste into Desmos to see it
|
||||
// f\left(x\right)=\left(\frac{k}{1+e^{\left(a+bx\right)}}\right)
|
||||
struct SigmoidMSE : internal::LossFunction {
|
||||
float a = 1;
|
||||
float b = -0.25;
|
||||
float k = 1;
|
||||
|
||||
float offset;
|
||||
|
||||
SigmoidMSE(const float a, const float b, const float k) : a(a), b(b), k(k) {
|
||||
offset = -std::pow(sigmoid(0), 2);
|
||||
}
|
||||
explicit SigmoidMSE(const float horizontalStretch) {
|
||||
b /= horizontalStretch;
|
||||
offset = -std::pow(sigmoid(0), 2);
|
||||
}
|
||||
|
||||
float sigmoid(const float x) const { return k / (1 + std::exp(a + b * x)); }
|
||||
|
||||
float forward(const Tensor& output, const Tensor& target) override;
|
||||
Tensor backward(const Tensor& output, const Tensor& target) override;
|
||||
};
|
||||
|
||||
struct CrossEntropyLoss : internal::LossFunction {
|
||||
float forward(const Tensor& output, const Tensor& target) override;
|
||||
Tensor backward(const Tensor& output, const Tensor& target) override;
|
||||
|
||||
17
src/util.h
17
src/util.h
@@ -2,6 +2,8 @@
|
||||
|
||||
#include "types.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
namespace Ember {
|
||||
// Formats a number with commas
|
||||
inline std::string formatNum(const i64 v) {
|
||||
@@ -17,4 +19,19 @@ namespace Ember {
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> split(const std::string& str, const char delim) {
|
||||
std::vector<std::string> result;
|
||||
|
||||
std::istringstream stream(str);
|
||||
|
||||
for (std::string token{}; std::getline(stream, token, delim);) {
|
||||
if (token.empty())
|
||||
continue;
|
||||
|
||||
result.push_back(token);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user