RAG Workflow with Reranking
This notebook walks through setting up a Workflow
to perform basic RAG with reranking.
!pip install -U llama-index
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..."
[Optional] Set up observability with Llamatrace
Section titled “[Optional] Set up observability with Llamatrace”Set up tracing to visualize each step in the workflow.
%pip install "openinference-instrumentation-llama-index>=3.0.0" "opentelemetry-proto>=1.12.0" opentelemetry-exporter-otlp opentelemetry-sdk
from opentelemetry.sdk import trace as trace_sdkfrom opentelemetry.sdk.trace.export import SimpleSpanProcessorfrom opentelemetry.exporter.otlp.proto.http.trace_exporter import ( OTLPSpanExporter as HTTPSpanExporter,)from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
# Add Phoenix API Key for tracingPHOENIX_API_KEY = "<YOUR-PHOENIX-API-KEY>"os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"api_key={PHOENIX_API_KEY}"
# Add Phoenixspan_phoenix_processor = SimpleSpanProcessor( HTTPSpanExporter(endpoint="https://app.phoenix.arize.com/v1/traces"))
# Add them to the tracertracer_provider = trace_sdk.TracerProvider()tracer_provider.add_span_processor(span_processor=span_phoenix_processor)
# Instrument the applicationLlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
!mkdir -p data!wget --user-agent "Mozilla" "https://arxiv.org/pdf/2307.09288.pdf" -O "data/llama2.pdf"
Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use asyncio.run()
to start an async event loop if one isn’t already running.
async def main(): <async code>
if __name__ == "__main__": import asyncio asyncio.run(main())
Designing the Workflow
Section titled “Designing the Workflow”RAG + Reranking consists of some clearly defined steps
- Indexing data, creating an index
- Using that index + a query to retrieve relevant text chunks
- Rerank the text retrieved text chunks using the original query
- Synthesizing a final response
With this in mind, we can create events and workflow steps to follow this process!
The Workflow Events
Section titled “The Workflow Events”To handle these steps, we need to define a few events:
- An event to pass retrieved nodes to the reranker
- An event to pass reranked nodes to the synthesizer
The other steps will use the built-in StartEvent
and StopEvent
events.
from llama_index.core.workflow import Eventfrom llama_index.core.schema import NodeWithScore
class RetrieverEvent(Event): """Result of running retrieval"""
nodes: list[NodeWithScore]
class RerankEvent(Event): """Result of running reranking on retrieved nodes"""
nodes: list[NodeWithScore]
The Workflow Itself
Section titled “The Workflow Itself”With our events defined, we can construct our workflow and steps.
Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!
from llama_index.core import SimpleDirectoryReader, VectorStoreIndexfrom llama_index.core.response_synthesizers import CompactAndRefinefrom llama_index.core.postprocessor.llm_rerank import LLMRerankfrom llama_index.core.workflow import ( Context, Workflow, StartEvent, StopEvent, step,)
from llama_index.llms.openai import OpenAIfrom llama_index.embeddings.openai import OpenAIEmbedding
class RAGWorkflow(Workflow): @step async def ingest(self, ctx: Context, ev: StartEvent) -> StopEvent | None: """Entry point to ingest a document, triggered by a StartEvent with `dirname`.""" dirname = ev.get("dirname") if not dirname: return None
documents = SimpleDirectoryReader(dirname).load_data() index = VectorStoreIndex.from_documents( documents=documents, embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"), ) return StopEvent(result=index)
@step async def retrieve( self, ctx: Context, ev: StartEvent ) -> RetrieverEvent | None: "Entry point for RAG, triggered by a StartEvent with `query`." query = ev.get("query") index = ev.get("index")
if not query: return None
print(f"Query the database with: {query}")
# store the query in the global context await ctx.store.set("query", query)
# get the index from the global context if index is None: print("Index is empty, load some documents before querying!") return None
retriever = index.as_retriever(similarity_top_k=2) nodes = await retriever.aretrieve(query) print(f"Retrieved {len(nodes)} nodes.") return RetrieverEvent(nodes=nodes)
@step async def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent: # Rerank the nodes ranker = LLMRerank( choice_batch_size=5, top_n=3, llm=OpenAI(model="gpt-4o-mini") ) print(await ctx.store.get("query", default=None), flush=True) new_nodes = ranker.postprocess_nodes( ev.nodes, query_str=await ctx.store.get("query", default=None) ) print(f"Reranked nodes to {len(new_nodes)}") return RerankEvent(nodes=new_nodes)
@step async def synthesize(self, ctx: Context, ev: RerankEvent) -> StopEvent: """Return a streaming response using reranked nodes.""" llm = OpenAI(model="gpt-4o-mini") summarizer = CompactAndRefine(llm=llm, streaming=True, verbose=True) query = await ctx.store.get("query", default=None)
response = await summarizer.asynthesize(query, nodes=ev.nodes) return StopEvent(result=response)
And thats it! Let’s explore the workflow we wrote a bit.
- We have two entry points (the steps that accept
StartEvent
) - The steps themselves decide when they can run
- The workflow context is used to store the user query
- The nodes are passed around, and finally a streaming response is returned
Run the Workflow!
Section titled “Run the Workflow!”w = RAGWorkflow()
# Ingest the documentsindex = await w.run(dirname="data")
# Run a queryresult = await w.run(query="How was Llama2 trained?", index=index)async for chunk in result.async_response_gen(): print(chunk, end="", flush=True)
Query the database with: How was Llama2 trained?Retrieved 2 nodes.How was Llama2 trained?Reranked nodes to 2Llama 2 was trained through a multi-step process that began with pretraining using publicly available online sources. This was followed by the creation of an initial version of Llama 2-Chat through supervised fine-tuning. The model was then iteratively refined using Reinforcement Learning with Human Feedback (RLHF) methodologies, which included rejection sampling and Proximal Policy Optimization (PPO).
During pretraining, the model utilized an optimized auto-regressive transformer architecture, incorporating robust data cleaning, updated data mixes, and training on a significantly larger dataset of 2 trillion tokens. The training process also involved increased context length and the use of grouped-query attention (GQA) to enhance inference scalability.
The training employed the AdamW optimizer with specific hyperparameters, including a cosine learning rate schedule and gradient clipping. The models were pretrained on Meta’s Research SuperCluster and internal production clusters, utilizing NVIDIA A100 GPUs.