Chess FEN dataloading

Expects bullet-style text input
FEN | EVAL | WDL
This commit is contained in:
Quinn
2025-11-08 12:33:17 -06:00
parent 5f6946d5f5
commit 3ff2ec4f8b
13 changed files with 922 additions and 57 deletions

View File

@@ -34,6 +34,7 @@ EXE ?= Ember$(EXE_EXT)
# Source and object files
SRCS := $(wildcard ./src/*.cpp)
SRCS += $(wildcard ./src/*/*.cpp)
SRCS += ./external/fmt/format.cpp
OBJS := $(SRCS:.cpp=.o)
DEPS := $(OBJS:.o=.d)

View File

@@ -4,31 +4,26 @@
int main() {
Ember::Network net(
Ember::layers::Input(28, 28, 1),
Ember::layers::Convolution(64, 3),
Ember::layers::Input(2 * 6 * 64),
Ember::layers::Linear(128),
Ember::activations::ReLU(),
Ember::layers::MaxPool(),
Ember::layers::Flatten(),
Ember::layers::Linear(512),
Ember::activations::ReLU(),
Ember::layers::Linear(64),
Ember::activations::ReLU(),
Ember::layers::Linear(10),
Ember::activations::Softmax()
Ember::layers::Linear(1)
);
Ember::dataloaders::ImageDataLoader dataloader("../datasets/FashionMNIST/", 64, 0.9, 6, 28, 28);
constexpr Ember::usize evalScale = 400;
Ember::dataloaders::chess::BulletTextDataLoader dataloader("../datasets/preludeData.txt", 1024 * 16, evalScale, 1);
Ember::optimizers::Adam optimizer(net);
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::CrossEntropyLoss());
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::SigmoidMSE(evalScale));
std::cout << net << std::endl;
learner.addCallbacks(
Ember::callbacks::DropLROnPlateau(3, 0.3),
Ember::callbacks::StopWhenNoProgress(5),
Ember::callbacks::DropLROnPlateau(3, 0.3, Ember::Metric::TRAIN_LOSS),
Ember::callbacks::StopWhenNoProgress(5, Ember::Metric::TRAIN_LOSS),
Ember::callbacks::AutosaveBest("../net.bin")
);
learner.learn(0.001, 20, 1);
}
learner.learn(0.005, 40, 1);
}

View File

@@ -1,15 +1,23 @@
#include "activation.h"
#include <algorithm>
namespace Ember {
namespace internal::activations {
float ReLU(const float x) {
return std::max(x, 0.0f);
}
float CReLU(const float x) {
return std::clamp(x, 0.0f, 1.0f);
}
namespace derivatives {
float ReLU(const float x) {
return x > 0 ? 1 : 0;
}
float CReLU(const float x) {
return x > 0 && x < 1 ? 1 : 0;
}
}
}
@@ -27,6 +35,19 @@ namespace Ember {
}
void CReLU::forward(const Layer& previous) {
for (usize prev = 0; prev < previous.values.size(); prev++)
values.data[prev] = internal::activations::CReLU(previous.values.data[prev]);
}
Tensor CReLU::backward(const Layer& previous, const Tensor& gradOutput) const {
Tensor result(gradOutput.dims());
for (usize prev = 0; prev < gradOutput.size(); prev++)
result.data[prev] = gradOutput.data[prev] * internal::activations::derivatives::CReLU(previous.values.data[prev]);
return result;
}
void Softmax::forward(const Layer& previous) {
const usize batchSize = previous.values.dim(0);
const usize numClasses = previous.values.dim(1);
@@ -72,4 +93,4 @@ namespace Ember {
return result;
}
}
}
}

View File

@@ -22,6 +22,20 @@ namespace Ember {
}
};
struct CReLU : internal::NonComputeLayer {
void forward(const Layer& previous) override;
Tensor backward(const Layer& previous, const Tensor& gradOutput) const override;
std::unique_ptr<Layer> clone() override {
return std::make_unique<CReLU>(*this);
}
std::string str() const override {
return fmt::format("Clipped ReLU - {}", dims());
}
};
struct Softmax : internal::NonComputeLayer {
void forward(const Layer& previous) override;

415
src/chess/board.cpp Normal file
View File

@@ -0,0 +1,415 @@
#include "board.h"
#include "../util.h"
#include <sstream>
#include <random>
#include <bit>
using std::popcount;
#define ctzll(x) std::countr_zero(x)
namespace Ember::chess {
// Helpers and util functions
constexpr Square toSquare(const Rank rank, const File file) { return static_cast<Square>((static_cast<int>(rank) << 3) | file); }
// Takes square (h8) and converts it into a bitboard index (64)
constexpr Square parseSquare(const std::string_view square) { return static_cast<Square>((square.at(1) - '1') * 8 + (square.at(0) - 'a')); }
bool readBit(const u64 bb, const i8 sq) { return (1ULL << sq) & bb; }
template<bool value>
void setBit(auto& bitboard, const usize index) {
assert(index <= sizeof(bitboard) * 8);
if constexpr (value)
bitboard |= (1ULL << index);
else
bitboard &= ~(1ULL << index);
}
Square getLSB(auto bb) {
assert(bb > 0);
return static_cast<Square>(ctzll(bb));
}
Square popLSB(auto& bb) {
assert(bb > 0);
const Square sq = getLSB(bb);
bb &= bb - 1;
return sq;
}
template<int dir>
u64 shift(const u64 bb) {
return dir > 0 ? bb << dir : bb >> -dir;
}
u64 shift(const int dir, const u64 bb) { return dir > 0 ? bb << dir : bb >> -dir; }
constexpr u8 castleIndex(const Color c, const bool kingside) { return c == WHITE ? (kingside ? 3 : 2) : (kingside ? 1 : 0); }
constexpr Square flipRank(Square s) { return Square(s ^ 0b111000); }
constexpr Square flipFile(Square s) { return Square(s ^ 0b000111); }
// Encodes a chess move
class Move {
u16 move;
public:
constexpr Move() = default;
constexpr ~Move() = default;
constexpr Move(const u8 startSquare, const u8 endSquare, const MoveType flags = STANDARD_MOVE) {
move = startSquare | flags;
move |= endSquare << 6;
}
constexpr Move(const u8 startSquare, const u8 endSquare, const PieceType promo) {
move = startSquare | PROMOTION;
move |= endSquare << 6;
move |= (promo - 1) << 12;
}
Move(std::string strIn, Board& board);
constexpr static Move null() { return Move(a1, a1); }
std::string toString() const;
Square from() const { return static_cast<Square>(move & 0b111111); }
Square to() const { return static_cast<Square>((move >> 6) & 0b111111); }
MoveType typeOf() const { return static_cast<MoveType>(move & 0xC000); }
PieceType promo() const {
assert(typeOf() == PROMOTION);
return static_cast<PieceType>(((move >> 12) & 0b11) + 1);
}
bool isNull() const { return *this == null(); }
bool operator==(const Move other) const { return move == other.move; }
friend std::ostream& operator<<(std::ostream& os, const Move& m) {
os << m.toString();
return os;
}
};
// Returns the piece on a square as a character
char Board::getPieceAt(const i8 sq) const {
assert(sq >= 0);
assert(sq < 64);
if (getPiece(sq) == NO_PIECE_TYPE)
return ' ';
constexpr char whiteSymbols[] = { 'P', 'N', 'B', 'R', 'Q', 'K' };
constexpr char blackSymbols[] = { 'p', 'n', 'b', 'r', 'q', 'k' };
if (((1ULL << sq) & byColor[WHITE]) != 0)
return whiteSymbols[getPiece(sq)];
return blackSymbols[getPiece(sq)];
}
void Board::placePiece(const Color c, const PieceType pt, const int sq) {
assert(sq >= 0);
assert(sq < 64);
auto& BB = byPieces[pt];
assert(!readBit(BB, sq));
BB ^= 1ULL << sq;
byColor[c] ^= 1ULL << sq;
mailbox[sq] = pt;
}
void Board::removePiece(Color c, PieceType pt, int sq) {
assert(sq >= 0);
assert(sq < 64);
auto& BB = byPieces[pt];
assert(readBit(BB, sq));
BB ^= 1ULL << sq;
byColor[c] ^= 1ULL << sq;
mailbox[sq] = NO_PIECE_TYPE;
}
void Board::removePiece(Color c, int sq) {
assert(sq >= 0);
assert(sq < 64);
auto& BB = byPieces[getPiece(sq)];
assert(readBit(BB, sq));
BB ^= 1ULL << sq;
byColor[c] ^= 1ULL << sq;
mailbox[sq] = NO_PIECE_TYPE;
}
void Board::resetMailbox() {
mailbox.fill(NO_PIECE_TYPE);
for (u8 i = 0; i < 64; i++) {
PieceType& sq = mailbox[i];
const u64 mask = 1ULL << i;
if (mask & pieces(PAWN))
sq = PAWN;
else if (mask & pieces(KNIGHT))
sq = KNIGHT;
else if (mask & pieces(BISHOP))
sq = BISHOP;
else if (mask & pieces(ROOK))
sq = ROOK;
else if (mask & pieces(QUEEN))
sq = QUEEN;
else if (mask & pieces(KING))
sq = KING;
}
}
void Board::setCastlingRights(const Color c, const Square sq, const bool value) { castling[castleIndex(c, ctzll(pieces(c, KING)) < sq)] = (value == false ? NO_SQUARE : sq); }
void Board::unsetCastlingRights(const Color c) { castling[castleIndex(c, true)] = castling[castleIndex(c, false)] = NO_SQUARE; }
Square Board::castleSq(const Color c, const bool kingside) const { return castling[castleIndex(c, kingside)]; }
u8 Board::count(const PieceType pt) const { return popcount(pieces(pt)); }
u64 Board::pieces() const { return byColor[WHITE] | byColor[BLACK]; }
u64 Board::pieces(const Color c) const { return byColor[c]; }
u64 Board::pieces(const PieceType pt) const { return byPieces[pt]; }
u64 Board::pieces(const Color c, const PieceType pt) const { return byPieces[pt] & byColor[c]; }
u64 Board::pieces(const PieceType pt1, const PieceType pt2) const { return byPieces[pt1] | byPieces[pt2]; }
u64 Board::pieces(const Color c, const PieceType pt1, const PieceType pt2) const { return (byPieces[pt1] | byPieces[pt2]) & byColor[c]; }
// Load a board from the FEN
void Board::loadFromFEN(const std::string fen) {
// Clear all squares
byPieces.fill(0);
byColor.fill(0);
const std::vector<std::string> tokens = split(fen, ' ');
const std::vector<std::string> rankTokens = split(tokens[0], '/');
int currIdx = 56;
constexpr char whitePieces[6] = { 'P', 'N', 'B', 'R', 'Q', 'K' };
constexpr char blackPieces[6] = { 'p', 'n', 'b', 'r', 'q', 'k' };
for (const std::string& rank : rankTokens) {
for (const char c : rank) {
if (isdigit(c)) { // Empty squares
currIdx += c - '0';
continue;
}
for (int i = 0; i < 6; i++) {
if (c == whitePieces[i]) {
setBit<1>(byPieces[i], currIdx);
setBit<1>(byColor[WHITE], currIdx);
break;
}
if (c == blackPieces[i]) {
setBit<1>(byPieces[i], currIdx);
setBit<1>(byColor[BLACK], currIdx);
break;
}
}
currIdx++;
}
currIdx -= 16;
}
if (tokens[1] == "w")
stm = WHITE;
else
stm = BLACK;
castling.fill(NO_SQUARE);
if (tokens[2].find('-') == std::string::npos) {
// Standard FEN and maybe XFEN later
if (tokens[2].find('K') != std::string::npos)
castling[castleIndex(WHITE, true)] = h1;
if (tokens[2].find('Q') != std::string::npos)
castling[castleIndex(WHITE, false)] = a1;
if (tokens[2].find('k') != std::string::npos)
castling[castleIndex(BLACK, true)] = h8;
if (tokens[2].find('q') != std::string::npos)
castling[castleIndex(BLACK, false)] = a8;
// FRC FEN
if (std::tolower(tokens[2][0]) >= 'a' && std::tolower(tokens[2][0]) <= 'h') {
for (char token : tokens[2]) {
const auto file = static_cast<File>(std::tolower(token) - 'a');
if (std::isupper(token))
setCastlingRights(WHITE, toSquare(RANK1, file), true);
else
setCastlingRights(BLACK, toSquare(RANK8, file), true);
}
}
}
if (tokens[3] != "-")
epSquare = parseSquare(tokens[3]);
else
epSquare = NO_SQUARE;
halfMoveClock = tokens.size() > 4 ? (stoi(tokens[4])) : 0;
fullMoveClock = tokens.size() > 5 ? (stoi(tokens[5])) : 1;
resetMailbox();
}
// Return the type of the piece on the square
PieceType Board::getPiece(const i8 sq) const {
assert(sq >= 0);
assert(sq < 64);
return mailbox[sq];
}
bool Board::isCapture(const Move m) const { return ((1ULL << m.to() & pieces(~stm)) || m.typeOf() == EN_PASSANT); }
std::vector<float> Board::asInputLayer() const {
const auto getFeature = [this](const Color pieceColor, const Square square) {
const bool enemy = stm != pieceColor;
const int squareIndex = (stm == BLACK) ? flipRank(square) : static_cast<int>(square);
return enemy * 64 * 6 + getPiece(square) * 64 + squareIndex;
};
std::vector<float> res(2 * 6 * 64);
u64 whitePieces = pieces(WHITE);
u64 blackPieces = pieces(BLACK);
while (whitePieces) {
const Square sq = popLSB(whitePieces);
const usize feature = getFeature(WHITE, sq);
res[feature] = true;
}
while (blackPieces) {
const Square sq = popLSB(blackPieces);
const usize feature = getFeature(BLACK, sq);
res[feature] = true;
}
return res;
}
void Board::move(const Move m) {
epSquare = NO_SQUARE;
Square from = m.from();
Square to = m.to();
MoveType mt = m.typeOf();
PieceType pt = getPiece(from);
PieceType toPT = NO_PIECE_TYPE;
removePiece(stm, pt, from);
if (isCapture(m)) {
toPT = getPiece(to);
halfMoveClock = 0;
if (mt != EN_PASSANT) {
removePiece(~stm, toPT, to);
}
}
else {
if (pt == PAWN)
halfMoveClock = 0;
else
halfMoveClock++;
}
switch (mt) {
case STANDARD_MOVE:
placePiece(stm, pt, to);
if (pt == PAWN && (to + 16 == from || to - 16 == from)
&& (pieces(~stm, PAWN) & (shift<EAST>((1ULL << to) & ~MASK_FILE[FILE_H]) | shift<WEST>((1ULL << to) & ~MASK_FILE[FILE_A])))) // Only set EP square if it could be taken
epSquare = static_cast<Square>(stm == WHITE ? from + NORTH : from + SOUTH);
break;
case EN_PASSANT:
removePiece(~stm, PAWN, to + (stm == WHITE ? SOUTH : NORTH));
placePiece(stm, pt, to);
break;
case CASTLE:
assert(getPiece(to) == ROOK);
removePiece(stm, ROOK, to);
if (stm == WHITE) {
if (from < to) {
placePiece(stm, KING, g1);
placePiece(stm, ROOK, f1);
}
else {
placePiece(stm, KING, c1);
placePiece(stm, ROOK, d1);
}
}
else {
if (from < to) {
placePiece(stm, KING, g8);
placePiece(stm, ROOK, f8);
}
else {
placePiece(stm, KING, c8);
placePiece(stm, ROOK, d8);
}
}
break;
case PROMOTION:
placePiece(stm, m.promo(), to);
break;
}
assert(popcount(pieces(WHITE, KING)) == 1);
assert(popcount(pieces(BLACK, KING)) == 1);
if (pt == ROOK) {
const Square sq = castleSq(stm, from > ctzll(pieces(stm, KING)));
if (from == sq)
setCastlingRights(stm, from, false);
}
else if (pt == KING)
unsetCastlingRights(stm);
if (toPT == ROOK) {
const Square sq = castleSq(~stm, to > ctzll(pieces(~stm, KING)));
if (to == sq)
setCastlingRights(~stm, to, false);
}
stm = ~stm;
fullMoveClock += stm == WHITE;
}
std::ostream& operator<<(std::ostream& os, const Board& board) {
os << "\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n";
usize line = 1;
for (i32 rank = (board.stm == WHITE) * 7; (board.stm == WHITE) ? rank >= 0 : rank < 8; (board.stm == WHITE) ? rank-- : rank++) {
os << "\u2502 ";
for (usize file = 0; file < 8; file++) {
const auto sq = static_cast<Square>(rank * 8 + file);
const auto color = ((1ULL << sq) & board.pieces(WHITE)) ? "\033[33m" : "\033[34m";
os << color << board.getPieceAt(sq) << "\033[0m ";
}
os << "\u2502 " << rank + 1 << "\n";
}
os << "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n";
os << " a b c d e f g h\n";
return os;
}
}

136
src/chess/board.h Normal file
View File

@@ -0,0 +1,136 @@
#pragma once
#include <array>
#include "../types.h"
#include "../tensor.h"
namespace Ember::chess {
enum Color {
BLACK,
WHITE
};
enum PieceType : i8 {
PAWN,
KNIGHT,
BISHOP,
ROOK,
QUEEN,
KING,
NO_PIECE_TYPE
};
// clang-format off
enum Square : i8 {
a1, b1, c1, d1, e1, f1, g1, h1,
a2, b2, c2, d2, e2, f2, g2, h2,
a3, b3, c3, d3, e3, f3, g3, h3,
a4, b4, c4, d4, e4, f4, g4, h4,
a5, b5, c5, d5, e5, f5, g5, h5,
a6, b6, c6, d6, e6, f6, g6, h6,
a7, b7, c7, d7, e7, f7, g7, h7,
a8, b8, c8, d8, e8, f8, g8, h8,
NO_SQUARE
};
enum Direction : int {
NORTH = 8,
NORTH_EAST = 9,
EAST = 1,
SOUTH_EAST = -7,
SOUTH = -8,
SOUTH_WEST = -9,
WEST = -1,
NORTH_WEST = 7,
NORTH_NORTH = 16,
SOUTH_SOUTH = -16
};
enum File : int {
FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H
};
enum Rank : int {
RANK1, RANK2, RANK3, RANK4, RANK5, RANK6, RANK7, RANK8
};
constexpr u64 MASK_FILE[8] = {
0x101010101010101, 0x202020202020202, 0x404040404040404, 0x808080808080808, 0x1010101010101010, 0x2020202020202020, 0x4040404040404040, 0x8080808080808080,
};
constexpr u64 MASK_RANK[8] = {
0xff, 0xff00, 0xff0000, 0xff000000, 0xff00000000, 0xff0000000000, 0xff000000000000, 0xff00000000000000
};
//Inverts the color (WHITE -> BLACK) and (BLACK -> WHITE)
constexpr Color operator~(const Color c) { return static_cast<Color>(c ^ 1); }
inline Square& operator++(Square& s) { return s = static_cast<Square>(static_cast<i8>(s) + 1); }
inline Square& operator--(Square& s) { return s = static_cast<Square>(static_cast<i8>(s) - 1); }
constexpr Square operator+(const Square s, const Direction d) { return static_cast<Square>(static_cast<i8>(s) + static_cast<i8>(d)); }
constexpr Square operator-(const Square s, const Direction d) { return static_cast<Square>(static_cast<i8>(s) - static_cast<i8>(d)); }
inline Square& operator+=(Square& s, const Direction d) { return s = s + d; }
inline Square& operator-=(Square& s, const Direction d) { return s = s - d; }
//clang-format on
enum MoveType {
STANDARD_MOVE = 0, EN_PASSANT = 0x4000, CASTLE = 0x8000, PROMOTION = 0xC000
};
constexpr std::array<Square, 4> ROOK_CASTLE_END_SQ = { d8, f8, d1, f1 };
constexpr std::array<Square, 4> KING_CASTLE_END_SQ = { c8, g8, c1, g1 };
class Move;
struct Board {
// Index is based on square, returns the piece type
std::array<PieceType, 64> mailbox;
// Indexed pawns, knights, bishops, rooks, queens, king
std::array<u64, 6> byPieces;
// Index is based on color, so black is colors[0]
std::array<u64, 2> byColor;
// En passant square
Square epSquare;
// Index KQkq
std::array<Square, 4> castling;
Color stm;
usize halfMoveClock;
usize fullMoveClock;
private:
void placePiece(Color c, PieceType pt, int sq);
void removePiece(Color c, PieceType pt, int sq);
void removePiece(Color c, int sq);
void resetMailbox();
void setCastlingRights(Color c, Square sq, bool value);
void unsetCastlingRights(Color c);
Square castleSq(Color c, bool kingside) const;
public:
u8 count(PieceType pt) const;
u64 pieces() const;
u64 pieces(Color c) const;
u64 pieces(PieceType pt) const;
u64 pieces(Color c, PieceType pt) const;
u64 pieces(PieceType pt1, PieceType pt2) const;
u64 pieces(Color c, PieceType pt1, PieceType pt2) const;
void loadFromFEN(std::string fen);
char getPieceAt(i8 i) const;
PieceType getPiece(i8 sq) const;
bool isCapture(Move m) const;
std::vector<float> asInputLayer() const;
void move(Move m);
friend std::ostream& operator<<(std::ostream& os, const Board& board);
};
}

174
src/chess/dataloader.cpp Normal file
View File

@@ -0,0 +1,174 @@
#include "../dataloader.h"
#include "board.h"
#include "../../external/fmt/format.h"
#include <filesystem>
#include <algorithm>
#include <fstream>
#include <random>
#include <chrono>
#include <omp.h>
#include "../util.h"
namespace Ember::dataloaders::chess {
BulletTextDataLoader::BulletTextDataLoader(const std::string& filePath, const u64 batchSize, const usize evalScale, const u64 threads) : DataLoader(batchSize, threads), filePath(filePath), evalScale(evalScale) {
fmt::println("Attempting to open file '{}'", filePath);
if (!std::filesystem::exists(filePath) || std::filesystem::is_directory(filePath))
exitWithMsg("Data file does not exist or is a directory: " + filePath, 1);
std::string l;
file = std::ifstream(filePath);
while (std::getline(file, l))
numSamples++;
fmt::println("Found {} positions", formatNum(numSamples));
}
void BulletTextDataLoader::loadBatch(const usize batchIdx) {
data[batchIdx].input.resize(batchSize, static_cast<usize>(2 * 6 * 64));
data[batchIdx].target.resize(batchSize, static_cast<usize>(1));
data[batchIdx].target.fill(0);
data[batchIdx].input.fill(0);
std::vector<std::vector<internal::DataPoint>> localData(threads);
std::string l;
std::vector<std::string> lines;
u64 linesRead = 0;
// Load lines into a buffer
while (linesRead < batchSize) {
if (!std::getline(file, l)) {
file = std::ifstream(filePath);
continue;
}
// Discard null chars because windows uses encoding that
// places a \0 after every character
std::erase_if(l, [](const char c) { return c == '\0'; });
if (std::ranges::all_of(l.begin(), l.end(), [](const char c) { return std::isspace(c); }))
continue;
lines.emplace_back(l);
linesRead++;
}
assert(lines.size() == batchSize);
std::vector<u64> shuffledIndexes;
shuffledIndexes.reserve(batchSize);
// This is for batch shuffling
for (usize i = 0; i < batchSize; i++)
shuffledIndexes.push_back(i);
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
std::default_random_engine rng(seed);
// Shuffle the vector
std::ranges::shuffle(shuffledIndexes, rng);
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
for (usize i = 0; i < batchSize; i++) {
std::string& line = lines[shuffledIndexes[i]];
// Strip UTF-16 BOM if present
if (line.size() >= 2 && static_cast<unsigned char>(line[0]) == 0xFF && static_cast<unsigned char>(line[1]) == 0xFE)
line = line.substr(2);
const auto tokens = split(line, '|');
if (tokens.size() != 3)
exitWithMsg(fmt::format("Expected 3 tokens, got {}. Failed to parse line: {}", tokens.size(), line), 1);
assert(tokens.size() == 3);
const std::string& fen = tokens[0];
const float eval = std::stof(tokens[1]);
// Token 2 is discarded b/c it's the WDL which is not
// used yet
Ember::chess::Board board{};
board.loadFromFEN(fen);
std::vector<float> input = board.asInputLayer();
std::memcpy(&data[batchIdx].input[i, 0], input.data(), sizeof(float) * input.size());
data[batchIdx].target[i, 0] = eval * evalScale;
}
}
void BulletTextDataLoader::loadTestSet() {
// Local file object, always reads the
// first batchSize samples for
// more consistent results
std::ifstream file(filePath);
data[currBatch].input.resize(batchSize, static_cast<usize>(2 * 6 * 64));
data[currBatch].target.resize(batchSize, static_cast<usize>(1));
data[currBatch].target.fill(0);
data[currBatch].input.fill(0);
std::vector<std::vector<internal::DataPoint>> localData(threads);
std::string l;
std::vector<std::string> lines;
u64 linesRead = 0;
// Load lines into a buffer
while (linesRead < batchSize) {
if (!std::getline(file, l))
exitWithMsg(fmt::format("Failed to load test batch. Is the data corrupted? Are there at least {} data points?", batchSize), 1);
// Discard null chars because windows uses encoding that
// places a \0 after every character
std::erase_if(l, [](const char c) { return c == '\0'; });
if (std::ranges::all_of(l.begin(), l.end(), [](const char c) { return std::isspace(c); }))
continue;
lines.emplace_back(l);
linesRead++;
}
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
for (usize i = 0; i < batchSize; i++) {
std::string& line = lines[i];
// Strip UTF-16 BOM if present
if (line.size() >= 2 && static_cast<unsigned char>(line[0]) == 0xFF && static_cast<unsigned char>(line[1]) == 0xFE)
line = line.substr(2);
const auto tokens = split(line, '|');
if (tokens.size() != 3)
exitWithMsg(fmt::format("Expected 3 tokens, got {}. Failed to parse line: {}", tokens.size(), line), 1);
assert(tokens.size() == 3);
const std::string& fen = tokens[0];
const float eval = std::stof(tokens[1]);
Ember::chess::Board board{};
board.loadFromFEN(fen);
std::vector<float> input = board.asInputLayer();
std::memcpy(&data[currBatch].input[i, 0], input.data(), sizeof(float) * input.size());
data[currBatch].target[i, 0] = eval * evalScale;
}
}
u64 BulletTextDataLoader::countCorrect(const Tensor& output, const Tensor& target) {
u64 numCorrect = 0;
for (usize i = 0; i < target.dim(0); i++)
numCorrect += (std::round(output[i, 0] / evalScale) == std::round(target[i, 0] / evalScale));
return numCorrect;
}
}

View File

@@ -38,17 +38,12 @@ std::vector<float> loadGreyscaleImage(const std::string& path, const Ember::usiz
}
namespace Ember::dataloaders {
ImageDataLoader::ImageDataLoader(const std::string& dataDir, const u64 batchSize, const float trainSplit, const u64 threads, const usize width, const usize height)
: DataLoader(batchSize, trainSplit, threads) {
this->width = width;
this->height = height;
ImageDataLoader::ImageDataLoader(const std::string& dataDir, const u64 batchSize, const u64 threads, const float trainSplit, const usize width, const usize height)
: DataLoader(batchSize, threads), dataDir(dataDir), trainSplit(trainSplit), width(width), height(height) {
fmt::println("Attempting to open data dir '{}'", dataDir);
if (!std::filesystem::exists(dataDir) || !std::filesystem::is_directory(dataDir))
exitWithMsg("Data directory does not exist or is not a directory: " + dataDir, 1);
this->dataDir = dataDir;
for (const auto &entry: std::filesystem::directory_iterator(this->dataDir)) {
if (entry.is_directory())
types.push_back(entry.path().string());
@@ -95,7 +90,7 @@ namespace Ember::dataloaders {
std::vector<std::vector<internal::DataPoint>> localData(threads);
#pragma omp parallel for num_threads(threads)
#pragma omp parallel for num_threads(std::max<usize>(threads, 1))
for (usize i = 0; i < batchSize; i++) {
std::mt19937 rng{ std::random_device{}() + omp_get_thread_num()};
@@ -142,4 +137,22 @@ namespace Ember::dataloaders {
}
}
}
}
u64 ImageDataLoader::countCorrect(const Tensor& output, const Tensor& target) {
u64 numCorrect = 0;
for (usize i = 0; i < target.dim(0); i++) {
usize guess = 0;
usize goal = 0;
for (usize j = 0; j < target.dim(1); j++) {
if (output[i, j] > output[i, guess])
guess = j;
if (target[i, j] > target[i, goal])
goal = j;
}
numCorrect += (guess == goal);
}
return numCorrect;
}
}

View File

@@ -3,6 +3,7 @@
#include "types.h"
#include "tensor.h"
#include <fstream>
#include <vector>
#include <future>
#include <random>
@@ -21,22 +22,17 @@ namespace Ember {
struct DataLoader {
u64 threads;
u64 batchSize;
float trainSplit;
u64 numSamples;
u64 numSamples = 0;
usize currBatch = 0;
usize currBatch;
std::future<void> dataFuture;
std::array<DataPoint, 2> data;
DataLoader(const u64 batchSize, const float trainSplit, const u64 threads) {
DataLoader(const u64 batchSize, const u64 threads) {
this->threads = threads;
this->batchSize = batchSize;
this->trainSplit = trainSplit;
this->numSamples = 0;
this->currBatch = 0;
}
// Loads batch into other buffer
@@ -57,10 +53,14 @@ namespace Ember {
return data[currBatch];
}
void swapBuffers() {
virtual void swapBuffers() {
currBatch ^= 1;
}
// Returns the number of "correct" outputs
// from the network
virtual u64 countCorrect(const Tensor& output, const Tensor& target) = 0;
virtual ~DataLoader() = default;
};
}
@@ -72,17 +72,46 @@ namespace Ember {
std::vector<u64> samplesPerType;
std::vector<std::vector<std::string>> allImages;
std::vector<usize> trainSamplesPerType;
usize numTrainSamples;
usize numTestSamples;
std::vector<u64> trainSamplesPerType;
u64 numTrainSamples;
u64 numTestSamples;
float trainSplit;
usize width;
usize height;
ImageDataLoader(const std::string& dataDir, const u64 batchSize, const float trainSplit, const u64 threads = 0, const usize width = 0, const usize height = 0);
ImageDataLoader(const std::string& dataDir, const u64 batchSize, const u64 threads, const float trainSplit, const usize width = 0, const usize height = 0);
void loadBatch(const usize batchIdx) override;
void loadTestSet() override;
u64 countCorrect(const Tensor& output, const Tensor& target) override;
};
// Defined in ./chess/*
namespace chess {
struct BulletTextDataLoader : internal::DataLoader {
std::string filePath;
u64 batchNumber = 0;
usize evalScale = 0;
std::ifstream file;
BulletTextDataLoader(const std::string& filePath, const u64 batchSize, const usize evalScale, const u64 threads = 0);
void loadBatch(const usize batchIdx) override;
void loadTestSet() override;
u64 countCorrect(const Tensor& output, const Tensor& target) override;
void swapBuffers() override {
batchNumber++;
currBatch ^= 1;
}
};
}
}
}

View File

@@ -53,7 +53,6 @@ namespace Ember {
// Returns { test loss, test accuracy }
const auto getTestLossAcc = [&]() {
usize numCorrect = 0;
dataLoader.loadTestSet();
const internal::DataPoint& data = dataLoader.batchData();
const usize testSize = data.input.dim(0);
@@ -62,17 +61,7 @@ namespace Ember {
const float loss = lossFunc->forward(net.output(), data.target);
for (usize i = 0; i < data.target.dim(0); i++) {
usize guess = 0;
usize goal = 0;
for (usize j = 0; j < data.target.dim(1); j++) {
if (net.output()[i, j] > net.output()[i, guess])
guess = j;
if (data.target[i, j] > data.target[i, goal])
goal = j;
}
numCorrect += (guess == goal);
}
const u64 numCorrect = dataLoader.countCorrect(net.output(), data.target);
return std::pair<float, float>{ loss, numCorrect / static_cast<float>(testSize ? testSize : 1) };
};
@@ -160,7 +149,7 @@ namespace Ember {
internal::cursor::up();
internal::cursor::up();
internal::cursor::begin();
fmt::println("{:>5L}{:>14.5f}{:>13}{:>17}{:>12}", epoch, trainLoss / currentBatch, "Pending", "Pending", formatTime(stopwatch.elapsed()));
fmt::print("{:>5L}{:>14.5f}{:>13}{:>17}{:>12}\n", epoch, trainLoss / currentBatch, "Pending", "Pending", formatTime(stopwatch.elapsed()));
std::cout << progressBar.report(currentBatch + 1, batchesPerEpoch, 63) << " " << std::endl;
afterBatch:

View File

@@ -7,20 +7,56 @@ namespace Ember::loss {
float loss = 0;
for (usize i = 0; i < output.size(); i++)
loss += std::pow(output.data[i] - target.data[i], 2);
return loss / output.dim(1);
return loss / output.size();
}
Tensor MeanSquaredError::backward(const Tensor& output, const Tensor& target) {
Tensor gradient;
gradient.resize(output.dims());
const float scalar = 2.0f / output.dim(1);
const float scalar = 2.0f / output.size();
for (usize i = 0; i < output.size(); i++)
gradient[i] = (output.data[i] - target.data[i]) * scalar;
return gradient;
}
float SigmoidMSE::forward(const Tensor& output, const Tensor& target) {
assert(output.size() == target.size());
float loss = 0;
for (usize i = 0; i < output.size(); i++) {
const float imprecision = std::abs(sigmoid(output.data[i]) - sigmoid(target.data[i]));
loss += std::pow(imprecision, 2);
}
return loss / output.size() - offset;
}
Tensor SigmoidMSE::backward(const Tensor& output, const Tensor& target) {
assert(output.size() == target.size());
Tensor gradient;
gradient.resize(output.dims());
const float scalar = 2.0f / output.size();
for (usize i = 0; i < output.size(); i++) {
const float expOutput = std::exp(a + b * output.data[i]);
const float expTarget = std::exp(a + b * target.data[i]);
const float fOutput = k / (1.0f + expOutput);
const float fTarget = k / (1.0f + expTarget);
const float fprime_out = -k * b * expOutput / ((1.0f + expOutput) * (1.0f + expOutput));
gradient.data[i] = scalar * (fOutput - fTarget) * fprime_out;
}
return gradient;
}
float CrossEntropyLoss::forward(const Tensor& output, const Tensor& target) {
assert(output.size() == target.size());
@@ -31,7 +67,7 @@ namespace Ember::loss {
prob = std::min(prob, 1.0f);
loss -= target.data[i] * std::log(prob);
}
return loss / output.dim(0);
return loss / output.size();
}
Tensor CrossEntropyLoss::backward(const Tensor& output, const Tensor& target) {
@@ -40,7 +76,7 @@ namespace Ember::loss {
Tensor gradient;
gradient.resize(output.dims());
const float scalar = 1.0f / output.dim(0);
const float scalar = 1.0f / output.size();
for (usize i = 0; i < output.size(); i++) {
const float prob = std::max(output.data[i], 1e-10f);
gradient.data[i] = -target.data[i] / prob * scalar;

View File

@@ -19,6 +19,31 @@ namespace Ember {
Tensor backward(const Tensor& output, const Tensor& target) override;
};
// Apply sigmoid to the input and target before
// calculating the loss using MSE
// Modeled by the below function, paste into Desmos to see it
// f\left(x\right)=\left(\frac{k}{1+e^{\left(a+bx\right)}}\right)
struct SigmoidMSE : internal::LossFunction {
float a = 1;
float b = -0.25;
float k = 1;
float offset;
SigmoidMSE(const float a, const float b, const float k) : a(a), b(b), k(k) {
offset = -std::pow(sigmoid(0), 2);
}
explicit SigmoidMSE(const float horizontalStretch) {
b /= horizontalStretch;
offset = -std::pow(sigmoid(0), 2);
}
float sigmoid(const float x) const { return k / (1 + std::exp(a + b * x)); }
float forward(const Tensor& output, const Tensor& target) override;
Tensor backward(const Tensor& output, const Tensor& target) override;
};
struct CrossEntropyLoss : internal::LossFunction {
float forward(const Tensor& output, const Tensor& target) override;
Tensor backward(const Tensor& output, const Tensor& target) override;

View File

@@ -2,6 +2,8 @@
#include "types.h"
#include <sstream>
namespace Ember {
// Formats a number with commas
inline std::string formatNum(const i64 v) {
@@ -17,4 +19,19 @@ namespace Ember {
return s;
}
inline std::vector<std::string> split(const std::string& str, const char delim) {
std::vector<std::string> result;
std::istringstream stream(str);
for (std::string token{}; std::getline(stream, token, delim);) {
if (token.empty())
continue;
result.push_back(token);
}
return result;
}
}