LlaVa Demo with LlamaIndex
In this example, we illustrate how we use LlaVa for belowing tasks:
- Retrieval Augmented Image Captioning
- Pydantic Structured Output
- Multi-Modal Retrieval-Augmented Generation (RAG) using Llava-13b
Context for LLaVA: Large Language and Vision Assistant
For LlamaIndex: LlaVa+Replicate enables us to run image understanding locally and combine the multi-modal knowledge with our RAG knowledge based system.
Retrieval Augmented Image Captioning using Llava-13b
Section titled “Retrieval Augmented Image Captioning using Llava-13b”Using Replicate serving LLaVa model through LlamaIndex
Section titled “Using Replicate serving LLaVa model through LlamaIndex”%pip install llama-index-vector-stores-qdrant%pip install llama-index-readers-file%pip install llama-index-multi-modal-llms-replicate
%pip install unstructured replicate%pip install llama_index ftfy regex tqdm%pip install git+https://github.com/openai/CLIP.git%pip install torch torchvision%pip install matplotlib scikit-image%pip install -U qdrant_client
UsageError: Line magic function `%` not found.
import os
REPLICATE_API_TOKEN = "..." # Your Relicate API token hereos.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
Perform Data Extraction from Tesla 10K file
Section titled “Perform Data Extraction from Tesla 10K file”In these sections we use Unstructured to parse out the table and non-table elements.
Extract Elements
Section titled “Extract Elements”We use Unstructured to extract table and non-table elements from the 10-K filing.
!wget "https://www.dropbox.com/scl/fi/mlaymdy1ni1ovyeykhhuk/tesla_2021_10k.htm?rlkey=qf9k4zn0ejrbm716j0gg7r802&dl=1" -O tesla_2021_10k.htm!wget "https://docs.google.com/uc?export=download&id=1UU0xc3uLXs-WG0aDQSXjGacUkp142rLS" -O texas.jpg
from llama_index.readers.file import FlatReaderfrom pathlib import Pathfrom llama_index.core.node_parser import UnstructuredElementNodeParser
reader = FlatReader()docs_2021 = reader.load_data(Path("tesla_2021_10k.htm"))node_parser = UnstructuredElementNodeParser()
import openai
OPENAI_API_KEY = "..."openai.api_key = OPENAI_API_KEY # add your openai api key hereos.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
import osimport pickle
if not os.path.exists("2021_nodes.pkl"): raw_nodes_2021 = node_parser.get_nodes_from_documents(docs_2021) pickle.dump(raw_nodes_2021, open("2021_nodes.pkl", "wb"))else: raw_nodes_2021 = pickle.load(open("2021_nodes.pkl", "rb"))
nodes_2021, objects_2021 = node_parser.get_nodes_and_objects(raw_nodes_2021)
Setup Composable Retriever
Section titled “Setup Composable Retriever”Now that we’ve extracted tables and their summaries, we can setup a composable retriever in LlamaIndex to query these tables.
Construct Retrievers
Section titled “Construct Retrievers”from llama_index.core import VectorStoreIndex
# construct top-level vector index + query enginevector_index = VectorStoreIndex(nodes=nodes_2021, objects=objects_2021)query_engine = vector_index.as_query_engine(similarity_top_k=5, verbose=True)
from PIL import Imageimport matplotlib.pyplot as plt
imageUrl = "./texas.jpg"image = Image.open(imageUrl).convert("RGB")
plt.figure(figsize=(16, 5))plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f1b2e09b790>
Running LLaVa model using Replicate through LlamaIndex for image understanding
Section titled “Running LLaVa model using Replicate through LlamaIndex for image understanding”from llama_index.multi_modal_llms.replicate import ReplicateMultiModalfrom llama_index.core.schema import ImageDocumentfrom llama_index.multi_modal_llms.replicate.base import ( REPLICATE_MULTI_MODAL_LLM_MODELS,)
print(imageUrl)
llava_multi_modal_llm = ReplicateMultiModal( model=REPLICATE_MULTI_MODAL_LLM_MODELS["llava-13b"], max_new_tokens=200, temperature=0.1,)
prompt = "which Tesla factory is shown in the image? Please answer just the name of the factory."
llava_response = llava_multi_modal_llm.complete( prompt=prompt, image_documents=[ImageDocument(image_path=imageUrl)],)
./texas.jpg
print(llava_response.text)
Gigafactory
Retrieve relevant information from LlamaIndex knowledge base based on LLaVa image understanding to augment Image Captioning
Section titled “Retrieve relevant information from LlamaIndex knowledge base based on LLaVa image understanding to augment Image Captioning”rag_response = query_engine.query(llava_response.text)
print(rag_response)
Gigafactory is a term used by Tesla to describe its expansive manufacturing facilities that are strategically located in various regions worldwide. These factories are specifically designed to produce a range of Tesla products, including electric vehicles, battery cells, and energy storage solutions. Currently, Tesla operates Gigafactories in Nevada, New York, Shanghai, and Berlin, with plans to establish another one in Texas. The primary objective of these Gigafactories is to significantly enhance Tesla's production capabilities, drive down costs, and optimize operational efficiency across its manufacturing operations.
Multi-Modal Pydantic Program with LLaVa
Section titled “Multi-Modal Pydantic Program with LLaVa”Initialize the Instagram Ads Pydantic Class
Section titled “Initialize the Instagram Ads Pydantic Class”input_image_path = Path("instagram_images")if not input_image_path.exists(): Path.mkdir(input_image_path)
!wget "https://docs.google.com/uc?export=download&id=12ZpBBFkYu-jzz1iz356U5kMikn4uN9ww" -O ./instagram_images/jordan.png
Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.ERROR: could not open HSTS store at '/home/loganm/.wget-hsts'. HSTS will be disabled.--2024-01-15 14:39:59-- https://docs.google.com/uc?export=download&id=12ZpBBFkYu-jzz1iz356U5kMikn4uN9wwResolving docs.google.com (docs.google.com)... 142.251.32.78, 2607:f8b0:400b:807::200eConnecting to docs.google.com (docs.google.com)|142.251.32.78|:443... connected.HTTP request sent, awaiting response... 303 See OtherLocation: https://drive.usercontent.google.com/download?id=12ZpBBFkYu-jzz1iz356U5kMikn4uN9ww&export=download [following]--2024-01-15 14:40:00-- https://drive.usercontent.google.com/download?id=12ZpBBFkYu-jzz1iz356U5kMikn4uN9ww&export=downloadResolving drive.usercontent.google.com (drive.usercontent.google.com)... 142.251.32.65, 2607:f8b0:400b:802::2001Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|142.251.32.65|:443... connected.HTTP request sent, awaiting response... 200 OKLength: 2722061 (2.6M) [image/png]Saving to: ‘./instagram_images/jordan.png’
./instagram_images/ 100%[===================>] 2.60M 10.7MB/s in 0.2s
2024-01-15 14:40:01 (10.7 MB/s) - ‘./instagram_images/jordan.png’ saved [2722061/2722061]
from pydantic import BaseModel
class InsAds(BaseModel): """Data model for a Ins Ads."""
account: str brand: str product: str category: str discount: str price: str comments: str review: str description: str
from PIL import Imageimport matplotlib.pyplot as plt
ins_imageUrl = "./instagram_images/jordan.png"image = Image.open(ins_imageUrl).convert("RGB")
plt.figure(figsize=(16, 5))plt.imshow(image)
<matplotlib.image.AxesImage at 0x7f1b2e0fe710>
Using Multi-Modal Pydantic Program to generate structured output using Llava-13b
Section titled “Using Multi-Modal Pydantic Program to generate structured output using Llava-13b”from llama_index.multi_modal_llms.replicate import ReplicateMultiModalfrom llama_index.core.program import MultiModalLLMCompletionProgramfrom llama_index.core.output_parsers import PydanticOutputParserfrom llama_index.multi_modal_llms.replicate.base import ( REPLICATE_MULTI_MODAL_LLM_MODELS,)
prompt_template_str = """\ can you summarize what is in the image\ and return the answer with json format \"""
def pydantic_llava( model_name, output_class, image_documents, prompt_template_str): mm_llm = ReplicateMultiModal( model=REPLICATE_MULTI_MODAL_LLM_MODELS["llava-13b"], max_new_tokens=1000, )
llm_program = MultiModalLLMCompletionProgram.from_defaults( output_parser=PydanticOutputParser(output_class), image_documents=image_documents, prompt_template_str=prompt_template_str, multi_modal_llm=mm_llm, verbose=True, )
response = llm_program() print(f"Model: {model_name}") for res in response: print(res) return response
Output Structured Pydantic Output
Section titled “Output Structured Pydantic Output”from llama_index.core import SimpleDirectoryReader
ins_image_documents = SimpleDirectoryReader("./instagram_images").load_data()
pydantic_response = pydantic_llava( "llava-13b", InsAds, ins_image_documents, prompt_template_str)
Model: llava-13b('account', 'airjordan')('brand', 'Air Jordan')('product', 'Air Jordan 2')('category', 'Sneakers')('discount', '15%')('price', '$175')('comments', '')('review', '')('description', 'Italy.')
print(pydantic_response.brand)
Air Jordan
Advanced Multi-Modal Retrieval using GPT4V and Multi-Modal Index/Retriever/Query Engine
Section titled “Advanced Multi-Modal Retrieval using GPT4V and Multi-Modal Index/Retriever/Query Engine”Downloading text, images data from raw files [Wikipedia] for Multi Modal Index/Retrieval
Section titled “Downloading text, images data from raw files [Wikipedia] for Multi Modal Index/Retrieval”from pathlib import Pathimport requests
wiki_titles = [ "batman", "Vincent van Gogh", "San Francisco", "iPhone", "Tesla Model S", "BTS", "Air Jordan",]
data_path = Path("data_wiki")
for title in wiki_titles: response = requests.get( "https://en.wikipedia.org/w/api.php", params={ "action": "query", "format": "json", "titles": title, "prop": "extracts", "explaintext": True, }, ).json() page = next(iter(response["query"]["pages"].values())) wiki_text = page["extract"]
if not data_path.exists(): Path.mkdir(data_path)
with open(data_path / f"{title}.txt", "w") as fp: fp.write(wiki_text)
import wikipediaimport urllib.request
image_path = Path("data_wiki")image_uuid = 0# image_metadata_dict stores images metadata including image uuid, filename and pathimage_metadata_dict = {}MAX_IMAGES_PER_WIKI = 30
wiki_titles = [ "Air Jordan", "San Francisco", "Batman", "Vincent van Gogh", "iPhone", "Tesla Model S", "BTS band",]
# create folder for images onlyif not image_path.exists(): Path.mkdir(image_path)
# Download images for wiki pages# Assing UUID for each imagefor title in wiki_titles: images_per_wiki = 0 print(title) try: page_py = wikipedia.page(title) list_img_urls = page_py.images for url in list_img_urls: if url.endswith(".jpg") or url.endswith(".png"): image_uuid += 1 image_file_name = title + "_" + url.split("/")[-1]
# img_path could be s3 path pointing to the raw image file in the future image_metadata_dict[image_uuid] = { "filename": image_file_name, "img_path": "./" + str(image_path / f"{image_uuid}.jpg"), } urllib.request.urlretrieve( url, image_path / f"{image_uuid}.jpg" ) images_per_wiki += 1 # Limit the number of images downloaded per wiki page to 15 if images_per_wiki > MAX_IMAGES_PER_WIKI: break except: print(str(Exception("No images found for Wikipedia page: ")) + title) continue
Build Multi-modal index and Vector Store to index both text and images
Section titled “Build Multi-modal index and Vector Store to index both text and images”import qdrant_clientfrom llama_index.core import SimpleDirectoryReaderfrom llama_index.vector_stores.qdrant import QdrantVectorStorefrom llama_index.core import VectorStoreIndex, StorageContextfrom llama_index.core.indices import MultiModalVectorStoreIndex
# Create a local Qdrant vector storeclient = qdrant_client.QdrantClient(path="qdrant_mm_db")
text_store = QdrantVectorStore( client=client, collection_name="text_collection")image_store = QdrantVectorStore( client=client, collection_name="image_collection")storage_context = StorageContext.from_defaults( vector_store=text_store, image_store=image_store)
# Create the MultiModal indexdocuments = SimpleDirectoryReader("./data_wiki/").load_data()index = MultiModalVectorStoreIndex.from_documents( documents, storage_context=storage_context,)
from PIL import Imageimport matplotlib.pyplot as pltimport os
def plot_images(image_metadata_dict): original_images_urls = [] images_shown = 0 for image_id in image_metadata_dict: img_path = image_metadata_dict[image_id]["img_path"] if os.path.isfile(img_path): filename = image_metadata_dict[image_id]["filename"] image = Image.open(img_path).convert("RGB")
plt.subplot(8, 8, len(original_images_urls) + 1) plt.imshow(image) plt.xticks([]) plt.yticks([])
original_images_urls.append(filename) images_shown += 1 if images_shown >= 64: break
plt.tight_layout()
plot_images(image_metadata_dict)
Multi-Modal RAG Retrieval and Querying using LlaVa pydantic structured output
Section titled “Multi-Modal RAG Retrieval and Querying using LlaVa pydantic structured output”# generate retrieval resultsretriever = index.as_retriever(similarity_top_k=3, image_similarity_top_k=5)retrieval_results = retriever.retrieve(pydantic_response.brand)
from llama_index.core.response.notebook_utils import ( display_source_node, display_image_uris,)from llama_index.core.schema import ImageNode
retrieved_image = []for res_node in retrieval_results: if isinstance(res_node.node, ImageNode): retrieved_image.append(res_node.node.metadata["file_path"]) else: display_source_node(res_node, source_length=200)
display_image_uris(retrieved_image)
Node ID: 88152a5f-b72d-4b97-915d-f4d27a233bc2
Similarity: 0.8671320192132904
Text: Air Jordan is a line of basketball shoes produced by Nike, Inc. Related apparel and accessories are marketed under Jordan Brand.
The first Air Jordan shoe was produced for basketball player Michael…
Node ID: b11f9bd5-ad86-4a53-8622-85418e6bb70b
Similarity: 0.8659507837211329
Text: === Jordan “6 Rings” shoe ===
The Jordan 6 Rings (aka Jordan Six Rings, Or Montells) is a combination of the seven Air Jordan shoes that Michael Jordan wore during his 6 championship seasons. That …
Node ID: 0d7d5283-d4d9-476b-a4cd-81d9949cc60b
Similarity: 0.8545030450835027
Text: === Sneaker collecting ===
The shoes have had a large impact on the rise of “sneakerhead” culture. In the 1980s, collecting sneakers became more common, as well as trading and reselling them. As n…
Synthesis the RAG results using retrieved texts and images
Section titled “Synthesis the RAG results using retrieved texts and images”from llama_index.core import PromptTemplatefrom llama_index.core.query_engine import SimpleMultiModalQueryEngine
qa_tmpl_str = ( "Context information is below.\n" "---------------------\n" "{context_str}\n" "---------------------\n" "Given the context information and not prior knowledge, " "answer the query.\n" "Query: {query_str}\n" "Answer: ")qa_tmpl = PromptTemplate(qa_tmpl_str)
query_engine = index.as_query_engine( llm=llava_multi_modal_llm, text_qa_template=qa_tmpl, similarity_top_k=2, image_similarity_top_k=1,)
query_str = "Tell me more about the " + pydantic_response.brand + " brand."response = query_engine.query(query_str)
print(response)
The Air Jordan brand is a line of basketball shoes produced by Nike, Inc. It was created for Michael Jordan, a basketball player who played for the Chicago Bulls during the 1980s and 1990s. The first Air Jordan shoe was released in 1985, and it has since become one of the most iconic and successful shoe lines in history. The shoes are known for their distinctive design, high-quality materials, and innovative technology, which has helped to establish the Air Jordan brand as a leader in the athletic footwear industry. The brand has also expanded to include apparel, accessories, and other products, and has become a cultural phenomenon, with a significant impact on fashion, music, and popular culture.