Skip to content

Commit efa591b

Browse files
authored
Merge pull request #10 from matlab-deep-learning/dev-fix-streaming-bug
Support function calls in streaming
2 parents ab154ef + 0b4b546 commit efa591b

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

+llms/+internal/callOpenAIChatAPI.m

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,20 @@
8484
if isempty(nvp.StreamFun)
8585
message = response.Body.Data.choices(1).message;
8686
else
87-
message = struct("role", "assistant", ...
88-
"content", streamedText);
87+
pat = '{"' + wildcardPattern + '":';
88+
if contains(streamedText,pat)
89+
s = jsondecode(streamedText);
90+
if contains(s.function.arguments,pat)
91+
prompt = jsondecode(s.function.arguments);
92+
s.function.arguments = prompt;
93+
end
94+
message = struct("role", "assistant", ...
95+
"content",[], ...
96+
"tool_calls",jsondecode(streamedText));
97+
else
98+
message = struct("role", "assistant", ...
99+
"content", streamedText);
100+
end
89101
end
90102
if isfield(message, "tool_choice")
91103
text = "";

+llms/+stream/responseStreamer.m

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,44 @@
3636
str = erase(str,"data: ");
3737

3838
for i = 1:length(str)
39-
json = jsondecode(str{i});
40-
if strcmp(json.choices.finish_reason,'stop')
39+
if strcmp(str{i},'[DONE]')
4140
stop = true;
4241
return
4342
else
44-
txt = json.choices.delta.content;
45-
this.StreamFun(txt);
46-
this.ResponseText = [this.ResponseText txt];
43+
try
44+
json = jsondecode(str{i});
45+
catch ME
46+
errID = 'llms:stream:responseStreamer:InvalidInput';
47+
msg = "Input does not have the expected json format. " + str{i};
48+
causeException = MException(errID,msg);
49+
ME = addCause(ME,causeException);
50+
rethrow(ME)
51+
end
52+
if ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"])
53+
stop = true;
54+
return
55+
else
56+
if isfield(json.choices.delta,"tool_calls")
57+
if isfield(json.choices.delta.tool_calls,"id")
58+
id = json.choices.delta.tool_calls.id;
59+
type = json.choices.delta.tool_calls.type;
60+
fcn = json.choices.delta.tool_calls.function;
61+
s = struct('id',id,'type',type,'function',fcn);
62+
txt = jsonencode(s);
63+
else
64+
s = jsondecode(this.ResponseText);
65+
args = json.choices.delta.tool_calls.function.arguments;
66+
s.function.arguments = [s.function.arguments args];
67+
txt = jsonencode(s);
68+
end
69+
this.StreamFun('');
70+
this.ResponseText = txt;
71+
else
72+
txt = json.choices.delta.content;
73+
this.StreamFun(txt);
74+
this.ResponseText = [this.ResponseText txt];
75+
end
76+
end
4777
end
4878
end
4979
end

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
*.asv
33
*.mat
44
startup.m
5+
papers_to_read.csv
6+
data/*
-197 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)