Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions packages/backend/src/api/services/MoocletExperimentService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,16 @@ export class MoocletExperimentService extends ExperimentService {

updatedExperiment.moocletPolicyParameters = policyParameterResponse.parameters;

// Transform prior keys from Mooclet version IDs back to UpGrade condition codes,
// so the PUT response is consistent with the GET response from attachPolicyParamsToExperimentDTO.
const updatedTsParams = updatedExperiment.moocletPolicyParameters as MoocletTSConfigurablePolicyParametersDTO;
if (updatedTsParams?.prior) {
updatedTsParams.prior = this.translateVersionIdsToConditionCodes(
updatedTsParams.prior,
currentMoocletExperimentRef.versionConditionMaps
);
}

// --------- update versions ----------------------

if (!versionEdits) {
Expand Down Expand Up @@ -522,6 +532,18 @@ export class MoocletExperimentService extends ExperimentService {
currentMoocletExperimentRef: MoocletExperimentRef,
logger: UpgradeLogger
): Promise<MoocletPolicyParametersResponseDetails> {
const tsParams = newPolicyParameters as MoocletTSConfigurablePolicyParametersDTO;
if (tsParams.prior) {
// Translate conditionCode keys to Mooclet version IDs before sending to the Mooclet API
newPolicyParameters = {
...tsParams,
prior: this.translateConditionCodesToVersionIds(
tsParams.prior,
currentMoocletExperimentRef.versionConditionMaps
),
} as MoocletPolicyParametersDTO;
}

return this.moocletDataService.updatePolicyParameters(
currentMoocletExperimentRef.policyParametersId,
{
Expand Down Expand Up @@ -1045,10 +1067,9 @@ export class MoocletExperimentService extends ExperimentService {

moocletExperimentRef.variableId = moocletVariableResponse?.id;
moocletExperimentRef.policyId = newMoocletRequest.policy;
moocletExperimentRef.outcomeVariableName =
upgradeExperiment.assignmentAlgorithm === ASSIGNMENT_ALGORITHM.MOOCLET_TS_CONFIGURABLE
? (moocletPolicyParameters as MoocletTSConfigurablePolicyParametersDTO).outcome_variable_name
: undefined;
moocletExperimentRef.outcomeVariableName = (
moocletPolicyParameters as MoocletTSConfigurablePolicyParametersDTO
).outcome_variable_name;
} catch (err) {
await this.orchestrateDeleteMoocletResources(moocletExperimentRef, logger);
throw err;
Expand Down Expand Up @@ -1168,24 +1189,20 @@ export class MoocletExperimentService extends ExperimentService {
logger
);

// Transform current_posteriors keys from Mooclet version IDs to UpGrade condition codes
// Transform current_posteriors and prior keys from Mooclet version IDs to UpGrade condition codes
const tsConfigurableParams = policyParameters.parameters as MoocletTSConfigurablePolicyParametersDTO;
if (tsConfigurableParams.current_posteriors) {
const transformedPosteriors = {};
for (const [versionId, posteriorData] of Object.entries(tsConfigurableParams.current_posteriors)) {
const versionConditionMap = moocletExperimentRef.versionConditionMaps.find(
(map) => map.moocletVersionId === parseInt(versionId, 10)
);

if (versionConditionMap?.experimentCondition?.conditionCode) {
transformedPosteriors[versionConditionMap.experimentCondition.conditionCode] = posteriorData;
} else {
logger.warn({
message: `No condition mapping found for Mooclet version ${versionId} in experiment ${experiment.id}`,
});
}
}
tsConfigurableParams.current_posteriors = transformedPosteriors;
tsConfigurableParams.current_posteriors = this.translateVersionIdsToConditionCodes(
tsConfigurableParams.current_posteriors,
moocletExperimentRef.versionConditionMaps
);
}

if (tsConfigurableParams.prior) {
tsConfigurableParams.prior = this.translateVersionIdsToConditionCodes(
tsConfigurableParams.prior,
moocletExperimentRef.versionConditionMaps
);
}

experiment.moocletPolicyParameters = policyParameters.parameters;
Expand Down Expand Up @@ -1350,6 +1367,37 @@ export class MoocletExperimentService extends ExperimentService {
return SUPPORTED_MOOCLET_ALGORITHMS.includes(assignmentAlgorithm);
}

private translateConditionCodesToVersionIds<T>(
record: Record<string, T>,
versionConditionMaps: MoocletVersionConditionMap[]
): Record<string, T> {
const result: Record<string, T> = {};
for (const [conditionCode, value] of Object.entries(record)) {
const map = versionConditionMaps?.find((m) => m.experimentCondition?.conditionCode === conditionCode);
if (map?.moocletVersionId) {
result[String(map.moocletVersionId)] = value;
}
}
return result;
}

private translateVersionIdsToConditionCodes<T>(
record: Record<string, T>,
versionConditionMaps: MoocletVersionConditionMap[]
): Record<string, T> {
const result: Record<string, T> = {};
for (const [versionId, value] of Object.entries(record)) {
const map = versionConditionMaps?.find((m) => String(m.moocletVersionId) === versionId);
if (!map?.experimentCondition?.conditionCode) {
throw new MoocletError(
`Reward feedback summary could not be processed: no condition mapping found for Mooclet version ${versionId}`
);
}
result[map.experimentCondition.conditionCode] = value;
}
return result;
}

/**
* Generate a unique outcome variable name based on experiment name and timestamp.
* Returns a string in this format:
Expand Down
137 changes: 132 additions & 5 deletions packages/backend/src/api/services/MoocletRewardsService.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { UpgradeLogger } from '../../lib/logger/UpgradeLogger';
import { EXPERIMENT_STATE, SERVER_ERROR, BinaryRewardValueMap } from 'upgrade_types';
import {
EXPERIMENT_STATE,
SERVER_ERROR,
BinaryRewardValueMap,
MoocletTSConfigurablePolicyParametersDTO,
Prior,
} from 'upgrade_types';
import { RequestedExperimentUser } from '../controllers/validators/ExperimentUserValidator';
import { MoocletExperimentRef } from '../models/MoocletExperimentRef';
import { MoocletDataService } from './MoocletDataService';
Expand Down Expand Up @@ -243,7 +249,24 @@ export class MoocletRewardsService {
}
}

return this.createExperimentRewardsSummary(moocletExperimentRef, rewards, logger);
let tsConfigurableParams: MoocletTSConfigurablePolicyParametersDTO | undefined;
if (moocletExperimentRef.policyParametersId) {
try {
const policyParametersResponse = await this.moocletDataService.getPolicyParameters(
moocletExperimentRef.policyParametersId,
logger
);
tsConfigurableParams = policyParametersResponse.parameters as MoocletTSConfigurablePolicyParametersDTO;
} catch (policyError) {
logger.warn({
message: 'Could not fetch policy parameters for Thompson sampling estimate',
experimentId,
error: policyError,
});
}
}

return this.createExperimentRewardsSummary(moocletExperimentRef, rewards, logger, tsConfigurableParams);
} catch (error) {
logger.error({ message: 'Error fetching rewards summary for experiment', experimentId, error });
throw error;
Expand All @@ -266,7 +289,8 @@ export class MoocletRewardsService {
public async createExperimentRewardsSummary(
moocletExperimentRef: MoocletExperimentRef,
rewardsData: MoocletValueResponseDetails[],
logger: UpgradeLogger
logger: UpgradeLogger,
policyParameters?: MoocletTSConfigurablePolicyParametersDTO
): Promise<ExperimentRewardsSummary> {
const rewards: MoocletValueResponseDetails[] = rewardsData;

Expand All @@ -278,22 +302,42 @@ export class MoocletRewardsService {
return [];
}

const versionConditionPairs = moocletExperimentRef.versionConditionMaps.map(
({ experimentCondition, moocletVersionId }) => ({
conditionCode: experimentCondition.conditionCode,
moocletVersionId,
})
);
const DEFAULT_PRIOR: Prior = { success: 1, failure: 1 };
const estimatedWeightMap = policyParameters
? this.computeThompsonWeightsMap(versionConditionPairs, policyParameters)
: null;

const rewardsSummaries = moocletExperimentRef.versionConditionMaps.map(
({ experimentCondition, moocletVersionId }) => {
const conditionCode = experimentCondition.conditionCode;
const versionIdKey = String(moocletVersionId);
const versionRewards = rewards.filter((reward) => reward.version === moocletVersionId);
const successes = versionRewards.filter((reward) => reward.value === 1.0).length;
const failures = versionRewards.filter((reward) => reward.value === 0.0).length;
const total = successes + failures;
const percentSuccess = total > 0 ? (successes / total) * 100 : 0.0;
const successRate = percentSuccess.toFixed(1) + '%';

const conditionPrior: Prior = policyParameters?.prior?.[versionIdKey] ?? DEFAULT_PRIOR;
const conditionPosteriors = policyParameters?.current_posteriors?.[versionIdKey];

const rewardsForCondition: ExperimentRewardsByCondition = {
conditionCode: experimentCondition.conditionCode,
conditionCode,
successes,
failures,
total,
successRate,
order: experimentCondition.order,
estimatedWeight: estimatedWeightMap?.get(conditionCode),
priorSuccess: conditionPrior.success,
priorFailure: conditionPrior.failure,
posteriorSuccesses: conditionPosteriors?.successes ?? 0,
posteriorFailures: conditionPosteriors?.failures ?? 0,
Comment thread
danoswaltCL marked this conversation as resolved.
};
return rewardsForCondition;
}
Expand All @@ -303,6 +347,89 @@ export class MoocletRewardsService {
return orderedRewardsSummary;
}

/**
* Computes Thompson Sampling estimated weight percentages per condition.
* Combines per-condition priors with current_posteriors as Beta distribution inputs.
* Returns a Map of conditionCode → integer estimated weight (all values sum to 100).
*/
private computeThompsonWeightsMap(
versionConditionPairs: { conditionCode: string; moocletVersionId: number }[],
params: MoocletTSConfigurablePolicyParametersDTO
): Map<string, number> {
const DEFAULT_PRIOR: Prior = { success: 1, failure: 1 };

const arms = versionConditionPairs.map(({ conditionCode, moocletVersionId }) => {
const versionIdKey = String(moocletVersionId);
const posteriors = params.current_posteriors?.[versionIdKey];
return {
conditionCode,
alpha: posteriors?.successes ?? DEFAULT_PRIOR.success,
beta: posteriors?.failures ?? DEFAULT_PRIOR.failure,
Comment thread
danoswaltCL marked this conversation as resolved.
};
});

const iterations = 10_000;
const wins = new Array(arms.length).fill(0);
for (let i = 0; i < iterations; i++) {
let maxSample = -1;
Comment thread
danoswaltCL marked this conversation as resolved.
let maxIdx = 0;
for (let j = 0; j < arms.length; j++) {
const sample = this.randBeta(arms[j].alpha, arms[j].beta);
if (sample > maxSample) {
maxSample = sample;
maxIdx = j;
}
}
wins[maxIdx]++;
}

// Largest Remainder Method: normalize to integer percentages summing to exactly 100
const raw = arms.map((_, i) => (wins[i] / iterations) * 100);
const floored = raw.map(Math.floor);
const remainder = 100 - floored.reduce((a, b) => a + b, 0);
const indices = raw
.map((v, i) => ({ diff: v - floored[i], i }))
.sort((a, b) => b.diff - a.diff)
.map(({ i }) => i);
for (let k = 0; k < remainder; k++) floored[indices[k]]++;

return new Map(arms.map((arm, i) => [arm.conditionCode, floored[i]]));
}

// --- Beta distribution sampling (Marsaglia-Tsang / Box-Muller) ---

private randNormal(): number {
const u = Math.random() || Number.EPSILON; // guard against log(0)
return Math.sqrt(-2 * Math.log(u)) * Math.cos(2 * Math.PI * Math.random());
}

private randGamma(shape: number): number {
if (shape < 1) {
return this.randGamma(1 + shape) * Math.pow(Math.random(), 1 / shape);
}
const d = shape - 1 / 3;
const c = 1 / Math.sqrt(9 * d);
let x: number;
let v: number;
let u: number;
do {
do {
x = this.randNormal();
v = 1 + c * x;
} while (v <= 0);
v = v * v * v;
u = Math.random();
// eslint-disable-next-line no-constant-condition
} while (!(u < 1 - 0.0331 * x * x * x * x) && !(Math.log(u) < 0.5 * x * x + d * (1 - v + Math.log(v))));
return d * v;
}

private randBeta(alpha: number, beta: number): number {
const g1 = this.randGamma(alpha);
const g2 = this.randGamma(beta);
return g1 / (g1 + g2);
}

/**
* Throws a 409 data-conflict error for most unexpected cases
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,10 @@ describe('#MoocletExperimentService', () => {
const mockPolicyParams = {
parameters: {
assignmentAlgorithm: ASSIGNMENT_ALGORITHM.MOOCLET_TS_CONFIGURABLE,
prior: { success: 1, failure: 1 },
prior: {
'1': { success: 1, failure: 1 },
'2': { success: 1, failure: 1 },
},
},
};

Expand Down Expand Up @@ -1667,7 +1670,16 @@ describe('#MoocletExperimentService', () => {
mockMoocletExperimentRef.id = 'ref-123';
mockMoocletExperimentRef.policyParametersId = 1;
mockMoocletExperimentRef.variableId = 2;
mockMoocletExperimentRef.versionConditionMaps = [];
mockMoocletExperimentRef.versionConditionMaps = [
{
moocletVersionId: 10,
experimentCondition: { conditionCode: 'control' } as any,
} as MoocletVersionConditionMap,
{
moocletVersionId: 20,
experimentCondition: { conditionCode: 'treatment' } as any,
} as MoocletVersionConditionMap,
];

const mockExperiment = {
id: 'exp-123',
Expand Down Expand Up @@ -1713,7 +1725,13 @@ describe('#MoocletExperimentService', () => {
updateSpy.mockResolvedValue(updatedExperiment);

jest.spyOn(moocletExperimentService as any, 'doRevertablePolicyParameterChange').mockResolvedValue({
parameters: mockExperiment.moocletPolicyParameters,
parameters: {
...mockExperiment.moocletPolicyParameters,
prior: {
'10': { success: 1, failure: 1 },
'20': { success: 1, failure: 1 },
},
},
});

const result = await (moocletExperimentService as any).handleEditMoocletTransaction(manager, params);
Expand Down
Loading