Build a Streaming API for a Large Language Model#
Large Language Models can power a wide range of applications from chatbots to assistants and intelligent systems. However, these models can be heavy and slow and your users want systems that are both intelligent and fast!
Large language models work by turning your questions into tokens and then generating new token one at a time until it decides that generation should stop. This means you want to stream the output tokens generated by a large language model to the client. In this tutorial, we will discuss how to achieve this with Streaming Endpoints in Jina.
Service Schemas#
The first step is to define the streaming service schemas, as you would do in any other service framework. The input to the service is the prompt and the maximum number of tokens to generate, while the output is simply the token ID:
from docarray import BaseDoc
class PromptDocument(BaseDoc):
prompt: str
max_tokens: int
class ModelOutputDocument(BaseDoc):
token_id: int
generated_text: str
Note
Thanks to DocArray’s flexibility, you can implement very flexible services. For instance, you can use Tensor types to efficiently stream token logits back to the client and implement complex token sampling strategies on the client side.
Service initialization#
Our service depends on a large language model. As an example, we will use the gpt2
model. This is how you would load
such a model in your executor
from jina import Executor, requests
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
class TokenStreamingExecutor(Executor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
Implement the streaming endpoint#
Our streaming endpoint accepts a PromptDocument
as input and streams ModelOutputDocument
s. To stream a document back to
the client, use the yield
keyword in the endpoint implementation. Therefore, we use the model to generate
up to max_tokens
tokens and yield them until the generation stops:
class TokenStreamingExecutor(Executor):
...
@requests(on='/stream')
async def task(self, doc: PromptDocument, **kwargs) -> ModelOutputDocument:
input = tokenizer(doc.prompt, return_tensors='pt')
input_len = input['input_ids'].shape[1]
for _ in range(doc.max_tokens):
output = self.model.generate(**input, max_new_tokens=1)
if output[0][-1] == tokenizer.eos_token_id:
break
yield ModelOutputDocument(
token_id=output[0][-1],
generated_text=tokenizer.decode(
output[0][input_len:], skip_special_tokens=True
),
)
input = {
'input_ids': output,
'attention_mask': torch.ones(1, len(output[0])),
}
Learn more about streaming endpoints from the Executor
documentation.
Serve and send requests#
The final step is to serve the Executor and send requests using the client. To serve the Executor using gRPC:
from jina import Deployment
with Deployment(uses=TokenStreamingExecutor, port=12345, protocol='grpc') as dep:
dep.block()
To send requests from a client:
import asyncio
from jina import Client
async def main():
client = Client(port=12345, protocol='grpc', asyncio=True)
async for doc in client.stream_doc(
on='/stream',
inputs=PromptDocument(prompt='what is the capital of France ?', max_tokens=10),
return_type=ModelOutputDocument,
):
print(doc.generated_text)
asyncio.run(main())
The
The capital
The capital of
The capital of France
The capital of France is
The capital of France is Paris
The capital of France is Paris.