diff --git a/src/encoder.py b/src/encoder.py index 71a2c9a..b833e8c 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -1,31 +1,39 @@ -from torch import nn +import numpy as np + class Encoder: def __init__(self, path): with open(path) as f: - lines = f.readlines() - - text = "\n".join(lines) - - self.dictionary = {} + text = f.read() - index = 1 + chars = sorted(set(text)) self.dictionary = {"[MASK]": 0} - for char in text: - if self.dictionary.get(char): + for i, c in enumerate(chars, start=1): + self.dictionary[c] = i + self._inv = {v: k for k, v in self.dictionary.items()} + + max_ord = max((ord(c) for c in chars), default=0) + self._lookup = np.zeros(max_ord + 1, dtype=np.int32) + for c, i in self.dictionary.items(): + if c == "[MASK]": continue - - self.dictionary[char] = index - index += 1 - + self._lookup[ord(c)] = i + def encode(self, text): return [self.dictionary[char] for char in text] - + + def encode_array(self, text, chunk_chars=4_000_000): + out = np.empty(len(text), dtype=np.int32) + for i in range(0, len(text), chunk_chars): + block = text[i:i + chunk_chars] + cps = np.frombuffer(block.encode("utf-32-le"), dtype=np.uint32) + out[i:i + len(cps)] = self._lookup[cps] + return out + def decode(self, encoded) -> list[int]: - inv_dict = {v: k for k, v in self.dictionary.items()} - return [inv_dict[encoding] for encoding in encoded] - + return [self._inv[e] for e in encoded] + def vocab(self): return self.dictionary.keys() diff --git a/src/export_onnx.py b/src/export_onnx.py index c2033f2..5558e95 100644 --- a/src/export_onnx.py +++ b/src/export_onnx.py @@ -7,11 +7,11 @@ from src.model import Transformer -def export(checkpoint: str, out_dir: str, seq_len: int): +def export(checkpoint: str, out_dir: str, seq_len: int, data: str): out = Path(out_dir) out.mkdir(parents=True, exist_ok=True) - model = Transformer() + model = Transformer(data_path=data) model.load_state_dict(torch.load(checkpoint, map_location="cpu")) model.eval() @@ -39,8 +39,9 @@ def main(): p.add_argument("--checkpoint", default="checkpoints/checkpoint.pt") p.add_argument("--out", default="web/public") p.add_argument("--seq-len", type=int, default=128) + p.add_argument("--data", default="data/input.txt") args = p.parse_args() - export(args.checkpoint, args.out, args.seq_len) + export(args.checkpoint, args.out, args.seq_len, args.data) if __name__ == "__main__": diff --git a/src/model.py b/src/model.py index 83bf56b..1cbfa20 100644 --- a/src/model.py +++ b/src/model.py @@ -3,10 +3,10 @@ class Transformer(nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, data_path="data/input.txt", *args, **kwargs): super().__init__(*args, **kwargs) - self.encoder = Encoder("data/input.txt") + self.encoder = Encoder(data_path) vocab_size = len(self.encoder.vocab()) max_seq_len = 1024 diff --git a/src/sample.py b/src/sample.py index 26a28d4..ea9ae27 100644 --- a/src/sample.py +++ b/src/sample.py @@ -36,10 +36,11 @@ def main(): p.add_argument("--query", default="To be, ") p.add_argument("--length", type=int, default=64) p.add_argument("--device", default="mps") + p.add_argument("--data", default="data/input.txt") args = p.parse_args() device = torch.device(args.device) - model = Transformer().to(device) + model = Transformer(data_path=args.data).to(device) model.load_state_dict(torch.load(args.checkpoint, map_location=device)) model.eval() diff --git a/src/train.py b/src/train.py index 6565563..9905fb4 100644 --- a/src/train.py +++ b/src/train.py @@ -10,31 +10,37 @@ def main(): p = argparse.ArgumentParser() p.add_argument("--device", default="mps") + p.add_argument("--data", default="data/input.txt") + p.add_argument("--batch-size", type=int, default=64) + p.add_argument("--seq-len", type=int, default=128) + p.add_argument("--resume", default=None, help="path to checkpoint.pt to resume from") args = p.parse_args() os.makedirs("checkpoints", exist_ok=True) - data = "data/input.txt" + data = args.data device = torch.device(args.device) with open(data) as f: - lines = f.readlines() + text = f.read() - text = "\n".join(lines) - - seq_len = 128 + seq_len = args.seq_len iterations = 10000000 - batch_size = 64 + batch_size = args.batch_size - transformer = Transformer().to(device) + transformer = Transformer(data_path=data).to(device) + if args.resume: + transformer.load_state_dict(torch.load(args.resume, map_location=device)) + print(f"resumed weights from {args.resume}") optimizer = torch.optim.Adam(transformer.parameters(), lr=1e-4) - corpus = torch.tensor(transformer.encoder.encode(text), dtype=torch.long, device=device) + corpus_np = transformer.encoder.encode_array(text) + corpus = torch.from_numpy(corpus_np).to(device=device, dtype=torch.int32) corpus_len = corpus.shape[0] - arange_seq = torch.arange(seq_len, device=device) + arange_seq = torch.arange(seq_len, device=device, dtype=torch.int64) def grab_batch() -> torch.Tensor: - starts = torch.randint(0, corpus_len - seq_len, (batch_size,), device=device) - return corpus[starts[:, None] + arange_seq[None, :]] + starts = torch.randint(0, corpus_len - seq_len, (batch_size,), device=device, dtype=torch.int64) + return corpus[starts[:, None] + arange_seq[None, :]].long() def add_noise(input: list[int], t: float) -> str: mask = (torch.rand(batch_size, seq_len, device=device) < mask_prob).long()