Implement accumulate-then-act streaming for thinking models
This commit is contained in:
@@ -39,10 +39,11 @@ export interface ToolDefinition {
|
||||
}
|
||||
|
||||
export interface ChatChunk {
|
||||
type: "text" | "thinking" | "tool_call" | "done" | "error";
|
||||
type: "text" | "thinking" | "text_delta" | "thinking_delta" | "tool_calls" | "done" | "error";
|
||||
text?: string;
|
||||
toolCall?: ToolCall;
|
||||
toolCalls?: ToolCall[];
|
||||
error?: string;
|
||||
finishReason?: string;
|
||||
}
|
||||
|
||||
type GeminiPart = Record<string, unknown>;
|
||||
@@ -313,17 +314,26 @@ export async function* streamGeminiChat(opts: {
|
||||
if (part.text) {
|
||||
if (isPartThought(part as Record<string, unknown>)) {
|
||||
thoughts += part.text;
|
||||
yield { type: "thinking", text: part.text };
|
||||
yield { type: "thinking_delta", text: part.text };
|
||||
} else {
|
||||
text += part.text;
|
||||
yield { type: "text", text: part.text };
|
||||
yield { type: "text_delta", text: part.text };
|
||||
}
|
||||
}
|
||||
if (part.functionCall) {
|
||||
toolCalls.push(part.functionCall);
|
||||
toolCalls.push({
|
||||
id: `tc-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
name: part.functionCall.name,
|
||||
args: (part.functionCall.args as Record<string, unknown>) ?? {},
|
||||
thoughtSignature: (part as { thoughtSignature?: string }).thoughtSignature,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
yield { type: "tool_calls", toolCalls };
|
||||
}
|
||||
const durationMs = Date.now() - startTime;
|
||||
|
||||
logTrainingTelemetryDb({
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
* We normalize them to JSON Schema before sending.
|
||||
*/
|
||||
|
||||
import type { ChatMessage, ToolCall, ToolDefinition } from "./gemini-chat";
|
||||
import type { ChatMessage, ToolCall, ToolDefinition, ChatChunk } from "./gemini-chat";
|
||||
|
||||
const DEFAULT_CHAT_URL = "https://api.deepseek.com/chat/completions";
|
||||
|
||||
@@ -389,3 +389,149 @@ export async function callOpenAiCompatibleChat(opts: {
|
||||
|
||||
return { text, thoughts, toolCalls, finishReason };
|
||||
}
|
||||
|
||||
|
||||
export async function* streamOpenAiCompatibleChat(opts: {
|
||||
systemPrompt: string;
|
||||
messages: ChatMessage[];
|
||||
tools?: ToolDefinition[];
|
||||
temperature?: number;
|
||||
includeThoughts?: boolean;
|
||||
signal?: AbortSignal;
|
||||
}): AsyncGenerator<ChatChunk> {
|
||||
const apiKey = resolveApiKey();
|
||||
if (!apiKey) {
|
||||
yield {
|
||||
type: "error",
|
||||
error: "No API key: set DEEPSEEK_API_KEY or VIBN_OPENAI_COMPATIBLE_API_KEY for OpenAI-compatible chat.",
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
const url = resolveChatUrl();
|
||||
const model = resolveModel();
|
||||
const tools = toOpenAiTools(opts.tools);
|
||||
const oaiMessages = toOpenAiMessages(opts.systemPrompt, opts.messages);
|
||||
const body: Record<string, unknown> = {
|
||||
model,
|
||||
messages: oaiMessages,
|
||||
temperature: opts.temperature ?? 0.7,
|
||||
max_tokens: 8192,
|
||||
stream: true,
|
||||
};
|
||||
if (tools?.length) body.tools = tools;
|
||||
|
||||
let res: Response;
|
||||
try {
|
||||
res = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
signal: opts.signal,
|
||||
});
|
||||
} catch (e) {
|
||||
const aborted = opts.signal?.aborted || (e instanceof Error && e.name === "AbortError");
|
||||
yield {
|
||||
type: "error",
|
||||
error: aborted ? "aborted" : `Network error: ${e instanceof Error ? e.message : String(e)}`,
|
||||
};
|
||||
return;
|
||||
}
|
||||
|
||||
if (!res.ok) {
|
||||
const text = await res.text().catch(() => "");
|
||||
yield { type: "error", error: `Chat API error ${res.status}: ${text}` };
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = res.body?.getReader();
|
||||
if (!reader) {
|
||||
yield { type: "error", error: "No response body stream." };
|
||||
return;
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
let buffer = "";
|
||||
|
||||
// Accumulated tool calls
|
||||
const toolCallsAcc: Record<number, { id: string; name: string; argsStr: string }> = {};
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
|
||||
for (const line of lines) {
|
||||
const tLine = line.trim();
|
||||
if (!tLine || !tLine.startsWith("data: ")) continue;
|
||||
const dataStr = tLine.slice(6);
|
||||
if (dataStr === "[DONE]") continue;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(dataStr);
|
||||
const delta = parsed.choices?.[0]?.delta;
|
||||
if (!delta) continue;
|
||||
|
||||
if (typeof delta.reasoning_content === "string" && delta.reasoning_content.length > 0) {
|
||||
yield { type: "thinking_delta", text: delta.reasoning_content };
|
||||
}
|
||||
if (typeof delta.content === "string" && delta.content.length > 0) {
|
||||
yield { type: "text_delta", text: delta.content };
|
||||
}
|
||||
|
||||
if (delta.tool_calls && Array.isArray(delta.tool_calls)) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const idx = tc.index;
|
||||
if (idx === undefined) continue;
|
||||
if (!toolCallsAcc[idx]) {
|
||||
toolCallsAcc[idx] = { id: "", name: "", argsStr: "" };
|
||||
}
|
||||
if (tc.id) toolCallsAcc[idx].id = tc.id;
|
||||
if (tc.function?.name) toolCallsAcc[idx].name += tc.function.name;
|
||||
if (tc.function?.arguments) toolCallsAcc[idx].argsStr += tc.function.arguments;
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// ignore unparseable chunks
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
const aborted = opts.signal?.aborted || (e instanceof Error && e.name === "AbortError");
|
||||
yield {
|
||||
type: "error",
|
||||
error: aborted ? "aborted" : `Stream read error: ${e instanceof Error ? e.message : String(e)}`,
|
||||
};
|
||||
}
|
||||
|
||||
const toolCalls: ToolCall[] = [];
|
||||
for (const idx of Object.keys(toolCallsAcc).sort((a,b) => Number(a) - Number(b))) {
|
||||
const acc = toolCallsAcc[Number(idx)];
|
||||
let args = {};
|
||||
try {
|
||||
if (acc.argsStr) args = JSON.parse(acc.argsStr);
|
||||
} catch {
|
||||
// ignore bad json
|
||||
}
|
||||
if (acc.name) {
|
||||
toolCalls.push({
|
||||
id: acc.id || `tc-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
name: acc.name,
|
||||
args
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
yield { type: "tool_calls", toolCalls };
|
||||
}
|
||||
|
||||
yield { type: "done" };
|
||||
}
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
*/
|
||||
|
||||
import type { ChatMessage, ToolDefinition } from "./gemini-chat";
|
||||
import { callGeminiChat } from "./gemini-chat";
|
||||
import { callOpenAiCompatibleChat } from "./openai-compatible-chat";
|
||||
import { callGeminiChat, streamGeminiChat } from "./gemini-chat";
|
||||
import { callOpenAiCompatibleChat, streamOpenAiCompatibleChat } from "./openai-compatible-chat";
|
||||
|
||||
export type VibnChatCallOpts = {
|
||||
systemPrompt: string;
|
||||
@@ -33,3 +33,13 @@ export async function callVibnChat(opts: VibnChatCallOpts) {
|
||||
}
|
||||
return callGeminiChat(opts);
|
||||
}
|
||||
|
||||
|
||||
export async function* streamVibnChat(opts: VibnChatCallOpts) {
|
||||
const p = (process.env.VIBN_CHAT_PROVIDER || "gemini").toLowerCase().trim();
|
||||
if (p === "deepseek" || p === "openai_compatible") {
|
||||
yield* streamOpenAiCompatibleChat(opts);
|
||||
return;
|
||||
}
|
||||
yield* streamGeminiChat(opts);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user