From d127fec9f7e5dedcde7989c32ae568ecbf574d34 Mon Sep 17 00:00:00 2001 From: KennyWu Date: Tue, 5 Nov 2024 15:02:57 +0800 Subject: [PATCH] feat: support for visualizing citation results (via embeddings) (#461) * feat:support for visualizing citation results (via embeddings) Signed-off-by: Kennywu * fix: remove ktem dependency in visualize_cited * fix: limit onnx version for fastembed * fix: test case of indexing * fix: minor update * fix: chroma req * fix: chroma req --------- Signed-off-by: Kennywu Co-authored-by: Tadashi --- libs/kotaemon/kotaemon/indices/vectorindex.py | 6 +- libs/kotaemon/pyproject.toml | 5 +- libs/ktem/ktem/reasoning/simple.py | 48 +++++- libs/ktem/ktem/utils/visualize_cited.py | 142 ++++++++++++++++++ 4 files changed, 196 insertions(+), 5 deletions(-) create mode 100644 libs/ktem/ktem/utils/visualize_cited.py diff --git a/libs/kotaemon/kotaemon/indices/vectorindex.py b/libs/kotaemon/kotaemon/indices/vectorindex.py index 1906091..e8f79a6 100644 --- a/libs/kotaemon/kotaemon/indices/vectorindex.py +++ b/libs/kotaemon/kotaemon/indices/vectorindex.py @@ -53,7 +53,11 @@ class VectorIndexing(BaseIndexing): def write_chunk_to_file(self, docs: list[Document]): # save the chunks content into markdown format if self.cache_dir: - file_name = Path(docs[0].metadata["file_name"]) + file_name = docs[0].metadata.get("file_name") + if not file_name: + return + + file_name = Path(file_name) for i in range(len(docs)): markdown_content = "" if "page_label" in docs[i].metadata: diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 6ee2aa7..ee5c468 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "langchain-cohere>=0.2.4,<0.3.0", "llama-hub>=0.0.79,<0.1.0", "llama-index>=0.10.40,<0.11.0", + "chromadb<=0.5.16", "llama-index-vector-stores-chroma>=0.1.9", "llama-index-vector-stores-lancedb", "openai>=1.23.6,<2", @@ -52,7 +53,8 @@ dependencies = [ "python-dotenv>=1.0.1,<1.1", "tenacity>=8.2.3,<8.3", "theflow>=0.8.6,<0.9.0", - "trogon>=0.5.0,<0.6" + "trogon>=0.5.0,<0.6", + "umap-learn==0.5.5", ] readme = "README.md" authors = [ @@ -71,6 +73,7 @@ adv = [ "duckduckgo-search>=6.1.0,<6.2", "elasticsearch>=8.13.0,<8.14", "fastembed", + "onnxruntime=1.2.4,<1.3", "llama-cpp-python<0.2.8", "llama-index>=0.10.40,<0.11.0", diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 9702efd..80bebad 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -8,6 +8,7 @@ from typing import Generator import numpy as np import tiktoken +from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.llms.manager import llms from ktem.reasoning.prompt_optimization import ( CreateMindmapPipeline, @@ -16,6 +17,8 @@ from ktem.reasoning.prompt_optimization import ( ) from ktem.utils.plantuml import PlantUML from ktem.utils.render import Render +from ktem.utils.visualize_cited import CreateCitationVizPipeline +from plotly.io import to_json from theflow.settings import settings as flowsettings from kotaemon.base import ( @@ -240,6 +243,7 @@ class AnswerWithContextPipeline(BaseComponent): enable_citation: bool = False enable_mindmap: bool = False + enable_citation_viz: bool = False system_prompt: str = "" lang: str = "English" # support English and Japanese @@ -409,7 +413,12 @@ class AnswerWithContextPipeline(BaseComponent): answer = Document( text=output, - metadata={"mindmap": mindmap, "citation": citation, "qa_score": qa_score}, + metadata={ + "citation_viz": self.enable_citation_viz, + "mindmap": mindmap, + "citation": citation, + "qa_score": qa_score, + }, ) return answer @@ -474,6 +483,11 @@ class FullQAPipeline(BaseReasoning): evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() rewrite_pipeline: RewriteQuestionPipeline | None = None + create_citation_viz_pipeline: CreateCitationVizPipeline = Node( + default_callback=lambda _: CreateCitationVizPipeline( + embedding=embeddings.get_default() + ) + ) add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx() def retrieve( @@ -641,10 +655,28 @@ class FullQAPipeline(BaseReasoning): return mindmap_content - def show_citations_and_addons(self, answer, docs): + def prepare_citation_viz(self, answer, question, docs) -> Document | None: + doc_texts = [doc.text for doc in docs] + citation_plot = None + plot_content = None + + if answer.metadata["citation_viz"] and len(docs) > 1: + try: + citation_plot = self.create_citation_viz_pipeline(doc_texts, question) + except Exception as e: + print("Failed to create citation plot:", e) + + if citation_plot: + plot = to_json(citation_plot) + plot_content = Document(channel="plot", content=plot) + + return plot_content + + def show_citations_and_addons(self, answer, docs, question): # show the evidence with_citation, without_citation = self.prepare_citations(answer, docs) mindmap_output = self.prepare_mindmap(answer) + citation_plot_output = self.prepare_citation_viz(answer, question, docs) if not with_citation and not without_citation: yield Document(channel="info", content="
No evidence found.
") @@ -661,6 +693,10 @@ class FullQAPipeline(BaseReasoning): if mindmap_output: yield mindmap_output + # yield citation plot output + if citation_plot_output: + yield citation_plot_output + # yield warning message if has_llm_score and max_llm_rerank_score < CONTEXT_RELEVANT_WARNING_SCORE: yield Document( @@ -733,7 +769,7 @@ class FullQAPipeline(BaseReasoning): if scoring_thread: scoring_thread.join() - yield from self.show_citations_and_addons(answer, docs) + yield from self.show_citations_and_addons(answer, docs, message) return answer @@ -767,6 +803,7 @@ class FullQAPipeline(BaseReasoning): answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"] answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"] answer_pipeline.enable_mindmap = settings[f"{prefix}.create_mindmap"] + answer_pipeline.enable_citation_viz = settings[f"{prefix}.create_citation_viz"] answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"] answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"] answer_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( @@ -820,6 +857,11 @@ class FullQAPipeline(BaseReasoning): "value": False, "component": "checkbox", }, + "create_citation_viz": { + "name": "Create Embeddings Visualization", + "value": False, + "component": "checkbox", + }, "system_prompt": { "name": "System Prompt", "value": "This is a question answering system", diff --git a/libs/ktem/ktem/utils/visualize_cited.py b/libs/ktem/ktem/utils/visualize_cited.py new file mode 100644 index 0000000..a9602c7 --- /dev/null +++ b/libs/ktem/ktem/utils/visualize_cited.py @@ -0,0 +1,142 @@ +""" +This module aims to project high-dimensional embeddings +into a lower-dimensional space for visualization. + +Refs: +1. [RAGxplorer](https://github.com/gabrielchua/RAGxplorer) +2. [RAGVizExpander](https://github.com/KKenny0/RAGVizExpander) +""" +from typing import List, Tuple + +import numpy as np +import pandas as pd +import plotly.graph_objs as go +import umap + +from kotaemon.base import BaseComponent +from kotaemon.embeddings import BaseEmbeddings + +VISUALIZATION_SETTINGS = { + "Original Query": {"color": "red", "opacity": 1, "symbol": "cross", "size": 15}, + "Retrieved": {"color": "green", "opacity": 1, "symbol": "circle", "size": 10}, + "Chunks": {"color": "blue", "opacity": 0.4, "symbol": "circle", "size": 10}, + "Sub-Questions": {"color": "purple", "opacity": 1, "symbol": "star", "size": 15}, +} + + +class CreateCitationVizPipeline(BaseComponent): + """Creating PlotData for visualizing query results""" + + embedding: BaseEmbeddings + projector: umap.UMAP = None + + def _set_up_umap(self, embeddings: np.ndarray): + umap_transform = umap.UMAP().fit(embeddings) + return umap_transform + + def _project_embeddings(self, embeddings, umap_transform) -> np.ndarray: + umap_embeddings = np.empty((len(embeddings), 2)) + for i, embedding in enumerate(embeddings): + umap_embeddings[i] = umap_transform.transform([embedding]) + return umap_embeddings + + def _get_projections(self, embeddings, umap_transform): + projections = self._project_embeddings(embeddings, umap_transform) + x = projections[:, 0] + y = projections[:, 1] + return x, y + + def _prepare_projection_df( + self, + document_projections: Tuple[np.ndarray, np.ndarray], + document_text: List[str], + plot_size: int = 3, + ) -> pd.DataFrame: + """Prepares a DataFrame for visualization from projections and texts. + + Args: + document_projections (Tuple[np.ndarray, np.ndarray]): + Tuple of X and Y coordinates of document projections. + document_text (List[str]): List of document texts. + """ + df = pd.DataFrame({"x": document_projections[0], "y": document_projections[1]}) + df["document"] = document_text + df["document_cleaned"] = df.document.str.wrap(50).apply( + lambda x: x.replace("\n", "
")[:512] + "..." + ) + df["size"] = plot_size + df["category"] = "Retrieved" + return df + + def _plot_embeddings(self, df: pd.DataFrame) -> go.Figure: + """ + Creates a Plotly figure to visualize the embeddings. + + Args: + df (pd.DataFrame): DataFrame containing the data to visualize. + + Returns: + go.Figure: A Plotly figure object for visualization. + """ + fig = go.Figure() + + for category in df["category"].unique(): + category_df = df[df["category"] == category] + settings = VISUALIZATION_SETTINGS.get( + category, + {"color": "grey", "opacity": 1, "symbol": "circle", "size": 10}, + ) + fig.add_trace( + go.Scatter( + x=category_df["x"], + y=category_df["y"], + mode="markers", + name=category, + marker=dict( + color=settings["color"], + opacity=settings["opacity"], + symbol=settings["symbol"], + size=settings["size"], + line_width=0, + ), + hoverinfo="text", + text=category_df["document_cleaned"], + ) + ) + + fig.update_layout( + height=500, + legend=dict(y=100, x=0.5, xanchor="center", yanchor="top", orientation="h"), + ) + return fig + + def run(self, context: List[str], question: str): + embed_contexts = self.embedding(context) + context_embeddings = np.array([d.embedding for d in embed_contexts]) + + self.projector = self._set_up_umap(embeddings=context_embeddings) + + embed_query = self.embedding(question) + query_projection = self._get_projections( + embeddings=[embed_query[0].embedding], umap_transform=self.projector + ) + viz_query_df = pd.DataFrame( + { + "x": [query_projection[0][0]], + "y": [query_projection[1][0]], + "document_cleaned": question, + "category": "Original Query", + "size": 5, + } + ) + + context_projections = self._get_projections( + embeddings=context_embeddings, umap_transform=self.projector + ) + viz_base_df = self._prepare_projection_df( + document_projections=context_projections, document_text=context + ) + + visualization_df = pd.concat([viz_base_df, viz_query_df], axis=0) + fig = self._plot_embeddings(visualization_df) + return fig