Skip to content

Commit a9f788e

Browse files
authored
Merge pull request #9 from cnblogs/supports-tools-in-text-completion
feat: support tools in text generation
2 parents 4f32693 + 14a853d commit a9f788e

17 files changed

+263
-7
lines changed

src/Cnblogs.DashScope.Sdk/ChatMessage.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,7 @@ namespace Cnblogs.DashScope.Sdk;
55
/// <summary>
66
/// Represents a chat message between the user and the model.
77
/// </summary>
8-
public record ChatMessage(string Role, string Content) : IMessage<string>;
8+
/// <param name="Role">The role of this message.</param>
9+
/// <param name="Content">The content of this message.</param>
10+
/// <param name="ToolCalls">Calls to the function.</param>
11+
public record ChatMessage(string Role, string Content, List<ToolCall>? ToolCalls = null) : IMessage<string>;

src/Cnblogs.DashScope.Sdk/Cnblogs.DashScope.Sdk.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@
44
<GenerateDocumentationFile>true</GenerateDocumentationFile>
55
<PackageTags>Cnblogs;Dashscope;AI;Sdk;Embedding;</PackageTags>
66
</PropertyGroup>
7+
<ItemGroup>
8+
<PackageReference Include="JsonSchema.Net.Generation" Version="4.1.1" />
9+
</ItemGroup>
710
</Project>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Cnblogs.DashScope.Sdk;
2+
3+
/// <summary>
4+
/// Represents a call to function.
5+
/// </summary>
6+
/// <param name="Name">Name of the function to call.</param>
7+
/// <param name="Arguments">Arguments of this call, usually a json string.</param>
8+
public record FunctionCall(string Name, string? Arguments);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Json.Schema;
2+
3+
namespace Cnblogs.DashScope.Sdk;
4+
5+
/// <summary>
6+
/// Definition of function that can be called by model.
7+
/// </summary>
8+
/// <param name="Name">The name of the function.</param>
9+
/// <param name="Description">Descriptions about this function that help model to decide when to call this function.</param>
10+
/// <param name="Parameters">The parameters JSON schema.</param>
11+
public record FunctionDefinition(string Name, string Description, JsonSchema? Parameters);

src/Cnblogs.DashScope.Sdk/ITextGenerationParameters.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,9 @@ public interface ITextGenerationParameters : IIncrementalOutputParameter, ISeedP
4242
/// Enable internet search when generation. Defaults to false.
4343
/// </summary>
4444
public bool? EnableSearch { get; }
45+
46+
/// <summary>
47+
/// Available tools for model to call.
48+
/// </summary>
49+
public List<ToolDefinition>? Tools { get; }
4550
}

src/Cnblogs.DashScope.Sdk/TextGenerationInput.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,9 @@ public class TextGenerationInput
1414
/// The collection of context messages associated with this chat completions request.
1515
/// </summary>
1616
public IEnumerable<ChatMessage>? Messages { get; set; }
17+
18+
/// <summary>
19+
/// Available tools for model to use.
20+
/// </summary>
21+
public IEnumerable<ToolDefinition>? Tools { get; set; }
1722
}

src/Cnblogs.DashScope.Sdk/TextGenerationParameters.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ public class TextGenerationParameters : ITextGenerationParameters
3232
/// <inheritdoc />
3333
public bool? EnableSearch { get; set; }
3434

35+
/// <inheritdoc />
36+
public List<ToolDefinition>? Tools { get; set; }
37+
3538
/// <inheritdoc />
3639
public bool? IncrementalOutput { get; set; }
3740
}

src/Cnblogs.DashScope.Sdk/ToolCall.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
namespace Cnblogs.DashScope.Sdk;
2+
3+
/// <summary>
4+
/// Represents a call to tool.
5+
/// </summary>
6+
/// <param name="Id">Id of this tool call.</param>
7+
/// <param name="Type">Type of the tool.</param>
8+
/// <param name="Function">Not null if type is function.</param>
9+
public record ToolCall(string? Id, string Type, FunctionCall? Function);
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Cnblogs.DashScope.Sdk;
2+
3+
/// <summary>
4+
/// Definition of a tool that model can call during generation.
5+
/// </summary>
6+
/// <param name="Type">The type of this tool. Use <see cref="ToolTypes"/> to get all available options.</param>
7+
/// <param name="Function">Not null when <paramref name="Type"/> is tool.</param>
8+
public record ToolDefinition(string Type, FunctionDefinition? Function);
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
namespace Cnblogs.DashScope.Sdk;
2+
3+
/// <summary>
4+
/// Available tool types for <see cref="ToolDefinition"/>.
5+
/// </summary>
6+
public static class ToolTypes
7+
{
8+
/// <summary>
9+
/// Function type.
10+
/// </summary>
11+
public const string Function = "function";
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"model": "qwen-max",
3+
"input": {
4+
"messages": [
5+
{
6+
"role": "user",
7+
"content": "杭州现在的天气如何?"
8+
}
9+
]
10+
},
11+
"parameters": {
12+
"result_format": "message",
13+
"seed": 1234,
14+
"max_tokens": 1500,
15+
"top_p": 0.8,
16+
"top_k": 100,
17+
"repetition_penalty": 1.1,
18+
"temperature": 0.85,
19+
"stop": [[37763, 367]],
20+
"enable_search": false,
21+
"incremental_output": false,
22+
"tools": [
23+
{
24+
"type": "function",
25+
"function": {
26+
"name": "get_current_weather",
27+
"description": "获取现在的天气",
28+
"parameters": {
29+
"type": "object",
30+
"properties": {
31+
"location": {
32+
"type": "string",
33+
"description": "要获取天气的省市名称,例如浙江省杭州市"
34+
},
35+
"unit": {
36+
"description": "温度单位",
37+
"enum": [
38+
"Celsius",
39+
"Fahrenheit"
40+
]
41+
}
42+
},
43+
"required": [
44+
"location"
45+
]
46+
}
47+
}
48+
}
49+
]
50+
}
51+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"output":{"choices":[{"finish_reason":"tool_calls","message":{"role":"assistant","tool_calls":[{"function":{"name":"get_current_weather","arguments":"{\"location\": \"浙江省杭州市\", \"unit\": \"Celsius\"}"},"id":"","type":"function"}],"content":""}}]},"usage":{"total_tokens":36,"output_tokens":31,"input_tokens":5},"request_id":"40b4361e-e936-91b5-879d-355a45d670f8"}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
HTTP/1.1 200 OK
2+
eagleeye-traceid: 7328e5207abf69133abfe3a68446fc2d
3+
content-type: application/json
4+
x-dashscope-call-gateway: true
5+
x-dashscope-experiments: 33e6d810-qwen-max-base-default-imbalance-fix-lua
6+
req-cost-time: 3898
7+
req-arrive-time: 1710324737299
8+
resp-start-time: 1710324741198
9+
x-envoy-upstream-service-time: 3893
10+
content-encoding: gzip
11+
vary: Accept-Encoding
12+
date: Wed, 13 Mar 2024 10:12:21 GMT
13+
server: istio-envoy
14+
transfer-encoding: chunked

test/Cnblogs.DashScope.Sdk.UnitTests/ServiceCollectionInjectorTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,27 @@ public void Configuration_CustomSectionName_Inject()
103103
httpClient.BaseAddress.Should().BeEquivalentTo(new Uri(ProxyApi));
104104
}
105105

106+
[Fact]
107+
public void Configuration_AddMultipleTime_Replace()
108+
{
109+
// Arrange
110+
var services = new ServiceCollection();
111+
112+
// Act
113+
services.AddDashScopeClient(ApiKey, ProxyApi);
114+
services.AddDashScopeClient(ApiKey, ProxyApi);
115+
var provider = services.BuildServiceProvider();
116+
var httpClient = provider.GetRequiredService<IHttpClientFactory>().CreateClient(nameof(IDashScopeClient));
117+
118+
// Assert
119+
provider.GetRequiredService<IDashScopeClient>().Should().NotBeNull().And
120+
.BeOfType<DashScopeClientCore>();
121+
httpClient.Should().NotBeNull();
122+
httpClient.DefaultRequestHeaders.Authorization.Should()
123+
.BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey));
124+
httpClient.BaseAddress.Should().BeEquivalentTo(new Uri(ProxyApi));
125+
}
126+
106127
[Fact]
107128
public void Configuration_NoApiKey_Throw()
108129
{

test/Cnblogs.DashScope.Sdk.UnitTests/TextGenerationSerializationTests.cs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,14 @@ public async Task SingleCompletion_TextFormatSse_SuccessAsync()
4747
message.ToString().Should().Be(testCase.ResponseModel.Output.Text);
4848
}
4949

50-
[Fact]
51-
public async Task SingleCompletion_MessageFormatNoSse_SuccessAsync()
50+
[Theory]
51+
[MemberData(nameof(SingleGenerationMessageFormatData))]
52+
public async Task SingleCompletion_MessageFormatNoSse_SuccessAsync(
53+
RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
54+
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> testCase)
5255
{
5356
// Arrange
5457
const bool sse = false;
55-
var testCase = Snapshots.TextGeneration.MessageFormat.SingleMessage;
5658
var (client, handler) = await Sut.GetTestClientAsync(sse, testCase);
5759

5860
// Act
@@ -83,7 +85,9 @@ public async Task SingleCompletion_MessageFormatSse_SuccessAsync()
8385
Arg.Is<HttpRequestMessage>(m => Checkers.IsJsonEquivalent(m.Content!, testCase.GetRequestJson(sse))),
8486
Arg.Any<CancellationToken>());
8587
outputs.SkipLast(1).Should().AllSatisfy(x => x.Output.Choices![0].FinishReason.Should().Be("null"));
86-
outputs.Last().Should().BeEquivalentTo(testCase.ResponseModel, o => o.Excluding(y => y.Output.Choices![0].Message.Content));
88+
outputs.Last().Should().BeEquivalentTo(
89+
testCase.ResponseModel,
90+
o => o.Excluding(y => y.Output.Choices![0].Message.Content));
8791
message.ToString().Should().Be(testCase.ResponseModel.Output.Choices![0].Message.Content);
8892
}
8993

@@ -105,7 +109,14 @@ public async Task ConversationCompletion_MessageFormatSse_SuccessAsync()
105109
Arg.Is<HttpRequestMessage>(m => Checkers.IsJsonEquivalent(m.Content!, testCase.GetRequestJson(sse))),
106110
Arg.Any<CancellationToken>());
107111
outputs.SkipLast(1).Should().AllSatisfy(x => x.Output.Choices![0].FinishReason.Should().Be("null"));
108-
outputs.Last().Should().BeEquivalentTo(testCase.ResponseModel, o => o.Excluding(y => y.Output.Choices![0].Message.Content));
112+
outputs.Last().Should().BeEquivalentTo(
113+
testCase.ResponseModel,
114+
o => o.Excluding(y => y.Output.Choices![0].Message.Content));
109115
message.ToString().Should().Be(testCase.ResponseModel.Output.Choices![0].Message.Content);
110116
}
117+
118+
public static readonly TheoryData<RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
119+
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>>> SingleGenerationMessageFormatData = new(
120+
Snapshots.TextGeneration.MessageFormat.SingleMessage,
121+
Snapshots.TextGeneration.MessageFormat.SingleMessageWithTools);
111122
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System.Text.Json.Serialization;
2+
using Json.More;
3+
using Json.Schema.Generation;
4+
5+
namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;
6+
7+
public record GetCurrentWeatherParameters(
8+
[property: Required]
9+
[property: Description("要获取天气的省市名称,例如浙江省杭州市")]
10+
string Location,
11+
[property: JsonConverter(typeof(EnumStringConverter<TemperatureUnit>))]
12+
[property: Description("温度单位")]
13+
TemperatureUnit Unit = TemperatureUnit.Celsius);
14+
15+
public enum TemperatureUnit
16+
{
17+
Celsius,
18+
Fahrenheit
19+
}

test/Cnblogs.DashScope.Sdk.UnitTests/Utils/Snapshots.cs

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;
1+
using Json.Schema;
2+
using Json.Schema.Generation;
3+
4+
namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;
25

36
public static class Snapshots
47
{
@@ -267,6 +270,75 @@ public static class MessageFormat
267270
}
268271
});
269272

273+
public static readonly
274+
RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
275+
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> SingleMessageWithTools =
276+
new(
277+
"single-generation-message-with-tools",
278+
new()
279+
{
280+
Model = "qwen-max",
281+
Input = new() { Messages = [new("user", "杭州现在的天气如何?")] },
282+
Parameters = new TextGenerationParameters()
283+
{
284+
ResultFormat = "message",
285+
Seed = 1234,
286+
MaxTokens = 1500,
287+
TopP = 0.8f,
288+
TopK = 100,
289+
RepetitionPenalty = 1.1f,
290+
Temperature = 0.85f,
291+
Stop = new([[37763, 367]]),
292+
EnableSearch = false,
293+
IncrementalOutput = false,
294+
Tools =
295+
[
296+
new ToolDefinition(
297+
"function",
298+
new FunctionDefinition(
299+
"get_current_weather",
300+
"获取现在的天气",
301+
new JsonSchemaBuilder().FromType<GetCurrentWeatherParameters>(
302+
new()
303+
{
304+
PropertyNameResolver = PropertyNameResolvers.LowerSnakeCase
305+
})
306+
.Build()))
307+
]
308+
}
309+
},
310+
new()
311+
{
312+
Output = new()
313+
{
314+
Choices =
315+
[
316+
new()
317+
{
318+
FinishReason = "tool_calls",
319+
Message = new(
320+
"assistant",
321+
string.Empty,
322+
[
323+
new(
324+
string.Empty,
325+
ToolTypes.Function,
326+
new(
327+
"get_current_weather",
328+
"""{"location": "浙江省杭州市", "unit": "Celsius"}"""))
329+
])
330+
}
331+
]
332+
},
333+
RequestId = "40b4361e-e936-91b5-879d-355a45d670f8",
334+
Usage = new()
335+
{
336+
InputTokens = 5,
337+
OutputTokens = 31,
338+
TotalTokens = 36
339+
}
340+
});
341+
270342
public static readonly RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
271343
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>>
272344
ConversationMessageIncremental = new(

0 commit comments

Comments
 (0)