Skip to main content

如何创建自定义聊天模型类

前置条件

本指南假定您熟悉以下概念:

本笔记本介绍了如何创建自定义聊天模型包装器,以防您希望使用自己的聊天模型或与 LangChain 直接支持的不同的包装器。

在扩展 SimpleChatModel 后,聊天模型需要实现以下几项必要内容:

  • 一个 _call 方法,该方法接收消息列表和调用选项(包括 stop 序列等内容),并返回一个字符串。

  • 一个 _llmType 方法,该方法返回一个字符串。仅用于日志记录。

您还可以实现以下可选方法:

  • 一个 _streamResponseChunks 方法,该方法返回一个 AsyncGenerator 并逐个生成 ChatGenerationChunks。这允许 LLM 支持流式输出。

我们来实现一个非常简单的自定义聊天模型,它只是回显输入的前 n 个字符。

import {
SimpleChatModel,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";

interface CustomChatModelInput extends BaseChatModelParams {
n: number;
}

class CustomChatModel extends SimpleChatModel {
n: number;

constructor(fields: CustomChatModelInput) {
super(fields);
this.n = fields.n;
}

_llmType() {
return "custom";
}

async _call(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<string> {
if (!messages.length) {
throw new Error("No messages provided.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
return messages[0].content.slice(0, this.n);
}

async *_streamResponseChunks(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
for (const letter of messages[0].content.slice(0, this.n)) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: letter,
}),
text: letter,
});
// Trigger the appropriate callback for new chunks
await runManager?.handleLLMNewToken(letter);
}
}
}

我们现在可以像使用其他聊天模型一样使用它:

const chatModel = new CustomChatModel({ n: 4 });

await chatModel.invoke([["human", "I am an LLM"]]);
AIMessage {
lc_serializable: true,
lc_kwargs: {
content: 'I am',
tool_calls: [],
invalid_tool_calls: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I am',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
usage_metadata: undefined
}

支持流式传输:

const stream = await chatModel.stream([["human", "I am an LLM"]]);

for await (const chunk of stream) {
console.log(chunk);
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'I',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'I',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: ' ',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: ' ',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'a',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'a',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}
AIMessageChunk {
lc_serializable: true,
lc_kwargs: {
content: 'm',
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
additional_kwargs: {},
response_metadata: {}
},
lc_namespace: [ 'langchain_core', 'messages' ],
content: 'm',
name: undefined,
additional_kwargs: {},
response_metadata: {},
id: undefined,
tool_calls: [],
invalid_tool_calls: [],
tool_call_chunks: [],
usage_metadata: undefined
}

如果你想利用 LangChain 的回调系统来实现诸如令牌追踪之类的功能,你可以扩展 BaseChatModel 类并实现更低层级的 _generate 方法。它同样以一个 BaseMessage 列表作为输入,但需要你构造并返回一个允许添加额外元数据的 ChatGeneration 对象。 下面是一个示例:

import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

interface AdvancedCustomChatModelOptions extends BaseChatModelCallOptions {}

interface AdvancedCustomChatModelParams extends BaseChatModelParams {
n: number;
}

class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
n: number;

static lc_name(): string {
return "AdvancedCustomChatModel";
}

constructor(fields: AdvancedCustomChatModelParams) {
super(fields);
this.n = fields.n;
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
// Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
// await subRunnable.invoke(params, runManager?.getChild());
const content = messages[0].content.slice(0, this.n);
const tokenUsage = {
usedTokens: this.n,
};
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: { tokenUsage },
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

这会将回调事件和 streamEvents 方法中的额外返回信息传递:

const chatModel = new AdvancedCustomChatModel({ n: 4 });

const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
version: "v2",
});

for await (const event of eventStream) {
if (event.event === "on_chat_model_end") {
console.log(JSON.stringify(event, null, 2));
}
}
{
"event": "on_chat_model_end",
"data": {
"output": {
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"messages",
"AIMessage"
],
"kwargs": {
"content": "I am",
"tool_calls": [],
"invalid_tool_calls": [],
"additional_kwargs": {},
"response_metadata": {
"tokenUsage": {
"usedTokens": 4
}
}
}
}
},
"run_id": "11dbdef6-1b91-407e-a497-1a1ce2974788",
"name": "AdvancedCustomChatModel",
"tags": [],
"metadata": {
"ls_model_type": "chat"
}
}

追踪(高级)

如果您正在实现一个自定义的聊天模型,并希望将其与 LangSmith 等追踪服务一起使用, 您可以通过在模型上实现 invocationParams() 方法,自动记录某次调用所使用的参数。

此方法完全是可选的,但它返回的任何内容都将作为元数据记录到追踪中。

以下是一种可能的使用模式:

import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
type BaseChatModelCallOptions,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";

interface CustomChatModelOptions extends BaseChatModelCallOptions {
// Some required or optional inner args
tools: Record<string, any>[];
}

interface CustomChatModelParams extends BaseChatModelParams {
temperature: number;
n: number;
}

class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
temperature: number;

n: number;

static lc_name(): string {
return "CustomChatModel";
}

constructor(fields: CustomChatModelParams) {
super(fields);
this.temperature = fields.temperature;
this.n = fields.n;
}

// Anything returned in this method will be logged as metadata in the trace.
// It is common to pass it any options used to invoke the function.
invocationParams(options?: this["ParsedCallOptions"]) {
return {
tools: options?.tools,
n: this.n,
};
}

async _generate(
messages: BaseMessage[],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (!messages.length) {
throw new Error("No messages provided.");
}
if (typeof messages[0].content !== "string") {
throw new Error("Multimodal messages are not supported.");
}
const additionalParams = this.invocationParams(options);
const content = await someAPIRequest(messages, additionalParams);
return {
generations: [{ message: new AIMessage({ content }), text: content }],
llmOutput: {},
};
}

_llmType(): string {
return "advanced_custom_chat_model";
}
}

Was this page helpful?


You can also leave detailed feedback on GitHub.