From 4c67602bebb1631fb1b60c594d238b22ddde7a33 Mon Sep 17 00:00:00 2001 From: Modestas Valauskas Date: Thu, 30 Apr 2026 17:28:27 +0200 Subject: [PATCH 1/2] Vectorize Encoder vocabulary build and bulk encode with numpy The original Encoder had two pure-Python passes: it built the vocab dict by iterating every character and dispatching a dict membership test, and its encode() ran a per-character dict lookup. Both are O(N) in Python. On Shakespeare (~1M chars) that is fine. On larger corpora (e.g. a 500MB wiki dump, ~520M chars) those Python loops dominate startup. This commit: - Builds the vocabulary as `sorted(set(text))` (one pass in C). - Adds `encode_array(text)` which converts text to UTF-32-LE bytes via the C codec, views the buffer as `uint32` codepoints, and gathers through a precomputed lookup table indexed by ord(c). Output is a `np.ndarray[int32]` ready to be moved to a torch tensor. - Caches the inverse dict so decode() does not rebuild it on every call. - Bulk encoding is chunked (default 4M chars/block) so peak transient memory stays bounded for very large corpora. Existing public API is preserved: `encode`, `decode`, and `vocab` behave the same. `encode_array` is additive. Measured on a 500 MB wiki dump (522M characters): vocab build: ~2.6 s (was estimated at minutes) encode_array: ~0.8 s (was estimated at minutes) --- src/encoder.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/src/encoder.py b/src/encoder.py index 71a2c9a..b833e8c 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -1,31 +1,39 @@ -from torch import nn +import numpy as np + class Encoder: def __init__(self, path): with open(path) as f: - lines = f.readlines() - - text = "\n".join(lines) - - self.dictionary = {} + text = f.read() - index = 1 + chars = sorted(set(text)) self.dictionary = {"[MASK]": 0} - for char in text: - if self.dictionary.get(char): + for i, c in enumerate(chars, start=1): + self.dictionary[c] = i + self._inv = {v: k for k, v in self.dictionary.items()} + + max_ord = max((ord(c) for c in chars), default=0) + self._lookup = np.zeros(max_ord + 1, dtype=np.int32) + for c, i in self.dictionary.items(): + if c == "[MASK]": continue - - self.dictionary[char] = index - index += 1 - + self._lookup[ord(c)] = i + def encode(self, text): return [self.dictionary[char] for char in text] - + + def encode_array(self, text, chunk_chars=4_000_000): + out = np.empty(len(text), dtype=np.int32) + for i in range(0, len(text), chunk_chars): + block = text[i:i + chunk_chars] + cps = np.frombuffer(block.encode("utf-32-le"), dtype=np.uint32) + out[i:i + len(cps)] = self._lookup[cps] + return out + def decode(self, encoded) -> list[int]: - inv_dict = {v: k for k, v in self.dictionary.items()} - return [inv_dict[encoding] for encoding in encoded] - + return [self._inv[e] for e in encoded] + def vocab(self): return self.dictionary.keys() From 84b8639bd99d84ab59e17a705f1fc4981b7751ab Mon Sep 17 00:00:00 2001 From: Modestas Valauskas Date: Thu, 30 Apr 2026 17:28:39 +0200 Subject: [PATCH 2/2] Make dataset path configurable and add training CLI flags Adds a small set of CLI knobs needed to point training at a different corpus and to recover from interruptions, plus a few correctness/perf tweaks that come along for the ride: - `Transformer(data_path=...)` constructor argument; previously the path was hardcoded to "data/input.txt". Threaded through train.py, sample.py, and export_onnx.py via a `--data` flag (default unchanged). - `--batch-size` and `--seq-len` flags so hyperparameters can be tuned without editing source. - `--resume ` flag that loads a saved state_dict before training. Useful for picking up a long run after a crash, machine reboot, or any other interruption. Only the model weights are restored; the optimizer state and step counter are not. - Use the new `encoder.encode_array()` and store the corpus as `int32` on device. The vocabulary easily fits in 32 bits (this PR's wiki sample has ~720 chars, the full HF wiki dump has ~7000), so int64 was wasting 50% of corpus memory. On a 500MB wiki corpus this saves ~2 GB of device memory. - Read the corpus with `f.read()` instead of `"\n".join(f.readlines())`. The old form silently doubled every newline. No vocab change, the encoder was building from the same join'd text. Sanity-checked: training on tiny Shakespeare with default flags gives the same it/s and matching loss curve as before. --- src/export_onnx.py | 7 ++++--- src/model.py | 4 ++-- src/sample.py | 3 ++- src/train.py | 28 +++++++++++++++++----------- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/export_onnx.py b/src/export_onnx.py index c2033f2..5558e95 100644 --- a/src/export_onnx.py +++ b/src/export_onnx.py @@ -7,11 +7,11 @@ from src.model import Transformer -def export(checkpoint: str, out_dir: str, seq_len: int): +def export(checkpoint: str, out_dir: str, seq_len: int, data: str): out = Path(out_dir) out.mkdir(parents=True, exist_ok=True) - model = Transformer() + model = Transformer(data_path=data) model.load_state_dict(torch.load(checkpoint, map_location="cpu")) model.eval() @@ -39,8 +39,9 @@ def main(): p.add_argument("--checkpoint", default="checkpoints/checkpoint.pt") p.add_argument("--out", default="web/public") p.add_argument("--seq-len", type=int, default=128) + p.add_argument("--data", default="data/input.txt") args = p.parse_args() - export(args.checkpoint, args.out, args.seq_len) + export(args.checkpoint, args.out, args.seq_len, args.data) if __name__ == "__main__": diff --git a/src/model.py b/src/model.py index 83bf56b..1cbfa20 100644 --- a/src/model.py +++ b/src/model.py @@ -3,10 +3,10 @@ class Transformer(nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, data_path="data/input.txt", *args, **kwargs): super().__init__(*args, **kwargs) - self.encoder = Encoder("data/input.txt") + self.encoder = Encoder(data_path) vocab_size = len(self.encoder.vocab()) max_seq_len = 1024 diff --git a/src/sample.py b/src/sample.py index 26a28d4..ea9ae27 100644 --- a/src/sample.py +++ b/src/sample.py @@ -36,10 +36,11 @@ def main(): p.add_argument("--query", default="To be, ") p.add_argument("--length", type=int, default=64) p.add_argument("--device", default="mps") + p.add_argument("--data", default="data/input.txt") args = p.parse_args() device = torch.device(args.device) - model = Transformer().to(device) + model = Transformer(data_path=args.data).to(device) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) model.eval() diff --git a/src/train.py b/src/train.py index 6565563..9905fb4 100644 --- a/src/train.py +++ b/src/train.py @@ -10,31 +10,37 @@ def main(): p = argparse.ArgumentParser() p.add_argument("--device", default="mps") + p.add_argument("--data", default="data/input.txt") + p.add_argument("--batch-size", type=int, default=64) + p.add_argument("--seq-len", type=int, default=128) + p.add_argument("--resume", default=None, help="path to checkpoint.pt to resume from") args = p.parse_args() os.makedirs("checkpoints", exist_ok=True) - data = "data/input.txt" + data = args.data device = torch.device(args.device) with open(data) as f: - lines = f.readlines() + text = f.read() - text = "\n".join(lines) - - seq_len = 128 + seq_len = args.seq_len iterations = 10000000 - batch_size = 64 + batch_size = args.batch_size - transformer = Transformer().to(device) + transformer = Transformer(data_path=data).to(device) + if args.resume: + transformer.load_state_dict(torch.load(args.resume, map_location=device)) + print(f"resumed weights from {args.resume}") optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4) - corpus = torch.tensor(transformer.encoder.encode(text), dtype=torch.long, device=device) + corpus_np = transformer.encoder.encode_array(text) + corpus = torch.from_numpy(corpus_np).to(device=device, dtype=torch.int32) corpus_len = corpus.shape[0] - arange_seq = torch.arange(seq_len, device=device) + arange_seq = torch.arange(seq_len, device=device, dtype=torch.int64) def grab_batch() -> torch.Tensor: - starts = torch.randint(0, corpus_len - seq_len, (batch_size,), device=device) - return corpus[starts[:, None] + arange_seq[None, :]] + starts = torch.randint(0, corpus_len - seq_len, (batch_size,), device=device, dtype=torch.int64) + return corpus[starts[:, None] + arange_seq[None, :]].long() def add_noise(input: list[int], t: float) -> str: mask = (torch.rand(batch_size, seq_len, device=device) < mask_prob).long()