Skip to content

Commit 3ec2414

Browse files
committed
Addressing comments by ccreutzi
1 parent b3de243 commit 3ec2414

File tree

6 files changed

+36
-11
lines changed

6 files changed

+36
-11
lines changed

+llms/+internal/callOpenAIChatAPI.m

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@
134134
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
135135
end
136136

137-
for i=1:length(nvpOptions)
138-
if isfield(nvp, nvpOptions(i))
139-
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
137+
for opt = nvpOptions.'
138+
if isfield(nvp, opt)
139+
parameters.(dict(opt)) = nvp.(opt);
140140
end
141141
end
142142
end

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Large Language Models (LLMs) with MATLAB® [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/llms-with-matlab)
22

3-
This repository contains example code to demonstrate how to connect MATLAB to the OpenAI™ Chat Completions API (which powers ChatGPT™) as well as OpenAI Images API (which powers DALL-E™). This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment.
3+
This repository contains example code to demonstrate how to connect MATLAB to the OpenAI™ Chat Completions API (which powers ChatGPT™) as well as OpenAI Images API (which powers DALL·E™). This allows you to leverage the natural language processing capabilities of large language models directly within your MATLAB environment.
44

5-
The functionality shown here serves as an interface to the ChatGPT and DALL-E APIs. To start using the OpenAI APIs, you first need to obtain the OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs.
5+
The functionality shown here serves as an interface to the ChatGPT and DALL·E APIs. To start using the OpenAI APIs, you first need to obtain OpenAI API keys. You are responsible for any fees OpenAI may charge for the use of their APIs. You should be familiar with the limitations and risks associated with using this technology, and you agree that you shall be solely responsible for full compliance with any terms that may apply to your use of the OpenAI APIs.
66

77
Some of the current LLMs supported are:
88
- gpt-3.5-turbo, gpt-3.5-turbo-1106
@@ -127,6 +127,7 @@ You can specifying the streaming function when you create the chat assistant. Th
127127
sf = @(x)fprintf("%s",x);
128128
chat = openAIChat(StreamFun=sf);
129129
txt = generate(chat,"What is Model-Based Design and how is it related to Digital Twin?")
130+
% Should stream the response token by token
130131
```
131132
132133
### Calling MATLAB functions with the API
@@ -280,6 +281,7 @@ image_path = "peppers.png";
280281
messages = openAIMessages;
281282
messages = addUserMessageWithImages(messages,"What is in the image?",image_path);
282283
[txt,response] = generate(chat,messages);
284+
% Should output the description of the image
283285
```
284286

285287
### Obtaining embeddings
@@ -308,6 +310,7 @@ mdl = openAIImages(ModelName="dall-e-3");
308310
images = generate(mdl,"Create a 3D avatar of a whimsical sushi on the beach. He is decorated with various sushi elements and is playfully interacting with the beach environment.");
309311
figure
310312
imshow(images{1})
313+
% Should output an image based on the prompt
311314
```
312315

313316
## Examples

openAIChat.m

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,21 @@
131131

132132
if isfield(nvp,"StreamFun")
133133
this.StreamFun = nvp.StreamFun;
134+
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
135+
error("llms:invalidOptionAndValueForModel", ...
136+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StreamFun", nvp.ModelName));
137+
end
134138
else
135139
this.StreamFun = [];
136140
end
137141

138142
if ~isempty(nvp.Tools)
139143
this.Tools = nvp.Tools;
140144
[this.FunctionsStruct, this.FunctionNames] = functionAsStruct(nvp.Tools);
145+
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
146+
error("llms:invalidOptionAndValueForModel", ...
147+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "Tools", nvp.ModelName));
148+
end
141149
else
142150
this.Tools = [];
143151
this.FunctionsStruct = [];
@@ -155,6 +163,11 @@
155163
this.Temperature = nvp.Temperature;
156164
this.TopProbabilityMass = nvp.TopProbabilityMass;
157165
this.StopSequences = nvp.StopSequences;
166+
if ~isempty(nvp.StopSequences) && strcmp(nvp.ModelName,'gpt-4-vision-preview')
167+
error("llms:invalidOptionAndValueForModel", ...
168+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "StopSequences", nvp.ModelName));
169+
end
170+
158171

159172
% ResponseFormat is only supported in the latest models only
160173
if (nvp.ResponseFormat == "json")
@@ -208,7 +221,17 @@
208221
nvp.Seed {mustBeIntegerOrEmpty(nvp.Seed)} = []
209222
end
210223

224+
if nvp.MaxNumTokens ~= Inf && strcmp(this.ModelName,'gpt-4-vision-preview')
225+
error("llms:invalidOptionAndValueForModel", ...
226+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "MaxNumTokens", this.ModelName));
227+
end
228+
211229
toolChoice = convertToolChoice(this, nvp.ToolChoice);
230+
if ~isempty(nvp.ToolChoice) && strcmp(this.ModelName,'gpt-4-vision-preview')
231+
error("llms:invalidOptionAndValueForModel", ...
232+
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionForModel", "ToolChoice", this.ModelName));
233+
end
234+
212235
if isstring(messages) && isscalar(messages)
213236
messagesStruct = {struct("role", "user", "content", messages)};
214237
else

openAIMessages.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ function validateRegularAssistant(content)
297297
end
298298

299299
function validateAssistantWithToolCalls(toolCallStruct)
300-
if ~isstruct(toolCallStruct)||~isfield(toolCallStruct, "id")||~isfield(toolCallStruct, "function")
300+
if ~(isstruct(toolCallStruct) && isfield(toolCallStruct, "id") && isfield(toolCallStruct, "function"))
301301
error("llms:mustBeAssistantWithIdAndFunction", ...
302302
llms.utils.errorMessageCatalog.getMessage("llms:mustBeAssistantWithIdAndFunction"))
303303
else

tests/topenAIImages.m

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ function saveEnvVar(testCase)
99
openAIEnvVar = "OPENAI_API_KEY";
1010
if isenv(openAIEnvVar)
1111
key = getenv(openAIEnvVar);
12+
testCase.addTeardown(@() setenv(openAIEnvVar, key));
1213
unsetenv(openAIEnvVar);
13-
testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key);
1414
end
1515
end
1616
end
@@ -223,7 +223,7 @@ function invalidInputsVariation(testCase, InvalidVariationInput)
223223

224224
function invalidEditInput = iGetInvalidEditInput
225225
validImage = string(which("peppers.png"));
226-
nonPNGImage = which("corn.tif");
226+
nonPNGImage = string(which("corn.tif"));
227227
invalidEditInput = struct( ...
228228
"EmptyImage",struct( ...
229229
"Input",{{ [], "prompt" }},...

tests/topenAIMessages.m

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,8 @@ function assistantToolCallMessageWithoutArgsIsAdded(testCase)
9797
functionName = "functionName";
9898
funCall = struct("name", functionName, "arguments", "{}");
9999
toolCall = struct("id", "123", "type", "function", "function", funCall);
100-
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", []);
101100
% tool_calls is an array of struct in API response
102-
toolCallPrompt.tool_calls = toolCall;
101+
toolCallPrompt = struct("role", "assistant", "content", "","tool_calls", toolCall);
103102
msgs = addResponseMessage(msgs, toolCallPrompt);
104103
% to include in msgs, tool_calls must be a cell
105104
testCase.verifyEqual(fieldnames(msgs.Messages{1}), fieldnames(toolCallPrompt));
@@ -112,8 +111,8 @@ function assistantParallelToolCallMessageIsAdded(testCase)
112111
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
113112
funCall = struct("name", functionName, "arguments", args);
114113
toolCall = struct("id", "123", "type", "function", "function", funCall);
115-
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", []);
116114
% tool_calls is an array of struct in API response
115+
toolCallPrompt = struct("role", "assistant", "content", "", "tool_calls", toolCall);
117116
toolCallPrompt.tool_calls = [toolCall,toolCall,toolCall];
118117
msgs = addResponseMessage(msgs, toolCallPrompt);
119118
testCase.verifyEqual(msgs.Messages{1}, toolCallPrompt);

0 commit comments

Comments
 (0)