Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions internal/providers/chat_stream_normalize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package providers

import (
"bufio"
"bytes"
"encoding/json"
"io"
)

// chatDonePayload terminates a chat completions SSE stream.
var chatDonePayload = []byte("data: [DONE]\n\n")

// peekForNonSSE inspects up to this many leading bytes to classify the upstream
// response. SSE payloads begin with a field name (data:, event:, id:, retry:) or
// a ':' comment; a buffered JSON completion begins with '{'. 512 bytes comfortably
// clears any leading whitespace or comment lines without buffering real streams.
const peekForNonSSE = 512

// EnsureChatCompletionSSE normalizes a chat completions stream so the client
// always receives well-formed Server-Sent Events terminated by data: [DONE].
//
// Some OpenAI-compatible upstreams ignore stream:true and reply with a single
// buffered application/json completion (no data: framing, no [DONE]). Forwarding
// that verbatim under a text/event-stream content type leaves SSE clients waiting
// forever for an end-of-stream marker that never arrives. When the upstream body
// is detected as a buffered JSON object it is re-emitted as one SSE chunk plus a
// terminal [DONE]; genuine SSE streams pass through untouched with no buffering.
func EnsureChatCompletionSSE(stream io.ReadCloser) io.ReadCloser {
if stream == nil {
return nil
}

reader := bufio.NewReaderSize(stream, peekForNonSSE)
if firstNonSpaceByte(reader, peekForNonSSE) != '{' {
// Genuine SSE (or empty): stream through unchanged, no buffering.
return &bufferedReadCloser{Reader: reader, closer: stream}
}

// The '{' that classified this body is already buffered, so io.ReadAll
// always returns at least that byte; a mid-read failure still yields the
// partial bytes. Either way bufferedCompletionToSSE forwards what arrived
// (raw when the JSON is truncated) and appends [DONE], so generated content
// is never dropped and the client always receives a terminator.
body, _ := io.ReadAll(reader)
_ = stream.Close() //nolint:errcheck
return io.NopCloser(bytes.NewReader(bufferedCompletionToSSE(body)))
}

// firstNonSpaceByte reports the first non-whitespace byte buffered by reader,
// peeking one byte further at a time so a genuine SSE stream is classified from
// its first token without blocking until a full buffer fills. It never consumes
// input, so a passed-through stream is forwarded byte-for-byte. Returns 0 when
// the stream ends, errors, or yields only whitespace within max bytes.
func firstNonSpaceByte(r *bufio.Reader, max int) byte {
for i := 1; i <= max; i++ {
prefix, err := r.Peek(i)
if len(prefix) < i {
_ = err // EOF or error before any non-space byte was found
return 0
}
switch b := prefix[i-1]; b {
case ' ', '\t', '\r', '\n':
continue
default:
return b
}
}
return 0
}

// bufferedCompletionToSSE wraps a buffered chat completion JSON object as a
// single SSE chunk followed by the terminal [DONE] marker. The object field is
// rewritten to chat.completion.chunk and each choice's message is moved to delta
// so OpenAI SSE clients parse it as a streaming chunk. If the body does not parse
// as a JSON object it is forwarded as-is so no data is lost, still followed by
// [DONE] so the client stops waiting.
func bufferedCompletionToSSE(body []byte) []byte {
payload := body
var obj map[string]any
if err := json.Unmarshal(body, &obj); err == nil {
obj["object"] = "chat.completion.chunk"
if choices, ok := obj["choices"].([]any); ok {
for _, c := range choices {
choice, ok := c.(map[string]any)
if !ok {
continue
}
if msg, ok := choice["message"]; ok {
choice["delta"] = msg
delete(choice, "message")
}
}
}
if encoded, err := json.Marshal(obj); err == nil {
payload = encoded
}
}

var out bytes.Buffer
out.WriteString("data: ")
out.Write(payload)
out.WriteString("\n\n")
out.Write(chatDonePayload)
return out.Bytes()
}

// bufferedReadCloser pairs a buffered reader with the underlying stream's Close.
type bufferedReadCloser struct {
*bufio.Reader
closer io.Closer
}

func (b *bufferedReadCloser) Close() error { return b.closer.Close() }
111 changes: 111 additions & 0 deletions internal/providers/chat_stream_normalize_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package providers

import (
"io"
"strings"
"testing"
)

// errAfterReadCloser yields its data once, then fails — simulating a connection
// that drops mid-body after some bytes have arrived.
type errAfterReadCloser struct {
data []byte
err error
done bool
}

func (r *errAfterReadCloser) Read(p []byte) (int, error) {
if r.done {
return 0, r.err
}
n := copy(p, r.data)
r.data = r.data[n:]
if len(r.data) == 0 {
r.done = true
}
return n, nil
}

func (r *errAfterReadCloser) Close() error { return nil }

func TestEnsureChatCompletionSSE_ConvertsBufferedJSON(t *testing.T) {
// Upstream ignored stream:true and returned a buffered, non-SSE completion.
body := `{"id":"x","object":"chat.completion","choices":[{"finish_reason":"stop","message":{"role":"assistant","content":"Hi there"}}]}`
stream := io.NopCloser(strings.NewReader(body))

got, err := io.ReadAll(EnsureChatCompletionSSE(stream))
if err != nil {
t.Fatalf("read stream: %v", err)
}

out := string(got)
if !strings.HasPrefix(out, "data: {") {
t.Fatalf("expected SSE data framing, got %q", out)
}
if !strings.HasSuffix(out, "data: [DONE]\n\n") {
t.Fatalf("expected terminal done marker, got %q", out)
}
if !strings.Contains(out, `"object":"chat.completion.chunk"`) {
t.Fatalf("expected object rewritten to chunk, got %q", out)
}
if !strings.Contains(out, `"delta":`) || strings.Contains(out, `"message":`) {
t.Fatalf("expected message rewritten to delta, got %q", out)
}
}

func TestEnsureChatCompletionSSE_PassesThroughRealSSE(t *testing.T) {
chunks := [][]byte{
[]byte("data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n"),
[]byte("data: {\"choices\":[{\"delta\":{\"content\":\" there\"}}]}\n\n"),
[]byte("data: [DONE]\n\n"),
}
original := strings.Join([]string{string(chunks[0]), string(chunks[1]), string(chunks[2])}, "")
stream := &chunkedReadCloser{chunks: chunks}

got, err := io.ReadAll(EnsureChatCompletionSSE(stream))
if err != nil {
t.Fatalf("read stream: %v", err)
}
if string(got) != original {
t.Fatalf("expected genuine SSE passed through unchanged.\n got: %q\nwant: %q", string(got), original)
}
}

func TestEnsureChatCompletionSSE_PassesThroughSSEWithLeadingComment(t *testing.T) {
// Providers like OpenRouter emit a leading ": ... PROCESSING" comment line.
body := ": OPENROUTER PROCESSING\n\ndata: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n"
stream := io.NopCloser(strings.NewReader(body))

got, err := io.ReadAll(EnsureChatCompletionSSE(stream))
if err != nil {
t.Fatalf("read stream: %v", err)
}
if string(got) != body {
t.Fatalf("expected comment-prefixed SSE unchanged, got %q", string(got))
}
}

func TestEnsureChatCompletionSSE_PreservesPartialBodyOnReadError(t *testing.T) {
// Upstream began a buffered JSON body, then the connection dropped mid-read.
// The partial content must still reach the client, followed by [DONE].
partial := `{"id":"x","choices":[{"message":{"content":"Hel`
stream := &errAfterReadCloser{data: []byte(partial), err: io.ErrUnexpectedEOF}

got, err := io.ReadAll(EnsureChatCompletionSSE(stream))
if err != nil {
t.Fatalf("read stream: %v", err)
}
out := string(got)
if !strings.Contains(out, "Hel") {
t.Fatalf("expected partial content preserved, got %q", out)
}
if !strings.HasSuffix(out, "data: [DONE]\n\n") {
t.Fatalf("expected terminal done marker, got %q", out)
}
}

func TestEnsureChatCompletionSSE_NilStream(t *testing.T) {
if EnsureChatCompletionSSE(nil) != nil {
t.Fatal("expected nil for nil stream")
}
}
6 changes: 5 additions & 1 deletion internal/providers/deepseek/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,15 @@ func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatReque
if err != nil {
return nil, err
}
return p.client.DoStream(ctx, llmclient.Request{
stream, err := p.client.DoStream(ctx, llmclient.Request{
Method: http.MethodPost,
Endpoint: "/chat/completions",
Body: body,
})
if err != nil {
return nil, err
}
return providers.EnsureChatCompletionSSE(stream), nil
}

// ListModels retrieves the list of available models from DeepSeek.
Expand Down
6 changes: 5 additions & 1 deletion internal/providers/openai/compatible_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,15 @@ func (p *CompatibleProvider) StreamChatCompletion(ctx context.Context, req *core
if err != nil {
return nil, err
}
return p.client.DoStream(ctx, p.prepareRequest(llmclient.Request{
stream, err := p.client.DoStream(ctx, p.prepareRequest(llmclient.Request{
Method: http.MethodPost,
Endpoint: "/chat/completions",
Body: body,
}))
if err != nil {
return nil, err
}
return providers.EnsureChatCompletionSSE(stream), nil
Comment thread
greptile-apps[bot] marked this conversation as resolved.
}

func (p *CompatibleProvider) ListModels(ctx context.Context) (*core.ModelsResponse, error) {
Expand Down
9 changes: 8 additions & 1 deletion internal/providers/xai/xai.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,19 @@ func (p *Provider) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*

// StreamChatCompletion returns a raw response body for streaming (caller must close)
func (p *Provider) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) {
return p.client.DoStream(ctx, llmclient.Request{
if req == nil {
return nil, core.NewInvalidRequestError("chat request is required", nil)
}
stream, err := p.client.DoStream(ctx, llmclient.Request{
Method: http.MethodPost,
Endpoint: "/chat/completions",
Body: req.WithStreaming(),
Headers: xGrokConversationHeaders(ctx, req),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
})
if err != nil {
return nil, err
}
return providers.EnsureChatCompletionSSE(stream), nil
}

// ListModels retrieves the list of available models from xAI
Expand Down