diff --git a/internal/providers/chat_stream_normalize.go b/internal/providers/chat_stream_normalize.go new file mode 100644 index 00000000..ca9c5924 --- /dev/null +++ b/internal/providers/chat_stream_normalize.go @@ -0,0 +1,113 @@ +package providers + +import ( + "bufio" + "bytes" + "encoding/json" + "io" +) + +// chatDonePayload terminates a chat completions SSE stream. +var chatDonePayload = []byte("data: [DONE]\n\n") + +// peekForNonSSE inspects up to this many leading bytes to classify the upstream +// response. SSE payloads begin with a field name (data:, event:, id:, retry:) or +// a ':' comment; a buffered JSON completion begins with '{'. 512 bytes comfortably +// clears any leading whitespace or comment lines without buffering real streams. +const peekForNonSSE = 512 + +// EnsureChatCompletionSSE normalizes a chat completions stream so the client +// always receives well-formed Server-Sent Events terminated by data: [DONE]. +// +// Some OpenAI-compatible upstreams ignore stream:true and reply with a single +// buffered application/json completion (no data: framing, no [DONE]). Forwarding +// that verbatim under a text/event-stream content type leaves SSE clients waiting +// forever for an end-of-stream marker that never arrives. When the upstream body +// is detected as a buffered JSON object it is re-emitted as one SSE chunk plus a +// terminal [DONE]; genuine SSE streams pass through untouched with no buffering. +func EnsureChatCompletionSSE(stream io.ReadCloser) io.ReadCloser { + if stream == nil { + return nil + } + + reader := bufio.NewReaderSize(stream, peekForNonSSE) + if firstNonSpaceByte(reader, peekForNonSSE) != '{' { + // Genuine SSE (or empty): stream through unchanged, no buffering. + return &bufferedReadCloser{Reader: reader, closer: stream} + } + + // The '{' that classified this body is already buffered, so io.ReadAll + // always returns at least that byte; a mid-read failure still yields the + // partial bytes. Either way bufferedCompletionToSSE forwards what arrived + // (raw when the JSON is truncated) and appends [DONE], so generated content + // is never dropped and the client always receives a terminator. + body, _ := io.ReadAll(reader) + _ = stream.Close() //nolint:errcheck + return io.NopCloser(bytes.NewReader(bufferedCompletionToSSE(body))) +} + +// firstNonSpaceByte reports the first non-whitespace byte buffered by reader, +// peeking one byte further at a time so a genuine SSE stream is classified from +// its first token without blocking until a full buffer fills. It never consumes +// input, so a passed-through stream is forwarded byte-for-byte. Returns 0 when +// the stream ends, errors, or yields only whitespace within max bytes. +func firstNonSpaceByte(r *bufio.Reader, max int) byte { + for i := 1; i <= max; i++ { + prefix, err := r.Peek(i) + if len(prefix) < i { + _ = err // EOF or error before any non-space byte was found + return 0 + } + switch b := prefix[i-1]; b { + case ' ', '\t', '\r', '\n': + continue + default: + return b + } + } + return 0 +} + +// bufferedCompletionToSSE wraps a buffered chat completion JSON object as a +// single SSE chunk followed by the terminal [DONE] marker. The object field is +// rewritten to chat.completion.chunk and each choice's message is moved to delta +// so OpenAI SSE clients parse it as a streaming chunk. If the body does not parse +// as a JSON object it is forwarded as-is so no data is lost, still followed by +// [DONE] so the client stops waiting. +func bufferedCompletionToSSE(body []byte) []byte { + payload := body + var obj map[string]any + if err := json.Unmarshal(body, &obj); err == nil { + obj["object"] = "chat.completion.chunk" + if choices, ok := obj["choices"].([]any); ok { + for _, c := range choices { + choice, ok := c.(map[string]any) + if !ok { + continue + } + if msg, ok := choice["message"]; ok { + choice["delta"] = msg + delete(choice, "message") + } + } + } + if encoded, err := json.Marshal(obj); err == nil { + payload = encoded + } + } + + var out bytes.Buffer + out.WriteString("data: ") + out.Write(payload) + out.WriteString("\n\n") + out.Write(chatDonePayload) + return out.Bytes() +} + +// bufferedReadCloser pairs a buffered reader with the underlying stream's Close. +type bufferedReadCloser struct { + *bufio.Reader + closer io.Closer +} + +func (b *bufferedReadCloser) Close() error { return b.closer.Close() } diff --git a/internal/providers/chat_stream_normalize_test.go b/internal/providers/chat_stream_normalize_test.go new file mode 100644 index 00000000..5f05fad4 --- /dev/null +++ b/internal/providers/chat_stream_normalize_test.go @@ -0,0 +1,111 @@ +package providers + +import ( + "io" + "strings" + "testing" +) + +// errAfterReadCloser yields its data once, then fails — simulating a connection +// that drops mid-body after some bytes have arrived. +type errAfterReadCloser struct { + data []byte + err error + done bool +} + +func (r *errAfterReadCloser) Read(p []byte) (int, error) { + if r.done { + return 0, r.err + } + n := copy(p, r.data) + r.data = r.data[n:] + if len(r.data) == 0 { + r.done = true + } + return n, nil +} + +func (r *errAfterReadCloser) Close() error { return nil } + +func TestEnsureChatCompletionSSE_ConvertsBufferedJSON(t *testing.T) { + // Upstream ignored stream:true and returned a buffered, non-SSE completion. + body := `{"id":"x","object":"chat.completion","choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"Hi there"}}]}` + stream := io.NopCloser(strings.NewReader(body)) + + got, err := io.ReadAll(EnsureChatCompletionSSE(stream)) + if err != nil { + t.Fatalf("read stream: %v", err) + } + + out := string(got) + if !strings.HasPrefix(out, "data: {") { + t.Fatalf("expected SSE data framing, got %q", out) + } + if !strings.HasSuffix(out, "data: [DONE]\n\n") { + t.Fatalf("expected terminal done marker, got %q", out) + } + if !strings.Contains(out, `"object":"chat.completion.chunk"`) { + t.Fatalf("expected object rewritten to chunk, got %q", out) + } + if !strings.Contains(out, `"delta":`) || strings.Contains(out, `"message":`) { + t.Fatalf("expected message rewritten to delta, got %q", out) + } +} + +func TestEnsureChatCompletionSSE_PassesThroughRealSSE(t *testing.T) { + chunks := [][]byte{ + []byte("data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n"), + []byte("data: {\"choices\":[{\"delta\":{\"content\":\" there\"}}]}\n\n"), + []byte("data: [DONE]\n\n"), + } + original := strings.Join([]string{string(chunks[0]), string(chunks[1]), string(chunks[2])}, "") + stream := &chunkedReadCloser{chunks: chunks} + + got, err := io.ReadAll(EnsureChatCompletionSSE(stream)) + if err != nil { + t.Fatalf("read stream: %v", err) + } + if string(got) != original { + t.Fatalf("expected genuine SSE passed through unchanged.\n got: %q\nwant: %q", string(got), original) + } +} + +func TestEnsureChatCompletionSSE_PassesThroughSSEWithLeadingComment(t *testing.T) { + // Providers like OpenRouter emit a leading ": ... PROCESSING" comment line. + body := ": OPENROUTER PROCESSING\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n" + stream := io.NopCloser(strings.NewReader(body)) + + got, err := io.ReadAll(EnsureChatCompletionSSE(stream)) + if err != nil { + t.Fatalf("read stream: %v", err) + } + if string(got) != body { + t.Fatalf("expected comment-prefixed SSE unchanged, got %q", string(got)) + } +} + +func TestEnsureChatCompletionSSE_PreservesPartialBodyOnReadError(t *testing.T) { + // Upstream began a buffered JSON body, then the connection dropped mid-read. + // The partial content must still reach the client, followed by [DONE]. + partial := `{"id":"x","choices":[{"message":{"content":"Hel` + stream := &errAfterReadCloser{data: []byte(partial), err: io.ErrUnexpectedEOF} + + got, err := io.ReadAll(EnsureChatCompletionSSE(stream)) + if err != nil { + t.Fatalf("read stream: %v", err) + } + out := string(got) + if !strings.Contains(out, "Hel") { + t.Fatalf("expected partial content preserved, got %q", out) + } + if !strings.HasSuffix(out, "data: [DONE]\n\n") { + t.Fatalf("expected terminal done marker, got %q", out) + } +} + +func TestEnsureChatCompletionSSE_NilStream(t *testing.T) { + if EnsureChatCompletionSSE(nil) != nil { + t.Fatal("expected nil for nil stream") + } +} diff --git a/internal/providers/deepseek/deepseek.go b/internal/providers/deepseek/deepseek.go index 0557f74f..b4d2aba0 100644 --- a/internal/providers/deepseek/deepseek.go +++ b/internal/providers/deepseek/deepseek.go @@ -149,11 +149,15 @@ func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatReque if err != nil { return nil, err } - return p.client.DoStream(ctx, llmclient.Request{ + stream, err := p.client.DoStream(ctx, llmclient.Request{ Method: http.MethodPost, Endpoint: "/chat/completions", Body: body, }) + if err != nil { + return nil, err + } + return providers.EnsureChatCompletionSSE(stream), nil } // ListModels retrieves the list of available models from DeepSeek. diff --git a/internal/providers/openai/compatible_provider.go b/internal/providers/openai/compatible_provider.go index e017ddf1..0aee5bfc 100644 --- a/internal/providers/openai/compatible_provider.go +++ b/internal/providers/openai/compatible_provider.go @@ -125,11 +125,15 @@ func (p *CompatibleProvider) StreamChatCompletion(ctx context.Context, req *core if err != nil { return nil, err } - return p.client.DoStream(ctx, p.prepareRequest(llmclient.Request{ + stream, err := p.client.DoStream(ctx, p.prepareRequest(llmclient.Request{ Method: http.MethodPost, Endpoint: "/chat/completions", Body: body, })) + if err != nil { + return nil, err + } + return providers.EnsureChatCompletionSSE(stream), nil } func (p *CompatibleProvider) ListModels(ctx context.Context) (*core.ModelsResponse, error) { diff --git a/internal/providers/xai/xai.go b/internal/providers/xai/xai.go index ac6e2e71..37e1ed40 100644 --- a/internal/providers/xai/xai.go +++ b/internal/providers/xai/xai.go @@ -184,12 +184,19 @@ func (p *Provider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (* // StreamChatCompletion returns a raw response body for streaming (caller must close) func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) { - return p.client.DoStream(ctx, llmclient.Request{ + if req == nil { + return nil, core.NewInvalidRequestError("chat request is required", nil) + } + stream, err := p.client.DoStream(ctx, llmclient.Request{ Method: http.MethodPost, Endpoint: "/chat/completions", Body: req.WithStreaming(), Headers: xGrokConversationHeaders(ctx, req), }) + if err != nil { + return nil, err + } + return providers.EnsureChatCompletionSSE(stream), nil } // ListModels retrieves the list of available models from xAI