Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"build:app": "vite build",
"dev": "node scripts/dev-server.mjs",
"start": "node dist/cli.js serve",
"test": "tsx src/config.test.ts && tsx src/roots.test.ts && tsx src/skills.test.ts && tsx src/workspaces.test.ts && tsx src/review-checkpoints.test.ts",
"test": "tsx src/config.test.ts && tsx src/roots.test.ts && tsx src/skills.test.ts && tsx src/workspaces.test.ts && tsx src/oauth-store.test.ts && tsx src/review-checkpoints.test.ts",
"typecheck": "tsc -p tsconfig.json --noEmit"
},
"keywords": [],
Expand Down
7 changes: 5 additions & 2 deletions src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async function serve(): Promise<void> {

const { createServer } = await import("./server.js");
const config = loadConfig();
const { app } = createServer(config);
const { app, close } = createServer(config);
const httpServer = app.listen(config.port, config.host, () => {
console.log(`devspace listening on http://${config.host}:${config.port}/mcp`);
console.log(`public base url: ${config.publicBaseUrl}`);
Expand All @@ -192,7 +192,10 @@ async function serve(): Promise<void> {
});

const shutdown = () => {
httpServer.close(() => process.exit(0));
httpServer.close(() => {
close();
process.exit(0);
});
};
process.once("SIGINT", shutdown);
process.once("SIGTERM", shutdown);
Expand Down
21 changes: 21 additions & 0 deletions src/config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,24 @@ assert.deepEqual(fileConfig.allowedHosts, [
"::1",
"devspace.example.com",
]);

const bomConfigDir = mkdtempSync(join(tmpdir(), "devspace-bom-config-test-"));
writeFileSync(
join(bomConfigDir, "config.json"),
`\uFEFF${JSON.stringify({
port: 8989,
allowedRoots: [process.cwd()],
publicBaseUrl: "https://bom.example.com",
})}`,
"utf8",
);
writeFileSync(
join(bomConfigDir, "auth.json"),
`\uFEFF${JSON.stringify({ ownerToken: "bom-owner-password-long-enough" })}`,
"utf8",
);

const bomConfig = loadConfig({ DEVSPACE_CONFIG_DIR: bomConfigDir });
assert.equal(bomConfig.port, 8989);
assert.equal(bomConfig.oauth.ownerToken, "bom-owner-password-long-enough");
assert.equal(bomConfig.publicBaseUrl, "https://bom.example.com");
39 changes: 38 additions & 1 deletion src/db/schema.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { index, primaryKey, sqliteTable, text } from "drizzle-orm/sqlite-core";
import { index, integer, primaryKey, sqliteTable, text } from "drizzle-orm/sqlite-core";

export const workspaceSessions = sqliteTable(
"workspace_sessions",
Expand Down Expand Up @@ -40,5 +40,42 @@ export const loadedAgentFiles = sqliteTable(

export type WorkspaceSessionRow = typeof workspaceSessions.$inferSelect;
export type NewWorkspaceSessionRow = typeof workspaceSessions.$inferInsert;
export const oauthClients = sqliteTable("oauth_clients", {
clientId: text("client_id").primaryKey(),
clientJson: text("client_json").notNull(),
createdAt: integer("created_at").notNull(),
});

export const oauthAuthorizationCodes = sqliteTable(
"oauth_authorization_codes",
{
codeHash: text("code_hash").primaryKey(),
clientId: text("client_id")
.notNull()
.references(() => oauthClients.clientId, { onDelete: "cascade" }),
paramsJson: text("params_json").notNull(),
expiresAtMs: integer("expires_at_ms").notNull(),
},
(table) => [index("oauth_authorization_codes_expiry_idx").on(table.expiresAtMs)],
);

export const oauthTokens = sqliteTable(
"oauth_tokens",
{
tokenHash: text("token_hash").notNull(),
tokenKind: text("token_kind").notNull(),
clientId: text("client_id")
.notNull()
.references(() => oauthClients.clientId, { onDelete: "cascade" }),
scopesJson: text("scopes_json").notNull(),
expiresAt: integer("expires_at").notNull(),
resource: text("resource"),
},
(table) => [
primaryKey({ columns: [table.tokenHash, table.tokenKind] }),
index("oauth_tokens_expiry_idx").on(table.expiresAt),
],
);

export type LoadedAgentFileRow = typeof loadedAgentFiles.$inferSelect;
export type NewLoadedAgentFileRow = typeof loadedAgentFiles.$inferInsert;
73 changes: 29 additions & 44 deletions src/oauth-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import type {
OAuthTokens,
} from "@modelcontextprotocol/sdk/shared/auth.js";
import { checkResourceAllowed, resourceUrlFromServerUrl } from "@modelcontextprotocol/sdk/shared/auth-utils.js";
import {
SqliteOAuthStore,
type AuthorizationCodeRecord,
} from "./oauth-store.js";

export interface OAuthConfig {
ownerToken: string;
Expand All @@ -19,28 +23,6 @@ export interface OAuthConfig {
allowedRedirectHosts: string[];
}

interface AuthorizationCodeRecord {
clientId: string;
params: AuthorizationParams;
expiresAtMs: number;
}

interface AccessTokenRecord {
token: string;
clientId: string;
scopes: string[];
expiresAt: number;
resource?: URL;
}

interface RefreshTokenRecord {
token: string;
clientId: string;
scopes: string[];
expiresAt: number;
resource?: URL;
}

const CODE_TTL_MS = 5 * 60 * 1000;

function randomToken(): string {
Expand Down Expand Up @@ -138,13 +120,14 @@ function redirectHostAllowed(redirectUri: string, allowedHosts: string[]): boole
return allowedHosts.includes(parsed.hostname);
}

export class InMemoryOAuthClientsStore implements OAuthRegisteredClientsStore {
private readonly clients = new Map<string, OAuthClientInformationFull>();

constructor(private readonly allowedRedirectHosts: string[]) {}
export class SqliteOAuthClientsStore implements OAuthRegisteredClientsStore {
constructor(
private readonly allowedRedirectHosts: string[],
private readonly store: SqliteOAuthStore,
) {}

getClient(clientId: string): OAuthClientInformationFull | undefined {
return this.clients.get(clientId);
return this.store.getClient(clientId);
}

registerClient(
Expand All @@ -163,24 +146,24 @@ export class InMemoryOAuthClientsStore implements OAuthRegisteredClientsStore {
grant_types: client.grant_types ?? ["authorization_code", "refresh_token"],
response_types: client.response_types ?? ["code"],
};
this.clients.set(registered.client_id, registered);
this.store.saveClient(registered);
return registered;
}
}

export class SingleUserOAuthProvider implements OAuthServerProvider {
readonly clientsStore: OAuthRegisteredClientsStore;
private readonly codes = new Map<string, AuthorizationCodeRecord>();
private readonly accessTokens = new Map<string, AccessTokenRecord>();
private readonly refreshTokens = new Map<string, RefreshTokenRecord>();
private readonly store: SqliteOAuthStore;
private readonly resourceServerUrl: URL;

constructor(
private readonly config: OAuthConfig,
resourceServerUrl: URL,
stateDir: string,
) {
this.resourceServerUrl = resourceUrlFromServerUrl(resourceServerUrl);
this.clientsStore = new InMemoryOAuthClientsStore(config.allowedRedirectHosts);
this.store = new SqliteOAuthStore(stateDir);
this.clientsStore = new SqliteOAuthClientsStore(config.allowedRedirectHosts, this.store);
}

async authorize(
Expand Down Expand Up @@ -224,7 +207,7 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {
}

const code = `code-${randomUUID()}`;
this.codes.set(code, {
this.store.saveAuthorizationCode(hashToken(code), {
clientId: client.client_id,
params,
expiresAtMs: Date.now() + CODE_TTL_MS,
Expand Down Expand Up @@ -259,7 +242,7 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {
throw new InvalidGrantError("Invalid resource");
}

this.codes.delete(authorizationCode);
this.store.deleteAuthorizationCode(hashToken(authorizationCode));
return this.issueTokens(client.client_id, record.params.scopes ?? this.config.scopes, record.params.resource);
}

Expand All @@ -269,7 +252,8 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {
scopes?: string[],
resource?: URL,
): Promise<OAuthTokens> {
const record = this.refreshTokens.get(hashToken(refreshToken));
const refreshTokenHash = hashToken(refreshToken);
const record = this.store.getRefreshToken(refreshTokenHash);
if (!record || record.clientId !== client.client_id || record.expiresAt < Math.floor(Date.now() / 1000)) {
throw new InvalidGrantError("Invalid refresh token");
}
Expand All @@ -282,12 +266,12 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {
throw new AccessDeniedError("Refresh token cannot grant requested scopes");
}

this.refreshTokens.delete(hashToken(refreshToken));
this.store.deleteRefreshToken(refreshTokenHash);
return this.issueTokens(client.client_id, requestedScopes, resource ?? record.resource);
}

async verifyAccessToken(token: string): Promise<AuthInfo> {
const record = this.accessTokens.get(hashToken(token));
const record = this.store.getAccessToken(hashToken(token));
if (!record || record.expiresAt < Math.floor(Date.now() / 1000)) {
throw new InvalidTokenError("Invalid or expired access token");
}
Expand All @@ -303,15 +287,18 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {

async revokeToken(_client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise<void> {
const hashed = hashToken(request.token);
this.accessTokens.delete(hashed);
this.refreshTokens.delete(hashed);
this.store.revokeToken(hashed);
}

close(): void {
this.store.close();
}

private validCodeRecord(
client: OAuthClientInformationFull,
authorizationCode: string,
): AuthorizationCodeRecord {
const record = this.codes.get(authorizationCode);
const record = this.store.getAuthorizationCode(hashToken(authorizationCode));
if (!record || record.clientId !== client.client_id || record.expiresAtMs < Date.now()) {
throw new InvalidGrantError("Invalid authorization code");
}
Expand All @@ -325,15 +312,13 @@ export class SingleUserOAuthProvider implements OAuthServerProvider {
const accessExpiresAt = now + this.config.accessTokenTtlSeconds;
const refreshExpiresAt = now + this.config.refreshTokenTtlSeconds;

this.accessTokens.set(hashToken(accessToken), {
token: accessToken,
this.store.saveAccessToken(hashToken(accessToken), {
clientId,
scopes,
expiresAt: accessExpiresAt,
resource,
});
this.refreshTokens.set(hashToken(refreshToken), {
token: refreshToken,
this.store.saveRefreshToken(hashToken(refreshToken), {
clientId,
scopes,
expiresAt: refreshExpiresAt,
Expand Down
Loading