Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306
Open
TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
Open
Adding UniSRec model implemented on lightweight class hierarchy with pytorch preprocessing #306TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
TOPAPEC wants to merge 7 commits intoMTSWebServices:mainfrom
Conversation
Standalone sequential recommender package, mimics ModelBase interface without touching existing rectools code. FlatSASRec - plain ID-embedding SASRec encoder. UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training (ID emb -> adaptor only -> full finetune). Uses lightweight rank_topk instead of TorchRanker, reuses SASRecDataPreparator for the data pipeline. 30 tests, smoke scripts for both models. Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via multiplication, switched to masked_fill.
New config options: - ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion - optimizer: adam / adamw - scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio) - loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t) - patience: early stopping via EarlyStopping callback + val split - data_preparator: accept custom preparator instance 31 tests passing.
added 3 commits
April 24, 2026 22:17
- Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes.
2e923df to
d68834f
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
New
rectools.fast_transformersmodule — standalone transformer sequential recommenders that work with raw torch tensors, without going throughDataset/pandas.GPU-native preprocessing.
build_sequences()builds left-padded interaction sequences entirely in torch (argsort + scatter). On ML-20M (20M interactions) this takes 0.5s vs 14.6s for the pandas-basedSASRecDataPreparator— roughly 30x faster.FlatSASRec. Pre-norm SASRec encoder with plain id-embeddings, no ItemNet hierarchy. Wraps into
FlatSASRecModel(inheritsModelBase) so it plugs into standard RecTools fit/recommend.UniSRec. Three-phase sequential recommender with pretrained text embeddings and a learnable PCA adaptor:
UniSRecModel.fit(user_ids, item_ids, timestamps)takes raw tensors end-to-end. Supports softmax/BCE/gBCE/sampled_softmax losses, Adam/AdamW, cosine warmup scheduler, gradient clipping, early stopping, checkpoint save/load. FFN blocks are configurable (conv1d, linear_gelu, linear_relu).rank_topk()— batched top-k with CSR viewed-item filtering and whitelist support.Benchmark (ML-20M, 10 epochs, softmax, Adam, n_factors=256)
UniSRec ID: +4.6% HR@10, +6.0% NDCG@10, 1.65x faster overall.
New files
Source (9 modules, 1683 lines):
rectools/fast_transformers/gpu_data.py—build_sequences,align_embeddings,GPUBatchDataset,make_dataloaderrectools/fast_transformers/net.py—FlatSASRec,SASRecBlockrectools/fast_transformers/lightning_wrap.py—FlatSASRecLightningrectools/fast_transformers/model.py—FlatSASRecModel,FlatSASRecConfigrectools/fast_transformers/ranking.py—rank_topkrectools/fast_transformers/unisrec_net.py—UniSRec,FeedForward,make_ffnrectools/fast_transformers/unisrec_lightning.py—UniSRecLightning, loss/optimizer/scheduler dispatchrectools/fast_transformers/unisrec_model.py—UniSRecModel(three-phase fit, checkpoint)Tests (143 tests, 1920 lines):
tests/fast_transformers/test_gpu_data.py— sequence building, alignment, dataset/dataloadertests/fast_transformers/test_net.py,test_lightning_wrap.py,test_model.py— FlatSASRec stacktests/fast_transformers/test_unisrec_net.py,test_unisrec_lightning.py,test_unisrec_model.py— UniSRec stacktests/fast_transformers/test_ranking.py— top-k, filtering, edge casesScripts:
scripts/compare_sasrec_unisrec.py— full benchmark with markdown report generationscripts/comparison_report.md— benchmark resultsTest plan
pytest tests/fast_transformers/ -q)FlatSASRecModelfit/recommend through the standard RecTools API on a small dataset