From 3bd3830b8d64c6e813fb55c2fb8afbe687c7b55a Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Sun, 2 Feb 2025 15:19:48 +0700 Subject: [PATCH] feat: sso login, demo mode & new mindmap support (#644) bump:minor * fix: update .env.example * feat: add SSO login * fix: update flowsetting * fix: add requirement * fix: refine UI * fix: update group id-based operation * fix: improve citation logics * fix: UI enhancement * fix: user_id to string in models * fix: improve chat suggestion UI and flow * fix: improve group id handling * fix: improve chat suggestion * fix: secure download for single file * fix: file limiting in docstore * fix: improve chat suggestion logics & language conform * feat: add markmap and select text to highlight function * fix: update Dockerfile * fix: user id auto generate * fix: default user id * feat: add demo mode * fix: update flowsetting * fix: revise default params for demo * feat: sso_app alternative * feat: sso login demo * feat: demo specific customization * feat: add login using API key * fix: disable key-based login * fix: optimize duplicate upload * fix: gradio routing * fix: disable arm build for demo * fix: revise full-text search js logic * feat: add rate limit * fix: update Dockerfile with new launch script * fix: update Dockerfile * fix: update Dockerignore * fix: update ratelimit logic * fix: user_id in user management page * fix: rename conv logic * feat: update demo hint * fix: minor fix * fix: highlight on long PDF load * feat: add HF paper list * fix: update HF papers load logic * feat: fly config * fix: update fly config * fix: update paper list pull api * fix: minor update root routing * fix: minor update root routing * fix: simplify login flow & paper list UI * feat: add paper recommendation * fix: update Dockerfile * fix: update Dockerfile * fix: update default model * feat: add long context Ollama through LCOllama * feat: espose Gradio share to env * fix: revert customized changes * fix: list group at app load * fix: relocate share conv button * fix: update launch script * fix: update Docker CI * feat: add Ollama model selection at first setup * docs: update README --- .dockerignore | 2 + .env.example | 8 +- .github/workflows/build-push-docker.yaml | 1 + .github/workflows/fly-deploy.yml | 18 + .pre-commit-config.yaml | 1 + Dockerfile | 19 +- README.md | 22 +- app.py | 2 + docs/about.md | 4 +- docs/online_install.md | 2 +- flowsettings.py | 55 +- fly.toml | 26 + launch.sh | 23 + .../kotaemon/indices/qa/citation_qa.py | 9 +- libs/kotaemon/kotaemon/indices/qa/utils.py | 7 +- libs/kotaemon/kotaemon/indices/vectorindex.py | 13 +- libs/kotaemon/kotaemon/llms/__init__.py | 2 + libs/kotaemon/kotaemon/llms/chats/__init__.py | 2 + .../kotaemon/llms/chats/langchain_based.py | 37 + libs/kotaemon/kotaemon/loaders/pdf_loader.py | 5 +- libs/kotaemon/pyproject.toml | 1 + libs/ktem/ktem/app.py | 22 +- libs/ktem/ktem/assets/css/main.css | 47 +- libs/ktem/ktem/assets/js/main.js | 243 ++++- libs/ktem/ktem/assets/js/pdf_viewer.js | 276 ++--- libs/ktem/ktem/db/base_models.py | 10 +- libs/ktem/ktem/index/file/index.py | 24 +- libs/ktem/ktem/index/file/ui.py | 557 ++++++---- libs/ktem/ktem/index/file/utils.py | 58 ++ libs/ktem/ktem/llms/manager.py | 2 + libs/ktem/ktem/main.py | 41 +- libs/ktem/ktem/pages/chat/__init__.py | 966 +++++++++++------- libs/ktem/ktem/pages/chat/chat_panel.py | 21 +- libs/ktem/ktem/pages/chat/chat_suggestion.py | 25 +- libs/ktem/ktem/pages/chat/control.py | 166 ++- libs/ktem/ktem/pages/chat/demo_hint.py | 23 + libs/ktem/ktem/pages/chat/paper_list.py | 41 + libs/ktem/ktem/pages/chat/report.py | 2 +- libs/ktem/ktem/pages/help.py | 22 +- libs/ktem/ktem/pages/login.py | 55 +- libs/ktem/ktem/pages/resources/user.py | 16 +- libs/ktem/ktem/pages/settings.py | 93 +- libs/ktem/ktem/pages/setup.py | 44 +- .../reasoning/prompt_optimization/mindmap.py | 47 +- .../suggest_followup_chat.py | 4 +- libs/ktem/ktem/reasoning/simple.py | 42 +- libs/ktem/ktem/utils/hf_papers.py | 114 +++ libs/ktem/ktem/utils/rate_limit.py | 48 + libs/ktem/ktem/utils/render.py | 8 +- libs/ktem/pyproject.toml | 1 + sso_app.py | 51 + sso_app_demo.py | 97 ++ 52 files changed, 2488 insertions(+), 937 deletions(-) create mode 100644 .github/workflows/fly-deploy.yml create mode 100644 fly.toml create mode 100755 launch.sh create mode 100644 libs/ktem/ktem/index/file/utils.py create mode 100644 libs/ktem/ktem/pages/chat/demo_hint.py create mode 100644 libs/ktem/ktem/pages/chat/paper_list.py create mode 100644 libs/ktem/ktem/utils/hf_papers.py create mode 100644 libs/ktem/ktem/utils/rate_limit.py create mode 100644 sso_app.py create mode 100644 sso_app_demo.py diff --git a/.dockerignore b/.dockerignore index 0a57312..079881b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,3 +11,5 @@ env/ README.md *.zip *.sh + +!/launch.sh diff --git a/.env.example b/.env.example index c92d511..66d80bd 100644 --- a/.env.example +++ b/.env.example @@ -3,8 +3,8 @@ # settings for OpenAI OPENAI_API_BASE=https://api.openai.com/v1 OPENAI_API_KEY= -OPENAI_CHAT_MODEL=gpt-3.5-turbo -OPENAI_EMBEDDINGS_MODEL=text-embedding-ada-002 +OPENAI_CHAT_MODEL=gpt-4o-mini +OPENAI_EMBEDDINGS_MODEL=text-embedding-3-large # settings for Azure OpenAI AZURE_OPENAI_ENDPOINT= @@ -17,10 +17,8 @@ AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=text-embedding-ada-002 COHERE_API_KEY= # settings for local models -LOCAL_MODEL=llama3.1:8b +LOCAL_MODEL=qwen2.5:7b LOCAL_MODEL_EMBEDDINGS=nomic-embed-text -LOCAL_EMBEDDING_MODEL_DIM = 768 -LOCAL_EMBEDDING_MODEL_MAX_TOKENS = 8192 # settings for GraphRAG GRAPHRAG_API_KEY= diff --git a/.github/workflows/build-push-docker.yaml b/.github/workflows/build-push-docker.yaml index ee151a9..f5c8c10 100644 --- a/.github/workflows/build-push-docker.yaml +++ b/.github/workflows/build-push-docker.yaml @@ -28,6 +28,7 @@ jobs: target: - lite - full + - ollama steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main diff --git a/.github/workflows/fly-deploy.yml b/.github/workflows/fly-deploy.yml new file mode 100644 index 0000000..0b4beb5 --- /dev/null +++ b/.github/workflows/fly-deploy.yml @@ -0,0 +1,18 @@ +# See https://fly.io/docs/app-guides/continuous-deployment-with-github-actions/ + +name: Fly Deploy +on: + push: + branches: + - main +jobs: + deploy: + name: Deploy app + runs-on: ubuntu-latest + concurrency: deploy-group # optional: ensure only one action runs at a time + steps: + - uses: actions/checkout@v4 + - uses: superfly/flyctl-actions/setup-flyctl@master + - run: flyctl deploy --remote-only + env: + FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 514991d..11a82ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,6 +57,7 @@ repos: "types-requests", "sqlmodel", "types-Markdown", + "types-cachetools", types-tzlocal, ] args: ["--check-untyped-defs", "--ignore-missing-imports"] diff --git a/Dockerfile b/Dockerfile index 714d0ef..caedcec 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,6 +35,7 @@ RUN bash scripts/download_pdfjs.sh $PDFJS_PREBUILT_DIR # Copy contents COPY . /app +COPY launch.sh /app/launch.sh COPY .env.example /app/.env # Install pip packages @@ -54,7 +55,7 @@ RUN apt-get autoremove \ && rm -rf /var/lib/apt/lists/* \ && rm -rf ~/.cache -CMD ["python", "app.py"] +ENTRYPOINT ["sh", "/app/launch.sh"] # Full version FROM lite AS full @@ -97,7 +98,17 @@ RUN apt-get autoremove \ && rm -rf /var/lib/apt/lists/* \ && rm -rf ~/.cache -# Download nltk packages as required for unstructured -# RUN python -c "from unstructured.nlp.tokenize import _download_nltk_packages_if_not_present; _download_nltk_packages_if_not_present()" +ENTRYPOINT ["sh", "/app/launch.sh"] -CMD ["python", "app.py"] +# Ollama-bundled version +FROM full AS ollama + +# Install ollama +RUN --mount=type=ssh \ + --mount=type=cache,target=/root/.cache/pip \ + curl -fsSL https://ollama.com/install.sh | sh + +# RUN nohup bash -c "ollama serve &" && sleep 4 && ollama pull qwen2.5:7b +RUN nohup bash -c "ollama serve &" && sleep 4 && ollama pull nomic-embed-text + +ENTRYPOINT ["sh", "/app/launch.sh"] diff --git a/README.md b/README.md index a900f34..9374800 100644 --- a/README.md +++ b/README.md @@ -96,18 +96,7 @@ documents and developers who want to build their own RAG pipeline. ### With Docker (recommended) -1. We support both `lite` & `full` version of Docker images. With `full`, the extra packages of `unstructured` will be installed as well, it can support additional file types (`.doc`, `.docx`, ...) but the cost is larger docker image size. For most users, the `lite` image should work well in most cases. - - - To use the `lite` version. - - ```bash - docker run \ - -e GRADIO_SERVER_NAME=0.0.0.0 \ - -e GRADIO_SERVER_PORT=7860 \ - -v ./ktem_app_data:/app/ktem_app_data \ - -p 7860:7860 -it --rm \ - ghcr.io/cinnamon/kotaemon:main-lite - ``` +1. We support both `lite` & `full` version of Docker images. With `full` version, the extra packages of `unstructured` will be installed, which can support additional file types (`.doc`, `.docx`, ...) but the cost is larger docker image size. For most users, the `lite` image should work well in most cases. - To use the `full` version. @@ -124,9 +113,16 @@ documents and developers who want to build their own RAG pipeline. ```bash # change image name to - ghcr.io/cinnamon/kotaemon:feat-ollama_docker-full + docker run <...> ghcr.io/cinnamon/kotaemon:main-ollama ``` + - To use the `lite` version. + + ```bash + # change image name to + docker run <...> ghcr.io/cinnamon/kotaemon:main-lite + ``` + 2. We currently support and test two platforms: `linux/amd64` and `linux/arm64` (for newer Mac). You can specify the platform by passing `--platform` in the `docker run` command. For example: ```bash diff --git a/app.py b/app.py index a432d17..393abb3 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,7 @@ import os from theflow.settings import settings as flowsettings KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".") +KH_GRADIO_SHARE = getattr(flowsettings, "KH_GRADIO_SHARE", False) GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None) # override GRADIO_TEMP_DIR if it's not set if GRADIO_TEMP_DIR is None: @@ -21,4 +22,5 @@ demo.queue().launch( "libs/ktem/ktem/assets", GRADIO_TEMP_DIR, ], + share=KH_GRADIO_SHARE, ) diff --git a/docs/about.md b/docs/about.md index 757dc61..afc0a90 100644 --- a/docs/about.md +++ b/docs/about.md @@ -4,8 +4,8 @@ An open-source tool for chatting with your documents. Built with both end users developers in mind. [Source Code](https://github.com/Cinnamon/kotaemon) | -[Live Demo](https://huggingface.co/spaces/cin-model/kotaemon-demo) +[HF Space](https://huggingface.co/spaces/cin-model/kotaemon-demo) -[User Guide](https://cinnamon.github.io/kotaemon/) | +[Installation Guide](https://cinnamon.github.io/kotaemon/) | [Developer Guide](https://cinnamon.github.io/kotaemon/development/) | [Feedback](https://github.com/Cinnamon/kotaemon/issues) diff --git a/docs/online_install.md b/docs/online_install.md index 9f7eaf4..c5e55ca 100644 --- a/docs/online_install.md +++ b/docs/online_install.md @@ -1,7 +1,7 @@ ## Installation (Online HuggingFace Space) 1. Go to [HF kotaemon_template](https://huggingface.co/spaces/cin-model/kotaemon_template). -2. Use Duplicate function to create your own space. +2. Use Duplicate function to create your own space. Or use this [direct link](https://huggingface.co/spaces/cin-model/kotaemon_template?duplicate=true). ![Duplicate space](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/duplicate_space.png) ![Change space params](https://raw.githubusercontent.com/Cinnamon/kotaemon/main/docs/images/change_space_params.png) 3. Wait for the build to complete and start up (apprx 10 mins). diff --git a/flowsettings.py b/flowsettings.py index 0962eef..6abeea2 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -25,7 +25,8 @@ if not KH_APP_VERSION: except Exception: KH_APP_VERSION = "local" -KH_ENABLE_FIRST_SETUP = True +KH_GRADIO_SHARE = config("KH_GRADIO_SHARE", default=False, cast=bool) +KH_ENABLE_FIRST_SETUP = config("KH_ENABLE_FIRST_SETUP", default=True, cast=bool) KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool) KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/") @@ -65,6 +66,8 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface") KH_DOC_DIR = this_dir / "docs" KH_MODE = "dev" +KH_SSO_ENABLED = config("KH_SSO_ENABLED", default=False, cast=bool) + KH_FEATURE_CHAT_SUGGESTION = config( "KH_FEATURE_CHAT_SUGGESTION", default=False, cast=bool ) @@ -137,31 +140,36 @@ if config("AZURE_OPENAI_API_KEY", default="") and config( "default": False, } -if config("OPENAI_API_KEY", default=""): +OPENAI_DEFAULT = "" +OPENAI_API_KEY = config("OPENAI_API_KEY", default=OPENAI_DEFAULT) +GOOGLE_API_KEY = config("GOOGLE_API_KEY", default="your-key") +IS_OPENAI_DEFAULT = len(OPENAI_API_KEY) > 0 and OPENAI_API_KEY != OPENAI_DEFAULT + +if OPENAI_API_KEY: KH_LLMS["openai"] = { "spec": { "__type__": "kotaemon.llms.ChatOpenAI", "temperature": 0, "base_url": config("OPENAI_API_BASE", default="") or "https://api.openai.com/v1", - "api_key": config("OPENAI_API_KEY", default=""), - "model": config("OPENAI_CHAT_MODEL", default="gpt-3.5-turbo"), + "api_key": OPENAI_API_KEY, + "model": config("OPENAI_CHAT_MODEL", default="gpt-4o-mini"), "timeout": 20, }, - "default": True, + "default": IS_OPENAI_DEFAULT, } KH_EMBEDDINGS["openai"] = { "spec": { "__type__": "kotaemon.embeddings.OpenAIEmbeddings", "base_url": config("OPENAI_API_BASE", default="https://api.openai.com/v1"), - "api_key": config("OPENAI_API_KEY", default=""), + "api_key": OPENAI_API_KEY, "model": config( - "OPENAI_EMBEDDINGS_MODEL", default="text-embedding-ada-002" + "OPENAI_EMBEDDINGS_MODEL", default="text-embedding-3-large" ), "timeout": 10, "context_length": 8191, }, - "default": True, + "default": IS_OPENAI_DEFAULT, } if config("LOCAL_MODEL", default=""): @@ -169,11 +177,21 @@ if config("LOCAL_MODEL", default=""): "spec": { "__type__": "kotaemon.llms.ChatOpenAI", "base_url": KH_OLLAMA_URL, - "model": config("LOCAL_MODEL", default="llama3.1:8b"), + "model": config("LOCAL_MODEL", default="qwen2.5:7b"), "api_key": "ollama", }, "default": False, } + KH_LLMS["ollama-long-context"] = { + "spec": { + "__type__": "kotaemon.llms.LCOllamaChat", + "base_url": KH_OLLAMA_URL.replace("v1/", ""), + "model": config("LOCAL_MODEL", default="qwen2.5:7b"), + "num_ctx": 8192, + }, + "default": False, + } + KH_EMBEDDINGS["ollama"] = { "spec": { "__type__": "kotaemon.embeddings.OpenAIEmbeddings", @@ -183,7 +201,6 @@ if config("LOCAL_MODEL", default=""): }, "default": False, } - KH_EMBEDDINGS["fast_embed"] = { "spec": { "__type__": "kotaemon.embeddings.FastEmbedEmbeddings", @@ -205,9 +222,9 @@ KH_LLMS["google"] = { "spec": { "__type__": "kotaemon.llms.chats.LCGeminiChat", "model_name": "gemini-1.5-flash", - "api_key": config("GOOGLE_API_KEY", default="your-key"), + "api_key": GOOGLE_API_KEY, }, - "default": False, + "default": not IS_OPENAI_DEFAULT, } KH_LLMS["groq"] = { "spec": { @@ -241,8 +258,9 @@ KH_EMBEDDINGS["google"] = { "spec": { "__type__": "kotaemon.embeddings.LCGoogleEmbeddings", "model": "models/text-embedding-004", - "google_api_key": config("GOOGLE_API_KEY", default="your-key"), - } + "google_api_key": GOOGLE_API_KEY, + }, + "default": not IS_OPENAI_DEFAULT, } # KH_EMBEDDINGS["huggingface"] = { # "spec": { @@ -301,9 +319,12 @@ SETTINGS_REASONING = { USE_NANO_GRAPHRAG = config("USE_NANO_GRAPHRAG", default=False, cast=bool) USE_LIGHTRAG = config("USE_LIGHTRAG", default=True, cast=bool) +USE_MS_GRAPHRAG = config("USE_MS_GRAPHRAG", default=True, cast=bool) -GRAPHRAG_INDEX_TYPES = ["ktem.index.file.graph.GraphRAGIndex"] +GRAPHRAG_INDEX_TYPES = [] +if USE_MS_GRAPHRAG: + GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.GraphRAGIndex") if USE_NANO_GRAPHRAG: GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex") if USE_LIGHTRAG: @@ -323,7 +344,7 @@ GRAPHRAG_INDICES = [ ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" ), - "private": False, + "private": True, }, "index_type": graph_type, } @@ -338,7 +359,7 @@ KH_INDICES = [ ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" ), - "private": False, + "private": True, }, "index_type": "ktem.index.file.FileIndex", }, diff --git a/fly.toml b/fly.toml new file mode 100644 index 0000000..b9c37ac --- /dev/null +++ b/fly.toml @@ -0,0 +1,26 @@ +# fly.toml app configuration file generated for kotaemon on 2024-12-24T20:56:32+07:00 +# +# See https://fly.io/docs/reference/configuration/ for information about how to use this file. +# + +app = 'kotaemon' +primary_region = 'sin' + +[build] + +[mounts] + destination = "/app/ktem_app_data" + source = "ktem_volume" + +[http_service] + internal_port = 7860 + force_https = true + auto_stop_machines = 'suspend' + auto_start_machines = true + min_machines_running = 0 + processes = ['app'] + +[[vm]] + memory = '4gb' + cpu_kind = 'shared' + cpus = 4 diff --git a/launch.sh b/launch.sh new file mode 100755 index 0000000..08b3a59 --- /dev/null +++ b/launch.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +if [ -z "$GRADIO_SERVER_NAME" ]; then + export GRADIO_SERVER_NAME="0.0.0.0" +fi +if [ -z "$GRADIO_SERVER_PORT" ]; then + export GRADIO_SERVER_PORT="7860" +fi + +# Check if environment variable KH_DEMO_MODE is set to true +if [ "$KH_DEMO_MODE" = "true" ]; then + echo "KH_DEMO_MODE is true. Launching in demo mode..." + # Command to launch in demo mode + GR_FILE_ROOT_PATH="/app" KH_FEATURE_USER_MANAGEMENT=false USE_LIGHTRAG=false uvicorn sso_app_demo:app --host "$GRADIO_SERVER_NAME" --port "$GRADIO_SERVER_PORT" +else + if [ "$KH_SSO_ENABLED" = "true" ]; then + echo "KH_SSO_ENABLED is true. Launching in SSO mode..." + GR_FILE_ROOT_PATH="/app" KH_SSO_ENABLED=true uvicorn sso_app:app --host "$GRADIO_SERVER_NAME" --port "$GRADIO_SERVER_PORT" + else + ollama serve & + python app.py + fi +fi diff --git a/libs/kotaemon/kotaemon/indices/qa/citation_qa.py b/libs/kotaemon/kotaemon/indices/qa/citation_qa.py index efe34a2..b565632 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation_qa.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation_qa.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import Generator import numpy as np +from decouple import config from theflow.settings import settings as flowsettings from kotaemon.base import ( @@ -32,7 +33,9 @@ except ImportError: MAX_IMAGES = 10 CITATION_TIMEOUT = 5.0 -CONTEXT_RELEVANT_WARNING_SCORE = 0.7 +CONTEXT_RELEVANT_WARNING_SCORE = config( + "CONTEXT_RELEVANT_WARNING_SCORE", 0.3, cast=float +) DEFAULT_QA_TEXT_PROMPT = ( "Use the following pieces of context to answer the question at the end in detail with clear explanation. " # noqa: E501 @@ -385,7 +388,9 @@ class AnswerWithContextPipeline(BaseComponent): doc = id2docs[id_] doc_score = doc.metadata.get("llm_trulens_score", 0.0) is_open = not has_llm_score or ( - doc_score > CONTEXT_RELEVANT_WARNING_SCORE and len(with_citation) == 0 + doc_score + > CONTEXT_RELEVANT_WARNING_SCORE + # and len(with_citation) == 0 ) without_citation.append( Document( diff --git a/libs/kotaemon/kotaemon/indices/qa/utils.py b/libs/kotaemon/kotaemon/indices/qa/utils.py index d64fb20..c961f4f 100644 --- a/libs/kotaemon/kotaemon/indices/qa/utils.py +++ b/libs/kotaemon/kotaemon/indices/qa/utils.py @@ -2,6 +2,8 @@ from difflib import SequenceMatcher def find_text(search_span, context, min_length=5): + search_span, context = search_span.lower(), context.lower() + sentence_list = search_span.split("\n") context = context.replace("\n", " ") @@ -18,7 +20,7 @@ def find_text(search_span, context, min_length=5): matched_blocks = [] for _, start, length in match_results: - if length > max(len(sentence) * 0.2, min_length): + if length > max(len(sentence) * 0.25, min_length): matched_blocks.append((start, start + length)) if matched_blocks: @@ -42,6 +44,9 @@ def find_text(search_span, context, min_length=5): def find_start_end_phrase( start_phrase, end_phrase, context, min_length=5, max_excerpt_length=300 ): + start_phrase, end_phrase = start_phrase.lower(), end_phrase.lower() + context = context.lower() + context = context.replace("\n", " ") matches = [] diff --git a/libs/kotaemon/kotaemon/indices/vectorindex.py b/libs/kotaemon/kotaemon/indices/vectorindex.py index 5bf77c3..bc28215 100644 --- a/libs/kotaemon/kotaemon/indices/vectorindex.py +++ b/libs/kotaemon/kotaemon/indices/vectorindex.py @@ -177,7 +177,11 @@ class VectorRetrieval(BaseRetrieval): ] elif self.retrieval_mode == "text": query = text.text if isinstance(text, Document) else text - docs = self.doc_store.query(query, top_k=top_k_first_round, doc_ids=scope) + docs = [] + if scope: + docs = self.doc_store.query( + query, top_k=top_k_first_round, doc_ids=scope + ) result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs] elif self.retrieval_mode == "hybrid": # similarity search section @@ -206,9 +210,10 @@ class VectorRetrieval(BaseRetrieval): assert self.doc_store is not None query = text.text if isinstance(text, Document) else text - ds_docs = self.doc_store.query( - query, top_k=top_k_first_round, doc_ids=scope - ) + if scope: + ds_docs = self.doc_store.query( + query, top_k=top_k_first_round, doc_ids=scope + ) vs_query_thread = threading.Thread(target=query_vectorstore) ds_query_thread = threading.Thread(target=query_docstore) diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index d48e418..e7ddfbf 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -12,6 +12,7 @@ from .chats import ( LCChatOpenAI, LCCohereChat, LCGeminiChat, + LCOllamaChat, LlamaCppChat, ) from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI @@ -33,6 +34,7 @@ __all__ = [ "LCAnthropicChat", "LCGeminiChat", "LCCohereChat", + "LCOllamaChat", "LCAzureChatOpenAI", "LCChatOpenAI", "LlamaCppChat", diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 5585fba..2581356 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -7,6 +7,7 @@ from .langchain_based import ( LCChatOpenAI, LCCohereChat, LCGeminiChat, + LCOllamaChat, ) from .llamacpp import LlamaCppChat from .openai import AzureChatOpenAI, ChatOpenAI @@ -20,6 +21,7 @@ __all__ = [ "LCAnthropicChat", "LCGeminiChat", "LCCohereChat", + "LCOllamaChat", "LCChatOpenAI", "LCAzureChatOpenAI", "LCChatMixin", diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index bb98332..43daa76 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -358,3 +358,40 @@ class LCCohereChat(LCChatMixin, ChatLLM): # type: ignore raise ImportError("Please install langchain-cohere") return ChatCohere + + +class LCOllamaChat(LCChatMixin, ChatLLM): # type: ignore + base_url: str = Param( + help="Base Ollama URL. (default: http://localhost:11434/api/)", # noqa + required=True, + ) + model: str = Param( + help="Model name to use (https://ollama.com/library)", + required=True, + ) + num_ctx: int = Param( + help="The size of the context window (default: 8192)", + required=True, + ) + + def __init__( + self, + model: str | None = None, + base_url: str | None = None, + num_ctx: int | None = None, + **params, + ): + super().__init__( + base_url=base_url, + model=model, + num_ctx=num_ctx, + **params, + ) + + def _get_lc_class(self): + try: + from langchain_ollama import ChatOllama + except ImportError: + raise ImportError("Please install langchain-ollama") + + return ChatOllama diff --git a/libs/kotaemon/kotaemon/loaders/pdf_loader.py b/libs/kotaemon/kotaemon/loaders/pdf_loader.py index ecba89d..0a60197 100644 --- a/libs/kotaemon/kotaemon/loaders/pdf_loader.py +++ b/libs/kotaemon/kotaemon/loaders/pdf_loader.py @@ -3,15 +3,18 @@ from io import BytesIO from pathlib import Path from typing import Dict, List, Optional +from decouple import config from fsspec import AbstractFileSystem from llama_index.readers.file import PDFReader from PIL import Image from kotaemon.base import Document +PDF_LOADER_DPI = config("PDF_LOADER_DPI", default=40, cast=int) + def get_page_thumbnails( - file_path: Path, pages: list[int], dpi: int = 80 + file_path: Path, pages: list[int], dpi: int = PDF_LOADER_DPI ) -> List[Image.Image]: """Get image thumbnails of the pages in the PDF file. diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 2e60886..9791724 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "langchain-openai>=0.1.4,<0.2.0", "langchain-google-genai>=1.0.3,<2.0.0", "langchain-anthropic", + "langchain-ollama", "langchain-cohere>=0.2.4,<0.3.0", "llama-hub>=0.0.79,<0.1.0", "llama-index>=0.10.40,<0.11.0", diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index c4dce35..4ade819 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -13,7 +13,7 @@ from ktem.settings import BaseSettingGroup, SettingGroup, SettingReasoningGroup from theflow.settings import settings from theflow.utils.modules import import_dotted_string -BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "") +BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "") class BaseApp: @@ -57,7 +57,7 @@ class BaseApp: self._pdf_view_js = self._pdf_view_js.replace( "PDFJS_PREBUILT_DIR", pdf_js_dist_dir, - ).replace("GRADIO_ROOT_PATH", BASE_PATH) + ).replace("GR_FILE_ROOT_PATH", BASE_PATH) with (dir_assets / "js" / "svg-pan-zoom.min.js").open() as fi: self._svg_js = fi.read() @@ -79,7 +79,7 @@ class BaseApp: self.default_settings.index.finalize() self.settings_state = gr.State(self.default_settings.flatten()) - self.user_id = gr.State(1 if not self.f_user_management else None) + self.user_id = gr.State("default" if not self.f_user_management else None) def initialize_indices(self): """Create the index manager, start indices, and register to app settings""" @@ -173,15 +173,25 @@ class BaseApp: """Called when the app is created""" def make(self): + markmap_js = """ + + """ external_js = ( "" - "" "" # noqa + "" # noqa "" "" # noqa ) diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index 95b1f63..5a272e6 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -326,7 +326,12 @@ pdfjs-viewer-element { /* Switch checkbox styles */ -#is-public-checkbox { +/* #is-public-checkbox { + position: relative; + top: 4px; +} */ + +#suggest-chat-checkbox { position: relative; top: 4px; } @@ -411,3 +416,43 @@ details.evidence { tbody:not(.row_odd) { background: var(--table-even-background-fill); } + +#chat-suggestion { + max-height: 350px; +} + +#chat-suggestion table { + overflow: hidden; +} + +#chat-suggestion table thead { + display: none; +} + +#paper-suggestion table { + overflow: hidden; +} + +svg.markmap { + width: 100%; + height: 100%; + font-family: Quicksand, sans-serif; + font-size: 15px; +} + +div.markmap { + height: 400px; +} + +#google-login { + max-width: 450px; +} + +#user-api-key-wrapper { + max-width: 450px; +} + +#login-row { + display: grid; + place-items: center; +} diff --git a/libs/ktem/ktem/assets/js/main.js b/libs/ktem/ktem/assets/js/main.js index ad3c991..8e63853 100644 --- a/libs/ktem/ktem/assets/js/main.js +++ b/libs/ktem/ktem/assets/js/main.js @@ -11,10 +11,25 @@ function run() { version_node.style = "position: fixed; top: 10px; right: 10px;"; main_parent.appendChild(version_node); + // add favicon + const favicon = document.createElement("link"); + // set favicon attributes + favicon.rel = "icon"; + favicon.type = "image/svg+xml"; + favicon.href = "/favicon.ico"; + document.head.appendChild(favicon); + + // setup conversation dropdown placeholder + let conv_dropdown = document.querySelector("#conversation-dropdown input"); + conv_dropdown.placeholder = "Browse conversation"; + // move info-expand-button let info_expand_button = document.getElementById("info-expand-button"); let chat_info_panel = document.getElementById("info-expand"); - chat_info_panel.insertBefore(info_expand_button, chat_info_panel.childNodes[2]); + chat_info_panel.insertBefore( + info_expand_button, + chat_info_panel.childNodes[2] + ); // move toggle-side-bar button let chat_expand_button = document.getElementById("chat-expand-button"); @@ -24,22 +39,24 @@ function run() { // move setting close button let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav"); let setting_close_button = document.getElementById("save-setting-btn"); - setting_tab_nav_bar.appendChild(setting_close_button); + if (setting_close_button) { + setting_tab_nav_bar.appendChild(setting_close_button); + } let default_conv_column_min_width = "min(300px, 100%)"; - conv_column.style.minWidth = default_conv_column_min_width + conv_column.style.minWidth = default_conv_column_min_width; - globalThis.toggleChatColumn = (() => { + globalThis.toggleChatColumn = () => { /* get flex-grow value of chat_column */ let flex_grow = conv_column.style.flexGrow; - if (flex_grow == '0') { - conv_column.style.flexGrow = '1'; + if (flex_grow == "0") { + conv_column.style.flexGrow = "1"; conv_column.style.minWidth = default_conv_column_min_width; } else { - conv_column.style.flexGrow = '0'; + conv_column.style.flexGrow = "0"; conv_column.style.minWidth = "0px"; } - }); + }; chat_column.insertBefore(chat_expand_button, chat_column.firstChild); @@ -47,22 +64,34 @@ function run() { let mindmap_checkbox = document.getElementById("use-mindmap-checkbox"); let citation_dropdown = document.getElementById("citation-dropdown"); let chat_setting_panel = document.getElementById("chat-settings-expand"); - chat_setting_panel.insertBefore(mindmap_checkbox, chat_setting_panel.childNodes[2]); + chat_setting_panel.insertBefore( + mindmap_checkbox, + chat_setting_panel.childNodes[2] + ); chat_setting_panel.insertBefore(citation_dropdown, mindmap_checkbox); + // move share conv checkbox + let report_div = document.querySelector( + "#report-accordion > div:nth-child(3) > div:nth-child(1)" + ); + let share_conv_checkbox = document.getElementById("is-public-checkbox"); + if (share_conv_checkbox) { + report_div.insertBefore(share_conv_checkbox, report_div.querySelector("button")); + } + // create slider toggle - const is_public_checkbox = document.getElementById("is-public-checkbox"); + const is_public_checkbox = document.getElementById("suggest-chat-checkbox"); const label_element = is_public_checkbox.getElementsByTagName("label")[0]; const checkbox_span = is_public_checkbox.getElementsByTagName("span")[0]; new_div = document.createElement("div"); label_element.classList.add("switch"); is_public_checkbox.appendChild(checkbox_span); - label_element.appendChild(new_div) + label_element.appendChild(new_div); // clpse globalThis.clpseFn = (id) => { - var obj = document.getElementById('clpse-btn-' + id); + var obj = document.getElementById("clpse-btn-" + id); obj.classList.toggle("clpse-active"); var content = obj.nextElementSibling; if (content.style.display === "none") { @@ -70,48 +99,188 @@ function run() { } else { content.style.display = "none"; } - } + }; // store info in local storage globalThis.setStorage = (key, value) => { - localStorage.setItem(key, value) - } + localStorage.setItem(key, value); + }; globalThis.getStorage = (key, value) => { item = localStorage.getItem(key); return item ? item : value; - } + }; globalThis.removeFromStorage = (key) => { - localStorage.removeItem(key) - } + localStorage.removeItem(key); + }; // Function to scroll to given citation with ID // Sleep function using Promise and setTimeout function sleep(ms) { - return new Promise(resolve => setTimeout(resolve, ms)); + return new Promise((resolve) => setTimeout(resolve, ms)); } globalThis.scrollToCitation = async (event) => { - event.preventDefault(); // Prevent the default link behavior - var citationId = event.target.getAttribute('id'); + event.preventDefault(); // Prevent the default link behavior + var citationId = event.target.getAttribute("id"); - await sleep(100); // Sleep for 100 milliseconds + await sleep(100); // Sleep for 100 milliseconds - // check if modal is open - var modal = document.getElementById("pdf-modal"); - var citation = document.querySelector('mark[id="' + citationId + '"]'); + // check if modal is open + var modal = document.getElementById("pdf-modal"); + var citation = document.querySelector('mark[id="' + citationId + '"]'); - if (modal.style.display == "block") { - // trigger on click event of PDF Preview link - var detail_elem = citation; - // traverse up the DOM tree to find the parent element with tag detail - while (detail_elem.tagName.toLowerCase() != "details") { - detail_elem = detail_elem.parentElement; - } - detail_elem.getElementsByClassName("pdf-link").item(0).click(); - } else { - if (citation) { - citation.scrollIntoView({ behavior: 'smooth' }); + if (modal.style.display == "block") { + // trigger on click event of PDF Preview link + var detail_elem = citation; + // traverse up the DOM tree to find the parent element with tag detail + while (detail_elem.tagName.toLowerCase() != "details") { + detail_elem = detail_elem.parentElement; + } + detail_elem.getElementsByClassName("pdf-link").item(0).click(); + } else { + if (citation) { + citation.scrollIntoView({ behavior: "smooth" }); + } + } + }; + + globalThis.fullTextSearch = () => { + // Assign text selection event to last bot message + var bot_messages = document.querySelectorAll( + "div#main-chat-bot div.message-row.bot-row" + ); + var last_bot_message = bot_messages[bot_messages.length - 1]; + + // check if the last bot message has class "text_selection" + if (last_bot_message.classList.contains("text_selection")) { + return; + } + + // assign new class to last message + last_bot_message.classList.add("text_selection"); + + // Get sentences from evidence div + var evidences = document.querySelectorAll( + "#html-info-panel > div:last-child > div > details.evidence div.evidence-content" + ); + console.log("Indexing evidences", evidences); + + const segmenterEn = new Intl.Segmenter("en", { granularity: "sentence" }); + // Split sentences and save to all_segments list + var all_segments = []; + for (var evidence of evidences) { + // check if
tag is open + if (!evidence.parentElement.open) { + continue; + } + var markmap_div = evidence.querySelector("div.markmap"); + if (markmap_div) { + continue; + } + + var evidence_content = evidence.textContent.replace(/[\r\n]+/g, " "); + sentence_it = segmenterEn.segment(evidence_content)[Symbol.iterator](); + while ((sentence = sentence_it.next().value)) { + segment = sentence.segment.trim(); + if (segment) { + all_segments.push({ + id: all_segments.length, + text: segment, + }); } } - } + } + + let miniSearch = new MiniSearch({ + fields: ["text"], // fields to index for full-text search + storeFields: ["text"], + }); + + // Index all documents + miniSearch.addAll(all_segments); + + last_bot_message.addEventListener("mouseup", () => { + let selection = window.getSelection().toString(); + let results = miniSearch.search(selection); + + if (results.length == 0) { + return; + } + let matched_text = results[0].text; + console.log("query\n", selection, "\nmatched text\n", matched_text); + + var evidences = document.querySelectorAll( + "#html-info-panel > div:last-child > div > details.evidence div.evidence-content" + ); + // check if modal is open + var modal = document.getElementById("pdf-modal"); + + // convert all in evidences to normal text + evidences.forEach((evidence) => { + evidence.querySelectorAll("mark").forEach((mark) => { + mark.outerHTML = mark.innerText; + }); + }); + + // highlight matched_text in evidences + for (var evidence of evidences) { + var evidence_content = evidence.textContent.replace(/[\r\n]+/g, " "); + if (evidence_content.includes(matched_text)) { + // select all p and li elements + paragraphs = evidence.querySelectorAll("p, li"); + for (var p of paragraphs) { + var p_content = p.textContent.replace(/[\r\n]+/g, " "); + if (p_content.includes(matched_text)) { + p.innerHTML = p_content.replace( + matched_text, + "" + matched_text + "" + ); + console.log("highlighted", matched_text, "in", p); + if (modal.style.display == "block") { + // trigger on click event of PDF Preview link + var detail_elem = p; + // traverse up the DOM tree to find the parent element with tag detail + while (detail_elem.tagName.toLowerCase() != "details") { + detail_elem = detail_elem.parentElement; + } + detail_elem.getElementsByClassName("pdf-link").item(0).click(); + } else { + p.scrollIntoView({ behavior: "smooth", block: "center" }); + } + break; + } + } + } + } + }); + }; + + globalThis.spawnDocument = (content, options) => { + let opt = { + window: "", + closeChild: true, + childId: "_blank", + }; + Object.assign(opt, options); + // minimal error checking + if ( + content && + typeof content.toString == "function" && + content.toString().length + ) { + let child = window.open("", opt.childId, opt.window); + child.document.write(content.toString()); + if (opt.closeChild) child.document.close(); + return child; + } + }; + + globalThis.fillChatInput = (event) => { + let chatInput = document.querySelector("#chat-input textarea"); + // fill the chat input with the clicked div text + chatInput.value = "Explain " + event.target.textContent; + var evt = new Event("change"); + chatInput.dispatchEvent(new Event("input", { bubbles: true })); + chatInput.focus(); + }; } diff --git a/libs/ktem/ktem/assets/js/pdf_viewer.js b/libs/ktem/ktem/assets/js/pdf_viewer.js index 2166edb..eeec7ba 100644 --- a/libs/ktem/ktem/assets/js/pdf_viewer.js +++ b/libs/ktem/ktem/assets/js/pdf_viewer.js @@ -1,138 +1,186 @@ -function onBlockLoad () { - var infor_panel_scroll_pos = 0; - globalThis.createModal = () => { - // Create modal for the 1st time if it does not exist - var modal = document.getElementById("pdf-modal"); - var old_position = null; - var old_width = null; - var old_left = null; - var expanded = false; +function onBlockLoad() { + var infor_panel_scroll_pos = 0; + globalThis.createModal = () => { + // Create modal for the 1st time if it does not exist + var modal = document.getElementById("pdf-modal"); + var old_position = null; + var old_width = null; + var old_left = null; + var expanded = false; - modal.id = "pdf-modal"; - modal.className = "modal"; - modal.innerHTML = ` + modal.id = "pdf-modal"; + modal.className = "modal"; + modal.innerHTML = ` `; - modal.querySelector("#modal-close").onclick = function() { - modal.style.display = "none"; - var info_panel = document.getElementById("html-info-panel"); - if (info_panel) { - info_panel.style.display = "block"; - } - var scrollableDiv = document.getElementById("chat-info-panel"); - scrollableDiv.scrollTop = infor_panel_scroll_pos; - }; - - modal.querySelector("#modal-expand").onclick = function () { - expanded = !expanded; - if (expanded) { - old_position = modal.style.position; - old_left = modal.style.left; - old_width = modal.style.width; - - modal.style.position = "fixed"; - modal.style.width = "70%"; - modal.style.left = "15%"; - modal.style.height = "100dvh"; - } else { - modal.style.position = old_position; - modal.style.width = old_width; - modal.style.left = old_left; - modal.style.height = "85dvh"; - } - }; - } - - globalThis.compareText = (search_phrase, page_label) => { - var iframe = document.querySelector("#pdf-viewer").iframe; - var innerDoc = (iframe.contentDocument) ? iframe.contentDocument : iframe.contentWindow.document; - - var query_selector = ( - "#viewer > div[data-page-number='" + - page_label + - "'] > div.textLayer > span" - ); - var page_spans = innerDoc.querySelectorAll(query_selector); - for (var i = 0; i < page_spans.length; i++) { - var span = page_spans[i]; - if ( - span.textContent.length > 4 && - ( - search_phrase.includes(span.textContent) || - span.textContent.includes(search_phrase) - ) - ) { - span.innerHTML = "" + span.textContent + ""; - } else { - // if span is already highlighted, remove it - if (span.querySelector(".highlight")) { - span.innerHTML = span.textContent; - } - } - } - } - - // Sleep function using Promise and setTimeout - function sleep(ms) { - return new Promise(resolve => setTimeout(resolve, ms)); - } - - // Function to open modal and display PDF - globalThis.openModal = async (event) => { - event.preventDefault(); - var target = event.currentTarget; - var src = target.getAttribute("data-src"); - var page = target.getAttribute("data-page"); - var search = target.getAttribute("data-search"); - var phrase = target.getAttribute("data-phrase"); - - var pdfViewer = document.getElementById("pdf-viewer"); - - current_src = pdfViewer.getAttribute("src"); - if (current_src != src) { - pdfViewer.setAttribute("src", src); - } - // pdfViewer.setAttribute("phrase", phrase); - // pdfViewer.setAttribute("search", search); - pdfViewer.setAttribute("page", page); - - var scrollableDiv = document.getElementById("chat-info-panel"); - infor_panel_scroll_pos = scrollableDiv.scrollTop; - - var modal = document.getElementById("pdf-modal") - modal.style.display = "block"; + modal.querySelector("#modal-close").onclick = function () { + modal.style.display = "none"; var info_panel = document.getElementById("html-info-panel"); if (info_panel) { - info_panel.style.display = "none"; + info_panel.style.display = "block"; } - scrollableDiv.scrollTop = 0; + var scrollableDiv = document.getElementById("chat-info-panel"); + scrollableDiv.scrollTop = infor_panel_scroll_pos; + }; - /* search for text inside PDF page */ - await sleep(500); - compareText(search, page); + modal.querySelector("#modal-expand").onclick = function () { + expanded = !expanded; + if (expanded) { + old_position = modal.style.position; + old_left = modal.style.left; + old_width = modal.style.width; + + modal.style.position = "fixed"; + modal.style.width = "70%"; + modal.style.left = "15%"; + modal.style.height = "100dvh"; + } else { + modal.style.position = old_position; + modal.style.width = old_width; + modal.style.left = old_left; + modal.style.height = "85dvh"; + } + }; + }; + + function matchRatio(str1, str2) { + let n = str1.length; + let m = str2.length; + + let lcs = []; + for (let i = 0; i <= n; i++) { + lcs[i] = []; + for (let j = 0; j <= m; j++) { + lcs[i][j] = 0; + } } - globalThis.assignPdfOnclickEvent = () => { - // Get all links and attach click event - var links = document.getElementsByClassName("pdf-link"); - for (var i = 0; i < links.length; i++) { - links[i].onclick = openModal; + let result = ""; + let max = 0; + for (let i = 0; i < n; i++) { + for (let j = 0; j < m; j++) { + if (str1[i] === str2[j]) { + lcs[i + 1][j + 1] = lcs[i][j] + 1; + if (lcs[i + 1][j + 1] > max) { + max = lcs[i + 1][j + 1]; + result = str1.substring(i - max + 1, i + 1); + } } + } } - var created_modal = document.getElementById("pdf-viewer"); - if (!created_modal) { - createModal(); + return result.length / Math.min(n, m); + } + + globalThis.compareText = (search_phrases, page_label) => { + var iframe = document.querySelector("#pdf-viewer").iframe; + var innerDoc = iframe.contentDocument + ? iframe.contentDocument + : iframe.contentWindow.document; + + var renderedPages = innerDoc.querySelectorAll("div#viewer div.page"); + if (renderedPages.length == 0) { + // if pages are not rendered yet, wait and try again + setTimeout(() => compareText(search_phrases, page_label), 2000); + return; } + var query_selector = + "#viewer > div[data-page-number='" + + page_label + + "'] > div.textLayer > span"; + var page_spans = innerDoc.querySelectorAll(query_selector); + for (var i = 0; i < page_spans.length; i++) { + var span = page_spans[i]; + if ( + span.textContent.length > 4 && + search_phrases.some( + (phrase) => matchRatio(phrase, span.textContent) > 0.5 + ) + ) { + span.innerHTML = + "" + span.textContent + ""; + } else { + // if span is already highlighted, remove it + if (span.querySelector(".highlight")) { + span.innerHTML = span.textContent; + } + } + } + }; + + // Sleep function using Promise and setTimeout + function sleep(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + // Function to open modal and display PDF + globalThis.openModal = async (event) => { + event.preventDefault(); + var target = event.currentTarget; + var src = target.getAttribute("data-src"); + var page = target.getAttribute("data-page"); + var search = target.getAttribute("data-search"); + var highlighted_spans = + target.parentElement.parentElement.querySelectorAll("mark"); + + // Get text from highlighted spans + var search_phrases = Array.from(highlighted_spans).map( + (span) => span.textContent + ); + // Use regex to strip 【id】from search phrases + search_phrases = search_phrases.map((phrase) => + phrase.replace(/【\d+】/g, "") + ); + + // var phrase = target.getAttribute("data-phrase"); + + var pdfViewer = document.getElementById("pdf-viewer"); + + current_src = pdfViewer.getAttribute("src"); + if (current_src != src) { + pdfViewer.setAttribute("src", src); + } + // pdfViewer.setAttribute("phrase", phrase); + // pdfViewer.setAttribute("search", search); + pdfViewer.setAttribute("page", page); + + var scrollableDiv = document.getElementById("chat-info-panel"); + infor_panel_scroll_pos = scrollableDiv.scrollTop; + + var modal = document.getElementById("pdf-modal"); + modal.style.display = "block"; + var info_panel = document.getElementById("html-info-panel"); + if (info_panel) { + info_panel.style.display = "none"; + } + scrollableDiv.scrollTop = 0; + + /* search for text inside PDF page */ + await sleep(500); + compareText(search_phrases, page); + }; + + globalThis.assignPdfOnclickEvent = () => { + // Get all links and attach click event + var links = document.getElementsByClassName("pdf-link"); + for (var i = 0; i < links.length; i++) { + links[i].onclick = openModal; + } + }; + + var created_modal = document.getElementById("pdf-viewer"); + if (!created_modal) { + createModal(); + } } diff --git a/libs/ktem/ktem/db/base_models.py b/libs/ktem/ktem/db/base_models.py index 7c49705..1cec718 100644 --- a/libs/ktem/ktem/db/base_models.py +++ b/libs/ktem/ktem/db/base_models.py @@ -29,7 +29,7 @@ class BaseConversation(SQLModel): datetime.datetime.now(get_localzone()).strftime("%Y-%m-%d %H:%M:%S") ) ) - user: int = Field(default=0) # For now we only have one user + user: str = Field(default="") # For now we only have one user is_public: bool = Field(default=False) @@ -55,7 +55,9 @@ class BaseUser(SQLModel): __table_args__ = {"extend_existing": True} - id: Optional[int] = Field(default=None, primary_key=True) + id: str = Field( + default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True + ) username: str = Field(unique=True) username_lower: str = Field(unique=True) password: str @@ -76,7 +78,7 @@ class BaseSettings(SQLModel): id: str = Field( default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True ) - user: int = Field(default=0) + user: str = Field(default="") setting: dict = Field(default={}, sa_column=Column(JSON)) @@ -97,4 +99,4 @@ class BaseIssueReport(SQLModel): issues: dict = Field(default={}, sa_column=Column(JSON)) chat: Optional[dict] = Field(default=None, sa_column=Column(JSON)) settings: Optional[dict] = Field(default=None, sa_column=Column(JSON)) - user: Optional[int] = Field(default=None) + user: Optional[str] = Field(default=None) diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index 9092d48..7409f3c 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -17,6 +17,10 @@ from kotaemon.storages import BaseDocumentStore, BaseVectorStore from .base import BaseFileIndexIndexing, BaseFileIndexRetriever +def generate_uuid(): + return str(uuid.uuid4()) + + class FileIndex(BaseIndex): """ File index to store and allow retrieval of files @@ -76,7 +80,7 @@ class FileIndex(BaseIndex): "date_created": Column( DateTime(timezone=True), default=datetime.now(get_localzone()) ), - "user": Column(Integer, default=1), + "user": Column(String, default=""), "note": Column( MutableDict.as_mutable(JSON), # type: ignore default={}, @@ -101,7 +105,7 @@ class FileIndex(BaseIndex): "date_created": Column( DateTime(timezone=True), default=datetime.now(get_localzone()) ), - "user": Column(Integer, default=1), + "user": Column(String, default=""), "note": Column( MutableDict.as_mutable(JSON), # type: ignore default={}, @@ -117,7 +121,7 @@ class FileIndex(BaseIndex): "source_id": Column(String), "target_id": Column(String), "relation_type": Column(String), - "user": Column(Integer, default=1), + "user": Column(String, default=""), }, ) FileGroup = type( @@ -125,12 +129,20 @@ class FileIndex(BaseIndex): (Base,), { "__tablename__": f"index__{self.id}__group", - "id": Column(Integer, primary_key=True, autoincrement=True), + "__table_args__": ( + UniqueConstraint("name", "user", name="_name_user_uc"), + ), + "id": Column( + String, + primary_key=True, + default=lambda: str(uuid.uuid4()), + unique=True, + ), "date_created": Column( DateTime(timezone=True), default=datetime.now(get_localzone()) ), - "name": Column(String, unique=True), - "user": Column(Integer, default=1), + "name": Column(String), + "user": Column(String, default=""), "data": Column( MutableDict.as_mutable(JSON), # type: ignore default={"files": []}, diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index aa2da36..410e315 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -20,9 +20,14 @@ from sqlalchemy.orm import Session from theflow.settings import settings as flowsettings from ...utils.commands import WEB_SEARCH_COMMAND +from ...utils.rate_limit import check_rate_limit +from .utils import download_arxiv_pdf, is_arxiv_url -DOWNLOAD_MESSAGE = "Press again to download" +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False) +DOWNLOAD_MESSAGE = "Start download" MAX_FILENAME_LENGTH = 20 +MAX_FILE_COUNT = 200 chat_input_focus_js = """ function() { @@ -31,6 +36,15 @@ function() { } """ +chat_input_focus_js_with_submit = """ +function() { + let chatInput = document.querySelector("#chat-input textarea"); + let chatInputSubmit = document.querySelector("#chat-input button.submit-button"); + chatInputSubmit.click(); + chatInput.focus(); +} +""" + update_file_list_js = """ function(file_list) { var values = []; @@ -53,6 +67,7 @@ function(file_list) { allowSpaces: true, }) input_box = document.querySelector('#chat-input textarea'); + tribute.detach(input_box); tribute.attach(input_box); } """.replace( @@ -128,7 +143,9 @@ class FileIndexPage(BasePage): # TODO: on_building_ui is not correctly named if it's always called in # the constructor self.public_events = [f"onFileIndex{index.id}Changed"] - self.on_building_ui() + + if not KH_DEMO_MODE: + self.on_building_ui() def upload_instruction(self) -> str: msgs = [] @@ -201,10 +218,10 @@ class FileIndexPage(BasePage): with gr.Accordion("Advance options", open=False): with gr.Row(): - self.download_all_button = gr.DownloadButton( - "Download all files", - visible=True, - ) + if not KH_SSO_ENABLED: + self.download_all_button = gr.DownloadButton( + "Download all files", + ) self.delete_all_button = gr.Button( "Delete all files", variant="stop", @@ -249,13 +266,13 @@ class FileIndexPage(BasePage): ) with gr.Column(visible=False) as self._group_info_panel: + self.selected_group_id = gr.State(value=None) self.group_label = gr.Markdown() self.group_name = gr.Textbox( label="Group name", placeholder="Group name", lines=1, max_lines=1, - interactive=False, ) self.group_files = gr.Dropdown( label="Attached files", @@ -290,7 +307,7 @@ class FileIndexPage(BasePage): ) gr.Markdown("(separated by new line)") - with gr.Accordion("Advanced indexing options", open=True): + with gr.Accordion("Advanced indexing options", open=False): with gr.Row(): self.reindex = gr.Checkbox( value=False, label="Force reindex file", container=False @@ -324,6 +341,9 @@ class FileIndexPage(BasePage): def on_subscribe_public_events(self): """Subscribe to the declared public event of the app""" + if KH_DEMO_MODE: + return + self._app.subscribe_event( name=f"onFileIndex{self._index.id}Changed", definition={ @@ -500,6 +520,34 @@ class FileIndexPage(BasePage): return not is_zipped_state, new_button + def download_single_file_simple(self, is_zipped_state, file_html, file_id): + with Session(engine) as session: + source = session.execute( + select(self._index._resources["Source"]).where( + self._index._resources["Source"].id == file_id + ) + ).first() + if source: + target_file_name = Path(source[0].name) + + # create a temporary file with a path to export + output_file_path = os.path.join( + flowsettings.KH_ZIP_OUTPUT_DIR, target_file_name.stem + ".html" + ) + with open(output_file_path, "w") as f: + f.write(file_html) + + if is_zipped_state: + new_button = gr.DownloadButton(label="Download", value=None) + else: + # export the file path + new_button = gr.DownloadButton( + label=DOWNLOAD_MESSAGE, + value=output_file_path, + ) + + return not is_zipped_state, new_button + def download_all_files(self): if self._index.config.get("private", False): raise gr.Error("This feature is not available for private collection.") @@ -543,8 +591,145 @@ class FileIndexPage(BasePage): gr.update(visible=True), ] + def on_register_quick_uploads(self): + try: + # quick file upload event registration of first Index only + if self._index.id == 1: + self.quick_upload_state = gr.State(value=[]) + print("Setting up quick upload event") + + # override indexing function from chat page + self._app.chat_page.first_indexing_url_fn = ( + self.index_fn_url_with_default_loaders + ) + + if not KH_DEMO_MODE: + quickUploadedEvent = ( + self._app.chat_page.quick_file_upload.upload( + fn=lambda: gr.update( + value="Please wait for the indexing process " + "to complete before adding your question." + ), + outputs=self._app.chat_page.quick_file_upload_status, + ) + .then( + fn=self.index_fn_file_with_default_loaders, + inputs=[ + self._app.chat_page.quick_file_upload, + gr.State(value=False), + self._app.settings_state, + self._app.user_id, + ], + outputs=self.quick_upload_state, + concurrency_limit=10, + ) + .success( + fn=lambda: [ + gr.update(value=None), + gr.update(value="select"), + ], + outputs=[ + self._app.chat_page.quick_file_upload, + self._app.chat_page._indices_input[0], + ], + ) + ) + for event in self._app.get_event( + f"onFileIndex{self._index.id}Changed" + ): + quickUploadedEvent = quickUploadedEvent.then(**event) + + quickUploadedEvent = ( + quickUploadedEvent.success( + fn=lambda x: x, + inputs=self.quick_upload_state, + outputs=self._app.chat_page._indices_input[1], + ) + .then( + fn=lambda: gr.update(value="Indexing completed."), + outputs=self._app.chat_page.quick_file_upload_status, + ) + .then( + fn=self.list_file, + inputs=[self._app.user_id, self.filter], + outputs=[self.file_list_state, self.file_list], + concurrency_limit=20, + ) + .then( + fn=lambda: True, + inputs=None, + outputs=None, + js=chat_input_focus_js_with_submit, + ) + ) + + quickURLUploadedEvent = ( + self._app.chat_page.quick_urls.submit( + fn=lambda: gr.update( + value="Please wait for the indexing process " + "to complete before adding your question." + ), + outputs=self._app.chat_page.quick_file_upload_status, + ) + .then( + fn=self.index_fn_url_with_default_loaders, + inputs=[ + self._app.chat_page.quick_urls, + gr.State(value=False), + self._app.settings_state, + self._app.user_id, + ], + outputs=self.quick_upload_state, + concurrency_limit=10, + ) + .success( + fn=lambda: [ + gr.update(value=None), + gr.update(value="select"), + ], + outputs=[ + self._app.chat_page.quick_urls, + self._app.chat_page._indices_input[0], + ], + ) + ) + for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): + quickURLUploadedEvent = quickURLUploadedEvent.then(**event) + + quickURLUploadedEvent = quickURLUploadedEvent.success( + fn=lambda x: x, + inputs=self.quick_upload_state, + outputs=self._app.chat_page._indices_input[1], + ).then( + fn=lambda: gr.update(value="Indexing completed."), + outputs=self._app.chat_page.quick_file_upload_status, + ) + + if not KH_DEMO_MODE: + quickURLUploadedEvent = quickURLUploadedEvent.then( + fn=self.list_file, + inputs=[self._app.user_id, self.filter], + outputs=[self.file_list_state, self.file_list], + concurrency_limit=20, + ) + + quickURLUploadedEvent = quickURLUploadedEvent.then( + fn=lambda: True, + inputs=None, + outputs=None, + js=chat_input_focus_js_with_submit, + ) + + except Exception as e: + print(e) + def on_register_events(self): """Register all events to the app""" + self.on_register_quick_uploads() + + if KH_DEMO_MODE: + return + onDeleted = ( self.delete_button.click( fn=self.delete_event, @@ -606,12 +791,13 @@ class FileIndexPage(BasePage): ], ) - self.download_all_button.click( - fn=self.download_all_files, - inputs=[], - outputs=self.download_all_button, - show_progress="hidden", - ) + if not KH_SSO_ENABLED: + self.download_all_button.click( + fn=self.download_all_files, + inputs=[], + outputs=self.download_all_button, + show_progress="hidden", + ) self.delete_all_button.click( self.show_delete_all_confirm, @@ -659,12 +845,20 @@ class FileIndexPage(BasePage): ], ) - self.download_single_button.click( - fn=self.download_single_file, - inputs=[self.is_zipped_state, self.selected_file_id], - outputs=[self.is_zipped_state, self.download_single_button], - show_progress="hidden", - ) + if not KH_SSO_ENABLED: + self.download_single_button.click( + fn=self.download_single_file, + inputs=[self.is_zipped_state, self.selected_file_id], + outputs=[self.is_zipped_state, self.download_single_button], + show_progress="hidden", + ) + else: + self.download_single_button.click( + fn=self.download_single_file_simple, + inputs=[self.is_zipped_state, self.chunks, self.selected_file_id], + outputs=[self.is_zipped_state, self.download_single_button], + show_progress="hidden", + ) onUploaded = ( self.upload_button.click( @@ -689,121 +883,6 @@ class FileIndexPage(BasePage): ) ) - try: - # quick file upload event registration of first Index only - if self._index.id == 1: - self.quick_upload_state = gr.State(value=[]) - print("Setting up quick upload event") - - # override indexing function from chat page - self._app.chat_page.first_indexing_url_fn = ( - self.index_fn_url_with_default_loaders - ) - quickUploadedEvent = ( - self._app.chat_page.quick_file_upload.upload( - fn=lambda: gr.update( - value="Please wait for the indexing process " - "to complete before adding your question." - ), - outputs=self._app.chat_page.quick_file_upload_status, - ) - .then( - fn=self.index_fn_file_with_default_loaders, - inputs=[ - self._app.chat_page.quick_file_upload, - gr.State(value=False), - self._app.settings_state, - self._app.user_id, - ], - outputs=self.quick_upload_state, - ) - .success( - fn=lambda: [ - gr.update(value=None), - gr.update(value="select"), - ], - outputs=[ - self._app.chat_page.quick_file_upload, - self._app.chat_page._indices_input[0], - ], - ) - ) - for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): - quickUploadedEvent = quickUploadedEvent.then(**event) - - quickURLUploadedEvent = ( - self._app.chat_page.quick_urls.submit( - fn=lambda: gr.update( - value="Please wait for the indexing process " - "to complete before adding your question." - ), - outputs=self._app.chat_page.quick_file_upload_status, - ) - .then( - fn=self.index_fn_url_with_default_loaders, - inputs=[ - self._app.chat_page.quick_urls, - gr.State(value=True), - self._app.settings_state, - self._app.user_id, - ], - outputs=self.quick_upload_state, - ) - .success( - fn=lambda: [ - gr.update(value=None), - gr.update(value="select"), - ], - outputs=[ - self._app.chat_page.quick_urls, - self._app.chat_page._indices_input[0], - ], - ) - ) - for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): - quickURLUploadedEvent = quickURLUploadedEvent.then(**event) - - quickUploadedEvent.success( - fn=lambda x: x, - inputs=self.quick_upload_state, - outputs=self._app.chat_page._indices_input[1], - ).then( - fn=lambda: gr.update(value="Indexing completed."), - outputs=self._app.chat_page.quick_file_upload_status, - ).then( - fn=self.list_file, - inputs=[self._app.user_id, self.filter], - outputs=[self.file_list_state, self.file_list], - concurrency_limit=20, - ).then( - fn=lambda: True, - inputs=None, - outputs=None, - js=chat_input_focus_js, - ) - - quickURLUploadedEvent.success( - fn=lambda x: x, - inputs=self.quick_upload_state, - outputs=self._app.chat_page._indices_input[1], - ).then( - fn=lambda: gr.update(value="Indexing completed."), - outputs=self._app.chat_page.quick_file_upload_status, - ).then( - fn=self.list_file, - inputs=[self._app.user_id, self.filter], - outputs=[self.file_list_state, self.file_list], - concurrency_limit=20, - ).then( - fn=lambda: True, - inputs=None, - outputs=None, - js=chat_input_focus_js, - ) - - except Exception as e: - print(e) - uploadedEvent = onUploaded.then( fn=self.list_file, inputs=[self._app.user_id, self.filter], @@ -844,7 +923,12 @@ class FileIndexPage(BasePage): self.group_list.select( fn=self.interact_group_list, inputs=[self.group_list_state], - outputs=[self.group_label, self.group_name, self.group_files], + outputs=[ + self.group_label, + self.selected_group_id, + self.group_name, + self.group_files, + ], show_progress="hidden", ).then( fn=lambda: ( @@ -875,8 +959,9 @@ class FileIndexPage(BasePage): gr.update(visible=False), gr.update(value="### Add new group"), gr.update(visible=True), - gr.update(value="", interactive=True), + gr.update(value=""), gr.update(value=[]), + None, ], outputs=[ self.group_add_button, @@ -884,12 +969,13 @@ class FileIndexPage(BasePage): self._group_info_panel, self.group_name, self.group_files, + self.selected_group_id, ], ) self.group_chat_button.click( fn=self.set_group_id_selector, - inputs=[self.group_name], + inputs=[self.selected_group_id], outputs=[ self._index.get_selector_component_ui().selector, self._index.get_selector_component_ui().mode, @@ -897,44 +983,53 @@ class FileIndexPage(BasePage): ], ) + onGroupClosedEvent = { + "fn": lambda: [ + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + None, + ], + "outputs": [ + self.group_add_button, + self._group_info_panel, + self.group_close_button, + self.group_delete_button, + self.group_chat_button, + self.selected_group_id, + ], + } + self.group_close_button.click(**onGroupClosedEvent) onGroupSaved = ( self.group_save_button.click( fn=self.save_group, - inputs=[self.group_name, self.group_files, self._app.user_id], + inputs=[ + self.selected_group_id, + self.group_name, + self.group_files, + self._app.user_id, + ], ) .then( self.list_group, inputs=[self._app.user_id, self.file_list_state], outputs=[self.group_list_state, self.group_list], ) - .then( - fn=lambda: gr.update(visible=False), - outputs=[self._group_info_panel], + .then(**onGroupClosedEvent) + ) + onGroupDeleted = ( + self.group_delete_button.click( + fn=self.delete_group, + inputs=[self.selected_group_id], ) - ) - self.group_close_button.click( - fn=lambda: [ - gr.update(visible=True), - gr.update(visible=False), - gr.update(visible=False), - gr.update(visible=False), - gr.update(visible=False), - ], - outputs=[ - self.group_add_button, - self._group_info_panel, - self.group_close_button, - self.group_delete_button, - self.group_chat_button, - ], - ) - onGroupDeleted = self.group_delete_button.click( - fn=self.delete_group, - inputs=[self.group_name], - ).then( - self.list_group, - inputs=[self._app.user_id, self.file_list_state], - outputs=[self.group_list_state, self.group_list], + .then( + self.list_group, + inputs=[self._app.user_id, self.file_list_state], + outputs=[self.group_list_state, self.group_list], + ) + .then(**onGroupClosedEvent) ) for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): @@ -943,10 +1038,21 @@ class FileIndexPage(BasePage): def _on_app_created(self): """Called when the app is created""" + if KH_DEMO_MODE: + return + self._app.app.load( self.list_file, inputs=[self._app.user_id, self.filter], outputs=[self.file_list_state, self.file_list], + ).then( + self.list_group, + inputs=[self._app.user_id, self.file_list_state], + outputs=[self.group_list_state, self.group_list], + ).then( + self.list_file_names, + inputs=[self.file_list_state], + outputs=[self.group_files], ) def _may_extract_zip(self, files, zip_dir: str): @@ -1089,19 +1195,67 @@ class FileIndexPage(BasePage): return exist_ids + returned_ids - def index_fn_url_with_default_loaders(self, urls, reindex: bool, settings, user_id): - returned_ids = [] + def index_fn_url_with_default_loaders( + self, + urls, + reindex: bool, + settings, + user_id, + request: gr.Request, + ): + if KH_DEMO_MODE: + check_rate_limit("file_upload", request) + + returned_ids: list[str] = [] settings = deepcopy(settings) settings[f"index.options.{self._index.id}.reader_mode"] = "default" settings[f"index.options.{self._index.id}.quick_index_mode"] = True - if urls: - _iter = self.index_fn([], urls, reindex, settings, user_id) - try: - while next(_iter): - pass - except StopIteration as e: - returned_ids = e.value + if KH_DEMO_MODE: + urls_splitted = urls.split("\n") + if not all(is_arxiv_url(url) for url in urls_splitted): + raise ValueError("All URLs must be valid arXiv URLs") + + output_files = [ + download_arxiv_pdf( + url, + output_path=os.environ.get("GRADIO_TEMP_DIR", "/tmp"), + ) + for url in urls_splitted + ] + + exist_ids = [] + to_process_files = [] + for str_file_path in output_files: + file_path = Path(str_file_path) + exist_id = ( + self._index.get_indexing_pipeline(settings, user_id) + .route(file_path) + .get_id_if_exists(file_path) + ) + if exist_id: + exist_ids.append(exist_id) + else: + to_process_files.append(str_file_path) + + returned_ids = [] + if to_process_files: + _iter = self.index_fn(to_process_files, [], reindex, settings, user_id) + try: + while next(_iter): + pass + except StopIteration as e: + returned_ids = e.value + + returned_ids = exist_ids + returned_ids + else: + if urls: + _iter = self.index_fn([], urls, reindex, settings, user_id) + try: + while next(_iter): + pass + except StopIteration as e: + returned_ids = e.value return returned_ids @@ -1254,6 +1408,7 @@ class FileIndexPage(BasePage): return gr.update(choices=file_names) def list_group(self, user_id, file_list): + # supply file_list to display the file names in the group if file_list: file_id_to_name = {item["id"]: item["name"] for item in file_list} else: @@ -1319,27 +1474,42 @@ class FileIndexPage(BasePage): return results, group_list - def set_group_id_selector(self, selected_group_name): + def set_group_id_selector(self, selected_group_id): FileGroup = self._index._resources["FileGroup"] # check if group_name exist with Session(engine) as session: current_group = ( - session.query(FileGroup).filter_by(name=selected_group_name).first() + session.query(FileGroup).filter_by(id=selected_group_id).first() ) file_ids = [json.dumps(current_group.data["files"])] return [file_ids, "select", gr.Tabs(selected="chat-tab")] - def save_group(self, group_name, group_files, user_id): + def save_group(self, group_id, group_name, group_files, user_id): FileGroup = self._index._resources["FileGroup"] current_group = None # check if group_name exist with Session(engine) as session: - current_group = session.query(FileGroup).filter_by(name=group_name).first() + if group_id: + current_group = session.query(FileGroup).filter_by(id=group_id).first() + # update current group with new info + current_group.name = group_name + current_group.data["files"] = group_files # Update the files + session.commit() + else: + current_group = ( + session.query(FileGroup) + .filter_by( + name=group_name, + user=user_id, + ) + .first() + ) + if current_group: + raise gr.Error(f"Group {group_name} already exists") - if not current_group: current_group = FileGroup( name=group_name, data={"files": group_files}, # type: ignore @@ -1347,34 +1517,31 @@ class FileIndexPage(BasePage): ) session.add(current_group) session.commit() - else: - # update current group with new info - current_group.name = group_name - current_group.data["files"] = group_files # Update the files - session.commit() group_id = current_group.id gr.Info(f"Group {group_name} has been saved") return group_id - def delete_group(self, group_name): + def delete_group(self, group_id): + if not group_id: + raise gr.Error("No group is selected") + FileGroup = self._index._resources["FileGroup"] - group_id = None with Session(engine) as session: group = session.execute( - select(FileGroup).where(FileGroup.name == group_name) + select(FileGroup).where(FileGroup.id == group_id) ).first() if group: item = group[0] - group_id = item.id + group_name = item.name session.delete(item) session.commit() gr.Info(f"Group {group_name} has been deleted") else: - raise gr.Error(f"Group {group_name} not found") + raise gr.Error("No group found") - return group_id + return None def interact_file_list(self, list_files, ev: gr.SelectData): if ev.value == "-" and ev.index[0] == 0: @@ -1394,9 +1561,11 @@ class FileIndexPage(BasePage): raise gr.Error("No group is selected") selected_item = list_groups[selected_id] + selected_group_id = selected_item["id"] return ( "### Group Information", - gr.update(value=selected_item["name"], interactive=False), + selected_group_id, + selected_item["name"], selected_item["files"], ) @@ -1525,6 +1694,10 @@ class FileSelector(BasePage): self._index._resources["Source"].user == user_id ) + if KH_DEMO_MODE: + # limit query by MAX_FILE_COUNT + statement = statement.limit(MAX_FILE_COUNT) + results = session.execute(statement).all() for result in results: available_ids.append(result[0].id) diff --git a/libs/ktem/ktem/index/file/utils.py b/libs/ktem/ktem/index/file/utils.py new file mode 100644 index 0000000..83e6742 --- /dev/null +++ b/libs/ktem/ktem/index/file/utils.py @@ -0,0 +1,58 @@ +import os + +import requests + +# regex patterns for Arxiv URL +ARXIV_URL_PATTERNS = [ + "https://arxiv.org/abs/", + "https://arxiv.org/pdf/", +] + +ILLEGAL_NAME_CHARS = ["\\", "/", ":", "*", "?", '"', "<", ">", "|"] + + +def clean_name(name): + for char in ILLEGAL_NAME_CHARS: + name = name.replace(char, "_") + return name + + +def is_arxiv_url(url): + return any(url.startswith(pattern) for pattern in ARXIV_URL_PATTERNS) + + +# download PDF from Arxiv URL +def download_arxiv_pdf(url, output_path): + if not is_arxiv_url(url): + raise ValueError("Invalid Arxiv URL") + + is_abstract_url = "abs" in url + if is_abstract_url: + pdf_url = url.replace("abs", "pdf") + abstract_url = url + else: + pdf_url = url + abstract_url = url.replace("pdf", "abs") + + # get paper name from abstract url + response = requests.get(abstract_url) + + # parse HTML response and get h1.title + from bs4 import BeautifulSoup + + soup = BeautifulSoup(response.content, "html.parser") + name = clean_name( + soup.find("h1", class_="title").text.strip().replace("Title:", "") + ) + if not name: + raise ValueError("Failed to get paper name") + + output_file_path = os.path.join(output_path, name + ".pdf") + # prevent downloading if file already exists + if not os.path.exists(output_file_path): + response = requests.get(pdf_url) + + with open(output_file_path, "wb") as f: + f.write(response.content) + + return output_file_path diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index 829bcae..317d970 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -60,6 +60,7 @@ class LLMManager: LCAnthropicChat, LCCohereChat, LCGeminiChat, + LCOllamaChat, LlamaCppChat, ) @@ -69,6 +70,7 @@ class LLMManager: LCAnthropicChat, LCGeminiChat, LCCohereChat, + LCOllamaChat, LlamaCppChat, ] diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py index deeb415..f9ff6be 100644 --- a/libs/ktem/ktem/main.py +++ b/libs/ktem/ktem/main.py @@ -9,6 +9,7 @@ from ktem.pages.setup import SetupPage from theflow.settings import settings as flowsettings KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False) KH_ENABLE_FIRST_SETUP = getattr(flowsettings, "KH_ENABLE_FIRST_SETUP", False) KH_APP_DATA_EXISTS = getattr(flowsettings, "KH_APP_DATA_EXISTS", True) @@ -19,7 +20,7 @@ if config("KH_FIRST_SETUP", default=False, cast=bool): def toggle_first_setup_visibility(): global KH_APP_DATA_EXISTS - is_first_setup = KH_DEMO_MODE or not KH_APP_DATA_EXISTS + is_first_setup = not KH_DEMO_MODE and not KH_APP_DATA_EXISTS KH_APP_DATA_EXISTS = True return gr.update(visible=is_first_setup), gr.update(visible=not is_first_setup) @@ -70,7 +71,7 @@ class App(BaseApp): "indices-tab", ], id="indices-tab", - visible=not self.f_user_management, + visible=not self.f_user_management and not KH_DEMO_MODE, ) as self._tabs[f"{index.id}-tab"]: page = index.get_index_page_ui() setattr(self, f"_index_{index.id}", page) @@ -80,7 +81,7 @@ class App(BaseApp): elem_id="indices-tab", elem_classes=["fill-main-area-height", "scrollable", "indices-tab"], id="indices-tab", - visible=not self.f_user_management, + visible=not self.f_user_management and not KH_DEMO_MODE, ) as self._tabs["indices-tab"]: for index in self.index_manager.indices: with gr.Tab( @@ -90,23 +91,25 @@ class App(BaseApp): page = index.get_index_page_ui() setattr(self, f"_index_{index.id}", page) - with gr.Tab( - "Resources", - elem_id="resources-tab", - id="resources-tab", - visible=not self.f_user_management, - elem_classes=["fill-main-area-height", "scrollable"], - ) as self._tabs["resources-tab"]: - self.resources_page = ResourcesTab(self) + if not KH_DEMO_MODE: + if not KH_SSO_ENABLED: + with gr.Tab( + "Resources", + elem_id="resources-tab", + id="resources-tab", + visible=not self.f_user_management, + elem_classes=["fill-main-area-height", "scrollable"], + ) as self._tabs["resources-tab"]: + self.resources_page = ResourcesTab(self) - with gr.Tab( - "Settings", - elem_id="settings-tab", - id="settings-tab", - visible=not self.f_user_management, - elem_classes=["fill-main-area-height", "scrollable"], - ) as self._tabs["settings-tab"]: - self.settings_page = SettingsPage(self) + with gr.Tab( + "Settings", + elem_id="settings-tab", + id="settings-tab", + visible=not self.f_user_management, + elem_classes=["fill-main-area-height", "scrollable"], + ) as self._tabs["settings-tab"]: + self.settings_page = SettingsPage(self) with gr.Tab( "Help", diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 7696027..2f9ab94 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -1,5 +1,4 @@ import asyncio -import importlib import json import re from copy import deepcopy @@ -10,6 +9,7 @@ from ktem.app import BasePage from ktem.components import reasonings from ktem.db.models import Conversation, engine from ktem.index.file.ui import File +from ktem.reasoning.prompt_optimization.mindmap import MINDMAP_HTML_EXPORT_TEMPLATE from ktem.reasoning.prompt_optimization.suggest_conversation_name import ( SuggestConvNamePipeline, ) @@ -19,29 +19,41 @@ from ktem.reasoning.prompt_optimization.suggest_followup_chat import ( from plotly.io import from_json from sqlmodel import Session, select from theflow.settings import settings as flowsettings +from theflow.utils.modules import import_dotted_string from kotaemon.base import Document from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex, get_urls from ...utils.commands import WEB_SEARCH_COMMAND +from ...utils.hf_papers import get_recommended_papers +from ...utils.rate_limit import check_rate_limit from .chat_panel import ChatPanel +from .chat_suggestion import ChatSuggestion from .common import STATE from .control import ConversationControl +from .demo_hint import HintPage +from .paper_list import PaperListPage from .report import ReportIssue +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False) KH_WEB_SEARCH_BACKEND = getattr(flowsettings, "KH_WEB_SEARCH_BACKEND", None) WebSearch = None if KH_WEB_SEARCH_BACKEND: try: - module_name, class_name = KH_WEB_SEARCH_BACKEND.rsplit(".", 1) - module = importlib.import_module(module_name) - WebSearch = getattr(module, class_name) + WebSearch = import_dotted_string(KH_WEB_SEARCH_BACKEND, safe=False) except (ImportError, AttributeError) as e: - print(f"Error importing {class_name} from {module_name}: {e}") + print(f"Error importing {KH_WEB_SEARCH_BACKEND}: {e}") +REASONING_LIMITS = 2 if KH_DEMO_MODE else 10 DEFAULT_SETTING = "(default)" INFO_PANEL_SCALES = {True: 8, False: 4} +DEFAULT_QUESTION = ( + "What is the summary of this document?" + if not KH_DEMO_MODE + else "What is the summary of this paper?" +) chat_input_focus_js = """ function() { @@ -50,8 +62,59 @@ function() { } """ +quick_urls_submit_js = """ +function() { + let urlInput = document.querySelector("#quick-url-demo textarea"); + console.log("URL input:", urlInput); + urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'})); +} +""" + +recommended_papers_js = """ +function() { + // Get all links and attach click event + var links = document.querySelectorAll("#related-papers a"); + + function submitPaper(event) { + event.preventDefault(); + var target = event.currentTarget; + var url = target.getAttribute("href"); + console.log("URL:", url); + + let newChatButton = document.querySelector("#new-conv-button"); + newChatButton.click(); + + setTimeout(() => { + let urlInput = document.querySelector("#quick-url-demo textarea"); + // Fill the URL input + urlInput.value = url; + urlInput.dispatchEvent(new Event("input", { bubbles: true })); + urlInput.dispatchEvent(new KeyboardEvent('keypress', {'key': 'Enter'})); + }, 500 + ); + } + + for (var i = 0; i < links.length; i++) { + links[i].onclick = submitPaper; + } +} +""" + +clear_bot_message_selection_js = """ +function() { + var bot_messages = document.querySelectorAll( + "div#main-chat-bot div.message-row.bot-row" + ); + bot_messages.forEach(message => { + message.classList.remove("text_selection"); + }); +} +""" + pdfview_js = """ function() { + setTimeout(fullTextSearch(), 100); + // Get all links and attach click event var links = document.getElementsByClassName("pdf-link"); for (var i = 0; i < links.length; i++) { @@ -64,50 +127,71 @@ function() { links[i].onclick = scrollToCitation; } - var mindmap_el = document.getElementById('mindmap'); + var markmap_div = document.querySelector("div.markmap"); + var mindmap_el_script = document.querySelector('div.markmap script'); - if (mindmap_el) { - var output = svgPanZoom(mindmap_el); - const svg = mindmap_el.cloneNode(true); - - function on_svg_export(event) { - event.preventDefault(); // Prevent the default link behavior - // convert to a valid XML source - const as_text = new XMLSerializer().serializeToString(svg); - // store in a Blob - const blob = new Blob([as_text], { type: "image/svg+xml" }); - // create an URI pointing to that blob - const url = URL.createObjectURL(blob); - const win = open(url); - // so the Garbage Collector can collect the blob - win.onload = (evt) => URL.revokeObjectURL(url); - } - - var link = document.getElementById("mindmap-toggle"); - if (link) { - link.onclick = function(event) { - event.preventDefault(); // Prevent the default link behavior - var div = document.getElementById("mindmap-wrapper"); - if (div) { - var currentHeight = div.style.height; - if (currentHeight === '400px') { - var contentHeight = div.scrollHeight; - div.style.height = contentHeight + 'px'; - } else { - div.style.height = '400px' - } - } - }; - } - - var link = document.getElementById("mindmap-export"); - if (link) { - link.addEventListener('click', on_svg_export); - } + if (mindmap_el_script) { + markmap_div_html = markmap_div.outerHTML; } + // render the mindmap if the script tag is present + if (mindmap_el_script) { + markmap.autoLoader.renderAll(); + } + + setTimeout(() => { + var mindmap_el = document.querySelector('svg.markmap'); + + var text_nodes = document.querySelectorAll("svg.markmap div"); + for (var i = 0; i < text_nodes.length; i++) { + text_nodes[i].onclick = fillChatInput; + } + + if (mindmap_el) { + function on_svg_export(event) { + html = "{html_template}"; + html = html.replace("{markmap_div}", markmap_div_html); + spawnDocument(html, {window: "width=1000,height=1000"}); + } + + var link = document.getElementById("mindmap-toggle"); + if (link) { + link.onclick = function(event) { + event.preventDefault(); // Prevent the default link behavior + var div = document.querySelector("div.markmap"); + if (div) { + var currentHeight = div.style.height; + if (currentHeight === '400px' || (currentHeight === '')) { + div.style.height = '650px'; + } else { + div.style.height = '400px' + } + } + }; + } + + if (markmap_div_html) { + var link = document.getElementById("mindmap-export"); + if (link) { + link.addEventListener('click', on_svg_export); + } + } + } + }, 250); + return [links.length] } +""".replace( + "{html_template}", + MINDMAP_HTML_EXPORT_TEMPLATE.replace("\n", "").replace('"', '\\"'), +) + +fetch_api_key_js = """ +function(_, __) { + api_key = getStorage('google_api_key', ''); + console.log('session API key:', api_key); + return [api_key, _]; +} """ @@ -126,6 +210,7 @@ class ChatPage(BasePage): ) self._info_panel_expanded = gr.State(value=True) self._command_state = gr.State(value=None) + self._user_api_key = gr.Text(value="", visible=False) def on_building_ui(self): with gr.Row(): @@ -146,7 +231,17 @@ class ChatPage(BasePage): continue index_ui.unrender() # need to rerender later within Accordion - with gr.Accordion(label=index.name, open=index_id < 1): + is_first_index = index_id == 0 + index_name = index.name + + if KH_DEMO_MODE and is_first_index: + index_name = "Select from Paper Collection" + + with gr.Accordion( + label=index_name, + open=is_first_index, + elem_id=f"index-{index_id}", + ): index_ui.render() gr_index = index_ui.as_gradio_component() @@ -171,100 +266,114 @@ class ChatPage(BasePage): self._indices_input.append(gr_index) setattr(self, f"_index_{index.id}", index_ui) + self.chat_suggestion = ChatSuggestion(self._app) + if len(self._app.index_manager.indices) > 0: - with gr.Accordion(label="Quick Upload") as _: - self.quick_file_upload = File( - file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()), - file_count="multiple", - container=True, - show_label=False, - elem_id="quick-file", - ) + quick_upload_label = ( + "Quick Upload" if not KH_DEMO_MODE else "Or input new paper URL" + ) + + with gr.Accordion(label=quick_upload_label) as _: + self.quick_file_upload_status = gr.Markdown() + if not KH_DEMO_MODE: + self.quick_file_upload = File( + file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()), + file_count="multiple", + container=True, + show_label=False, + elem_id="quick-file", + ) self.quick_urls = gr.Textbox( - placeholder="Or paste URLs here", + placeholder=( + "Or paste URLs" + if not KH_DEMO_MODE + else "Paste Arxiv URLs\n(https://arxiv.org/abs/xxx)" + ), lines=1, container=False, show_label=False, - elem_id="quick-url", + elem_id=( + "quick-url" if not KH_DEMO_MODE else "quick-url-demo" + ), ) - self.quick_file_upload_status = gr.Markdown() - self.report_issue = ReportIssue(self._app) + if not KH_DEMO_MODE: + self.report_issue = ReportIssue(self._app) + else: + with gr.Accordion(label="Related papers", open=False): + self.related_papers = gr.Markdown(elem_id="related-papers") + + self.hint_page = HintPage(self._app) with gr.Column(scale=6, elem_id="chat-area"): + if KH_DEMO_MODE: + self.paper_list = PaperListPage(self._app) + self.chat_panel = ChatPanel(self._app) - with gr.Row(): - with gr.Accordion( - label="Chat settings", - elem_id="chat-settings-expand", - open=False, - ): - with gr.Row(elem_id="quick-setting-labels"): - gr.HTML("Reasoning method") - gr.HTML("Model") - gr.HTML("Language") - gr.HTML("Suggestion") + with gr.Accordion( + label="Chat settings", + elem_id="chat-settings-expand", + open=False, + visible=not KH_DEMO_MODE, + ) as self.chat_settings: + with gr.Row(elem_id="quick-setting-labels"): + gr.HTML("Reasoning method") + gr.HTML( + "Model", visible=not KH_DEMO_MODE and not KH_SSO_ENABLED + ) + gr.HTML("Language") - with gr.Row(): - reasoning_type_values = [ - (DEFAULT_SETTING, DEFAULT_SETTING) - ] + self._app.default_settings.reasoning.settings[ - "use" - ].choices - self.reasoning_type = gr.Dropdown( - choices=reasoning_type_values, - value=DEFAULT_SETTING, - container=False, - show_label=False, - ) - self.model_type = gr.Dropdown( - choices=self._app.default_settings.reasoning.options[ - "simple" - ] - .settings["llm"] - .choices, - value="", - container=False, - show_label=False, - ) - self.language = gr.Dropdown( - choices=[ - (DEFAULT_SETTING, DEFAULT_SETTING), - ] - + self._app.default_settings.reasoning.settings[ - "lang" - ].choices, - value=DEFAULT_SETTING, - container=False, - show_label=False, - ) - self.use_chat_suggestion = gr.Checkbox( - label="Chat suggestion", - container=False, - elem_id="use-suggestion-checkbox", - ) + with gr.Row(): + reasoning_setting = ( + self._app.default_settings.reasoning.settings["use"] + ) + model_setting = self._app.default_settings.reasoning.options[ + "simple" + ].settings["llm"] + language_setting = ( + self._app.default_settings.reasoning.settings["lang"] + ) + citation_setting = self._app.default_settings.reasoning.options[ + "simple" + ].settings["highlight_citation"] - self.citation = gr.Dropdown( - choices=[ - (DEFAULT_SETTING, DEFAULT_SETTING), - ] - + self._app.default_settings.reasoning.options["simple"] - .settings["highlight_citation"] - .choices, - value=DEFAULT_SETTING, - container=False, - show_label=False, - interactive=True, - elem_id="citation-dropdown", - ) + self.reasoning_type = gr.Dropdown( + choices=reasoning_setting.choices[:REASONING_LIMITS], + value=reasoning_setting.value, + container=False, + show_label=False, + ) + self.model_type = gr.Dropdown( + choices=model_setting.choices, + value=model_setting.value, + container=False, + show_label=False, + visible=not KH_DEMO_MODE and not KH_SSO_ENABLED, + ) + self.language = gr.Dropdown( + choices=language_setting.choices, + value=language_setting.value, + container=False, + show_label=False, + ) - self.use_mindmap = gr.State(value=DEFAULT_SETTING) - self.use_mindmap_check = gr.Checkbox( - label="Mindmap (default)", - container=False, - elem_id="use-mindmap-checkbox", - ) + self.citation = gr.Dropdown( + choices=citation_setting.choices, + value=citation_setting.value, + container=False, + show_label=False, + interactive=True, + elem_id="citation-dropdown", + ) + + self.use_mindmap = gr.State(value=True) + self.use_mindmap_check = gr.Checkbox( + label="Mindmap (on)", + container=False, + elem_id="use-mindmap-checkbox", + value=True, + ) with gr.Column( scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel" @@ -276,6 +385,9 @@ class ChatPage(BasePage): self.plot_panel = gr.Plot(visible=False) self.info_panel = gr.HTML(elem_id="html-info-panel") + self.followup_questions = self.chat_suggestion.examples + self.followup_questions_ui = self.chat_suggestion.accordion + def _json_to_plot(self, json_dict: dict | None): if json_dict: plot = from_json(json_dict) @@ -285,8 +397,18 @@ class ChatPage(BasePage): return plot def on_register_events(self): - self.followup_questions = self.chat_control.chat_suggestion.examples - self.followup_questions_ui = self.chat_control.chat_suggestion.accordion + # first index paper recommendation + if KH_DEMO_MODE and len(self._indices_input) > 0: + self._indices_input[1].change( + self.get_recommendations, + inputs=[self.first_selector_choices, self._indices_input[1]], + outputs=[self.related_papers], + ).then( + fn=None, + inputs=None, + outputs=None, + js=recommended_papers_js, + ) chat_event = ( gr.on( @@ -374,51 +496,44 @@ class ChatPage(BasePage): ) ) - # chat suggestion toggle - chat_event = chat_event.success( - fn=self.suggest_chat_conv, - inputs=[ + onSuggestChatEvent = { + "fn": self.suggest_chat_conv, + "inputs": [ self._app.settings_state, + self.language, self.chat_panel.chatbot, self._use_suggestion, ], - outputs=[ + "outputs": [ self.followup_questions_ui, self.followup_questions, ], - show_progress="hidden", - ) - # .success( - # self.chat_control.persist_chat_suggestions, - # inputs=[ - # self.chat_control.conversation_id, - # self.followup_questions, - # self._use_suggestion, - # self._app.user_id, - # ], - # show_progress="hidden", - # ) + "show_progress": "hidden", + } + # chat suggestion toggle + chat_event = chat_event.success(**onSuggestChatEvent) # final data persist - chat_event = chat_event.then( - fn=self.persist_data_source, - inputs=[ - self.chat_control.conversation_id, - self._app.user_id, - self.info_panel, - self.state_plot_panel, - self.state_retrieval_history, - self.state_plot_history, - self.chat_panel.chatbot, - self.state_chat, - ] - + self._indices_input, - outputs=[ - self.state_retrieval_history, - self.state_plot_history, - ], - concurrency_limit=20, - ) + if not KH_DEMO_MODE: + chat_event = chat_event.then( + fn=self.persist_data_source, + inputs=[ + self.chat_control.conversation_id, + self._app.user_id, + self.info_panel, + self.state_plot_panel, + self.state_retrieval_history, + self.state_plot_history, + self.chat_panel.chatbot, + self.state_chat, + ] + + self._indices_input, + outputs=[ + self.state_retrieval_history, + self.state_plot_history, + ], + concurrency_limit=20, + ) self.chat_control.btn_info_expand.click( fn=lambda is_expanded: ( @@ -432,163 +547,223 @@ class ChatPage(BasePage): fn=None, inputs=None, js="function() {toggleChatColumn();}" ) - self.chat_panel.chatbot.like( - fn=self.is_liked, - inputs=[self.chat_control.conversation_id], - outputs=None, - ) - self.chat_control.btn_new.click( - self.chat_control.new_conv, - inputs=self._app.user_id, - outputs=[self.chat_control.conversation_id, self.chat_control.conversation], - show_progress="hidden", - ).then( - self.chat_control.select_conv, - inputs=[self.chat_control.conversation, self._app.user_id], - outputs=[ - self.chat_control.conversation_id, - self.chat_control.conversation, - self.chat_control.conversation_rn, - self.chat_panel.chatbot, - self.followup_questions, - self.info_panel, - self.state_plot_panel, - self.state_retrieval_history, - self.state_plot_history, - self.chat_control.cb_is_public, - self.state_chat, - ] - + self._indices_input, - show_progress="hidden", - ).then( - fn=self._json_to_plot, - inputs=self.state_plot_panel, - outputs=self.plot_panel, - ).then( - fn=None, - inputs=None, - js=chat_input_focus_js, + if KH_DEMO_MODE: + self.chat_control.btn_demo_logout.click( + fn=None, + js=self.chat_control.logout_js, + ) + self.chat_control.btn_new.click( + fn=lambda: self.chat_control.select_conv("", None), + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.followup_questions, + self.info_panel, + self.state_plot_panel, + self.state_retrieval_history, + self.state_plot_history, + self.chat_control.cb_is_public, + self.state_chat, + ] + + self._indices_input, + ).then( + lambda: (gr.update(visible=False), gr.update(visible=True)), + outputs=[self.paper_list.accordion, self.chat_settings], + ).then( + fn=None, + inputs=None, + js=chat_input_focus_js, + ) + + if not KH_DEMO_MODE: + self.chat_control.btn_new.click( + self.chat_control.new_conv, + inputs=self._app.user_id, + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + ], + show_progress="hidden", + ).then( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation, self._app.user_id], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.followup_questions, + self.info_panel, + self.state_plot_panel, + self.state_retrieval_history, + self.state_plot_history, + self.chat_control.cb_is_public, + self.state_chat, + ] + + self._indices_input, + show_progress="hidden", + ).then( + fn=self._json_to_plot, + inputs=self.state_plot_panel, + outputs=self.plot_panel, + ).then( + fn=None, + inputs=None, + js=chat_input_focus_js, + ) + + self.chat_control.btn_del.click( + lambda id: self.toggle_delete(id), + inputs=[self.chat_control.conversation_id], + outputs=[ + self.chat_control._new_delete, + self.chat_control._delete_confirm, + ], + ) + self.chat_control.btn_del_conf.click( + self.chat_control.delete_conv, + inputs=[self.chat_control.conversation_id, self._app.user_id], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + ], + show_progress="hidden", + ).then( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation, self._app.user_id], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.followup_questions, + self.info_panel, + self.state_plot_panel, + self.state_retrieval_history, + self.state_plot_history, + self.chat_control.cb_is_public, + self.state_chat, + ] + + self._indices_input, + show_progress="hidden", + ).then( + fn=self._json_to_plot, + inputs=self.state_plot_panel, + outputs=self.plot_panel, + ).then( + lambda: self.toggle_delete(""), + outputs=[ + self.chat_control._new_delete, + self.chat_control._delete_confirm, + ], + ) + self.chat_control.btn_del_cnl.click( + lambda: self.toggle_delete(""), + outputs=[ + self.chat_control._new_delete, + self.chat_control._delete_confirm, + ], + ) + self.chat_control.btn_conversation_rn.click( + lambda: gr.update(visible=True), + outputs=[ + self.chat_control.conversation_rn, + ], + ) + self.chat_control.conversation_rn.submit( + self.chat_control.rename_conv, + inputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation_rn, + gr.State(value=True), + self._app.user_id, + ], + outputs=[ + self.chat_control.conversation, + self.chat_control.conversation, + self.chat_control.conversation_rn, + ], + show_progress="hidden", + ) + + onConvSelect = ( + self.chat_control.conversation.select( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation, self._app.user_id], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.followup_questions, + self.info_panel, + self.state_plot_panel, + self.state_retrieval_history, + self.state_plot_history, + self.chat_control.cb_is_public, + self.state_chat, + ] + + self._indices_input, + show_progress="hidden", + ) + .then( + fn=self._json_to_plot, + inputs=self.state_plot_panel, + outputs=self.plot_panel, + ) + .then( + lambda: self.toggle_delete(""), + outputs=[ + self.chat_control._new_delete, + self.chat_control._delete_confirm, + ], + ) ) - self.chat_control.btn_del.click( - lambda id: self.toggle_delete(id), - inputs=[self.chat_control.conversation_id], - outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], - ) - self.chat_control.btn_del_conf.click( - self.chat_control.delete_conv, - inputs=[self.chat_control.conversation_id, self._app.user_id], - outputs=[self.chat_control.conversation_id, self.chat_control.conversation], - show_progress="hidden", - ).then( - self.chat_control.select_conv, - inputs=[self.chat_control.conversation, self._app.user_id], - outputs=[ - self.chat_control.conversation_id, - self.chat_control.conversation, - self.chat_control.conversation_rn, - self.chat_panel.chatbot, - self.followup_questions, - self.info_panel, - self.state_plot_panel, - self.state_retrieval_history, - self.state_plot_history, - self.chat_control.cb_is_public, - self.state_chat, - ] - + self._indices_input, - show_progress="hidden", - ).then( - fn=self._json_to_plot, - inputs=self.state_plot_panel, - outputs=self.plot_panel, - ).then( - lambda: self.toggle_delete(""), - outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], - ) - self.chat_control.btn_del_cnl.click( - lambda: self.toggle_delete(""), - outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], - ) - self.chat_control.btn_conversation_rn.click( - lambda: gr.update(visible=True), - outputs=[ - self.chat_control.conversation_rn, - ], - ) - self.chat_control.conversation_rn.submit( - self.chat_control.rename_conv, - inputs=[ - self.chat_control.conversation_id, - self.chat_control.conversation_rn, - gr.State(value=True), - self._app.user_id, - ], - outputs=[ - self.chat_control.conversation, - self.chat_control.conversation, - self.chat_control.conversation_rn, - ], - show_progress="hidden", + if KH_DEMO_MODE: + onConvSelect = onConvSelect.then( + lambda: (gr.update(visible=False), gr.update(visible=True)), + outputs=[self.paper_list.accordion, self.chat_settings], + ) + + onConvSelect = ( + onConvSelect.then( + fn=lambda: True, + js=clear_bot_message_selection_js, + ) + .then( + fn=lambda: True, + inputs=None, + outputs=[self._preview_links], + js=pdfview_js, + ) + .then(fn=None, inputs=None, outputs=None, js=chat_input_focus_js) ) - self.chat_control.conversation.select( - self.chat_control.select_conv, - inputs=[self.chat_control.conversation, self._app.user_id], - outputs=[ - self.chat_control.conversation_id, - self.chat_control.conversation, - self.chat_control.conversation_rn, - self.chat_panel.chatbot, - self.followup_questions, - self.info_panel, - self.state_plot_panel, - self.state_retrieval_history, - self.state_plot_history, - self.chat_control.cb_is_public, - self.state_chat, - ] - + self._indices_input, - show_progress="hidden", - ).then( - fn=self._json_to_plot, - inputs=self.state_plot_panel, - outputs=self.plot_panel, - ).then( - lambda: self.toggle_delete(""), - outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], - ).then( - fn=lambda: True, - inputs=None, - outputs=[self._preview_links], - js=pdfview_js, - ).then( - fn=None, inputs=None, outputs=None, js=chat_input_focus_js - ) - - # evidence display on message selection - self.chat_panel.chatbot.select( - self.message_selected, - inputs=[ - self.state_retrieval_history, - self.state_plot_history, - ], - outputs=[ - self.info_panel, - self.state_plot_panel, - ], - ).then( - fn=self._json_to_plot, - inputs=self.state_plot_panel, - outputs=self.plot_panel, - ).then( - fn=lambda: True, - inputs=None, - outputs=[self._preview_links], - js=pdfview_js, - ).then( - fn=None, inputs=None, outputs=None, js=chat_input_focus_js - ) + if not KH_DEMO_MODE: + # evidence display on message selection + self.chat_panel.chatbot.select( + self.message_selected, + inputs=[ + self.state_retrieval_history, + self.state_plot_history, + ], + outputs=[ + self.info_panel, + self.state_plot_panel, + ], + ).then( + fn=self._json_to_plot, + inputs=self.state_plot_panel, + outputs=self.plot_panel, + ).then( + fn=lambda: True, + inputs=None, + outputs=[self._preview_links], + js=pdfview_js, + ) self.chat_control.cb_is_public.change( self.on_set_public_conversation, @@ -597,22 +772,30 @@ class ChatPage(BasePage): show_progress="hidden", ) - self.report_issue.report_btn.click( - self.report_issue.report, - inputs=[ - self.report_issue.correctness, - self.report_issue.issues, - self.report_issue.more_detail, - self.chat_control.conversation_id, - self.chat_panel.chatbot, - self._app.settings_state, - self._app.user_id, - self.info_panel, - self.state_chat, - ] - + self._indices_input, - outputs=None, - ) + if not KH_DEMO_MODE: + # user feedback events + self.chat_panel.chatbot.like( + fn=self.is_liked, + inputs=[self.chat_control.conversation_id], + outputs=None, + ) + self.report_issue.report_btn.click( + self.report_issue.report, + inputs=[ + self.report_issue.correctness, + self.report_issue.issues, + self.report_issue.more_detail, + self.chat_control.conversation_id, + self.chat_panel.chatbot, + self._app.settings_state, + self._app.user_id, + self.info_panel, + self.state_chat, + ] + + self._indices_input, + outputs=None, + ) + self.reasoning_type.change( self.reasoning_changed, inputs=[self.reasoning_type], @@ -624,11 +807,25 @@ class ChatPage(BasePage): outputs=[self.use_mindmap, self.use_mindmap_check], show_progress="hidden", ) - self.use_chat_suggestion.change( - lambda x: (x, gr.update(visible=x)), - inputs=[self.use_chat_suggestion], + + def toggle_chat_suggestion(current_state): + return current_state, gr.update(visible=current_state) + + def raise_error_on_state(state): + if not state: + raise ValueError("Chat suggestion disabled") + + self.chat_control.cb_suggest_chat.change( + fn=toggle_chat_suggestion, + inputs=[self.chat_control.cb_suggest_chat], outputs=[self._use_suggestion, self.followup_questions_ui], show_progress="hidden", + ).then( + fn=raise_error_on_state, + inputs=[self._use_suggestion], + show_progress="hidden", + ).success( + **onSuggestChatEvent ) self.chat_control.conversation_id.change( lambda: gr.update(visible=False), @@ -636,7 +833,7 @@ class ChatPage(BasePage): ) self.followup_questions.select( - self.chat_control.chat_suggestion.select_example, + self.chat_suggestion.select_example, outputs=[self.chat_panel.text_input], show_progress="hidden", ).then( @@ -646,6 +843,22 @@ class ChatPage(BasePage): js=chat_input_focus_js, ) + if KH_DEMO_MODE: + self.paper_list.examples.select( + self.paper_list.select_example, + inputs=[self.paper_list.papers_state], + outputs=[self.quick_urls], + show_progress="hidden", + ).then( + lambda: (gr.update(visible=False), gr.update(visible=True)), + outputs=[self.paper_list.accordion, self.chat_settings], + ).then( + fn=None, + inputs=None, + outputs=None, + js=quick_urls_submit_js, + ) + def submit_msg( self, chat_input, @@ -655,8 +868,13 @@ class ChatPage(BasePage): conv_id, conv_name, first_selector_choices, + request: gr.Request, ): """Submit a message to the chatbot""" + if KH_DEMO_MODE: + sso_user_id = check_rate_limit("chat", request) + print("User ID:", sso_user_id) + if not chat_input: raise ValueError("Input is empty") @@ -685,6 +903,7 @@ class ChatPage(BasePage): True, settings, user_id, + request=None, ) elif file_names: for file_name in file_names: @@ -698,7 +917,11 @@ class ChatPage(BasePage): # if file_ids is not empty and chat_input_text is empty # set the input to summary if not chat_input_text and file_ids: - chat_input_text = "Summary" + chat_input_text = DEFAULT_QUESTION + + # if start of conversation and no query is specified + if not chat_input_text and not chat_history: + chat_input_text = DEFAULT_QUESTION if file_ids: selector_output = [ @@ -716,13 +939,16 @@ class ChatPage(BasePage): raise gr.Error("Empty chat") if not conv_id: - id_, update = self.chat_control.new_conv(user_id) - with Session(engine) as session: - statement = select(Conversation).where(Conversation.id == id_) - name = session.exec(statement).one().name - new_conv_id = id_ - conv_update = update - new_conv_name = name + if not KH_DEMO_MODE: + id_, update = self.chat_control.new_conv(user_id) + with Session(engine) as session: + statement = select(Conversation).where(Conversation.id == id_) + name = session.exec(statement).one().name + new_conv_id = id_ + conv_update = update + new_conv_name = name + else: + new_conv_id, new_conv_name, conv_update = None, None, gr.update() else: new_conv_id = conv_id conv_update = gr.update() @@ -740,6 +966,17 @@ class ChatPage(BasePage): + [used_command] ) + def get_recommendations(self, first_selector_choices, file_ids): + first_selector_choices_map = { + item[1]: item[0] for item in first_selector_choices + } + file_names = [first_selector_choices_map[file_id] for file_id in file_ids] + if not file_names: + return "" + + first_file_name = file_names[0].split(".")[0].replace("_", " ") + return get_recommended_papers(first_file_name) + def toggle_delete(self, conv_id): if conv_id: return gr.update(visible=False), gr.update(visible=True) @@ -789,17 +1026,41 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.followup_questions, self.info_panel, self.state_plot_panel, self.state_retrieval_history, self.state_plot_history, self.chat_control.cb_is_public, + self.state_chat, ] + self._indices_input, "show_progress": "hidden", }, ) + def _on_app_created(self): + if KH_DEMO_MODE: + self._app.app.load( + fn=lambda x: x, + inputs=[self._user_api_key], + outputs=[self._user_api_key], + js=fetch_api_key_js, + ).then( + fn=self.chat_control.toggle_demo_login_visibility, + inputs=[self._user_api_key], + outputs=[ + self.chat_control.cb_suggest_chat, + self.chat_control.btn_new, + self.chat_control.btn_demo_logout, + self.chat_control.btn_demo_login, + ], + ).then( + fn=None, + inputs=None, + js=chat_input_focus_js, + ) + def persist_data_source( self, convo_id, @@ -1106,13 +1367,24 @@ class ChatPage(BasePage): return new_name, renamed - def suggest_chat_conv(self, settings, chat_history, use_suggestion): + def suggest_chat_conv( + self, + settings, + session_language, + chat_history, + use_suggestion, + ): + target_language = ( + session_language + if session_language not in (DEFAULT_SETTING, None) + else settings["reasoning.lang"] + ) if use_suggestion: suggest_pipeline = SuggestFollowupQuesPipeline() suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( - settings["reasoning.lang"], "English" + target_language, "English" ) - suggested_questions = [] + suggested_questions = [[each] for each in ChatSuggestion.CHAT_SAMPLES] if len(chat_history) >= 1: suggested_resp = suggest_pipeline(chat_history).text diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index 4b54648..7ced90b 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -1,5 +1,21 @@ import gradio as gr from ktem.app import BasePage +from theflow.settings import settings as flowsettings + +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) + +if not KH_DEMO_MODE: + PLACEHOLDER_TEXT = ( + "This is the beginning of a new conversation.\n" + "Start by uploading a file or a web URL. " + "Visit Files tab for more options (e.g: GraphRAG)." + ) +else: + PLACEHOLDER_TEXT = ( + "Welcome to Kotaemon Demo. " + "Start by browsing preloaded conversations to get onboard.\n" + "Check out Hint section for more tips." + ) class ChatPanel(BasePage): @@ -10,10 +26,7 @@ class ChatPanel(BasePage): def on_building_ui(self): self.chatbot = gr.Chatbot( label=self._app.app_name, - placeholder=( - "This is the beginning of a new conversation.\nIf you are new, " - "visit the Help tab for quick instructions." - ), + placeholder=PLACEHOLDER_TEXT, show_label=False, elem_id="main-chat-bot", show_copy_button=True, diff --git a/libs/ktem/ktem/pages/chat/chat_suggestion.py b/libs/ktem/ktem/pages/chat/chat_suggestion.py index 2bfe03c..5676556 100644 --- a/libs/ktem/ktem/pages/chat/chat_suggestion.py +++ b/libs/ktem/ktem/pages/chat/chat_suggestion.py @@ -4,29 +4,34 @@ from theflow.settings import settings as flowsettings class ChatSuggestion(BasePage): + CHAT_SAMPLES = getattr( + flowsettings, + "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", + [ + "Summary this document", + "Generate a FAQ for this document", + "Identify the main highlights in bullet points", + ], + ) + def __init__(self, app): self._app = app self.on_building_ui() def on_building_ui(self): - chat_samples = getattr( - flowsettings, - "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", - [ - "Summary this document", - "Generate a FAQ for this document", - "Identify the main highlights in this text", - ], - ) - self.chat_samples = [[each] for each in chat_samples] + self.chat_samples = [[each] for each in self.CHAT_SAMPLES] with gr.Accordion( label="Chat Suggestion", visible=getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False), ) as self.accordion: + self.default_example = gr.State( + value=self.chat_samples, + ) self.examples = gr.DataFrame( value=self.chat_samples, headers=["Next Question"], interactive=False, + elem_id="chat-suggestion", wrap=True, ) diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index 44b8835..fb2b50c 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -14,11 +14,22 @@ from .chat_suggestion import ChatSuggestion from .common import STATE logger = logging.getLogger(__name__) + +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False) ASSETS_DIR = "assets/icons" if not os.path.isdir(ASSETS_DIR): ASSETS_DIR = "libs/ktem/ktem/assets/icons" +logout_js = """ +function () { + removeFromStorage('google_api_key'); + window.location.href = "/logout"; +} +""" + + def is_conv_name_valid(name): """Check if the conversation name is valid""" errors = [] @@ -35,11 +46,13 @@ class ConversationControl(BasePage): def __init__(self, app): self._app = app + self.logout_js = logout_js self.on_building_ui() def on_building_ui(self): with gr.Row(): - gr.Markdown("## Conversations") + title_text = "Conversations" if not KH_DEMO_MODE else "Kotaemon Papers" + gr.Markdown("## {}".format(title_text)) self.btn_toggle_dark_mode = gr.Button( value="", icon=f"{ASSETS_DIR}/dark_mode.svg", @@ -83,42 +96,88 @@ class ConversationControl(BasePage): filterable=True, interactive=True, elem_classes=["unset-overflow"], + elem_id="conversation-dropdown", ) with gr.Row() as self._new_delete: + self.cb_suggest_chat = gr.Checkbox( + value=False, + label="Suggest chat", + min_width=10, + scale=6, + elem_id="suggest-chat-checkbox", + container=False, + visible=not KH_DEMO_MODE, + ) self.cb_is_public = gr.Checkbox( value=False, - label="Shared", - min_width=10, - scale=4, + label="Share this conversation", elem_id="is-public-checkbox", container=False, + visible=not KH_DEMO_MODE and not KH_SSO_ENABLED, ) - self.btn_conversation_rn = gr.Button( - value="", - icon=f"{ASSETS_DIR}/rename.svg", - min_width=2, - scale=1, - size="sm", - elem_classes=["no-background", "body-text-color"], - ) - self.btn_del = gr.Button( - value="", - icon=f"{ASSETS_DIR}/delete.svg", - min_width=2, - scale=1, - size="sm", - elem_classes=["no-background", "body-text-color"], - ) - self.btn_new = gr.Button( - value="", - icon=f"{ASSETS_DIR}/new.svg", - min_width=2, - scale=1, - size="sm", - elem_classes=["no-background", "body-text-color"], - elem_id="new-conv-button", - ) + + if not KH_DEMO_MODE: + self.btn_conversation_rn = gr.Button( + value="", + icon=f"{ASSETS_DIR}/rename.svg", + min_width=2, + scale=1, + size="sm", + elem_classes=["no-background", "body-text-color"], + ) + self.btn_del = gr.Button( + value="", + icon=f"{ASSETS_DIR}/delete.svg", + min_width=2, + scale=1, + size="sm", + elem_classes=["no-background", "body-text-color"], + ) + self.btn_new = gr.Button( + value="", + icon=f"{ASSETS_DIR}/new.svg", + min_width=2, + scale=1, + size="sm", + elem_classes=["no-background", "body-text-color"], + elem_id="new-conv-button", + ) + else: + self.btn_new = gr.Button( + value="New chat", + min_width=120, + size="sm", + scale=1, + variant="primary", + elem_id="new-conv-button", + visible=False, + ) + + if KH_DEMO_MODE: + with gr.Row(): + self.btn_demo_login = gr.Button( + "Sign-in to create new chat", + min_width=120, + size="sm", + scale=1, + variant="primary", + ) + _js_redirect = """ + () => { + url = '/login' + window.location.search; + window.open(url, '_blank'); + } + """ + self.btn_demo_login.click(None, js=_js_redirect) + + self.btn_demo_logout = gr.Button( + "Sign-out", + min_width=120, + size="sm", + scale=1, + visible=False, + ) with gr.Row(visible=False) as self._delete_confirm: self.btn_del_conf = gr.Button( @@ -139,8 +198,6 @@ class ConversationControl(BasePage): visible=False, ) - self.chat_suggestion = ChatSuggestion(self._app) - def load_chat_history(self, user_id): """Reload chat history""" @@ -241,6 +298,8 @@ class ConversationControl(BasePage): def select_conv(self, conversation_id, user_id): """Select the conversation""" + default_chat_suggestions = [[each] for each in ChatSuggestion.CHAT_SAMPLES] + with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) try: @@ -257,7 +316,9 @@ class ConversationControl(BasePage): selected = {} chats = result.data_source.get("messages", []) - chat_suggestions = result.data_source.get("chat_suggestions", []) + chat_suggestions = result.data_source.get( + "chat_suggestions", default_chat_suggestions + ) retrieval_history: list[str] = result.data_source.get( "retrieval_messages", [] @@ -282,7 +343,7 @@ class ConversationControl(BasePage): name = "" selected = {} chats = [] - chat_suggestions = [] + chat_suggestions = default_chat_suggestions retrieval_history = [] plot_history = [] info_panel = "" @@ -317,25 +378,21 @@ class ConversationControl(BasePage): def rename_conv(self, conversation_id, new_name, is_renamed, user_id): """Rename the conversation""" - if not is_renamed: + if not is_renamed or KH_DEMO_MODE or user_id is None or not conversation_id: return ( gr.update(), conversation_id, gr.update(visible=False), ) - if user_id is None: - gr.Warning("Please sign in first (Settings → User Settings)") - return gr.update(), "" - - if not conversation_id: - gr.Warning("No conversation selected.") - return gr.update(), "" - errors = is_conv_name_valid(new_name) if errors: gr.Warning(errors) - return gr.update(), conversation_id + return ( + gr.update(), + conversation_id, + gr.update(visible=False), + ) with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) @@ -382,6 +439,29 @@ class ConversationControl(BasePage): gr.Info("Chat suggestions updated.") + def toggle_demo_login_visibility(self, user_api_key, request: gr.Request): + try: + import gradiologin as grlogin + + user = grlogin.get_user(request) + except (ImportError, AssertionError): + user = None + + if user: # or user_api_key: + return [ + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=True), + gr.update(visible=False), + ] + else: + return [ + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=True), + ] + def _on_app_created(self): """Reload the conversation once the app is created""" self._app.app.load( diff --git a/libs/ktem/ktem/pages/chat/demo_hint.py b/libs/ktem/ktem/pages/chat/demo_hint.py new file mode 100644 index 0000000..4870546 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/demo_hint.py @@ -0,0 +1,23 @@ +from textwrap import dedent + +import gradio as gr +from ktem.app import BasePage + + +class HintPage(BasePage): + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + with gr.Accordion(label="Hint", open=False): + gr.Markdown( + dedent( + """ + - You can select any text from the chat answer to **highlight relevant citation(s)** on the right panel. + - **Citations** can be viewed on both PDF viewer and raw text. + - You can tweak the citation format and use advance (CoT) reasoning in **Chat settings** menu. + - Want to **explore more**? Check out the **Help** section to create your private space. + """ # noqa + ) + ) diff --git a/libs/ktem/ktem/pages/chat/paper_list.py b/libs/ktem/ktem/pages/chat/paper_list.py new file mode 100644 index 0000000..bddf4a4 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/paper_list.py @@ -0,0 +1,41 @@ +import gradio as gr +from ktem.app import BasePage +from pandas import DataFrame + +from ...utils.hf_papers import fetch_papers + + +class PaperListPage(BasePage): + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + self.papers_state = gr.State(None) + with gr.Accordion( + label="Browse popular daily papers", + open=True, + ) as self.accordion: + self.examples = gr.DataFrame( + value=[], + headers=["title", "url", "upvotes"], + column_widths=[60, 30, 10], + interactive=False, + elem_id="paper-suggestion", + wrap=True, + ) + return self.examples + + def load(self): + papers = fetch_papers(top_n=5) + papers_df = DataFrame(papers) + return (papers_df, papers) + + def _on_app_created(self): + self._app.app.load( + self.load, + outputs=[self.examples, self.papers_state], + ) + + def select_example(self, state, ev: gr.SelectData): + return state[ev.index[0]]["url"] diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py index f404743..b269f4e 100644 --- a/libs/ktem/ktem/pages/chat/report.py +++ b/libs/ktem/ktem/pages/chat/report.py @@ -12,7 +12,7 @@ class ReportIssue(BasePage): self.on_building_ui() def on_building_ui(self): - with gr.Accordion(label="Feedback", open=False): + with gr.Accordion(label="Feedback", open=False, elem_id="report-accordion"): self.correctness = gr.Radio( choices=[ ("The answer is correct", "correct"), diff --git a/libs/ktem/ktem/pages/help.py b/libs/ktem/ktem/pages/help.py index e3438d1..2ecdf7e 100644 --- a/libs/ktem/ktem/pages/help.py +++ b/libs/ktem/ktem/pages/help.py @@ -3,8 +3,12 @@ from pathlib import Path import gradio as gr import requests +from decouple import config from theflow.settings import settings +KH_DEMO_MODE = getattr(settings, "KH_DEMO_MODE", False) +HF_SPACE_URL = config("HF_SPACE_URL", default="") + def get_remote_doc(url: str) -> str: try: @@ -59,6 +63,22 @@ class HelpPage: about_md = f"Version: {self.app_version}\n\n{about_md}" gr.Markdown(about_md) + if KH_DEMO_MODE: + with gr.Accordion("Create Your Own Space"): + gr.Markdown( + "This is a demo with limited functionality. " + "Use **Create space** button to install Kotaemon " + "in your own space with all features " + "(including upload and manage your private " + "documents securely)." + ) + gr.Button( + value="Create Your Own Space", + link=HF_SPACE_URL, + variant="primary", + size="lg", + ) + user_guide_md_dir = self.doc_dir / "usage.md" if user_guide_md_dir.exists(): with (self.doc_dir / "usage.md").open(encoding="utf-8") as fi: @@ -68,7 +88,7 @@ class HelpPage: f"{self.remote_content_url}/v{self.app_version}/docs/usage.md" ) if user_guide_md: - with gr.Accordion("User Guide"): + with gr.Accordion("User Guide", open=not KH_DEMO_MODE): gr.Markdown(user_guide_md) if self.app_version: diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py index 9dc4839..1455cb8 100644 --- a/libs/ktem/ktem/pages/login.py +++ b/libs/ktem/ktem/pages/login.py @@ -3,6 +3,7 @@ import hashlib import gradio as gr from ktem.app import BasePage from ktem.db.models import User, engine +from ktem.pages.resources.user import create_user from sqlmodel import Session, select fetch_creds = """ @@ -85,19 +86,47 @@ class LoginPage(BasePage): }, ) - def login(self, usn, pwd): - if not usn or not pwd: - return None, usn, pwd + def login(self, usn, pwd, request: gr.Request): + try: + import gradiologin as grlogin + + user = grlogin.get_user(request) + except (ImportError, AssertionError): + user = None + + if user: + user_id = user["sub"] + with Session(engine) as session: + stmt = select(User).where( + User.id == user_id, + ) + result = session.exec(stmt).all() - hashed_password = hashlib.sha256(pwd.encode()).hexdigest() - with Session(engine) as session: - stmt = select(User).where( - User.username_lower == usn.lower().strip(), - User.password == hashed_password, - ) - result = session.exec(stmt).all() if result: - return result[0].id, "", "" + print("Existing user:", user) + return user_id, "", "" + else: + print("Creating new user:", user) + create_user( + usn=user["email"], + pwd="", + user_id=user_id, + is_admin=False, + ) + return user_id, "", "" + else: + if not usn or not pwd: + return None, usn, pwd - gr.Warning("Invalid username or password") - return None, usn, pwd + hashed_password = hashlib.sha256(pwd.encode()).hexdigest() + with Session(engine) as session: + stmt = select(User).where( + User.username_lower == usn.lower().strip(), + User.password == hashed_password, + ) + result = session.exec(stmt).all() + if result: + return result[0].id, "", "" + + gr.Warning("Invalid username or password") + return None, usn, pwd diff --git a/libs/ktem/ktem/pages/resources/user.py b/libs/ktem/ktem/pages/resources/user.py index 106c268..49835ce 100644 --- a/libs/ktem/ktem/pages/resources/user.py +++ b/libs/ktem/ktem/pages/resources/user.py @@ -94,7 +94,7 @@ def validate_password(pwd, pwd_cnf): return "" -def create_user(usn, pwd) -> bool: +def create_user(usn, pwd, user_id=None, is_admin=True) -> bool: with Session(engine) as session: statement = select(User).where(User.username_lower == usn.lower()) result = session.exec(statement).all() @@ -105,10 +105,11 @@ def create_user(usn, pwd) -> bool: else: hashed_password = hashlib.sha256(pwd.encode()).hexdigest() user = User( + id=user_id, username=usn, username_lower=usn.lower(), password=hashed_password, - admin=True, + admin=is_admin, ) session.add(user) session.commit() @@ -136,11 +137,12 @@ class UserManagement(BasePage): self.state_user_list = gr.State(value=None) self.user_list = gr.DataFrame( headers=["id", "name", "admin"], + column_widths=[0, 50, 50], interactive=False, ) with gr.Group(visible=False) as self._selected_panel: - self.selected_user_id = gr.Number(value=-1, visible=False) + self.selected_user_id = gr.State(value=-1) self.usn_edit = gr.Textbox(label="Username") with gr.Row(): self.pwd_edit = gr.Textbox(label="Change password", type="password") @@ -346,7 +348,7 @@ class UserManagement(BasePage): if not ev.selected: return -1 - return int(user_list["id"][ev.index[0]]) + return user_list["id"][ev.index[0]] def on_selected_user_change(self, selected_user_id): if selected_user_id == -1: @@ -367,7 +369,7 @@ class UserManagement(BasePage): btn_delete_no = gr.update(visible=False) with Session(engine) as session: - statement = select(User).where(User.id == int(selected_user_id)) + statement = select(User).where(User.id == selected_user_id) user = session.exec(statement).one() usn_edit = gr.update(value=user.username) @@ -414,7 +416,7 @@ class UserManagement(BasePage): return pwd, pwd_cnf with Session(engine) as session: - statement = select(User).where(User.id == int(selected_user_id)) + statement = select(User).where(User.id == selected_user_id) user = session.exec(statement).one() user.username = usn user.username_lower = usn.lower() @@ -432,7 +434,7 @@ class UserManagement(BasePage): return selected_user_id with Session(engine) as session: - statement = select(User).where(User.id == int(selected_user_id)) + statement = select(User).where(User.id == selected_user_id) user = session.exec(statement).one() session.delete(user) session.commit() diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index f899a45..0a211d0 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -5,6 +5,10 @@ from ktem.app import BasePage from ktem.components import reasonings from ktem.db.models import Settings, User, engine from sqlmodel import Session, select +from theflow.settings import settings as flowsettings + +KH_SSO_ENABLED = getattr(flowsettings, "KH_SSO_ENABLED", False) + signout_js = """ function(u, c, pw, pwc) { @@ -80,38 +84,44 @@ class SettingsPage(BasePage): # render application page if there are application settings self._render_app_tab = False - if self._default_settings.application.settings: + + if not KH_SSO_ENABLED and self._default_settings.application.settings: self._render_app_tab = True # render index page if there are index settings (general and/or specific) self._render_index_tab = False - if self._default_settings.index.settings: - self._render_index_tab = True - else: - for sig in self._default_settings.index.options.values(): - if sig.settings: - self._render_index_tab = True - break + + if not KH_SSO_ENABLED: + if self._default_settings.index.settings: + self._render_index_tab = True + else: + for sig in self._default_settings.index.options.values(): + if sig.settings: + self._render_index_tab = True + break # render reasoning page if there are reasoning settings self._render_reasoning_tab = False - if len(self._default_settings.reasoning.settings) > 1: - self._render_reasoning_tab = True - else: - for sig in self._default_settings.reasoning.options.values(): - if sig.settings: - self._render_reasoning_tab = True - break + + if not KH_SSO_ENABLED: + if len(self._default_settings.reasoning.settings) > 1: + self._render_reasoning_tab = True + else: + for sig in self._default_settings.reasoning.options.values(): + if sig.settings: + self._render_reasoning_tab = True + break self.on_building_ui() def on_building_ui(self): - self.setting_save_btn = gr.Button( - "Save & Close", - variant="primary", - elem_classes=["right-button"], - elem_id="save-setting-btn", - ) + if not KH_SSO_ENABLED: + self.setting_save_btn = gr.Button( + "Save & Close", + variant="primary", + elem_classes=["right-button"], + elem_id="save-setting-btn", + ) if self._app.f_user_management: with gr.Tab("User settings"): self.user_tab() @@ -175,21 +185,22 @@ class SettingsPage(BasePage): ) def on_register_events(self): - self.setting_save_btn.click( - self.save_setting, - inputs=[self._user_id] + self.components(), - outputs=self._settings_state, - ).then( - lambda: gr.Tabs(selected="chat-tab"), - outputs=self._app.tabs, - ) + if not KH_SSO_ENABLED: + self.setting_save_btn.click( + self.save_setting, + inputs=[self._user_id] + self.components(), + outputs=self._settings_state, + ).then( + lambda: gr.Tabs(selected="chat-tab"), + outputs=self._app.tabs, + ) self._components["reasoning.use"].change( self.change_reasoning_mode, inputs=[self._components["reasoning.use"]], outputs=list(self._reasoning_mode.values()), show_progress="hidden", ) - if self._app.f_user_management: + if self._app.f_user_management and not KH_SSO_ENABLED: self.password_change_btn.click( self.change_password, inputs=[ @@ -223,15 +234,21 @@ class SettingsPage(BasePage): def user_tab(self): # user management self.current_name = gr.Markdown("Current user: ___") - self.signout = gr.Button("Logout") - self.password_change = gr.Textbox( - label="New password", interactive=True, type="password" - ) - self.password_change_confirm = gr.Textbox( - label="Confirm password", interactive=True, type="password" - ) - self.password_change_btn = gr.Button("Change password", interactive=True) + if KH_SSO_ENABLED: + import gradiologin as grlogin + + self.sso_signout = grlogin.LogoutButton("Logout") + else: + self.signout = gr.Button("Logout") + + self.password_change = gr.Textbox( + label="New password", interactive=True, type="password" + ) + self.password_change_confirm = gr.Textbox( + label="Confirm password", interactive=True, type="password" + ) + self.password_change_btn = gr.Button("Change password", interactive=True) def change_password(self, user_id, password, password_confirm): from ktem.pages.resources.user import validate_password diff --git a/libs/ktem/ktem/pages/setup.py b/libs/ktem/ktem/pages/setup.py index 21efa5d..09e3b6a 100644 --- a/libs/ktem/ktem/pages/setup.py +++ b/libs/ktem/ktem/pages/setup.py @@ -2,13 +2,13 @@ import json import gradio as gr import requests +from decouple import config from ktem.app import BasePage from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.llms.manager import llms from ktem.rerankings.manager import reranking_models_manager as rerankers from theflow.settings import settings as flowsettings -KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/") DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api") if DEFAULT_OLLAMA_URL.endswith("/"): @@ -113,9 +113,18 @@ class SetupPage(BasePage): ( "#### Setup Ollama\n\n" "Download and install Ollama from " - "https://ollama.com/" + "https://ollama.com/. Check out latest models at " + "https://ollama.com/library. " ) ) + self.ollama_model_name = gr.Textbox( + label="LLM model name", + value=config("LOCAL_MODEL", default="qwen2.5:7b"), + ) + self.ollama_emb_model_name = gr.Textbox( + label="Embedding model name", + value=config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"), + ) self.setup_log = gr.HTML( show_label=False, @@ -139,22 +148,23 @@ class SetupPage(BasePage): self.cohere_api_key, self.openai_api_key, self.google_api_key, + self.ollama_model_name, + self.ollama_emb_model_name, self.radio_model, ], outputs=[self.setup_log], show_progress="hidden", ) - if not KH_DEMO_MODE: - onSkipSetup = gr.on( - triggers=[self.btn_skip.click], - fn=lambda: None, - inputs=[], - show_progress="hidden", - outputs=[self.radio_model], - ) + onSkipSetup = gr.on( + triggers=[self.btn_skip.click], + fn=lambda: None, + inputs=[], + show_progress="hidden", + outputs=[self.radio_model], + ) - for event in self._app.get_event("onFirstSetupComplete"): - onSkipSetup = onSkipSetup.success(**event) + for event in self._app.get_event("onFirstSetupComplete"): + onSkipSetup = onSkipSetup.success(**event) onFirstSetupComplete = onFirstSetupComplete.success( fn=self.update_default_settings, @@ -181,12 +191,10 @@ class SetupPage(BasePage): cohere_api_key, openai_api_key, google_api_key, + ollama_model_name, + ollama_emb_model_name, radio_model_value, ): - # skip if KH_DEMO_MODE - if KH_DEMO_MODE: - raise gr.Error(DEMO_MESSAGE) - log_content = "" if not radio_model_value: gr.Info("Skip setup models.") @@ -274,7 +282,7 @@ class SetupPage(BasePage): spec={ "__type__": "kotaemon.llms.ChatOpenAI", "base_url": KH_OLLAMA_URL, - "model": "llama3.1:8b", + "model": ollama_model_name, "api_key": "ollama", }, default=True, @@ -284,7 +292,7 @@ class SetupPage(BasePage): spec={ "__type__": "kotaemon.embeddings.OpenAIEmbeddings", "base_url": KH_OLLAMA_URL, - "model": "nomic-embed-text", + "model": ollama_emb_model_name, "api_key": "ollama", }, default=True, diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py b/libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py index 9e134c4..42049ed 100644 --- a/libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py +++ b/libs/ktem/ktem/reasoning/prompt_optimization/mindmap.py @@ -1,4 +1,5 @@ import logging +from textwrap import dedent from ktem.llms.manager import llms @@ -8,6 +9,31 @@ from kotaemon.llms import ChatLLM, PromptTemplate logger = logging.getLogger(__name__) +MINDMAP_HTML_EXPORT_TEMPLATE = dedent( + """ + + + + + + + Mindmap + + + + + {markmap_div} + + +""" +) + + class CreateMindmapPipeline(BaseComponent): """Create a mindmap from the question and context""" @@ -37,6 +63,20 @@ Use the template like this: """ # noqa: E501 prompt_template: str = MINDMAP_PROMPT_TEMPLATE + @classmethod + def convert_uml_to_markdown(cls, text: str) -> str: + start_phrase = "@startmindmap" + end_phrase = "@endmindmap" + + try: + text = text.split(start_phrase)[-1] + text = text.split(end_phrase)[0] + text = text.strip().replace("*", "#") + except IndexError: + text = "" + + return text + def run(self, question: str, context: str) -> Document: # type: ignore prompt_template = PromptTemplate(self.prompt_template) prompt = prompt_template.populate( @@ -49,4 +89,9 @@ Use the template like this: HumanMessage(content=prompt), ] - return self.llm(messages) + uml_text = self.llm(messages).text + markdown_text = self.convert_uml_to_markdown(uml_text) + + return Document( + text=markdown_text, + ) diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py b/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py index 53d46b3..0e07057 100644 --- a/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py +++ b/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py @@ -15,12 +15,10 @@ class SuggestFollowupQuesPipeline(BaseComponent): SUGGEST_QUESTIONS_PROMPT_TEMPLATE = ( "Based on the chat history above. " "your task is to generate 3 to 5 relevant follow-up questions. " - "These questions should be simple, clear, " + "These questions should be simple, very concise, " "and designed to guide the conversation further. " - "Ensure that the questions are open-ended to encourage detailed responses. " "Respond in JSON format with 'questions' key. " "Answer using the language {lang} same as the question. " - "If the question uses Chinese, the answer should be in Chinese.\n" ) prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE extra_prompt: str = """Example of valid response: diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index f90b39c..d11495d 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -1,5 +1,6 @@ import logging import threading +from textwrap import dedent from typing import Generator from ktem.embeddings.manager import embedding_models_manager as embeddings @@ -8,7 +9,6 @@ from ktem.reasoning.prompt_optimization import ( DecomposeQuestionPipeline, RewriteQuestionPipeline, ) -from ktem.utils.plantuml import PlantUML from ktem.utils.render import Render from ktem.utils.visualize_cited import CreateCitationVizPipeline from plotly.io import to_json @@ -165,21 +165,23 @@ class FullQAPipeline(BaseReasoning): mindmap = answer.metadata["mindmap"] if mindmap: mindmap_text = mindmap.text - uml_renderer = PlantUML() - - try: - mindmap_svg = uml_renderer.process(mindmap_text) - except Exception as e: - print("Failed to process mindmap:", e) - mindmap_svg = "" - - # post-process the mindmap SVG - mindmap_svg = ( - mindmap_svg.replace("sans-serif", "Quicksand, sans-serif") - .replace("#181818", "#cecece") - .replace("background:#FFFFF", "background:none") - .replace("stroke-width:1", "stroke-width:2") - ) + mindmap_svg = dedent( + """ +
+ +
+ """ + ).format(mindmap_text) mindmap_content = Document( channel="info", @@ -323,7 +325,7 @@ class FullQAPipeline(BaseReasoning): def prepare_pipeline_instance(cls, settings, retrievers): return cls( retrievers=retrievers, - rewrite_pipeline=RewriteQuestionPipeline(), + rewrite_pipeline=None, ) @classmethod @@ -411,8 +413,8 @@ class FullQAPipeline(BaseReasoning): "value": "highlight", "component": "radio", "choices": [ - ("highlight (verbose)", "highlight"), - ("inline (concise)", "inline"), + ("citation: highlight", "highlight"), + ("citation: inline", "inline"), ("no citation", "off"), ], }, @@ -433,7 +435,7 @@ class FullQAPipeline(BaseReasoning): }, "system_prompt": { "name": "System Prompt", - "value": "This is a question answering system", + "value": ("This is a question answering system."), }, "qa_prompt": { "name": "QA Prompt (contains {context}, {question}, {lang})", diff --git a/libs/ktem/ktem/utils/hf_papers.py b/libs/ktem/ktem/utils/hf_papers.py new file mode 100644 index 0000000..f755671 --- /dev/null +++ b/libs/ktem/ktem/utils/hf_papers.py @@ -0,0 +1,114 @@ +from datetime import datetime, timedelta + +import requests +from cachetools import TTLCache, cached + +HF_API_URL = "https://huggingface.co/api/daily_papers" +ARXIV_URL = "https://arxiv.org/abs/{paper_id}" +SEMANTIC_SCHOLAR_QUERY_URL = "https://api.semanticscholar.org/graph/v1/paper/search/match?query={paper_name}" # noqa +SEMANTIC_SCHOLAR_RECOMMEND_URL = ( + "https://api.semanticscholar.org/recommendations/v1/papers/" # noqa +) +CACHE_TIME = 60 * 60 * 6 # 6 hours + + +# Function to parse the date string +def parse_date(date_str): + return datetime.strptime(date_str, "%Y-%m-%dT%H:%M:%S.%fZ") + + +@cached(cache=TTLCache(maxsize=500, ttl=CACHE_TIME)) +def get_recommendations_from_semantic_scholar(semantic_scholar_id: str): + try: + r = requests.post( + SEMANTIC_SCHOLAR_RECOMMEND_URL, + json={ + "positivePaperIds": [semantic_scholar_id], + }, + params={"fields": "externalIds,title,year", "limit": 14}, # type: ignore + ) + return r.json()["recommendedPapers"] + except KeyError as e: + print(e) + return [] + + +def filter_recommendations(recommendations, max_paper_count=5): + # include only arxiv papers + arxiv_paper = [ + r for r in recommendations if r["externalIds"].get("ArXiv", None) is not None + ] + if len(arxiv_paper) > max_paper_count: + arxiv_paper = arxiv_paper[:max_paper_count] + return arxiv_paper + + +def format_recommendation_into_markdown(recommendations): + comment = "(recommended by the Semantic Scholar API)\n\n" + for r in recommendations: + hub_paper_url = f"https://arxiv.org/abs/{r['externalIds']['ArXiv']}" + comment += f"* [{r['title']}]({hub_paper_url}) ({r['year']})\n" + + return comment + + +def get_paper_id_from_name(paper_name): + try: + response = requests.get( + SEMANTIC_SCHOLAR_QUERY_URL.format(paper_name=paper_name) + ) + response.raise_for_status() + items = response.json() + paper_id = items.get("data", [])[0].get("paperId") + except Exception as e: + print(e) + return None + + return paper_id + + +def get_recommended_papers(paper_name): + paper_id = get_paper_id_from_name(paper_name) + recommended_content = "" + if paper_id is None: + return recommended_content + + recommended_papers = get_recommendations_from_semantic_scholar(paper_id) + filtered_recommendations = filter_recommendations(recommended_papers) + + recommended_content = format_recommendation_into_markdown(filtered_recommendations) + return recommended_content + + +def fetch_papers(top_n=5): + try: + response = requests.get(f"{HF_API_URL}?limit=100") + response.raise_for_status() + items = response.json() + + # Calculate the date 3 days ago from now + three_days_ago = datetime.now() - timedelta(days=3) + + # Filter items from the last 3 days + recent_items = [ + item + for item in items + if parse_date(item.get("publishedAt")) >= three_days_ago + ] + + recent_items.sort( + key=lambda x: x.get("paper", {}).get("upvotes", 0), reverse=True + ) + output_items = [ + { + "title": item.get("paper", {}).get("title"), + "url": ARXIV_URL.format(paper_id=item.get("paper", {}).get("id")), + "upvotes": item.get("paper", {}).get("upvotes"), + } + for item in recent_items[:top_n] + ] + except Exception as e: + print(e) + return [] + + return output_items diff --git a/libs/ktem/ktem/utils/rate_limit.py b/libs/ktem/ktem/utils/rate_limit.py new file mode 100644 index 0000000..d1290f0 --- /dev/null +++ b/libs/ktem/ktem/utils/rate_limit.py @@ -0,0 +1,48 @@ +from collections import defaultdict +from datetime import datetime, timedelta + +import gradio as gr +from decouple import config + +# In-memory store for rate limiting (for demonstration purposes) +rate_limit_store: dict[str, dict] = defaultdict(dict) + +# Rate limit configuration +RATE_LIMIT = config("RATE_LIMIT", default=20, cast=int) +RATE_LIMIT_PERIOD = timedelta(hours=24) + + +def check_rate_limit(limit_type: str, request: gr.Request): + if request is None: + raise ValueError("This feature is not available") + + user_id = None + try: + import gradiologin as grlogin + + user = grlogin.get_user(request) + if user: + user_id = user.get("email") + except (ImportError, AssertionError): + pass + + if not user_id: + raise ValueError("Please sign-in to use this feature") + + now = datetime.now() + user_data = rate_limit_store[limit_type].get( + user_id, {"count": 0, "reset_time": now + RATE_LIMIT_PERIOD} + ) + + if now >= user_data["reset_time"]: + # Reset the rate limit for the user + user_data = {"count": 0, "reset_time": now + RATE_LIMIT_PERIOD} + + if user_data["count"] >= RATE_LIMIT: + raise ValueError("Rate limit exceeded. Please try again later.") + + # Increment the request count + user_data["count"] += 1 + rate_limit_store[limit_type][user_id] = user_data + + return user_id diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py index 49c2f79..18b09fe 100644 --- a/libs/ktem/ktem/utils/render.py +++ b/libs/ktem/ktem/utils/render.py @@ -5,7 +5,7 @@ from fast_langdetect import detect from kotaemon.base import RetrievedDocument -BASE_PATH = os.environ.get("GRADIO_ROOT_PATH", "") +BASE_PATH = os.environ.get("GR_FILE_ROOT_PATH", "") def is_close(val1, val2, tolerance=1e-9): @@ -44,7 +44,8 @@ class Render: o = " open" if open else "" return ( f"
" - f"{header}{content}

" + f"{header}{content}" + "

" ) @staticmethod @@ -225,6 +226,9 @@ class Render: doc, highlight_text=highlight_text, ) + rendered_doc_content = ( + f"
{rendered_doc_content}
" + ) return Render.collapsible( header=rendered_header, diff --git a/libs/ktem/pyproject.toml b/libs/ktem/pyproject.toml index 1fc420b..2add23b 100644 --- a/libs/ktem/pyproject.toml +++ b/libs/ktem/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "sqlmodel>=0.0.16,<0.1", "tiktoken>=0.6.0,<1", "gradio>=4.31.0,<5", + "gradiologin", "python-multipart==0.0.12", # required for gradio, pinning to avoid yanking issues with micropip (fixed in gradio >= 5.4.0) "markdown>=3.6,<4", "tzlocal>=5.0", diff --git a/sso_app.py b/sso_app.py new file mode 100644 index 0000000..780b8ca --- /dev/null +++ b/sso_app.py @@ -0,0 +1,51 @@ +import os + +import gradiologin as grlogin +from decouple import config +from fastapi import FastAPI +from fastapi.responses import FileResponse +from theflow.settings import settings as flowsettings + +KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".") +GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None) +# override GRADIO_TEMP_DIR if it's not set +if GRADIO_TEMP_DIR is None: + GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp") + os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR + + +GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID", default="") +GOOGLE_CLIENT_SECRET = config("GOOGLE_CLIENT_SECRET", default="") + + +from ktem.main import App # noqa + +gradio_app = App() +demo = gradio_app.make() + +app = FastAPI() +grlogin.register( + name="google", + server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + client_kwargs={ + "scope": "openid email profile", + }, +) + + +@app.get("/favicon.ico", include_in_schema=False) +async def favicon(): + return FileResponse(gradio_app._favicon) + + +grlogin.mount_gradio_app( + app, + demo, + "/app", + allowed_paths=[ + "libs/ktem/ktem/assets", + GRADIO_TEMP_DIR, + ], +) diff --git a/sso_app_demo.py b/sso_app_demo.py new file mode 100644 index 0000000..afe3488 --- /dev/null +++ b/sso_app_demo.py @@ -0,0 +1,97 @@ +import os + +import gradio as gr +from authlib.integrations.starlette_client import OAuth, OAuthError +from decouple import config +from fastapi import FastAPI, Request +from fastapi.responses import FileResponse +from starlette.config import Config +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import RedirectResponse +from theflow.settings import settings as flowsettings + +KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) +KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".") +GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None) +# override GRADIO_TEMP_DIR if it's not set +if GRADIO_TEMP_DIR is None: + GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp") + os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR + + +GOOGLE_CLIENT_ID = config("GOOGLE_CLIENT_ID", default="") +GOOGLE_CLIENT_SECRET = config("GOOGLE_CLIENT_SECRET", default="") +SECRET_KEY = config("SECRET_KEY", default="default-secret-key") + + +def add_session_middleware(app): + config_data = { + "GOOGLE_CLIENT_ID": GOOGLE_CLIENT_ID, + "GOOGLE_CLIENT_SECRET": GOOGLE_CLIENT_SECRET, + } + starlette_config = Config(environ=config_data) + oauth = OAuth(starlette_config) + oauth.register( + name="google", + server_metadata_url=( + "https://accounts.google.com/" ".well-known/openid-configuration" + ), + client_kwargs={"scope": "openid email profile"}, + ) + + app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) + return oauth + + +from ktem.main import App # noqa + +gradio_app = App() +main_demo = gradio_app.make() + +app = FastAPI() +oauth = add_session_middleware(app) + + +@app.get("/") +def public(request: Request): + root_url = gr.route_utils.get_root_url(request, "/", None) + return RedirectResponse(url=f"{root_url}/app/") + + +@app.get("/favicon.ico", include_in_schema=False) +async def favicon(): + return FileResponse(gradio_app._favicon) + + +@app.route("/logout") +async def logout(request: Request): + request.session.pop("user", None) + return RedirectResponse(url="/") + + +@app.route("/login") +async def login(request: Request): + root_url = gr.route_utils.get_root_url(request, "/login", None) + redirect_uri = f"{root_url}/auth" + return await oauth.google.authorize_redirect(request, redirect_uri) + + +@app.route("/auth") +async def auth(request: Request): + try: + access_token = await oauth.google.authorize_access_token(request) + except OAuthError: + return RedirectResponse(url="/") + request.session["user"] = dict(access_token)["userinfo"] + return RedirectResponse(url="/") + + +app = gr.mount_gradio_app( + app, + main_demo, + path="/app", + allowed_paths=[ + "libs/ktem/ktem/assets", + GRADIO_TEMP_DIR, + ], +)