Text Generation with Gradio

Gradio is a Python library that allows you to rapidly create user interfaces for deep learning / machine models. In this tutorial, we will walk through the process of creating a Gradio app for text generation using a pre-trained language model.

Step 1 : Prerequisites

Before you begin, make sure you have the following installed:

  • Python (3.8 or later)

For text generation, you can use popular language models like GPT-2, GPT-3, or any other model that suits your needs. Ensure you have the model and its dependencies installed.

1
2
3
pip install transformers==4.35
pip install torch==2.1.1
pip install gradio==4.4.1

Step 2: Create a Text Generation Function

Write a Python function that takes a prompt as input and generates text using the pre-trained language model. Here’s a simple example using the GPT-2 model from the transformers library:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import torch
import transformers
import gradio as gr
from transformers import AutoModelForCausalLM, GPT2Tokenizer

model_name = "gpt2"
model = AutoModelForCasualLM.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

def generate_text(prompt):
inputs = tokenizer([prompt], return_tensors="pt")
output = model.generate(inputs, max_length=100, do_sample= True, top_k=50, top_p=0.95)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

return generated_text

Step 3: Create a Gradio Interface

Now, create a Gradio interface for your text generation function. Define an interface with an input text box for the prompt and an output text box to display the generated text.

1
2
3

demo = gr.Interface(fn=generate_text, inputs="text", outputs="text")
demo.launch(server_name="0.0.0.0", server_port=7861)

Step 4: Run Your Gradio App

Save the script and run it. Gradio will launch a local web interface @ localhost:7861 for your text generation app. With this command you can auto-reload the browser when you make changes to your app. Open your web browser and navigate to the provided URL to interact with the app.

1
gradio app.py

Step 5 : Interact with Your Text Generation App

Enter a prompt in the input text box and see the generated text in the output text box. Experiment with different prompts to observe how the model responds. This how your first look of app will be

Initial Look

Step 6 : (Pro Tip) Lets add some good looks to this App

We can start with adding the description. So lets do that. We can also add more input to our app to have more control over our output. Update your Interface class with the following changes

1
2
3
4
5
6
7
8
9
10
11
12
13
14
description = """
# Deploy your first ML app using Gradio
"""
inputs = [
gr.Textbox(label="Prompt text"),
gr.Textbox(label="max-lenth generation", value=100),
gr.Slider(0.0, 1.0, label="top-p value", value=0.95),
gr.Textbox(label="top-k", value=50,),
]
outputs = [gr.Textbox(label="Generated Text")]

demo = gr.Interface(fn=generate_text, inputs=inputs, outputs=outputs, allow_flagging=False, description=description)

demo.launch(server_name="0.0.0.0", server_port=7861)

Step 7 : (Pro Tip) Animated text generation

Also text generation is taking time so it will be good if we can show the text word by word right. Lets do that by using some thread and TextIteratorStreamer class. Now update the generate_text method with following changes:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from threading import Thread
from transformers import TextIteratorStreamer

streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens = True)

def generate_text(prompt, max_length, top_p, top_k):
inputs = tokenizer([prompt], return_tensors="pt")

generate_kwargs = dict(
inputs,
max_length=int(max_length),top_p=float(top_p), do_sample=True, top_k=int(top_k), streamer=streamer
)

t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()

generated_text=[]

for text in streamer:
generated_text.append(text)
yield "".join(generated_text)

Conclusion

Congratulations! You have created a simple Gradio app for text generation.

Final Look

You can further customize the interface, explore different pre-trained models, and enhance the user experience. Gradio simplifies the process of creating interactive machine learning applications, making it easy for users to interact with your models.

Exporting an ONNX Model from PyTorch

PyTorch provides a straightforward way to export models to the ONNX format. Follow these steps:

  1. Ensure that you have PyTorch installed. You can install it using pip:

    1
    pip install torch
  2. Import the necessary libraries in your Python script:

    1
    2
    import torch
    import torchvision
  3. Load your PyTorch model:

    1
    model = torchvision.models.resnet18(pretrained=True)
  4. Export the model to ONNX:

    1
    torch.onnx.export(model, torch.randn(1, 3, 224, 224), "path/to/exported_model.onnx", export_params=True)

    Make sure to replace “path/to/exported_model.onnx” with the desired path to save the ONNX file.

Exporting an ONNX Model from Caffe2

Caffe2 supports exporting models directly to ONNX. Here are the steps:

  1. Install Caffe2 by following the instructions provided in the official Caffe2 documentation.

  2. Import the necessary libraries in your Python script:

    1
    2
    from caffe2.python.onnx import backend as caffe2_backend
    import torch.onnx
  3. Load your Caffe2 model into PyTorch:

    1
    2
    # Assuming you already have a Caffe2 model loaded
    torch_model = torch.onnx.load("path/to/caffe2_model.pb")
  4. Export the PyTorch model to ONNX:

    1
    torch.onnx.export(torch_model, torch.randn(1, 3, 224, 224), "path/to/exported_model.onnx", verbose=True)

    Adjust the input shape and replace “path/to/exported_model.onnx” with the desired path to save the ONNX file.