diff --git a/src/openbench/metric/metric.py b/src/openbench/metric/metric.py index 559712f..ccf1478 100644 --- a/src/openbench/metric/metric.py +++ b/src/openbench/metric/metric.py @@ -63,6 +63,11 @@ class MetricOptions(Enum): # Ref: https://en.wikipedia.org/wiki/Word_error_rate WER = "wer" + # Character Error Rate + # Evaluates transcription accuracy at character level, suitable for CJK languages + # Ref: https://en.wikipedia.org/wiki/Word_error_rate (applied at character granularity) + CER = "cer" + # Concatenated minimum-Permutation Word Error Rate # Evaluates multi-speaker transcription by finding the optimal speaker permutation # Ref: https://arxiv.org/abs/2004.09249 diff --git a/src/openbench/metric/word_error_metrics/word_error_metrics.py b/src/openbench/metric/word_error_metrics/word_error_metrics.py index 4cef2a1..ee6651f 100644 --- a/src/openbench/metric/word_error_metrics/word_error_metrics.py +++ b/src/openbench/metric/word_error_metrics/word_error_metrics.py @@ -297,6 +297,94 @@ def compute_metric(self, detail: Details) -> float: return (S + D + I) / N if N > 0 else 0.0 +def _split_to_chars(words: list[str]) -> list[str]: + """Split word-level tokens into individual characters, stripping whitespace.""" + return [ch for w in words for ch in w.strip() if ch.strip()] + + +@MetricRegistry.register_metric( + ( + PipelineType.TRANSCRIPTION, + PipelineType.ORCHESTRATION, + PipelineType.STREAMING_TRANSCRIPTION, + ), + MetricOptions.CER, +) +class CharacterErrorRate(BaseWordErrorMetric): + """Character Error Rate (CER) implementation. + + This metric evaluates transcription accuracy at the character level. + Both reference and hypothesis tokens are split into individual characters + before alignment, making it suitable for CJK languages (Chinese, Japanese, + Korean) where word boundaries are not marked by spaces. + + CER = (S + D + I) / N (same formula as WER, applied to characters) + """ + + @classmethod + def metric_name(cls) -> str: + return "cer" + + @classmethod + def metric_components(cls) -> MetricComponents: + return [ + "num_substitutions", + "num_deletions", + "num_insertions", + "num_characters", + ] + + def compute_components( + self, + reference: Transcript, + hypothesis: Transcript, + **kwargs, + ) -> dict[str, int]: + ref_words, _ = parse_diarzed_words(reference) + hyp_words, _ = parse_diarzed_words(hypothesis) + + if self.use_text_normalizer: + ref_words, _ = self.text_normalizer(words=ref_words, speakers=None) + hyp_words, _ = self.text_normalizer(words=hyp_words, speakers=None) + + ref_chars = _split_to_chars(ref_words) + hyp_chars = _split_to_chars(hyp_words) + + result = jiwer.compute_measures( + truth=" ".join(ref_chars), + hypothesis=" ".join(hyp_chars), + ) + result = AlignmentMetrics(**result) + alignments = result.ops[0] + + num_substitutions = 0 + num_deletions = 0 + num_insertions = 0 + + for alignment in alignments: + if alignment.type == "substitute": + num_substitutions += alignment.ref_end_idx - alignment.ref_start_idx + elif alignment.type == "delete": + num_deletions += alignment.ref_end_idx - alignment.ref_start_idx + elif alignment.type == "insert": + num_insertions += alignment.hyp_end_idx - alignment.hyp_start_idx + + return { + "num_substitutions": num_substitutions, + "num_deletions": num_deletions, + "num_insertions": num_insertions, + "num_characters": len(ref_chars), + } + + def compute_metric(self, detail: Details) -> float: + S = detail["num_substitutions"] + D = detail["num_deletions"] + I = detail["num_insertions"] # noqa: E741 + N = detail["num_characters"] + + return (S + D + I) / N if N > 0 else 0.0 + + @MetricRegistry.register_metric(PipelineType.ORCHESTRATION, MetricOptions.CPWER) class ConcatenatedMinimumPermutationWER(BaseWordErrorMetric): """Concatenated minimum-Permutation Word Error Rate (cpWER) implementation.