Skip to content

Commit

Permalink
fix: incrementalOutput not working (#17)
Browse files Browse the repository at this point in the history
1. set enable search to false for default chat configuration.
2. Fix incremental output issue and support streaming function call.
  • Loading branch information
robinyeeh authored Sep 23, 2024
1 parent 423142c commit 5840b73
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ public DashScopeChatModel dashscopeChatModel(DashScopeConnectionProperties commo
retryTemplate);
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = DashScopeEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public DashScopeApi dashscopeChatApi(DashScopeConnectionProperties commonProperties,
DashScopeChatProperties chatProperties, RestClient.Builder restClientBuilder,
WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class DashScopeChatProperties extends DashScopeParentProperties {
/**
* Default DashScope Chat model.
*/
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_TURBO;
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_PLUS;

/**
* Default temperature speed.
Expand All @@ -56,7 +56,6 @@ public class DashScopeChatProperties extends DashScopeParentProperties {
private DashScopeChatOptions options = DashScopeChatOptions.builder()
.withModel(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.withEnableSearch(true)
.build();

public DashScopeChatProperties() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@
* @author Ken
*/
public class DashScopeAiStreamFunctionCallingHelper {
private Boolean incrementalOutput = false;

public DashScopeAiStreamFunctionCallingHelper() {
}

public DashScopeAiStreamFunctionCallingHelper(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}

/**
* Merge the previous and current ChatCompletionChunk into a single one.
Expand All @@ -46,7 +54,6 @@ public class DashScopeAiStreamFunctionCallingHelper {
* @return the merged ChatCompletionChunk
*/
public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChunk current) {

if (previous == null) {
return current;
}
Expand All @@ -57,9 +64,18 @@ public ChatCompletionChunk merge(ChatCompletionChunk previous, ChatCompletionChu
Choice previousChoice0 = previous.output() == null ? null : previous.output().choices().get(0);
Choice currentChoice0 = current.output() == null ? null : current.output().choices().get(0);

//compatibility of incremental_output false for streaming function call
if (!incrementalOutput && isStreamingToolFunctionCall(current)) {
if (!isStreamingToolFunctionCallFinish(current)) {
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, List.of(new Choice(null, null))), usage);
} else {
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, List.of(currentChoice0)), usage);
}
}

Choice choice = merge(previousChoice0, currentChoice0);
List<Choice> chunkChoices = choice == null ? List.of() : List.of(choice);
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, chunkChoices), usage);
return new ChatCompletionChunk(id, new ChatCompletionOutput(null, chunkChoices), usage);
}

private Choice merge(Choice previous, Choice current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,6 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
.toEntity(ChatCompletion.class);
}

private final DashScopeAiStreamFunctionCallingHelper chunkMerger = new DashScopeAiStreamFunctionCallingHelper();

/**
* Creates a streaming chat response for the given chat conversation.
* @param chatRequest The chat completion request. Must have the stream property set
Expand All @@ -1251,6 +1249,8 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
Assert.isTrue(chatRequest.stream(), "Request must set the stream property to true.");

AtomicBoolean isInsideTool = new AtomicBoolean(false);
boolean incrementalOutput = chatRequest.parameters() != null && chatRequest.parameters().incrementalOutput != null && chatRequest.parameters().incrementalOutput;
DashScopeAiStreamFunctionCallingHelper chunkMerger = new DashScopeAiStreamFunctionCallingHelper(incrementalOutput);

return this.webClient.post()
.uri("/api/v1/services/aigc/text-generation/generation")
Expand All @@ -1262,21 +1262,21 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
.filter(SSE_DONE_PREDICATE.negate())
.map(content -> ModelOptionsUtils.jsonToObject(content, ChatCompletionChunk.class))
.map(chunk -> {
if (this.chunkMerger.isStreamingToolFunctionCall(chunk)) {
if (chunkMerger.isStreamingToolFunctionCall(chunk)) {
isInsideTool.set(true);
}
return chunk;
})
.windowUntil(chunk -> {
if (isInsideTool.get() && this.chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
if (isInsideTool.get() && chunkMerger.isStreamingToolFunctionCallFinish(chunk)) {
isInsideTool.set(false);
return true;
}
return !isInsideTool.get();
})
.concatMapIterable(window -> {
Mono<ChatCompletionChunk> monoChunk = window.reduce(new ChatCompletionChunk(null, null, null),
this.chunkMerger::merge);
chunkMerger::merge);
return List.of(monoChunk);
})
.flatMap(mono -> mono);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,7 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
options = ModelOptionsUtils.merge(options, this.defaultOptions, DashScopeChatOptions.class);

if (!CollectionUtils.isEmpty(enabledToolsToUse)) {
options = ModelOptionsUtils.merge(
DashScopeChatOptions.builder().withTools(this.getFunctionTools(enabledToolsToUse)).build(), options,
DashScopeChatOptions.class);
options.setTools(this.getFunctionTools(enabledToolsToUse));
}

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(message -> {
Expand Down Expand Up @@ -338,7 +336,7 @@ private ChatCompletionRequestParameter toDashScopeRequestParameter(DashScopeChat
return new ChatCompletionRequestParameter();
}

Boolean incrementalOutput = stream || options.getIncrementalOutput();
Boolean incrementalOutput = options.getIncrementalOutput();
return new ChatCompletionRequestParameter("message", options.getSeed(), options.getMaxTokens(),
options.getTopP(), options.getTopK(), options.getRepetitionPenalty(), options.getPresencePenalty(),
options.getTemperature(), options.getStop(), options.getEnableSearch(), incrementalOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public class DashScopeChatOptions implements FunctionCallingOptions, ChatOptions
/**
* 控制在流式输出模式下是否开启增量输出,即后续输出内容是否包含已输出的内容。设置为True时,将开启增量输出模式,后面输出不会包含已经输出的内容,您需要自行拼接整体输出;设置为False则会包含已输出的内容。
*/
private @JsonProperty("incremental_output") Boolean incrementalOutput = false;
private @JsonProperty("incremental_output") Boolean incrementalOutput = true;

/** 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。默认为1.1。 */
private @JsonProperty("repetition_penalty") Float repetitionPenalty;
Expand Down Expand Up @@ -341,6 +341,7 @@ public static DashScopeChatOptions fromOptions(DashScopeChatOptions fromOptions)
.withStop(fromOptions.getStop())
.withStream(fromOptions.getStream())
.withEnableSearch(fromOptions.enableSearch)
.withIncrementalOutput(fromOptions.getIncrementalOutput())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withFunctions(fromOptions.getFunctions())
.withRepetitionPenalty(fromOptions.getRepetitionPenalty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.api.DashScopeApi.ChatCompletionFinishReason;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import com.alibaba.cloud.ai.dashscope.tool.DashScopeFunctionTestConfiguration;
import com.alibaba.cloud.ai.dashscope.chat.tool.MockOrderService;
import com.alibaba.cloud.ai.dashscope.chat.tool.MockWeatherService;
Expand Down Expand Up @@ -75,7 +76,7 @@ public class DashScopeChatClientIT {
private DashScopeChatModel dashscopeChatModel;

@Autowired
private DashScopeApi dashscopeApi;
private DashScopeApi dashscopeChatApi;

@Value("classpath:/prompts/rag/system-qa.st")
private Resource systemResource;
Expand All @@ -85,7 +86,7 @@ public class DashScopeChatClientIT {

@Test
void callTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -102,14 +103,20 @@ void callTest() throws IOException {

@Test
void streamTest() throws InterruptedException, IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());
ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
.defaultAdvisors(
new DocumentRetrievalAdvisor(retriever, systemResource.getContentAsString(StandardCharsets.UTF_8)))
.build();

Flux<ChatResponse> response = chatClient.prompt().user("如何快速开始百炼?").stream().chatResponse();
Flux<ChatResponse> response = chatClient.prompt()
.user("如何快速开始百炼?")
.options(DashScopeChatOptions.builder()
.withIncrementalOutput(true)
.build())
.stream()
.chatResponse();

CountDownLatch cdl = new CountDownLatch(1);
response.subscribe(data -> {
Expand Down Expand Up @@ -159,7 +166,7 @@ void callWithFunctionBeanTest() {

@Test
void callWithFunctionAndRagTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -178,7 +185,7 @@ void callWithFunctionAndRagTest() throws IOException {

@Test
void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -187,7 +194,13 @@ void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException
.defaultFunctions("weatherFunction")
.build();

Flux<ChatResponse> response = chatClient.prompt().user("上海今天的天气如何?").stream().chatResponse();
Flux<ChatResponse> response = chatClient.prompt()
.user("上海今天的天气如何?")
.options(DashScopeChatOptions.builder()
.withIncrementalOutput(true)
.build())
.stream()
.chatResponse();

CountDownLatch cdl = new CountDownLatch(1);
response.subscribe(data -> {
Expand All @@ -206,7 +219,7 @@ void streamCallWithFunctionAndRagTest() throws InterruptedException, IOException

@Test
void callWithReferencedRagTest() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand All @@ -232,7 +245,7 @@ void callWithReferencedRagTest() throws IOException {

@Test
void streamCallWithReferencedRagTest() throws IOException, InterruptedException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand Down Expand Up @@ -272,7 +285,7 @@ void streamCallWithReferencedRagTest() throws IOException, InterruptedException

@Test
void callWithMemory() throws IOException {
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeApi,
DocumentRetriever retriever = new DashScopeDocumentRetriever(dashscopeChatApi,
DashScopeDocumentRetrieverOptions.builder().withIndexName("spring-ai知识库").build());

ChatClient chatClient = ChatClient.builder(dashscopeChatModel)
Expand Down Expand Up @@ -309,24 +322,24 @@ void callWithMemory() throws IOException {
@Test
void reader() {
String filePath = "/Users/nuocheng.lxm/Desktop/新能源产业有哪些-36氪.pdf";
DashScopeDocumentCloudReader reader = new DashScopeDocumentCloudReader(filePath, dashscopeApi, null);
DashScopeDocumentCloudReader reader = new DashScopeDocumentCloudReader(filePath, dashscopeChatApi, null);
List<Document> documentList = reader.get();
DashScopeDocumentTransformer transformer = new DashScopeDocumentTransformer(dashscopeApi);
DashScopeDocumentTransformer transformer = new DashScopeDocumentTransformer(dashscopeChatApi);
List<Document> transformerList = transformer.apply(documentList);
System.out.println(transformerList.size());
}

@Test
void embed() {
DashScopeEmbeddingModel embeddingModel = new DashScopeEmbeddingModel(dashscopeApi);
DashScopeEmbeddingModel embeddingModel = new DashScopeEmbeddingModel(dashscopeChatApi);
Document document = new Document("你好阿里云");
float[] vectorList = embeddingModel.embed(document);
System.out.println(vectorList.length);
}

@Test
void vectorStore() {
DashScopeCloudStore cloudStore = new DashScopeCloudStore(dashscopeApi, new DashScopeStoreOptions("诺成SpringAI"));
DashScopeCloudStore cloudStore = new DashScopeCloudStore(dashscopeChatApi, new DashScopeStoreOptions("诺成SpringAI"));
List<Document> documentList = Arrays.asList(
new Document("file_f0b6b18b14994ed8a0b45648ce5d0da5_10001", "abc", new HashMap<>()),
new Document("file_d3083d64026d4864b4558d18f9ca2a6d_10001", "abc", new HashMap<>()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ import { Outlet } from 'react-router-dom';
(window as any).Vaadin ??= {};
(window as any).Vaadin.copilot ??= {};
(window as any).Vaadin.copilot._ref ??= {};
(window as any).Vaadin.copilot._ref.Outlet = Outlet;
(window as any).Vaadin.copilot._ref.Outlet = Outlet;

0 comments on commit 5840b73

Please sign in to comment.