A draw-a-digit demo lives or dies on one thing nobody puts in the headline: whether the browser hands the model the same kind of picture it trained on. Here is the preprocessing math that takes my MNIST MLP from 98.85% on clean test images to 96.7% on messy freehand drawings, plus the per-row quantization that shrinks the artifact to 141 KB with no measurable accuracy loss.
The distribution problem
A model trained on clean MNIST learns one specific input distribution: digits size-normalized into a 20x20 box, anti-aliased to grayscale, then centered in a 28x28 frame by center of mass. A person drawing in a browser produces none of that. They draw off-center, too big or too small, tilted, with whatever stroke weight the pen gives them. Dump raw canvas pixels into a 28x28 grid and the demo only classifies inputs that already happen to look like MNIST, which is to say almost none of them.
I fixed this from two sides: make the model tolerant during training, and make the browser reproduce the canonical MNIST preprocessing exactly at inference. Both matter, and the second is the one people skip.
Training the model to survive real hands
During training, about 80% of every batch is randomly affine-warped before it reaches the network: rotation in , scale in , and a shift of up to pixels, all sampled per image and resampled with bilinear interpolation. The remaining 20% is left clean so the published clean-test number stays an honest measurement of the unaugmented task.
The payoff is graceful degradation rather than a cliff. Clean test accuracy is 98.85%; an augmented test set built as a proxy for messy drawn input still scores 97.62%. The model loses about a point when the input stops being perfect instead of collapsing.
- architecture
- MLP , ReLU on both hidden layers, softmax on the output (logits and cross-entropy during training, softmax applied at inference).
- parameters
- 104,768 weights plus 170 biases. Small on purpose: a CNN would reach about 99.3%, but the point here is the pipeline, not the score.
- optimizer
- Adam, learning rate cosine-annealed to 0 over 40 epochs, weight decay , batch size 128, seed 1234.
The browser pipeline, step by step
This is the load-bearing accuracy detail. Given a drawing as a grayscale field where ink is high and background is 0 (white on black, matching MNIST polarity), the browser runs the exact sequence the training data went through. The single most important step is the one almost every demo skips: centering by center of mass, not by bounding box.
Bounding box.Scan for inked pixels above threshold 8 to find the tight box. An empty canvas returns the all-background vector rather than crashing.Scale to 20x20.Compute and resize to , preserving aspect ratio.Anti-aliased downscale.Area-average (box-filter) resample into that target size, which produces the soft grayscale edges real MNIST has. A hard binary 28x28 looks nothing like the training data.Center of mass into 28x28.Translate the glyph so its center of mass lands at the frame center .Scale and standardize.Divide by 255, then apply the training-set statistics below.Flattenrow-major () into a length-784 vector.
The center of mass is computed over inked intensity:
Standardization uses the constants I computed over the training set in 0-to-1 scaled pixels, which match the canonical MNIST values:
Verified against the literature: the LeCun MNIST page is explicit that error rates differ significantly depending on centering, and center of mass is the convention the original dataset used. End to end, synthetic digits pasted off-center, scaled 2x to 5.5x, rotated with pen-thickness dilation onto a 200x200 canvas and then run through this exact pipeline score 96.7%. The demo is accurate on realistic drawing, not just on clean test images.
Quantization: 962 KB down to 141 KB
Floats as JSON text made a 962 KB artifact, over my 600 KB budget for the page. I quantize the weights to int8 with a symmetric, per-output-row scale. Per-row rather than per-tensor is what preserves accuracy: each output neuron gets its own scale, so a row with small weights is not crushed by a row with large ones.
At inference the forward pass dequantizes on the fly, , then computes with ReLU on the hidden layers and softmax on the output. Biases stay float because there are only 170 of them. The int8 model scores 98.86%, which is not a typo: per-row quantization landed within noise of the 98.85% float model. The artifact is 141 KB.
{
"format": "mnist-mlp/v1",
"arch": { "layers": [784,128,32,10], "hidden_activation": "relu" },
"preprocess": { "fit_box": 20, "frame": 28, "center": "center_of_mass",
"mean": 0.1306605, "std": 0.3081078, "flatten": "row_major" },
"metrics": { "test_accuracy_clean": 0.9885, "test_accuracy_augmented": 0.9762 },
"layers": [ { "in": 784, "out": 128, "quant": "int8_per_row", ... } ]
}Proving the three implementations agree
Three things compute this function: the PyTorch trainer, a numpy reference, and the deployed JavaScript forward pass. I cross-checked them rather than trusting that they matched. The JS forward pass agrees with the numpy reference on 200 of 200 real test vectors, and the numpy reference reproduces the torch accuracy, so the trainer, the JSON, and the 200-line zero-dependency JS loop all compute the same model.
The trainer, exporter, and IDX loader are in scripts/mnist/; the deployed format and the JS contract are documented in full in methodology.md. The from-scratch C++ kernels live at github.com/ShreeChaturvedi/fast-mnist-nn.