commit 035a2ecab05097c663887ebf12f6716d9bbac6aa
Author: TimMikeladze <[email protected]>
Date: Thu Sep 22 17:27:14 2022 +0300
First pass at streaming support
diff --git a/package.json b/package.json
index 7d93ea1..f8e3232 100644
--- a/package.json
+++ b/package.json
@@ -58,6 +58,7 @@
],
"devDependencies": {
"@size-limit/preset-small-lib": "7.0.8",
+ "@types/ws": "8.5.3",
"husky": "8.0.1",
"size-limit": "7.0.8",
"tsdx": "0.14.1",
@@ -68,6 +69,8 @@
"node-notifier": ">=8.0.1"
},
"dependencies": {
- "isomorphic-unfetch": "3.1.0"
+ "isomorphic-unfetch": "3.1.0",
+ "isomorphic-ws": "5.0.0",
+ "ws": "8.8.1"
}
}
diff --git a/src/HuggingFace.ts b/src/HuggingFace.ts
index 6909702..29f3684 100644
--- a/src/HuggingFace.ts
+++ b/src/HuggingFace.ts
@@ -1,6 +1,12 @@
import fetch from 'isomorphic-unfetch';
+import WebSocket from 'ws';
export type Options = {
+ /**
+ * (Default: `true`) If enabled, array arguments will be sent over a WebSocket connection and the response will be streamed back.
+ */
+ use_streaming?: boolean;
+
/**
* (Default: false). Boolean to use GPU instead of CPU for inference (requires Startup plan at least).
*/
@@ -21,7 +27,14 @@ export type Options = {
};
export type Args = {
+ /**
+ * The name of the HuggingFace model to use.
+ */
model: string;
+ /**
+ * When `use_streaming` option is enabled this id will be included in each response to identify the request.
+ */
+ id?: string;
};
export type FillMaskArgs = Args & {
@@ -391,9 +404,9 @@ export class HuggingFace {
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
*/
public async fillMask(
- args: FillMaskArgs,
+ args: FillMaskArgs | FillMaskArgs[],
options?: Options
- ): Promise<FillMaskReturn> {
+ ): Promise<FillMaskReturn | FillMaskReturn[]> {
return this.request(args, options);
}
@@ -401,19 +414,33 @@ export class HuggingFace {
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
*/
public async summarization(
- args: SummarizationArgs,
+ args: SummarizationArgs | SummarizationArgs[],
options?: Options
- ): Promise<SummarizationReturn> {
+ ): Promise<SummarizationReturn | SummarizationReturn[]> {
return (await this.request(args, options))?.[0];
}
+ /**
+ * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
+ */
+ public async questionAnswer(
+ args: QuestionAnswerArgs[],
+ options?: Options
+ ): Promise<QuestionAnswerReturn[]>;
/**
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
*/
public async questionAnswer(
args: QuestionAnswerArgs,
options?: Options
- ): Promise<QuestionAnswerReturn> {
+ ): Promise<QuestionAnswerReturn>;
+ /**
+ * Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
+ */
+ public async questionAnswer(
+ args: QuestionAnswerArgs | QuestionAnswerArgs[],
+ options?: Options
+ ): Promise<QuestionAnswerReturn | QuestionAnswerReturn[]> {
return await this.request(args, options);
}
@@ -421,9 +448,9 @@ export class HuggingFace {
* Don’t know SQL? Don’t want to dive into a large spreadsheet? Ask questions in plain english! Recommended model: google/tapas-base-finetuned-wtq.
*/
public async tableQuestionAnswer(
- args: TableQuestionAnswerArgs,
+ args: TableQuestionAnswerArgs | TableQuestionAnswerArgs[],
options?: Options
- ): Promise<TableQuestionAnswerReturn> {
+ ): Promise<TableQuestionAnswerReturn | TableQuestionAnswerReturn[]> {
return await this.request(args, options);
}
@@ -431,9 +458,9 @@ export class HuggingFace {
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
*/
public async textClassification(
- args: TextClassificationArgs,
+ args: TextClassificationArgs | TextClassificationArgs[],
options?: Options
- ): Promise<TextClassificationReturn> {
+ ): Promise<TextClassificationReturn | TextClassificationReturn[]> {
return await this.request(args, options);
}
@@ -441,9 +468,9 @@ export class HuggingFace {
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
*/
public async textGeneration(
- args: TextGenerationArgs,
+ args: TextGenerationArgs | TextGenerationArgs[],
options?: Options
- ): Promise<TextGenerationReturn> {
+ ): Promise<TextGenerationReturn | TextGenerationReturn[]> {
return (await this.request(args, options))?.[0];
}
@@ -451,9 +478,9 @@ export class HuggingFace {
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
*/
public async tokenClassification(
- args: TokenClassificationArgs,
+ args: TokenClassificationArgs | TokenClassificationArgs[],
options?: Options
- ): Promise<TokenClassificationReturn> {
+ ): Promise<TokenClassificationReturn | TokenClassificationReturn[]> {
return HuggingFace.toArray(await this.request(args, options));
}
@@ -461,9 +488,9 @@ export class HuggingFace {
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
*/
public async translation(
- args: TranslationArgs,
+ args: TranslationArgs | TranslationArgs[],
options?: Options
- ): Promise<TranslationReturn> {
+ ): Promise<TranslationReturn | TranslationReturn[]> {
return (await this.request(args, options))?.[0];
}
@@ -471,9 +498,9 @@ export class HuggingFace {
* This task is super useful to try out classification with zero code, you simply pass a sentence/paragraph and the possible labels for that sentence, and you get a result. Recommended model: facebook/bart-large-mnli.
*/
public async zeroShotClassification(
- args: ZeroShotClassificationArgs,
+ args: ZeroShotClassificationArgs | ZeroShotClassificationArgs[],
options?: Options
- ): Promise<ZeroShotClassificationReturn> {
+ ): Promise<ZeroShotClassificationReturn | ZeroShotClassificationReturn[]> {
return HuggingFace.toArray(await this.request(args, options));
}
@@ -482,9 +509,9 @@ export class HuggingFace {
*
*/
public async conversational(
- args: ConversationalArgs,
+ args: ConversationalArgs | ConversationalArgs[],
options?: Options
- ): Promise<ConversationalReturn> {
+ ): Promise<ConversationalReturn | ConversationalReturn[]> {
return await this.request(args, options);
}
@@ -492,43 +519,111 @@ export class HuggingFace {
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
*/
public async featureExtraction(
- args: FeatureExtractionArgs,
+ args: FeatureExtractionArgs | FeatureExtractionArgs[],
options?: Options
- ): Promise<FeatureExtractionReturn> {
+ ): Promise<FeatureExtractionReturn | FeatureExtractionReturn[]> {
return await this.request(args, options);
}
- public async request(args: Args, options?: Options): Promise<any> {
+ public async request(
+ args: Args | Args[],
+ options?: Options
+ ): Promise<any | any[]> {
const mergedOptions = { ...this.defaultOptions, ...options };
- const { model, ...otherArgs } = args;
- const response = await fetch(
- `https://api-inference.huggingface.co/models/${model}`,
- {
- headers: { Authorization: `Bearer ${this.apiKey}` },
- method: 'POST',
- body: JSON.stringify({
- ...otherArgs,
- options: mergedOptions,
- }),
+
+ if (Array.isArray(args) && options?.use_streaming !== false) {
+ const models = new Set(args.map(x => x.model));
+
+ if (models.size > 1) {
+ throw new Error(
+ 'You can only send use one model per request when the `use_streaming` option is enabled. Please group your requests by model.'
+ );
}
- );
-
- if (
- mergedOptions.retry_on_error !== false &&
- response.status === 503 &&
- !mergedOptions.wait_for_model
- ) {
- return this.request(args, {
- ...mergedOptions,
- wait_for_model: true,
+
+ const model = args[0].model;
+
+ const uniqueIds = args
+ .map(x => x.id)
+ .filter(x => x !== undefined && x !== null && x?.trim() !== '');
+
+ if (uniqueIds.length !== new Set(uniqueIds).size) {
+ throw new Error('Duplicate ids found in args');
+ }
+
+ const ws = new WebSocket(
+ `wss://api-inference.huggingface.co/bulk/stream/cpu/${model}`
+ );
+
+ // @ts-ignore
+ const responses: any[] = [];
+
+ // @ts-ignore
+ return new Promise((resolve, reject) => {
+ ws.on('open', () => {
+ ws.send(`Bearer ${this.apiKey}`, { binary: true });
+
+ for (const arg of args) {
+ ws.send(JSON.stringify(arg), { binary: true });
+ }
+ });
+
+ ws.on('message', (data: any) => {
+ console.log(Buffer.from(data).toString());
+ console.log(data);
+ // const message = JSON.parse(data);
+ // if (message.type == 'results') {
+ // responses.push(message);
+ // if (responses.length === args.length) {
+ // ws.close();
+ // resolve(responses);
+ // }
+ // }
+ });
+ ws.on('error', message => {
+ console.log(message);
+ reject(message);
+ });
});
- }
+ } else {
+ const httpRequest = async (args: Args) => {
+ const { model, ...otherArgs } = args;
+
+ const response = await fetch(
+ `https://api-inference.huggingface.co/models/${model}`,
+ {
+ headers: { Authorization: `Bearer ${this.apiKey}` },
+ method: 'POST',
+ body: JSON.stringify({
+ ...otherArgs,
+ options: mergedOptions,
+ }),
+ }
+ );
+
+ if (
+ mergedOptions.retry_on_error !== false &&
+ response.status === 503 &&
+ !mergedOptions.wait_for_model
+ ) {
+ return this.request(args, {
+ ...mergedOptions,
+ wait_for_model: true,
+ });
+ }
+
+ const res = await response.json();
+ if (res.error) {
+ throw new Error(res.error);
+ }
+ return res;
+ };
+
+ if (Array.isArray(args)) {
+ return Promise.all(args.map(x => httpRequest(x)));
+ }
- const res = await response.json();
- if (res.error) {
- throw new Error(res.error);
+ return httpRequest(args);
}
- return res;
}
private static toArray(obj: any): any[] {
diff --git a/test/HuggingFace.test.ts b/test/HuggingFace.test.ts
index ed7a592..da0a2bd 100644
--- a/test/HuggingFace.test.ts
+++ b/test/HuggingFace.test.ts
@@ -7,7 +7,7 @@ describe('HuggingFace', () => {
// Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error.
let hf = new HuggingFace(process.env.HF_API_KEY as string);
- it('throws error if model does not exist', () => {
+ xit('throws error if model does not exist', () => {
expect(
hf.fillMask({
model: 'this-model-does-not-exist-123',
@@ -17,7 +17,24 @@ describe('HuggingFace', () => {
`Model this-model-does-not-exist-123 does not exist`
);
});
- it('fillMask', async () => {
+ xit('throws error if multiple models are provided and use_streaming is true', () => {
+ expect(
+ hf.fillMask([
+ {
+ model: 'this-model-does-not-exist-123',
+ inputs: '[MASK] world!',
+ },
+ {
+ model: 'this-model-also-does-not-exist-123',
+ inputs: '[MASK] world!',
+ },
+ ])
+ ).rejects.toThrowError(
+ `Model this-model-does-not-exist-123 does not exist`
+ );
+ });
+
+ xit('fillMask', async () => {
expect(
await hf.fillMask({
model: 'bert-base-uncased',
@@ -34,7 +51,7 @@ describe('HuggingFace', () => {
])
);
});
- it('summarization', async () => {
+ xit('summarization', async () => {
expect(
await hf.summarization({
model: 'facebook/bart-large-cnn',
@@ -49,7 +66,7 @@ describe('HuggingFace', () => {
'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world.',
});
});
- it('questionAnswer', async () => {
+ xit('questionAnswer', async () => {
expect(
await hf.questionAnswer({
model: 'deepset/roberta-base-squad2',
@@ -65,7 +82,7 @@ describe('HuggingFace', () => {
end: expect.any(Number),
});
});
- it('table question answer', async () => {
+ xit('table question answer', async () => {
expect(
await hf.tableQuestionAnswer({
model: 'google/tapas-base-finetuned-wtq',
@@ -90,7 +107,7 @@ describe('HuggingFace', () => {
aggregator: 'AVERAGE',
});
});
- it('textClassification', async () => {
+ xit('textClassification', async () => {
expect(
await hf.textClassification({
model: 'distilbert-base-uncased-finetuned-sst-2-english',
@@ -105,7 +122,7 @@ describe('HuggingFace', () => {
])
);
});
- it('textGeneration', async () => {
+ xit('textGeneration', async () => {
expect(
await hf.textGeneration({
model: 'gpt2',
@@ -116,7 +133,7 @@ describe('HuggingFace', () => {
'The answer to the universe is not a binary number that is at a certain point defined in our theory of time, but an infinite number of infinitely long points and points for which each of these points has the given form in our equation. If the given',
});
});
- it(`tokenClassification`, async () => {
+ xit('tokenClassification', async () => {
expect(
await hf.tokenClassification({
model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
@@ -134,7 +151,7 @@ describe('HuggingFace', () => {
])
);
});
- it(`translation`, async () => {
+ xit('translation', async () => {
expect(
await hf.translation({
model: 'Helsinki-NLP/opus-mt-ru-en',
@@ -144,7 +161,7 @@ describe('HuggingFace', () => {
translation_text: 'My name is Wolfgang and I live in Berlin.',
});
});
- it(`zeroShotClassification`, async () => {
+ xit('zeroShotClassification', async () => {
expect(
await hf.zeroShotClassification({
model: 'facebook/bart-large-mnli',
@@ -164,7 +181,7 @@ describe('HuggingFace', () => {
])
);
});
- it(`conversational`, async () => {
+ xit('conversational', async () => {
expect(
await hf.conversational({
model: 'microsoft/DialoGPT-large',
@@ -191,7 +208,7 @@ describe('HuggingFace', () => {
],
});
});
- it(`featureExtraction`, async () => {
+ xit('featureExtraction', async () => {
expect(
await hf.featureExtraction({
model: 'sentence-transformers/paraphrase-xlm-r-multilingual-v1',
@@ -206,4 +223,69 @@ describe('HuggingFace', () => {
})
).toEqual([0.6623499393463135, 0.9382339715957642, 0.22963346540927887]);
});
+
+ xit('use http for array input when use_streaming is false', async () => {
+ const res = await hf.questionAnswer(
+ [
+ {
+ model: 'deepset/roberta-base-squad2',
+ inputs: {
+ question: 'What is the capital of France?',
+ context: 'The capital of France is Paris.',
+ },
+ },
+ {
+ model: 'deepset/roberta-base-squad2',
+ inputs: {
+ question: 'What is the capital of England?',
+ context: 'The capital of England is London.',
+ },
+ },
+ ],
+ {
+ use_streaming: false,
+ }
+ );
+
+ expect(res).toHaveLength(2);
+
+ expect(res[0]).toEqual({
+ answer: 'Paris',
+ score: expect.any(Number),
+ start: expect.any(Number),
+ end: expect.any(Number),
+ });
+ expect(res[1]).toEqual({
+ answer: 'London',
+ score: expect.any(Number),
+ start: expect.any(Number),
+ end: expect.any(Number),
+ });
+ });
+
+ it('use websockets for array input when use_streaming is true', async () => {
+ const res = await hf.questionAnswer(
+ [
+ {
+ model: 'deepset/roberta-base-squad2',
+ inputs: {
+ question: 'What is the capital of France?',
+ context: 'The capital of France is Paris.',
+ },
+ },
+ {
+ model: 'deepset/roberta-base-squad2',
+ inputs: {
+ question: 'What is the capital of England?',
+ context: 'The capital of England is London.',
+ },
+ },
+ ],
+ {
+ use_streaming: true,
+ }
+ );
+
+ console.log(res);
+ });
});
diff --git a/yarn.lock b/yarn.lock
index c7cfe8d..6033ed8 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -1404,6 +1404,13 @@
resolved "https://registry.yarnpkg.com/@types/stack-utils/-/stack-utils-1.0.1.tgz#0a851d3bd96498fa25c33ab7278ed3bd65f06c3e"
integrity sha512-l42BggppR6zLmpfU6fq9HEa2oGPEI8yrSPL3GITjfRInppYFahObbIQOQK3UGxEnyQpltZLaPe75046NOZQikw==
+"@types/[email protected]":
+ version "8.5.3"
+ resolved "https://registry.yarnpkg.com/@types/ws/-/ws-8.5.3.tgz#7d25a1ffbecd3c4f2d35068d0b283c037003274d"
+ integrity sha512-6YOoWjruKj1uLf3INHH7D3qTXwFfEsg1kf3c0uDdSBJwfa/llkwIjrAGV7j7mVgGNbzTQ3HiHKKDXl6bJPD97w==
+ dependencies:
+ "@types/node" "*"
+
"@types/yargs-parser@*":
version "21.0.0"
resolved "https://registry.yarnpkg.com/@types/yargs-parser/-/yargs-parser-21.0.0.tgz#0c60e537fa790f5f9472ed2776c2b71ec117351b"
@@ -3842,6 +3849,11 @@ [email protected]:
node-fetch "^2.6.1"
unfetch "^4.2.0"
[email protected]:
+ version "5.0.0"
+ resolved "https://registry.yarnpkg.com/isomorphic-ws/-/isomorphic-ws-5.0.0.tgz#e5529148912ecb9b451b46ed44d53dae1ce04bbf"
+ integrity sha512-muId7Zzn9ywDsyXgTIafTry2sV3nySZeUDe6YedVd1Hvuuep5AsIlqK+XefWpYTyJG5e503F2xIuT2lcU6rCSw==
+
isstream@~0.1.2:
version "0.1.2"
resolved "https://registry.yarnpkg.com/isstream/-/isstream-0.1.2.tgz#47e63f7af55afa6f92e1500e690eb8b8529c099a"
@@ -6651,6 +6663,11 @@ [email protected]:
dependencies:
mkdirp "^0.5.1"
[email protected]:
+ version "8.8.1"
+ resolved "https://registry.yarnpkg.com/ws/-/ws-8.8.1.tgz#5dbad0feb7ade8ecc99b830c1d77c913d4955ff0"
+ integrity sha512-bGy2JzvzkPowEJV++hF07hAD6niYSr0JzBNo/J29WsB57A2r7Wlc1UFcTR9IzrPvuNVO4B8LGqF8qcpsVOhJCA==
+
ws@^7.0.0:
version "7.5.8"
resolved "https://registry.yarnpkg.com/ws/-/ws-7.5.8.tgz#ac2729881ab9e7cbaf8787fe3469a48c5c7f636a"