From c81263967b4fb474c970a581597d8f4f18f6f115 Mon Sep 17 00:00:00 2001 From: Amin Samadi Date: Tue, 21 Apr 2026 15:04:59 -0700 Subject: [PATCH] Add pluggable embedding backends --- README.md | 29 ++ src/team_comm_tools/feature_builder.py | 37 ++- .../within_person_discursive_range.py | 13 +- src/team_comm_tools/utils/check_embeddings.py | 137 ++++++++- ..._discursive_diversity_custom_embeddings.py | 67 +++++ tests/test_pluggable_embeddings.py | 267 ++++++++++++++++++ 6 files changed, 531 insertions(+), 19 deletions(-) create mode 100644 tests/test_discursive_diversity_custom_embeddings.py create mode 100644 tests/test_pluggable_embeddings.py diff --git a/README.md b/README.md index 10232b88..c579f99c 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,35 @@ my_feature_builder = FeatureBuilder( my_feature_builder.featurize() ``` +If you already generate embeddings elsewhere in your pipeline, you can supply your own encoder instead of using the default sentence-transformers model: + +```python +import numpy as np +from openai import OpenAI + +client = OpenAI() + +def openai_encoder(texts): + response = client.embeddings.create( + model="text-embedding-3-small", + input=texts, + ) + return np.array([item.embedding for item in response.data]) + +my_feature_builder = FeatureBuilder( + input_df = my_pandas_dataframe, + conversation_id_col = "conversation_id", + speaker_id_col = "speaker_id", + message_col = "message", + vector_directory = "./vector_data/", + embedding_fn = openai_encoder, + embedding_backend_id = "openai-text-embedding-3-small", + embedding_dim = 1536, +) +``` + +When a custom `embedding_fn` is provided, the package keeps the vector cache separate for that backend and does not initialize the default sentence-transformers model unless it is actually needed. + ### Data Format We accept input data in the format of a Pandas DataFrame. Your data needs to have three (3) required input columns and one optional column. diff --git a/src/team_comm_tools/feature_builder.py b/src/team_comm_tools/feature_builder.py index 759bfbda..9ba38b58 100644 --- a/src/team_comm_tools/feature_builder.py +++ b/src/team_comm_tools/feature_builder.py @@ -9,6 +9,7 @@ import time import itertools import warnings +from collections.abc import Callable # Imports from feature files and classes from team_comm_tools.utils.download_resources import download @@ -90,6 +91,15 @@ class FeatureBuilder: :type ner_cutoff: int :param regenerate_vectors: If true, regenerates vector data even if it already exists. Defaults to False. :type regenerate_vectors: bool, optional + :param embedding_fn: Optional callable that maps a list of messages to a 2D embedding array. + Defaults to None, which preserves the built-in sentence-transformers backend. + :type embedding_fn: Callable[[list[str]], np.ndarray] | None, optional + :param embedding_backend_id: Optional identifier used to keep custom vector caches distinct across + embedding backends. Ignored when `embedding_fn` is None. + :type embedding_backend_id: str | None, optional + :param embedding_dim: Optional embedding dimension for validating custom encoder output and stabilizing + custom vector cache keys. + :type embedding_dim: int | None, optional :param compute_vectors_from_preprocessed: If true, computes vectors using preprocessed text (with capitalization and punctuation removed). Defaults to False. :type compute_vectors_from_preprocessed: bool, optional @@ -137,6 +147,9 @@ def __init__( ner_training_df: pd.DataFrame = None, ner_cutoff: int = 0.9, regenerate_vectors: bool = False, + embedding_fn: Callable[[list[str]], np.ndarray] | None = None, + embedding_backend_id: str | None = None, + embedding_dim: int | None = None, compute_vectors_from_preprocessed: bool = False, custom_liwc_dictionary_path: str = '', convo_aggregation = True, @@ -174,6 +187,9 @@ def __init__( self.within_task = within_task self.ner_cutoff = ner_cutoff self.regenerate_vectors = regenerate_vectors + self.embedding_fn = embedding_fn + self.embedding_backend_id = embedding_backend_id + self.embedding_dim = embedding_dim self.convo_aggregation = convo_aggregation self.convo_methods = convo_methods self.convo_columns = convo_columns @@ -389,10 +405,25 @@ def __init__( self.output_file_path_user_level = re.sub(r'/user/', r'/output/user/', self.output_file_path_user_level) # Logic for processing vector cache - self.vect_path = vector_directory + "sentence/" + ("turns" if self.turns else "chats") + "/" + base_file_name + self.vect_path = build_vector_cache_path( + vector_directory + "sentence/" + ("turns" if self.turns else "chats") + "/" + base_file_name, + embedding_fn=self.embedding_fn, + embedding_backend_id=self.embedding_backend_id, + embedding_dim=self.embedding_dim, + ) self.bert_path = vector_directory + "sentiment/" + ("turns" if self.turns else "chats") + "/" + base_file_name - check_embeddings(self.chat_data, self.vect_path, self.bert_path, need_sentence, need_sentiment, self.regenerate_vectors, message_col = self.vector_colname) + check_embeddings( + self.chat_data, + self.vect_path, + self.bert_path, + need_sentence, + need_sentiment, + self.regenerate_vectors, + message_col=self.vector_colname, + embedding_fn=self.embedding_fn, + embedding_dim=self.embedding_dim, + ) if(need_sentence): self.vect_data = pd.read_csv(self.vect_path, encoding='mac_roman') @@ -781,4 +812,4 @@ def verify_timestamp_format(self, timestamp_col) -> None: raise ValueError( f"Column '{timestamp_col}' contains values that are neither parseable as datetime " f"nor convertible to numeric format." - ) \ No newline at end of file + ) diff --git a/src/team_comm_tools/features/within_person_discursive_range.py b/src/team_comm_tools/features/within_person_discursive_range.py index 1d83b1fb..503cd2a9 100644 --- a/src/team_comm_tools/features/within_person_discursive_range.py +++ b/src/team_comm_tools/features/within_person_discursive_range.py @@ -5,7 +5,14 @@ import warnings warnings.filterwarnings('ignore') # We get empty slice warnings for short conversations -def get_nan_vector(): +def get_nan_vector(chat_data=None): + if chat_data is not None: + for value in chat_data["message_embedding"]: + if isinstance(value, np.ndarray): + return np.zeros(value.shape, dtype=float) + if isinstance(value, (list, tuple)): + return np.zeros(len(value), dtype=float) + current_dir = os.path.dirname(__file__) nan_vector_file_path = os.path.join(current_dir, './assets/nan_vector.txt') nan_vector_file_path = os.path.abspath(nan_vector_file_path) @@ -35,13 +42,13 @@ def get_nan_vector(): def get_within_person_disc_range(chat_data, num_chunks, conversation_id_col, speaker_id_col): # Get nan vector - nan_vector = get_nan_vector() + nan_vector = get_nan_vector(chat_data) #calculate mean vector per speaker per chunk mean_vec_speaker_chunks = pd.DataFrame(chat_data.groupby([conversation_id_col, speaker_id_col, 'chunk_num']).message_embedding.apply(np.mean)).unstack('chunk_num').rename(columns={'message_embedding': 'mean_chunk_vec'}) #collapse multi-index - mean_vec_speaker_chunks.columns = ["_c".join(col).strip() for col in mean_vec_speaker_chunks.columns.values] + mean_vec_speaker_chunks.columns = ["_c".join(str(part) for part in col).strip() for col in mean_vec_speaker_chunks.columns.values] actual_num_chunks = len(mean_vec_speaker_chunks[2:].columns) # omit the first two, which is conversation_num and speaker_nickname diff --git a/src/team_comm_tools/utils/check_embeddings.py b/src/team_comm_tools/utils/check_embeddings.py index f0b364c4..3ec4e429 100644 --- a/src/team_comm_tools/utils/check_embeddings.py +++ b/src/team_comm_tools/utils/check_embeddings.py @@ -4,6 +4,7 @@ import os import pickle import warnings +import hashlib from tqdm import tqdm from pathlib import Path @@ -17,18 +18,93 @@ logging.set_verbosity(40) # only log errors -model_vect = SentenceTransformer('all-MiniLM-L6-v2') +DEFAULT_VECTOR_MODEL_NAME = 'all-MiniLM-L6-v2' +model_vect = None MODEL = f"cardiffnlp/twitter-roberta-base-sentiment-latest" -tokenizer = AutoTokenizer.from_pretrained(MODEL) -model_bert = AutoModelForSequenceClassification.from_pretrained(MODEL) +tokenizer = None +model_bert = None os.environ["TOKENIZERS_PARALLELISM"] = "false" EMOJIS_TO_PRESERVE = { "(:", "(;", "):", "/:", ":(", ":)", ":/", ";)" } + +def get_vector_model(): + global model_vect + if model_vect is None: + model_vect = SentenceTransformer(DEFAULT_VECTOR_MODEL_NAME) + return model_vect + + +def get_sentiment_model(): + global tokenizer, model_bert + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(MODEL) + if model_bert is None: + model_bert = AutoModelForSequenceClassification.from_pretrained(MODEL) + return tokenizer, model_bert + + +def default_vector_encoder(texts): + return get_vector_model().encode(texts) + + +def get_embedding_cache_suffix(embedding_fn=None, embedding_backend_id=None, embedding_dim=None): + if embedding_fn is None: + return "" + + raw_backend_id = embedding_backend_id + if not raw_backend_id: + callable_name = getattr(embedding_fn, "__qualname__", type(embedding_fn).__qualname__) + raw_backend_id = f"{embedding_fn.__module__}.{callable_name}" + if embedding_dim is not None: + raw_backend_id = f"{raw_backend_id}-dim{embedding_dim}" + + safe_backend_id = re.sub(r"[^A-Za-z0-9]+", "-", raw_backend_id).strip("-").lower()[:48] or "custom" + digest = hashlib.sha1(raw_backend_id.encode("utf-8")).hexdigest()[:12] + return f"__{safe_backend_id}-{digest}" + + +def build_vector_cache_path(vect_path: str, embedding_fn=None, embedding_backend_id=None, embedding_dim=None) -> str: + suffix = get_embedding_cache_suffix( + embedding_fn=embedding_fn, + embedding_backend_id=embedding_backend_id, + embedding_dim=embedding_dim, + ) + if not suffix: + return vect_path + + path = Path(vect_path) + return str(path.with_name(f"{path.stem}{suffix}{path.suffix}")) + + +def validate_embeddings(embeddings, expected_rows, embedding_dim=None): + embeddings = np.asarray(embeddings, dtype=float) + + if embeddings.ndim == 1: + if expected_rows != 1: + raise ValueError("Custom embedding_fn must return one embedding per input string.") + embeddings = embeddings.reshape(1, -1) + + if embeddings.ndim != 2: + raise ValueError("Custom embedding_fn must return a 2D array-like object.") + + if embeddings.shape[0] != expected_rows: + raise ValueError( + "Custom embedding_fn must return the same number of embeddings as the number of input strings." + ) + + if embedding_dim is not None and embeddings.shape[1] != embedding_dim: + raise ValueError( + f"Custom embedding_fn returned vectors with dimension {embeddings.shape[1]}, expected {embedding_dim}." + ) + + return embeddings + # Check if embeddings exist def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, need_sentence: bool, - need_sentiment: bool, regenerate_vectors: bool, message_col: str = "message"): + need_sentiment: bool, regenerate_vectors: bool, message_col: str = "message", + embedding_fn=None, embedding_dim=None): """ Check if embeddings and required lexicons exist, and generate them if they don't. @@ -54,7 +130,13 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, ne :rtype: None """ if (regenerate_vectors or (not os.path.isfile(vect_path))) and need_sentence: - generate_vect(chat_data, vect_path, message_col) + generate_vect( + chat_data, + vect_path, + message_col, + embedding_fn=embedding_fn, + embedding_dim=embedding_dim, + ) if (regenerate_vectors or (not os.path.isfile(bert_path))) and need_sentiment: generate_bert(chat_data, bert_path, message_col) @@ -63,10 +145,22 @@ def check_embeddings(chat_data: pd.DataFrame, vect_path: str, bert_path: str, ne # check whether the given vector and bert data matches length of chat data if len(vector_df) != len(chat_data): print("ERROR: The length of the vector data does not match the length of the chat data. Regenerating...") - generate_vect(chat_data, vect_path, message_col) + generate_vect( + chat_data, + vect_path, + message_col, + embedding_fn=embedding_fn, + embedding_dim=embedding_dim, + ) except FileNotFoundError: # It's OK if we don't have the path, if the sentence vectors are not necessary if need_sentence: - generate_vect(chat_data, vect_path, message_col) + generate_vect( + chat_data, + vect_path, + message_col, + embedding_fn=embedding_fn, + embedding_dim=embedding_dim, + ) try: bert_df = pd.read_csv(bert_path) @@ -337,10 +431,16 @@ def str_to_vec(str_vec): vector_list = [float(e) for e in str_vec[1:-1].split(',')] return np.array(vector_list) -def get_nan_vector(): +def get_nan_vector(embedding_fn=None, embedding_dim=None): """ Get a default value for an empty string (the "NaN vector") and returns it as a 1D np array. """ + if embedding_dim is not None: + return np.zeros(embedding_dim, dtype=float) + + if embedding_fn is not None: + return validate_embeddings(embedding_fn([""]), expected_rows=1, embedding_dim=embedding_dim)[0] + current_dir = os.path.dirname(__file__) nan_vector_file_path = os.path.join(current_dir, '../features/assets/nan_vector.txt') nan_vector_file_path = os.path.abspath(nan_vector_file_path) @@ -348,7 +448,7 @@ def get_nan_vector(): with open(nan_vector_file_path, "r") as f: return str_to_vec(f.read()) -def generate_vect(chat_data, output_path, message_col, batch_size = 64): +def generate_vect(chat_data, output_path, message_col, batch_size = 64, embedding_fn=None, embedding_dim=None): """ Generates sentence vectors for the given chat data and saves them to a CSV file. @@ -364,12 +464,21 @@ def generate_vect(chat_data, output_path, message_col, batch_size = 64): :return: None :rtype: None """ - print(f"Generating SBERT sentence vectors...") + print("Generating sentence vectors...") - nan_vector = get_nan_vector() + encoder = embedding_fn or default_vector_encoder + nan_vector = get_nan_vector(embedding_fn=embedding_fn, embedding_dim=embedding_dim) empty_to_nan = [text if text and text.strip() else None for text in chat_data[message_col].tolist()] non_empty_texts = [text for text in empty_to_nan if text is not None] - all_embeddings = [emb for i in tqdm(range(0, len(non_empty_texts), batch_size)) for emb in model_vect.encode(non_empty_texts[i:i + batch_size])] + all_embeddings = [] + for i in tqdm(range(0, len(non_empty_texts), batch_size)): + batch = non_empty_texts[i:i + batch_size] + batch_embeddings = validate_embeddings( + encoder(batch), + expected_rows=len(batch), + embedding_dim=embedding_dim, + ) + all_embeddings.extend(batch_embeddings) embeddings = np.tile(nan_vector, (len(empty_to_nan), 1)) # default embeddings to the NAN vector non_empty_index = 0 for idx, text in enumerate(empty_to_nan): @@ -423,6 +532,8 @@ def get_sentiment(texts): :rtype: pd.DataFrame """ + tokenizer, model_bert = get_sentiment_model() + # Handle and tokenize non-null and non-empty texts texts_series = pd.Series(texts) non_null_non_empty_texts = texts_series[texts_series.apply(lambda x: pd.notnull(x) and x.strip() != '')].tolist() @@ -449,4 +560,4 @@ def get_sentiment(texts): sent_df = pd.DataFrame(np.nan, index=texts_series.index, columns=['positive_bert', 'negative_bert', 'neutral_bert']) sent_df.loc[texts_series.apply(lambda x: pd.notnull(x) and x.strip() != ''), ['positive_bert', 'negative_bert', 'neutral_bert']] = non_null_sent_df.values - return sent_df \ No newline at end of file + return sent_df diff --git a/tests/test_discursive_diversity_custom_embeddings.py b/tests/test_discursive_diversity_custom_embeddings.py new file mode 100644 index 00000000..8b2e1d8e --- /dev/null +++ b/tests/test_discursive_diversity_custom_embeddings.py @@ -0,0 +1,67 @@ +from pathlib import Path +import importlib.util +import sys +import types + +import numpy as np +import pandas as pd + + +REPO_ROOT = Path(__file__).resolve().parents[1] +DISC_DIVERSITY_PATH = REPO_ROOT / "src" / "team_comm_tools" / "features" / "discursive_diversity.py" +WITHIN_PERSON_PATH = REPO_ROOT / "src" / "team_comm_tools" / "features" / "within_person_discursive_range.py" + + +def load_discursive_diversity_modules(monkeypatch): + fake_package = types.ModuleType("team_comm_tools") + fake_package.__path__ = [] + fake_features_package = types.ModuleType("team_comm_tools.features") + fake_features_package.__path__ = [] + + monkeypatch.setitem(sys.modules, "team_comm_tools", fake_package) + monkeypatch.setitem(sys.modules, "team_comm_tools.features", fake_features_package) + + disc_spec = importlib.util.spec_from_file_location( + "team_comm_tools.features.discursive_diversity", + DISC_DIVERSITY_PATH, + ) + disc_module = importlib.util.module_from_spec(disc_spec) + sys.modules[disc_spec.name] = disc_module + disc_spec.loader.exec_module(disc_module) + + within_spec = importlib.util.spec_from_file_location( + "team_comm_tools.features.within_person_discursive_range", + WITHIN_PERSON_PATH, + ) + within_module = importlib.util.module_from_spec(within_spec) + sys.modules[within_spec.name] = within_module + within_spec.loader.exec_module(within_module) + + return disc_module, within_module + + +def test_within_person_disc_range_handles_custom_embedding_dimensions(monkeypatch): + _, within_module = load_discursive_diversity_modules(monkeypatch) + + vec_a = np.array([1.0] * 1536) + vec_b = np.array([0.5] * 1536) + + chat_data = pd.DataFrame( + { + "conversation_num": ["conv1", "conv1", "conv1"], + "speaker_nickname": ["alice", "alice", "bob"], + "chunk_num": [0, 1, 0], + "message_embedding": [vec_a, vec_b, vec_a], + } + ) + + result = within_module.get_within_person_disc_range( + chat_data, + num_chunks=2, + conversation_id_col="conversation_num", + speaker_id_col="speaker_nickname", + ) + + assert list(result.index) == ["conv1"] + assert np.isfinite(result["incongruent_modulation"].iloc[0]) + assert np.isfinite(result["within_person_disc_range"].iloc[0]) diff --git a/tests/test_pluggable_embeddings.py b/tests/test_pluggable_embeddings.py new file mode 100644 index 00000000..0ee192c2 --- /dev/null +++ b/tests/test_pluggable_embeddings.py @@ -0,0 +1,267 @@ +from pathlib import Path +import ast +import importlib.util +import sys +import types + +import numpy as np +import pandas as pd + + +REPO_ROOT = Path(__file__).resolve().parents[1] +CHECK_EMBEDDINGS_PATH = REPO_ROOT / "src" / "team_comm_tools" / "utils" / "check_embeddings.py" +FEATURE_BUILDER_PATH = REPO_ROOT / "src" / "team_comm_tools" / "feature_builder.py" + + +def load_check_embeddings_module(monkeypatch): + fake_torch = types.ModuleType("torch") + fake_sentence_transformers = types.ModuleType("sentence_transformers") + fake_sentence_transformers.SentenceTransformer = lambda *args, **kwargs: types.SimpleNamespace( + encode=lambda texts: np.array([[float(len(text))] for text in texts]) + ) + fake_sentence_transformers.util = object() + + fake_transformers = types.ModuleType("transformers") + fake_transformers.AutoTokenizer = types.SimpleNamespace( + from_pretrained=lambda *args, **kwargs: object() + ) + fake_transformers.AutoModelForSequenceClassification = types.SimpleNamespace( + from_pretrained=lambda *args, **kwargs: object() + ) + fake_transformers.logging = types.SimpleNamespace(set_verbosity=lambda *args, **kwargs: None) + + fake_scipy = types.ModuleType("scipy") + fake_scipy_special = types.ModuleType("scipy.special") + fake_scipy_special.softmax = lambda values, axis=None: values + fake_scipy.special = fake_scipy_special + + monkeypatch.setitem(sys.modules, "torch", fake_torch) + monkeypatch.setitem(sys.modules, "sentence_transformers", fake_sentence_transformers) + monkeypatch.setitem(sys.modules, "transformers", fake_transformers) + monkeypatch.setitem(sys.modules, "scipy", fake_scipy) + monkeypatch.setitem(sys.modules, "scipy.special", fake_scipy_special) + + spec = importlib.util.spec_from_file_location("check_embeddings_test_module", CHECK_EMBEDDINGS_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def load_feature_builder_module(monkeypatch, check_embeddings_module): + fake_package = types.ModuleType("team_comm_tools") + fake_package.__path__ = [] + fake_utils_package = types.ModuleType("team_comm_tools.utils") + fake_utils_package.__path__ = [] + + fake_download_resources = types.ModuleType("team_comm_tools.utils.download_resources") + fake_download_resources.download = lambda: None + + fake_chat_calc = types.ModuleType("team_comm_tools.utils.calculate_chat_level_features") + fake_chat_calc.ChatLevelFeaturesCalculator = object + + fake_user_calc = types.ModuleType("team_comm_tools.utils.calculate_user_level_features") + fake_user_calc.UserLevelFeaturesCalculator = object + + fake_conv_calc = types.ModuleType("team_comm_tools.utils.calculate_conversation_level_features") + fake_conv_calc.ConversationLevelFeaturesCalculator = object + + fake_preprocess = types.ModuleType("team_comm_tools.utils.preprocess") + fake_preprocess.preprocess_conversation_columns = ( + lambda df, column_names, grouping_keys, cumulative_grouping, within_task: df + ) + fake_preprocess.remove_unhashable_cols = lambda df, column_names: df + fake_preprocess.preprocess_text_lowercase_but_retain_punctuation = lambda text: str(text).lower() + fake_preprocess.preprocess_text = lambda text: str(text).lower() + fake_preprocess.preprocess_naive_turns = lambda df, column_names: df + + fake_check_embeddings = types.ModuleType("team_comm_tools.utils.check_embeddings") + fake_check_embeddings.build_vector_cache_path = check_embeddings_module.build_vector_cache_path + fake_check_embeddings.check_embeddings = lambda *args, **kwargs: None + + feature_names = [ + "Named Entity Recognition", + "Sentiment (RoBERTa)", + "Message Length", + "Message Quantity", + "Information Exchange", + "LIWC and Other Lexicons", + "Questions", + "Conversational Repair", + "Word Type-Token Ratio", + "Proportion of First-Person Pronouns", + "Function Word Accommodation", + "Content Word Accommodation", + "Hedge", + "TextBlob Subjectivity", + "TextBlob Polarity", + "Positivity Z-Score", + "Dale-Chall Score", + "Time Difference", + "Politeness Strategies", + "Politeness / Receptiveness Markers", + "Certainty", + "Online Discussion Tags", + "Turn-Taking Index", + "Equal Participation", + "Team Burstiness", + "Conversation Level Aggregates", + "User Level Aggregates", + "Information Diversity", + ] + + fake_feature_dict = types.ModuleType("team_comm_tools.feature_dict") + fake_feature_dict.feature_dict = {} + for name in feature_names: + fake_feature_dict.feature_dict[name] = { + "vect_data": name == "Information Diversity", + "bert_sentiment_data": name == "Sentiment (RoBERTa)", + "level": "Conversation" if name in { + "Turn-Taking Index", + "Equal Participation", + "Team Burstiness", + "Conversation Level Aggregates", + "User Level Aggregates", + "Information Diversity", + } else "Chat", + "function": f"func_{name}", + "columns": [], + } + + monkeypatch.setitem(sys.modules, "team_comm_tools", fake_package) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils", fake_utils_package) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.download_resources", fake_download_resources) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.calculate_chat_level_features", fake_chat_calc) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.calculate_user_level_features", fake_user_calc) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.calculate_conversation_level_features", fake_conv_calc) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.preprocess", fake_preprocess) + monkeypatch.setitem(sys.modules, "team_comm_tools.utils.check_embeddings", fake_check_embeddings) + monkeypatch.setitem(sys.modules, "team_comm_tools.feature_dict", fake_feature_dict) + + spec = importlib.util.spec_from_file_location("feature_builder_test_module", FEATURE_BUILDER_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module.feature_dict = fake_feature_dict.feature_dict + return module + + +def test_build_vector_cache_path_keeps_default_location(monkeypatch): + check_embeddings_module = load_check_embeddings_module(monkeypatch) + default_path = "vector_data/sentence/chats/output.csv" + + assert check_embeddings_module.build_vector_cache_path(default_path) == default_path + + +def test_build_vector_cache_path_changes_when_backend_changes(monkeypatch): + check_embeddings_module = load_check_embeddings_module(monkeypatch) + + def fake_encoder(texts): + return np.array([[float(len(text)), 1.0, 2.0] for text in texts]) + + path_one = check_embeddings_module.build_vector_cache_path( + "vector_data/sentence/chats/output.csv", + embedding_fn=fake_encoder, + embedding_backend_id="openai-text-embedding-3-small", + embedding_dim=3, + ) + path_two = check_embeddings_module.build_vector_cache_path( + "vector_data/sentence/chats/output.csv", + embedding_fn=fake_encoder, + embedding_backend_id="openai-text-embedding-3-large", + embedding_dim=3, + ) + + assert path_one != path_two + assert path_one.endswith(".csv") + assert path_two.endswith(".csv") + + +def test_generate_vect_uses_custom_embedding_fn(tmp_path, monkeypatch): + check_embeddings_module = load_check_embeddings_module(monkeypatch) + calls = [] + + def fake_encoder(texts): + calls.append(list(texts)) + return np.array([[float(len(text)), float(len(text)) + 1.0] for text in texts]) + + chat_data = pd.DataFrame({"message": ["hello", " ", "world"]}) + output_path = tmp_path / "vectors.csv" + + check_embeddings_module.generate_vect( + chat_data, + str(output_path), + "message", + batch_size=8, + embedding_fn=fake_encoder, + embedding_dim=2, + ) + + output_df = pd.read_csv(output_path) + first_vector = np.array(ast.literal_eval(output_df.iloc[0]["message_embedding"])) + empty_vector = np.array(ast.literal_eval(output_df.iloc[1]["message_embedding"])) + third_vector = np.array(ast.literal_eval(output_df.iloc[2]["message_embedding"])) + + assert calls == [["hello", "world"]] + assert np.array_equal(first_vector, np.array([5.0, 6.0])) + assert np.array_equal(empty_vector, np.array([0.0, 0.0])) + assert np.array_equal(third_vector, np.array([5.0, 6.0])) + + +def test_feature_builder_passes_custom_embedding_config_into_cache_and_generation(tmp_path, monkeypatch): + check_embeddings_module = load_check_embeddings_module(monkeypatch) + feature_builder_module = load_feature_builder_module(monkeypatch, check_embeddings_module) + captured = {} + + def fake_check_embeddings(chat_data, vect_path, bert_path, need_sentence, need_sentiment, + regenerate_vectors, message_col="message", embedding_fn=None, embedding_dim=None): + captured["vect_path"] = vect_path + captured["bert_path"] = bert_path + captured["embedding_fn"] = embedding_fn + captured["embedding_dim"] = embedding_dim + + Path(vect_path).parent.mkdir(parents=True, exist_ok=True) + pd.DataFrame( + { + "message": chat_data[message_col], + "message_embedding": ["[0.0, 0.0, 0.0]"] * len(chat_data), + } + ).to_csv(vect_path, index=False) + + Path(bert_path).parent.mkdir(parents=True, exist_ok=True) + pd.DataFrame( + { + "positive_bert": [0.0] * len(chat_data), + "negative_bert": [0.0] * len(chat_data), + "neutral_bert": [1.0] * len(chat_data), + } + ).to_csv(bert_path, index=False) + + feature_builder_module.check_embeddings = fake_check_embeddings + + def fake_encoder(texts): + return np.array([[1.0, 2.0, 3.0] for _ in texts]) + + input_df = pd.DataFrame( + { + "conversation_num": [1, 1], + "speaker_nickname": ["a", "b"], + "message": ["hello", "world"], + "timestamp": [1, 2], + } + ) + + builder = feature_builder_module.FeatureBuilder( + input_df=input_df, + vector_directory=f"{tmp_path}/", + output_file_path_chat_level=str(tmp_path / "chat" / "custom_vectors.csv"), + output_file_path_conv_level=str(tmp_path / "conv" / "custom_vectors.csv"), + output_file_path_user_level=str(tmp_path / "user" / "custom_vectors.csv"), + embedding_fn=fake_encoder, + embedding_backend_id="openai-text-embedding-3-small", + embedding_dim=3, + ) + + assert captured["embedding_fn"] is fake_encoder + assert captured["embedding_dim"] == 3 + assert captured["bert_path"] == f"{tmp_path}/sentiment/chats/custom_vectors.csv" + assert captured["vect_path"] != f"{tmp_path}/sentence/chats/custom_vectors.csv" + assert builder.vect_path == captured["vect_path"]