Skip to content

Python Client

LoRAX Python client provides a convenient way of interfacing with a lorax instance running in your environment.

Install

pip install lorax-client

Usage

from lorax import Client

endpoint_url = "http://127.0.0.1:8080"

client = Client(endpoint_url)
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
    if not response.token.special:
        text += response.token.text

print(text)
# ' Rayleigh scattering'

or with the asynchronous client:

from lorax import AsyncClient

endpoint_url = "http://127.0.0.1:8080"

client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'

# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
    if not response.token.special:
        text += response.token.text

print(text)
# ' Rayleigh scattering'

See API reference for full details.

Batch Inference

In some cases you may have a list of prompts that you wish to process in bulk ("batch processing").

Rather than process each prompt one at a time, you can take advantage of the AsyncClient and LoRAX's native parallelism to submit your prompts at once and await the results:

import asyncio
import time
from lorax import AsyncClient

# Batch of prompts to submit
prompts = [
    "The quick brown fox",
    "The rain in Spain",
    "What comes up",
]

# Initialize the async client
endpoint_url = "http://127.0.0.1:8080"
async_client = AsyncClient(endpoint_url)

# Submit all prompts and do not block on the response
t0 = time.time()
futures = []
for prompt in prompts:
    resp = async_client.generate(prompt, max_new_tokens=64)
    futures.append(resp)

# Await the completion of all the prompt requests
responses = await asyncio.gather(*futures)

# Print responses
# Responses will always come back in the same order as the original list
for resp in responses:
    print(resp.generated_text)

# Print duration to process all requests in batch
print("duration (s):", time.time() - t0)

Output:

duration (s): 2.9093329906463623

Compare this against the duration of submitting one at a time. You should find that for 3 prompts the duration of async is about 2.5 - 3x faster than serial processing:

from lorax import Client

client = Client(endpoint_url)

t0 = time.time()
responses = []
for prompt in prompts:
    resp = client.generate(prompt, max_new_tokens=64)
    responses.append(resp)

for resp in responses:
    print(resp.generated_text)

print("duration (s):", time.time() - t0)

Output:

duration (s): 8.385080099105835

Predibase Inference Endpoints

The LoRAX client can also be used to connect to Predibase managed LoRAX endpoints (including Predibase's serverless endpoints).

You need only make the following changes to the above examples:

  1. Change the endpoint_url to match the endpoint of your Predibase LLM of choice.
  2. Provide your Predibase API token in the headers provided to the client.

Example:

from lorax import Client

# You can get your Predibase API token by going to Settings > My Profile > Generate API Token
# You can get your Predibase Tenant short code by going to Settings > My Profile > Overview > Tenant ID
endpoint_url = f"https://serving.app.predibase.com/{predibase_tenant_short_code}/deployments/v2/llms/{llm_deployment_name}"
headers = {
    "Authorization": f"Bearer {api_token}"
}

client = Client(endpoint_url, headers=headers)

# same as above from here ...
response = client.generate("Why is the sky blue?", adapter_id=f"{model_repo}/{model_version}")

Note that by default Predibase will use its internal model repos as the default adapter_source. To use an adapter from Huggingface:

response = client.generate("Why is the sky blue?", adapter_id="some/adapter", adapter_source="hub")