deploy: current vibn theia state
Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { FrontendApplicationContribution } from '@theia/core/lib/browser';
|
||||
import { inject, injectable } from '@theia/core/shared/inversify';
|
||||
import { OpenAiLanguageModelsManager, OpenAiModelDescription, OPENAI_PROVIDER_ID } from '../common';
|
||||
import { API_KEY_PREF, CUSTOM_ENDPOINTS_PREF, MODELS_PREF, USE_RESPONSE_API_PREF } from '../common/openai-preferences';
|
||||
import { AICorePreferences, PREFERENCE_NAME_MAX_RETRIES } from '@theia/ai-core/lib/common/ai-core-preferences';
|
||||
import { PreferenceService } from '@theia/core';
|
||||
|
||||
@injectable()
|
||||
export class OpenAiFrontendApplicationContribution implements FrontendApplicationContribution {
|
||||
|
||||
@inject(PreferenceService)
|
||||
protected preferenceService: PreferenceService;
|
||||
|
||||
@inject(OpenAiLanguageModelsManager)
|
||||
protected manager: OpenAiLanguageModelsManager;
|
||||
|
||||
@inject(AICorePreferences)
|
||||
protected aiCorePreferences: AICorePreferences;
|
||||
|
||||
protected prevModels: string[] = [];
|
||||
protected prevCustomModels: Partial<OpenAiModelDescription>[] = [];
|
||||
|
||||
onStart(): void {
|
||||
this.preferenceService.ready.then(() => {
|
||||
const apiKey = this.preferenceService.get<string>(API_KEY_PREF, undefined);
|
||||
this.manager.setApiKey(apiKey);
|
||||
|
||||
const proxyUri = this.preferenceService.get<string>('http.proxy', undefined);
|
||||
this.manager.setProxyUrl(proxyUri);
|
||||
|
||||
const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
|
||||
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createOpenAIModelDescription(modelId)));
|
||||
this.prevModels = [...models];
|
||||
|
||||
const customModels = this.preferenceService.get<Partial<OpenAiModelDescription>[]>(CUSTOM_ENDPOINTS_PREF, []);
|
||||
this.manager.createOrUpdateLanguageModels(...this.createCustomModelDescriptionsFromPreferences(customModels));
|
||||
this.prevCustomModels = [...customModels];
|
||||
|
||||
this.preferenceService.onPreferenceChanged(event => {
|
||||
if (event.preferenceName === API_KEY_PREF) {
|
||||
this.manager.setApiKey(this.preferenceService.get<string>(API_KEY_PREF, undefined));
|
||||
this.updateAllModels();
|
||||
} else if (event.preferenceName === MODELS_PREF) {
|
||||
this.handleModelChanges(this.preferenceService.get<string[]>(MODELS_PREF, []));
|
||||
} else if (event.preferenceName === CUSTOM_ENDPOINTS_PREF) {
|
||||
this.handleCustomModelChanges(this.preferenceService.get<Partial<OpenAiModelDescription>[]>(CUSTOM_ENDPOINTS_PREF, []));
|
||||
} else if (event.preferenceName === USE_RESPONSE_API_PREF) {
|
||||
this.updateAllModels();
|
||||
} else if (event.preferenceName === 'http.proxy') {
|
||||
this.manager.setProxyUrl(this.preferenceService.get<string>('http.proxy', undefined));
|
||||
}
|
||||
});
|
||||
|
||||
this.aiCorePreferences.onPreferenceChanged(event => {
|
||||
if (event.preferenceName === PREFERENCE_NAME_MAX_RETRIES) {
|
||||
this.updateAllModels();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
protected handleModelChanges(newModels: string[]): void {
|
||||
const oldModels = new Set(this.prevModels);
|
||||
const updatedModels = new Set(newModels);
|
||||
|
||||
const modelsToRemove = [...oldModels].filter(model => !updatedModels.has(model));
|
||||
const modelsToAdd = [...updatedModels].filter(model => !oldModels.has(model));
|
||||
|
||||
this.manager.removeLanguageModels(...modelsToRemove.map(model => `openai/${model}`));
|
||||
this.manager.createOrUpdateLanguageModels(...modelsToAdd.map(modelId => this.createOpenAIModelDescription(modelId)));
|
||||
this.prevModels = newModels;
|
||||
}
|
||||
|
||||
protected handleCustomModelChanges(newCustomModels: Partial<OpenAiModelDescription>[]): void {
|
||||
const oldModels = this.createCustomModelDescriptionsFromPreferences(this.prevCustomModels);
|
||||
const newModels = this.createCustomModelDescriptionsFromPreferences(newCustomModels);
|
||||
|
||||
const modelsToRemove = oldModels.filter(model => !newModels.some(newModel => newModel.id === model.id));
|
||||
const modelsToAddOrUpdate = newModels.filter(newModel =>
|
||||
!oldModels.some(model =>
|
||||
model.id === newModel.id &&
|
||||
model.model === newModel.model &&
|
||||
model.url === newModel.url &&
|
||||
model.deployment === newModel.deployment &&
|
||||
model.apiKey === newModel.apiKey &&
|
||||
model.apiVersion === newModel.apiVersion &&
|
||||
model.developerMessageSettings === newModel.developerMessageSettings &&
|
||||
model.supportsStructuredOutput === newModel.supportsStructuredOutput &&
|
||||
model.enableStreaming === newModel.enableStreaming &&
|
||||
model.useResponseApi === newModel.useResponseApi));
|
||||
|
||||
this.manager.removeLanguageModels(...modelsToRemove.map(model => model.id));
|
||||
this.manager.createOrUpdateLanguageModels(...modelsToAddOrUpdate);
|
||||
this.prevCustomModels = [...newCustomModels];
|
||||
}
|
||||
|
||||
protected updateAllModels(): void {
|
||||
const models = this.preferenceService.get<string[]>(MODELS_PREF, []);
|
||||
this.manager.createOrUpdateLanguageModels(...models.map(modelId => this.createOpenAIModelDescription(modelId)));
|
||||
|
||||
const customModels = this.preferenceService.get<Partial<OpenAiModelDescription>[]>(CUSTOM_ENDPOINTS_PREF, []);
|
||||
this.manager.createOrUpdateLanguageModels(...this.createCustomModelDescriptionsFromPreferences(customModels));
|
||||
}
|
||||
|
||||
protected createOpenAIModelDescription(modelId: string): OpenAiModelDescription {
|
||||
const id = `${OPENAI_PROVIDER_ID}/${modelId}`;
|
||||
const maxRetries = this.aiCorePreferences.get(PREFERENCE_NAME_MAX_RETRIES) ?? 3;
|
||||
const useResponseApi = this.preferenceService.get<boolean>(USE_RESPONSE_API_PREF, false);
|
||||
return {
|
||||
id: id,
|
||||
model: modelId,
|
||||
apiKey: true,
|
||||
apiVersion: true,
|
||||
developerMessageSettings: openAIModelsNotSupportingDeveloperMessages.includes(modelId) ? 'user' : 'developer',
|
||||
enableStreaming: !openAIModelsWithDisabledStreaming.includes(modelId),
|
||||
supportsStructuredOutput: !openAIModelsWithoutStructuredOutput.includes(modelId),
|
||||
maxRetries: maxRetries,
|
||||
useResponseApi: useResponseApi
|
||||
};
|
||||
}
|
||||
|
||||
protected createCustomModelDescriptionsFromPreferences(
|
||||
preferences: Partial<OpenAiModelDescription>[]
|
||||
): OpenAiModelDescription[] {
|
||||
const maxRetries = this.aiCorePreferences.get(PREFERENCE_NAME_MAX_RETRIES) ?? 3;
|
||||
return preferences.reduce((acc, pref) => {
|
||||
if (!pref.model || !pref.url || typeof pref.model !== 'string' || typeof pref.url !== 'string') {
|
||||
return acc;
|
||||
}
|
||||
|
||||
return [
|
||||
...acc,
|
||||
{
|
||||
id: pref.id && typeof pref.id === 'string' ? pref.id : pref.model,
|
||||
model: pref.model,
|
||||
url: pref.url,
|
||||
deployment: typeof pref.deployment === 'string' && pref.deployment ? pref.deployment : undefined,
|
||||
apiKey: typeof pref.apiKey === 'string' || pref.apiKey === true ? pref.apiKey : undefined,
|
||||
apiVersion: typeof pref.apiVersion === 'string' || pref.apiVersion === true ? pref.apiVersion : undefined,
|
||||
developerMessageSettings: pref.developerMessageSettings ?? 'developer',
|
||||
supportsStructuredOutput: pref.supportsStructuredOutput ?? true,
|
||||
enableStreaming: pref.enableStreaming ?? true,
|
||||
maxRetries: pref.maxRetries ?? maxRetries,
|
||||
useResponseApi: pref.useResponseApi ?? false
|
||||
}
|
||||
];
|
||||
}, []);
|
||||
}
|
||||
}
|
||||
|
||||
const openAIModelsWithDisabledStreaming: string[] = [];
|
||||
const openAIModelsNotSupportingDeveloperMessages = ['o1-preview', 'o1-mini'];
|
||||
const openAIModelsWithoutStructuredOutput = ['o1-preview', 'gpt-4-turbo', 'gpt-4', 'gpt-3.5-turbo', 'o1-mini', 'gpt-4o-2024-05-13'];
|
||||
32
packages/ai-openai/src/browser/openai-frontend-module.ts
Normal file
32
packages/ai-openai/src/browser/openai-frontend-module.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { ContainerModule } from '@theia/core/shared/inversify';
|
||||
import { OpenAiPreferencesSchema } from '../common/openai-preferences';
|
||||
import { FrontendApplicationContribution, RemoteConnectionProvider, ServiceConnectionProvider } from '@theia/core/lib/browser';
|
||||
import { OpenAiFrontendApplicationContribution } from './openai-frontend-application-contribution';
|
||||
import { OPENAI_LANGUAGE_MODELS_MANAGER_PATH, OpenAiLanguageModelsManager } from '../common';
|
||||
import { PreferenceContribution } from '@theia/core';
|
||||
|
||||
export default new ContainerModule(bind => {
|
||||
bind(PreferenceContribution).toConstantValue({ schema: OpenAiPreferencesSchema });
|
||||
bind(OpenAiFrontendApplicationContribution).toSelf().inSingletonScope();
|
||||
bind(FrontendApplicationContribution).toService(OpenAiFrontendApplicationContribution);
|
||||
bind(OpenAiLanguageModelsManager).toDynamicValue(ctx => {
|
||||
const provider = ctx.container.get<ServiceConnectionProvider>(RemoteConnectionProvider);
|
||||
return provider.createProxy<OpenAiLanguageModelsManager>(OPENAI_LANGUAGE_MODELS_MANAGER_PATH);
|
||||
}).inSingletonScope();
|
||||
});
|
||||
16
packages/ai-openai/src/common/index.ts
Normal file
16
packages/ai-openai/src/common/index.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
export * from './openai-language-models-manager';
|
||||
@@ -0,0 +1,79 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
export const OPENAI_LANGUAGE_MODELS_MANAGER_PATH = '/services/open-ai/language-model-manager';
|
||||
export const OpenAiLanguageModelsManager = Symbol('OpenAiLanguageModelsManager');
|
||||
|
||||
export const OPENAI_PROVIDER_ID = 'openai';
|
||||
|
||||
export interface OpenAiModelDescription {
|
||||
/**
|
||||
* The identifier of the model which will be shown in the UI.
|
||||
*/
|
||||
id: string;
|
||||
/**
|
||||
* The model ID as used by the OpenAI API.
|
||||
*/
|
||||
model: string;
|
||||
/**
|
||||
* The OpenAI API compatible endpoint where the model is hosted. If not provided the default OpenAI endpoint will be used.
|
||||
*/
|
||||
url?: string;
|
||||
/**
|
||||
* The key for the model. If 'true' is provided the global OpenAI API key will be used.
|
||||
*/
|
||||
apiKey: string | true | undefined;
|
||||
/**
|
||||
* The version for the api. If 'true' is provided the global OpenAI version will be used.
|
||||
*/
|
||||
apiVersion: string | true | undefined;
|
||||
/**
|
||||
* Optional deployment name for Azure OpenAI.
|
||||
*/
|
||||
deployment?: string;
|
||||
/**
|
||||
* Indicate whether the streaming API shall be used.
|
||||
*/
|
||||
enableStreaming: boolean;
|
||||
/**
|
||||
* Property to configure the developer message of the model. Setting this property to 'user', 'system', or 'developer' will use that string as the role for the system message.
|
||||
* Setting it to 'mergeWithFollowingUserMessage' will prefix the following user message with the system message or convert the system message to user if the following message
|
||||
* is not a user message. 'skip' will remove the system message altogether.
|
||||
* Defaults to 'developer'.
|
||||
*/
|
||||
developerMessageSettings?: 'user' | 'system' | 'developer' | 'mergeWithFollowingUserMessage' | 'skip';
|
||||
/**
|
||||
* Flag to configure whether the OpenAPI model supports structured output. Default is `true`.
|
||||
*/
|
||||
supportsStructuredOutput: boolean;
|
||||
/**
|
||||
* Maximum number of retry attempts when a request fails. Default is 3.
|
||||
*/
|
||||
maxRetries: number;
|
||||
/**
|
||||
* Flag to configure whether to use the newer OpenAI Response API instead of the Chat Completion API.
|
||||
* For official OpenAI models, this defaults to `true`. For custom providers, users must explicitly enable it.
|
||||
* Default is `false` for custom models.
|
||||
*/
|
||||
useResponseApi?: boolean;
|
||||
}
|
||||
export interface OpenAiLanguageModelsManager {
|
||||
apiKey: string | undefined;
|
||||
setApiKey(key: string | undefined): void;
|
||||
setApiVersion(version: string | undefined): void;
|
||||
setProxyUrl(proxyUrl: string | undefined): void;
|
||||
createOrUpdateLanguageModels(...models: OpenAiModelDescription[]): Promise<void>;
|
||||
removeLanguageModels(...modelIds: string[]): void
|
||||
}
|
||||
152
packages/ai-openai/src/common/openai-preferences.ts
Normal file
152
packages/ai-openai/src/common/openai-preferences.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { AI_CORE_PREFERENCES_TITLE } from '@theia/ai-core/lib/common/ai-core-preferences';
|
||||
import { nls, PreferenceSchema } from '@theia/core';
|
||||
|
||||
export const API_KEY_PREF = 'ai-features.openAiOfficial.openAiApiKey';
|
||||
export const MODELS_PREF = 'ai-features.openAiOfficial.officialOpenAiModels';
|
||||
export const USE_RESPONSE_API_PREF = 'ai-features.openAiOfficial.useResponseApi';
|
||||
export const CUSTOM_ENDPOINTS_PREF = 'ai-features.openAiCustom.customOpenAiModels';
|
||||
|
||||
export const OpenAiPreferencesSchema: PreferenceSchema = {
|
||||
properties: {
|
||||
[API_KEY_PREF]: {
|
||||
type: 'string',
|
||||
markdownDescription: nls.localize('theia/ai/openai/apiKey/mdDescription',
|
||||
'Enter an API Key of your official OpenAI Account. **Please note:** By using this preference the Open AI API key will be stored in clear text \
|
||||
on the machine running Theia. Use the environment variable `OPENAI_API_KEY` to set the key securely.'),
|
||||
title: AI_CORE_PREFERENCES_TITLE,
|
||||
},
|
||||
[MODELS_PREF]: {
|
||||
type: 'array',
|
||||
description: nls.localize('theia/ai/openai/models/description', 'Official OpenAI models to use'),
|
||||
title: AI_CORE_PREFERENCES_TITLE,
|
||||
default: [
|
||||
'gpt-5.2',
|
||||
'gpt-5.2-pro',
|
||||
'gpt-5.1',
|
||||
'gpt-5',
|
||||
'gpt-5-mini',
|
||||
'gpt-4.1',
|
||||
'gpt-4.1-mini',
|
||||
'gpt-4o'
|
||||
],
|
||||
items: {
|
||||
type: 'string'
|
||||
}
|
||||
},
|
||||
[USE_RESPONSE_API_PREF]: {
|
||||
type: 'boolean',
|
||||
default: false,
|
||||
title: AI_CORE_PREFERENCES_TITLE,
|
||||
markdownDescription: nls.localize('theia/ai/openai/useResponseApi/mdDescription',
|
||||
'Use the newer OpenAI Response API instead of the Chat Completion API for official OpenAI models.\
|
||||
\
|
||||
This setting only applies to official OpenAI models - custom providers must configure this individually.\
|
||||
\
|
||||
Note that for the response API, tool call definitions must satisfy Open AI\'s [strict schema definition](https://platform.openai.com/docs/guides/function-calling#strict-mode).\
|
||||
Best effort is made to convert non-conformant schemas, but errors are still possible.')
|
||||
},
|
||||
[CUSTOM_ENDPOINTS_PREF]: {
|
||||
type: 'array',
|
||||
title: AI_CORE_PREFERENCES_TITLE,
|
||||
markdownDescription: nls.localize('theia/ai/openai/customEndpoints/mdDescription',
|
||||
'Integrate custom models compatible with the OpenAI API, for example via `vllm`. The required attributes are `model` and `url`.\
|
||||
\n\
|
||||
Optionally, you can\
|
||||
\n\
|
||||
- specify a unique `id` to identify the custom model in the UI. If none is given `model` will be used as `id`.\
|
||||
\n\
|
||||
- provide an `apiKey` to access the API served at the given url. Use `true` to indicate the use of the global OpenAI API key.\
|
||||
\n\
|
||||
- provide an `apiVersion` to access the API served at the given url in Azure. Use `true` to indicate the use of the global OpenAI API version.\
|
||||
\n\
|
||||
- provide a `deployment` name for your Azure deployment.\
|
||||
\n\
|
||||
- set `developerMessageSettings` to one of `user`, `system`, `developer`, `mergeWithFollowingUserMessage`, or `skip` to control how the developer message is\
|
||||
included (where `user`, `system`, and `developer` will be used as a role, `mergeWithFollowingUserMessage` will prefix the following user message with the system\
|
||||
message or convert the system message to user message if the next message is not a user message. `skip` will just remove the system message).\
|
||||
Defaulting to `developer`.\
|
||||
\n\
|
||||
- specify `supportsStructuredOutput: false` to indicate that structured output shall not be used.\
|
||||
\n\
|
||||
- specify `enableStreaming: false` to indicate that streaming shall not be used.\
|
||||
\n\
|
||||
- specify `useResponseApi: true` to use the newer OpenAI Response API instead of the Chat Completion API (requires compatible endpoint).\
|
||||
\n\
|
||||
Refer to [our documentation](https://theia-ide.org/docs/user_ai/#openai-compatible-models-eg-via-vllm) for more information.'),
|
||||
default: [],
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
model: {
|
||||
type: 'string',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/modelId/title', 'Model ID')
|
||||
},
|
||||
url: {
|
||||
type: 'string',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/url/title', 'The Open AI API compatible endpoint where the model is hosted')
|
||||
},
|
||||
id: {
|
||||
type: 'string',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/id/title', 'A unique identifier which is used in the UI to identify the custom model'),
|
||||
},
|
||||
apiKey: {
|
||||
type: ['string', 'boolean'],
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/apiKey/title',
|
||||
'Either the key to access the API served at the given url or `true` to use the global OpenAI API key'),
|
||||
},
|
||||
apiVersion: {
|
||||
type: ['string', 'boolean'],
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/apiVersion/title',
|
||||
'Either the version to access the API served at the given url in Azure or `true` to use the global OpenAI API version'),
|
||||
},
|
||||
deployment: {
|
||||
type: 'string',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/deployment/title',
|
||||
'The deployment name to access the API served at the given url in Azure'),
|
||||
},
|
||||
developerMessageSettings: {
|
||||
type: 'string',
|
||||
enum: ['user', 'system', 'developer', 'mergeWithFollowingUserMessage', 'skip'],
|
||||
default: 'developer',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/developerMessageSettings/title',
|
||||
'Controls the handling of system messages: `user`, `system`, and `developer` will be used as a role, `mergeWithFollowingUserMessage` will prefix\
|
||||
the following user message with the system message or convert the system message to user message if the next message is not a user message.\
|
||||
`skip` will just remove the system message), defaulting to `developer`.')
|
||||
},
|
||||
supportsStructuredOutput: {
|
||||
type: 'boolean',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/supportsStructuredOutput/title',
|
||||
'Indicates whether the model supports structured output. `true` by default.'),
|
||||
},
|
||||
enableStreaming: {
|
||||
type: 'boolean',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/enableStreaming/title',
|
||||
'Indicates whether the streaming API shall be used. `true` by default.'),
|
||||
},
|
||||
useResponseApi: {
|
||||
type: 'boolean',
|
||||
title: nls.localize('theia/ai/openai/customEndpoints/useResponseApi/title',
|
||||
'Use the newer OpenAI Response API instead of the Chat Completion API. `false` by default for custom providers.'
|
||||
+ 'Note: Will automatically fall back to Chat Completions API when tools are used.'),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
42
packages/ai-openai/src/node/openai-backend-module.ts
Normal file
42
packages/ai-openai/src/node/openai-backend-module.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { ContainerModule } from '@theia/core/shared/inversify';
|
||||
import { OPENAI_LANGUAGE_MODELS_MANAGER_PATH, OpenAiLanguageModelsManager } from '../common/openai-language-models-manager';
|
||||
import { ConnectionHandler, PreferenceContribution, RpcConnectionHandler } from '@theia/core';
|
||||
import { OpenAiLanguageModelsManagerImpl } from './openai-language-models-manager-impl';
|
||||
import { ConnectionContainerModule } from '@theia/core/lib/node/messaging/connection-container-module';
|
||||
import { OpenAiModelUtils } from './openai-language-model';
|
||||
import { OpenAiResponseApiUtils } from './openai-response-api-utils';
|
||||
import { OpenAiPreferencesSchema } from '../common/openai-preferences';
|
||||
|
||||
export const OpenAiModelFactory = Symbol('OpenAiModelFactory');
|
||||
|
||||
// We use a connection module to handle AI services separately for each frontend.
|
||||
const openAiConnectionModule = ConnectionContainerModule.create(({ bind, bindBackendService, bindFrontendService }) => {
|
||||
bind(OpenAiLanguageModelsManagerImpl).toSelf().inSingletonScope();
|
||||
bind(OpenAiLanguageModelsManager).toService(OpenAiLanguageModelsManagerImpl);
|
||||
bind(ConnectionHandler).toDynamicValue(ctx =>
|
||||
new RpcConnectionHandler(OPENAI_LANGUAGE_MODELS_MANAGER_PATH, () => ctx.container.get(OpenAiLanguageModelsManager))
|
||||
).inSingletonScope();
|
||||
});
|
||||
|
||||
export default new ContainerModule(bind => {
|
||||
bind(PreferenceContribution).toConstantValue({ schema: OpenAiPreferencesSchema });
|
||||
bind(OpenAiModelUtils).toSelf().inSingletonScope();
|
||||
bind(OpenAiResponseApiUtils).toSelf().inSingletonScope();
|
||||
bind(ConnectionContainerModule).toConstantValue(openAiConnectionModule);
|
||||
});
|
||||
392
packages/ai-openai/src/node/openai-language-model.ts
Normal file
392
packages/ai-openai/src/node/openai-language-model.ts
Normal file
@@ -0,0 +1,392 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import {
|
||||
LanguageModel,
|
||||
LanguageModelParsedResponse,
|
||||
LanguageModelRequest,
|
||||
LanguageModelMessage,
|
||||
LanguageModelResponse,
|
||||
LanguageModelTextResponse,
|
||||
TokenUsageService,
|
||||
UserRequest,
|
||||
ImageContent,
|
||||
LanguageModelStatus
|
||||
} from '@theia/ai-core';
|
||||
import { CancellationToken } from '@theia/core';
|
||||
import { injectable } from '@theia/core/shared/inversify';
|
||||
import { OpenAI, AzureOpenAI } from 'openai';
|
||||
import { ChatCompletionStream } from 'openai/lib/ChatCompletionStream';
|
||||
import { RunnableToolFunctionWithoutParse } from 'openai/lib/RunnableFunction';
|
||||
import { ChatCompletionMessageParam } from 'openai/resources';
|
||||
import { StreamingAsyncIterator } from './openai-streaming-iterator';
|
||||
import { OPENAI_PROVIDER_ID } from '../common';
|
||||
import type { FinalRequestOptions } from 'openai/internal/request-options';
|
||||
import type { RunnerOptions } from 'openai/lib/AbstractChatCompletionRunner';
|
||||
import { OpenAiResponseApiUtils, processSystemMessages } from './openai-response-api-utils';
|
||||
import * as undici from 'undici';
|
||||
|
||||
export class MistralFixedOpenAI extends OpenAI {
|
||||
protected override async prepareOptions(options: FinalRequestOptions): Promise<void> {
|
||||
const messages = (options.body as { messages: Array<ChatCompletionMessageParam> }).messages;
|
||||
if (Array.isArray(messages)) {
|
||||
(options.body as { messages: Array<ChatCompletionMessageParam> }).messages.forEach(m => {
|
||||
if (m.role === 'assistant' && m.tool_calls) {
|
||||
// Mistral OpenAI Endpoint expects refusal to be undefined and not null for optional properties
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
if (m.refusal === null) {
|
||||
m.refusal = undefined;
|
||||
}
|
||||
// Mistral OpenAI Endpoint expects parsed to be undefined and not null for optional properties
|
||||
// eslint-disable-next-line no-null/no-null
|
||||
if ((m as unknown as { parsed: null | undefined }).parsed === null) {
|
||||
(m as unknown as { parsed: null | undefined }).parsed = undefined;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
return super.prepareOptions(options);
|
||||
};
|
||||
}
|
||||
|
||||
export const OpenAiModelIdentifier = Symbol('OpenAiModelIdentifier');
|
||||
|
||||
export type DeveloperMessageSettings = 'user' | 'system' | 'developer' | 'mergeWithFollowingUserMessage' | 'skip';
|
||||
|
||||
export class OpenAiModel implements LanguageModel {
|
||||
|
||||
/**
|
||||
* The options for the OpenAI runner.
|
||||
*/
|
||||
protected runnerOptions: RunnerOptions = {
|
||||
// The maximum number of chat completions to return in a single request.
|
||||
// Each function call counts as a chat completion.
|
||||
// To support use cases with many function calls (e.g. @Coder), we set this to a high value.
|
||||
maxChatCompletions: 100,
|
||||
};
|
||||
|
||||
/**
|
||||
* @param id the unique id for this language model. It will be used to identify the model in the UI.
|
||||
* @param model the model id as it is used by the OpenAI API
|
||||
* @param enableStreaming whether the streaming API shall be used
|
||||
* @param apiKey a function that returns the API key to use for this model, called on each request
|
||||
* @param apiVersion a function that returns the OpenAPI version to use for this model, called on each request
|
||||
* @param developerMessageSettings how to handle system messages
|
||||
* @param url the OpenAI API compatible endpoint where the model is hosted. If not provided the default OpenAI endpoint will be used.
|
||||
* @param maxRetries the maximum number of retry attempts when a request fails
|
||||
* @param useResponseApi whether to use the newer OpenAI Response API instead of the Chat Completion API
|
||||
*/
|
||||
constructor(
|
||||
public readonly id: string,
|
||||
public model: string,
|
||||
public status: LanguageModelStatus,
|
||||
public enableStreaming: boolean,
|
||||
public apiKey: () => string | undefined,
|
||||
public apiVersion: () => string | undefined,
|
||||
public supportsStructuredOutput: boolean,
|
||||
public url: string | undefined,
|
||||
public deployment: string | undefined,
|
||||
public openAiModelUtils: OpenAiModelUtils,
|
||||
public responseApiUtils: OpenAiResponseApiUtils,
|
||||
public developerMessageSettings: DeveloperMessageSettings = 'developer',
|
||||
public maxRetries: number = 3,
|
||||
public useResponseApi: boolean = false,
|
||||
protected readonly tokenUsageService?: TokenUsageService,
|
||||
protected proxy?: string
|
||||
) { }
|
||||
|
||||
protected getSettings(request: LanguageModelRequest): Record<string, unknown> {
|
||||
return request.settings ?? {};
|
||||
}
|
||||
|
||||
async request(request: UserRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
|
||||
const openai = this.initializeOpenAi();
|
||||
|
||||
return this.useResponseApi ?
|
||||
this.handleResponseApiRequest(openai, request, cancellationToken)
|
||||
: this.handleChatCompletionsRequest(openai, request, cancellationToken);
|
||||
}
|
||||
|
||||
protected async handleChatCompletionsRequest(openai: OpenAI, request: UserRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
|
||||
const settings = this.getSettings(request);
|
||||
|
||||
if (request.response_format?.type === 'json_schema' && this.supportsStructuredOutput) {
|
||||
return this.handleStructuredOutputRequest(openai, request);
|
||||
}
|
||||
|
||||
if (this.isNonStreamingModel(this.model) || (typeof settings.stream === 'boolean' && !settings.stream)) {
|
||||
return this.handleNonStreamingRequest(openai, request);
|
||||
}
|
||||
|
||||
if (this.id.startsWith(`${OPENAI_PROVIDER_ID}/`)) {
|
||||
settings['stream_options'] = { include_usage: true };
|
||||
}
|
||||
|
||||
if (cancellationToken?.isCancellationRequested) {
|
||||
return { text: '' };
|
||||
}
|
||||
let runner: ChatCompletionStream;
|
||||
const tools = this.createTools(request);
|
||||
|
||||
if (tools) {
|
||||
runner = openai.chat.completions.runTools({
|
||||
model: this.model,
|
||||
messages: this.processMessages(request.messages),
|
||||
stream: true,
|
||||
tools: tools,
|
||||
tool_choice: 'auto',
|
||||
...settings
|
||||
}, {
|
||||
...this.runnerOptions, maxRetries: this.maxRetries
|
||||
});
|
||||
} else {
|
||||
runner = openai.chat.completions.stream({
|
||||
model: this.model,
|
||||
messages: this.processMessages(request.messages),
|
||||
stream: true,
|
||||
...settings
|
||||
});
|
||||
}
|
||||
|
||||
return { stream: new StreamingAsyncIterator(runner, request.requestId, cancellationToken, this.tokenUsageService, this.id) };
|
||||
}
|
||||
|
||||
protected async handleNonStreamingRequest(openai: OpenAI, request: UserRequest): Promise<LanguageModelTextResponse> {
|
||||
const settings = this.getSettings(request);
|
||||
const response = await openai.chat.completions.create({
|
||||
model: this.model,
|
||||
messages: this.processMessages(request.messages),
|
||||
...settings
|
||||
});
|
||||
|
||||
const message = response.choices[0].message;
|
||||
|
||||
// Record token usage if token usage service is available
|
||||
if (this.tokenUsageService && response.usage) {
|
||||
await this.tokenUsageService.recordTokenUsage(
|
||||
this.id,
|
||||
{
|
||||
inputTokens: response.usage.prompt_tokens,
|
||||
outputTokens: response.usage.completion_tokens,
|
||||
requestId: request.requestId
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
text: message.content ?? ''
|
||||
};
|
||||
}
|
||||
|
||||
protected isNonStreamingModel(_model: string): boolean {
|
||||
return !this.enableStreaming;
|
||||
}
|
||||
|
||||
protected async handleStructuredOutputRequest(openai: OpenAI, request: UserRequest): Promise<LanguageModelParsedResponse> {
|
||||
const settings = this.getSettings(request);
|
||||
// TODO implement tool support for structured output (parse() seems to require different tool format)
|
||||
const result = await openai.chat.completions.parse({
|
||||
model: this.model,
|
||||
messages: this.processMessages(request.messages),
|
||||
response_format: request.response_format,
|
||||
...settings
|
||||
});
|
||||
const message = result.choices[0].message;
|
||||
if (message.refusal || message.parsed === undefined) {
|
||||
console.error('Error in OpenAI chat completion stream:', JSON.stringify(message));
|
||||
}
|
||||
|
||||
// Record token usage if token usage service is available
|
||||
if (this.tokenUsageService && result.usage) {
|
||||
await this.tokenUsageService.recordTokenUsage(
|
||||
this.id,
|
||||
{
|
||||
inputTokens: result.usage.prompt_tokens,
|
||||
outputTokens: result.usage.completion_tokens,
|
||||
requestId: request.requestId
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
content: message.content ?? '',
|
||||
parsed: message.parsed
|
||||
};
|
||||
}
|
||||
|
||||
protected createTools(request: LanguageModelRequest): RunnableToolFunctionWithoutParse[] | undefined {
|
||||
return request.tools?.map(tool => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
function: (args_string: string) => tool.handler(args_string)
|
||||
}
|
||||
} as RunnableToolFunctionWithoutParse));
|
||||
}
|
||||
|
||||
protected initializeOpenAi(): OpenAI {
|
||||
const apiKey = this.apiKey();
|
||||
if (!apiKey && !(this.url)) {
|
||||
throw new Error('Please provide OPENAI_API_KEY in preferences or via environment variable');
|
||||
}
|
||||
|
||||
const apiVersion = this.apiVersion();
|
||||
// We need to hand over "some" key, even if a custom url is not key protected as otherwise the OpenAI client will throw an error
|
||||
const key = apiKey ?? 'no-key';
|
||||
|
||||
let fo;
|
||||
if (this.proxy) {
|
||||
const proxyAgent = new undici.ProxyAgent(this.proxy);
|
||||
fo = {
|
||||
dispatcher: proxyAgent,
|
||||
};
|
||||
}
|
||||
|
||||
if (apiVersion) {
|
||||
return new AzureOpenAI({ apiKey: key, baseURL: this.url, apiVersion: apiVersion, deployment: this.deployment, fetchOptions: fo });
|
||||
} else {
|
||||
return new MistralFixedOpenAI({ apiKey: key, baseURL: this.url, fetchOptions: fo });
|
||||
}
|
||||
}
|
||||
|
||||
protected async handleResponseApiRequest(openai: OpenAI, request: UserRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
|
||||
const settings = this.getSettings(request);
|
||||
const isStreamingRequest = this.enableStreaming && !(typeof settings.stream === 'boolean' && !settings.stream);
|
||||
|
||||
try {
|
||||
return await this.responseApiUtils.handleRequest(
|
||||
openai,
|
||||
request,
|
||||
settings,
|
||||
this.model,
|
||||
this.openAiModelUtils,
|
||||
this.developerMessageSettings,
|
||||
this.runnerOptions,
|
||||
this.id,
|
||||
isStreamingRequest,
|
||||
this.tokenUsageService,
|
||||
cancellationToken
|
||||
);
|
||||
} catch (error) {
|
||||
// If Response API fails, fall back to Chat Completions API
|
||||
if (error instanceof Error) {
|
||||
console.warn(`Response API failed for model ${this.id}, falling back to Chat Completions API:`, error.message);
|
||||
return this.handleChatCompletionsRequest(openai, request, cancellationToken);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
protected processMessages(messages: LanguageModelMessage[]): ChatCompletionMessageParam[] {
|
||||
return this.openAiModelUtils.processMessages(messages, this.developerMessageSettings, this.model);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class for processing messages for the OpenAI language model.
|
||||
*
|
||||
* Adopters can rebind this class to implement custom message processing behavior.
|
||||
*/
|
||||
@injectable()
|
||||
export class OpenAiModelUtils {
|
||||
|
||||
protected processSystemMessages(
|
||||
messages: LanguageModelMessage[],
|
||||
developerMessageSettings: DeveloperMessageSettings
|
||||
): LanguageModelMessage[] {
|
||||
return processSystemMessages(messages, developerMessageSettings);
|
||||
}
|
||||
|
||||
protected toOpenAiRole(
|
||||
message: LanguageModelMessage,
|
||||
developerMessageSettings: DeveloperMessageSettings
|
||||
): 'developer' | 'user' | 'assistant' | 'system' {
|
||||
if (message.actor === 'system') {
|
||||
if (developerMessageSettings === 'user' || developerMessageSettings === 'system' || developerMessageSettings === 'developer') {
|
||||
return developerMessageSettings;
|
||||
} else {
|
||||
return 'developer';
|
||||
}
|
||||
} else if (message.actor === 'ai') {
|
||||
return 'assistant';
|
||||
}
|
||||
return 'user';
|
||||
}
|
||||
|
||||
protected toOpenAIMessage(
|
||||
message: LanguageModelMessage,
|
||||
developerMessageSettings: DeveloperMessageSettings
|
||||
): ChatCompletionMessageParam {
|
||||
if (LanguageModelMessage.isTextMessage(message)) {
|
||||
return {
|
||||
role: this.toOpenAiRole(message, developerMessageSettings),
|
||||
content: message.text
|
||||
};
|
||||
}
|
||||
if (LanguageModelMessage.isToolUseMessage(message)) {
|
||||
return {
|
||||
role: 'assistant',
|
||||
tool_calls: [{ id: message.id, function: { name: message.name, arguments: JSON.stringify(message.input) }, type: 'function' }]
|
||||
};
|
||||
}
|
||||
if (LanguageModelMessage.isToolResultMessage(message)) {
|
||||
return {
|
||||
role: 'tool',
|
||||
tool_call_id: message.tool_use_id,
|
||||
// content only supports text content so we need to stringify any potential data we have, e.g., images
|
||||
content: typeof message.content === 'string' ? message.content : JSON.stringify(message.content)
|
||||
};
|
||||
}
|
||||
if (LanguageModelMessage.isImageMessage(message) && message.actor === 'user') {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url:
|
||||
ImageContent.isBase64(message.image) ?
|
||||
`data:${message.image.mimeType};base64,${message.image.base64data}` :
|
||||
message.image.url
|
||||
}
|
||||
}]
|
||||
};
|
||||
}
|
||||
throw new Error(`Unknown message type:'${JSON.stringify(message)}'`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the provided list of messages by applying system message adjustments and converting
|
||||
* them to the format expected by the OpenAI API.
|
||||
*
|
||||
* Adopters can rebind this processing to implement custom behavior.
|
||||
*
|
||||
* @param messages the list of messages to process.
|
||||
* @param developerMessageSettings how system and developer messages are handled during processing.
|
||||
* @param model the OpenAI model identifier. Currently not used, but allows subclasses to implement model-specific behavior.
|
||||
* @returns an array of messages formatted for the OpenAI API.
|
||||
*/
|
||||
processMessages(
|
||||
messages: LanguageModelMessage[],
|
||||
developerMessageSettings: DeveloperMessageSettings,
|
||||
model?: string
|
||||
): ChatCompletionMessageParam[] {
|
||||
const processed = this.processSystemMessages(messages, developerMessageSettings);
|
||||
return processed.filter(m => m.type !== 'thinking').map(m => this.toOpenAIMessage(m, developerMessageSettings));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { LanguageModelRegistry, LanguageModelStatus, TokenUsageService } from '@theia/ai-core';
|
||||
import { inject, injectable } from '@theia/core/shared/inversify';
|
||||
import { OpenAiModel, OpenAiModelUtils } from './openai-language-model';
|
||||
import { OpenAiResponseApiUtils } from './openai-response-api-utils';
|
||||
import { OpenAiLanguageModelsManager, OpenAiModelDescription } from '../common';
|
||||
|
||||
@injectable()
|
||||
export class OpenAiLanguageModelsManagerImpl implements OpenAiLanguageModelsManager {
|
||||
|
||||
@inject(OpenAiModelUtils)
|
||||
protected readonly openAiModelUtils: OpenAiModelUtils;
|
||||
|
||||
@inject(OpenAiResponseApiUtils)
|
||||
protected readonly responseApiUtils: OpenAiResponseApiUtils;
|
||||
|
||||
protected _apiKey: string | undefined;
|
||||
protected _apiVersion: string | undefined;
|
||||
protected _proxyUrl: string | undefined;
|
||||
|
||||
@inject(LanguageModelRegistry)
|
||||
protected readonly languageModelRegistry: LanguageModelRegistry;
|
||||
|
||||
@inject(TokenUsageService)
|
||||
protected readonly tokenUsageService: TokenUsageService;
|
||||
|
||||
get apiKey(): string | undefined {
|
||||
return this._apiKey ?? process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
get apiVersion(): string | undefined {
|
||||
return this._apiVersion ?? process.env.OPENAI_API_VERSION;
|
||||
}
|
||||
|
||||
protected calculateStatus(modelDescription: OpenAiModelDescription, effectiveApiKey: string | undefined): LanguageModelStatus {
|
||||
// Always mark custom models (models with url) as ready for now as we do not know about API Key requirements
|
||||
if (modelDescription.url) {
|
||||
return { status: 'ready' };
|
||||
}
|
||||
return effectiveApiKey
|
||||
? { status: 'ready' }
|
||||
: { status: 'unavailable', message: 'No OpenAI API key set' };
|
||||
}
|
||||
|
||||
// Triggered from frontend. In case you want to use the models on the backend
|
||||
// without a frontend then call this yourself
|
||||
async createOrUpdateLanguageModels(...modelDescriptions: OpenAiModelDescription[]): Promise<void> {
|
||||
for (const modelDescription of modelDescriptions) {
|
||||
const model = await this.languageModelRegistry.getLanguageModel(modelDescription.id);
|
||||
const apiKeyProvider = () => {
|
||||
if (modelDescription.apiKey === true) {
|
||||
return this.apiKey;
|
||||
}
|
||||
if (modelDescription.apiKey) {
|
||||
return modelDescription.apiKey;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
const apiVersionProvider = () => {
|
||||
if (modelDescription.apiVersion === true) {
|
||||
return this.apiVersion;
|
||||
}
|
||||
if (modelDescription.apiVersion) {
|
||||
return modelDescription.apiVersion;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
const proxyUrlProvider = (url: string | undefined) => {
|
||||
// first check if the proxy url is provided via Theia settings
|
||||
if (this._proxyUrl) {
|
||||
return this._proxyUrl;
|
||||
}
|
||||
|
||||
// if not fall back to the environment variables
|
||||
let protocolVar;
|
||||
if (url && url.startsWith('http:')) {
|
||||
protocolVar = 'http_proxy';
|
||||
} else if (url && url.startsWith('https:')) {
|
||||
protocolVar = 'https_proxy';
|
||||
}
|
||||
|
||||
if (protocolVar) {
|
||||
// Get the environment variable
|
||||
return process.env[protocolVar];
|
||||
}
|
||||
|
||||
// neither the settings nor the environment variable is set
|
||||
return undefined;
|
||||
};
|
||||
|
||||
// Determine the effective API key for status
|
||||
const status = this.calculateStatus(modelDescription, apiKeyProvider());
|
||||
|
||||
if (model) {
|
||||
if (!(model instanceof OpenAiModel)) {
|
||||
console.warn(`OpenAI: model ${modelDescription.id} is not an OpenAI model`);
|
||||
continue;
|
||||
}
|
||||
await this.languageModelRegistry.patchLanguageModel<OpenAiModel>(modelDescription.id, {
|
||||
model: modelDescription.model,
|
||||
enableStreaming: modelDescription.enableStreaming,
|
||||
url: modelDescription.url,
|
||||
apiKey: apiKeyProvider,
|
||||
apiVersion: apiVersionProvider,
|
||||
deployment: modelDescription.deployment,
|
||||
developerMessageSettings: modelDescription.developerMessageSettings || 'developer',
|
||||
supportsStructuredOutput: modelDescription.supportsStructuredOutput,
|
||||
status,
|
||||
maxRetries: modelDescription.maxRetries,
|
||||
useResponseApi: modelDescription.useResponseApi ?? false
|
||||
});
|
||||
} else {
|
||||
this.languageModelRegistry.addLanguageModels([
|
||||
new OpenAiModel(
|
||||
modelDescription.id,
|
||||
modelDescription.model,
|
||||
status,
|
||||
modelDescription.enableStreaming,
|
||||
apiKeyProvider,
|
||||
apiVersionProvider,
|
||||
modelDescription.supportsStructuredOutput,
|
||||
modelDescription.url,
|
||||
modelDescription.deployment,
|
||||
this.openAiModelUtils,
|
||||
this.responseApiUtils,
|
||||
modelDescription.developerMessageSettings,
|
||||
modelDescription.maxRetries,
|
||||
modelDescription.useResponseApi ?? false,
|
||||
this.tokenUsageService,
|
||||
proxyUrlProvider(modelDescription.url)
|
||||
)
|
||||
]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removeLanguageModels(...modelIds: string[]): void {
|
||||
this.languageModelRegistry.removeLanguageModels(modelIds);
|
||||
}
|
||||
|
||||
setApiKey(apiKey: string | undefined): void {
|
||||
if (apiKey) {
|
||||
this._apiKey = apiKey;
|
||||
} else {
|
||||
this._apiKey = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
setApiVersion(apiVersion: string | undefined): void {
|
||||
if (apiVersion) {
|
||||
this._apiVersion = apiVersion;
|
||||
} else {
|
||||
this._apiVersion = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
setProxyUrl(proxyUrl: string | undefined): void {
|
||||
if (proxyUrl) {
|
||||
this._proxyUrl = proxyUrl;
|
||||
} else {
|
||||
this._proxyUrl = undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
502
packages/ai-openai/src/node/openai-model-utils.spec.ts
Normal file
502
packages/ai-openai/src/node/openai-model-utils.spec.ts
Normal file
@@ -0,0 +1,502 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2024 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
import { expect } from 'chai';
|
||||
import { OpenAiModelUtils } from './openai-language-model';
|
||||
import { LanguageModelMessage } from '@theia/ai-core';
|
||||
import { OpenAiResponseApiUtils, recursiveStrictJSONSchema } from './openai-response-api-utils';
|
||||
import type { JSONSchema, JSONSchemaDefinition } from 'openai/lib/jsonschema';
|
||||
|
||||
const utils = new OpenAiModelUtils();
|
||||
const responseUtils = new OpenAiResponseApiUtils();
|
||||
|
||||
describe('OpenAiModelUtils - processMessages', () => {
|
||||
describe("when developerMessageSettings is 'skip'", () => {
|
||||
it('should remove all system messages', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system message' },
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'system', type: 'text', text: 'another system message' },
|
||||
];
|
||||
const result = utils.processMessages(messages, 'skip', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'user message' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should do nothing if there is no system message', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'user', type: 'text', text: 'another user message' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'skip', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'user message' },
|
||||
{ role: 'user', content: 'another user message' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("when developerMessageSettings is 'mergeWithFollowingUserMessage'", () => {
|
||||
it('should merge the system message with the next user message, assign role user, and remove the system message', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system msg' },
|
||||
{ actor: 'user', type: 'text', text: 'user msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'mergeWithFollowingUserMessage', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'system msg\nuser msg' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should create a new user message if no user message exists, and remove the system message', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system only msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'mergeWithFollowingUserMessage', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'system only msg' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should create a merge multiple system message with the next user message', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'system', type: 'text', text: 'system message' },
|
||||
{ actor: 'system', type: 'text', text: 'system message2' },
|
||||
{ actor: 'user', type: 'text', text: 'user message2' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'mergeWithFollowingUserMessage', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'user message' },
|
||||
{ role: 'user', content: 'system message\nsystem message2\nuser message2' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should create a new user message from several system messages if the next message is not a user message', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'system', type: 'text', text: 'system message' },
|
||||
{ actor: 'system', type: 'text', text: 'system message2' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'mergeWithFollowingUserMessage', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'user message' },
|
||||
{ role: 'user', content: 'system message\nsystem message2' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when no special merging or skipping is needed', () => {
|
||||
it('should leave messages unchanged in ordering and assign roles based on developerMessageSettings', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'system', type: 'text', text: 'system message' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
// Using a developerMessageSettings that is not merge/skip, e.g., 'developer'
|
||||
const result = utils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'user message' },
|
||||
{ role: 'developer', content: 'system message' },
|
||||
{ role: 'assistant', content: 'ai message' }
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('role assignment for system messages when developerMessageSettings is one of the role strings', () => {
|
||||
it('should assign role as specified for a system message when developerMessageSettings is "user"', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai msg' }
|
||||
];
|
||||
// Since the first message is system and developerMessageSettings is not merge/skip, ordering is not adjusted
|
||||
const result = utils.processMessages(messages, 'user', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'user', content: 'system msg' },
|
||||
{ role: 'assistant', content: 'ai msg' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should assign role as specified for a system message when developerMessageSettings is "system"', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai msg' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'system', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'system', content: 'system msg' },
|
||||
{ role: 'assistant', content: 'ai msg' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should assign role as specified for a system message when developerMessageSettings is "developer"', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system msg' },
|
||||
{ actor: 'user', type: 'text', text: 'user msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai msg' }
|
||||
];
|
||||
const result = utils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result).to.deep.equal([
|
||||
{ role: 'developer', content: 'system msg' },
|
||||
{ role: 'user', content: 'user msg' },
|
||||
{ role: 'assistant', content: 'ai msg' }
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenAiModelUtils - processMessagesForResponseApi', () => {
|
||||
describe("when developerMessageSettings is 'skip'", () => {
|
||||
it('should remove all system messages and return no instructions', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system message' },
|
||||
{ actor: 'user', type: 'text', text: 'user message' },
|
||||
{ actor: 'system', type: 'text', text: 'another system message' },
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'skip', 'gpt-4');
|
||||
expect(result.instructions).to.be.undefined;
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'user message' }]
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("when developerMessageSettings is 'mergeWithFollowingUserMessage'", () => {
|
||||
it('should merge system message with user message and return no instructions', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'system msg' },
|
||||
{ actor: 'user', type: 'text', text: 'user msg' },
|
||||
{ actor: 'ai', type: 'text', text: 'ai message' }
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'mergeWithFollowingUserMessage', 'gpt-4');
|
||||
expect(result.instructions).to.be.undefined;
|
||||
expect(result.input).to.have.lengthOf(2);
|
||||
expect(result.input[0]).to.deep.equal({
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'system msg\nuser msg' }]
|
||||
});
|
||||
const assistantMessage = result.input[1];
|
||||
expect(assistantMessage).to.deep.include({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
status: 'completed',
|
||||
content: [{ type: 'output_text', text: 'ai message', annotations: [] }]
|
||||
});
|
||||
if (assistantMessage.type === 'message' && 'id' in assistantMessage) {
|
||||
expect(assistantMessage.id).to.be.a('string').and.to.match(/^msg_/);
|
||||
} else {
|
||||
throw new Error('Expected assistant message to have an id');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('when system messages should be converted to instructions', () => {
|
||||
it('should extract system messages as instructions and convert other messages to input items', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'You are a helpful assistant' },
|
||||
{ actor: 'user', type: 'text', text: 'Hello!' },
|
||||
{ actor: 'ai', type: 'text', text: 'Hi there!' }
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.instructions).to.equal('You are a helpful assistant');
|
||||
expect(result.input).to.have.lengthOf(2);
|
||||
expect(result.input[0]).to.deep.equal({
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'Hello!' }]
|
||||
});
|
||||
const assistantMessage = result.input[1];
|
||||
expect(assistantMessage).to.deep.include({
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
status: 'completed',
|
||||
content: [{ type: 'output_text', text: 'Hi there!', annotations: [] }]
|
||||
});
|
||||
if (assistantMessage.type === 'message' && 'id' in assistantMessage) {
|
||||
expect(assistantMessage.id).to.be.a('string').and.to.match(/^msg_/);
|
||||
} else {
|
||||
throw new Error('Expected assistant message to have an id');
|
||||
}
|
||||
});
|
||||
|
||||
it('should combine multiple system messages into instructions', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'system', type: 'text', text: 'You are helpful' },
|
||||
{ actor: 'system', type: 'text', text: 'Be concise' },
|
||||
{ actor: 'user', type: 'text', text: 'What is 2+2?' }
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.instructions).to.equal('You are helpful\nBe concise');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'What is 2+2?' }]
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tool use and tool result messages', () => {
|
||||
it('should convert tool use messages to function calls', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{ actor: 'user', type: 'text', text: 'Calculate 2+2' },
|
||||
{
|
||||
actor: 'ai',
|
||||
type: 'tool_use',
|
||||
id: 'call_123',
|
||||
name: 'calculator',
|
||||
input: { expression: '2+2' }
|
||||
}
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'Calculate 2+2' }]
|
||||
},
|
||||
{
|
||||
type: 'function_call',
|
||||
call_id: 'call_123',
|
||||
name: 'calculator',
|
||||
arguments: '{"expression":"2+2"}'
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert tool result messages to function call outputs', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{
|
||||
actor: 'user',
|
||||
type: 'tool_result',
|
||||
name: 'calculator',
|
||||
tool_use_id: 'call_123',
|
||||
content: '4'
|
||||
}
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'function_call_output',
|
||||
call_id: 'call_123',
|
||||
output: '4'
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it('should stringify non-string tool result content', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{
|
||||
actor: 'user',
|
||||
type: 'tool_result',
|
||||
name: 'data_processor',
|
||||
tool_use_id: 'call_456',
|
||||
content: { result: 'success', data: [1, 2, 3] }
|
||||
}
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'function_call_output',
|
||||
call_id: 'call_456',
|
||||
output: '{"result":"success","data":[1,2,3]}'
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('image messages', () => {
|
||||
it('should convert base64 image messages to input image items', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{
|
||||
actor: 'user',
|
||||
type: 'image',
|
||||
image: {
|
||||
mimeType: 'image/png',
|
||||
base64data: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=='
|
||||
}
|
||||
}
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'input_image',
|
||||
detail: 'auto',
|
||||
image_url: 'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg=='
|
||||
}]
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it('should convert URL image messages to input image items', () => {
|
||||
const messages: LanguageModelMessage[] = [
|
||||
{
|
||||
actor: 'user',
|
||||
type: 'image',
|
||||
image: {
|
||||
url: 'https://example.com/image.png'
|
||||
}
|
||||
}
|
||||
];
|
||||
const result = responseUtils.processMessages(messages, 'developer', 'gpt-4');
|
||||
expect(result.input).to.deep.equal([
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'input_image',
|
||||
detail: 'auto',
|
||||
image_url: 'https://example.com/image.png'
|
||||
}]
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should throw error for unknown message types', () => {
|
||||
const invalidMessage = {
|
||||
actor: 'user',
|
||||
type: 'unknown_type',
|
||||
someProperty: 'value'
|
||||
};
|
||||
const messages = [invalidMessage] as unknown as LanguageModelMessage[];
|
||||
expect(() => responseUtils.processMessages(messages, 'developer', 'gpt-4'))
|
||||
.to.throw('unhandled case');
|
||||
});
|
||||
});
|
||||
|
||||
describe('recursiveStrictJSONSchema', () => {
|
||||
it('should return the same object and not modify it when schema has no properties to strictify', () => {
|
||||
const schema: JSONSchema = { type: 'string', description: 'Simple string' };
|
||||
const originalJson = JSON.stringify(schema);
|
||||
|
||||
const result = recursiveStrictJSONSchema(schema);
|
||||
|
||||
expect(result).to.equal(schema);
|
||||
expect(JSON.stringify(schema)).to.equal(originalJson);
|
||||
const resultObj = result as JSONSchema;
|
||||
expect(resultObj).to.not.have.property('additionalProperties');
|
||||
expect(resultObj).to.not.have.property('required');
|
||||
});
|
||||
|
||||
it('should not mutate original but return a new strictified schema when branching applies (properties/items)', () => {
|
||||
const original: JSONSchema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string' },
|
||||
data: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
a: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
const originalClone = JSON.parse(JSON.stringify(original));
|
||||
|
||||
const resultDef = recursiveStrictJSONSchema(original);
|
||||
const result = resultDef as JSONSchema;
|
||||
|
||||
expect(result).to.not.equal(original);
|
||||
expect(original).to.deep.equal(originalClone);
|
||||
|
||||
expect(result.additionalProperties).to.equal(false);
|
||||
expect(result.required).to.have.members(['path', 'data']);
|
||||
|
||||
const itemsDef = (result.properties?.data as JSONSchema).items as JSONSchemaDefinition;
|
||||
expect(itemsDef).to.be.ok;
|
||||
const itemsObj = itemsDef as JSONSchema;
|
||||
expect(itemsObj.additionalProperties).to.equal(false);
|
||||
expect(itemsObj.required).to.have.members(['a']);
|
||||
|
||||
const originalItems = ((original.properties!.data as JSONSchema).items) as JSONSchema;
|
||||
expect(originalItems).to.not.have.property('additionalProperties');
|
||||
expect(originalItems).to.not.have.property('required');
|
||||
});
|
||||
|
||||
it('should strictify nested parameters schema and not mutate the original', () => {
|
||||
const replacementProperties: Record<string, JSONSchema> = {
|
||||
oldContent: { type: 'string', description: 'The exact content to be replaced. Must match exactly, including whitespace, comments, etc.' },
|
||||
newContent: { type: 'string', description: 'The new content to insert in place of matched old content.' },
|
||||
multiple: { type: 'boolean', description: 'Set to true if multiple occurrences of the oldContent are expected to be replaced.' }
|
||||
};
|
||||
|
||||
const parameters: JSONSchema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string', description: 'The path of the file where content will be replaced.' },
|
||||
replacements: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'object',
|
||||
properties: replacementProperties,
|
||||
required: ['oldContent', 'newContent']
|
||||
},
|
||||
description: 'An array of replacement objects, each containing oldContent and newContent strings.'
|
||||
},
|
||||
reset: {
|
||||
type: 'boolean',
|
||||
description: 'Set to true to clear any existing pending changes for this file and start fresh. Default is false, which merges with existing changes.'
|
||||
}
|
||||
},
|
||||
required: ['path', 'replacements']
|
||||
};
|
||||
|
||||
const originalClone = JSON.parse(JSON.stringify(parameters));
|
||||
|
||||
const strictifiedDef = recursiveStrictJSONSchema(parameters);
|
||||
const strictified = strictifiedDef as JSONSchema;
|
||||
|
||||
expect(strictified).to.not.equal(parameters);
|
||||
expect(parameters).to.deep.equal(originalClone);
|
||||
|
||||
expect(strictified.additionalProperties).to.equal(false);
|
||||
expect(strictified.required).to.have.members(['path', 'replacements', 'reset']);
|
||||
|
||||
const items = (strictified.properties!.replacements as JSONSchema).items as JSONSchema;
|
||||
expect(items.additionalProperties).to.equal(false);
|
||||
expect(items.required).to.have.members(['oldContent', 'newContent', 'multiple']);
|
||||
|
||||
const origItems = ((parameters.properties!.replacements as JSONSchema).items) as JSONSchema;
|
||||
expect(origItems.required).to.deep.equal(['oldContent', 'newContent']);
|
||||
expect(origItems).to.not.have.property('additionalProperties');
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
23
packages/ai-openai/src/node/openai-request-api-context.ts
Normal file
23
packages/ai-openai/src/node/openai-request-api-context.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2025 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
interface OpenAIRequestApiContext {
|
||||
parent?: OpenAIRequestApiContext;
|
||||
|
||||
}
|
||||
|
||||
// export class OpenAIRequestApiContext {
|
||||
// }
|
||||
841
packages/ai-openai/src/node/openai-response-api-utils.ts
Normal file
841
packages/ai-openai/src/node/openai-response-api-utils.ts
Normal file
@@ -0,0 +1,841 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2025 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import {
|
||||
createToolCallError,
|
||||
ImageContent,
|
||||
LanguageModelMessage,
|
||||
LanguageModelResponse,
|
||||
LanguageModelStreamResponsePart,
|
||||
TextMessage,
|
||||
TokenUsageService,
|
||||
ToolInvocationContext,
|
||||
ToolRequest,
|
||||
ToolRequestParameters,
|
||||
UserRequest
|
||||
} from '@theia/ai-core';
|
||||
import { CancellationToken, unreachable } from '@theia/core';
|
||||
import { Deferred } from '@theia/core/lib/common/promise-util';
|
||||
import { injectable } from '@theia/core/shared/inversify';
|
||||
import { OpenAI } from 'openai';
|
||||
import type { RunnerOptions } from 'openai/lib/AbstractChatCompletionRunner';
|
||||
import type {
|
||||
FunctionTool,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInputItem,
|
||||
ResponseStreamEvent
|
||||
} from 'openai/resources/responses/responses';
|
||||
import type { ResponsesModel } from 'openai/resources/shared';
|
||||
import { DeveloperMessageSettings, OpenAiModelUtils } from './openai-language-model';
|
||||
import { JSONSchema, JSONSchemaDefinition } from 'openai/lib/jsonschema';
|
||||
|
||||
interface ToolCall {
|
||||
id: string;
|
||||
call_id?: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
result?: unknown;
|
||||
error?: Error;
|
||||
executed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Utility class for handling OpenAI Response API requests and tool calling cycles.
|
||||
*
|
||||
* This class encapsulates the complexity of the Response API's multi-turn conversation
|
||||
* patterns for tool calling, keeping the main language model class clean and focused.
|
||||
*/
|
||||
@injectable()
|
||||
export class OpenAiResponseApiUtils {
|
||||
|
||||
/**
|
||||
* Handles Response API requests with proper tool calling cycles.
|
||||
* Works for both streaming and non-streaming cases.
|
||||
*/
|
||||
async handleRequest(
|
||||
openai: OpenAI,
|
||||
request: UserRequest,
|
||||
settings: Record<string, unknown>,
|
||||
model: string,
|
||||
modelUtils: OpenAiModelUtils,
|
||||
developerMessageSettings: DeveloperMessageSettings,
|
||||
runnerOptions: RunnerOptions,
|
||||
modelId: string,
|
||||
isStreaming: boolean,
|
||||
tokenUsageService?: TokenUsageService,
|
||||
cancellationToken?: CancellationToken
|
||||
): Promise<LanguageModelResponse> {
|
||||
if (cancellationToken?.isCancellationRequested) {
|
||||
return { text: '' };
|
||||
}
|
||||
|
||||
const { instructions, input } = this.processMessages(request.messages, developerMessageSettings, model);
|
||||
const tools = this.convertToolsForResponseApi(request.tools);
|
||||
|
||||
// If no tools are provided, use simple response handling
|
||||
if (!tools || tools.length === 0) {
|
||||
if (isStreaming) {
|
||||
const stream = openai.responses.stream({
|
||||
model: model as ResponsesModel,
|
||||
instructions,
|
||||
input,
|
||||
...settings
|
||||
});
|
||||
return { stream: this.createSimpleResponseApiStreamIterator(stream, request.requestId, modelId, tokenUsageService, cancellationToken) };
|
||||
} else {
|
||||
const response = await openai.responses.create({
|
||||
model: model as ResponsesModel,
|
||||
instructions,
|
||||
input,
|
||||
...settings
|
||||
});
|
||||
|
||||
// Record token usage if available
|
||||
if (tokenUsageService && response.usage) {
|
||||
await tokenUsageService.recordTokenUsage(
|
||||
modelId,
|
||||
{
|
||||
inputTokens: response.usage.input_tokens,
|
||||
outputTokens: response.usage.output_tokens,
|
||||
requestId: request.requestId
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return { text: response.output_text || '' };
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calling with multi-turn conversation using the unified iterator
|
||||
const iterator = new ResponseApiToolCallIterator(
|
||||
openai,
|
||||
request,
|
||||
settings,
|
||||
model,
|
||||
modelUtils,
|
||||
developerMessageSettings,
|
||||
runnerOptions,
|
||||
modelId,
|
||||
this,
|
||||
isStreaming,
|
||||
tokenUsageService,
|
||||
cancellationToken
|
||||
);
|
||||
|
||||
return { stream: iterator };
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts ToolRequest objects to the format expected by the Response API.
|
||||
*/
|
||||
convertToolsForResponseApi(tools?: ToolRequest[]): FunctionTool[] | undefined {
|
||||
if (!tools || tools.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const converted = tools.map(tool => ({
|
||||
type: 'function' as const,
|
||||
name: tool.name,
|
||||
description: tool.description || '',
|
||||
// The Response API is very strict re: JSON schema: all properties must be listed as required,
|
||||
// and additional properties must be disallowed.
|
||||
// https://platform.openai.com/docs/guides/function-calling#strict-mode
|
||||
parameters: this.recursiveStrictToolCallParameters(tool.parameters),
|
||||
strict: true
|
||||
}));
|
||||
console.debug(`Converted ${tools.length} tools for Response API:`, converted.map(t => t.name));
|
||||
return converted;
|
||||
}
|
||||
|
||||
recursiveStrictToolCallParameters(schema: ToolRequestParameters): FunctionTool['parameters'] {
|
||||
return recursiveStrictJSONSchema(schema) as FunctionTool['parameters'];
|
||||
}
|
||||
|
||||
protected createSimpleResponseApiStreamIterator(
|
||||
stream: AsyncIterable<ResponseStreamEvent>,
|
||||
requestId: string,
|
||||
modelId: string,
|
||||
tokenUsageService?: TokenUsageService,
|
||||
cancellationToken?: CancellationToken
|
||||
): AsyncIterable<LanguageModelStreamResponsePart> {
|
||||
return {
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<LanguageModelStreamResponsePart> {
|
||||
try {
|
||||
for await (const event of stream) {
|
||||
if (cancellationToken?.isCancellationRequested) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (event.type === 'response.output_text.delta') {
|
||||
yield {
|
||||
content: event.delta
|
||||
};
|
||||
} else if (event.type === 'response.completed') {
|
||||
if (tokenUsageService && event.response?.usage) {
|
||||
await tokenUsageService.recordTokenUsage(
|
||||
modelId,
|
||||
{
|
||||
inputTokens: event.response.usage.input_tokens,
|
||||
outputTokens: event.response.usage.output_tokens,
|
||||
requestId
|
||||
}
|
||||
);
|
||||
}
|
||||
} else if (event.type === 'error') {
|
||||
console.error('Response API error:', event.message);
|
||||
throw new Error(`Response API error: ${event.message}`);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in Response API stream:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes the provided list of messages by applying system message adjustments and converting
|
||||
* them directly to the format expected by the OpenAI Response API.
|
||||
*
|
||||
* This method converts messages directly without going through ChatCompletionMessageParam types.
|
||||
*
|
||||
* @param messages the list of messages to process.
|
||||
* @param developerMessageSettings how system and developer messages are handled during processing.
|
||||
* @param model the OpenAI model identifier. Currently not used, but allows subclasses to implement model-specific behavior.
|
||||
* @returns an object containing instructions and input formatted for the Response API.
|
||||
*/
|
||||
processMessages(
|
||||
messages: LanguageModelMessage[],
|
||||
developerMessageSettings: DeveloperMessageSettings,
|
||||
model: string
|
||||
): { instructions?: string; input: ResponseInputItem[] } {
|
||||
const processed = this.processSystemMessages(messages, developerMessageSettings)
|
||||
.filter(m => m.type !== 'thinking');
|
||||
|
||||
// Extract system/developer messages for instructions
|
||||
const systemMessages = processed.filter((m): m is TextMessage => m.type === 'text' && m.actor === 'system');
|
||||
const instructions = systemMessages.length > 0
|
||||
? systemMessages.map(m => m.text).join('\n')
|
||||
: undefined;
|
||||
|
||||
// Convert non-system messages to Response API input items
|
||||
const nonSystemMessages = processed.filter(m => m.actor !== 'system');
|
||||
const input: ResponseInputItem[] = [];
|
||||
|
||||
for (const message of nonSystemMessages) {
|
||||
if (LanguageModelMessage.isTextMessage(message)) {
|
||||
if (message.actor === 'ai') {
|
||||
// Assistant messages use ResponseOutputMessage format
|
||||
input.push({
|
||||
id: `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
status: 'completed',
|
||||
content: [{
|
||||
type: 'output_text',
|
||||
text: message.text,
|
||||
annotations: []
|
||||
}]
|
||||
});
|
||||
} else {
|
||||
// User messages use input format
|
||||
input.push({
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'input_text',
|
||||
text: message.text
|
||||
}]
|
||||
});
|
||||
}
|
||||
} else if (LanguageModelMessage.isToolUseMessage(message)) {
|
||||
input.push({
|
||||
type: 'function_call',
|
||||
call_id: message.id,
|
||||
name: message.name,
|
||||
arguments: JSON.stringify(message.input)
|
||||
});
|
||||
} else if (LanguageModelMessage.isToolResultMessage(message)) {
|
||||
const content = typeof message.content === 'string' ? message.content : JSON.stringify(message.content);
|
||||
input.push({
|
||||
type: 'function_call_output',
|
||||
call_id: message.tool_use_id,
|
||||
output: content
|
||||
});
|
||||
} else if (LanguageModelMessage.isImageMessage(message)) {
|
||||
input.push({
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{
|
||||
type: 'input_image',
|
||||
detail: 'auto',
|
||||
image_url: ImageContent.isBase64(message.image) ?
|
||||
`data:${message.image.mimeType};base64,${message.image.base64data}` :
|
||||
message.image.url
|
||||
}]
|
||||
});
|
||||
} else if (LanguageModelMessage.isThinkingMessage(message)) {
|
||||
// Pass
|
||||
} else {
|
||||
unreachable(message);
|
||||
}
|
||||
}
|
||||
|
||||
return { instructions, input };
|
||||
}
|
||||
|
||||
protected processSystemMessages(
|
||||
messages: LanguageModelMessage[],
|
||||
developerMessageSettings: DeveloperMessageSettings
|
||||
): LanguageModelMessage[] {
|
||||
return processSystemMessages(messages, developerMessageSettings);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterator for handling Response API streaming with tool calls.
|
||||
* Based on the pattern from openai-streaming-iterator.ts but adapted for Response API.
|
||||
*/
|
||||
class ResponseApiToolCallIterator implements AsyncIterableIterator<LanguageModelStreamResponsePart> {
|
||||
protected readonly requestQueue = new Array<Deferred<IteratorResult<LanguageModelStreamResponsePart>>>();
|
||||
protected readonly messageCache = new Array<LanguageModelStreamResponsePart>();
|
||||
protected done = false;
|
||||
protected terminalError: Error | undefined = undefined;
|
||||
|
||||
// Current iteration state
|
||||
protected currentInput: ResponseInputItem[];
|
||||
protected currentToolCalls = new Map<string, ToolCall>();
|
||||
protected totalInputTokens = 0;
|
||||
protected totalOutputTokens = 0;
|
||||
protected iteration = 0;
|
||||
protected readonly maxIterations: number;
|
||||
protected readonly tools: FunctionTool[] | undefined;
|
||||
protected readonly instructions?: string;
|
||||
protected currentResponseText = '';
|
||||
|
||||
constructor(
|
||||
protected readonly openai: OpenAI,
|
||||
protected readonly request: UserRequest,
|
||||
protected readonly settings: Record<string, unknown>,
|
||||
protected readonly model: string,
|
||||
protected readonly modelUtils: OpenAiModelUtils,
|
||||
protected readonly developerMessageSettings: DeveloperMessageSettings,
|
||||
protected readonly runnerOptions: RunnerOptions,
|
||||
protected readonly modelId: string,
|
||||
protected readonly utils: OpenAiResponseApiUtils,
|
||||
protected readonly isStreaming: boolean,
|
||||
protected readonly tokenUsageService?: TokenUsageService,
|
||||
protected readonly cancellationToken?: CancellationToken
|
||||
) {
|
||||
const { instructions, input } = utils.processMessages(request.messages, developerMessageSettings, model);
|
||||
this.instructions = instructions;
|
||||
this.currentInput = input;
|
||||
this.tools = utils.convertToolsForResponseApi(request.tools);
|
||||
this.maxIterations = runnerOptions.maxChatCompletions || 100;
|
||||
|
||||
// Start the first iteration
|
||||
this.startIteration();
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator](): AsyncIterableIterator<LanguageModelStreamResponsePart> {
|
||||
return this;
|
||||
}
|
||||
|
||||
async next(): Promise<IteratorResult<LanguageModelStreamResponsePart>> {
|
||||
if (this.messageCache.length && this.requestQueue.length) {
|
||||
throw new Error('Assertion error: cache and queue should not both be populated.');
|
||||
}
|
||||
|
||||
// Deliver all the messages we got, even if we've since terminated.
|
||||
if (this.messageCache.length) {
|
||||
return {
|
||||
done: false,
|
||||
value: this.messageCache.shift()!
|
||||
};
|
||||
} else if (this.terminalError) {
|
||||
throw this.terminalError;
|
||||
} else if (this.done) {
|
||||
return {
|
||||
done: true,
|
||||
value: undefined
|
||||
};
|
||||
} else {
|
||||
const deferred = new Deferred<IteratorResult<LanguageModelStreamResponsePart>>();
|
||||
this.requestQueue.push(deferred);
|
||||
return deferred.promise;
|
||||
}
|
||||
}
|
||||
|
||||
protected async startIteration(): Promise<void> {
|
||||
try {
|
||||
while (this.iteration < this.maxIterations && !this.cancellationToken?.isCancellationRequested) {
|
||||
console.debug(`Starting Response API iteration ${this.iteration} with ${this.currentInput.length} input messages`);
|
||||
|
||||
await this.processStream();
|
||||
|
||||
// Check if we have tool calls that need execution
|
||||
if (this.currentToolCalls.size === 0) {
|
||||
// No tool calls, we're done
|
||||
this.finalize();
|
||||
return;
|
||||
}
|
||||
|
||||
// Execute all tool calls
|
||||
await this.executeToolCalls();
|
||||
|
||||
// Prepare for next iteration
|
||||
this.prepareNextIteration();
|
||||
this.iteration++;
|
||||
}
|
||||
|
||||
// Max iterations reached
|
||||
this.finalize();
|
||||
} catch (error) {
|
||||
this.terminalError = error instanceof Error ? error : new Error(String(error));
|
||||
this.finalize();
|
||||
}
|
||||
}
|
||||
|
||||
protected async processStream(): Promise<void> {
|
||||
this.currentToolCalls.clear();
|
||||
this.currentResponseText = '';
|
||||
|
||||
if (this.isStreaming) {
|
||||
// Use streaming API
|
||||
const stream = this.openai.responses.stream({
|
||||
model: this.model as ResponsesModel,
|
||||
instructions: this.instructions,
|
||||
input: this.currentInput,
|
||||
tools: this.tools,
|
||||
...this.settings
|
||||
});
|
||||
|
||||
for await (const event of stream) {
|
||||
if (this.cancellationToken?.isCancellationRequested) {
|
||||
break;
|
||||
}
|
||||
await this.handleStreamEvent(event);
|
||||
}
|
||||
} else {
|
||||
// Use non-streaming API but yield results incrementally
|
||||
await this.processNonStreamingResponse();
|
||||
}
|
||||
}
|
||||
|
||||
protected async processNonStreamingResponse(): Promise<void> {
|
||||
const response = await this.openai.responses.create({
|
||||
model: this.model as ResponsesModel,
|
||||
instructions: this.instructions,
|
||||
input: this.currentInput,
|
||||
tools: this.tools,
|
||||
...this.settings
|
||||
});
|
||||
|
||||
// Record token usage
|
||||
if (response.usage) {
|
||||
this.totalInputTokens += response.usage.input_tokens;
|
||||
this.totalOutputTokens += response.usage.output_tokens;
|
||||
}
|
||||
|
||||
// First, yield any text content from the response
|
||||
this.currentResponseText = response.output_text || '';
|
||||
if (this.currentResponseText) {
|
||||
this.handleIncoming({ content: this.currentResponseText });
|
||||
}
|
||||
|
||||
// Find function calls in the response
|
||||
const functionCalls = response.output?.filter((item): item is ResponseFunctionToolCall => item.type === 'function_call') || [];
|
||||
|
||||
// Process each function call
|
||||
for (const functionCall of functionCalls) {
|
||||
if (functionCall.id && functionCall.name) {
|
||||
const toolCall: ToolCall = {
|
||||
id: functionCall.id,
|
||||
call_id: functionCall.call_id || functionCall.id,
|
||||
name: functionCall.name,
|
||||
arguments: functionCall.arguments || '',
|
||||
executed: false
|
||||
};
|
||||
|
||||
this.currentToolCalls.set(functionCall.id, toolCall);
|
||||
|
||||
// Yield the tool call initiation
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: functionCall.id,
|
||||
finished: false,
|
||||
function: {
|
||||
name: functionCall.name,
|
||||
arguments: functionCall.arguments || ''
|
||||
}
|
||||
}]
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected async handleStreamEvent(event: ResponseStreamEvent): Promise<void> {
|
||||
switch (event.type) {
|
||||
case 'response.output_text.delta':
|
||||
this.currentResponseText += event.delta;
|
||||
this.handleIncoming({ content: event.delta });
|
||||
break;
|
||||
|
||||
case 'response.output_item.added':
|
||||
if (event.item?.type === 'function_call') {
|
||||
this.handleFunctionCallAdded(event.item);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'response.function_call_arguments.delta':
|
||||
this.handleFunctionCallArgsDelta(event);
|
||||
break;
|
||||
|
||||
case 'response.function_call_arguments.done':
|
||||
await this.handleFunctionCallArgsDone(event);
|
||||
break;
|
||||
|
||||
case 'response.output_item.done':
|
||||
if (event.item?.type === 'function_call') {
|
||||
this.handleFunctionCallDone(event.item);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'response.completed':
|
||||
if (event.response?.usage) {
|
||||
this.totalInputTokens += event.response.usage.input_tokens;
|
||||
this.totalOutputTokens += event.response.usage.output_tokens;
|
||||
}
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
console.error('Response API error:', event.message);
|
||||
throw new Error(`Response API error: ${event.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
protected handleFunctionCallAdded(functionCall: ResponseFunctionToolCall): void {
|
||||
if (functionCall.id && functionCall.call_id) {
|
||||
console.debug(`Function call added: ${functionCall.name} with id ${functionCall.id} and call_id ${functionCall.call_id}`);
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
id: functionCall.id,
|
||||
call_id: functionCall.call_id,
|
||||
name: functionCall.name || '',
|
||||
arguments: functionCall.arguments || '',
|
||||
executed: false
|
||||
};
|
||||
|
||||
this.currentToolCalls.set(functionCall.id, toolCall);
|
||||
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: functionCall.id,
|
||||
finished: false,
|
||||
function: {
|
||||
name: functionCall.name || '',
|
||||
arguments: functionCall.arguments || ''
|
||||
}
|
||||
}]
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected handleFunctionCallArgsDelta(event: ResponseFunctionCallArgumentsDeltaEvent): void {
|
||||
const toolCall = this.currentToolCalls.get(event.item_id);
|
||||
if (toolCall) {
|
||||
toolCall.arguments += event.delta;
|
||||
|
||||
if (event.delta) {
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: event.item_id,
|
||||
argumentsDelta: true,
|
||||
function: {
|
||||
arguments: event.delta
|
||||
}
|
||||
}]
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected async handleFunctionCallArgsDone(event: ResponseFunctionCallArgumentsDoneEvent): Promise<void> {
|
||||
let toolCall = this.currentToolCalls.get(event.item_id);
|
||||
if (!toolCall) {
|
||||
// Create if we didn't see the added event
|
||||
toolCall = {
|
||||
id: event.item_id,
|
||||
name: event.name || '',
|
||||
arguments: event.arguments || '',
|
||||
executed: false
|
||||
};
|
||||
this.currentToolCalls.set(event.item_id, toolCall);
|
||||
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: event.item_id,
|
||||
finished: false,
|
||||
function: {
|
||||
name: event.name || '',
|
||||
arguments: event.arguments || ''
|
||||
}
|
||||
}]
|
||||
});
|
||||
} else {
|
||||
// Update with final values
|
||||
toolCall.name = event.name || toolCall.name;
|
||||
toolCall.arguments = event.arguments || toolCall.arguments;
|
||||
}
|
||||
}
|
||||
|
||||
protected handleFunctionCallDone(functionCall: ResponseFunctionToolCall): void {
|
||||
if (!functionCall.id) { console.warn('Unexpected absence of ID for call ID', functionCall.call_id); return; }
|
||||
const toolCall = this.currentToolCalls.get(functionCall.id);
|
||||
if (toolCall && !toolCall.call_id && functionCall.call_id) {
|
||||
toolCall.call_id = functionCall.call_id;
|
||||
}
|
||||
}
|
||||
|
||||
protected async executeToolCalls(): Promise<void> {
|
||||
for (const [itemId, toolCall] of this.currentToolCalls) {
|
||||
if (toolCall.executed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const tool = this.request.tools?.find(t => t.name === toolCall.name);
|
||||
if (tool) {
|
||||
try {
|
||||
const result = await tool.handler(toolCall.arguments, ToolInvocationContext.create(itemId));
|
||||
toolCall.result = result;
|
||||
|
||||
// Yield the tool call completion
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: itemId,
|
||||
finished: true,
|
||||
function: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments
|
||||
},
|
||||
result
|
||||
}]
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error executing tool ${toolCall.name}:`, error);
|
||||
toolCall.error = error instanceof Error ? error : new Error(String(error));
|
||||
|
||||
// Yield the tool call error
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: itemId,
|
||||
finished: true,
|
||||
function: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments
|
||||
},
|
||||
result: createToolCallError(error instanceof Error ? error.message : String(error))
|
||||
}]
|
||||
});
|
||||
}
|
||||
} else {
|
||||
console.warn(`Tool ${toolCall.name} not found in request tools`);
|
||||
toolCall.error = new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
// Yield the tool call error
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: itemId,
|
||||
finished: true,
|
||||
function: {
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments
|
||||
},
|
||||
result: createToolCallError(`Tool '${toolCall.name}' not found in the available tools for this request.`, 'tool-not-available')
|
||||
}]
|
||||
});
|
||||
}
|
||||
|
||||
toolCall.executed = true;
|
||||
}
|
||||
}
|
||||
|
||||
protected prepareNextIteration(): void {
|
||||
// Add assistant response with the actual text that was streamed
|
||||
const assistantMessage: ResponseInputItem = {
|
||||
role: 'assistant',
|
||||
content: this.currentResponseText
|
||||
};
|
||||
|
||||
// Add the function calls that were made by the assistant
|
||||
const functionCalls: ResponseInputItem[] = [];
|
||||
for (const [itemId, toolCall] of this.currentToolCalls) {
|
||||
functionCalls.push({
|
||||
type: 'function_call',
|
||||
call_id: toolCall.call_id || itemId,
|
||||
name: toolCall.name,
|
||||
arguments: toolCall.arguments
|
||||
});
|
||||
}
|
||||
|
||||
// Add tool results
|
||||
const toolResults: ResponseInputItem[] = [];
|
||||
for (const [itemId, toolCall] of this.currentToolCalls) {
|
||||
const callId = toolCall.call_id || itemId;
|
||||
|
||||
if (toolCall.result !== undefined) {
|
||||
const resultContent = typeof toolCall.result === 'string' ? toolCall.result : JSON.stringify(toolCall.result);
|
||||
toolResults.push({
|
||||
type: 'function_call_output',
|
||||
call_id: callId,
|
||||
output: resultContent
|
||||
});
|
||||
} else if (toolCall.error) {
|
||||
toolResults.push({
|
||||
type: 'function_call_output',
|
||||
call_id: callId,
|
||||
output: `Error: ${toolCall.error.message}`
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this.currentInput = [...this.currentInput, assistantMessage, ...functionCalls, ...toolResults];
|
||||
}
|
||||
|
||||
protected handleIncoming(message: LanguageModelStreamResponsePart): void {
|
||||
if (this.messageCache.length && this.requestQueue.length) {
|
||||
throw new Error('Assertion error: cache and queue should not both be populated.');
|
||||
}
|
||||
|
||||
if (this.requestQueue.length) {
|
||||
this.requestQueue.shift()!.resolve({
|
||||
done: false,
|
||||
value: message
|
||||
});
|
||||
} else {
|
||||
this.messageCache.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
protected async finalize(): Promise<void> {
|
||||
this.done = true;
|
||||
|
||||
// Record final token usage
|
||||
if (this.tokenUsageService && (this.totalInputTokens > 0 || this.totalOutputTokens > 0)) {
|
||||
try {
|
||||
await this.tokenUsageService.recordTokenUsage(
|
||||
this.modelId,
|
||||
{
|
||||
inputTokens: this.totalInputTokens,
|
||||
outputTokens: this.totalOutputTokens,
|
||||
requestId: this.request.requestId
|
||||
}
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Error recording token usage:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve any outstanding requests
|
||||
if (this.terminalError) {
|
||||
this.requestQueue.forEach(request => request.reject(this.terminalError));
|
||||
} else {
|
||||
this.requestQueue.forEach(request => request.resolve({ done: true, value: undefined }));
|
||||
}
|
||||
this.requestQueue.length = 0;
|
||||
}
|
||||
}
|
||||
|
||||
export function processSystemMessages(
|
||||
messages: LanguageModelMessage[],
|
||||
developerMessageSettings: DeveloperMessageSettings
|
||||
): LanguageModelMessage[] {
|
||||
if (developerMessageSettings === 'skip') {
|
||||
return messages.filter(message => message.actor !== 'system');
|
||||
} else if (developerMessageSettings === 'mergeWithFollowingUserMessage') {
|
||||
const updated = messages.slice();
|
||||
for (let i = updated.length - 1; i >= 0; i--) {
|
||||
if (updated[i].actor === 'system') {
|
||||
const systemMessage = updated[i] as TextMessage;
|
||||
if (i + 1 < updated.length && updated[i + 1].actor === 'user') {
|
||||
// Merge system message with the next user message
|
||||
const userMessage = updated[i + 1] as TextMessage;
|
||||
updated[i + 1] = {
|
||||
...updated[i + 1],
|
||||
text: systemMessage.text + '\n' + userMessage.text
|
||||
} as TextMessage;
|
||||
updated.splice(i, 1);
|
||||
} else {
|
||||
// The message directly after is not a user message (or none exists), so create a new user message right after
|
||||
updated.splice(i + 1, 0, { actor: 'user', type: 'text', text: systemMessage.text });
|
||||
updated.splice(i, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return updated;
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
export function recursiveStrictJSONSchema(schema: JSONSchemaDefinition): JSONSchemaDefinition {
|
||||
if (typeof schema === 'boolean') { return schema; }
|
||||
let result: JSONSchema | undefined = undefined;
|
||||
if (schema.properties) {
|
||||
result ??= { ...schema };
|
||||
result.additionalProperties = false;
|
||||
result.required = Object.keys(schema.properties);
|
||||
result.properties = Object.fromEntries(Object.entries(schema.properties).map(([key, props]) => [key, recursiveStrictJSONSchema(props)]));
|
||||
}
|
||||
if (schema.items) {
|
||||
result ??= { ...schema };
|
||||
result.items = Array.isArray(schema.items)
|
||||
? schema.items.map(recursiveStrictJSONSchema)
|
||||
: recursiveStrictJSONSchema(schema.items);
|
||||
}
|
||||
if (schema.oneOf) {
|
||||
result ??= { ...schema };
|
||||
result.oneOf = schema.oneOf.map(recursiveStrictJSONSchema);
|
||||
}
|
||||
if (schema.anyOf) {
|
||||
result ??= { ...schema };
|
||||
result.anyOf = schema.anyOf.map(recursiveStrictJSONSchema);
|
||||
}
|
||||
if (schema.allOf) {
|
||||
result ??= { ...schema };
|
||||
result.allOf = schema.allOf.map(recursiveStrictJSONSchema);
|
||||
}
|
||||
if (schema.if) {
|
||||
result ??= { ...schema };
|
||||
result.if = recursiveStrictJSONSchema(schema.if);
|
||||
}
|
||||
if (schema.then) {
|
||||
result ??= { ...schema };
|
||||
result.then = recursiveStrictJSONSchema(schema.then);
|
||||
}
|
||||
if (schema.else) {
|
||||
result ??= { ...schema };
|
||||
result.else = recursiveStrictJSONSchema(schema.else);
|
||||
}
|
||||
if (schema.not) {
|
||||
result ??= { ...schema };
|
||||
result.not = recursiveStrictJSONSchema(schema.not);
|
||||
}
|
||||
|
||||
return result ?? schema;
|
||||
}
|
||||
255
packages/ai-openai/src/node/openai-streaming-iterator.spec.ts
Normal file
255
packages/ai-openai/src/node/openai-streaming-iterator.spec.ts
Normal file
@@ -0,0 +1,255 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2025 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { expect } from 'chai';
|
||||
import * as sinon from 'sinon';
|
||||
import { StreamingAsyncIterator } from './openai-streaming-iterator';
|
||||
import { ChatCompletionStream } from 'openai/lib/ChatCompletionStream';
|
||||
import { CancellationTokenSource, CancellationError } from '@theia/core';
|
||||
import { LanguageModelStreamResponsePart, isTextResponsePart, isToolCallResponsePart } from '@theia/ai-core';
|
||||
import { EventEmitter } from 'events';
|
||||
import { ChatCompletionToolMessageParam } from 'openai/resources';
|
||||
|
||||
describe('StreamingAsyncIterator', () => {
|
||||
let mockStream: ChatCompletionStream & EventEmitter;
|
||||
let iterator: StreamingAsyncIterator;
|
||||
let cts: CancellationTokenSource;
|
||||
const consoleError = console.error;
|
||||
|
||||
beforeEach(() => {
|
||||
mockStream = new EventEmitter() as ChatCompletionStream & EventEmitter;
|
||||
mockStream.abort = sinon.stub();
|
||||
|
||||
cts = new CancellationTokenSource();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (iterator) {
|
||||
iterator.dispose();
|
||||
}
|
||||
cts.dispose();
|
||||
console.error = consoleError;
|
||||
});
|
||||
|
||||
function createIterator(withCancellationToken = false): StreamingAsyncIterator {
|
||||
return new StreamingAsyncIterator(mockStream, '', withCancellationToken ? cts.token : undefined);
|
||||
}
|
||||
|
||||
it('should yield messages in the correct order when consumed immediately', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
setTimeout(() => {
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'Hello' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: ' ' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'World' } }] });
|
||||
mockStream.emit('end');
|
||||
}, 10);
|
||||
|
||||
const results: LanguageModelStreamResponsePart[] = [];
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await iterator.next();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
expect(results).to.deep.equal([
|
||||
{ content: 'Hello' },
|
||||
{ content: ' ' },
|
||||
{ content: 'World' }
|
||||
]);
|
||||
});
|
||||
|
||||
it('should buffer messages if consumer is slower (messages arrive before .next() is called)', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'A' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'B' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'C' } }] });
|
||||
mockStream.emit('end');
|
||||
|
||||
const results: string[] = [];
|
||||
while (true) {
|
||||
const { value, done } = await iterator.next();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
results.push((isTextResponsePart(value) && value.content) || '');
|
||||
}
|
||||
|
||||
expect(results).to.deep.equal(['A', 'B', 'C']);
|
||||
});
|
||||
|
||||
it('should resolve queued next() call when a message arrives (consumer is waiting first)', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
const nextPromise = iterator.next();
|
||||
|
||||
setTimeout(() => {
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'Hello from queue' } }] });
|
||||
mockStream.emit('end');
|
||||
}, 10);
|
||||
|
||||
const first = await nextPromise;
|
||||
expect(first.done).to.be.false;
|
||||
expect(first.value.content).to.equal('Hello from queue');
|
||||
|
||||
const second = await iterator.next();
|
||||
expect(second.done).to.be.true;
|
||||
expect(second.value).to.be.undefined;
|
||||
});
|
||||
|
||||
it('should handle the end event correctly', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'EndTest1' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'EndTest2' } }] });
|
||||
mockStream.emit('end');
|
||||
|
||||
const results: string[] = [];
|
||||
while (true) {
|
||||
const { value, done } = await iterator.next();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
results.push((isTextResponsePart(value) && value.content) || '');
|
||||
}
|
||||
|
||||
expect(results).to.deep.equal(['EndTest1', 'EndTest2']);
|
||||
});
|
||||
|
||||
it('should reject pending .next() call with an error if error event occurs', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
const pendingNext = iterator.next();
|
||||
|
||||
// Suppress console.error output
|
||||
console.error = () => { };
|
||||
|
||||
const error = new Error('Stream error occurred');
|
||||
mockStream.emit('error', error);
|
||||
|
||||
try {
|
||||
await pendingNext;
|
||||
expect.fail('The promise should have been rejected with an error.');
|
||||
} catch (err) {
|
||||
expect(err).to.equal(error);
|
||||
}
|
||||
});
|
||||
|
||||
it('should reject pending .next() call with a CancellationError if "abort" event occurs', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
const pendingNext = iterator.next();
|
||||
|
||||
// Suppress console.error output
|
||||
console.error = () => { };
|
||||
|
||||
mockStream.emit('abort');
|
||||
|
||||
try {
|
||||
await pendingNext;
|
||||
expect.fail('The promise should have been rejected with a CancellationError.');
|
||||
} catch (err) {
|
||||
expect(err).to.be.instanceOf(CancellationError);
|
||||
}
|
||||
});
|
||||
|
||||
it('should call stream.abort() when cancellation token is triggered', async () => {
|
||||
iterator = createIterator(true);
|
||||
|
||||
cts.cancel();
|
||||
|
||||
sinon.assert.calledOnce(mockStream.abort as sinon.SinonSpy);
|
||||
});
|
||||
|
||||
it('should not lose unconsumed messages after disposal, but no new ones arrive', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'Msg1' } }] });
|
||||
mockStream.emit('chunk', { choices: [{ delta: { content: 'Msg2' } }] });
|
||||
|
||||
iterator.dispose();
|
||||
|
||||
let result = await iterator.next();
|
||||
expect(result.done).to.be.false;
|
||||
expect(result.value.content).to.equal('Msg1');
|
||||
|
||||
result = await iterator.next();
|
||||
expect(result.done).to.be.false;
|
||||
expect(result.value.content).to.equal('Msg2');
|
||||
|
||||
result = await iterator.next();
|
||||
expect(result.done).to.be.true;
|
||||
expect(result.value).to.be.undefined;
|
||||
});
|
||||
|
||||
it('should reject all pending requests with an error if disposal occurs after stream error', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
const pendingNext1 = iterator.next();
|
||||
const pendingNext2 = iterator.next();
|
||||
|
||||
// Suppress console.error output
|
||||
console.error = () => { };
|
||||
|
||||
const error = new Error('Critical error');
|
||||
mockStream.emit('error', error);
|
||||
|
||||
try {
|
||||
await pendingNext1;
|
||||
expect.fail('expected to be rejected');
|
||||
} catch (err) {
|
||||
expect(err).to.equal(error);
|
||||
}
|
||||
|
||||
try {
|
||||
await pendingNext2;
|
||||
expect.fail('expected to be rejected');
|
||||
} catch (err) {
|
||||
expect(err).to.equal(error);
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle receiving a "message" event with role="tool"', async () => {
|
||||
iterator = createIterator();
|
||||
|
||||
setTimeout(() => {
|
||||
mockStream.emit('message', {
|
||||
role: 'tool',
|
||||
tool_call_id: 'tool-123',
|
||||
content: [{ type: 'text', text: 'Part1' }, { type: 'text', text: 'Part2' }]
|
||||
} satisfies ChatCompletionToolMessageParam);
|
||||
mockStream.emit('end');
|
||||
}, 10);
|
||||
|
||||
const results: LanguageModelStreamResponsePart[] = [];
|
||||
for await (const part of iterator) {
|
||||
results.push(part);
|
||||
}
|
||||
|
||||
expect(results).to.have.lengthOf(1);
|
||||
expect(isToolCallResponsePart(results[0]) && results[0].tool_calls).to.deep.equal([
|
||||
{
|
||||
id: 'tool-123',
|
||||
finished: true,
|
||||
result: { content: [{ type: 'text', text: 'Part1' }, { type: 'text', text: 'Part2' }] }
|
||||
}
|
||||
]);
|
||||
});
|
||||
});
|
||||
186
packages/ai-openai/src/node/openai-streaming-iterator.ts
Normal file
186
packages/ai-openai/src/node/openai-streaming-iterator.ts
Normal file
@@ -0,0 +1,186 @@
|
||||
// *****************************************************************************
|
||||
// Copyright (C) 2025 EclipseSource GmbH.
|
||||
//
|
||||
// This program and the accompanying materials are made available under the
|
||||
// terms of the Eclipse Public License v. 2.0 which is available at
|
||||
// http://www.eclipse.org/legal/epl-2.0.
|
||||
//
|
||||
// This Source Code may also be made available under the following Secondary
|
||||
// Licenses when the conditions for such availability set forth in the Eclipse
|
||||
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
|
||||
// with the GNU Classpath Exception which is available at
|
||||
// https://www.gnu.org/software/classpath/license.html.
|
||||
//
|
||||
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
|
||||
// *****************************************************************************
|
||||
|
||||
import { LanguageModelStreamResponsePart, TokenUsageService, TokenUsageParams, ToolCallResult, ToolCallTextResult } from '@theia/ai-core';
|
||||
import { CancellationError, CancellationToken, Disposable, DisposableCollection } from '@theia/core';
|
||||
import { Deferred } from '@theia/core/lib/common/promise-util';
|
||||
import { ChatCompletionStream, ChatCompletionStreamEvents } from 'openai/lib/ChatCompletionStream';
|
||||
import { ChatCompletionContentPartText } from 'openai/resources';
|
||||
|
||||
type IterResult = IteratorResult<LanguageModelStreamResponsePart>;
|
||||
|
||||
export class StreamingAsyncIterator implements AsyncIterableIterator<LanguageModelStreamResponsePart>, Disposable {
|
||||
protected readonly requestQueue = new Array<Deferred<IterResult>>();
|
||||
protected readonly messageCache = new Array<LanguageModelStreamResponsePart>();
|
||||
protected done = false;
|
||||
protected terminalError: Error | undefined = undefined;
|
||||
protected readonly toDispose = new DisposableCollection();
|
||||
|
||||
constructor(
|
||||
protected readonly stream: ChatCompletionStream,
|
||||
protected readonly requestId: string,
|
||||
cancellationToken?: CancellationToken,
|
||||
protected readonly tokenUsageService?: TokenUsageService,
|
||||
protected readonly model?: string,
|
||||
) {
|
||||
this.registerStreamListener('error', error => {
|
||||
console.error('Error in OpenAI chat completion stream:', error);
|
||||
this.terminalError = error;
|
||||
this.dispose();
|
||||
});
|
||||
this.registerStreamListener('abort', () => {
|
||||
this.terminalError = new CancellationError();
|
||||
this.dispose();
|
||||
}, true);
|
||||
this.registerStreamListener('message', message => {
|
||||
if (message.role === 'tool') {
|
||||
this.handleIncoming({
|
||||
tool_calls: [{
|
||||
id: message.tool_call_id,
|
||||
finished: true,
|
||||
result: tryParseToolResult(message.content)
|
||||
}]
|
||||
});
|
||||
}
|
||||
console.debug('Received Open AI message', JSON.stringify(message));
|
||||
});
|
||||
this.registerStreamListener('end', () => {
|
||||
this.dispose();
|
||||
}, true);
|
||||
this.registerStreamListener('chunk', (chunk, snapshot) => {
|
||||
// Handle token usage reporting
|
||||
if (chunk.usage && this.tokenUsageService && this.model) {
|
||||
const inputTokens = chunk.usage.prompt_tokens || 0;
|
||||
const outputTokens = chunk.usage.completion_tokens || 0;
|
||||
if (inputTokens > 0 || outputTokens > 0) {
|
||||
const tokenUsageParams: TokenUsageParams = {
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
requestId
|
||||
};
|
||||
this.tokenUsageService.recordTokenUsage(this.model, tokenUsageParams)
|
||||
.catch(error => console.error('Error recording token usage:', error));
|
||||
}
|
||||
}
|
||||
// Patch missing fields that OpenAI SDK requires but some providers (e.g., Copilot) don't send
|
||||
for (const choice of snapshot?.choices ?? []) {
|
||||
// Ensure role is set (required by finalizeChatCompletion)
|
||||
if (choice?.message && !choice.message.role) {
|
||||
choice.message.role = 'assistant';
|
||||
}
|
||||
// Ensure tool_calls have type set (required by #emitToolCallDoneEvent and finalizeChatCompletion)
|
||||
if (choice?.message?.tool_calls) {
|
||||
for (const call of choice.message.tool_calls) {
|
||||
if (call.type === undefined) {
|
||||
call.type = 'function';
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// OpenAI can push out reasoning tokens, but can't handle it as part of messages
|
||||
if (snapshot?.choices[0]?.message && Object.keys(snapshot.choices[0].message).includes('reasoning')) {
|
||||
const reasoning = (snapshot.choices[0].message as { reasoning: string }).reasoning;
|
||||
this.handleIncoming({ thought: reasoning, signature: '' });
|
||||
// delete message parts which cannot be handled by openai
|
||||
delete (snapshot.choices[0].message as { reasoning?: string }).reasoning;
|
||||
delete (snapshot.choices[0].message as { channel?: string }).channel;
|
||||
return;
|
||||
}
|
||||
this.handleIncoming({ ...chunk.choices[0]?.delta as LanguageModelStreamResponsePart });
|
||||
});
|
||||
if (cancellationToken) {
|
||||
this.toDispose.push(cancellationToken.onCancellationRequested(() => stream.abort()));
|
||||
}
|
||||
}
|
||||
|
||||
[Symbol.asyncIterator](): AsyncIterableIterator<LanguageModelStreamResponsePart> { return this; }
|
||||
|
||||
next(): Promise<IterResult> {
|
||||
if (this.messageCache.length && this.requestQueue.length) {
|
||||
throw new Error('Assertion error: cache and queue should not both be populated.');
|
||||
}
|
||||
// Deliver all the messages we got, even if we've since terminated.
|
||||
if (this.messageCache.length) {
|
||||
return Promise.resolve({
|
||||
done: false,
|
||||
value: this.messageCache.shift()!
|
||||
});
|
||||
} else if (this.terminalError) {
|
||||
return Promise.reject(this.terminalError);
|
||||
} else if (this.done) {
|
||||
return Promise.resolve({
|
||||
done: true,
|
||||
value: undefined
|
||||
});
|
||||
} else {
|
||||
const toQueue = new Deferred<IterResult>();
|
||||
this.requestQueue.push(toQueue);
|
||||
return toQueue.promise;
|
||||
}
|
||||
}
|
||||
|
||||
protected handleIncoming(message: LanguageModelStreamResponsePart): void {
|
||||
if (this.messageCache.length && this.requestQueue.length) {
|
||||
throw new Error('Assertion error: cache and queue should not both be populated.');
|
||||
}
|
||||
if (this.requestQueue.length) {
|
||||
this.requestQueue.shift()!.resolve({
|
||||
done: false,
|
||||
value: message
|
||||
});
|
||||
} else {
|
||||
this.messageCache.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
protected registerStreamListener<Event extends keyof ChatCompletionStreamEvents>(eventType: Event, handler: ChatCompletionStreamEvents[Event], once?: boolean): void {
|
||||
if (once) {
|
||||
this.stream.once(eventType, handler);
|
||||
} else {
|
||||
this.stream.on(eventType, handler);
|
||||
}
|
||||
this.toDispose.push({ dispose: () => this.stream.off(eventType, handler) });
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.done = true;
|
||||
this.toDispose.dispose();
|
||||
// We will be receiving no more messages. Any outstanding requests have to be handled.
|
||||
if (this.terminalError) {
|
||||
this.requestQueue.forEach(request => request.reject(this.terminalError));
|
||||
} else {
|
||||
this.requestQueue.forEach(request => request.resolve({ done: true, value: undefined }));
|
||||
}
|
||||
// Leave the message cache alone - if it was populated, then the request queue was empty, but we'll still try to deliver the messages if asked.
|
||||
this.requestQueue.length = 0;
|
||||
}
|
||||
}
|
||||
|
||||
function tryParseToolResult(result: string | ChatCompletionContentPartText[]): ToolCallResult {
|
||||
try {
|
||||
if (typeof result === 'string') {
|
||||
return JSON.parse(result);
|
||||
}
|
||||
return {
|
||||
content: result.map<ToolCallTextResult>(part => ({
|
||||
type: 'text',
|
||||
text: part.text
|
||||
}))
|
||||
};
|
||||
} catch (error) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user