4. Routing#
When we have multiple data sources such as a GraphDB, PDF documents (i.e., a vector store), we might need to answer user queries based on the correct data source. For example, if the user wants to know about reviews of a hospital, user query should be redirected to the vector store containing embeddings of hospital reviews. On the other hand, if the user wants to know about information such as the doctors, patients, their visits to the hospital, the user query should probably be send to a graph database that contains the hospial information. Therefore, to provide such as functionality we will now focus on “Routing” in RAG with LangChain.
In this section we talk about two main types of routing techniques, namely Logical routing and Semantic routing.
First let’s import our libraries and create two vector stores to where we re-direct the user queries.
%load_ext dotenv
%dotenv secrets/secrets.env
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
def generate_vectorstores(file, dir):
loader = PyPDFLoader(file)
documents = loader.load()
# Split text into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=20)
text_chunks = text_splitter.split_documents(documents)
vectorstore = Chroma.from_documents(documents=text_chunks,
embedding=OpenAIEmbeddings(),
persist_directory=dir)
vectorstore.persist()
return vectorstore
# Create a vectorstore to answer questions about LoRA
vectorstore_lora = generate_vectorstores("data/LoRA.pdf","data/vectorstore_lora")
# Create a vectorstore to answer questions about BERT
vectorstore_bert = generate_vectorstores("data/BERT.pdf","data/vectorstore_bert")
/Users/sakunaharinda/Documents/Repositories/ragatouille/venv/lib/python3.12/site-packages/langchain_core/_api/deprecation.py:119: LangChainDeprecationWarning: Since Chroma 0.4.x the manual persistence method is no longer supported as docs are automatically persisted.
warn_deprecated(
retriever_lora = vectorstore_lora.as_retriever(search_kwargs={'k':5})
retriever_bert = vectorstore_bert.as_retriever(search_kwargs={'k':5})
4.1. Logical Routing#
In logical routing we let the LLM to decide the route based on a set of pre-defined options/routes. To do that we first define our router with three main routes as a Pydantic model. In the QueryRouter
model, we define 2 fields, namely datasource
indicating the datasource where the query is re-directed to and the question
representing the user query. For the datasource
field, we allow three values “lora”, “bert” that represent two vectore stores we created earlier, and “general” to route the query directly to the LLM as the fallback mechanism.
After specifying our router we initialize our LLM as GPT-4 to provide the output as a QueryRouter
object using with_structured_output()
method.
Finally, we crate our router chain using LCEL.
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate
from typing import Literal
class QueryRouter(BaseModel):
"""Route a user query to the appropriate datasource that will help answer the query accurately"""
datasource: Literal['lora', 'bert', 'general'] = Field(...,
description="Given a user question choose which datasource would be most relevant for answering their question"
)
question: str = Field(..., description="User question to be routed to the appropriate datasource")
llm = ChatOpenAI(model='gpt-4',temperature=0)
structured_llm = llm.with_structured_output(QueryRouter)
router_prompt = ChatPromptTemplate.from_messages(
[
("system", "You are an expert router that can direct user queries to the appropriate datasource. Route the following user question about a topic in NLP and LLMs to the appropriate datasource.\nIf it is a general question not related to the provided datasources, route it to the general datasource.\n"),
("user", "{question}")
]
)
router = (
{'question': RunnablePassthrough()}
| router_prompt
| structured_llm
)
After invoking our router chain we will be able to see it logically decides the datasource to redirect the query and output it as a QueryRouter
object.
question = "How does the BERT work?"
result = router.invoke(question)
result
QueryRouter(datasource='bert', question='How does the BERT work?')
Then, to use the router output and perform the QA accordingly, we define a new method choose_route
. choose_route
checks the router chain result to extract the datasource and defines three chains to answer the questions related to BERT, LoRA, and general domain.
We complete our RAG with one final chain by putting all the methods and chains together in the full_chain
.
qa_prompt = hub.pull('rlm/rag-prompt')
def choose_route(result):
llm_route = ChatOpenAI(model='gpt-4',temperature=0)
if "bert" in result.datasource.lower():
print(f"> Asking about BERT ...\nQuestion: {result.question}\nAnswer:")
bert_chain = (
{'context': retriever_bert, 'question': RunnablePassthrough()}
| qa_prompt
| llm_route
| StrOutputParser()
)
return bert_chain.invoke(result.question)
elif "lora" in result.datasource.lower():
print(f"> Asking about LoRA ...\nQuestion: {result.question}\nAnswer:")
lora_chain = (
{'context': retriever_lora, 'question': RunnablePassthrough()}
| qa_prompt
| llm_route
| StrOutputParser()
)
return lora_chain.invoke(result.question)
else:
print(f"> Asking about a general question ...\nQuestion: {result.question}\nAnswer:")
general_chain = llm_route | StrOutputParser()
return general_chain.invoke(result.question)
from langchain_core.runnables import RunnableLambda
full_chain = router | RunnableLambda(choose_route)
full_chain.invoke("What are the benefits of LoRA?")
> Asking about LoRA ...
Question: What are the benefits of LoRA?
Answer:
'LoRA, or Low-Rank Adaptation, offers several benefits. It is an efficient adaptation strategy that does not introduce inference latency or reduce input sequence length, while maintaining high model quality. It allows for quick task-switching by sharing the majority of the model parameters. LoRA also makes training more efficient and lowers the hardware barrier to entry by up to 3 times when using adaptive optimizers. Furthermore, it can reduce the number of trainable parameters by 10,000 times and the GPU memory requirement by 3 times, while performing on-par or better than fine-tuning in model quality.'
The LangSmith trace for our logical router will look like this.
4.2. Semantic Rounting#
In contrast to the logical routing, semantic routing depends on the semantic similarity between the user query and the router prompts to decide which route to take. Let’s try to implement it for RAG!
First, we define two prompts representing the two routes of our semantic router.
physics_template = """You are a very smart physics professor. \
You are great at answering questions about physics in a concise and easy to understand manner. \
When you don't know the answer to a question you admit that you don't know.
Here is a question:
{question}"""
math_template = """You are a very good mathematician. You are great at answering math questions. \
You are so good because you are able to break down hard problems into their component parts, \
answer the component parts, and then put them together to answer the broader question.
Here is a question:
{question}"""
Secondly we generate embedding vectors for both of those prompts using OpenAIEmbeddings
.
embeddings = OpenAIEmbeddings()
routes = [physics_template, math_template]
route_embeddings = embeddings.embed_documents(routes)
len(route_embeddings)
2
We now create the router that first embed the user query and get the cosine similarity scores between the query embeddings and the embeddings of each prompt. Depending on the similarity, the router returns the prompt that has the highest similarity with the query to use as the prompt to the LLM.
from langchain.utils.math import cosine_similarity
from langchain.prompts import PromptTemplate
def router(input):
# Generate embeddings for the user query
query_embedding = embeddings.embed_query(input['question'])
# Getting similarity scores between the user query and the routes. This contains the similarity scores between the user query and each of the two routes.
similarity = cosine_similarity([query_embedding], route_embeddings)[0]
# Find the route that gives the maximum similarity score
route_id = similarity.argmax()
if route_id == 0:
print(f"> Asking a physics question ...\nQuestion: {input['question']}\nAnswer:")
else:
print(f"> Asking a math question ...\nQuestion: {input['question']}\nAnswer:")
return PromptTemplate.from_template(routes[route_id])
Finally, we create our RAG chain that first takes the user query and then answer it using the appropriate prompt decided by the router.
semantic_router_chain = (
{'question': RunnablePassthrough()}
| RunnableLambda(router)
| ChatOpenAI(model='gpt-4',temperature=0)
| StrOutputParser()
)
semantic_router_chain.invoke("What is the formula for the area of a circle?")
> Asking a math question ...
Question: What is the formula for the area of a circle?
Answer:
'The formula for the area of a circle is A = πr², where A is the area and r is the radius of the circle.'
This technique is much more simpler compared to the logical routing. The LangSmith trace for our semantic router will look like this.