Abstract OpenBLAS calls into Tensor.madd()
This commit is contained in:
@@ -4,26 +4,31 @@
|
||||
|
||||
int main() {
|
||||
Ember::Network net(
|
||||
Ember::layers::Input(2 * 6 * 64),
|
||||
Ember::layers::Linear(128),
|
||||
Ember::layers::Input(28, 28, 1),
|
||||
Ember::layers::Convolution(64, 3),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::Linear(1)
|
||||
Ember::layers::MaxPool(),
|
||||
Ember::layers::Flatten(),
|
||||
Ember::layers::Linear(512),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::Linear(64),
|
||||
Ember::activations::ReLU(),
|
||||
Ember::layers::Linear(10),
|
||||
Ember::activations::Softmax()
|
||||
);
|
||||
|
||||
constexpr Ember::usize evalScale = 400;
|
||||
|
||||
Ember::dataloaders::chess::BulletTextDataLoader dataloader("../datasets/preludeData.txt", 1024 * 16, evalScale, 1);
|
||||
Ember::dataloaders::ImageDataLoader dataloader("../datasets/FashionMNIST/", 64, 6, 0.9, 28, 28);
|
||||
Ember::optimizers::Adam optimizer(net);
|
||||
|
||||
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::SigmoidMSE(evalScale));
|
||||
Ember::Learner learner(net, dataloader, optimizer, Ember::loss::CrossEntropyLoss());
|
||||
|
||||
std::cout << net << std::endl;
|
||||
|
||||
learner.addCallbacks(
|
||||
Ember::callbacks::DropLROnPlateau(3, 0.3, Ember::Metric::TRAIN_LOSS),
|
||||
Ember::callbacks::StopWhenNoProgress(5, Ember::Metric::TRAIN_LOSS),
|
||||
Ember::callbacks::DropLROnPlateau(3, 0.3),
|
||||
Ember::callbacks::StopWhenNoProgress(5),
|
||||
Ember::callbacks::AutosaveBest("../net.bin")
|
||||
);
|
||||
|
||||
learner.learn(0.005, 40, 1);
|
||||
learner.learn(0.001, 20, 1);
|
||||
}
|
||||
@@ -95,7 +95,7 @@ namespace Ember::layers {
|
||||
for (usize kx = 0; kx < kernelSize; kx++) {
|
||||
const usize ix = ox * stride + kx;
|
||||
const usize iy = oy * stride + ky;
|
||||
rowPtr[idx++] = previous.values[i, ix, iy];
|
||||
rowPtr[idx++] = previous.values[i, ix, iy, ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -149,7 +149,7 @@ namespace Ember::layers {
|
||||
for (usize kx = 0; kx < kernelSize; kx++) {
|
||||
const usize ix = ox * stride + kx;
|
||||
const usize iy = oy * stride + ky;
|
||||
rowPtr[idx++] = previous.values[i, ix, iy];
|
||||
rowPtr[idx++] = previous.values[i, ix, iy, ch];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -162,17 +162,16 @@ namespace Ember::layers {
|
||||
// gradOutput: (rows x numKernels)
|
||||
// localPatch: (rows x cols)
|
||||
// weightGrad: (numKernels x cols)
|
||||
cblas_sgemm(
|
||||
CblasRowMajor, CblasTrans, CblasNoTrans,
|
||||
weightGrad.madd(
|
||||
CblasTrans, CblasNoTrans,
|
||||
numKernels, cols, rows,
|
||||
1.0f,
|
||||
goPtr, numKernels,
|
||||
localPatch.data(), cols,
|
||||
(i == 0 ? 0.0f : 1.0f),
|
||||
weightGrad.ptr(), cols
|
||||
(i == 0 ? 0.0f : 1.0f)
|
||||
);
|
||||
|
||||
cblas_sgemm(
|
||||
sgemm(
|
||||
CblasRowMajor, CblasNoTrans, CblasNoTrans,
|
||||
rows, cols, numKernels,
|
||||
1.0f,
|
||||
|
||||
54
src/layer.h
54
src/layer.h
@@ -3,7 +3,6 @@
|
||||
#include "tensor.h"
|
||||
|
||||
#include <utility>
|
||||
#include <cblas.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
@@ -128,29 +127,12 @@ namespace Ember {
|
||||
// Fill values in the current layer
|
||||
void forward(const Layer& previous) override {
|
||||
const usize batchSize = values.dim(0);
|
||||
const usize inputSize = previous.values.size() / batchSize;
|
||||
const usize outputSize = values.size() / batchSize;
|
||||
|
||||
for (usize i = 0; i < batchSize; i++)
|
||||
std::memcpy(&values[i, 0], biases.ptr(), outputSize * sizeof(float));
|
||||
|
||||
// Batched matmul across all
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // previous.values: batch x inputSize
|
||||
CblasTrans, // weights: inputSize x outputSize
|
||||
static_cast<int>(batchSize),
|
||||
static_cast<int>(outputSize),
|
||||
static_cast<int>(inputSize),
|
||||
1.0f,
|
||||
previous.values.ptr(),
|
||||
static_cast<int>(inputSize),
|
||||
weights.ptr(),
|
||||
static_cast<int>(inputSize),
|
||||
1.0f,
|
||||
values.ptr(),
|
||||
static_cast<int>(outputSize)
|
||||
);
|
||||
values.madd(previous.values, weights, false, true);
|
||||
}
|
||||
|
||||
// Returns gradInput, weightGrad, biasGrad
|
||||
@@ -164,40 +146,10 @@ namespace Ember {
|
||||
Tensor biasGrad(outputSize);
|
||||
|
||||
// gradInput = (batch x outputSize) * (outputSize x inputSize)
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasNoTrans, // A = gradOutput
|
||||
CblasNoTrans, // B = weights
|
||||
static_cast<int>(batchSize),
|
||||
static_cast<int>(inputSize),
|
||||
static_cast<int>(outputSize),
|
||||
1.0f,
|
||||
gradOutput.ptr(),
|
||||
static_cast<int>(outputSize),
|
||||
weights.ptr(),
|
||||
static_cast<int>(inputSize),
|
||||
0.0f,
|
||||
gradInput.ptr(),
|
||||
static_cast<int>(inputSize)
|
||||
);
|
||||
gradInput.madd(gradOutput, weights);
|
||||
|
||||
// weightGrad = (outputSize x batch) * (batch x inputSize)
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
CblasTrans, // A = gradOutput
|
||||
CblasNoTrans, // B = previous.values
|
||||
static_cast<int>(outputSize),
|
||||
static_cast<int>(inputSize),
|
||||
static_cast<int>(batchSize),
|
||||
1.0f,
|
||||
gradOutput.ptr(),
|
||||
static_cast<int>(outputSize),
|
||||
previous.values.ptr(),
|
||||
static_cast<int>(inputSize),
|
||||
0.0f,
|
||||
weightGrad.ptr(),
|
||||
static_cast<int>(inputSize)
|
||||
);
|
||||
weightGrad.madd(gradOutput, previous.values, true, false);
|
||||
|
||||
// Sum over batch of gradOutput
|
||||
for (usize i = 0; i < batchSize; i++)
|
||||
|
||||
64
src/tensor.h
64
src/tensor.h
@@ -4,10 +4,13 @@
|
||||
|
||||
#include "../external/fmt/format.h"
|
||||
|
||||
#include <cblas.h>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
namespace Ember {
|
||||
#define sgemm cblas_sgemm;
|
||||
|
||||
namespace internal {
|
||||
template <typename T>
|
||||
concept UsizeLike = std::is_same_v<std::decay_t<T>, usize>;
|
||||
@@ -168,5 +171,66 @@ namespace Ember {
|
||||
((idx += static_cast<usize>(args) * strides[strideIdx++]), ...);
|
||||
return data[idx];
|
||||
}
|
||||
|
||||
|
||||
// Matrix operations
|
||||
|
||||
// Compute a * b then add to the current tensor
|
||||
void madd(const Tensor& a, const Tensor& b) { madd(a, b, false, false); }
|
||||
// Compute a * b then add to the current tensor
|
||||
void madd(const Tensor& a, const Tensor& b, const bool transposeA, const bool transposeB) {
|
||||
assert(dimensionality == 2);
|
||||
assert(a.dimensionality == 2);
|
||||
assert(b.dimensionality == 2);
|
||||
|
||||
// Logical dimensions for op(A) and op(B)
|
||||
const usize aRows = transposeA ? a.dim(1) : a.dim(0);
|
||||
const usize aCols = transposeA ? a.dim(0) : a.dim(1);
|
||||
const usize bRows = transposeB ? b.dim(1) : b.dim(0);
|
||||
const usize bCols = transposeB ? b.dim(0) : b.dim(1);
|
||||
|
||||
// Ensure dimensions are compatible for C = op(A) * op(B) + C
|
||||
assert(aCols == bRows);
|
||||
assert(this->dim(0) == aRows);
|
||||
assert(this->dim(1) == bCols);
|
||||
|
||||
// Matrix multiplication parameters
|
||||
const int M = static_cast<int>(this->dim(0)); // rows of C / op(A)
|
||||
const int N = static_cast<int>(this->dim(1)); // cols of C / op(B)
|
||||
const int K = static_cast<int>(aCols); // inner dimension
|
||||
|
||||
const auto transA = transposeA ? CblasTrans : CblasNoTrans;
|
||||
const auto transB = transposeB ? CblasTrans : CblasNoTrans;
|
||||
|
||||
// Leading dimensions assuming row major
|
||||
const int lda = static_cast<int>(a.dim(1));
|
||||
const int ldb = static_cast<int>(b.dim(1));
|
||||
const int ldc = static_cast<int>(this->dim(1));
|
||||
|
||||
// Perform C = op(A) * op(B) + C
|
||||
cblas_sgemm(
|
||||
CblasRowMajor,
|
||||
transA, transB,
|
||||
M, N, K,
|
||||
1.0f,
|
||||
a.ptr(), lda,
|
||||
b.ptr(), ldb,
|
||||
1.0f,
|
||||
this->ptr(), ldc
|
||||
);
|
||||
}
|
||||
// Compute a * b then add to the current tensor
|
||||
void madd(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const blasint M, const blasint N, const blasint K,
|
||||
const float alpha, const float* A, const blasint lda, const float* B, const blasint ldb, const float beta) {
|
||||
cblas_sgemm(
|
||||
CblasRowMajor, transA, transB,
|
||||
M, N, K,
|
||||
alpha,
|
||||
A, lda,
|
||||
B, ldb,
|
||||
beta,
|
||||
this->ptr(), static_cast<int>(this->dim(1))
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user