如何创建自定义聊天模型类
前置条件
本指南假定您熟悉以下概念:
本笔记本介绍了如何创建自定义聊天模型包装器,以防您希望使用自己的聊天模型或与 LangChain 直接支持的不同的包装器。
在扩展 SimpleChatModel
类
后,聊天模型需要实现以下几项必要内容:
一个
_call方法,该方法接收消息列表和调用选项(包括stop序列等内容),并返回一个字符串。一个
_llmType方法,该方法返回一个字符串。仅用于日志记录。
您还可以实现以下可选方法:
- 一个
_streamResponseChunks方法,该方法返回一个AsyncGenerator并逐个生成ChatGenerationChunks。这允许 LLM 支持流式输出。
我们来实现一个非常简单的自定义聊天模型,它只是回显输入的前 n 个字符。
import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";
interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}
class CustomChatModel extends SimpleChatModel {
n: number;
constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}
_llmType() {
return "custom";
}
async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}
async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}
我们现在可以像使用其他聊天模型一样使用它:
const chatModel = new CustomChatModel({ n: 4 });
await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
lc_serializable: true,
lc_kwargs: {
content: 'I am',
tool_calls: [],
invalid_tool_calls: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I am',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
usage_metadata: undefined
}
支持流式传输:
const stream = await chatModel.stream([["human", "I am an LLM"]]);
for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'I',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: ' ',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: ' ',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'a',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'a',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'm',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'm',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
如果你想利用 LangChain
的回调系统来实现诸如令牌追踪之类的功能,你可以扩展
BaseChatModel
类并实现更低层级的 _generate 方法。它同样以一个 BaseMessage
列表作为输入,但需要你构造并返回一个允许添加额外元数据的
ChatGeneration 对象。 下面是一个示例:
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
interface AdvancedCustomChatModelOptions extends BaseChatModelCallOptions {}
interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}
class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;
static lc_name(): string {
return "AdvancedCustomChatModel";
}
constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}
这会将回调事件和 streamEvents 方法中的额外返回信息传递:
const chatModel = new AdvancedCustomChatModel({ n: 4 });
const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v2",
});
for await (const event of eventStream) {
if (event.event === "on_chat_model_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_chat_model_end",
"data": {
"output": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"tool_calls": [],
"invalid_tool_calls": [],
"additional_kwargs": {},
"response_metadata": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
},
"run_id": "11dbdef6-1b91-407e-a497-1a1ce2974788",
"name": "AdvancedCustomChatModel",
"tags": [],
"metadata": {
"ls_model_type": "chat"
}
}
追踪(高级)
如果您正在实现一个自定义的聊天模型,并希望将其与
LangSmith 等追踪服务一起使用,
您可以通过在模型上实现 invocationParams()
方法,自动记录某次调用所使用的参数。
此方法完全是可选的,但它返回的任何内容都将作为元数据记录到追踪中。
以下是一种可能的使用模式:
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
type BaseChatModelCallOptions,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}
interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
n: number;
}
class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;
n: number;
static lc_name(): string {
return "CustomChatModel";
}
constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
this.n = fields.n;
}
// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}
async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}
_llmType(): string {
return "advanced_custom_chat_model";
}
}