feat: sso login, demo mode & new mindmap support (#644) bump:minor

* fix: update .env.example

* feat: add SSO login

* fix: update flowsetting

* fix: add requirement

* fix: refine UI

* fix: update group id-based operation

* fix: improve citation logics

* fix: UI enhancement

* fix: user_id to string in models

* fix: improve chat suggestion UI and flow

* fix: improve group id handling

* fix: improve chat suggestion

* fix: secure download for single file

* fix: file limiting in docstore

* fix: improve chat suggestion logics & language conform

* feat: add markmap and select text to highlight function

* fix: update Dockerfile

* fix: user id auto generate

* fix: default user id

* feat: add demo mode

* fix: update flowsetting

* fix: revise default params for demo

* feat: sso_app alternative

* feat: sso login demo

* feat: demo specific customization

* feat: add login using API key

* fix: disable key-based login

* fix: optimize duplicate upload

* fix: gradio routing

* fix: disable arm build for demo

* fix: revise full-text search js logic

* feat: add rate limit

* fix: update Dockerfile with new launch script

* fix: update Dockerfile

* fix: update Dockerignore

* fix: update ratelimit logic

* fix: user_id in user management page

* fix: rename conv logic

* feat: update demo hint

* fix: minor fix

* fix: highlight on long PDF load

* feat: add HF paper list

* fix: update HF papers load logic

* feat: fly config

* fix: update fly config

* fix: update paper list pull api

* fix: minor update root routing

* fix: minor update root routing

* fix: simplify login flow & paper list UI

* feat: add paper recommendation

* fix: update Dockerfile

* fix: update Dockerfile

* fix: update default model

* feat: add long context Ollama through LCOllama

* feat: espose Gradio share to env

* fix: revert customized changes

* fix: list group at app load

* fix: relocate share conv button

* fix: update launch script

* fix: update Docker CI

* feat: add Ollama model selection at first setup

* docs: update README
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2025-02-02 15:19:48 +07:00 committed by GitHub
parent 3006402d7e
commit 3bd3830b8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 2488 additions and 937 deletions

View File

@ -11,3 +11,5 @@ env/
README.md README.md
*.zip *.zip
*.sh *.sh
!/launch.sh

View File

@ -3,8 +3,8 @@
# settings for OpenAI # settings for OpenAI
OPENAI_API_BASE=https://api.openai.com/v1 OPENAI_API_BASE=https://api.openai.com/v1
OPENAI_API_KEY=<YOUR_OPENAI_KEY> OPENAI_API_KEY=<YOUR_OPENAI_KEY>
OPENAI_CHAT_MODEL=gpt-3.5-turbo OPENAI_CHAT_MODEL=gpt-4o-mini
OPENAI_EMBEDDINGS_MODEL=text-embedding-ada-002 OPENAI_EMBEDDINGS_MODEL=text-embedding-3-large
# settings for Azure OpenAI # settings for Azure OpenAI
AZURE_OPENAI_ENDPOINT= AZURE_OPENAI_ENDPOINT=
@ -17,10 +17,8 @@ AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=text-embedding-ada-002
COHERE_API_KEY=<COHERE_API_KEY> COHERE_API_KEY=<COHERE_API_KEY>
# settings for local models # settings for local models
LOCAL_MODEL=llama3.1:8b LOCAL_MODEL=qwen2.5:7b
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
LOCAL_EMBEDDING_MODEL_DIM = 768
LOCAL_EMBEDDING_MODEL_MAX_TOKENS = 8192
# settings for GraphRAG # settings for GraphRAG
GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY> GRAPHRAG_API_KEY=<YOUR_OPENAI_KEY>

View File

@ -28,6 +28,7 @@ jobs:
target: target:
- lite - lite
- full - full
- ollama
steps: steps:
- name: Free Disk Space (Ubuntu) - name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main uses: jlumbroso/free-disk-space@main

18
.github/workflows/fly-deploy.yml vendored Normal file
View File

@ -0,0 +1,18 @@
# See https://fly.io/docs/app-guides/continuous-deployment-with-github-actions/
name: Fly Deploy
on:
push:
branches:
- main
jobs:
deploy:
name: Deploy app
runs-on: ubuntu-latest
concurrency: deploy-group # optional: ensure only one action runs at a time
steps:
- uses: actions/checkout@v4
- uses: superfly/flyctl-actions/setup-flyctl@master
- run: flyctl deploy --remote-only
env:
FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }}

View File

@ -57,6 +57,7 @@ repos:
"types-requests", "types-requests",
"sqlmodel", "sqlmodel",
"types-Markdown", "types-Markdown",
"types-cachetools",
types-tzlocal, types-tzlocal,
] ]
args: ["--check-untyped-defs", "--ignore-missing-imports"] args: ["--check-untyped-defs", "--ignore-missing-imports"]

View File

@ -35,6 +35,7 @@ RUN bash scripts/download_pdfjs.sh $PDFJS_PREBUILT_DIR
# Copy contents # Copy contents
COPY . /app COPY . /app
COPY launch.sh /app/launch.sh
COPY .env.example /app/.env COPY .env.example /app/.env
# Install pip packages # Install pip packages
@ -54,7 +55,7 @@ RUN apt-get autoremove \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& rm -rf ~/.cache && rm -rf ~/.cache
CMD ["python", "app.py"] ENTRYPOINT ["sh", "/app/launch.sh"]
# Full version # Full version
FROM lite AS full FROM lite AS full
@ -97,7 +98,17 @@ RUN apt-get autoremove \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& rm -rf ~/.cache && rm -rf ~/.cache
# Download nltk packages as required for unstructured ENTRYPOINT ["sh", "/app/launch.sh"]
# RUN python -c "from unstructured.nlp.tokenize import _download_nltk_packages_if_not_present; _download_nltk_packages_if_not_present()"
CMD ["python", "app.py"] # Ollama-bundled version
FROM full AS ollama
# Install ollama
RUN --mount=type=ssh \
--mount=type=cache,target=/root/.cache/pip \
curl -fsSL https://ollama.com/install.sh | sh
# RUN nohup bash -c "ollama serve &" && sleep 4 && ollama pull qwen2.5:7b
RUN nohup bash -c "ollama serve &" && sleep 4 && ollama pull nomic-embed-text
ENTRYPOINT ["sh", "/app/launch.sh"]

View File

@ -96,18 +96,7 @@ documents and developers who want to build their own RAG pipeline.
### With Docker (recommended) ### With Docker (recommended)
1. We support both `lite` & `full` version of Docker images. With `full`, the extra packages of `unstructured` will be installed as well, it can support additional file types (`.doc`, `.docx`, ...) but the cost is larger docker image size. For most users, the `lite` image should work well in most cases. 1. We support both `lite` & `full` version of Docker images. With `full` version, the extra packages of `unstructured` will be installed, which can support additional file types (`.doc`, `.docx`, ...) but the cost is larger docker image size. For most users, the `lite` image should work well in most cases.
- To use the `lite` version.
```bash
docker run \
-e GRADIO_SERVER_NAME=0.0.0.0 \
-e GRADIO_SERVER_PORT=7860 \
-v ./ktem_app_data:/app/ktem_app_data \
-p 7860:7860 -it --rm \
ghcr.io/cinnamon/kotaemon:main-lite
```
- To use the `full` version. - To use the `full` version.
@ -124,7 +113,14 @@ documents and developers who want to build their own RAG pipeline.
```bash ```bash
# change image name to # change image name to
ghcr.io/cinnamon/kotaemon:feat-ollama_docker-full docker run <...> ghcr.io/cinnamon/kotaemon:main-ollama
```
- To use the `lite` version.
```bash
# change image name to
docker run <...> ghcr.io/cinnamon/kotaemon:main-lite
``` ```
2. We currently support and test two platforms: `linux/amd64` and `linux/arm64` (for newer Mac). You can specify the platform by passing `--platform` in the `docker run` command. For example: 2. We currently support and test two platforms: `linux/amd64` and `linux/arm64` (for newer Mac). You can specify the platform by passing `--platform` in the `docker run` command. For example:

2
app.py
View File

@ -3,6 +3,7 @@ import os
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".") KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".")
KH_GRADIO_SHARE = getattr(flowsettings, "KH_GRADIO_SHARE", False)
GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None) GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None)
# override GRADIO_TEMP_DIR if it's not set # override GRADIO_TEMP_DIR if it's not set
if GRADIO_TEMP_DIR is None: if GRADIO_TEMP_DIR is None:
@ -21,4 +22,5 @@ demo.queue().launch(
"libs/ktem/ktem/assets", "libs/ktem/ktem/assets",
GRADIO_TEMP_DIR, GRADIO_TEMP_DIR,
], ],
share=KH_GRADIO_SHARE,
) )

View File

@ -4,8 +4,8 @@ An open-source tool for chatting with your documents. Built with both end users
developers in mind. developers in mind.
[Source Code](https://github.com/Cinnamon/kotaemon) | [Source Code](https://github.com/Cinnamon/kotaemon) |
[Live Demo](https://huggingface.co/spaces/cin-model/kotaemon-demo) [HF Space](https://huggingface.co/spaces/cin-model/kotaemon-demo)
[User Guide](https://cinnamon.github.io/kotaemon/) | [Installation Guide](https://cinnamon.github.io/kotaemon/) |
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) | [Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
[Feedback](https://github.com/Cinnamon/kotaemon/issues) [Feedback](https://github.com/Cinnamon/kotaemon/issues)

View File

@ -1,7 +1,7 @@
## Installation (Online HuggingFace Space) ## Installation (Online HuggingFace Space)
1. Go to [HF kotaemon_template](https://huggingface.co/spaces/cin-model/kotaemon_template). 1. Go to [HF kotaemon_template](https://huggingface.co/spaces/cin-model/kotaemon_template).
2. Use Duplicate function to create your own space. 2. Use Duplicate function to create your own space. Or use this [direct link](https://huggingface.co/spaces/cin-model/kotaemon_template?duplicate=true).
![Duplicate space](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/duplicate_space.png) ![Duplicate space](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/duplicate_space.png)
![Change space params](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/change_space_params.png) ![Change space params](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/change_space_params.png)
3. Wait for the build to complete and start up (apprx 10 mins). 3. Wait for the build to complete and start up (apprx 10 mins).

View File

@ -25,7 +25,8 @@ if not KH_APP_VERSION:
except Exception: except Exception:
KH_APP_VERSION = "local" KH_APP_VERSION = "local"
KH_ENABLE_FIRST_SETUP = True KH_GRADIO_SHARE = config("KH_GRADIO_SHARE", default=False, cast=bool)
KH_ENABLE_FIRST_SETUP = config("KH_ENABLE_FIRST_SETUP", default=True, cast=bool)
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool) KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/") KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/")
@ -65,6 +66,8 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
KH_DOC_DIR = this_dir / "docs" KH_DOC_DIR = this_dir / "docs"
KH_MODE = "dev" KH_MODE = "dev"
KH_SSO_ENABLED = config("KH_SSO_ENABLED", default=False, cast=bool)
KH_FEATURE_CHAT_SUGGESTION = config( KH_FEATURE_CHAT_SUGGESTION = config(
"KH_FEATURE_CHAT_SUGGESTION", default=False, cast=bool "KH_FEATURE_CHAT_SUGGESTION", default=False, cast=bool
) )
@ -137,31 +140,36 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
"default": False, "default": False,
} }
if config("OPENAI_API_KEY", default=""): OPENAI_DEFAULT = "<YOUR_OPENAI_KEY>"
OPENAI_API_KEY = config("OPENAI_API_KEY", default=OPENAI_DEFAULT)
GOOGLE_API_KEY = config("GOOGLE_API_KEY", default="your-key")
IS_OPENAI_DEFAULT = len(OPENAI_API_KEY) > 0 and OPENAI_API_KEY != OPENAI_DEFAULT
if OPENAI_API_KEY:
KH_LLMS["openai"] = { KH_LLMS["openai"] = {
"spec": { "spec": {
"__type__": "kotaemon.llms.ChatOpenAI", "__type__": "kotaemon.llms.ChatOpenAI",
"temperature": 0, "temperature": 0,
"base_url": config("OPENAI_API_BASE", default="") "base_url": config("OPENAI_API_BASE", default="")
or "https://api.openai.com/v1", or "https://api.openai.com/v1",
"api_key": config("OPENAI_API_KEY", default=""), "api_key": OPENAI_API_KEY,
"model": config("OPENAI_CHAT_MODEL", default="gpt-3.5-turbo"), "model": config("OPENAI_CHAT_MODEL", default="gpt-4o-mini"),
"timeout": 20, "timeout": 20,
}, },
"default": True, "default": IS_OPENAI_DEFAULT,
} }
KH_EMBEDDINGS["openai"] = { KH_EMBEDDINGS["openai"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.OpenAIEmbeddings", "__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": config("OPENAI_API_BASE", default="https://api.openai.com/v1"), "base_url": config("OPENAI_API_BASE", default="https://api.openai.com/v1"),
"api_key": config("OPENAI_API_KEY", default=""), "api_key": OPENAI_API_KEY,
"model": config( "model": config(
"OPENAI_EMBEDDINGS_MODEL", default="text-embedding-ada-002" "OPENAI_EMBEDDINGS_MODEL", default="text-embedding-3-large"
), ),
"timeout": 10, "timeout": 10,
"context_length": 8191, "context_length": 8191,
}, },
"default": True, "default": IS_OPENAI_DEFAULT,
} }
if config("LOCAL_MODEL", default=""): if config("LOCAL_MODEL", default=""):
@ -169,11 +177,21 @@ if config("LOCAL_MODEL", default=""):
"spec": { "spec": {
"__type__": "kotaemon.llms.ChatOpenAI", "__type__": "kotaemon.llms.ChatOpenAI",
"base_url": KH_OLLAMA_URL, "base_url": KH_OLLAMA_URL,
"model": config("LOCAL_MODEL", default="llama3.1:8b"), "model": config("LOCAL_MODEL", default="qwen2.5:7b"),
"api_key": "ollama", "api_key": "ollama",
}, },
"default": False, "default": False,
} }
KH_LLMS["ollama-long-context"] = {
"spec": {
"__type__": "kotaemon.llms.LCOllamaChat",
"base_url": KH_OLLAMA_URL.replace("v1/", ""),
"model": config("LOCAL_MODEL", default="qwen2.5:7b"),
"num_ctx": 8192,
},
"default": False,
}
KH_EMBEDDINGS["ollama"] = { KH_EMBEDDINGS["ollama"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.OpenAIEmbeddings", "__type__": "kotaemon.embeddings.OpenAIEmbeddings",
@ -183,7 +201,6 @@ if config("LOCAL_MODEL", default=""):
}, },
"default": False, "default": False,
} }
KH_EMBEDDINGS["fast_embed"] = { KH_EMBEDDINGS["fast_embed"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings", "__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
@ -205,9 +222,9 @@ KH_LLMS["google"] = {
"spec": { "spec": {
"__type__": "kotaemon.llms.chats.LCGeminiChat", "__type__": "kotaemon.llms.chats.LCGeminiChat",
"model_name": "gemini-1.5-flash", "model_name": "gemini-1.5-flash",
"api_key": config("GOOGLE_API_KEY", default="your-key"), "api_key": GOOGLE_API_KEY,
}, },
"default": False, "default": not IS_OPENAI_DEFAULT,
} }
KH_LLMS["groq"] = { KH_LLMS["groq"] = {
"spec": { "spec": {
@ -241,8 +258,9 @@ KH_EMBEDDINGS["google"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings", "__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
"model": "models/text-embedding-004", "model": "models/text-embedding-004",
"google_api_key": config("GOOGLE_API_KEY", default="your-key"), "google_api_key": GOOGLE_API_KEY,
} },
"default": not IS_OPENAI_DEFAULT,
} }
# KH_EMBEDDINGS["huggingface"] = { # KH_EMBEDDINGS["huggingface"] = {
# "spec": { # "spec": {
@ -301,9 +319,12 @@ SETTINGS_REASONING = {
USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool)
USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool) USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool)
USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool)
GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"] GRAPHRAG_INDEX_TYPES = []
if USE_MS_GRAPHRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.GraphRAGIndex")
if USE_NANO_GRAPHRAG: if USE_NANO_GRAPHRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex") GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex")
if USE_LIGHTRAG: if USE_LIGHTRAG:
@ -323,7 +344,7 @@ GRAPHRAG_INDICES = [
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip" ".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
), ),
"private": False, "private": True,
}, },
"index_type": graph_type, "index_type": graph_type,
} }
@ -338,7 +359,7 @@ KH_INDICES = [
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip" ".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
), ),
"private": False, "private": True,
}, },
"index_type": "ktem.index.file.FileIndex", "index_type": "ktem.index.file.FileIndex",
}, },

26
fly.toml Normal file
View File

@ -0,0 +1,26 @@
# fly.toml app configuration file generated for kotaemon on 2024-12-24T20:56:32+07:00
#
# See https://fly.io/docs/reference/configuration/ for information about how to use this file.
#
app = 'kotaemon'
primary_region = 'sin'
[build]
[mounts]
destination = "/app/ktem_app_data"
source = "ktem_volume"
[http_service]
internal_port = 7860
force_https = true
auto_stop_machines = 'suspend'
auto_start_machines = true
min_machines_running = 0
processes = ['app']
[[vm]]
memory = '4gb'
cpu_kind = 'shared'
cpus = 4

23
launch.sh Executable file
View File

@ -0,0 +1,23 @@
#!/bin/bash
if [ -z "$GRADIO_SERVER_NAME" ]; then
export GRADIO_SERVER_NAME="0.0.0.0"
fi
if [ -z "$GRADIO_SERVER_PORT" ]; then
export GRADIO_SERVER_PORT="7860"
fi
# Check if environment variable KH_DEMO_MODE is set to true
if [ "$KH_DEMO_MODE" = "true" ]; then
echo "KH_DEMO_MODE is true. Launching in demo mode..."
# Command to launch in demo mode
GR_FILE_ROOT_PATH="/app" KH_FEATURE_USER_MANAGEMENT=false USE_LIGHTRAG=false uvicorn sso_app_demo:app --host "$GRADIO_SERVER_NAME" --port "$GRADIO_SERVER_PORT"
else
if [ "$KH_SSO_ENABLED" = "true" ]; then
echo "KH_SSO_ENABLED is true. Launching in SSO mode..."
GR_FILE_ROOT_PATH="/app" KH_SSO_ENABLED=true uvicorn sso_app:app --host "$GRADIO_SERVER_NAME" --port "$GRADIO_SERVER_PORT"
else
ollama serve &
python app.py
fi
fi

View File

@ -3,6 +3,7 @@ from collections import defaultdict
from typing import Generator from typing import Generator
import numpy as np import numpy as np
from decouple import config
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from kotaemon.base import ( from kotaemon.base import (
@ -32,7 +33,9 @@ except ImportError:
MAX_IMAGES = 10 MAX_IMAGES = 10
CITATION_TIMEOUT = 5.0 CITATION_TIMEOUT = 5.0
CONTEXT_RELEVANT_WARNING_SCORE = 0.7 CONTEXT_RELEVANT_WARNING_SCORE = config(
"CONTEXT_RELEVANT_WARNING_SCORE", 0.3, cast=float
)
DEFAULT_QA_TEXT_PROMPT = ( DEFAULT_QA_TEXT_PROMPT = (
"Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501 "Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501
@ -385,7 +388,9 @@ class AnswerWithContextPipeline(BaseComponent):
doc = id2docs[id_] doc = id2docs[id_]
doc_score = doc.metadata.get("llm_trulens_score", 0.0) doc_score = doc.metadata.get("llm_trulens_score", 0.0)
is_open = not has_llm_score or ( is_open = not has_llm_score or (
doc_score > CONTEXT_RELEVANT_WARNING_SCORE and len(with_citation) == 0 doc_score
> CONTEXT_RELEVANT_WARNING_SCORE
# and len(with_citation) == 0
) )
without_citation.append( without_citation.append(
Document( Document(

View File

@ -2,6 +2,8 @@ from difflib import SequenceMatcher
def find_text(search_span, context, min_length=5): def find_text(search_span, context, min_length=5):
search_span, context = search_span.lower(), context.lower()
sentence_list = search_span.split("\n") sentence_list = search_span.split("\n")
context = context.replace("\n", " ") context = context.replace("\n", " ")
@ -18,7 +20,7 @@ def find_text(search_span, context, min_length=5):
matched_blocks = [] matched_blocks = []
for _, start, length in match_results: for _, start, length in match_results:
if length > max(len(sentence) * 0.2, min_length): if length > max(len(sentence) * 0.25, min_length):
matched_blocks.append((start, start + length)) matched_blocks.append((start, start + length))
if matched_blocks: if matched_blocks:
@ -42,6 +44,9 @@ def find_text(search_span, context, min_length=5):
def find_start_end_phrase( def find_start_end_phrase(
start_phrase, end_phrase, context, min_length=5, max_excerpt_length=300 start_phrase, end_phrase, context, min_length=5, max_excerpt_length=300
): ):
start_phrase, end_phrase = start_phrase.lower(), end_phrase.lower()
context = context.lower()
context = context.replace("\n", " ") context = context.replace("\n", " ")
matches = [] matches = []

View File

@ -177,7 +177,11 @@ class VectorRetrieval(BaseRetrieval):
] ]
elif self.retrieval_mode == "text": elif self.retrieval_mode == "text":
query = text.text if isinstance(text, Document) else text query = text.text if isinstance(text, Document) else text
docs = self.doc_store.query(query, top_k=top_k_first_round, doc_ids=scope) docs = []
if scope:
docs = self.doc_store.query(
query, top_k=top_k_first_round, doc_ids=scope
)
result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs] result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs]
elif self.retrieval_mode == "hybrid": elif self.retrieval_mode == "hybrid":
# similarity search section # similarity search section
@ -206,6 +210,7 @@ class VectorRetrieval(BaseRetrieval):
assert self.doc_store is not None assert self.doc_store is not None
query = text.text if isinstance(text, Document) else text query = text.text if isinstance(text, Document) else text
if scope:
ds_docs = self.doc_store.query( ds_docs = self.doc_store.query(
query, top_k=top_k_first_round, doc_ids=scope query, top_k=top_k_first_round, doc_ids=scope
) )

View File

@ -12,6 +12,7 @@ from .chats import (
LCChatOpenAI, LCChatOpenAI,
LCCohereChat, LCCohereChat,
LCGeminiChat, LCGeminiChat,
LCOllamaChat,
LlamaCppChat, LlamaCppChat,
) )
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
@ -33,6 +34,7 @@ __all__ = [
"LCAnthropicChat", "LCAnthropicChat",
"LCGeminiChat", "LCGeminiChat",
"LCCohereChat", "LCCohereChat",
"LCOllamaChat",
"LCAzureChatOpenAI", "LCAzureChatOpenAI",
"LCChatOpenAI", "LCChatOpenAI",
"LlamaCppChat", "LlamaCppChat",

View File

@ -7,6 +7,7 @@ from .langchain_based import (
LCChatOpenAI, LCChatOpenAI,
LCCohereChat, LCCohereChat,
LCGeminiChat, LCGeminiChat,
LCOllamaChat,
) )
from .llamacpp import LlamaCppChat from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI from .openai import AzureChatOpenAI, ChatOpenAI
@ -20,6 +21,7 @@ __all__ = [
"LCAnthropicChat", "LCAnthropicChat",
"LCGeminiChat", "LCGeminiChat",
"LCCohereChat", "LCCohereChat",
"LCOllamaChat",
"LCChatOpenAI", "LCChatOpenAI",
"LCAzureChatOpenAI", "LCAzureChatOpenAI",
"LCChatMixin", "LCChatMixin",

View File

@ -358,3 +358,40 @@ class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore
raise ImportError("Please install langchain-cohere") raise ImportError("Please install langchain-cohere")
return ChatCohere return ChatCohere
class LCOllamaChat(LCChatMixin, ChatLLM): # type: ignore
base_url: str = Param(
help="Base Ollama URL. (default: http://localhost:11434/api/)", # noqa
required=True,
)
model: str = Param(
help="Model name to use (https://ollama.com/library)",
required=True,
)
num_ctx: int = Param(
help="The size of the context window (default: 8192)",
required=True,
)
def __init__(
self,
model: str | None = None,
base_url: str | None = None,
num_ctx: int | None = None,
**params,
):
super().__init__(
base_url=base_url,
model=model,
num_ctx=num_ctx,
**params,
)
def _get_lc_class(self):
try:
from langchain_ollama import ChatOllama
except ImportError:
raise ImportError("Please install langchain-ollama")
return ChatOllama

View File

@ -3,15 +3,18 @@ from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional
from decouple import config
from fsspec import AbstractFileSystem from fsspec import AbstractFileSystem
from llama_index.readers.file import PDFReader from llama_index.readers.file import PDFReader
from PIL import Image from PIL import Image
from kotaemon.base import Document from kotaemon.base import Document
PDF_LOADER_DPI = config("PDF_LOADER_DPI", default=40, cast=int)
def get_page_thumbnails( def get_page_thumbnails(
file_path: Path, pages: list[int], dpi: int = 80 file_path: Path, pages: list[int], dpi: int = PDF_LOADER_DPI
) -> List[Image.Image]: ) -> List[Image.Image]:
"""Get image thumbnails of the pages in the PDF file. """Get image thumbnails of the pages in the PDF file.

View File

@ -35,6 +35,7 @@ dependencies = [
"langchain-openai>=0.1.4,<0.2.0", "langchain-openai>=0.1.4,<0.2.0",
"langchain-google-genai>=1.0.3,<2.0.0", "langchain-google-genai>=1.0.3,<2.0.0",
"langchain-anthropic", "langchain-anthropic",
"langchain-ollama",
"langchain-cohere>=0.2.4,<0.3.0", "langchain-cohere>=0.2.4,<0.3.0",
"llama-hub>=0.0.79,<0.1.0", "llama-hub>=0.0.79,<0.1.0",
"llama-index>=0.10.40,<0.11.0", "llama-index>=0.10.40,<0.11.0",

View File

@ -13,7 +13,7 @@ from ktem.settings import BaseSettingGroup, SettingGroup, SettingReasoningGroup
from theflow.settings import settings from theflow.settings import settings
from theflow.utils.modules import import_dotted_string from theflow.utils.modules import import_dotted_string
BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "") BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "")
class BaseApp: class BaseApp:
@ -57,7 +57,7 @@ class BaseApp:
self._pdf_view_js = self._pdf_view_js.replace( self._pdf_view_js = self._pdf_view_js.replace(
"PDFJS_PREBUILT_DIR", "PDFJS_PREBUILT_DIR",
pdf_js_dist_dir, pdf_js_dist_dir,
).replace("GRADIO_ROOT_PATH", BASE_PATH) ).replace("GR_FILE_ROOT_PATH", BASE_PATH)
with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi: with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi:
self._svg_js = fi.read() self._svg_js = fi.read()
@ -79,7 +79,7 @@ class BaseApp:
self.default_settings.index.finalize() self.default_settings.index.finalize()
self.settings_state = gr.State(self.default_settings.flatten()) self.settings_state = gr.State(self.default_settings.flatten())
self.user_id = gr.State(1 if not self.f_user_management else None) self.user_id = gr.State("default" if not self.f_user_management else None)
def initialize_indices(self): def initialize_indices(self):
"""Create the index manager, start indices, and register to app settings""" """Create the index manager, start indices, and register to app settings"""
@ -173,15 +173,25 @@ class BaseApp:
"""Called when the app is created""" """Called when the app is created"""
def make(self): def make(self):
markmap_js = """
<script>
window.markmap = {
/** @type AutoLoaderOptions */
autoLoader: {
toolbar: true, // Enable toolbar
},
};
</script>
"""
external_js = ( external_js = (
"<script type='module' " "<script type='module' "
"src='https://cdn.skypack.dev/pdfjs-viewer-element'>" "src='https://cdn.skypack.dev/pdfjs-viewer-element'>"
"</script>" "</script>"
"<script>"
f"{self._svg_js}"
"</script>"
"<script type='module' " "<script type='module' "
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa "src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
f"{markmap_js}"
"<script src='https://cdn.jsdelivr.net/npm/markmap-autoloader@0.16'></script>" # noqa
"<script src='https://cdn.jsdelivr.net/npm/minisearch@7.1.1/dist/umd/index.min.js'></script>" # noqa
"</script>" "</script>"
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa "<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
) )

View File

@ -326,7 +326,12 @@ pdfjs-viewer-element {
/* Switch checkbox styles */ /* Switch checkbox styles */
#is-public-checkbox { /* #is-public-checkbox {
position: relative;
top: 4px;
} */
#suggest-chat-checkbox {
position: relative; position: relative;
top: 4px; top: 4px;
} }
@ -411,3 +416,43 @@ details.evidence {
tbody:not(.row_odd) { tbody:not(.row_odd) {
background: var(--table-even-background-fill); background: var(--table-even-background-fill);
} }
#chat-suggestion {
max-height: 350px;
}
#chat-suggestion table {
overflow: hidden;
}
#chat-suggestion table thead {
display: none;
}
#paper-suggestion table {
overflow: hidden;
}
svg.markmap {
width: 100%;
height: 100%;
font-family: Quicksand, sans-serif;
font-size: 15px;
}
div.markmap {
height: 400px;
}
#google-login {
max-width: 450px;
}
#user-api-key-wrapper {
max-width: 450px;
}
#login-row {
display: grid;
place-items: center;
}

View File

@ -11,10 +11,25 @@ function run() {
version_node.style = "position: fixed; top: 10px; right: 10px;"; version_node.style = "position: fixed; top: 10px; right: 10px;";
main_parent.appendChild(version_node); main_parent.appendChild(version_node);
// add favicon
const favicon = document.createElement("link");
// set favicon attributes
favicon.rel = "icon";
favicon.type = "image/svg+xml";
favicon.href = "/favicon.ico";
document.head.appendChild(favicon);
// setup conversation dropdown placeholder
let conv_dropdown = document.querySelector("#conversation-dropdown input");
conv_dropdown.placeholder = "Browse conversation";
// move info-expand-button // move info-expand-button
let info_expand_button = document.getElementById("info-expand-button"); let info_expand_button = document.getElementById("info-expand-button");
let chat_info_panel = document.getElementById("info-expand"); let chat_info_panel = document.getElementById("info-expand");
chat_info_panel.insertBefore(info_expand_button, chat_info_panel.childNodes[2]); chat_info_panel.insertBefore(
info_expand_button,
chat_info_panel.childNodes[2]
);
// move toggle-side-bar button // move toggle-side-bar button
let chat_expand_button = document.getElementById("chat-expand-button"); let chat_expand_button = document.getElementById("chat-expand-button");
@ -24,22 +39,24 @@ function run() {
// move setting close button // move setting close button
let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav"); let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav");
let setting_close_button = document.getElementById("save-setting-btn"); let setting_close_button = document.getElementById("save-setting-btn");
if (setting_close_button) {
setting_tab_nav_bar.appendChild(setting_close_button); setting_tab_nav_bar.appendChild(setting_close_button);
}
let default_conv_column_min_width = "min(300px, 100%)"; let default_conv_column_min_width = "min(300px, 100%)";
conv_column.style.minWidth = default_conv_column_min_width conv_column.style.minWidth = default_conv_column_min_width;
globalThis.toggleChatColumn = (() => { globalThis.toggleChatColumn = () => {
/* get flex-grow value of chat_column */ /* get flex-grow value of chat_column */
let flex_grow = conv_column.style.flexGrow; let flex_grow = conv_column.style.flexGrow;
if (flex_grow == '0') { if (flex_grow == "0") {
conv_column.style.flexGrow = '1'; conv_column.style.flexGrow = "1";
conv_column.style.minWidth = default_conv_column_min_width; conv_column.style.minWidth = default_conv_column_min_width;
} else { } else {
conv_column.style.flexGrow = '0'; conv_column.style.flexGrow = "0";
conv_column.style.minWidth = "0px"; conv_column.style.minWidth = "0px";
} }
}); };
chat_column.insertBefore(chat_expand_button, chat_column.firstChild); chat_column.insertBefore(chat_expand_button, chat_column.firstChild);
@ -47,22 +64,34 @@ function run() {
let mindmap_checkbox = document.getElementById("use-mindmap-checkbox"); let mindmap_checkbox = document.getElementById("use-mindmap-checkbox");
let citation_dropdown = document.getElementById("citation-dropdown"); let citation_dropdown = document.getElementById("citation-dropdown");
let chat_setting_panel = document.getElementById("chat-settings-expand"); let chat_setting_panel = document.getElementById("chat-settings-expand");
chat_setting_panel.insertBefore(mindmap_checkbox, chat_setting_panel.childNodes[2]); chat_setting_panel.insertBefore(
mindmap_checkbox,
chat_setting_panel.childNodes[2]
);
chat_setting_panel.insertBefore(citation_dropdown, mindmap_checkbox); chat_setting_panel.insertBefore(citation_dropdown, mindmap_checkbox);
// move share conv checkbox
let report_div = document.querySelector(
"#report-accordion > div:nth-child(3) > div:nth-child(1)"
);
let share_conv_checkbox = document.getElementById("is-public-checkbox");
if (share_conv_checkbox) {
report_div.insertBefore(share_conv_checkbox, report_div.querySelector("button"));
}
// create slider toggle // create slider toggle
const is_public_checkbox = document.getElementById("is-public-checkbox"); const is_public_checkbox = document.getElementById("suggest-chat-checkbox");
const label_element = is_public_checkbox.getElementsByTagName("label")[0]; const label_element = is_public_checkbox.getElementsByTagName("label")[0];
const checkbox_span = is_public_checkbox.getElementsByTagName("span")[0]; const checkbox_span = is_public_checkbox.getElementsByTagName("span")[0];
new_div = document.createElement("div"); new_div = document.createElement("div");
label_element.classList.add("switch"); label_element.classList.add("switch");
is_public_checkbox.appendChild(checkbox_span); is_public_checkbox.appendChild(checkbox_span);
label_element.appendChild(new_div) label_element.appendChild(new_div);
// clpse // clpse
globalThis.clpseFn = (id) => { globalThis.clpseFn = (id) => {
var obj = document.getElementById('clpse-btn-' + id); var obj = document.getElementById("clpse-btn-" + id);
obj.classList.toggle("clpse-active"); obj.classList.toggle("clpse-active");
var content = obj.nextElementSibling; var content = obj.nextElementSibling;
if (content.style.display === "none") { if (content.style.display === "none") {
@ -70,29 +99,29 @@ function run() {
} else { } else {
content.style.display = "none"; content.style.display = "none";
} }
} };
// store info in local storage // store info in local storage
globalThis.setStorage = (key, value) => { globalThis.setStorage = (key, value) => {
localStorage.setItem(key, value) localStorage.setItem(key, value);
} };
globalThis.getStorage = (key, value) => { globalThis.getStorage = (key, value) => {
item = localStorage.getItem(key); item = localStorage.getItem(key);
return item ? item : value; return item ? item : value;
} };
globalThis.removeFromStorage = (key) => { globalThis.removeFromStorage = (key) => {
localStorage.removeItem(key) localStorage.removeItem(key);
} };
// Function to scroll to given citation with ID // Function to scroll to given citation with ID
// Sleep function using Promise and setTimeout // Sleep function using Promise and setTimeout
function sleep(ms) { function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms)); return new Promise((resolve) => setTimeout(resolve, ms));
} }
globalThis.scrollToCitation = async (event) => { globalThis.scrollToCitation = async (event) => {
event.preventDefault(); // Prevent the default link behavior event.preventDefault(); // Prevent the default link behavior
var citationId = event.target.getAttribute('id'); var citationId = event.target.getAttribute("id");
await sleep(100); // Sleep for 100 milliseconds await sleep(100); // Sleep for 100 milliseconds
@ -110,8 +139,148 @@ function run() {
detail_elem.getElementsByClassName("pdf-link").item(0).click(); detail_elem.getElementsByClassName("pdf-link").item(0).click();
} else { } else {
if (citation) { if (citation) {
citation.scrollIntoView({ behavior: 'smooth' }); citation.scrollIntoView({ behavior: "smooth" });
}
}
};
globalThis.fullTextSearch = () => {
// Assign text selection event to last bot message
var bot_messages = document.querySelectorAll(
"div#main-chat-bot div.message-row.bot-row"
);
var last_bot_message = bot_messages[bot_messages.length - 1];
// check if the last bot message has class "text_selection"
if (last_bot_message.classList.contains("text_selection")) {
return;
}
// assign new class to last message
last_bot_message.classList.add("text_selection");
// Get sentences from evidence div
var evidences = document.querySelectorAll(
"#html-info-panel > div:last-child > div > details.evidence div.evidence-content"
);
console.log("Indexing evidences", evidences);
const segmenterEn = new Intl.Segmenter("en", { granularity: "sentence" });
// Split sentences and save to all_segments list
var all_segments = [];
for (var evidence of evidences) {
// check if <details> tag is open
if (!evidence.parentElement.open) {
continue;
}
var markmap_div = evidence.querySelector("div.markmap");
if (markmap_div) {
continue;
}
var evidence_content = evidence.textContent.replace(/[\r\n]+/g, " ");
sentence_it = segmenterEn.segment(evidence_content)[Symbol.iterator]();
while ((sentence = sentence_it.next().value)) {
segment = sentence.segment.trim();
if (segment) {
all_segments.push({
id: all_segments.length,
text: segment,
});
}
}
}
let miniSearch = new MiniSearch({
fields: ["text"], // fields to index for full-text search
storeFields: ["text"],
});
// Index all documents
miniSearch.addAll(all_segments);
last_bot_message.addEventListener("mouseup", () => {
let selection = window.getSelection().toString();
let results = miniSearch.search(selection);
if (results.length == 0) {
return;
}
let matched_text = results[0].text;
console.log("query\n", selection, "\nmatched text\n", matched_text);
var evidences = document.querySelectorAll(
"#html-info-panel > div:last-child > div > details.evidence div.evidence-content"
);
// check if modal is open
var modal = document.getElementById("pdf-modal");
// convert all <mark> in evidences to normal text
evidences.forEach((evidence) => {
evidence.querySelectorAll("mark").forEach((mark) => {
mark.outerHTML = mark.innerText;
});
});
// highlight matched_text in evidences
for (var evidence of evidences) {
var evidence_content = evidence.textContent.replace(/[\r\n]+/g, " ");
if (evidence_content.includes(matched_text)) {
// select all p and li elements
paragraphs = evidence.querySelectorAll("p, li");
for (var p of paragraphs) {
var p_content = p.textContent.replace(/[\r\n]+/g, " ");
if (p_content.includes(matched_text)) {
p.innerHTML = p_content.replace(
matched_text,
"<mark>" + matched_text + "</mark>"
);
console.log("highlighted", matched_text, "in", p);
if (modal.style.display == "block") {
// trigger on click event of PDF Preview link
var detail_elem = p;
// traverse up the DOM tree to find the parent element with tag detail
while (detail_elem.tagName.toLowerCase() != "details") {
detail_elem = detail_elem.parentElement;
}
detail_elem.getElementsByClassName("pdf-link").item(0).click();
} else {
p.scrollIntoView({ behavior: "smooth", block: "center" });
}
break;
} }
} }
} }
} }
});
};
globalThis.spawnDocument = (content, options) => {
let opt = {
window: "",
closeChild: true,
childId: "_blank",
};
Object.assign(opt, options);
// minimal error checking
if (
content &&
typeof content.toString == "function" &&
content.toString().length
) {
let child = window.open("", opt.childId, opt.window);
child.document.write(content.toString());
if (opt.closeChild) child.document.close();
return child;
}
};
globalThis.fillChatInput = (event) => {
let chatInput = document.querySelector("#chat-input textarea");
// fill the chat input with the clicked div text
chatInput.value = "Explain " + event.target.textContent;
var evt = new Event("change");
chatInput.dispatchEvent(new Event("input", { bubbles: true }));
chatInput.focus();
};
}

View File

@ -17,7 +17,7 @@ function onBlockLoad () {
<span class="close" id="modal-expand">&#x26F6;</span> <span class="close" id="modal-expand">&#x26F6;</span>
</div> </div>
<div class="modal-body"> <div class="modal-body">
<pdfjs-viewer-element id="pdf-viewer" viewer-path="GRADIO_ROOT_PATH/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true"> <pdfjs-viewer-element id="pdf-viewer" viewer-path="GR_FILE_ROOT_PATH/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
</pdfjs-viewer-element> </pdfjs-viewer-element>
</div> </div>
</div> </div>
@ -51,28 +51,65 @@ function onBlockLoad () {
modal.style.height = "85dvh"; modal.style.height = "85dvh";
} }
}; };
};
function matchRatio(str1, str2) {
let n = str1.length;
let m = str2.length;
let lcs = [];
for (let i = 0; i <= n; i++) {
lcs[i] = [];
for (let j = 0; j <= m; j++) {
lcs[i][j] = 0;
}
} }
globalThis.compareText = (search_phrase, page_label) => { let result = "";
var iframe = document.querySelector("#pdf-viewer").iframe; let max = 0;
var innerDoc = (iframe.contentDocument) ? iframe.contentDocument : iframe.contentWindow.document; for (let i = 0; i < n; i++) {
for (let j = 0; j < m; j++) {
if (str1[i] === str2[j]) {
lcs[i + 1][j + 1] = lcs[i][j] + 1;
if (lcs[i + 1][j + 1] > max) {
max = lcs[i + 1][j + 1];
result = str1.substring(i - max + 1, i + 1);
}
}
}
}
var query_selector = ( return result.length / Math.min(n, m);
}
globalThis.compareText = (search_phrases, page_label) => {
var iframe = document.querySelector("#pdf-viewer").iframe;
var innerDoc = iframe.contentDocument
? iframe.contentDocument
: iframe.contentWindow.document;
var renderedPages = innerDoc.querySelectorAll("div#viewer div.page");
if (renderedPages.length == 0) {
// if pages are not rendered yet, wait and try again
setTimeout(() => compareText(search_phrases, page_label), 2000);
return;
}
var query_selector =
"#viewer > div[data-page-number='" + "#viewer > div[data-page-number='" +
page_label + page_label +
"'] > div.textLayer > span" "'] > div.textLayer > span";
);
var page_spans = innerDoc.querySelectorAll(query_selector); var page_spans = innerDoc.querySelectorAll(query_selector);
for (var i = 0; i < page_spans.length; i++) { for (var i = 0; i < page_spans.length; i++) {
var span = page_spans[i]; var span = page_spans[i];
if ( if (
span.textContent.length > 4 && span.textContent.length > 4 &&
( search_phrases.some(
search_phrase.includes(span.textContent) || (phrase) => matchRatio(phrase, span.textContent) > 0.5
span.textContent.includes(search_phrase)
) )
) { ) {
span.innerHTML = "<span class='highlight selected'>" + span.textContent + "</span>"; span.innerHTML =
"<span class='highlight selected'>" + span.textContent + "</span>";
} else { } else {
// if span is already highlighted, remove it // if span is already highlighted, remove it
if (span.querySelector(".highlight")) { if (span.querySelector(".highlight")) {
@ -80,11 +117,11 @@ function onBlockLoad () {
} }
} }
} }
} };
// Sleep function using Promise and setTimeout // Sleep function using Promise and setTimeout
function sleep(ms) { function sleep(ms) {
return new Promise(resolve => setTimeout(resolve, ms)); return new Promise((resolve) => setTimeout(resolve, ms));
} }
// Function to open modal and display PDF // Function to open modal and display PDF
@ -94,7 +131,19 @@ function onBlockLoad () {
var src = target.getAttribute("data-src"); var src = target.getAttribute("data-src");
var page = target.getAttribute("data-page"); var page = target.getAttribute("data-page");
var search = target.getAttribute("data-search"); var search = target.getAttribute("data-search");
var phrase = target.getAttribute("data-phrase"); var highlighted_spans =
target.parentElement.parentElement.querySelectorAll("mark");
// Get text from highlighted spans
var search_phrases = Array.from(highlighted_spans).map(
(span) => span.textContent
);
// Use regex to strip 【id】from search phrases
search_phrases = search_phrases.map((phrase) =>
phrase.replace(/【\d+】/g, "")
);
// var phrase = target.getAttribute("data-phrase");
var pdfViewer = document.getElementById("pdf-viewer"); var pdfViewer = document.getElementById("pdf-viewer");
@ -109,7 +158,7 @@ function onBlockLoad () {
var scrollableDiv = document.getElementById("chat-info-panel"); var scrollableDiv = document.getElementById("chat-info-panel");
infor_panel_scroll_pos = scrollableDiv.scrollTop; infor_panel_scroll_pos = scrollableDiv.scrollTop;
var modal = document.getElementById("pdf-modal") var modal = document.getElementById("pdf-modal");
modal.style.display = "block"; modal.style.display = "block";
var info_panel = document.getElementById("html-info-panel"); var info_panel = document.getElementById("html-info-panel");
if (info_panel) { if (info_panel) {
@ -119,8 +168,8 @@ function onBlockLoad () {
/* search for text inside PDF page */ /* search for text inside PDF page */
await sleep(500); await sleep(500);
compareText(search, page); compareText(search_phrases, page);
} };
globalThis.assignPdfOnclickEvent = () => { globalThis.assignPdfOnclickEvent = () => {
// Get all links and attach click event // Get all links and attach click event
@ -128,11 +177,10 @@ function onBlockLoad () {
for (var i = 0; i < links.length; i++) { for (var i = 0; i < links.length; i++) {
links[i].onclick = openModal; links[i].onclick = openModal;
} }
} };
var created_modal = document.getElementById("pdf-viewer"); var created_modal = document.getElementById("pdf-viewer");
if (!created_modal) { if (!created_modal) {
createModal(); createModal();
} }
} }

View File

@ -29,7 +29,7 @@ class BaseConversation(SQLModel):
datetime.datetime.now(get_localzone()).strftime("%Y-%m-%d %H:%M:%S") datetime.datetime.now(get_localzone()).strftime("%Y-%m-%d %H:%M:%S")
) )
) )
user: int = Field(default=0) # For now we only have one user user: str = Field(default="") # For now we only have one user
is_public: bool = Field(default=False) is_public: bool = Field(default=False)
@ -55,7 +55,9 @@ class BaseUser(SQLModel):
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id: Optional[int] = Field(default=None, primary_key=True) id: str = Field(
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
)
username: str = Field(unique=True) username: str = Field(unique=True)
username_lower: str = Field(unique=True) username_lower: str = Field(unique=True)
password: str password: str
@ -76,7 +78,7 @@ class BaseSettings(SQLModel):
id: str = Field( id: str = Field(
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
) )
user: int = Field(default=0) user: str = Field(default="")
setting: dict = Field(default={}, sa_column=Column(JSON)) setting: dict = Field(default={}, sa_column=Column(JSON))
@ -97,4 +99,4 @@ class BaseIssueReport(SQLModel):
issues: dict = Field(default={}, sa_column=Column(JSON)) issues: dict = Field(default={}, sa_column=Column(JSON))
chat: Optional[dict] = Field(default=None, sa_column=Column(JSON)) chat: Optional[dict] = Field(default=None, sa_column=Column(JSON))
settings: Optional[dict] = Field(default=None, sa_column=Column(JSON)) settings: Optional[dict] = Field(default=None, sa_column=Column(JSON))
user: Optional[int] = Field(default=None) user: Optional[str] = Field(default=None)

View File

@ -17,6 +17,10 @@ from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
def generate_uuid():
return str(uuid.uuid4())
class FileIndex(BaseIndex): class FileIndex(BaseIndex):
""" """
File index to store and allow retrieval of files File index to store and allow retrieval of files
@ -76,7 +80,7 @@ class FileIndex(BaseIndex):
"date_created": Column( "date_created": Column(
DateTime(timezone=True), default=datetime.now(get_localzone()) DateTime(timezone=True), default=datetime.now(get_localzone())
), ),
"user": Column(Integer, default=1), "user": Column(String, default=""),
"note": Column( "note": Column(
MutableDict.as_mutable(JSON), # type: ignore MutableDict.as_mutable(JSON), # type: ignore
default={}, default={},
@ -101,7 +105,7 @@ class FileIndex(BaseIndex):
"date_created": Column( "date_created": Column(
DateTime(timezone=True), default=datetime.now(get_localzone()) DateTime(timezone=True), default=datetime.now(get_localzone())
), ),
"user": Column(Integer, default=1), "user": Column(String, default=""),
"note": Column( "note": Column(
MutableDict.as_mutable(JSON), # type: ignore MutableDict.as_mutable(JSON), # type: ignore
default={}, default={},
@ -117,7 +121,7 @@ class FileIndex(BaseIndex):
"source_id": Column(String), "source_id": Column(String),
"target_id": Column(String), "target_id": Column(String),
"relation_type": Column(String), "relation_type": Column(String),
"user": Column(Integer, default=1), "user": Column(String, default=""),
}, },
) )
FileGroup = type( FileGroup = type(
@ -125,12 +129,20 @@ class FileIndex(BaseIndex):
(Base,), (Base,),
{ {
"__tablename__": f"index__{self.id}__group", "__tablename__": f"index__{self.id}__group",
"id": Column(Integer, primary_key=True, autoincrement=True), "__table_args__": (
UniqueConstraint("name", "user", name="_name_user_uc"),
),
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"date_created": Column( "date_created": Column(
DateTime(timezone=True), default=datetime.now(get_localzone()) DateTime(timezone=True), default=datetime.now(get_localzone())
), ),
"name": Column(String, unique=True), "name": Column(String),
"user": Column(Integer, default=1), "user": Column(String, default=""),
"data": Column( "data": Column(
MutableDict.as_mutable(JSON), # type: ignore MutableDict.as_mutable(JSON), # type: ignore
default={"files": []}, default={"files": []},

View File

@ -20,9 +20,14 @@ from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from ...utils.commands import WEB_SEARCH_COMMAND from ...utils.commands import WEB_SEARCH_COMMAND
from ...utils.rate_limit import check_rate_limit
from .utils import download_arxiv_pdf, is_arxiv_url
DOWNLOAD_MESSAGE = "Press again to download" KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
DOWNLOAD_MESSAGE = "Start download"
MAX_FILENAME_LENGTH = 20 MAX_FILENAME_LENGTH = 20
MAX_FILE_COUNT = 200
chat_input_focus_js = """ chat_input_focus_js = """
function() { function() {
@ -31,6 +36,15 @@ function() {
} }
""" """
chat_input_focus_js_with_submit = """
function() {
let chatInput = document.querySelector("#chat-input textarea");
let chatInputSubmit = document.querySelector("#chat-input button.submit-button");
chatInputSubmit.click();
chatInput.focus();
}
"""
update_file_list_js = """ update_file_list_js = """
function(file_list) { function(file_list) {
var values = []; var values = [];
@ -53,6 +67,7 @@ function(file_list) {
allowSpaces: true, allowSpaces: true,
}) })
input_box = document.querySelector('#chat-input textarea'); input_box = document.querySelector('#chat-input textarea');
tribute.detach(input_box);
tribute.attach(input_box); tribute.attach(input_box);
} }
""".replace( """.replace(
@ -128,6 +143,8 @@ class FileIndexPage(BasePage):
# TODO: on_building_ui is not correctly named if it's always called in # TODO: on_building_ui is not correctly named if it's always called in
# the constructor # the constructor
self.public_events = [f"onFileIndex{index.id}Changed"] self.public_events = [f"onFileIndex{index.id}Changed"]
if not KH_DEMO_MODE:
self.on_building_ui() self.on_building_ui()
def upload_instruction(self) -> str: def upload_instruction(self) -> str:
@ -201,9 +218,9 @@ class FileIndexPage(BasePage):
with gr.Accordion("Advance options", open=False): with gr.Accordion("Advance options", open=False):
with gr.Row(): with gr.Row():
if not KH_SSO_ENABLED:
self.download_all_button = gr.DownloadButton( self.download_all_button = gr.DownloadButton(
"Download all files", "Download all files",
visible=True,
) )
self.delete_all_button = gr.Button( self.delete_all_button = gr.Button(
"Delete all files", "Delete all files",
@ -249,13 +266,13 @@ class FileIndexPage(BasePage):
) )
with gr.Column(visible=False) as self._group_info_panel: with gr.Column(visible=False) as self._group_info_panel:
self.selected_group_id = gr.State(value=None)
self.group_label = gr.Markdown() self.group_label = gr.Markdown()
self.group_name = gr.Textbox( self.group_name = gr.Textbox(
label="Group name", label="Group name",
placeholder="Group name", placeholder="Group name",
lines=1, lines=1,
max_lines=1, max_lines=1,
interactive=False,
) )
self.group_files = gr.Dropdown( self.group_files = gr.Dropdown(
label="Attached files", label="Attached files",
@ -290,7 +307,7 @@ class FileIndexPage(BasePage):
) )
gr.Markdown("(separated by new line)") gr.Markdown("(separated by new line)")
with gr.Accordion("Advanced indexing options", open=True): with gr.Accordion("Advanced indexing options", open=False):
with gr.Row(): with gr.Row():
self.reindex = gr.Checkbox( self.reindex = gr.Checkbox(
value=False, label="Force reindex file", container=False value=False, label="Force reindex file", container=False
@ -324,6 +341,9 @@ class FileIndexPage(BasePage):
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app""" """Subscribe to the declared public event of the app"""
if KH_DEMO_MODE:
return
self._app.subscribe_event( self._app.subscribe_event(
name=f"onFileIndex{self._index.id}Changed", name=f"onFileIndex{self._index.id}Changed",
definition={ definition={
@ -500,6 +520,34 @@ class FileIndexPage(BasePage):
return not is_zipped_state, new_button return not is_zipped_state, new_button
def download_single_file_simple(self, is_zipped_state, file_html, file_id):
with Session(engine) as session:
source = session.execute(
select(self._index._resources["Source"]).where(
self._index._resources["Source"].id == file_id
)
).first()
if source:
target_file_name = Path(source[0].name)
# create a temporary file with a path to export
output_file_path = os.path.join(
flowsettings.KH_ZIP_OUTPUT_DIR, target_file_name.stem + ".html"
)
with open(output_file_path, "w") as f:
f.write(file_html)
if is_zipped_state:
new_button = gr.DownloadButton(label="Download", value=None)
else:
# export the file path
new_button = gr.DownloadButton(
label=DOWNLOAD_MESSAGE,
value=output_file_path,
)
return not is_zipped_state, new_button
def download_all_files(self): def download_all_files(self):
if self._index.config.get("private", False): if self._index.config.get("private", False):
raise gr.Error("This feature is not available for private collection.") raise gr.Error("This feature is not available for private collection.")
@ -543,8 +591,145 @@ class FileIndexPage(BasePage):
gr.update(visible=True), gr.update(visible=True),
] ]
def on_register_quick_uploads(self):
try:
# quick file upload event registration of first Index only
if self._index.id == 1:
self.quick_upload_state = gr.State(value=[])
print("Setting up quick upload event")
# override indexing function from chat page
self._app.chat_page.first_indexing_url_fn = (
self.index_fn_url_with_default_loaders
)
if not KH_DEMO_MODE:
quickUploadedEvent = (
self._app.chat_page.quick_file_upload.upload(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_file_with_default_loaders,
inputs=[
self._app.chat_page.quick_file_upload,
gr.State(value=False),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
concurrency_limit=10,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_file_upload,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(
f"onFileIndex{self._index.id}Changed"
):
quickUploadedEvent = quickUploadedEvent.then(**event)
quickUploadedEvent = (
quickUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
)
.then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
.then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js_with_submit,
)
)
quickURLUploadedEvent = (
self._app.chat_page.quick_urls.submit(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_url_with_default_loaders,
inputs=[
self._app.chat_page.quick_urls,
gr.State(value=False),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
concurrency_limit=10,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_urls,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickURLUploadedEvent = quickURLUploadedEvent.then(**event)
quickURLUploadedEvent = quickURLUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
)
if not KH_DEMO_MODE:
quickURLUploadedEvent = quickURLUploadedEvent.then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
)
quickURLUploadedEvent = quickURLUploadedEvent.then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js_with_submit,
)
except Exception as e:
print(e)
def on_register_events(self): def on_register_events(self):
"""Register all events to the app""" """Register all events to the app"""
self.on_register_quick_uploads()
if KH_DEMO_MODE:
return
onDeleted = ( onDeleted = (
self.delete_button.click( self.delete_button.click(
fn=self.delete_event, fn=self.delete_event,
@ -606,6 +791,7 @@ class FileIndexPage(BasePage):
], ],
) )
if not KH_SSO_ENABLED:
self.download_all_button.click( self.download_all_button.click(
fn=self.download_all_files, fn=self.download_all_files,
inputs=[], inputs=[],
@ -659,12 +845,20 @@ class FileIndexPage(BasePage):
], ],
) )
if not KH_SSO_ENABLED:
self.download_single_button.click( self.download_single_button.click(
fn=self.download_single_file, fn=self.download_single_file,
inputs=[self.is_zipped_state, self.selected_file_id], inputs=[self.is_zipped_state, self.selected_file_id],
outputs=[self.is_zipped_state, self.download_single_button], outputs=[self.is_zipped_state, self.download_single_button],
show_progress="hidden", show_progress="hidden",
) )
else:
self.download_single_button.click(
fn=self.download_single_file_simple,
inputs=[self.is_zipped_state, self.chunks, self.selected_file_id],
outputs=[self.is_zipped_state, self.download_single_button],
show_progress="hidden",
)
onUploaded = ( onUploaded = (
self.upload_button.click( self.upload_button.click(
@ -689,121 +883,6 @@ class FileIndexPage(BasePage):
) )
) )
try:
# quick file upload event registration of first Index only
if self._index.id == 1:
self.quick_upload_state = gr.State(value=[])
print("Setting up quick upload event")
# override indexing function from chat page
self._app.chat_page.first_indexing_url_fn = (
self.index_fn_url_with_default_loaders
)
quickUploadedEvent = (
self._app.chat_page.quick_file_upload.upload(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_file_with_default_loaders,
inputs=[
self._app.chat_page.quick_file_upload,
gr.State(value=False),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_file_upload,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickUploadedEvent = quickUploadedEvent.then(**event)
quickURLUploadedEvent = (
self._app.chat_page.quick_urls.submit(
fn=lambda: gr.update(
value="Please wait for the indexing process "
"to complete before adding your question."
),
outputs=self._app.chat_page.quick_file_upload_status,
)
.then(
fn=self.index_fn_url_with_default_loaders,
inputs=[
self._app.chat_page.quick_urls,
gr.State(value=True),
self._app.settings_state,
self._app.user_id,
],
outputs=self.quick_upload_state,
)
.success(
fn=lambda: [
gr.update(value=None),
gr.update(value="select"),
],
outputs=[
self._app.chat_page.quick_urls,
self._app.chat_page._indices_input[0],
],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
quickURLUploadedEvent = quickURLUploadedEvent.then(**event)
quickUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
).then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
).then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js,
)
quickURLUploadedEvent.success(
fn=lambda x: x,
inputs=self.quick_upload_state,
outputs=self._app.chat_page._indices_input[1],
).then(
fn=lambda: gr.update(value="Indexing completed."),
outputs=self._app.chat_page.quick_file_upload_status,
).then(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
).then(
fn=lambda: True,
inputs=None,
outputs=None,
js=chat_input_focus_js,
)
except Exception as e:
print(e)
uploadedEvent = onUploaded.then( uploadedEvent = onUploaded.then(
fn=self.list_file, fn=self.list_file,
inputs=[self._app.user_id, self.filter], inputs=[self._app.user_id, self.filter],
@ -844,7 +923,12 @@ class FileIndexPage(BasePage):
self.group_list.select( self.group_list.select(
fn=self.interact_group_list, fn=self.interact_group_list,
inputs=[self.group_list_state], inputs=[self.group_list_state],
outputs=[self.group_label, self.group_name, self.group_files], outputs=[
self.group_label,
self.selected_group_id,
self.group_name,
self.group_files,
],
show_progress="hidden", show_progress="hidden",
).then( ).then(
fn=lambda: ( fn=lambda: (
@ -875,8 +959,9 @@ class FileIndexPage(BasePage):
gr.update(visible=False), gr.update(visible=False),
gr.update(value="### Add new group"), gr.update(value="### Add new group"),
gr.update(visible=True), gr.update(visible=True),
gr.update(value="", interactive=True), gr.update(value=""),
gr.update(value=[]), gr.update(value=[]),
None,
], ],
outputs=[ outputs=[
self.group_add_button, self.group_add_button,
@ -884,12 +969,13 @@ class FileIndexPage(BasePage):
self._group_info_panel, self._group_info_panel,
self.group_name, self.group_name,
self.group_files, self.group_files,
self.selected_group_id,
], ],
) )
self.group_chat_button.click( self.group_chat_button.click(
fn=self.set_group_id_selector, fn=self.set_group_id_selector,
inputs=[self.group_name], inputs=[self.selected_group_id],
outputs=[ outputs=[
self._index.get_selector_component_ui().selector, self._index.get_selector_component_ui().selector,
self._index.get_selector_component_ui().mode, self._index.get_selector_component_ui().mode,
@ -897,45 +983,54 @@ class FileIndexPage(BasePage):
], ],
) )
onGroupSaved = ( onGroupClosedEvent = {
self.group_save_button.click( "fn": lambda: [
fn=self.save_group,
inputs=[self.group_name, self.group_files, self._app.user_id],
)
.then(
self.list_group,
inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list],
)
.then(
fn=lambda: gr.update(visible=False),
outputs=[self._group_info_panel],
)
)
self.group_close_button.click(
fn=lambda: [
gr.update(visible=True), gr.update(visible=True),
gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False),
None,
], ],
outputs=[ "outputs": [
self.group_add_button, self.group_add_button,
self._group_info_panel, self._group_info_panel,
self.group_close_button, self.group_close_button,
self.group_delete_button, self.group_delete_button,
self.group_chat_button, self.group_chat_button,
self.selected_group_id,
],
}
self.group_close_button.click(**onGroupClosedEvent)
onGroupSaved = (
self.group_save_button.click(
fn=self.save_group,
inputs=[
self.selected_group_id,
self.group_name,
self.group_files,
self._app.user_id,
], ],
) )
onGroupDeleted = self.group_delete_button.click( .then(
fn=self.delete_group,
inputs=[self.group_name],
).then(
self.list_group, self.list_group,
inputs=[self._app.user_id, self.file_list_state], inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list], outputs=[self.group_list_state, self.group_list],
) )
.then(**onGroupClosedEvent)
)
onGroupDeleted = (
self.group_delete_button.click(
fn=self.delete_group,
inputs=[self.selected_group_id],
)
.then(
self.list_group,
inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list],
)
.then(**onGroupClosedEvent)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onGroupDeleted = onGroupDeleted.then(**event) onGroupDeleted = onGroupDeleted.then(**event)
@ -943,10 +1038,21 @@ class FileIndexPage(BasePage):
def _on_app_created(self): def _on_app_created(self):
"""Called when the app is created""" """Called when the app is created"""
if KH_DEMO_MODE:
return
self._app.app.load( self._app.app.load(
self.list_file, self.list_file,
inputs=[self._app.user_id, self.filter], inputs=[self._app.user_id, self.filter],
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
).then(
self.list_group,
inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list],
).then(
self.list_file_names,
inputs=[self.file_list_state],
outputs=[self.group_files],
) )
def _may_extract_zip(self, files, zip_dir: str): def _may_extract_zip(self, files, zip_dir: str):
@ -1089,12 +1195,60 @@ class FileIndexPage(BasePage):
return exist_ids + returned_ids return exist_ids + returned_ids
def index_fn_url_with_default_loaders(self, urls, reindex: bool, settings, user_id): def index_fn_url_with_default_loaders(
returned_ids = [] self,
urls,
reindex: bool,
settings,
user_id,
request: gr.Request,
):
if KH_DEMO_MODE:
check_rate_limit("file_upload", request)
returned_ids: list[str] = []
settings = deepcopy(settings) settings = deepcopy(settings)
settings[f"index.options.{self._index.id}.reader_mode"] = "default" settings[f"index.options.{self._index.id}.reader_mode"] = "default"
settings[f"index.options.{self._index.id}.quick_index_mode"] = True settings[f"index.options.{self._index.id}.quick_index_mode"] = True
if KH_DEMO_MODE:
urls_splitted = urls.split("\n")
if not all(is_arxiv_url(url) for url in urls_splitted):
raise ValueError("All URLs must be valid arXiv URLs")
output_files = [
download_arxiv_pdf(
url,
output_path=os.environ.get("GRADIO_TEMP_DIR", "/tmp"),
)
for url in urls_splitted
]
exist_ids = []
to_process_files = []
for str_file_path in output_files:
file_path = Path(str_file_path)
exist_id = (
self._index.get_indexing_pipeline(settings, user_id)
.route(file_path)
.get_id_if_exists(file_path)
)
if exist_id:
exist_ids.append(exist_id)
else:
to_process_files.append(str_file_path)
returned_ids = []
if to_process_files:
_iter = self.index_fn(to_process_files, [], reindex, settings, user_id)
try:
while next(_iter):
pass
except StopIteration as e:
returned_ids = e.value
returned_ids = exist_ids + returned_ids
else:
if urls: if urls:
_iter = self.index_fn([], urls, reindex, settings, user_id) _iter = self.index_fn([], urls, reindex, settings, user_id)
try: try:
@ -1254,6 +1408,7 @@ class FileIndexPage(BasePage):
return gr.update(choices=file_names) return gr.update(choices=file_names)
def list_group(self, user_id, file_list): def list_group(self, user_id, file_list):
# supply file_list to display the file names in the group
if file_list: if file_list:
file_id_to_name = {item["id"]: item["name"] for item in file_list} file_id_to_name = {item["id"]: item["name"] for item in file_list}
else: else:
@ -1319,27 +1474,42 @@ class FileIndexPage(BasePage):
return results, group_list return results, group_list
def set_group_id_selector(self, selected_group_name): def set_group_id_selector(self, selected_group_id):
FileGroup = self._index._resources["FileGroup"] FileGroup = self._index._resources["FileGroup"]
# check if group_name exist # check if group_name exist
with Session(engine) as session: with Session(engine) as session:
current_group = ( current_group = (
session.query(FileGroup).filter_by(name=selected_group_name).first() session.query(FileGroup).filter_by(id=selected_group_id).first()
) )
file_ids = [json.dumps(current_group.data["files"])] file_ids = [json.dumps(current_group.data["files"])]
return [file_ids, "select", gr.Tabs(selected="chat-tab")] return [file_ids, "select", gr.Tabs(selected="chat-tab")]
def save_group(self, group_name, group_files, user_id): def save_group(self, group_id, group_name, group_files, user_id):
FileGroup = self._index._resources["FileGroup"] FileGroup = self._index._resources["FileGroup"]
current_group = None current_group = None
# check if group_name exist # check if group_name exist
with Session(engine) as session: with Session(engine) as session:
current_group = session.query(FileGroup).filter_by(name=group_name).first() if group_id:
current_group = session.query(FileGroup).filter_by(id=group_id).first()
# update current group with new info
current_group.name = group_name
current_group.data["files"] = group_files # Update the files
session.commit()
else:
current_group = (
session.query(FileGroup)
.filter_by(
name=group_name,
user=user_id,
)
.first()
)
if current_group:
raise gr.Error(f"Group {group_name} already exists")
if not current_group:
current_group = FileGroup( current_group = FileGroup(
name=group_name, name=group_name,
data={"files": group_files}, # type: ignore data={"files": group_files}, # type: ignore
@ -1347,34 +1517,31 @@ class FileIndexPage(BasePage):
) )
session.add(current_group) session.add(current_group)
session.commit() session.commit()
else:
# update current group with new info
current_group.name = group_name
current_group.data["files"] = group_files # Update the files
session.commit()
group_id = current_group.id group_id = current_group.id
gr.Info(f"Group {group_name} has been saved") gr.Info(f"Group {group_name} has been saved")
return group_id return group_id
def delete_group(self, group_name): def delete_group(self, group_id):
if not group_id:
raise gr.Error("No group is selected")
FileGroup = self._index._resources["FileGroup"] FileGroup = self._index._resources["FileGroup"]
group_id = None
with Session(engine) as session: with Session(engine) as session:
group = session.execute( group = session.execute(
select(FileGroup).where(FileGroup.name == group_name) select(FileGroup).where(FileGroup.id == group_id)
).first() ).first()
if group: if group:
item = group[0] item = group[0]
group_id = item.id group_name = item.name
session.delete(item) session.delete(item)
session.commit() session.commit()
gr.Info(f"Group {group_name} has been deleted") gr.Info(f"Group {group_name} has been deleted")
else: else:
raise gr.Error(f"Group {group_name} not found") raise gr.Error("No group found")
return group_id return None
def interact_file_list(self, list_files, ev: gr.SelectData): def interact_file_list(self, list_files, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0: if ev.value == "-" and ev.index[0] == 0:
@ -1394,9 +1561,11 @@ class FileIndexPage(BasePage):
raise gr.Error("No group is selected") raise gr.Error("No group is selected")
selected_item = list_groups[selected_id] selected_item = list_groups[selected_id]
selected_group_id = selected_item["id"]
return ( return (
"### Group Information", "### Group Information",
gr.update(value=selected_item["name"], interactive=False), selected_group_id,
selected_item["name"],
selected_item["files"], selected_item["files"],
) )
@ -1525,6 +1694,10 @@ class FileSelector(BasePage):
self._index._resources["Source"].user == user_id self._index._resources["Source"].user == user_id
) )
if KH_DEMO_MODE:
# limit query by MAX_FILE_COUNT
statement = statement.limit(MAX_FILE_COUNT)
results = session.execute(statement).all() results = session.execute(statement).all()
for result in results: for result in results:
available_ids.append(result[0].id) available_ids.append(result[0].id)

View File

@ -0,0 +1,58 @@
import os
import requests
# regex patterns for Arxiv URL
ARXIV_URL_PATTERNS = [
"https://arxiv.org/abs/",
"https://arxiv.org/pdf/",
]
ILLEGAL_NAME_CHARS = ["\\", "/", ":", "*", "?", '"', "<", ">", "|"]
def clean_name(name):
for char in ILLEGAL_NAME_CHARS:
name = name.replace(char, "_")
return name
def is_arxiv_url(url):
return any(url.startswith(pattern) for pattern in ARXIV_URL_PATTERNS)
# download PDF from Arxiv URL
def download_arxiv_pdf(url, output_path):
if not is_arxiv_url(url):
raise ValueError("Invalid Arxiv URL")
is_abstract_url = "abs" in url
if is_abstract_url:
pdf_url = url.replace("abs", "pdf")
abstract_url = url
else:
pdf_url = url
abstract_url = url.replace("pdf", "abs")
# get paper name from abstract url
response = requests.get(abstract_url)
# parse HTML response and get h1.title
from bs4 import BeautifulSoup
soup = BeautifulSoup(response.content, "html.parser")
name = clean_name(
soup.find("h1", class_="title").text.strip().replace("Title:", "")
)
if not name:
raise ValueError("Failed to get paper name")
output_file_path = os.path.join(output_path, name + ".pdf")
# prevent downloading if file already exists
if not os.path.exists(output_file_path):
response = requests.get(pdf_url)
with open(output_file_path, "wb") as f:
f.write(response.content)
return output_file_path

View File

@ -60,6 +60,7 @@ class LLMManager:
LCAnthropicChat, LCAnthropicChat,
LCCohereChat, LCCohereChat,
LCGeminiChat, LCGeminiChat,
LCOllamaChat,
LlamaCppChat, LlamaCppChat,
) )
@ -69,6 +70,7 @@ class LLMManager:
LCAnthropicChat, LCAnthropicChat,
LCGeminiChat, LCGeminiChat,
LCCohereChat, LCCohereChat,
LCOllamaChat,
LlamaCppChat, LlamaCppChat,
] ]

View File

@ -9,6 +9,7 @@ from ktem.pages.setup import SetupPage
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False) KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False)
KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True) KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True)
@ -19,7 +20,7 @@ if config("KH_FIRST_SETUP", default=False, cast=bool):
def toggle_first_setup_visibility(): def toggle_first_setup_visibility():
global KH_APP_DATA_EXISTS global KH_APP_DATA_EXISTS
is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS is_first_setup = not KH_DEMO_MODE and not KH_APP_DATA_EXISTS
KH_APP_DATA_EXISTS = True KH_APP_DATA_EXISTS = True
return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup) return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup)
@ -70,7 +71,7 @@ class App(BaseApp):
"indices-tab", "indices-tab",
], ],
id="indices-tab", id="indices-tab",
visible=not self.f_user_management, visible=not self.f_user_management and not KH_DEMO_MODE,
) as self._tabs[f"{index.id}-tab"]: ) as self._tabs[f"{index.id}-tab"]:
page = index.get_index_page_ui() page = index.get_index_page_ui()
setattr(self, f"_index_{index.id}", page) setattr(self, f"_index_{index.id}", page)
@ -80,7 +81,7 @@ class App(BaseApp):
elem_id="indices-tab", elem_id="indices-tab",
elem_classes=["fill-main-area-height", "scrollable", "indices-tab"], elem_classes=["fill-main-area-height", "scrollable", "indices-tab"],
id="indices-tab", id="indices-tab",
visible=not self.f_user_management, visible=not self.f_user_management and not KH_DEMO_MODE,
) as self._tabs["indices-tab"]: ) as self._tabs["indices-tab"]:
for index in self.index_manager.indices: for index in self.index_manager.indices:
with gr.Tab( with gr.Tab(
@ -90,6 +91,8 @@ class App(BaseApp):
page = index.get_index_page_ui() page = index.get_index_page_ui()
setattr(self, f"_index_{index.id}", page) setattr(self, f"_index_{index.id}", page)
if not KH_DEMO_MODE:
if not KH_SSO_ENABLED:
with gr.Tab( with gr.Tab(
"Resources", "Resources",
elem_id="resources-tab", elem_id="resources-tab",

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import importlib
import json import json
import re import re
from copy import deepcopy from copy import deepcopy
@ -10,6 +9,7 @@ from ktem.app import BasePage
from ktem.components import reasonings from ktem.components import reasonings
from ktem.db.models import Conversation, engine from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File from ktem.index.file.ui import File
from ktem.reasoning.prompt_optimization.mindmap import MINDMAP_HTML_EXPORT_TEMPLATE
from ktem.reasoning.prompt_optimization.suggest_conversation_name import ( from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline, SuggestConvNamePipeline,
) )
@ -19,29 +19,41 @@ from ktem.reasoning.prompt_optimization.suggest_followup_chat import (
from plotly.io import from_json from plotly.io import from_json
from sqlmodel import Session, select from sqlmodel import Session, select
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls
from ...utils.commands import WEB_SEARCH_COMMAND from ...utils.commands import WEB_SEARCH_COMMAND
from ...utils.hf_papers import get_recommended_papers
from ...utils.rate_limit import check_rate_limit
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE from .common import STATE
from .control import ConversationControl from .control import ConversationControl
from .demo_hint import HintPage
from .paper_list import PaperListPage
from .report import ReportIssue from .report import ReportIssue
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None) KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None)
WebSearch = None WebSearch = None
if KH_WEB_SEARCH_BACKEND: if KH_WEB_SEARCH_BACKEND:
try: try:
module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1) WebSearch = import_dotted_string(KH_WEB_SEARCH_BACKEND, safe=False)
module = importlib.import_module(module_name)
WebSearch = getattr(module, class_name)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
print(f"Error importing {class_name} from {module_name}: {e}") print(f"Error importing {KH_WEB_SEARCH_BACKEND}: {e}")
REASONING_LIMITS = 2 if KH_DEMO_MODE else 10
DEFAULT_SETTING = "(default)" DEFAULT_SETTING = "(default)"
INFO_PANEL_SCALES = {True: 8, False: 4} INFO_PANEL_SCALES = {True: 8, False: 4}
DEFAULT_QUESTION = (
"What is the summary of this document?"
if not KH_DEMO_MODE
else "What is the summary of this paper?"
)
chat_input_focus_js = """ chat_input_focus_js = """
function() { function() {
@ -50,8 +62,59 @@ function() {
} }
""" """
quick_urls_submit_js = """
function() {
let urlInput = document.querySelector("#quick-url-demo textarea");
console.log("URL input:", urlInput);
urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'}));
}
"""
recommended_papers_js = """
function() {
// Get all links and attach click event
var links = document.querySelectorAll("#related-papers a");
function submitPaper(event) {
event.preventDefault();
var target = event.currentTarget;
var url = target.getAttribute("href");
console.log("URL:", url);
let newChatButton = document.querySelector("#new-conv-button");
newChatButton.click();
setTimeout(() => {
let urlInput = document.querySelector("#quick-url-demo textarea");
// Fill the URL input
urlInput.value = url;
urlInput.dispatchEvent(new Event("input", { bubbles: true }));
urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'}));
}, 500
);
}
for (var i = 0; i < links.length; i++) {
links[i].onclick = submitPaper;
}
}
"""
clear_bot_message_selection_js = """
function() {
var bot_messages = document.querySelectorAll(
"div#main-chat-bot div.message-row.bot-row"
);
bot_messages.forEach(message => {
message.classList.remove("text_selection");
});
}
"""
pdfview_js = """ pdfview_js = """
function() { function() {
setTimeout(fullTextSearch(), 100);
// Get all links and attach click event // Get all links and attach click event
var links = document.getElementsByClassName("pdf-link"); var links = document.getElementsByClassName("pdf-link");
for (var i = 0; i < links.length; i++) { for (var i = 0; i < links.length; i++) {
@ -64,35 +127,42 @@ function() {
links[i].onclick = scrollToCitation; links[i].onclick = scrollToCitation;
} }
var mindmap_el = document.getElementById('mindmap'); var markmap_div = document.querySelector("div.markmap");
var mindmap_el_script = document.querySelector('div.markmap script');
if (mindmap_el_script) {
markmap_div_html = markmap_div.outerHTML;
}
// render the mindmap if the script tag is present
if (mindmap_el_script) {
markmap.autoLoader.renderAll();
}
setTimeout(() => {
var mindmap_el = document.querySelector('svg.markmap');
var text_nodes = document.querySelectorAll("svg.markmap div");
for (var i = 0; i < text_nodes.length; i++) {
text_nodes[i].onclick = fillChatInput;
}
if (mindmap_el) { if (mindmap_el) {
var output = svgPanZoom(mindmap_el);
const svg = mindmap_el.cloneNode(true);
function on_svg_export(event) { function on_svg_export(event) {
event.preventDefault(); // Prevent the default link behavior html = "{html_template}";
// convert to a valid XML source html = html.replace("{markmap_div}", markmap_div_html);
const as_text = new XMLSerializer().serializeToString(svg); spawnDocument(html, {window: "width=1000,height=1000"});
// store in a Blob
const blob = new Blob([as_text], { type: "image/svg+xml" });
// create an URI pointing to that blob
const url = URL.createObjectURL(blob);
const win = open(url);
// so the Garbage Collector can collect the blob
win.onload = (evt) => URL.revokeObjectURL(url);
} }
var link = document.getElementById("mindmap-toggle"); var link = document.getElementById("mindmap-toggle");
if (link) { if (link) {
link.onclick = function(event) { link.onclick = function(event) {
event.preventDefault(); // Prevent the default link behavior event.preventDefault(); // Prevent the default link behavior
var div = document.getElementById("mindmap-wrapper"); var div = document.querySelector("div.markmap");
if (div) { if (div) {
var currentHeight = div.style.height; var currentHeight = div.style.height;
if (currentHeight === '400px') { if (currentHeight === '400px' || (currentHeight === '')) {
var contentHeight = div.scrollHeight; div.style.height = '650px';
div.style.height = contentHeight + 'px';
} else { } else {
div.style.height = '400px' div.style.height = '400px'
} }
@ -100,14 +170,28 @@ function() {
}; };
} }
if (markmap_div_html) {
var link = document.getElementById("mindmap-export"); var link = document.getElementById("mindmap-export");
if (link) { if (link) {
link.addEventListener('click', on_svg_export); link.addEventListener('click', on_svg_export);
} }
} }
}
}, 250);
return [links.length] return [links.length]
} }
""".replace(
"{html_template}",
MINDMAP_HTML_EXPORT_TEMPLATE.replace("\n", "").replace('"', '\\"'),
)
fetch_api_key_js = """
function(_, __) {
api_key = getStorage('google_api_key', '');
console.log('session API key:', api_key);
return [api_key, _];
}
""" """
@ -126,6 +210,7 @@ class ChatPage(BasePage):
) )
self._info_panel_expanded = gr.State(value=True) self._info_panel_expanded = gr.State(value=True)
self._command_state = gr.State(value=None) self._command_state = gr.State(value=None)
self._user_api_key = gr.Text(value="", visible=False)
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
@ -146,7 +231,17 @@ class ChatPage(BasePage):
continue continue
index_ui.unrender() # need to rerender later within Accordion index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=index.name, open=index_id < 1): is_first_index = index_id == 0
index_name = index.name
if KH_DEMO_MODE and is_first_index:
index_name = "Select from Paper Collection"
with gr.Accordion(
label=index_name,
open=is_first_index,
elem_id=f"index-{index_id}",
):
index_ui.render() index_ui.render()
gr_index = index_ui.as_gradio_component() gr_index = index_ui.as_gradio_component()
@ -171,8 +266,16 @@ class ChatPage(BasePage):
self._indices_input.append(gr_index) self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui) setattr(self, f"_index_{index.id}", index_ui)
self.chat_suggestion = ChatSuggestion(self._app)
if len(self._app.index_manager.indices) > 0: if len(self._app.index_manager.indices) > 0:
with gr.Accordion(label="Quick Upload") as _: quick_upload_label = (
"Quick Upload" if not KH_DEMO_MODE else "Or input new paper URL"
)
with gr.Accordion(label=quick_upload_label) as _:
self.quick_file_upload_status = gr.Markdown()
if not KH_DEMO_MODE:
self.quick_file_upload = File( self.quick_file_upload = File(
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()), file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
file_count="multiple", file_count="multiple",
@ -181,89 +284,95 @@ class ChatPage(BasePage):
elem_id="quick-file", elem_id="quick-file",
) )
self.quick_urls = gr.Textbox( self.quick_urls = gr.Textbox(
placeholder="Or paste URLs here", placeholder=(
"Or paste URLs"
if not KH_DEMO_MODE
else "Paste Arxiv URLs\n(https://arxiv.org/abs/xxx)"
),
lines=1, lines=1,
container=False, container=False,
show_label=False, show_label=False,
elem_id="quick-url", elem_id=(
"quick-url" if not KH_DEMO_MODE else "quick-url-demo"
),
) )
self.quick_file_upload_status = gr.Markdown()
if not KH_DEMO_MODE:
self.report_issue = ReportIssue(self._app) self.report_issue = ReportIssue(self._app)
else:
with gr.Accordion(label="Related papers", open=False):
self.related_papers = gr.Markdown(elem_id="related-papers")
self.hint_page = HintPage(self._app)
with gr.Column(scale=6, elem_id="chat-area"): with gr.Column(scale=6, elem_id="chat-area"):
if KH_DEMO_MODE:
self.paper_list = PaperListPage(self._app)
self.chat_panel = ChatPanel(self._app) self.chat_panel = ChatPanel(self._app)
with gr.Row():
with gr.Accordion( with gr.Accordion(
label="Chat settings", label="Chat settings",
elem_id="chat-settings-expand", elem_id="chat-settings-expand",
open=False, open=False,
): visible=not KH_DEMO_MODE,
) as self.chat_settings:
with gr.Row(elem_id="quick-setting-labels"): with gr.Row(elem_id="quick-setting-labels"):
gr.HTML("Reasoning method") gr.HTML("Reasoning method")
gr.HTML("Model") gr.HTML(
"Model", visible=not KH_DEMO_MODE and not KH_SSO_ENABLED
)
gr.HTML("Language") gr.HTML("Language")
gr.HTML("Suggestion")
with gr.Row(): with gr.Row():
reasoning_type_values = [ reasoning_setting = (
(DEFAULT_SETTING, DEFAULT_SETTING) self._app.default_settings.reasoning.settings["use"]
] + self._app.default_settings.reasoning.settings[ )
"use" model_setting = self._app.default_settings.reasoning.options[
].choices "simple"
].settings["llm"]
language_setting = (
self._app.default_settings.reasoning.settings["lang"]
)
citation_setting = self._app.default_settings.reasoning.options[
"simple"
].settings["highlight_citation"]
self.reasoning_type = gr.Dropdown( self.reasoning_type = gr.Dropdown(
choices=reasoning_type_values, choices=reasoning_setting.choices[:REASONING_LIMITS],
value=DEFAULT_SETTING, value=reasoning_setting.value,
container=False, container=False,
show_label=False, show_label=False,
) )
self.model_type = gr.Dropdown( self.model_type = gr.Dropdown(
choices=self._app.default_settings.reasoning.options[ choices=model_setting.choices,
"simple" value=model_setting.value,
]
.settings["llm"]
.choices,
value="",
container=False, container=False,
show_label=False, show_label=False,
visible=not KH_DEMO_MODE and not KH_SSO_ENABLED,
) )
self.language = gr.Dropdown( self.language = gr.Dropdown(
choices=[ choices=language_setting.choices,
(DEFAULT_SETTING, DEFAULT_SETTING), value=language_setting.value,
]
+ self._app.default_settings.reasoning.settings[
"lang"
].choices,
value=DEFAULT_SETTING,
container=False, container=False,
show_label=False, show_label=False,
) )
self.use_chat_suggestion = gr.Checkbox(
label="Chat suggestion",
container=False,
elem_id="use-suggestion-checkbox",
)
self.citation = gr.Dropdown( self.citation = gr.Dropdown(
choices=[ choices=citation_setting.choices,
(DEFAULT_SETTING, DEFAULT_SETTING), value=citation_setting.value,
]
+ self._app.default_settings.reasoning.options["simple"]
.settings["highlight_citation"]
.choices,
value=DEFAULT_SETTING,
container=False, container=False,
show_label=False, show_label=False,
interactive=True, interactive=True,
elem_id="citation-dropdown", elem_id="citation-dropdown",
) )
self.use_mindmap = gr.State(value=DEFAULT_SETTING) self.use_mindmap = gr.State(value=True)
self.use_mindmap_check = gr.Checkbox( self.use_mindmap_check = gr.Checkbox(
label="Mindmap (default)", label="Mindmap (on)",
container=False, container=False,
elem_id="use-mindmap-checkbox", elem_id="use-mindmap-checkbox",
value=True,
) )
with gr.Column( with gr.Column(
@ -276,6 +385,9 @@ class ChatPage(BasePage):
self.plot_panel = gr.Plot(visible=False) self.plot_panel = gr.Plot(visible=False)
self.info_panel = gr.HTML(elem_id="html-info-panel") self.info_panel = gr.HTML(elem_id="html-info-panel")
self.followup_questions = self.chat_suggestion.examples
self.followup_questions_ui = self.chat_suggestion.accordion
def _json_to_plot(self, json_dict: dict | None): def _json_to_plot(self, json_dict: dict | None):
if json_dict: if json_dict:
plot = from_json(json_dict) plot = from_json(json_dict)
@ -285,8 +397,18 @@ class ChatPage(BasePage):
return plot return plot
def on_register_events(self): def on_register_events(self):
self.followup_questions = self.chat_control.chat_suggestion.examples # first index paper recommendation
self.followup_questions_ui = self.chat_control.chat_suggestion.accordion if KH_DEMO_MODE and len(self._indices_input) > 0:
self._indices_input[1].change(
self.get_recommendations,
inputs=[self.first_selector_choices, self._indices_input[1]],
outputs=[self.related_papers],
).then(
fn=None,
inputs=None,
outputs=None,
js=recommended_papers_js,
)
chat_event = ( chat_event = (
gr.on( gr.on(
@ -374,32 +496,25 @@ class ChatPage(BasePage):
) )
) )
# chat suggestion toggle onSuggestChatEvent = {
chat_event = chat_event.success( "fn": self.suggest_chat_conv,
fn=self.suggest_chat_conv, "inputs": [
inputs=[
self._app.settings_state, self._app.settings_state,
self.language,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self._use_suggestion, self._use_suggestion,
], ],
outputs=[ "outputs": [
self.followup_questions_ui, self.followup_questions_ui,
self.followup_questions, self.followup_questions,
], ],
show_progress="hidden", "show_progress": "hidden",
) }
# .success( # chat suggestion toggle
# self.chat_control.persist_chat_suggestions, chat_event = chat_event.success(**onSuggestChatEvent)
# inputs=[
# self.chat_control.conversation_id,
# self.followup_questions,
# self._use_suggestion,
# self._app.user_id,
# ],
# show_progress="hidden",
# )
# final data persist # final data persist
if not KH_DEMO_MODE:
chat_event = chat_event.then( chat_event = chat_event.then(
fn=self.persist_data_source, fn=self.persist_data_source,
inputs=[ inputs=[
@ -432,15 +547,44 @@ class ChatPage(BasePage):
fn=None, inputs=None, js="function() {toggleChatColumn();}" fn=None, inputs=None, js="function() {toggleChatColumn();}"
) )
self.chat_panel.chatbot.like( if KH_DEMO_MODE:
fn=self.is_liked, self.chat_control.btn_demo_logout.click(
inputs=[self.chat_control.conversation_id], fn=None,
outputs=None, js=self.chat_control.logout_js,
) )
self.chat_control.btn_new.click(
fn=lambda: self.chat_control.select_conv("", None),
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.followup_questions,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
self.state_plot_history,
self.chat_control.cb_is_public,
self.state_chat,
]
+ self._indices_input,
).then(
lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[self.paper_list.accordion, self.chat_settings],
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
)
if not KH_DEMO_MODE:
self.chat_control.btn_new.click( self.chat_control.btn_new.click(
self.chat_control.new_conv, self.chat_control.new_conv,
inputs=self._app.user_id, inputs=self._app.user_id,
outputs=[self.chat_control.conversation_id, self.chat_control.conversation], outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
],
show_progress="hidden", show_progress="hidden",
).then( ).then(
self.chat_control.select_conv, self.chat_control.select_conv,
@ -473,12 +617,18 @@ class ChatPage(BasePage):
self.chat_control.btn_del.click( self.chat_control.btn_del.click(
lambda id: self.toggle_delete(id), lambda id: self.toggle_delete(id),
inputs=[self.chat_control.conversation_id], inputs=[self.chat_control.conversation_id],
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], outputs=[
self.chat_control._new_delete,
self.chat_control._delete_confirm,
],
) )
self.chat_control.btn_del_conf.click( self.chat_control.btn_del_conf.click(
self.chat_control.delete_conv, self.chat_control.delete_conv,
inputs=[self.chat_control.conversation_id, self._app.user_id], inputs=[self.chat_control.conversation_id, self._app.user_id],
outputs=[self.chat_control.conversation_id, self.chat_control.conversation], outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
],
show_progress="hidden", show_progress="hidden",
).then( ).then(
self.chat_control.select_conv, self.chat_control.select_conv,
@ -504,11 +654,17 @@ class ChatPage(BasePage):
outputs=self.plot_panel, outputs=self.plot_panel,
).then( ).then(
lambda: self.toggle_delete(""), lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], outputs=[
self.chat_control._new_delete,
self.chat_control._delete_confirm,
],
) )
self.chat_control.btn_del_cnl.click( self.chat_control.btn_del_cnl.click(
lambda: self.toggle_delete(""), lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], outputs=[
self.chat_control._new_delete,
self.chat_control._delete_confirm,
],
) )
self.chat_control.btn_conversation_rn.click( self.chat_control.btn_conversation_rn.click(
lambda: gr.update(visible=True), lambda: gr.update(visible=True),
@ -532,6 +688,7 @@ class ChatPage(BasePage):
show_progress="hidden", show_progress="hidden",
) )
onConvSelect = (
self.chat_control.conversation.select( self.chat_control.conversation.select(
self.chat_control.select_conv, self.chat_control.select_conv,
inputs=[self.chat_control.conversation, self._app.user_id], inputs=[self.chat_control.conversation, self._app.user_id],
@ -550,22 +707,42 @@ class ChatPage(BasePage):
] ]
+ self._indices_input, + self._indices_input,
show_progress="hidden", show_progress="hidden",
).then( )
.then(
fn=self._json_to_plot, fn=self._json_to_plot,
inputs=self.state_plot_panel, inputs=self.state_plot_panel,
outputs=self.plot_panel, outputs=self.plot_panel,
).then( )
.then(
lambda: self.toggle_delete(""), lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], outputs=[
).then( self.chat_control._new_delete,
self.chat_control._delete_confirm,
],
)
)
if KH_DEMO_MODE:
onConvSelect = onConvSelect.then(
lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[self.paper_list.accordion, self.chat_settings],
)
onConvSelect = (
onConvSelect.then(
fn=lambda: True,
js=clear_bot_message_selection_js,
)
.then(
fn=lambda: True, fn=lambda: True,
inputs=None, inputs=None,
outputs=[self._preview_links], outputs=[self._preview_links],
js=pdfview_js, js=pdfview_js,
).then( )
fn=None, inputs=None, outputs=None, js=chat_input_focus_js .then(fn=None, inputs=None, outputs=None, js=chat_input_focus_js)
) )
if not KH_DEMO_MODE:
# evidence display on message selection # evidence display on message selection
self.chat_panel.chatbot.select( self.chat_panel.chatbot.select(
self.message_selected, self.message_selected,
@ -586,8 +763,6 @@ class ChatPage(BasePage):
inputs=None, inputs=None,
outputs=[self._preview_links], outputs=[self._preview_links],
js=pdfview_js, js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
) )
self.chat_control.cb_is_public.change( self.chat_control.cb_is_public.change(
@ -597,6 +772,13 @@ class ChatPage(BasePage):
show_progress="hidden", show_progress="hidden",
) )
if not KH_DEMO_MODE:
# user feedback events
self.chat_panel.chatbot.like(
fn=self.is_liked,
inputs=[self.chat_control.conversation_id],
outputs=None,
)
self.report_issue.report_btn.click( self.report_issue.report_btn.click(
self.report_issue.report, self.report_issue.report,
inputs=[ inputs=[
@ -613,6 +795,7 @@ class ChatPage(BasePage):
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
) )
self.reasoning_type.change( self.reasoning_type.change(
self.reasoning_changed, self.reasoning_changed,
inputs=[self.reasoning_type], inputs=[self.reasoning_type],
@ -624,11 +807,25 @@ class ChatPage(BasePage):
outputs=[self.use_mindmap, self.use_mindmap_check], outputs=[self.use_mindmap, self.use_mindmap_check],
show_progress="hidden", show_progress="hidden",
) )
self.use_chat_suggestion.change(
lambda x: (x, gr.update(visible=x)), def toggle_chat_suggestion(current_state):
inputs=[self.use_chat_suggestion], return current_state, gr.update(visible=current_state)
def raise_error_on_state(state):
if not state:
raise ValueError("Chat suggestion disabled")
self.chat_control.cb_suggest_chat.change(
fn=toggle_chat_suggestion,
inputs=[self.chat_control.cb_suggest_chat],
outputs=[self._use_suggestion, self.followup_questions_ui], outputs=[self._use_suggestion, self.followup_questions_ui],
show_progress="hidden", show_progress="hidden",
).then(
fn=raise_error_on_state,
inputs=[self._use_suggestion],
show_progress="hidden",
).success(
**onSuggestChatEvent
) )
self.chat_control.conversation_id.change( self.chat_control.conversation_id.change(
lambda: gr.update(visible=False), lambda: gr.update(visible=False),
@ -636,7 +833,7 @@ class ChatPage(BasePage):
) )
self.followup_questions.select( self.followup_questions.select(
self.chat_control.chat_suggestion.select_example, self.chat_suggestion.select_example,
outputs=[self.chat_panel.text_input], outputs=[self.chat_panel.text_input],
show_progress="hidden", show_progress="hidden",
).then( ).then(
@ -646,6 +843,22 @@ class ChatPage(BasePage):
js=chat_input_focus_js, js=chat_input_focus_js,
) )
if KH_DEMO_MODE:
self.paper_list.examples.select(
self.paper_list.select_example,
inputs=[self.paper_list.papers_state],
outputs=[self.quick_urls],
show_progress="hidden",
).then(
lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[self.paper_list.accordion, self.chat_settings],
).then(
fn=None,
inputs=None,
outputs=None,
js=quick_urls_submit_js,
)
def submit_msg( def submit_msg(
self, self,
chat_input, chat_input,
@ -655,8 +868,13 @@ class ChatPage(BasePage):
conv_id, conv_id,
conv_name, conv_name,
first_selector_choices, first_selector_choices,
request: gr.Request,
): ):
"""Submit a message to the chatbot""" """Submit a message to the chatbot"""
if KH_DEMO_MODE:
sso_user_id = check_rate_limit("chat", request)
print("User ID:", sso_user_id)
if not chat_input: if not chat_input:
raise ValueError("Input is empty") raise ValueError("Input is empty")
@ -685,6 +903,7 @@ class ChatPage(BasePage):
True, True,
settings, settings,
user_id, user_id,
request=None,
) )
elif file_names: elif file_names:
for file_name in file_names: for file_name in file_names:
@ -698,7 +917,11 @@ class ChatPage(BasePage):
# if file_ids is not empty and chat_input_text is empty # if file_ids is not empty and chat_input_text is empty
# set the input to summary # set the input to summary
if not chat_input_text and file_ids: if not chat_input_text and file_ids:
chat_input_text = "Summary" chat_input_text = DEFAULT_QUESTION
# if start of conversation and no query is specified
if not chat_input_text and not chat_history:
chat_input_text = DEFAULT_QUESTION
if file_ids: if file_ids:
selector_output = [ selector_output = [
@ -716,6 +939,7 @@ class ChatPage(BasePage):
raise gr.Error("Empty chat") raise gr.Error("Empty chat")
if not conv_id: if not conv_id:
if not KH_DEMO_MODE:
id_, update = self.chat_control.new_conv(user_id) id_, update = self.chat_control.new_conv(user_id)
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == id_) statement = select(Conversation).where(Conversation.id == id_)
@ -723,6 +947,8 @@ class ChatPage(BasePage):
new_conv_id = id_ new_conv_id = id_
conv_update = update conv_update = update
new_conv_name = name new_conv_name = name
else:
new_conv_id, new_conv_name, conv_update = None, None, gr.update()
else: else:
new_conv_id = conv_id new_conv_id = conv_id
conv_update = gr.update() conv_update = gr.update()
@ -740,6 +966,17 @@ class ChatPage(BasePage):
+ [used_command] + [used_command]
) )
def get_recommendations(self, first_selector_choices, file_ids):
first_selector_choices_map = {
item[1]: item[0] for item in first_selector_choices
}
file_names = [first_selector_choices_map[file_id] for file_id in file_ids]
if not file_names:
return ""
first_file_name = file_names[0].split(".")[0].replace("_", " ")
return get_recommended_papers(first_file_name)
def toggle_delete(self, conv_id): def toggle_delete(self, conv_id):
if conv_id: if conv_id:
return gr.update(visible=False), gr.update(visible=True) return gr.update(visible=False), gr.update(visible=True)
@ -789,17 +1026,41 @@ class ChatPage(BasePage):
self.chat_control.conversation, self.chat_control.conversation,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.followup_questions,
self.info_panel, self.info_panel,
self.state_plot_panel, self.state_plot_panel,
self.state_retrieval_history, self.state_retrieval_history,
self.state_plot_history, self.state_plot_history,
self.chat_control.cb_is_public, self.chat_control.cb_is_public,
self.state_chat,
] ]
+ self._indices_input, + self._indices_input,
"show_progress": "hidden", "show_progress": "hidden",
}, },
) )
def _on_app_created(self):
if KH_DEMO_MODE:
self._app.app.load(
fn=lambda x: x,
inputs=[self._user_api_key],
outputs=[self._user_api_key],
js=fetch_api_key_js,
).then(
fn=self.chat_control.toggle_demo_login_visibility,
inputs=[self._user_api_key],
outputs=[
self.chat_control.cb_suggest_chat,
self.chat_control.btn_new,
self.chat_control.btn_demo_logout,
self.chat_control.btn_demo_login,
],
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
)
def persist_data_source( def persist_data_source(
self, self,
convo_id, convo_id,
@ -1106,13 +1367,24 @@ class ChatPage(BasePage):
return new_name, renamed return new_name, renamed
def suggest_chat_conv(self, settings, chat_history, use_suggestion): def suggest_chat_conv(
self,
settings,
session_language,
chat_history,
use_suggestion,
):
target_language = (
session_language
if session_language not in (DEFAULT_SETTING, None)
else settings["reasoning.lang"]
)
if use_suggestion: if use_suggestion:
suggest_pipeline = SuggestFollowupQuesPipeline() suggest_pipeline = SuggestFollowupQuesPipeline()
suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English" target_language, "English"
) )
suggested_questions = [] suggested_questions = [[each] for each in ChatSuggestion.CHAT_SAMPLES]
if len(chat_history) >= 1: if len(chat_history) >= 1:
suggested_resp = suggest_pipeline(chat_history).text suggested_resp = suggest_pipeline(chat_history).text

View File

@ -1,5 +1,21 @@
import gradio as gr import gradio as gr
from ktem.app import BasePage from ktem.app import BasePage
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
if not KH_DEMO_MODE:
PLACEHOLDER_TEXT = (
"This is the beginning of a new conversation.\n"
"Start by uploading a file or a web URL. "
"Visit Files tab for more options (e.g: GraphRAG)."
)
else:
PLACEHOLDER_TEXT = (
"Welcome to Kotaemon Demo. "
"Start by browsing preloaded conversations to get onboard.\n"
"Check out Hint section for more tips."
)
class ChatPanel(BasePage): class ChatPanel(BasePage):
@ -10,10 +26,7 @@ class ChatPanel(BasePage):
def on_building_ui(self): def on_building_ui(self):
self.chatbot = gr.Chatbot( self.chatbot = gr.Chatbot(
label=self._app.app_name, label=self._app.app_name,
placeholder=( placeholder=PLACEHOLDER_TEXT,
"This is the beginning of a new conversation.\nIf you are new, "
"visit the Help tab for quick instructions."
),
show_label=False, show_label=False,
elem_id="main-chat-bot", elem_id="main-chat-bot",
show_copy_button=True, show_copy_button=True,

View File

@ -4,29 +4,34 @@ from theflow.settings import settings as flowsettings
class ChatSuggestion(BasePage): class ChatSuggestion(BasePage):
def __init__(self, app): CHAT_SAMPLES = getattr(
self._app = app
self.on_building_ui()
def on_building_ui(self):
chat_samples = getattr(
flowsettings, flowsettings,
"KH_FEATURE_CHAT_SUGGESTION_SAMPLES", "KH_FEATURE_CHAT_SUGGESTION_SAMPLES",
[ [
"Summary this document", "Summary this document",
"Generate a FAQ for this document", "Generate a FAQ for this document",
"Identify the main highlights in this text", "Identify the main highlights in bullet points",
], ],
) )
self.chat_samples = [[each] for each in chat_samples]
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
self.chat_samples = [[each] for each in self.CHAT_SAMPLES]
with gr.Accordion( with gr.Accordion(
label="Chat Suggestion", label="Chat Suggestion",
visible=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False), visible=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False),
) as self.accordion: ) as self.accordion:
self.default_example = gr.State(
value=self.chat_samples,
)
self.examples = gr.DataFrame( self.examples = gr.DataFrame(
value=self.chat_samples, value=self.chat_samples,
headers=["Next Question"], headers=["Next Question"],
interactive=False, interactive=False,
elem_id="chat-suggestion",
wrap=True, wrap=True,
) )

View File

@ -14,11 +14,22 @@ from .chat_suggestion import ChatSuggestion
from .common import STATE from .common import STATE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
ASSETS_DIR = "assets/icons" ASSETS_DIR = "assets/icons"
if not os.path.isdir(ASSETS_DIR): if not os.path.isdir(ASSETS_DIR):
ASSETS_DIR = "libs/ktem/ktem/assets/icons" ASSETS_DIR = "libs/ktem/ktem/assets/icons"
logout_js = """
function () {
removeFromStorage('google_api_key');
window.location.href = "/logout";
}
"""
def is_conv_name_valid(name): def is_conv_name_valid(name):
"""Check if the conversation name is valid""" """Check if the conversation name is valid"""
errors = [] errors = []
@ -35,11 +46,13 @@ class ConversationControl(BasePage):
def __init__(self, app): def __init__(self, app):
self._app = app self._app = app
self.logout_js = logout_js
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
gr.Markdown("## Conversations") title_text = "Conversations" if not KH_DEMO_MODE else "Kotaemon Papers"
gr.Markdown("## {}".format(title_text))
self.btn_toggle_dark_mode = gr.Button( self.btn_toggle_dark_mode = gr.Button(
value="", value="",
icon=f"{ASSETS_DIR}/dark_mode.svg", icon=f"{ASSETS_DIR}/dark_mode.svg",
@ -83,17 +96,28 @@ class ConversationControl(BasePage):
filterable=True, filterable=True,
interactive=True, interactive=True,
elem_classes=["unset-overflow"], elem_classes=["unset-overflow"],
elem_id="conversation-dropdown",
) )
with gr.Row() as self._new_delete: with gr.Row() as self._new_delete:
self.cb_suggest_chat = gr.Checkbox(
value=False,
label="Suggest chat",
min_width=10,
scale=6,
elem_id="suggest-chat-checkbox",
container=False,
visible=not KH_DEMO_MODE,
)
self.cb_is_public = gr.Checkbox( self.cb_is_public = gr.Checkbox(
value=False, value=False,
label="Shared", label="Share this conversation",
min_width=10,
scale=4,
elem_id="is-public-checkbox", elem_id="is-public-checkbox",
container=False, container=False,
visible=not KH_DEMO_MODE and not KH_SSO_ENABLED,
) )
if not KH_DEMO_MODE:
self.btn_conversation_rn = gr.Button( self.btn_conversation_rn = gr.Button(
value="", value="",
icon=f"{ASSETS_DIR}/rename.svg", icon=f"{ASSETS_DIR}/rename.svg",
@ -119,6 +143,41 @@ class ConversationControl(BasePage):
elem_classes=["no-background", "body-text-color"], elem_classes=["no-background", "body-text-color"],
elem_id="new-conv-button", elem_id="new-conv-button",
) )
else:
self.btn_new = gr.Button(
value="New chat",
min_width=120,
size="sm",
scale=1,
variant="primary",
elem_id="new-conv-button",
visible=False,
)
if KH_DEMO_MODE:
with gr.Row():
self.btn_demo_login = gr.Button(
"Sign-in to create new chat",
min_width=120,
size="sm",
scale=1,
variant="primary",
)
_js_redirect = """
() => {
url = '/login' + window.location.search;
window.open(url, '_blank');
}
"""
self.btn_demo_login.click(None, js=_js_redirect)
self.btn_demo_logout = gr.Button(
"Sign-out",
min_width=120,
size="sm",
scale=1,
visible=False,
)
with gr.Row(visible=False) as self._delete_confirm: with gr.Row(visible=False) as self._delete_confirm:
self.btn_del_conf = gr.Button( self.btn_del_conf = gr.Button(
@ -139,8 +198,6 @@ class ConversationControl(BasePage):
visible=False, visible=False,
) )
self.chat_suggestion = ChatSuggestion(self._app)
def load_chat_history(self, user_id): def load_chat_history(self, user_id):
"""Reload chat history""" """Reload chat history"""
@ -241,6 +298,8 @@ class ConversationControl(BasePage):
def select_conv(self, conversation_id, user_id): def select_conv(self, conversation_id, user_id):
"""Select the conversation""" """Select the conversation"""
default_chat_suggestions = [[each] for each in ChatSuggestion.CHAT_SAMPLES]
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id) statement = select(Conversation).where(Conversation.id == conversation_id)
try: try:
@ -257,7 +316,9 @@ class ConversationControl(BasePage):
selected = {} selected = {}
chats = result.data_source.get("messages", []) chats = result.data_source.get("messages", [])
chat_suggestions = result.data_source.get("chat_suggestions", []) chat_suggestions = result.data_source.get(
"chat_suggestions", default_chat_suggestions
)
retrieval_history: list[str] = result.data_source.get( retrieval_history: list[str] = result.data_source.get(
"retrieval_messages", [] "retrieval_messages", []
@ -282,7 +343,7 @@ class ConversationControl(BasePage):
name = "" name = ""
selected = {} selected = {}
chats = [] chats = []
chat_suggestions = [] chat_suggestions = default_chat_suggestions
retrieval_history = [] retrieval_history = []
plot_history = [] plot_history = []
info_panel = "" info_panel = ""
@ -317,25 +378,21 @@ class ConversationControl(BasePage):
def rename_conv(self, conversation_id, new_name, is_renamed, user_id): def rename_conv(self, conversation_id, new_name, is_renamed, user_id):
"""Rename the conversation""" """Rename the conversation"""
if not is_renamed: if not is_renamed or KH_DEMO_MODE or user_id is None or not conversation_id:
return ( return (
gr.update(), gr.update(),
conversation_id, conversation_id,
gr.update(visible=False), gr.update(visible=False),
) )
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
if not conversation_id:
gr.Warning("No conversation selected.")
return gr.update(), ""
errors = is_conv_name_valid(new_name) errors = is_conv_name_valid(new_name)
if errors: if errors:
gr.Warning(errors) gr.Warning(errors)
return gr.update(), conversation_id return (
gr.update(),
conversation_id,
gr.update(visible=False),
)
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id) statement = select(Conversation).where(Conversation.id == conversation_id)
@ -382,6 +439,29 @@ class ConversationControl(BasePage):
gr.Info("Chat suggestions updated.") gr.Info("Chat suggestions updated.")
def toggle_demo_login_visibility(self, user_api_key, request: gr.Request):
try:
import gradiologin as grlogin
user = grlogin.get_user(request)
except (ImportError, AssertionError):
user = None
if user: # or user_api_key:
return [
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=True),
gr.update(visible=False),
]
else:
return [
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=True),
]
def _on_app_created(self): def _on_app_created(self):
"""Reload the conversation once the app is created""" """Reload the conversation once the app is created"""
self._app.app.load( self._app.app.load(

View File

@ -0,0 +1,23 @@
from textwrap import dedent
import gradio as gr
from ktem.app import BasePage
class HintPage(BasePage):
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Hint", open=False):
gr.Markdown(
dedent(
"""
- You can select any text from the chat answer to **highlight relevant citation(s)** on the right panel.
- **Citations** can be viewed on both PDF viewer and raw text.
- You can tweak the citation format and use advance (CoT) reasoning in **Chat settings** menu.
- Want to **explore more**? Check out the **Help** section to create your private space.
""" # noqa
)
)

View File

@ -0,0 +1,41 @@
import gradio as gr
from ktem.app import BasePage
from pandas import DataFrame
from ...utils.hf_papers import fetch_papers
class PaperListPage(BasePage):
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
self.papers_state = gr.State(None)
with gr.Accordion(
label="Browse popular daily papers",
open=True,
) as self.accordion:
self.examples = gr.DataFrame(
value=[],
headers=["title", "url", "upvotes"],
column_widths=[60, 30, 10],
interactive=False,
elem_id="paper-suggestion",
wrap=True,
)
return self.examples
def load(self):
papers = fetch_papers(top_n=5)
papers_df = DataFrame(papers)
return (papers_df, papers)
def _on_app_created(self):
self._app.app.load(
self.load,
outputs=[self.examples, self.papers_state],
)
def select_example(self, state, ev: gr.SelectData):
return state[ev.index[0]]["url"]

View File

@ -12,7 +12,7 @@ class ReportIssue(BasePage):
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
with gr.Accordion(label="Feedback", open=False): with gr.Accordion(label="Feedback", open=False, elem_id="report-accordion"):
self.correctness = gr.Radio( self.correctness = gr.Radio(
choices=[ choices=[
("The answer is correct", "correct"), ("The answer is correct", "correct"),

View File

@ -3,8 +3,12 @@ from pathlib import Path
import gradio as gr import gradio as gr
import requests import requests
from decouple import config
from theflow.settings import settings from theflow.settings import settings
KH_DEMO_MODE = getattr(settings, "KH_DEMO_MODE", False)
HF_SPACE_URL = config("HF_SPACE_URL", default="")
def get_remote_doc(url: str) -> str: def get_remote_doc(url: str) -> str:
try: try:
@ -59,6 +63,22 @@ class HelpPage:
about_md = f"Version: {self.app_version}\n\n{about_md}" about_md = f"Version: {self.app_version}\n\n{about_md}"
gr.Markdown(about_md) gr.Markdown(about_md)
if KH_DEMO_MODE:
with gr.Accordion("Create Your Own Space"):
gr.Markdown(
"This is a demo with limited functionality. "
"Use **Create space** button to install Kotaemon "
"in your own space with all features "
"(including upload and manage your private "
"documents securely)."
)
gr.Button(
value="Create Your Own Space",
link=HF_SPACE_URL,
variant="primary",
size="lg",
)
user_guide_md_dir = self.doc_dir / "usage.md" user_guide_md_dir = self.doc_dir / "usage.md"
if user_guide_md_dir.exists(): if user_guide_md_dir.exists():
with (self.doc_dir / "usage.md").open(encoding="utf-8") as fi: with (self.doc_dir / "usage.md").open(encoding="utf-8") as fi:
@ -68,7 +88,7 @@ class HelpPage:
f"{self.remote_content_url}/v{self.app_version}/docs/usage.md" f"{self.remote_content_url}/v{self.app_version}/docs/usage.md"
) )
if user_guide_md: if user_guide_md:
with gr.Accordion("User Guide"): with gr.Accordion("User Guide", open=not KH_DEMO_MODE):
gr.Markdown(user_guide_md) gr.Markdown(user_guide_md)
if self.app_version: if self.app_version:

View File

@ -3,6 +3,7 @@ import hashlib
import gradio as gr import gradio as gr
from ktem.app import BasePage from ktem.app import BasePage
from ktem.db.models import User, engine from ktem.db.models import User, engine
from ktem.pages.resources.user import create_user
from sqlmodel import Session, select from sqlmodel import Session, select
fetch_creds = """ fetch_creds = """
@ -85,7 +86,35 @@ class LoginPage(BasePage):
}, },
) )
def login(self, usn, pwd): def login(self, usn, pwd, request: gr.Request):
try:
import gradiologin as grlogin
user = grlogin.get_user(request)
except (ImportError, AssertionError):
user = None
if user:
user_id = user["sub"]
with Session(engine) as session:
stmt = select(User).where(
User.id == user_id,
)
result = session.exec(stmt).all()
if result:
print("Existing user:", user)
return user_id, "", ""
else:
print("Creating new user:", user)
create_user(
usn=user["email"],
pwd="",
user_id=user_id,
is_admin=False,
)
return user_id, "", ""
else:
if not usn or not pwd: if not usn or not pwd:
return None, usn, pwd return None, usn, pwd

View File

@ -94,7 +94,7 @@ def validate_password(pwd, pwd_cnf):
return "" return ""
def create_user(usn, pwd) -> bool: def create_user(usn, pwd, user_id=None, is_admin=True) -> bool:
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower()) statement = select(User).where(User.username_lower == usn.lower())
result = session.exec(statement).all() result = session.exec(statement).all()
@ -105,10 +105,11 @@ def create_user(usn, pwd) -> bool:
else: else:
hashed_password = hashlib.sha256(pwd.encode()).hexdigest() hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
user = User( user = User(
id=user_id,
username=usn, username=usn,
username_lower=usn.lower(), username_lower=usn.lower(),
password=hashed_password, password=hashed_password,
admin=True, admin=is_admin,
) )
session.add(user) session.add(user)
session.commit() session.commit()
@ -136,11 +137,12 @@ class UserManagement(BasePage):
self.state_user_list = gr.State(value=None) self.state_user_list = gr.State(value=None)
self.user_list = gr.DataFrame( self.user_list = gr.DataFrame(
headers=["id", "name", "admin"], headers=["id", "name", "admin"],
column_widths=[0, 50, 50],
interactive=False, interactive=False,
) )
with gr.Group(visible=False) as self._selected_panel: with gr.Group(visible=False) as self._selected_panel:
self.selected_user_id = gr.Number(value=-1, visible=False) self.selected_user_id = gr.State(value=-1)
self.usn_edit = gr.Textbox(label="Username") self.usn_edit = gr.Textbox(label="Username")
with gr.Row(): with gr.Row():
self.pwd_edit = gr.Textbox(label="Change password", type="password") self.pwd_edit = gr.Textbox(label="Change password", type="password")
@ -346,7 +348,7 @@ class UserManagement(BasePage):
if not ev.selected: if not ev.selected:
return -1 return -1
return int(user_list["id"][ev.index[0]]) return user_list["id"][ev.index[0]]
def on_selected_user_change(self, selected_user_id): def on_selected_user_change(self, selected_user_id):
if selected_user_id == -1: if selected_user_id == -1:
@ -367,7 +369,7 @@ class UserManagement(BasePage):
btn_delete_no = gr.update(visible=False) btn_delete_no = gr.update(visible=False)
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id)) statement = select(User).where(User.id == selected_user_id)
user = session.exec(statement).one() user = session.exec(statement).one()
usn_edit = gr.update(value=user.username) usn_edit = gr.update(value=user.username)
@ -414,7 +416,7 @@ class UserManagement(BasePage):
return pwd, pwd_cnf return pwd, pwd_cnf
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id)) statement = select(User).where(User.id == selected_user_id)
user = session.exec(statement).one() user = session.exec(statement).one()
user.username = usn user.username = usn
user.username_lower = usn.lower() user.username_lower = usn.lower()
@ -432,7 +434,7 @@ class UserManagement(BasePage):
return selected_user_id return selected_user_id
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id)) statement = select(User).where(User.id == selected_user_id)
user = session.exec(statement).one() user = session.exec(statement).one()
session.delete(user) session.delete(user)
session.commit() session.commit()

View File

@ -5,6 +5,10 @@ from ktem.app import BasePage
from ktem.components import reasonings from ktem.components import reasonings
from ktem.db.models import Settings, User, engine from ktem.db.models import Settings, User, engine
from sqlmodel import Session, select from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False)
signout_js = """ signout_js = """
function(u, c, pw, pwc) { function(u, c, pw, pwc) {
@ -80,11 +84,14 @@ class SettingsPage(BasePage):
# render application page if there are application settings # render application page if there are application settings
self._render_app_tab = False self._render_app_tab = False
if self._default_settings.application.settings:
if not KH_SSO_ENABLED and self._default_settings.application.settings:
self._render_app_tab = True self._render_app_tab = True
# render index page if there are index settings (general and/or specific) # render index page if there are index settings (general and/or specific)
self._render_index_tab = False self._render_index_tab = False
if not KH_SSO_ENABLED:
if self._default_settings.index.settings: if self._default_settings.index.settings:
self._render_index_tab = True self._render_index_tab = True
else: else:
@ -95,6 +102,8 @@ class SettingsPage(BasePage):
# render reasoning page if there are reasoning settings # render reasoning page if there are reasoning settings
self._render_reasoning_tab = False self._render_reasoning_tab = False
if not KH_SSO_ENABLED:
if len(self._default_settings.reasoning.settings) > 1: if len(self._default_settings.reasoning.settings) > 1:
self._render_reasoning_tab = True self._render_reasoning_tab = True
else: else:
@ -106,6 +115,7 @@ class SettingsPage(BasePage):
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
if not KH_SSO_ENABLED:
self.setting_save_btn = gr.Button( self.setting_save_btn = gr.Button(
"Save & Close", "Save & Close",
variant="primary", variant="primary",
@ -175,6 +185,7 @@ class SettingsPage(BasePage):
) )
def on_register_events(self): def on_register_events(self):
if not KH_SSO_ENABLED:
self.setting_save_btn.click( self.setting_save_btn.click(
self.save_setting, self.save_setting,
inputs=[self._user_id] + self.components(), inputs=[self._user_id] + self.components(),
@ -189,7 +200,7 @@ class SettingsPage(BasePage):
outputs=list(self._reasoning_mode.values()), outputs=list(self._reasoning_mode.values()),
show_progress="hidden", show_progress="hidden",
) )
if self._app.f_user_management: if self._app.f_user_management and not KH_SSO_ENABLED:
self.password_change_btn.click( self.password_change_btn.click(
self.change_password, self.change_password,
inputs=[ inputs=[
@ -223,6 +234,12 @@ class SettingsPage(BasePage):
def user_tab(self): def user_tab(self):
# user management # user management
self.current_name = gr.Markdown("Current user: ___") self.current_name = gr.Markdown("Current user: ___")
if KH_SSO_ENABLED:
import gradiologin as grlogin
self.sso_signout = grlogin.LogoutButton("Logout")
else:
self.signout = gr.Button("Logout") self.signout = gr.Button("Logout")
self.password_change = gr.Textbox( self.password_change = gr.Textbox(

View File

@ -2,13 +2,13 @@ import json
import gradio as gr import gradio as gr
import requests import requests
from decouple import config
from ktem.app import BasePage from ktem.app import BasePage
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
from ktem.llms.manager import llms from ktem.llms.manager import llms
from ktem.rerankings.manager import reranking_models_manager as rerankers from ktem.rerankings.manager import reranking_models_manager as rerankers
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/") KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api") DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
if DEFAULT_OLLAMA_URL.endswith("/"): if DEFAULT_OLLAMA_URL.endswith("/"):
@ -113,9 +113,18 @@ class SetupPage(BasePage):
( (
"#### Setup Ollama\n\n" "#### Setup Ollama\n\n"
"Download and install Ollama from " "Download and install Ollama from "
"https://ollama.com/" "https://ollama.com/. Check out latest models at "
"https://ollama.com/library. "
) )
) )
self.ollama_model_name = gr.Textbox(
label="LLM model name",
value=config("LOCAL_MODEL", default="qwen2.5:7b"),
)
self.ollama_emb_model_name = gr.Textbox(
label="Embedding model name",
value=config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
)
self.setup_log = gr.HTML( self.setup_log = gr.HTML(
show_label=False, show_label=False,
@ -139,12 +148,13 @@ class SetupPage(BasePage):
self.cohere_api_key, self.cohere_api_key,
self.openai_api_key, self.openai_api_key,
self.google_api_key, self.google_api_key,
self.ollama_model_name,
self.ollama_emb_model_name,
self.radio_model, self.radio_model,
], ],
outputs=[self.setup_log], outputs=[self.setup_log],
show_progress="hidden", show_progress="hidden",
) )
if not KH_DEMO_MODE:
onSkipSetup = gr.on( onSkipSetup = gr.on(
triggers=[self.btn_skip.click], triggers=[self.btn_skip.click],
fn=lambda: None, fn=lambda: None,
@ -181,12 +191,10 @@ class SetupPage(BasePage):
cohere_api_key, cohere_api_key,
openai_api_key, openai_api_key,
google_api_key, google_api_key,
ollama_model_name,
ollama_emb_model_name,
radio_model_value, radio_model_value,
): ):
# skip if KH_DEMO_MODE
if KH_DEMO_MODE:
raise gr.Error(DEMO_MESSAGE)
log_content = "" log_content = ""
if not radio_model_value: if not radio_model_value:
gr.Info("Skip setup models.") gr.Info("Skip setup models.")
@ -274,7 +282,7 @@ class SetupPage(BasePage):
spec={ spec={
"__type__": "kotaemon.llms.ChatOpenAI", "__type__": "kotaemon.llms.ChatOpenAI",
"base_url": KH_OLLAMA_URL, "base_url": KH_OLLAMA_URL,
"model": "llama3.1:8b", "model": ollama_model_name,
"api_key": "ollama", "api_key": "ollama",
}, },
default=True, default=True,
@ -284,7 +292,7 @@ class SetupPage(BasePage):
spec={ spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings", "__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": KH_OLLAMA_URL, "base_url": KH_OLLAMA_URL,
"model": "nomic-embed-text", "model": ollama_emb_model_name,
"api_key": "ollama", "api_key": "ollama",
}, },
default=True, default=True,

View File

@ -1,4 +1,5 @@
import logging import logging
from textwrap import dedent
from ktem.llms.manager import llms from ktem.llms.manager import llms
@ -8,6 +9,31 @@ from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MINDMAP_HTML_EXPORT_TEMPLATE = dedent(
"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Mindmap</title>
<style>
svg.markmap {
width: 100%;
height: 100vh;
}
</style>
<script src="https://cdn.jsdelivr.net/npm/markmap-autoloader@0.16"></script>
</head>
<body>
{markmap_div}
</body>
</html>
"""
)
class CreateMindmapPipeline(BaseComponent): class CreateMindmapPipeline(BaseComponent):
"""Create a mindmap from the question and context""" """Create a mindmap from the question and context"""
@ -37,6 +63,20 @@ Use the template like this:
""" # noqa: E501 """ # noqa: E501
prompt_template: str = MINDMAP_PROMPT_TEMPLATE prompt_template: str = MINDMAP_PROMPT_TEMPLATE
@classmethod
def convert_uml_to_markdown(cls, text: str) -> str:
start_phrase = "@startmindmap"
end_phrase = "@endmindmap"
try:
text = text.split(start_phrase)[-1]
text = text.split(end_phrase)[0]
text = text.strip().replace("*", "#")
except IndexError:
text = ""
return text
def run(self, question: str, context: str) -> Document: # type: ignore def run(self, question: str, context: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.prompt_template) prompt_template = PromptTemplate(self.prompt_template)
prompt = prompt_template.populate( prompt = prompt_template.populate(
@ -49,4 +89,9 @@ Use the template like this:
HumanMessage(content=prompt), HumanMessage(content=prompt),
] ]
return self.llm(messages) uml_text = self.llm(messages).text
markdown_text = self.convert_uml_to_markdown(uml_text)
return Document(
text=markdown_text,
)

View File

@ -15,12 +15,10 @@ class SuggestFollowupQuesPipeline(BaseComponent):
SUGGEST_QUESTIONS_PROMPT_TEMPLATE = ( SUGGEST_QUESTIONS_PROMPT_TEMPLATE = (
"Based on the chat history above. " "Based on the chat history above. "
"your task is to generate 3 to 5 relevant follow-up questions. " "your task is to generate 3 to 5 relevant follow-up questions. "
"These questions should be simple, clear, " "These questions should be simple, very concise, "
"and designed to guide the conversation further. " "and designed to guide the conversation further. "
"Ensure that the questions are open-ended to encourage detailed responses. "
"Respond in JSON format with 'questions' key. " "Respond in JSON format with 'questions' key. "
"Answer using the language {lang} same as the question. " "Answer using the language {lang} same as the question. "
"If the question uses Chinese, the answer should be in Chinese.\n"
) )
prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE
extra_prompt: str = """Example of valid response: extra_prompt: str = """Example of valid response:

View File

@ -1,5 +1,6 @@
import logging import logging
import threading import threading
from textwrap import dedent
from typing import Generator from typing import Generator
from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.embeddings.manager import embedding_models_manager as embeddings
@ -8,7 +9,6 @@ from ktem.reasoning.prompt_optimization import (
DecomposeQuestionPipeline, DecomposeQuestionPipeline,
RewriteQuestionPipeline, RewriteQuestionPipeline,
) )
from ktem.utils.plantuml import PlantUML
from ktem.utils.render import Render from ktem.utils.render import Render
from ktem.utils.visualize_cited import CreateCitationVizPipeline from ktem.utils.visualize_cited import CreateCitationVizPipeline
from plotly.io import to_json from plotly.io import to_json
@ -165,21 +165,23 @@ class FullQAPipeline(BaseReasoning):
mindmap = answer.metadata["mindmap"] mindmap = answer.metadata["mindmap"]
if mindmap: if mindmap:
mindmap_text = mindmap.text mindmap_text = mindmap.text
uml_renderer = PlantUML() mindmap_svg = dedent(
"""
try: <div class="markmap">
mindmap_svg = uml_renderer.process(mindmap_text) <script type="text/template">
except Exception as e: ---
print("Failed to process mindmap:", e) markmap:
mindmap_svg = "<svg></svg>" colorFreezeLevel: 2
activeNode:
# post-process the mindmap SVG placement: center
mindmap_svg = ( initialExpandLevel: 4
mindmap_svg.replace("sans-serif", "Quicksand, sans-serif") maxWidth: 200
.replace("#181818", "#cecece") ---
.replace("background:#FFFFF", "background:none") {}
.replace("stroke-width:1", "stroke-width:2") </script>
) </div>
"""
).format(mindmap_text)
mindmap_content = Document( mindmap_content = Document(
channel="info", channel="info",
@ -323,7 +325,7 @@ class FullQAPipeline(BaseReasoning):
def prepare_pipeline_instance(cls, settings, retrievers): def prepare_pipeline_instance(cls, settings, retrievers):
return cls( return cls(
retrievers=retrievers, retrievers=retrievers,
rewrite_pipeline=RewriteQuestionPipeline(), rewrite_pipeline=None,
) )
@classmethod @classmethod
@ -411,8 +413,8 @@ class FullQAPipeline(BaseReasoning):
"value": "highlight", "value": "highlight",
"component": "radio", "component": "radio",
"choices": [ "choices": [
("highlight (verbose)", "highlight"), ("citation: highlight", "highlight"),
("inline (concise)", "inline"), ("citation: inline", "inline"),
("no citation", "off"), ("no citation", "off"),
], ],
}, },
@ -433,7 +435,7 @@ class FullQAPipeline(BaseReasoning):
}, },
"system_prompt": { "system_prompt": {
"name": "System Prompt", "name": "System Prompt",
"value": "This is a question answering system", "value": ("This is a question answering system."),
}, },
"qa_prompt": { "qa_prompt": {
"name": "QA Prompt (contains {context}, {question}, {lang})", "name": "QA Prompt (contains {context}, {question}, {lang})",

View File

@ -0,0 +1,114 @@
from datetime import datetime, timedelta
import requests
from cachetools import TTLCache, cached
HF_API_URL = "https://huggingface.co/api/daily_papers"
ARXIV_URL = "https://arxiv.org/abs/{paper_id}"
SEMANTIC_SCHOLAR_QUERY_URL = "https://api.semanticscholar.org/graph/v1/paper/search/match?query={paper_name}" # noqa
SEMANTIC_SCHOLAR_RECOMMEND_URL = (
"https://api.semanticscholar.org/recommendations/v1/papers/" # noqa
)
CACHE_TIME = 60 * 60 * 6 # 6 hours
# Function to parse the date string
def parse_date(date_str):
return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ")
@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME))
def get_recommendations_from_semantic_scholar(semantic_scholar_id: str):
try:
r = requests.post(
SEMANTIC_SCHOLAR_RECOMMEND_URL,
json={
"positivePaperIds": [semantic_scholar_id],
},
params={"fields": "externalIds,title,year", "limit": 14}, # type: ignore
)
return r.json()["recommendedPapers"]
except KeyError as e:
print(e)
return []
def filter_recommendations(recommendations, max_paper_count=5):
# include only arxiv papers
arxiv_paper = [
r for r in recommendations if r["externalIds"].get("ArXiv", None) is not None
]
if len(arxiv_paper) > max_paper_count:
arxiv_paper = arxiv_paper[:max_paper_count]
return arxiv_paper
def format_recommendation_into_markdown(recommendations):
comment = "(recommended by the Semantic Scholar API)\n\n"
for r in recommendations:
hub_paper_url = f"https://arxiv.org/abs/{r['externalIds']['ArXiv']}"
comment += f"* [{r['title']}]({hub_paper_url}) ({r['year']})\n"
return comment
def get_paper_id_from_name(paper_name):
try:
response = requests.get(
SEMANTIC_SCHOLAR_QUERY_URL.format(paper_name=paper_name)
)
response.raise_for_status()
items = response.json()
paper_id = items.get("data", [])[0].get("paperId")
except Exception as e:
print(e)
return None
return paper_id
def get_recommended_papers(paper_name):
paper_id = get_paper_id_from_name(paper_name)
recommended_content = ""
if paper_id is None:
return recommended_content
recommended_papers = get_recommendations_from_semantic_scholar(paper_id)
filtered_recommendations = filter_recommendations(recommended_papers)
recommended_content = format_recommendation_into_markdown(filtered_recommendations)
return recommended_content
def fetch_papers(top_n=5):
try:
response = requests.get(f"{HF_API_URL}?limit=100")
response.raise_for_status()
items = response.json()
# Calculate the date 3 days ago from now
three_days_ago = datetime.now() - timedelta(days=3)
# Filter items from the last 3 days
recent_items = [
item
for item in items
if parse_date(item.get("publishedAt")) >= three_days_ago
]
recent_items.sort(
key=lambda x: x.get("paper", {}).get("upvotes", 0), reverse=True
)
output_items = [
{
"title": item.get("paper", {}).get("title"),
"url": ARXIV_URL.format(paper_id=item.get("paper", {}).get("id")),
"upvotes": item.get("paper", {}).get("upvotes"),
}
for item in recent_items[:top_n]
]
except Exception as e:
print(e)
return []
return output_items

View File

@ -0,0 +1,48 @@
from collections import defaultdict
from datetime import datetime, timedelta
import gradio as gr
from decouple import config
# In-memory store for rate limiting (for demonstration purposes)
rate_limit_store: dict[str, dict] = defaultdict(dict)
# Rate limit configuration
RATE_LIMIT = config("RATE_LIMIT", default=20, cast=int)
RATE_LIMIT_PERIOD = timedelta(hours=24)
def check_rate_limit(limit_type: str, request: gr.Request):
if request is None:
raise ValueError("This feature is not available")
user_id = None
try:
import gradiologin as grlogin
user = grlogin.get_user(request)
if user:
user_id = user.get("email")
except (ImportError, AssertionError):
pass
if not user_id:
raise ValueError("Please sign-in to use this feature")
now = datetime.now()
user_data = rate_limit_store[limit_type].get(
user_id, {"count": 0, "reset_time": now + RATE_LIMIT_PERIOD}
)
if now >= user_data["reset_time"]:
# Reset the rate limit for the user
user_data = {"count": 0, "reset_time": now + RATE_LIMIT_PERIOD}
if user_data["count"] >= RATE_LIMIT:
raise ValueError("Rate limit exceeded. Please try again later.")
# Increment the request count
user_data["count"] += 1
rate_limit_store[limit_type][user_id] = user_data
return user_id

View File

@ -5,7 +5,7 @@ from fast_langdetect import detect
from kotaemon.base import RetrievedDocument from kotaemon.base import RetrievedDocument
BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "") BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "")
def is_close(val1, val2, tolerance=1e-9): def is_close(val1, val2, tolerance=1e-9):
@ -44,7 +44,8 @@ class Render:
o = " open" if open else "" o = " open" if open else ""
return ( return (
f"<details class='evidence' {o}><summary>" f"<details class='evidence' {o}><summary>"
f"{header}</summary>{content}</details><br>" f"{header}</summary>{content}"
"</details><br>"
) )
@staticmethod @staticmethod
@ -225,6 +226,9 @@ class Render:
doc, doc,
highlight_text=highlight_text, highlight_text=highlight_text,
) )
rendered_doc_content = (
f"<div class='evidence-content'>{rendered_doc_content}</div>"
)
return Render.collapsible( return Render.collapsible(
header=rendered_header, header=rendered_header,

View File

@ -27,6 +27,7 @@ dependencies = [
"sqlmodel>=0.0.16,<0.1", "sqlmodel>=0.0.16,<0.1",
"tiktoken>=0.6.0,<1", "tiktoken>=0.6.0,<1",
"gradio>=4.31.0,<5", "gradio>=4.31.0,<5",
"gradiologin",
"python-multipart==0.0.12", # required for gradio, pinning to avoid yanking issues with micropip (fixed in gradio >= 5.4.0) "python-multipart==0.0.12", # required for gradio, pinning to avoid yanking issues with micropip (fixed in gradio >= 5.4.0)
"markdown>=3.6,<4", "markdown>=3.6,<4",
"tzlocal>=5.0", "tzlocal>=5.0",

51
sso_app.py Normal file
View File

@ -0,0 +1,51 @@
import os
import gradiologin as grlogin
from decouple import config
from fastapi import FastAPI
from fastapi.responses import FileResponse
from theflow.settings import settings as flowsettings
KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".")
GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None)
# override GRADIO_TEMP_DIR if it's not set
if GRADIO_TEMP_DIR is None:
GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp")
os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID", default="")
GOOGLE_CLIENT_SECRET = config("GOOGLE_CLIENT_SECRET", default="")
from ktem.main import App # noqa
gradio_app = App()
demo = gradio_app.make()
app = FastAPI()
grlogin.register(
name="google",
server_metadata_url="https://accounts.google.com/.well-known/openid-configuration",
client_id=GOOGLE_CLIENT_ID,
client_secret=GOOGLE_CLIENT_SECRET,
client_kwargs={
"scope": "openid email profile",
},
)
@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
return FileResponse(gradio_app._favicon)
grlogin.mount_gradio_app(
app,
demo,
"/app",
allowed_paths=[
"libs/ktem/ktem/assets",
GRADIO_TEMP_DIR,
],
)

97
sso_app_demo.py Normal file
View File

@ -0,0 +1,97 @@
import os
import gradio as gr
from authlib.integrations.starlette_client import OAuth, OAuthError
from decouple import config
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from starlette.config import Config
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import RedirectResponse
from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".")
GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None)
# override GRADIO_TEMP_DIR if it's not set
if GRADIO_TEMP_DIR is None:
GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp")
os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID", default="")
GOOGLE_CLIENT_SECRET = config("GOOGLE_CLIENT_SECRET", default="")
SECRET_KEY = config("SECRET_KEY", default="default-secret-key")
def add_session_middleware(app):
config_data = {
"GOOGLE_CLIENT_ID": GOOGLE_CLIENT_ID,
"GOOGLE_CLIENT_SECRET": GOOGLE_CLIENT_SECRET,
}
starlette_config = Config(environ=config_data)
oauth = OAuth(starlette_config)
oauth.register(
name="google",
server_metadata_url=(
"https://accounts.google.com/" ".well-known/openid-configuration"
),
client_kwargs={"scope": "openid email profile"},
)
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY)
return oauth
from ktem.main import App # noqa
gradio_app = App()
main_demo = gradio_app.make()
app = FastAPI()
oauth = add_session_middleware(app)
@app.get("/")
def public(request: Request):
root_url = gr.route_utils.get_root_url(request, "/", None)
return RedirectResponse(url=f"{root_url}/app/")
@app.get("/favicon.ico", include_in_schema=False)
async def favicon():
return FileResponse(gradio_app._favicon)
@app.route("/logout")
async def logout(request: Request):
request.session.pop("user", None)
return RedirectResponse(url="/")
@app.route("/login")
async def login(request: Request):
root_url = gr.route_utils.get_root_url(request, "/login", None)
redirect_uri = f"{root_url}/auth"
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.route("/auth")
async def auth(request: Request):
try:
access_token = await oauth.google.authorize_access_token(request)
except OAuthError:
return RedirectResponse(url="/")
request.session["user"] = dict(access_token)["userinfo"]
return RedirectResponse(url="/")
app = gr.mount_gradio_app(
app,
main_demo,
path="/app",
allowed_paths=[
"libs/ktem/ktem/assets",
GRADIO_TEMP_DIR,
],
)