Skip to content

Dev parallelfunctioncall jsonmode vision dalle #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
function [text, message, response] = callOpenAIChatAPI(messages, functions, nvp)
% This function is undocumented and will change in a future release

%callOpenAIChatAPI Calls the openAI chat completions API.
%
% MESSAGES and FUNCTIONS should be structs matching the json format
% required by the OpenAI Chat Completions API.
% Ref: https://platform.openai.com/docs/guides/gpt/chat-completions-api
%
% Currently, the supported NVP are, including the equivalent name in the API:
% - FunctionCall (function_call)
% - ToolChoice (tool_choice)
% - ModelName (model)
% - Temperature (temperature)
% - TopProbabilityMass (top_p)
Expand All @@ -17,6 +15,8 @@
% - MaxNumTokens (max_tokens)
% - PresencePenalty (presence_penalty)
% - FrequencyPenalty (frequence_penalty)
% - ResponseFormat (response_format)
% - Seed (seed)
% - ApiKey
% - TimeOut
% - StreamFun
Expand Down Expand Up @@ -50,12 +50,12 @@
% % Send a request
% [text, message] = llms.internal.callOpenAIChatAPI(messages, functions, ApiKey=apiKey)

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

arguments
messages
functions
nvp.FunctionCall = []
nvp.ToolChoice = []
nvp.ModelName = "gpt-3.5-turbo"
nvp.Temperature = 1
nvp.TopProbabilityMass = 1
Expand All @@ -64,6 +64,8 @@
nvp.MaxNumTokens = inf
nvp.PresencePenalty = 0
nvp.FrequencyPenalty = 0
nvp.ResponseFormat = "text"
nvp.Seed = []
nvp.ApiKey = ""
nvp.TimeOut = 10
nvp.StreamFun = []
Expand All @@ -85,7 +87,7 @@
message = struct("role", "assistant", ...
"content", streamedText);
end
if isfield(message, "function_call")
if isfield(message, "tool_choice")
text = "";
else
text = string(message.content);
Expand All @@ -105,22 +107,36 @@

parameters.stream = ~isempty(nvp.StreamFun);

if ~isempty(functions)
parameters.functions = functions;
if ~isempty(functions) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tools = functions;
end

if ~isempty(nvp.ToolChoice) && ~strcmp(nvp.ModelName,'gpt-4-vision-preview')
parameters.tool_choice = nvp.ToolChoice;
end

if ismember(nvp.ModelName,["gpt-3.5-turbo-1106","gpt-4-1106-preview"])
if strcmp(nvp.ResponseFormat,"json")
parameters.response_format = struct('type','json_object');
end
end

if ~isempty(nvp.FunctionCall)
parameters.function_call = nvp.FunctionCall;
if ~isempty(nvp.Seed)
parameters.seed = nvp.Seed;
end

parameters.model = nvp.ModelName;

dict = mapNVPToParameters;

nvpOptions = keys(dict);
for i=1:length(nvpOptions)
if isfield(nvp, nvpOptions(i))
parameters.(dict(nvpOptions(i))) = nvp.(nvpOptions(i));
if strcmp(nvp.ModelName,'gpt-4-vision-preview')
nvpOptions(ismember(nvpOptions,["MaxNumTokens","StopSequences"])) = [];
end

for opt = nvpOptions.'
if isfield(nvp, opt)
parameters.(dict(opt)) = nvp.(opt);
end
end
end
Expand Down
21 changes: 12 additions & 9 deletions +llms/+utils/errorMessageCatalog.m
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
classdef errorMessageCatalog
% This class is undocumented and will change in a future release

%errorMessageCatalog Stores the error messages from this repository

% Copyright 2023 The MathWorks, Inc.
% Copyright 2023-2024 The MathWorks, Inc.

properties(Constant)
%CATALOG dictionary mapping error ids to error msgs
Catalog = buildErrorMessageCatalog;
end

methods(Static)
function msg = getMessage(messageId, slot)
% This function is undocumented and will change in a future release

%getMessage returns error message given a messageID and a SLOT.
% The value in SLOT should be ordered, where the n-th element
% will replace the value "{n}".
Expand Down Expand Up @@ -41,13 +38,19 @@
catalog("llms:parameterMustBeUnique") = "A parameter name equivalent to '{1}' already exists in Parameters. Redefining a parameter is not allowed.";
catalog("llms:mustBeAssistantCall") = "Input struct must contain field 'role' with value 'assistant', and field 'content'.";
catalog("llms:mustBeAssistantWithContent") = "Input struct must contain field 'content' containing text with one or more characters.";
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function_call' must be a struct with fields 'name' and 'arguments'.";
catalog("llms:mustBeAssistantWithIdAndFunction") = "Field 'tool_call' must be a struct with fields 'id' and 'function'.";
catalog("llms:mustBeAssistantWithNameAndArguments") = "Field 'function' must be a struct with fields 'name' and 'arguments'.";
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, ToolChoice must not be specified.";
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";
end

catalog("llms:invalidOptionAndValueForModel") = "'{1}' with value '{2}' is not supported for ModelName '{3}'";
catalog("llms:invalidOptionForModel") = "{1} is not supported for ModelName '{2}'";
catalog("llms:functionNotAvailableForModel") = "This function is not supported for ModelName '{1}'";
catalog("llms:promptLimitCharacter") = "Prompt must have a maximum length of {1} characters for ModelName '{2}'";
catalog("llms:pngExpected") = "Argument must be a PNG image.";
catalog("llms:warningJsonInstruction") = "When using JSON mode, you must also prompt the model to produce JSON yourself via a system or user message.";
end
28 changes: 28 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
*.fig binary
*.mat binary
*.mdl binary diff merge=mlAutoMerge
*.mdlp binary
*.mexa64 binary
*.mexw64 binary
*.mexmaci64 binary
*.mlapp binary
*.mldatx binary
*.mlproj binary
*.mlx binary
*.p binary
*.sfx binary
*.sldd binary
*.slreqx binary merge=mlAutoMerge
*.slmx binary merge=mlAutoMerge
*.sltx binary
*.slxc binary
*.slx binary merge=mlAutoMerge
*.slxp binary

## Other common binary file types
*.docx binary
*.exe binary
*.jpg binary
*.pdf binary
*.png binary
*.xlsx binary
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.env
*.asv
startup.m
Loading