Skip to content

Commit 3fd64f2

Browse files
committed
Merge branch 'embeddings' into 'main'
Adding embedding model. See merge request dferreir/llms-with-matlab!11
2 parents 76565a5 + 4244345 commit 3fd64f2

File tree

4 files changed

+129
-1
lines changed

4 files changed

+129
-1
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
function mustBeNonzeroLengthTextScalar(content)
2+
mustBeNonzeroLengthText(content)
3+
mustBeTextScalar(content)
4+
end

extractOpenAIEmbeddings.m

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
function [emb, response] = extractOpenAIEmbeddings(text, nvp)
2+
% EXTRACTOPENAIEMBEDDINGS Generate text embeddings using the OpenAI API
3+
%
4+
% emb = EXTRACTOPENAIEMBEDDINGS(text) generates an embedding of the input
5+
% TEXT using the OpenAI API.
6+
%
7+
% emb = EXTRACTOPENAIEMBEDDINGS(text,Name=Value) specifies optional
8+
% specifies additional options using one or more name-value pairs:
9+
%
10+
% 'ModelName' - The ID of the model to use.
11+
%
12+
% 'ApiKey' - OpenAI API token. It can also be specified by
13+
% setting the environment variable OPENAI_API_KEY
14+
%
15+
% 'TimeOut' - Connection Timeout in seconds (default: 10 secs)
16+
%
17+
% [emb, response] = EXTRACTOPENAIEMBEDDINGS(...) also returns the full
18+
% response from the OpenAI API call.
19+
%
20+
% Copyright 2023 The MathWorks, Inc.
21+
22+
arguments
23+
text (1,:) {mustBeText}
24+
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,"text-embedding-ada-002")} = "text-embedding-ada-002"
25+
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10
26+
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
27+
end
28+
29+
END_POINT = "https://api.openai.com/v1/embeddings";
30+
31+
key = llms.internal.getApiKeyFromNvpOrEnv(nvp);
32+
33+
parameters = struct("input",text,"model",nvp.ModelName);
34+
35+
response = llms.internal.sendRequest(parameters,key, END_POINT, nvp.TimeOut);
36+
37+
if isfield(response, "data")
38+
emb = [response.Body.Data.data.embedding];
39+
emb = emb';
40+
else
41+
emb = [];
42+
end

openAIChat.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
nvp.Temperature {mustBeValidTemperature} = 1
109109
nvp.TopProbabilityMass {mustBeValidTopP} = 1
110110
nvp.StopSequences {mustBeValidStop} = {}
111-
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
111+
nvp.ApiKey {llms.utils.mustBeNonzeroLengthTextScalar}
112112
nvp.PresencePenalty {mustBeValidPenalty} = 0
113113
nvp.FrequencyPenalty {mustBeValidPenalty} = 0
114114
nvp.TimeOut (1,1) {mustBeReal,mustBePositive} = 10

tests/textractOpenAIEmbeddings.m

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
classdef textractOpenAIEmbeddings < matlab.unittest.TestCase
2+
% Tests for extractOpenAIEmbeddings
3+
4+
% Copyright 2023 The MathWorks, Inc.
5+
6+
methods (TestClassSetup)
7+
function saveEnvVar(testCase)
8+
% Ensures key is not in environment variable for tests
9+
openAIEnvVar = "OPENAI_API_KEY";
10+
if isenv(openAIEnvVar)
11+
key = getenv(openAIEnvVar);
12+
unsetenv(openAIEnvVar);
13+
testCase.addTeardown(@(x) setenv(openAIEnvVar, x), key);
14+
end
15+
end
16+
end
17+
18+
properties(TestParameter)
19+
InvalidInput = iGetInvalidInput;
20+
end
21+
22+
methods(Test)
23+
% Test methods
24+
function embedsDifferentStringTypes(testCase)
25+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ApiKey="this-is-not-a-real-key"));
26+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings('bla', ApiKey="this-is-not-a-real-key"));
27+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings({'bla'}, ApiKey="this-is-not-a-real-key"));
28+
end
29+
30+
function keyNotFound(testCase)
31+
testCase.verifyError(@()extractOpenAIEmbeddings("bla"), "llms:keyMustBeSpecified");
32+
end
33+
34+
function useAllNVP(testCase)
35+
testCase.verifyWarningFree(@()extractOpenAIEmbeddings("bla", ModelName="text-embedding-ada-002", ...
36+
ApiKey="this-is-not-a-real-key", TimeOut=10));
37+
end
38+
39+
function verySmallTimeOutErrors(testCase)
40+
testCase.verifyError(@()extractOpenAIEmbeddings("bla", TimeOut=0.0001, ApiKey="false-key"), "MATLAB:webservices:Timeout")
41+
end
42+
43+
function testInvalidInputs(testCase, InvalidInput)
44+
testCase.verifyError(@()extractOpenAIEmbeddings(InvalidInput.Input{:}), InvalidInput.Error);
45+
end
46+
end
47+
end
48+
49+
function invalidInput = iGetInvalidInput
50+
invalidInput = struct( ...
51+
"InvalidTimeOutType", struct( ...
52+
"Input",{{ "bla", "TimeOut", "2" }},...
53+
"Error", "MATLAB:validators:mustBeReal"), ...
54+
...
55+
"InvalidTimeOutSize", struct( ...
56+
"Input",{{ "bla", "TimeOut", [1 1 1] }},...
57+
"Error", "MATLAB:validation:IncompatibleSize"), ...
58+
...
59+
"WrongTypeText",struct( ...
60+
"Input",{{ 123 }},...
61+
"Error","MATLAB:validators:mustBeText"),...
62+
...
63+
"InvalidModelNameType",struct( ...
64+
"Input",{{"bla", "ModelName", 0 }},...
65+
"Error","MATLAB:validators:mustBeMember"),...
66+
...
67+
"InvalidModelNameSize",struct( ...
68+
"Input",{{"bla", "ModelName", ["gpt-3.5-turbo", "gpt-3.5-turbo"] }},...
69+
"Error","MATLAB:validation:IncompatibleSize"),...
70+
...
71+
"InvalidModelNameOption",struct( ...
72+
"Input",{{"bla", "ModelName", "gpt" }},...
73+
"Error","MATLAB:validators:mustBeMember"),...
74+
...
75+
"InvalidApiKeyType",struct( ...
76+
"Input",{{"bla", "ApiKey" 123 }},...
77+
"Error","MATLAB:validators:mustBeNonzeroLengthText"),...
78+
...
79+
"InvalidApiKeySize",struct( ...
80+
"Input",{{"bla", "ApiKey" ["abc" "abc"] }},...
81+
"Error","MATLAB:validators:mustBeTextScalar"));
82+
end

0 commit comments

Comments
 (0)