Skip to content
3 changes: 2 additions & 1 deletion kv_cache_benchmark/kv_cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ def access_cache(self, key: str, phase: InferencePhase = InferencePhase.DECODE,
self.stats['storage_read_host_latencies'].append(timing.host)

if self.model_config.kv_cache_size_per_token > 0:
num_tokens = entry_size / self.model_config.kv_cache_size_per_token
sharded_bytes_per_token = kv_cache_size_per_token / max(1, self.tensor_parallel)
num_tokens = entry_size / sharded_bytes_per_token
self.stats['storage_tokens_processed'] += num_tokens

return location, timing.total
Expand Down
Loading