Skip to content

Commit 177f546

Browse files
authored
AI Hybrid Inference: migrate to LanguageModelMessage (#9027)
1 parent 72cd626 commit 177f546

File tree

4 files changed

+165
-48
lines changed

4 files changed

+165
-48
lines changed

common/api-review/ai.api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,10 +705,10 @@ export interface LanguageModelMessage {
705705

706706
// @public (undocumented)
707707
export interface LanguageModelMessageContent {
708-
// (undocumented)
709-
content: LanguageModelMessageContentValue;
710708
// (undocumented)
711709
type: LanguageModelMessageType;
710+
// (undocumented)
711+
value: LanguageModelMessageContentValue;
712712
}
713713

714714
// @public (undocumented)

packages/ai/src/methods/chrome-adapter.test.ts

Lines changed: 118 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
Availability,
2525
LanguageModel,
2626
LanguageModelCreateOptions,
27-
LanguageModelMessageContent
27+
LanguageModelMessage
2828
} from '../types/language-model';
2929
import { match, stub } from 'sinon';
3030
import { GenerateContentRequest, AIErrorCode } from '../types';
@@ -146,7 +146,7 @@ describe('ChromeAdapter', () => {
146146
})
147147
).to.be.false;
148148
});
149-
it('returns false if request content has non-user role', async () => {
149+
it('returns false if request content has "function" role', async () => {
150150
const adapter = new ChromeAdapter(
151151
{
152152
availability: async () => Availability.available
@@ -157,7 +157,7 @@ describe('ChromeAdapter', () => {
157157
await adapter.isAvailable({
158158
contents: [
159159
{
160-
role: 'model',
160+
role: 'function',
161161
parts: []
162162
}
163163
]
@@ -320,7 +320,7 @@ describe('ChromeAdapter', () => {
320320
} as LanguageModel;
321321
const languageModel = {
322322
// eslint-disable-next-line @typescript-eslint/no-unused-vars
323-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
323+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
324324
} as LanguageModel;
325325
const createStub = stub(languageModelProvider, 'create').resolves(
326326
languageModel
@@ -345,8 +345,13 @@ describe('ChromeAdapter', () => {
345345
// Asserts Vertex input type is mapped to Chrome type.
346346
expect(promptStub).to.have.been.calledOnceWith([
347347
{
348-
type: 'text',
349-
content: request.contents[0].parts[0].text
348+
role: request.contents[0].role,
349+
content: [
350+
{
351+
type: 'text',
352+
value: request.contents[0].parts[0].text
353+
}
354+
]
350355
}
351356
]);
352357
// Asserts expected output.
@@ -366,7 +371,7 @@ describe('ChromeAdapter', () => {
366371
} as LanguageModel;
367372
const languageModel = {
368373
// eslint-disable-next-line @typescript-eslint/no-unused-vars
369-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
374+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
370375
} as LanguageModel;
371376
const createStub = stub(languageModelProvider, 'create').resolves(
372377
languageModel
@@ -404,12 +409,17 @@ describe('ChromeAdapter', () => {
404409
// Asserts Vertex input type is mapped to Chrome type.
405410
expect(promptStub).to.have.been.calledOnceWith([
406411
{
407-
type: 'text',
408-
content: request.contents[0].parts[0].text
409-
},
410-
{
411-
type: 'image',
412-
content: match.instanceOf(ImageBitmap)
412+
role: request.contents[0].role,
413+
content: [
414+
{
415+
type: 'text',
416+
value: request.contents[0].parts[0].text
417+
},
418+
{
419+
type: 'image',
420+
value: match.instanceOf(ImageBitmap)
421+
}
422+
]
413423
}
414424
]);
415425
// Asserts expected output.
@@ -426,7 +436,7 @@ describe('ChromeAdapter', () => {
426436
it('honors prompt options', async () => {
427437
const languageModel = {
428438
// eslint-disable-next-line @typescript-eslint/no-unused-vars
429-
prompt: (p: LanguageModelMessageContent[]) => Promise.resolve('')
439+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('')
430440
} as LanguageModel;
431441
const languageModelProvider = {
432442
create: () => Promise.resolve(languageModel)
@@ -450,13 +460,48 @@ describe('ChromeAdapter', () => {
450460
expect(promptStub).to.have.been.calledOnceWith(
451461
[
452462
{
453-
type: 'text',
454-
content: request.contents[0].parts[0].text
463+
role: request.contents[0].role,
464+
content: [
465+
{
466+
type: 'text',
467+
value: request.contents[0].parts[0].text
468+
}
469+
]
455470
}
456471
],
457472
promptOptions
458473
);
459474
});
475+
it('normalizes roles', async () => {
476+
const languageModel = {
477+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
478+
prompt: (p: LanguageModelMessage[]) => Promise.resolve('unused')
479+
} as LanguageModel;
480+
const promptStub = stub(languageModel, 'prompt').resolves('unused');
481+
const languageModelProvider = {
482+
create: () => Promise.resolve(languageModel)
483+
} as LanguageModel;
484+
const adapter = new ChromeAdapter(
485+
languageModelProvider,
486+
'prefer_on_device'
487+
);
488+
const request = {
489+
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
490+
} as GenerateContentRequest;
491+
await adapter.generateContent(request);
492+
expect(promptStub).to.have.been.calledOnceWith([
493+
{
494+
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
495+
role: 'assistant',
496+
content: [
497+
{
498+
type: 'text',
499+
value: request.contents[0].parts[0].text
500+
}
501+
]
502+
}
503+
]);
504+
});
460505
});
461506
describe('countTokens', () => {
462507
it('counts tokens is not yet available', async () => {
@@ -528,8 +573,13 @@ describe('ChromeAdapter', () => {
528573
expect(createStub).to.have.been.calledOnceWith(createOptions);
529574
expect(promptStub).to.have.been.calledOnceWith([
530575
{
531-
type: 'text',
532-
content: request.contents[0].parts[0].text
576+
role: request.contents[0].role,
577+
content: [
578+
{
579+
type: 'text',
580+
value: request.contents[0].parts[0].text
581+
}
582+
]
533583
}
534584
]);
535585
const actual = await toStringArray(response.body!);
@@ -584,12 +634,17 @@ describe('ChromeAdapter', () => {
584634
expect(createStub).to.have.been.calledOnceWith(createOptions);
585635
expect(promptStub).to.have.been.calledOnceWith([
586636
{
587-
type: 'text',
588-
content: request.contents[0].parts[0].text
589-
},
590-
{
591-
type: 'image',
592-
content: match.instanceOf(ImageBitmap)
637+
role: request.contents[0].role,
638+
content: [
639+
{
640+
type: 'text',
641+
value: request.contents[0].parts[0].text
642+
},
643+
{
644+
type: 'image',
645+
value: match.instanceOf(ImageBitmap)
646+
}
647+
]
593648
}
594649
]);
595650
const actual = await toStringArray(response.body!);
@@ -625,13 +680,50 @@ describe('ChromeAdapter', () => {
625680
expect(promptStub).to.have.been.calledOnceWith(
626681
[
627682
{
628-
type: 'text',
629-
content: request.contents[0].parts[0].text
683+
role: request.contents[0].role,
684+
content: [
685+
{
686+
type: 'text',
687+
value: request.contents[0].parts[0].text
688+
}
689+
]
630690
}
631691
],
632692
promptOptions
633693
);
634694
});
695+
it('normalizes roles', async () => {
696+
const languageModel = {
697+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
698+
promptStreaming: p => new ReadableStream()
699+
} as LanguageModel;
700+
const promptStub = stub(languageModel, 'promptStreaming').returns(
701+
new ReadableStream()
702+
);
703+
const languageModelProvider = {
704+
create: () => Promise.resolve(languageModel)
705+
} as LanguageModel;
706+
const adapter = new ChromeAdapter(
707+
languageModelProvider,
708+
'prefer_on_device'
709+
);
710+
const request = {
711+
contents: [{ role: 'model', parts: [{ text: 'unused' }] }]
712+
} as GenerateContentRequest;
713+
await adapter.generateContentStream(request);
714+
expect(promptStub).to.have.been.calledOnceWith([
715+
{
716+
// Asserts Vertex's "model" role normalized to Chrome's "assistant" role.
717+
role: 'assistant',
718+
content: [
719+
{
720+
type: 'text',
721+
value: request.contents[0].parts[0].text
722+
}
723+
]
724+
}
725+
]);
726+
});
635727
});
636728
});
637729

packages/ai/src/methods/chrome-adapter.ts

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@ import {
2323
InferenceMode,
2424
Part,
2525
AIErrorCode,
26-
OnDeviceParams
26+
OnDeviceParams,
27+
Content,
28+
Role
2729
} from '../types';
2830
import {
2931
Availability,
3032
LanguageModel,
31-
LanguageModelMessageContent
33+
LanguageModelMessage,
34+
LanguageModelMessageContent,
35+
LanguageModelMessageRole
3236
} from '../types/language-model';
3337

3438
/**
@@ -115,10 +119,8 @@ export class ChromeAdapter {
115119
*/
116120
async generateContent(request: GenerateContentRequest): Promise<Response> {
117121
const session = await this.createSession();
118-
// TODO: support multiple content objects when Chrome supports
119-
// sequence<LanguageModelMessage>
120122
const contents = await Promise.all(
121-
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
123+
request.contents.map(ChromeAdapter.toLanguageModelMessage)
122124
);
123125
const text = await session.prompt(
124126
contents,
@@ -139,10 +141,8 @@ export class ChromeAdapter {
139141
request: GenerateContentRequest
140142
): Promise<Response> {
141143
const session = await this.createSession();
142-
// TODO: support multiple content objects when Chrome supports
143-
// sequence<LanguageModelMessage>
144144
const contents = await Promise.all(
145-
request.contents[0].parts.map(ChromeAdapter.toLanguageModelMessageContent)
145+
request.contents.map(ChromeAdapter.toLanguageModelMessage)
146146
);
147147
const stream = await session.promptStreaming(
148148
contents,
@@ -169,12 +169,8 @@ export class ChromeAdapter {
169169
}
170170

171171
for (const content of request.contents) {
172-
// Returns false if the request contains multiple roles, eg a chat history.
173-
// TODO: remove this guard once LanguageModelMessage is supported.
174-
if (content.role !== 'user') {
175-
logger.debug(
176-
`Non-user role "${content.role}" rejected for on-device inference.`
177-
);
172+
if (content.role === 'function') {
173+
logger.debug(`"Function" role rejected for on-device inference.`);
178174
return false;
179175
}
180176

@@ -233,6 +229,21 @@ export class ChromeAdapter {
233229
});
234230
}
235231

232+
/**
233+
* Converts Vertex {@link Content} object to a Chrome {@link LanguageModelMessage} object.
234+
*/
235+
private static async toLanguageModelMessage(
236+
content: Content
237+
): Promise<LanguageModelMessage> {
238+
const languageModelMessageContents = await Promise.all(
239+
content.parts.map(ChromeAdapter.toLanguageModelMessageContent)
240+
);
241+
return {
242+
role: ChromeAdapter.toLanguageModelMessageRole(content.role),
243+
content: languageModelMessageContents
244+
};
245+
}
246+
236247
/**
237248
* Converts a Vertex Part object to a Chrome LanguageModelMessageContent object.
238249
*/
@@ -242,7 +253,7 @@ export class ChromeAdapter {
242253
if (part.text) {
243254
return {
244255
type: 'text',
245-
content: part.text
256+
value: part.text
246257
};
247258
} else if (part.inlineData) {
248259
const formattedImageContent = await fetch(
@@ -252,14 +263,24 @@ export class ChromeAdapter {
252263
const imageBitmap = await createImageBitmap(imageBlob);
253264
return {
254265
type: 'image',
255-
content: imageBitmap
266+
value: imageBitmap
256267
};
257268
}
258269
// Assumes contents have been verified to contain only a single TextPart.
259270
// TODO: support other input types
260271
throw new Error('Not yet implemented');
261272
}
262273

274+
/**
275+
* Converts a Vertex {@link Role} string to a {@link LanguageModelMessageRole} string.
276+
*/
277+
private static toLanguageModelMessageRole(
278+
role: Role
279+
): LanguageModelMessageRole {
280+
// Assumes 'function' rule has been filtered by isOnDeviceRequest
281+
return role === 'model' ? 'assistant' : 'user';
282+
}
283+
263284
/**
264285
* Abstracts Chrome session creation.
265286
*

packages/ai/src/types/language-model.ts

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
17+
/**
18+
* {@see https://github.com/webmachinelearning/prompt-api#full-api-surface-in-web-idl}
19+
*/
1820
export interface LanguageModel extends EventTarget {
1921
create(options?: LanguageModelCreateOptions): Promise<LanguageModel>;
2022
availability(options?: LanguageModelCreateCoreOptions): Promise<Availability>;
@@ -57,8 +59,10 @@ export interface LanguageModelExpectedInput {
5759
type: LanguageModelMessageType;
5860
languages?: string[];
5961
}
60-
// TODO: revert to type from Prompt API explainer once it's supported.
61-
export type LanguageModelPrompt = LanguageModelMessageContent[];
62+
export type LanguageModelPrompt =
63+
| LanguageModelMessage[]
64+
| LanguageModelMessageShorthand[]
65+
| string;
6266
export type LanguageModelInitialPrompts =
6367
| LanguageModelMessage[]
6468
| LanguageModelMessageShorthand[];
@@ -72,7 +76,7 @@ export interface LanguageModelMessageShorthand {
7276
}
7377
export interface LanguageModelMessageContent {
7478
type: LanguageModelMessageType;
75-
content: LanguageModelMessageContentValue;
79+
value: LanguageModelMessageContentValue;
7680
}
7781
export type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
7882
export type LanguageModelMessageType = 'text' | 'image' | 'audio';

0 commit comments

Comments
 (0)