diff --git a/.gitignore b/.gitignore index 5eaa7cdd9526e1f38be67985eecc76da96db8712..cd41b5712ce26752b1c121841bbcad74996fd843 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,8 @@ # Temporary files tmp/ temp/ + +# Local generated artifacts +doc-2.0-sources/ +data/ +__pycache__/ diff --git a/README.md b/README.md index 5416b4e8ba292e75ac0c868edde97b19e7dd3667..ffa2aaabd86859bc43628523548e5f0b0eacad2a 100644 --- a/README.md +++ b/README.md @@ -110,3 +110,78 @@ docker compose up -d - [ClickHouse MCP](https://github.com/ClickHouse/mcp-clickhouse) — MCP server for ClickHouse - [LibreChat](https://github.com/danny-avila/LibreChat) — Chat UI - [Langfuse](https://langfuse.com) — LLM observability + +## Alternative: Local RAG System (ClickHouse + Ollama) + +This repository also includes an alternative Python-based RAG workflow focused on local inference with Ollama and vector search in ClickHouse (see `rag_system_alternative.py` and `ReadMe.txt`). + +### Features + +- Vector search using embeddings in ClickHouse +- Optional reranking for improved answer quality +- Query caching for faster repeated requests +- Multi-document PDF processing +- Benchmarking and export helpers (CSV/JSON/Parquet) + +### Requirements + +#### System + +| Component | Minimum | Recommended | +|---|---|---| +| CPU | 4 cores | 8+ cores | +| RAM | 16 GB | 32+ GB | +| Disk | 10 GB | 20+ GB | +| GPU | Optional | NVIDIA 8 GB VRAM | + +#### Software + +- Python 3.10+ +- ClickHouse (Cloud or local) +- Ollama installed locally + +### Quick Start (Local RAG) + +1. Install dependencies: + +```bash +pip install clickhouse-connect ollama pypdf pandas numpy +pip install ipywidgets tqdm scikit-learn pyarrow +``` + +2. Install and start Ollama: + +```bash +ollama pull nomic-embed-text +ollama pull llama3.2:3b +ollama serve +``` + +3. Configure ClickHouse credentials in your script or environment. + +4. Run: + +```bash +python rag_system_alternative.py +``` + +### Suggested Environment Variables (Local RAG) + +```bash +# ClickHouse +CLICKHOUSE_HOST=your-host.clickhouse.cloud +CLICKHOUSE_USER=default +CLICKHOUSE_PASSWORD=your-password + +# Ollama +EMBED_MODEL=nomic-embed-text +LLM_MODEL=llama3.2:3b + +# RAG tuning +CHUNK_SIZE=1000 +CHUNK_OVERLAP=150 +TOP_K=8 +NUM_CTX=4096 +NUM_PREDICT=400 +TEMPERATURE=0.1 +``` diff --git a/bash.exe.stackdump b/bash.exe.stackdump new file mode 100644 index 0000000000000000000000000000000000000000..5fe7e37ebdfa6556bd5c196661984446a2e3c969 --- /dev/null +++ b/bash.exe.stackdump @@ -0,0 +1,32 @@ +Stack trace: +Frame Function Args +000FFFEFF60 00210062F57 (00000000002, 00000000002, 00000000000, 000FFFFDE50) +00000000000 00210065045 (000FFFF0910, 00000000000, 00000000744, 00000000000) +000FFFF0670 0021013AB68 (00000000000, 00000000000, 00000000000, 00000000000) +000000000C1 0021013619B (00000000000, 00200000000, 00000000000, 00000000000) +00000000000 002101365A5 (00210199B0B, 000FFFFFFFF, 0000000000B, 00000080000) +00000000000 00210199B0B (00210199B0B, 000FFFFFFFF, 0000000000B, 00000080000) +00000000000 0010044E724 (00000000000, 00000000000, 00000000000, 00000000000) +00000000000 0010044E8A9 (00000000000, 00000000000, 00000000000, 000FFFFCE00) +00000000000 0010044EA42 (00000000210, 00000738888, 000006F6A90, 000FFFF0BC0) +00000000000 00210065045 (000FFFF1320, 00000000000, 00000000744, 00000000000) +000FFFF11B0 0021013AB68 (00000000000, 00000000000, 00000000030, 00000000000) +000FFFF1320 00210063243 (000FFFF1490, 002102D35D0, 002100477C4, 000FFFF19E0) +000FFFF1990 7FFC12C861CF (00210040000, 002100477C4, 002102D35D0, 000FFFF1490) +000FFFF1990 7FFC12B323A7 (000FFFF22BC, 00210373AA0, 00210132B05, 00000000000) +000FFFFC730 7FFC12C85B0E (00000000000, 00000000000, 00000000000, 00000000000) +000FFFFC730 002100C98E9 (000FFFFC550, 00600000000, 0000000045C, 00000000000) +000FFFFC730 002100CAFF0 (00000000000, 00210135DE8, 000FFFFC660, 00210273880) +000FFFFC730 002100CB871 (0000000080F, 00000000000, 0010042CFB0, 00210135DE8) +000FFFFC730 0021015AA22 (00000080000, 0010061D894, 00000000000, 00100623248) +00100624660 00210199B0B (00000080000, 0010061D894, 00000000000, 00100623248) +00100624660 0010042C67B (008001FF860, 008001ED060, 00210199B0B, 008001A9850) +00100624660 0010042EC67 (00000000000, 0010044E4A0, 00000000002, 008FFFFFFFF) +008001FF810 0010041A317 (00210199B0B, 0010044E4A0, 00000000002, 008001FF810) +00000000000 0010041ACA2 (19802D000500018, 00000000000, 00000000000, 00100623E68) +00000000000 001004024A5 (00800000002, 00000000001, 0000000002D, 00000000000) +00000000000 001004EA5C5 (000FFFFCC70, 00800000160, 00210049B25, 000006E0000) +000FFFFCD30 00210049B91 (00000000000, 00000000000, 00000000000, 00000000000) +000FFFFFFF0 00210047716 (00000000000, 00000000000, 00000000000, 00000000000) +000FFFFFFF0 002100477C4 (00000000000, 00000000000, 00000000000, 00000000000) +End of stack trace diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f33b2cfb755141dd22c4014170bf8fa36945c699 --- /dev/null +++ b/config.py @@ -0,0 +1,72 @@ +import os +from dataclasses import dataclass, field +from typing import List, Tuple +from dotenv import load_dotenv + +load_dotenv() + + +@dataclass +class Config: + # ClickHouse + ch_host: str = os.getenv('CLICKHOUSE_HOST', 'ug1o26imbr.eu-central-1.aws.clickhouse.cloud') + ch_user: str = os.getenv('CLICKHOUSE_USER', 'default') + ch_password: str = os.getenv('CLICKHOUSE_PASSWORD', '~MlK_g7KdbqYH') + ch_secure: bool = os.getenv('CLICKHOUSE_SECURE', 'true').lower() == 'true' + + # Ollama + embed_model: str = os.getenv('EMBED_MODEL', 'nomic-embed-text') + llm_model: str = os.getenv('LLM_MODEL', 'llama3.2:3b') + + # RAG + chunk_size: int = int(os.getenv('CHUNK_SIZE', '1000')) + chunk_overlap: int = int(os.getenv('CHUNK_OVERLAP', '150')) + top_k: int = int(os.getenv('TOP_K', '8')) + rerank_top_k: int = int(os.getenv('RERANK_TOP_K', '3')) + similarity_threshold: float = float(os.getenv('SIMILARITY_THRESHOLD', '0.35')) + batch_size: int = int(os.getenv('BATCH_SIZE', '32')) + + # Generation + num_ctx: int = int(os.getenv('NUM_CTX', '4096')) + num_predict: int = int(os.getenv('NUM_PREDICT', '400')) + temperature: float = float(os.getenv('TEMPERATURE', '0.1')) + top_p: float = float(os.getenv('TOP_P', '0.9')) + repeat_penalty: float = float(os.getenv('REPEAT_PENALTY', '1.1')) + + # Limits + max_text_length: int = int(os.getenv('MAX_TEXT_LENGTH', '3072')) + min_chunk_size: int = int(os.getenv('MIN_CHUNK_SIZE', '100')) + max_chunks_per_doc: int = int(os.getenv('MAX_CHUNKS_PER_DOC', '2000')) + + # Cache + cache_enabled: bool = os.getenv('CACHE_ENABLED', 'true').lower() == 'true' + cache_ttl: int = int(os.getenv('CACHE_TTL', '3600')) + + # Paths + docs_folder: str = './doc-2.0-sources' + few_shot_folder: str = './data/few_shot_examples' + results_folder: str = './data/results' + + doc_files: List[Tuple[str, str]] = field(default_factory=list) + + def __post_init__(self): + if os.path.exists(self.docs_folder): + for root, dirs, files in os.walk(self.docs_folder): + for file in files: + if file.endswith(('.txt', '.md')): + full_path = os.path.join(root, file) + rel_path = os.path.relpath(root, self.docs_folder) + if rel_path == '.': + source_name = file.replace('.txt', '').replace('.md', '') + else: + source_name = f"{rel_path}/{file}" + self.doc_files.append((full_path, source_name)) + + os.makedirs(self.few_shot_folder, exist_ok=True) + os.makedirs(self.results_folder, exist_ok=True) + + print(f"[DOCS] Folder: {self.docs_folder}") + print(f"[DOCS] Files found: {len(self.doc_files)}") + + +config = Config() \ No newline at end of file diff --git a/core/database.py b/core/database.py new file mode 100644 index 0000000000000000000000000000000000000000..191f31bbfeb2eb9e82610114caf18a367dbf8c78 --- /dev/null +++ b/core/database.py @@ -0,0 +1,124 @@ +import time +import json +from typing import List, Dict, Tuple +import clickhouse_connect +from config import config + + +class DatabaseManager: + def __init__(self): + self._client = None + self._cache = {} + self._cache_time = {} + + def get_client(self): + if self._client is None: + self._client = clickhouse_connect.get_client( + host=config.ch_host, + username=config.ch_user, + password=config.ch_password, + secure=config.ch_secure, + compress=True, + connect_timeout=30 + ) + print(f"[OK] Connected to ClickHouse") + return self._client + + def init_database(self, force_recreate: bool = False): + """Инициализирует базу данных + + Args: + force_recreate: Если True - пересоздаёт таблицу (удаляет все данные) + Если False - создаёт таблицу только если её нет + """ + client = self.get_client() + + # Проверяем существование таблицы + try: + result = client.query("EXISTS TABLE default.rag_chunks") + table_exists = result.result_rows[0][0] if result.result_rows else False + except: + table_exists = False + + if table_exists and not force_recreate: + print("[OK] Database already exists, reusing existing data") + # Проверяем количество чанков + try: + count_result = client.query("SELECT count(*) FROM default.rag_chunks") + chunk_count = count_result.result_rows[0][0] if count_result.result_rows else 0 + print(f"[OK] Existing chunks in database: {chunk_count}") + except: + pass + return + + # Если таблица не существует или force_recreate=True + if force_recreate: + print("[WARN] Force recreating database...") + client.command("DROP TABLE IF EXISTS default.rag_chunks") + + client.command(""" + CREATE TABLE IF NOT EXISTS default.rag_chunks ( + id UInt64, + source String, + page UInt32, + chunk String, + embedding Array(Float32), + chunk_hash String, + char_count UInt32, + created_at DateTime DEFAULT now() + ) ENGINE = MergeTree() + PARTITION BY source + ORDER BY id + """) + print("[OK] Database initialized") + + def insert_batch(self, chunks: List[Dict]): + if not chunks: + return + client = self.get_client() + rows = [[c['id'], c['source'], c['page'], c['chunk'], + c['embedding'], c['chunk_hash'], c['char_count']] for c in chunks] + client.insert('default.rag_chunks', rows, + column_names=['id', 'source', 'page', 'chunk', 'embedding', 'chunk_hash', 'char_count']) + print(f" [OK] Inserted {len(chunks)} chunks") + + def get_chunk_count(self) -> int: + """Возвращает количество чанков в базе""" + try: + client = self.get_client() + result = client.query("SELECT count(*) FROM default.rag_chunks") + return result.result_rows[0][0] if result.result_rows else 0 + except: + return 0 + + def search(self, embedding: List[float]) -> List[tuple]: + client = self.get_client() + query = """ + SELECT chunk, source, page, cosineDistance(embedding, %(emb)s) AS distance + FROM default.rag_chunks + WHERE distance < %(threshold)s + ORDER BY distance ASC + LIMIT %(top_k)s + """ + result = client.query(query, parameters={ + 'emb': embedding, + 'threshold': config.similarity_threshold, + 'top_k': config.top_k + }) + return result.result_rows + + def get_cache(self, key: str): + if not config.cache_enabled: + return None + if key in self._cache: + if time.time() - self._cache_time.get(key, 0) < config.cache_ttl: + return self._cache[key] + return None + + def set_cache(self, key: str, value: str): + if config.cache_enabled: + self._cache[key] = value + self._cache_time[key] = time.time() + + +db = DatabaseManager() \ No newline at end of file diff --git a/core/document_processor.py b/core/document_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c2bcd7e657edc7e3290da61346b49ed5c23c8f33 --- /dev/null +++ b/core/document_processor.py @@ -0,0 +1,93 @@ +import os +import re +import hashlib +from typing import List, Dict, Tuple +from config import config +from core.embeddings import embedder + + +class DocumentProcessor: + @staticmethod + def load_document(file_path: str, source_name: str) -> List[Tuple[int, str]]: + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + + chunks = [] + if len(content) > config.chunk_size: + parts = [] + current = [] + current_len = 0 + + for line in content.split('\n'): + if current_len + len(line) > config.chunk_size: + parts.append('\n'.join(current)) + current = [line] + current_len = len(line) + else: + current.append(line) + current_len += len(line) + + if current: + parts.append('\n'.join(current)) + + for i, part in enumerate(parts): + if len(part.strip()) > config.min_chunk_size: + chunks.append((i + 1, part.strip())) + else: + if len(content.strip()) > config.min_chunk_size: + chunks.append((1, content.strip())) + + return chunks + except Exception as e: + print(f" Error loading {file_path}: {e}") + return [] + + @staticmethod + def split_chunks(text: str) -> List[str]: + size = config.chunk_size + words = text.split() + chunks = [] + step = size - config.chunk_overlap + + for i in range(0, len(words), step): + chunk = ' '.join(words[i:i+size]) + if len(chunk) > config.min_chunk_size: + chunks.append(chunk) + if len(chunks) >= config.max_chunks_per_doc: + break + return chunks + + @staticmethod + def process_document(file_path: str, source_name: str, start_id: int) -> List[Dict]: + print(f"\n Processing: {source_name}") + pages = DocumentProcessor.load_document(file_path, source_name) + + if not pages: + return [] + + chunks = [] + for page_num, text in pages: + text_chunks = DocumentProcessor.split_chunks(text) + + for chunk in text_chunks: + if len(chunk) > config.max_text_length: + chunk = chunk[:config.max_text_length] + chunks.append({ + 'id': start_id + len(chunks), + 'source': source_name, + 'page': page_num, + 'chunk': chunk, + 'chunk_hash': hashlib.md5(chunk.encode()).hexdigest(), + 'char_count': len(chunk) + }) + if len(chunks) >= config.max_chunks_per_doc: + break + if len(chunks) >= config.max_chunks_per_doc: + break + + print(f" Created {len(chunks)} chunks from {source_name}") + return chunks + + +doc_processor = DocumentProcessor() \ No newline at end of file diff --git a/core/embeddings.py b/core/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..f71f67110239fb5b0560eeb0aa368ed5522662fd --- /dev/null +++ b/core/embeddings.py @@ -0,0 +1,48 @@ +import hashlib +from typing import List +from functools import lru_cache +import ollama +from config import config + + +class EmbeddingGenerator: + def __init__(self): + self.model = config.embed_model + self.batch_size = config.batch_size + self.max_length = config.max_text_length + + def _truncate_text(self, text: str) -> str: + if len(text) <= self.max_length: + return text + truncated = text[:self.max_length] + last_period = truncated.rfind('.') + if last_period > self.max_length // 2: + truncated = truncated[:last_period + 1] + return truncated.strip() + + def generate_batch(self, texts: List[str]) -> List[List[float]]: + if not texts: + return [] + safe_texts = [self._truncate_text(t) for t in texts] + try: + response = ollama.embed(model=self.model, input=safe_texts) + return response['embeddings'] + except Exception as e: + print(f" Embedding error: {e}") + return [[0.0] * 768 for _ in safe_texts] + + def generate(self, texts: List[str]) -> List[List[float]]: + all_embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i:i+self.batch_size] + embeddings = self.generate_batch(batch) + all_embeddings.extend(embeddings) + return all_embeddings + + @lru_cache(maxsize=256) + def generate_cached(self, text: str) -> tuple: + embedding = self.generate_batch([text])[0] + return tuple(embedding) + + +embedder = EmbeddingGenerator() \ No newline at end of file diff --git a/core/init.py b/core/init.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc0fb4e31179a845dd5892bdf734dd70ab1a8be --- /dev/null +++ b/core/init.py @@ -0,0 +1,4 @@ +from core.database import DatabaseManager +from core.embeddings import EmbeddingGenerator +from core.document_processor import DocumentProcessor +from core.reranker import Reranker \ No newline at end of file diff --git a/core/pdf_processor.py b/core/pdf_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..9941e85d0a028c1bae010083ad1eb9ef98eed518 --- /dev/null +++ b/core/pdf_processor.py @@ -0,0 +1,77 @@ +import re +import hashlib +from typing import List, Dict, Tuple +from pypdf import PdfReader +from config import config +from core.embeddings import embedder + + +class PDFProcessor: + @staticmethod + def extract_pdf(pdf_path: str, source_name: str) -> List[Tuple[int, str]]: + try: + reader = PdfReader(pdf_path) + total_pages = len(reader.pages) + print(f" Total pages: {total_pages}") + + pages = [] + for i in range(total_pages): + try: + page = reader.pages[i] + text = page.extract_text() + if text and len(text.strip()) > config.min_chunk_size: + text = re.sub(r'\n+', ' ', text) + pages.append((i + 1, text.strip())) + except: + pass + return pages + except Exception as e: + print(f" Error: {e}") + return [] + + @staticmethod + def split_chunks(text: str) -> List[str]: + size = config.chunk_size + words = text.split() + chunks = [] + step = size - config.chunk_overlap + + for i in range(0, len(words), step): + chunk = ' '.join(words[i:i+size]) + if len(chunk) > config.min_chunk_size: + chunks.append(chunk) + if len(chunks) >= config.max_chunks_per_doc: + break + return chunks + + @staticmethod + def process_document(pdf_path: str, source_name: str, start_id: int) -> List[Dict]: + print(f"\n Processing: {source_name}") + pages = PDFProcessor.extract_pdf(pdf_path, source_name) + + if not pages: + return [] + + chunks = [] + for page_num, text in pages: + for chunk in PDFProcessor.split_chunks(text): + if len(chunk) > config.max_text_length: + chunk = chunk[:config.max_text_length] + chunks.append({ + 'id': start_id + len(chunks), + 'source': source_name, + 'page': page_num, + 'chunk': chunk, + 'chunk_hash': hashlib.md5(chunk.encode()).hexdigest(), + 'char_count': len(chunk) + }) + if len(chunks) >= config.max_chunks_per_doc: + break + if len(chunks) >= config.max_chunks_per_doc: + break + + print(f" Created {len(chunks)} chunks") + return chunks + + +pdf_processor = PDFProcessor() \ No newline at end of file diff --git a/core/reranker.py b/core/reranker.py new file mode 100644 index 0000000000000000000000000000000000000000..dcba5481f4765c1d51672aba2436f567469f3580 --- /dev/null +++ b/core/reranker.py @@ -0,0 +1,27 @@ +import re +from typing import List, Tuple +from config import config + + +class Reranker: + @staticmethod + def rerank(question: str, results: List[tuple]) -> List[tuple]: + if not results: + return results + + q_words = set(re.findall(r'\b\w{4,}\b', question.lower())) + + scored = [] + for result in results: + chunk, source, page, distance = result + c_words = set(re.findall(r'\b\w{4,}\b', chunk.lower())) + overlap = len(q_words & c_words) / max(len(q_words), 1) + similarity = 1 - distance + final_score = similarity * 0.6 + overlap * 0.4 + scored.append((final_score, result)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [r for _, r in scored[:config.rerank_top_k]] + + +reranker = Reranker() \ No newline at end of file diff --git a/evaluator/folder_scanner.py b/evaluator/folder_scanner.py new file mode 100644 index 0000000000000000000000000000000000000000..99252f6a87aaa286fa52d86927689db09512249f --- /dev/null +++ b/evaluator/folder_scanner.py @@ -0,0 +1,59 @@ +import os +from pathlib import Path +from typing import List, Dict + + +class FolderScanner: + def __init__(self, root_path: str): + self.root_path = Path(root_path) + + def scan(self) -> List[Dict]: + results = [] + if not self.root_path.exists(): + return results + + for root, dirs, files in os.walk(self.root_path): + root_path = Path(root) + questions_file = None + answers_file = None + + for file in files: + file_lower = file.lower() + if 'question' in file_lower or file_lower == 'q.txt': + questions_file = root_path / file + if 'answer' in file_lower or file_lower == 'a.txt': + answers_file = root_path / file + + if questions_file and answers_file: + folder_name = str(root_path.relative_to(self.root_path)) + if folder_name == '.': + folder_name = 'root' + + results.append({ + 'folder_path': str(root_path), + 'folder_name': folder_name, + 'questions_file': str(questions_file), + 'answers_file': str(answers_file), + 'questions_count': count_records(questions_file), + 'answers_count': count_records(answers_file) + }) + + return results + + +def count_records(file_path: Path) -> int: + try: + if file_path.suffix == '.txt': + with open(file_path, 'r', encoding='utf-8') as f: + return len([line for line in f if line.strip()]) + elif file_path.suffix == '.csv': + import pandas as pd + return len(pd.read_csv(file_path)) + elif file_path.suffix == '.json': + import json + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + return len(data) if isinstance(data, list) else 1 + except: + pass + return 0 \ No newline at end of file diff --git a/evaluator/init.py b/evaluator/init.py new file mode 100644 index 0000000000000000000000000000000000000000..cd55c15ed96958f1dd9db197bdbdc335e3df4347 --- /dev/null +++ b/evaluator/init.py @@ -0,0 +1,3 @@ +from evaluator.qa_loader import QALoader +from evaluator.folder_scanner import FolderScanner +from evaluator.results import ResultsAnalyzer \ No newline at end of file diff --git a/evaluator/qa_loader.py b/evaluator/qa_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ed907f58d4f1edb16a8b7ca547313648cd1db648 --- /dev/null +++ b/evaluator/qa_loader.py @@ -0,0 +1,60 @@ +import json +import pandas as pd +from pathlib import Path +from typing import List, Tuple + + +class QALoader: + @staticmethod + def load_questions(file_path: str) -> List[str]: + path = Path(file_path) + + if path.suffix == '.txt': + with open(file_path, 'r', encoding='utf-8') as f: + return [line.strip() for line in f if line.strip()] + + elif path.suffix == '.csv': + df = pd.read_csv(file_path) + for col in df.columns: + if 'question' in col.lower() or 'query' in col.lower(): + return df[col].dropna().astype(str).tolist() + return df.iloc[:, 0].dropna().astype(str).tolist() + + elif path.suffix == '.json': + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, list): + return [item.get('question', str(item)) for item in data] + + return [] + + @staticmethod + def load_answers(file_path: str) -> List[str]: + path = Path(file_path) + + if path.suffix == '.txt': + with open(file_path, 'r', encoding='utf-8') as f: + return [line.strip() for line in f if line.strip()] + + elif path.suffix == '.csv': + df = pd.read_csv(file_path) + for col in df.columns: + if 'answer' in col.lower() or 'response' in col.lower(): + return df[col].dropna().astype(str).tolist() + return df.iloc[:, 0].dropna().astype(str).tolist() + + elif path.suffix == '.json': + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + if isinstance(data, list): + return [item.get('answer', str(item)) for item in data] + + return [] + + @staticmethod + def load_qa_pairs(questions_file: str, answers_file: str) -> List[Tuple[str, str]]: + questions = QALoader.load_questions(questions_file) + answers = QALoader.load_answers(answers_file) + + min_len = min(len(questions), len(answers)) + return list(zip(questions[:min_len], answers[:min_len])) \ No newline at end of file diff --git a/evaluator/results.py b/evaluator/results.py new file mode 100644 index 0000000000000000000000000000000000000000..5a9504aef665a065cee84dab546568665c14f940 --- /dev/null +++ b/evaluator/results.py @@ -0,0 +1,14 @@ +import os +import pandas as pd +from config import config + + +class ResultsAnalyzer: + @staticmethod + def save(results: list, filename: str = "evaluation_results.csv") -> pd.DataFrame: + os.makedirs(config.results_folder, exist_ok=True) + df = pd.DataFrame(results) + path = os.path.join(config.results_folder, filename) + df.to_csv(path, index=False, encoding='utf-8') + print(f" Results saved to: {path}") + return df \ No newline at end of file diff --git a/prompts/rag_api_en.txt b/prompts/rag_api_en.txt new file mode 100644 index 0000000000000000000000000000000000000000..6670be8083a608359a9078be96f8d8ace2b1810e --- /dev/null +++ b/prompts/rag_api_en.txt @@ -0,0 +1,29 @@ +The client wants to receive general describing information about API interface. Your task is to +provide well-structured and detailed information about API to the client, which will give an initial understanding +of how the API is used. Below is a collection of information found in the documentation to answer a customer request, +surrounded by --------------------- + +--------------------- +{context} +--------------------- + +This information is a set of blocks, divided into headers in .md format with the following information. + +Customer query: {query} + +Your task is to respond to the client's request as follows: + +1. Select information blocks that, based on the title and content, best match the API name that the client +sent in the request. Generally, there are two types of API - non-form API and form API. If in the customer query you +see any mentions form or forms, please select for answer only information blocks in the headers of which or +in content of which there are mentions form or forms. If in the customer query you don`t see any mentions form or forms, +please select for answer only information blocks where you don`t see any mentions form or forms. + +2. From the selected blocks of information, create an answer that has the following structure: +first a brief introduction with a description of what the API is used for, then - flow for this api, step by step as it is if +it was found, then the URLs that are used for this API with a mention of what environment they belong to - sandbox or +production, then you can give a brief set of basic request parameters with their short description and a brief set of +basic response parameters with their short description, if they are in the selected information. Then, you can provide +one example of the request and response used in this API. + +3. In your answer, you do not need to describe the names of the information blocks that were selected for the answer. diff --git a/prompts/rag_api_parameter_en.txt b/prompts/rag_api_parameter_en.txt new file mode 100644 index 0000000000000000000000000000000000000000..362c00a73bc9cecd3d9cc893fddb1d3e606fbe66 --- /dev/null +++ b/prompts/rag_api_parameter_en.txt @@ -0,0 +1,35 @@ +The client wants to receive information about parameter or parameters specified in query. Your task is to +provide well-structured and detailed information about parameter or mentioned parameters to the client. +Below is a collection of information found in the documentation to answer a customer request, +surrounded by --------------------- + +--------------------- +{context} +--------------------- + +This information is a set of blocks, divided into headers in .md format with the following information. + +Customer query: {query} + +Your task is to respond to the customer query as follows: + +Usually, customer asks about one single parameter or difference between the parameters specified in the request. + +If customer query is about one single parameter, do the following: + +It is necessary to find information only on the parameter specified in the request and give the most +detailed answer and complete information only on this specific parameter. Information about any other parameters +should not be included into response to customer. Following information should be included into response: + +1. Parameter name strictly as it is +2. Parameter description strictly as it is. +3. Information about parameter value characteristics: necessity, type, length etc. +4. Also, you may mention apis and requests where this parameter is used. Some useful notes about this parameter +also can be added to response if found. + +If a customer query is about several parameters, do the following: + +1. Find information about each mentioned parameter, as it described above for one single parameter request. +2. Include the information found for each parameter in the answer. +3. If a customer query contains a mention of any difference or differences between parameters analyze how the +specified parameters differ from each other and include a summary of their differences in your answer. diff --git a/prompts/rag_api_parameters_list_en.txt b/prompts/rag_api_parameters_list_en.txt new file mode 100644 index 0000000000000000000000000000000000000000..84eb013ab0f51fbb17036c38e7696b8e60fbf9b7 --- /dev/null +++ b/prompts/rag_api_parameters_list_en.txt @@ -0,0 +1,21 @@ +The client wants to receive information about the parameter list or set. Your task is to +provide well-structured and beautifully presented information to the client. +Below is a collection of information found in the documentation to answer a customer request, +surrounded by --------------------- + +--------------------- +{context} +--------------------- + +This information is a set of blocks, divided into headers in .md format with the following information. + +Customer query: {query} + +Your task is to respond to the client's request as follows: + +1. If the question is about request parameters - select the part of the information that concerns the request parameters. +If the question is about response parameters - select the part of the information that concerns the response parameters. +If it is not specified what parameters the query is about - select the entire part of the information that concerns the +parameters. + +2. Present parameters as a detailed list with their description and characteristics. diff --git a/rag-system.ipynb b/rag-system.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..d69b9cf12a678789a773a4470faaf2bdd1a60dc3 --- /dev/null +++ b/rag-system.ipynb @@ -0,0 +1,1101 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "id": "d8543c46", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " WARNING: Failed to remove contents in a temporary directory 'C:\\Users\\User\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\~-mpy.libs'.\n", + " You can safely remove it manually.\n", + " WARNING: Failed to remove contents in a temporary directory 'C:\\Users\\User\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\~-mpy'.\n", + " You can safely remove it manually.\n", + "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "langchain 0.3.0 requires numpy<2,>=1; python_version < \"3.12\", but you have numpy 2.2.6 which is incompatible.\n", + "langchain-community 0.3.0 requires numpy<2,>=1; python_version < \"3.12\", but you have numpy 2.2.6 which is incompatible.\n", + "typer 0.24.1 requires click>=8.2.1, but you have click 8.1.8 which is incompatible.\n" + ] + } + ], + "source": [ + "pip install pywidgets -q" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9eee0f71", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ ClickHouse подключён!\n" + ] + } + ], + "source": [ + "import ollama\n", + "import clickhouse_connect\n", + "import pandas as pd\n", + "import json, time, re\n", + "from pypdf import PdfReader\n", + "from IPython.display import display, HTML, clear_output\n", + "import ipywidgets as widgets\n", + "\n", + "# ── Настройки ────────────────────────────────────────────\n", + "CH_HOST = 'ug1o26imbr.eu-central-1.aws.clickhouse.cloud'\n", + "CH_USER = 'default'\n", + "CH_PASSWORD = '~MlK_g7KdbqYH' # ← вставьте реальный пароль\n", + "\n", + "EMBED_MODEL = 'nomic-embed-text'\n", + "LLM_MODEL = 'llama3.1'\n", + "\n", + "PDF_FILES = [\n", + " (r'C:\\Users\\User\\Desktop\\Folder_vs_documents\\integration.pdf', 'Integration'),\n", + " (r'C:\\Users\\User\\Desktop\\Folder_vs_documents\\manager.pdf', 'Manager'),\n", + " (r'C:\\Users\\User\\Desktop\\Folder_vs_documents\\merchant.pdf', 'Merchant'),\n", + "]\n", + "\n", + "QUESTIONS_CSV = r'C:\\Users\\User\\Desktop\\Folder_vs_documents\\questions.csv' # ← путь к CSV с вопросами\n", + "QUESTION_COL = 'question' # ← название колонки с вопросами\n", + "\n", + "# ── Подключение ──────────────────────────────────────────\n", + "# ── Подключение к ClickHouse ──────────────────────────────\n", + "client = clickhouse_connect.get_client(\n", + " host='ug1o26imbr.eu-central-1.aws.clickhouse.cloud',\n", + " username='default',\n", + " password='~MlK_g7KdbqYH',\n", + " secure=True\n", + ")\n", + "print(\"✅ ClickHouse подключён!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee60aa34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Все библиотеки загружены!\n" + ] + } + ], + "source": [ + "import ollama\n", + "import clickhouse_connect\n", + "import pandas as pd\n", + "import json, time, re\n", + "from pypdf import PdfReader\n", + "from IPython.display import display, HTML, clear_output\n", + "import ipywidgets as widgets\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7df65af0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Таблица rag_chunks готова\n" + ] + } + ], + "source": [ + "client.command(\"\"\"\n", + "CREATE TABLE IF NOT EXISTS default.rag_chunks (\n", + " id UInt64,\n", + " source String,\n", + " page UInt32,\n", + " chunk String,\n", + " embedding Array(Float32)\n", + ") ENGINE = MergeTree()\n", + "ORDER BY id\n", + "\"\"\")\n", + "\n", + "print(\"✅ Таблица rag_chunks готова\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "ea4f0913", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Функции готовы\n" + ] + } + ], + "source": [ + "def extract_pdf(pdf_path):\n", + " \"\"\"Извлечь текст по страницам из PDF\"\"\"\n", + " reader = PdfReader(pdf_path)\n", + " pages = []\n", + " for i, page in enumerate(reader.pages):\n", + " text = page.extract_text()\n", + " if text and len(text.strip()) > 30:\n", + " pages.append((i + 1, text.strip()))\n", + " return pages\n", + "\n", + "def split_chunks(text, size=100, overlap=10):\n", + " \"\"\"Разбить текст на перекрывающиеся чанки\"\"\"\n", + " words = text.split()\n", + " chunks = []\n", + " step = size - overlap\n", + " for i in range(0, len(words), step):\n", + " chunk = ' '.join(words[i:i + size])\n", + " if len(chunk) > 50:\n", + " chunks.append(chunk)\n", + " return chunks\n", + "\n", + "def get_embedding(text):\n", + " \"\"\"Получить вектор через Ollama\"\"\"\n", + " resp = ollama.embeddings(model=EMBED_MODEL, prompt=text)\n", + " return list(resp['embedding'])\n", + "\n", + "print(\"✅ Функции готовы\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "be8a3dc3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🗑️ Таблица очищена\n" + ] + } + ], + "source": [ + "client.command(\"TRUNCATE TABLE default.rag_chunks\")\n", + "print(\"🗑️ Таблица очищена\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2ca29917", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "📄 Обрабатываю: Integration\n", + " Страниц: 1154\n", + " ✅ Загружено чанков: 2833\n", + "\n", + "📄 Обрабатываю: Manager\n", + " Страниц: 564\n", + " ✅ Загружено чанков: 1350\n", + "\n", + "📄 Обрабатываю: Merchant\n", + " Страниц: 109\n", + " ✅ Загружено чанков: 205\n", + "\n", + "🎉 Итого загружено: 4388 чанков\n" + ] + } + ], + "source": [ + "chunk_id = 0\n", + "total_chunks = 0\n", + "\n", + "for pdf_path, source_name in PDF_FILES:\n", + " print(f\"\\n📄 Обрабатываю: {source_name}\")\n", + " pages = extract_pdf(pdf_path)\n", + " print(f\" Страниц: {len(pages)}\")\n", + "\n", + " rows = []\n", + " for page_num, text in pages:\n", + " chunks = split_chunks(text)\n", + " for chunk in chunks:\n", + " emb = get_embedding(chunk)\n", + " rows.append([chunk_id, source_name, page_num, chunk, emb])\n", + " chunk_id += 1\n", + "\n", + " client.insert(\n", + " 'default.rag_chunks',\n", + " rows,\n", + " column_names=['id', 'source', 'page', 'chunk', 'embedding']\n", + " )\n", + " total_chunks += len(rows)\n", + " print(f\" ✅ Загружено чанков: {len(rows)}\")\n", + "\n", + "print(f\"\\n🎉 Итого загружено: {total_chunks} чанков\")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "a3b15e25", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "⏱️ Время: 24.57s\n", + "🤖 Информация не найдена в документах.\n" + ] + } + ], + "source": [ + "def search_context(question, top_k=5):\n", + " \"\"\"Найти похожие чанки через cosineDistance\"\"\"\n", + " q_emb = get_embedding(question)\n", + " q_str = '[' + ','.join(map(str, q_emb)) + ']'\n", + "\n", + " res = client.query(f\"\"\"\n", + " SELECT chunk, source, page,\n", + " cosineDistance(embedding, {q_str}) AS dist\n", + " FROM default.rag_chunks\n", + " ORDER BY dist ASC\n", + " LIMIT {top_k}\n", + " \"\"\")\n", + " return res.result_rows\n", + "\n", + "def ask_rag(question, top_k=5):\n", + " \"\"\"Полный RAG пайплайн: вопрос → контекст → ответ\"\"\"\n", + " t_start = time.time()\n", + "\n", + " # 1. Найти контекст\n", + " ctx_rows = search_context(question, top_k)\n", + " context = \"\\n\\n\".join([\n", + " f\"[{r[1]}, стр.{r[2]}] {r[0]}\"\n", + " for r in ctx_rows\n", + " ])\n", + " t_retrieve = time.time() - t_start\n", + "\n", + " # 2. Спросить Llama\n", + " messages = [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": (\n", + " \"Ты — помощник по технической документации. \"\n", + " \"Отвечай ТОЛЬКО на основе предоставленного контекста. \"\n", + " \"Если ответ не найден — скажи 'Информация не найдена в документах'. \"\n", + " \"Отвечай кратко и по делу.\"\n", + " )\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"Контекст:\\n{context}\\n\\nВопрос: {question}\"\n", + " }\n", + " ]\n", + " resp = ollama.chat(model=LLM_MODEL, messages=messages)\n", + " t_total = time.time() - t_start\n", + "\n", + " return {\n", + " 'question' : question,\n", + " 'answer' : resp.message.content,\n", + " 'sources' : [(r[1], r[2]) for r in ctx_rows],\n", + " 'time_retrieve': round(t_retrieve, 2),\n", + " 'time_total' : round(t_total, 2),\n", + " }\n", + "\n", + "# Быстрый тест:\n", + "test = ask_rag(\"Тестовый вопрос?\")\n", + "print(f\"⏱️ Время: {test['time_total']}s\\n🤖 {test['answer'][:200]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "b98b3ba8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "📋 Вопросов: 59\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5ff9ce52f8c7414baa9fcc62c9c9ac32", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(IntProgress(value=0, bar_style='info', description='Прогресс:', layout=Layout(width='70%'), max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "📊 Результаты:\n", + " Вопросов обработано : 59/59\n", + " Среднее время ответа: 28.4 сек\n", + " Сохранено в : benchmark_results.csv\n" + ] + } + ], + "source": [ + "df_q = pd.read_csv(QUESTIONS_CSV)\n", + "questions = df_q[QUESTION_COL].dropna().tolist()\n", + "print(f\"📋 Вопросов: {len(questions)}\")\n", + "\n", + "# Прогресс-бар\n", + "progress = widgets.IntProgress(\n", + " value=0, min=0, max=len(questions),\n", + " description='Прогресс:',\n", + " bar_style='info',\n", + " layout=widgets.Layout(width='70%')\n", + ")\n", + "status_lbl = widgets.Label(value='Ожидание...')\n", + "display(widgets.VBox([progress, status_lbl]))\n", + "\n", + "# Запуск бенчмарка\n", + "results = []\n", + "for i, q in enumerate(questions):\n", + " status_lbl.value = f\"[{i+1}/{len(questions)}] {q[:70]}...\"\n", + " try:\n", + " res = ask_rag(q)\n", + " results.append({\n", + " '№' : i + 1,\n", + " 'question' : q,\n", + " 'answer' : res['answer'],\n", + " 'sources' : str(res['sources']),\n", + " 'time_retrieve': res['time_retrieve'],\n", + " 'time_total' : res['time_total'],\n", + " 'status' : 'ok'\n", + " })\n", + " except Exception as e:\n", + " results.append({\n", + " '№': i+1, 'question': q, 'answer': f'ОШИБКА: {e}',\n", + " 'sources':'', 'time_retrieve':0, 'time_total':0, 'status':'error'\n", + " })\n", + " progress.value = i + 1\n", + "\n", + "status_lbl.value = '✅ Готово!'\n", + "\n", + "# Сохранить результаты\n", + "df_res = pd.DataFrame(results)\n", + "df_res.to_csv('benchmark_results.csv', index=False, encoding='utf-8-sig')\n", + "\n", + "avg_time = df_res['time_total'].mean()\n", + "ok_count = (df_res['status'] == 'ok').sum()\n", + "print(f\"\\n📊 Результаты:\")\n", + "print(f\" Вопросов обработано : {ok_count}/{len(questions)}\")\n", + "print(f\" Среднее время ответа: {avg_time:.1f} сек\")\n", + "print(f\" Сохранено в : benchmark_results.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "770088f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🏆 Запуск LLM-судьи...\n", + " [1] Оценка: ⭐⭐ (2/5)\n", + " [2] Оценка: ⭐⭐ (2/5)\n", + " [3] Оценка: ⭐⭐ (2/5)\n", + " [4] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [5] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [6] Оценка: ⭐⭐⭐ (3/5)\n", + " [7] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [8] Оценка: ⭐ (1/5)\n", + " [9] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [10] Оценка: ⭐⭐ (2/5)\n", + " [11] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [12] Оценка: ⭐ (1/5)\n", + " [13] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [14] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [15] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [16] Оценка: ⭐⭐⭐ (3/5)\n", + " [17] Оценка: ⭐⭐⭐ (3/5)\n", + " [18] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [19] Оценка: ⭐⭐ (2/5)\n", + " [20] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [21] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [22] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [23] Оценка: ⭐⭐ (2/5)\n", + " [24] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [25] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [26] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [27] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [28] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [29] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [30] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [31] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [32] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [33] Оценка: ⭐⭐ (2/5)\n", + " [34] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [35] Оценка: ⭐⭐⭐ (3/5)\n", + " [36] Оценка: ⭐⭐⭐ (3/5)\n", + " [37] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [38] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [39] Оценка: ⭐⭐ (2/5)\n", + " [40] Оценка: ⭐ (1/5)\n", + " [41] Оценка: ⭐⭐⭐ (3/5)\n", + " [42] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [43] Оценка: ⭐ (1/5)\n", + " [44] Оценка: ⭐⭐⭐ (3/5)\n", + " [45] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [46] Оценка: ⭐⭐⭐ (3/5)\n", + " [47] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [48] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [49] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [50] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [51] Оценка: ⭐⭐ (2/5)\n", + " [52] Оценка: ⭐⭐ (2/5)\n", + " [53] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [54] Оценка: ⭐ (1/5)\n", + " [55] Оценка: ⭐⭐ (2/5)\n", + " [56] Оценка: ⭐⭐ (2/5)\n", + " [57] Оценка: ⭐⭐⭐⭐ (4/5)\n", + " [58] Оценка: ⭐ (1/5)\n", + " [59] Оценка: ⭐⭐ (2/5)\n", + "\n", + "✅ Средняя оценка: 3.12 / 5\n" + ] + } + ], + "source": [ + "def judge_answer(question, answer):\n", + " \"\"\"Llama оценивает качество своего же ответа от 1 до 5\"\"\"\n", + " prompt = f\"\"\"Оцени качество ответа на вопрос по шкале от 1 до 5.\n", + "\n", + "Вопрос: {question}\n", + "Ответ: {answer}\n", + "\n", + "Критерии оценки:\n", + "5 — полный, точный, понятный ответ\n", + "4 — хороший ответ с незначительными пробелами\n", + "3 — частично правильный ответ\n", + "2 — слабый ответ, мало полезной информации\n", + "1 — нет ответа или 'информация не найдена'\n", + "\n", + "Ответь ТОЛЬКО одной цифрой от 1 до 5.\"\"\"\n", + "\n", + " resp = ollama.chat(\n", + " model=LLM_MODEL,\n", + " messages=[{\"role\": \"user\", \"content\": prompt}]\n", + " )\n", + " text = resp.message.content.strip()\n", + " match = re.search(r'[1-5]', text)\n", + " return int(match.group()) if match else 0\n", + "\n", + "# Оценить все ответы\n", + "print(\"🏆 Запуск LLM-судьи...\")\n", + "scores = []\n", + "for i, row in df_res.iterrows():\n", + " score = judge_answer(row['question'], row['answer'])\n", + " scores.append(score)\n", + " print(f\" [{i+1}] Оценка: {'⭐'*score} ({score}/5)\")\n", + "\n", + "df_res['score'] = scores\n", + "df_res.to_csv('benchmark_results.csv', index=False, encoding='utf-8-sig')\n", + "print(f\"\\n✅ Средняя оценка: {sum(scores)/len(scores):.2f} / 5\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "1bcfd0ee", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "

\n", + " 📊 Результаты RAG Benchmark\n", + "

\n", + "
\n", + "
\n", + "
3.12
\n", + "
Средняя оценка / 5
\n", + "
\n", + "
\n", + "
28.4s
\n", + "
Среднее время ответа
\n", + "
\n", + "
\n", + "
100%
\n", + "
Успешных ответов
\n", + "
\n", + "
\n", + "
59
\n", + "
Всего вопросов
\n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
#ВопросОтветОценкаВремя
1Which tables, materialized views, and queries do I need to fully cover the analytics pipeline for paИнформация не найдена в документах....⭐⭐25.5s
2Provide an integration plan for loading payment and transaction data into ClickHouseИнформация не найдена в документах....⭐⭐24.86s
3Create a step-by-step guide: how to ingest data from application logs or APIs into ClickHouseИнформация не найдена в документах....⭐⭐26.11s
4How to test the integration (schema validation, data quality, idempotency, replay)?Для проверки интеграции следует выполнить тестовую транзакцию из Виртуального Терминала и убедиться, что все параметры подключены корректно....⭐⭐⭐⭐32.02s
5What is the difference between a table and a materialized view in ClickHouse?Информация не найдена в документах....⭐⭐⭐⭐25.65s
6Should I use separate tables per currency or one fact table with a currency dimension?Информация не найдена в документах....⭐⭐⭐25.54s
7How do I check ingestion status and the last successful batch?Информация не найдена в документах....⭐⭐⭐⭐25.14s
8Which ingestion or ETL job statuses should we model and store?Информация не найдена в документах....26.58s
9What is a Connecting Party in our data model (source owner vs analytics consumer)?Connecting Party в нашем модели представляет Merchant и ID назначенный внешним процессором, если транзакция обрабатывалась им....⭐⭐⭐⭐29.66s
10What is a Payment Gateway as a dimension or entity in the warehouse?Информация не найдена в документах....⭐⭐27.08s
11What is a table in ClickHouse, and when is it fact vs dimension?Информация не найдена в документах....⭐⭐⭐⭐27.46s
12What is a materialized view and when should I prefer it over a plain table?Информация не найдена в документах....25.32s
13What is a merchant control key, and may it appear in stored rows or only in secrets?Merchant Control Key - это ключ, который присваивается аккаунту Connecting Party в системе Doc2.0 Gateway. Он не хранится в таблицах, а используется к...⭐⭐⭐⭐37.94s
14How to sign HTTP requests to ClickHouse HTTP API or to external loaders using OAuth RSA-SHA256?Информация не найдена в документах....⭐⭐⭐⭐26.89s
15How to generate and securely store a private key for signing or client TLS?Use openssl commands:\n", + "\n", + "- Generate RSA keys:\n", + " ```\n", + "openssl genpkey -algorithm RSA -out private_key_pkcs_8.pem -pkeyopt rsa_keygen_bits:4096\n", + "```\n", + " ```\n", + "o...⭐⭐⭐⭐40.07s
16Do I need a private key for bulk transfer API (v4/transfer) or only for certain operations?Информация не найдена в документах....⭐⭐⭐25.73s
17Which APIs or clients require a private key or TLS client certificate?Информация не найдена в документах....⭐⭐⭐26.38s
18What is the difference between v2/sale and v4/sale event streams for schema and ingestion?Информация не найдена в документах....⭐⭐⭐⭐25.27s
19What is the difference between raw sale events and sale-form funnel events in the warehouse?Информация не найдена в документах....⭐⭐26.13s
20When will I receive a webhook or callback after an ingestion batch completes?После того, как исходная транзакция получит конечный статус....⭐⭐⭐⭐30.51s
21How to validate callback or webhook origin (signature, allowlist, TLS)?Информация не найдена в документах....⭐⭐⭐⭐29.4s
22Should I implement both status polling and callback handling for pipelines?Нет, поскольку в документах указано, что для получения окончательного статуса нужно реализовать только одну из двух возможностей: статусы можно получи...⭐⭐⭐⭐38.19s
23How to build an operator-facing “finish” or success view after a load completes?Информация не найдена в документах....⭐⭐30.47s
24How to parameterize dashboards or SQL safely without exposing raw secrets?Информация не найдена в документах....⭐⭐⭐⭐29.53s
25How to compute a control checksum for v2/sale payloads before insert into ClickHouse?control = SHA-1( ENDPOINTID | client_orderid | amount | email | merchant_control ) \n", + "\n", + "где ENDPOINTID — идентификатор эндпоинта, client_orderid — клиент...⭐⭐⭐⭐47.13s
26How to compute a control checksum for v4/transfer payloads before insert into ClickHouse?Для вычисления контрольной суммы для в4/transfer пакетов необходимо:\n", + "\n", + "1. Преобразовать параметры запроса к шестнадцатеричному представлению (как указа...⭐⭐⭐⭐52.15s
27What staging table or intermediate layer do we use for 3DS upload-method-url results?В документах не указано название конкретной таблицы или интермедиарного слоя....⭐⭐⭐⭐25.02s
28Do I need a 3DS upload path in the pipeline for v2/sale?Информация не найдена в документах....⭐⭐⭐⭐23.17s
29Do I need a 3DS upload path in the pipeline for v2/sale-form?Информация не найдена в документах....⭐⭐⭐⭐26.19s
30Do I need a 3DS upload path in the pipeline for v2/return?Информация не найдена в документах....⭐⭐⭐⭐25.22s
31How to model reversal events (compensating rows or status flags)?Compensating rows или status flag Type: Enum Default: Yes....⭐⭐⭐⭐28.31s
32How to model refund events and link them to the original transaction?Моделировать возвраты можно, используя статус «reversal», который указывает на полную или частичную отмену предыдущего одобренного транзакции. Этот ст...⭐⭐⭐⭐43.28s
33How to ingest Google Pay events into ClickHouse?Информация не найдена в документах....⭐⭐25.9s
34How to ingest Apple Pay events into ClickHouse?Информация не найдена в документах....⭐⭐⭐⭐26.87s
35Where are test scenarios or golden datasets documented? Информация не найдена в документах....⭐⭐⭐26.73s
36Where is the Postman collection or equivalent HTTP examples for our loaders?Информация не найдена в документах....⭐⭐⭐25.94s
37Where is the schema for ingestion status responses (system tables, metadata tables)?Информация не найдена в документах....⭐⭐⭐⭐27.42s
38What insert or query throughput per node should we plan for in ClickHouse?Информация не найдена в документах....⭐⭐⭐⭐26.87s
39Which tenants or merchants are represented in the warehouse?Информация не найдена в документах....⭐⭐26.03s
40Which currencies exist in the dimension table?Информация не найдена в документах....26.2s
41Which payment methods exist in the dimension table?Информация не найдена в документах....⭐⭐⭐26.03s
42Which transaction types are modeled in the fact table?sale, chargeback and amount of funds held....⭐⭐⭐⭐26.23s
43How to ingest bank transfer events into ClickHouse?Информация не найдена в документах....26.91s
44How to export full transaction history (SELECT … FORMAT, object storage, backups)?Информация не найдена в документах....⭐⭐⭐26.99s
45What is the INSERT or JSONEachRow schema for raw v2/sale events?Информация не найдена в документах....⭐⭐⭐⭐27.49s
46What is the SQL or report definition for the transaction report?Информация не найдена в документах....⭐⭐⭐26.46s
47Is the control parameter mandatory for v2/sale ingestion?Информация не найдена в документах....⭐⭐⭐⭐25.93s
48Is the control parameter mandatory for v4/transfer ingestion?Информация не найдена в документах....⭐⭐⭐⭐24.35s
49What is RPI in the transfer API, and which column or surrogate key replaces it in ClickHouse?Информация не найдена в документах....⭐⭐⭐⭐24.76s
50What is the difference between storing RPI vs card number in policy terms?Информация не найдена в документах....⭐⭐⭐⭐25.7s
51Is it safe to store RPI in ClickHouse under our retention and access model?Информация не найдена в документах....⭐⭐26.03s
52Is it safe to store card number in ClickHouse at all?Информация не найдена в документах....⭐⭐24.74s
53Do we need PCI controls for storing raw v2/sale payloads?Информация не найдена в документах....⭐⭐⭐⭐26.04s
54Do we need PCI controls for sale-form funnel events?Информация не найдена в документах....26.71s
55Do we need PCI DSS certification for integrating analytics with Payneteasy data?Информация не найдена в документах....⭐⭐30.22s
56I received a merchant control key; how do I use it with ClickHouse (secrets manager, not plain tableИнформация не найдена в документах....⭐⭐25.87s
57What is your model and version?Похоже, что вопрос относятся к версии Doc2.0 Manager Manual. \n", + "\n", + "Модель: Doc2.0\n", + "Версия: не указана (в данном контексте)....⭐⭐⭐⭐28.74s
58When were you launched?Информация не найдена в документах....25.64s
59Do you collect my messages?Информация не найдена в документах....⭐⭐29.01s
\n", + "

\n", + " 💾 Полные результаты сохранены в benchmark_results.csv\n", + "

\n", + "
\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Красивый HTML-отчёт прямо в Jupyter\n", + "avg_score = df_res['score'].mean()\n", + "avg_time = df_res['time_total'].mean()\n", + "ok_pct = (df_res['status'] == 'ok').mean() * 100\n", + "score_dist = df_res['score'].value_counts().sort_index()\n", + "\n", + "rows_html = \"\"\n", + "for _, r in df_res.iterrows():\n", + " color = '#f0fdf4' if r['score'] >= 4 else ('#fff8e1' if r['score'] >= 3 else '#fef2f2')\n", + " stars = '⭐' * int(r['score']) if r['score'] > 0 else '—'\n", + " rows_html += f\"\"\"\n", + " \n", + " {int(r['№'])}\n", + " {r['question'][:100]}\n", + " {str(r['answer'])[:150]}...\n", + " {stars}\n", + " {r['time_total']}s\n", + " \"\"\"\n", + "\n", + "html = f\"\"\"\n", + "
\n", + "

\n", + " 📊 Результаты RAG Benchmark\n", + "

\n", + "
\n", + "
\n", + "
{avg_score:.2f}
\n", + "
Средняя оценка / 5
\n", + "
\n", + "
\n", + "
{avg_time:.1f}s
\n", + "
Среднее время ответа
\n", + "
\n", + "
\n", + "
{ok_pct:.0f}%
\n", + "
Успешных ответов
\n", + "
\n", + "
\n", + "
{len(df_res)}
\n", + "
Всего вопросов
\n", + "
\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " {rows_html}\n", + "
#ВопросОтветОценкаВремя
\n", + "

\n", + " 💾 Полные результаты сохранены в benchmark_results.csv\n", + "

\n", + "
\n", + "\"\"\"\n", + "\n", + "display(HTML(html))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/rag_engine/engine.py b/rag_engine/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..19de9542afe459d13b60522a04fe2d4cb06c6ac2 --- /dev/null +++ b/rag_engine/engine.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +RAG Engine - FIXED VERSION +""" + +import time +import json +import hashlib +from typing import Dict, List, Tuple +import ollama + +from config import config +from core.database import db +from core.embeddings import embedder +from core.reranker import reranker +from router.smart_router import select_prompt + + +class RAGEngine: + """RAG движок с автоматическим выбором промпта""" + + @staticmethod + def ask(question: str) -> Dict: + t_start = time.time() + + # Кэш + cache_key = hashlib.md5(question.encode()).hexdigest() + cached = db.get_cache(cache_key) + if cached: + result = json.loads(cached) + result['cached'] = True + return result + + # Поиск контекста + q_emb = list(embedder.generate_cached(question)) + results = db.search(q_emb) + + if not results: + return { + 'question': question, + 'answer': "NOT FOUND in documentation", + 'sources': [], + 'time_total': round(time.time() - t_start, 2) + } + + # Реранжинг + reranked = reranker.rerank(question, results) + + # Контекст + context_parts = [] + sources = [] + for r in reranked[:config.rerank_top_k]: + chunk, source, page = r[0], r[1], r[2] + context_parts.append(f"[{source}, p.{page}]\n{chunk[:800]}") + sources.append((source, page)) + + context = "\n\n".join(context_parts) + + # Выбор промпта + system_prompt, num_predict, temperature = select_prompt(question) + + # Форматирование + if '{context}' in system_prompt and '{query}' in system_prompt: + formatted_prompt = system_prompt.format(context=context, query=question) + else: + formatted_prompt = system_prompt + + # Генерация + messages = [ + {"role": "system", "content": formatted_prompt}, + {"role": "user", "content": f"CONTEXT:\n{context}\n\nQUESTION: {question}"} + ] + + try: + response = ollama.chat( + model=config.llm_model, + messages=messages, + options={ + "num_predict": num_predict, + "temperature": temperature, + "top_k": 40, + "top_p": config.top_p, + "num_ctx": config.num_ctx, + "repeat_penalty": config.repeat_penalty + } + ) + # ✅ ИСПРАВЛЕНО: response - это словарь + if isinstance(response, dict): + answer = response['message']['content'] + else: + answer = response.message.content + status = 'success' + except Exception as e: + answer = f"ERROR: {e}" + status = 'error' + + result = { + 'question': question, + 'answer': answer, + 'sources': sources, + 'time_total': round(time.time() - t_start, 2), + 'cached': False, + 'status': status + } + + db.set_cache(cache_key, json.dumps(result)) + return result + + +rag = RAGEngine() \ No newline at end of file diff --git a/rag_engine/init.py b/rag_engine/init.py new file mode 100644 index 0000000000000000000000000000000000000000..fbecaec5935c8030913833a9fa84446a59a539ef --- /dev/null +++ b/rag_engine/init.py @@ -0,0 +1 @@ +from rag_engine.engine import RAGEngine, rag \ No newline at end of file diff --git a/rag_system_alternative.py b/rag_system_alternative.py new file mode 100644 index 0000000000000000000000000000000000000000..9e008f606c55911c019f68deebfaff1eb86ec2f5 --- /dev/null +++ b/rag_system_alternative.py @@ -0,0 +1,799 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +RAG SYSTEM WITH CLICKHOUSE + OLLAMA +Retrieval-Augmented Generation system for document querying. +Optimized for large documents (1000+ pages). +Target: 15-25 seconds per response, Accuracy: 3.5-4.0/5. +""" + +__version__ = "2.0.0" +__author__ = "RAG System" +__description__ = "Production RAG system with ClickHouse and Ollama" + +# ============================================================================== +# DEPENDENCY CHECK +# ============================================================================== + +def check_dependencies(): + """Check if all required packages are installed""" + required = { + 'clickhouse_connect': 'ClickHouse', + 'ollama': 'Ollama', + 'pypdf': 'PDF reader', + 'pandas': 'Data processing', + 'numpy': 'Numerical operations' + } + + optional = { + 'IPython': 'Jupyter widgets', + 'tqdm': 'Progress bars', + 'sklearn': 'Scikit-learn (reranking)', + 'pyarrow': 'Parquet export', + 'openpyxl': 'Excel export' + } + + missing_required = [] + missing_optional = [] + + for package, name in required.items(): + try: + __import__(package) + except ImportError: + missing_required.append(package) + print(f"[ERROR] {name} - MISSING") + + for package, name in optional.items(): + try: + __import__(package) + except ImportError: + missing_optional.append(package) + + if missing_required: + print(f"\n[WARN] Missing required packages: {', '.join(missing_required)}") + print(f"Install: pip install {' '.join(missing_required)}") + return False + + if missing_optional: + print(f"\n[INFO] Optional packages not installed: {', '.join(missing_optional)}") + print(f"Install for full features: pip install {' '.join(missing_optional)}") + + print("[OK] All required dependencies are installed!") + return True + +# ============================================================================== +# IMPORTS +# ============================================================================== + +import os +import sys +import json +import time +import re +import hashlib +from typing import List, Dict, Any, Optional, Tuple +from dataclasses import dataclass, field +from functools import lru_cache +import warnings +warnings.filterwarnings('ignore') + +# Core imports +import ollama +import clickhouse_connect +import pandas as pd +import numpy as np +from pypdf import PdfReader + +# Optional imports with fallbacks +try: + from IPython.display import display, HTML + IPYTHON_AVAILABLE = True +except ImportError: + IPYTHON_AVAILABLE = False + +try: + from tqdm.notebook import tqdm + TQDM_AVAILABLE = True +except ImportError: + try: + from tqdm import tqdm + TQDM_AVAILABLE = True + except ImportError: + TQDM_AVAILABLE = False + def tqdm(iterable, *args, **kwargs): + return iterable + +try: + from sklearn.metrics.pairwise import cosine_similarity + SKLEARN_AVAILABLE = True +except ImportError: + SKLEARN_AVAILABLE = False + +# ============================================================================== +# CONFIGURATION +# ============================================================================== + +@dataclass +class Config: + """Configuration class for RAG system""" + + # ClickHouse connection + ch_host: str = field(default_factory=lambda: os.getenv('CLICKHOUSE_HOST', 'ug1o26imbr.eu-central-1.aws.clickhouse.cloud')) + ch_user: str = field(default_factory=lambda: os.getenv('CLICKHOUSE_USER', 'default')) + ch_password: str = field(default_factory=lambda: os.getenv('CLICKHOUSE_PASSWORD', '~MlK_g7KdbqYH')) + ch_secure: bool = field(default_factory=lambda: os.getenv('CLICKHOUSE_SECURE', 'true').lower() == 'true') + + # Ollama models + embed_model: str = field(default_factory=lambda: os.getenv('EMBED_MODEL', 'nomic-embed-text')) + llm_model: str = field(default_factory=lambda: os.getenv('LLM_MODEL', 'llama3.2:3b')) + + # RAG parameters + chunk_size: int = field(default_factory=lambda: int(os.getenv('CHUNK_SIZE', '1000'))) + chunk_overlap: int = field(default_factory=lambda: int(os.getenv('CHUNK_OVERLAP', '150'))) + top_k: int = field(default_factory=lambda: int(os.getenv('TOP_K', '8'))) + rerank_top_k: int = field(default_factory=lambda: int(os.getenv('RERANK_TOP_K', '3'))) + similarity_threshold: float = field(default_factory=lambda: float(os.getenv('SIMILARITY_THRESHOLD', '0.35'))) + batch_size: int = field(default_factory=lambda: int(os.getenv('BATCH_SIZE', '32'))) + + # Generation parameters + num_ctx: int = field(default_factory=lambda: int(os.getenv('NUM_CTX', '4096'))) + num_predict: int = field(default_factory=lambda: int(os.getenv('NUM_PREDICT', '400'))) + temperature: float = field(default_factory=lambda: float(os.getenv('TEMPERATURE', '0.1'))) + top_p: float = field(default_factory=lambda: float(os.getenv('TOP_P', '0.9'))) + repeat_penalty: float = field(default_factory=lambda: float(os.getenv('REPEAT_PENALTY', '1.1'))) + + # Limits + max_text_length: int = field(default_factory=lambda: int(os.getenv('MAX_TEXT_LENGTH', '3072'))) + min_chunk_size: int = field(default_factory=lambda: int(os.getenv('MIN_CHUNK_SIZE', '100'))) + max_chunks_per_doc: int = field(default_factory=lambda: int(os.getenv('MAX_CHUNKS_PER_DOC', '2000'))) + + # Cache + cache_enabled: bool = field(default_factory=lambda: os.getenv('CACHE_ENABLED', 'true').lower() == 'true') + cache_ttl: int = field(default_factory=lambda: int(os.getenv('CACHE_TTL', '3600'))) + + # File paths + pdf_files: List[Tuple[str, str]] = field(default_factory=lambda: [ + (r'C:\Users\User\Desktop\Folder_vs_documents\integration.pdf', 'Integration'), + (r'C:\Users\User\Desktop\Folder_vs_documents\manager.pdf', 'Manager'), + (r'C:\Users\User\Desktop\Folder_vs_documents\merchant.pdf', 'Merchant'), + ]) + questions_csv: str = field(default_factory=lambda: os.getenv('QUESTIONS_CSV', r'C:\Users\User\Desktop\Folder_vs_documents\questions.csv')) + + def validate(self) -> bool: + """Validate configuration""" + if not self.ch_password and self.ch_host != 'localhost': + print("[ERROR] Password required for remote ClickHouse") + return False + if self.chunk_size <= self.chunk_overlap: + print("[ERROR] chunk_size must be > chunk_overlap") + return False + if not 0 < self.similarity_threshold < 1: + print("[ERROR] similarity_threshold must be between 0 and 1") + return False + return True + +config = Config() + +# ============================================================================== +# DATABASE MANAGER +# ============================================================================== + +class DatabaseManager: + """Manages ClickHouse database operations""" + + def __init__(self): + self._client = None + self._cache = {} + self._cache_time = {} + + def get_client(self): + """Get or create ClickHouse client""" + if self._client is None: + self._client = clickhouse_connect.get_client( + host=config.ch_host, + username=config.ch_user, + password=config.ch_password, + secure=config.ch_secure, + compress=True, + connect_timeout=30 + ) + print("[OK] Connected to ClickHouse") + return self._client + + def init_database(self): + """Initialize database schema""" + client = self.get_client() + client.command("DROP TABLE IF EXISTS default.rag_chunks") + client.command(""" + CREATE TABLE default.rag_chunks ( + id UInt64, + source String, + page UInt32, + chunk String, + embedding Array(Float32), + chunk_hash String, + char_count UInt32, + created_at DateTime DEFAULT now() + ) ENGINE = MergeTree() + PARTITION BY source + ORDER BY id + """) + print("[OK] Database initialized") + + def insert_batch(self, chunks: List[Dict]): + """Insert chunks in batches""" + if not chunks: + return + client = self.get_client() + rows = [[c['id'], c['source'], c['page'], c['chunk'], + c['embedding'], c['chunk_hash'], c['char_count']] for c in chunks] + client.insert('default.rag_chunks', rows, + column_names=['id', 'source', 'page', 'chunk', 'embedding', 'chunk_hash', 'char_count']) + print(f" [OK] Inserted {len(chunks)} chunks") + + def search(self, embedding: List[float]) -> List[tuple]: + """Search for similar chunks""" + client = self.get_client() + query = """ + SELECT chunk, source, page, cosineDistance(embedding, %(emb)s) AS distance + FROM default.rag_chunks + WHERE distance < %(threshold)s + ORDER BY distance ASC + LIMIT %(top_k)s + """ + result = client.query(query, parameters={ + 'emb': embedding, + 'threshold': config.similarity_threshold, + 'top_k': config.top_k + }) + return result.result_rows + + def get_cache(self, key: str): + """Get cached value""" + if not config.cache_enabled: + return None + if key in self._cache: + if time.time() - self._cache_time.get(key, 0) < config.cache_ttl: + return self._cache[key] + return None + + def set_cache(self, key: str, value: str): + """Set cached value""" + if config.cache_enabled: + self._cache[key] = value + self._cache_time[key] = time.time() + +db = DatabaseManager() + +# ============================================================================== +# EMBEDDING GENERATOR +# ============================================================================== + +class EmbeddingGenerator: + """Generates embeddings for text chunks""" + + def __init__(self): + self.model = config.embed_model + self.batch_size = config.batch_size + self.max_length = config.max_text_length + + def _truncate_text(self, text: str) -> str: + """Truncate text to safe length""" + if len(text) <= self.max_length: + return text + truncated = text[:self.max_length] + last_period = truncated.rfind('.') + if last_period > self.max_length // 2: + truncated = truncated[:last_period + 1] + return truncated.strip() + + def generate_batch(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for a batch of texts""" + if not texts: + return [] + safe_texts = [self._truncate_text(t) for t in texts] + + try: + response = ollama.embed(model=self.model, input=safe_texts) + return response['embeddings'] + except Exception as e: + print(f" [WARN] Embedding error: {e}") + return [[0.0] * 768 for _ in safe_texts] + + def generate(self, texts: List[str]) -> List[List[float]]: + """Generate embeddings for all texts with batching""" + all_embeddings = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i:i+self.batch_size] + embeddings = self.generate_batch(batch) + all_embeddings.extend(embeddings) + return all_embeddings + + @lru_cache(maxsize=256) + def generate_cached(self, text: str) -> tuple: + """Generate embedding with caching""" + embedding = self.generate_batch([text])[0] + return tuple(embedding) + +embedder = EmbeddingGenerator() + +# ============================================================================== +# RERANKER +# ============================================================================== + +class Reranker: + """Reranks search results for better accuracy""" + + @staticmethod + def rerank(question: str, results: List[tuple]) -> List[tuple]: + """Rerank results by keyword overlap and similarity""" + if not results: + return results + + q_words = set(re.findall(r'\b\w{4,}\b', question.lower())) + + scored = [] + for idx, result in enumerate(results): + chunk, source, page, distance = result + c_words = set(re.findall(r'\b\w{4,}\b', chunk.lower())) + overlap = len(q_words & c_words) / max(len(q_words), 1) + similarity = 1 - distance + final_score = similarity * 0.6 + overlap * 0.4 + scored.append((final_score, result)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [r for _, r in scored[:config.rerank_top_k]] + +reranker = Reranker() + +# ============================================================================== +# PDF PROCESSOR +# ============================================================================== + +class PDFProcessor: + """Processes PDF documents into chunks""" + + @staticmethod + def extract_pdf(pdf_path: str, source_name: str) -> List[Tuple[int, str]]: + """Extract text from PDF""" + try: + reader = PdfReader(pdf_path) + total_pages = len(reader.pages) + print(f" Total pages: {total_pages}") + + pages = [] + for i in range(total_pages): + try: + page = reader.pages[i] + text = page.extract_text() + if text and len(text.strip()) > config.min_chunk_size: + text = re.sub(r'\n+', ' ', text) + pages.append((i + 1, text.strip())) + except: + pass + return pages + except Exception as e: + print(f" [ERROR] Error: {e}") + return [] + + @staticmethod + def split_chunks(text: str) -> List[str]: + """Split text into overlapping chunks""" + size = config.chunk_size + words = text.split() + chunks = [] + step = size - config.chunk_overlap + + for i in range(0, len(words), step): + chunk = ' '.join(words[i:i+size]) + if len(chunk) > config.min_chunk_size: + chunks.append(chunk) + if len(chunks) >= config.max_chunks_per_doc: + break + return chunks + + @staticmethod + def process_document(pdf_path: str, source_name: str, start_id: int) -> List[Dict]: + """Process entire document into chunks""" + print(f"\nProcessing: {source_name}") + pages = PDFProcessor.extract_pdf(pdf_path, source_name) + + if not pages: + return [] + + chunks = [] + for page_num, text in pages: + for chunk in PDFProcessor.split_chunks(text): + if len(chunk) > config.max_text_length: + chunk = chunk[:config.max_text_length] + chunks.append({ + 'id': start_id + len(chunks), + 'source': source_name, + 'page': page_num, + 'chunk': chunk, + 'chunk_hash': hashlib.md5(chunk.encode()).hexdigest(), + 'char_count': len(chunk) + }) + if len(chunks) >= config.max_chunks_per_doc: + break + if len(chunks) >= config.max_chunks_per_doc: + break + + print(f" [OK] Created {len(chunks)} chunks") + return chunks + +pdf_processor = PDFProcessor() + +# ============================================================================== +# RAG ENGINE +# ============================================================================== + +SYSTEM_PROMPT = """You are a technical documentation expert. Answer based ONLY on the provided context. + +FORMAT: +ANSWER: [clear, specific answer] +SOURCE: [document name, page X] +EVIDENCE: [exact quote from documentation] + +If information not found: "NOT FOUND in documentation" + +Be concise, accurate, and always cite sources.""" + +class RAGEngine: + """Main RAG engine for question answering""" + + @staticmethod + def ask(question: str) -> Dict: + """Ask a question and get answer""" + t_start = time.time() + + # Check cache + cache_key = hashlib.md5(question.encode()).hexdigest() + cached = db.get_cache(cache_key) + if cached: + result = json.loads(cached) + result['cached'] = True + result['time_total'] = 0.5 + return result + + # Search + q_emb = list(embedder.generate_cached(question)) + results = db.search(q_emb) + + if not results: + return { + 'question': question, + 'answer': "NOT FOUND in documentation", + 'sources': [], + 'time_total': round(time.time() - t_start, 2) + } + + # Rerank + reranked = reranker.rerank(question, results) + + # Prepare context + context_parts = [] + sources = [] + for r in reranked: + chunk, source, page = r[0], r[1], r[2] + context_parts.append(f"[{source}, p.{page}]\n{chunk[:800]}\n[/]") + sources.append((source, page)) + + context = "\n\n".join(context_parts) + + # Generate answer + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": f"CONTEXT:\n{context}\n\nQUESTION: {question}"} + ] + + try: + response = ollama.chat( + model=config.llm_model, + messages=messages, + options={ + "num_predict": config.num_predict, + "temperature": config.temperature, + "top_k": 40, + "top_p": config.top_p, + "num_ctx": config.num_ctx, + "repeat_penalty": config.repeat_penalty + } + ) + answer = response.message.content + except Exception as e: + answer = f"ERROR: {e}" + + result = { + 'question': question, + 'answer': answer, + 'sources': sources, + 'time_total': round(time.time() - t_start, 2), + 'cached': False + } + + db.set_cache(cache_key, json.dumps(result)) + return result + +rag = RAGEngine() + +# ============================================================================== +# BENCHMARK +# ============================================================================== + +class Benchmark: + """Benchmarking utilities""" + + @staticmethod + def quick_judge(question: str, answer: str) -> int: + """Quick quality judgement""" + prompt = f"""Rate answer quality 1-5. Question: {question[:100]} Answer: {answer[:200]} +Reply ONLY a number 1-5.""" + + try: + resp = ollama.chat( + model=config.llm_model, + messages=[{"role": "user", "content": prompt}], + options={"num_predict": 5, "temperature": 0} + ) + match = re.search(r'[1-5]', resp.message.content) + return int(match.group()) if match else 0 + except: + return 0 + + @staticmethod + def run(questions: List[str]) -> pd.DataFrame: + """Run benchmark on questions""" + results = [] + for i, q in enumerate(tqdm(questions, desc=" Benchmarking")): + try: + res = rag.ask(q) + score = Benchmark.quick_judge(q, res['answer']) + results.append({ + 'id': i+1, + 'question': q, + 'answer': res['answer'][:300], + 'score': score, + 'time': res['time_total'], + 'sources': len(res['sources']), + 'status': 'ok' + }) + except Exception as e: + results.append({ + 'id': i+1, + 'question': q, + 'answer': f'ERROR: {e}', + 'score': 0, + 'time': 0, + 'sources': 0, + 'status': 'error' + }) + return pd.DataFrame(results) + +# ============================================================================== +# DATA EXPORT +# ============================================================================== + +class DataExporter: + """Export data to various formats""" + + @staticmethod + def to_csv(filename: str = "rag_chunks.csv") -> pd.DataFrame: + """Export chunks to CSV""" + print(f"\nExporting to {filename}...") + client = db.get_client() + result = client.query("SELECT id, source, page, chunk, char_count FROM default.rag_chunks ORDER BY id") + df = pd.DataFrame(result.result_rows, columns=['id', 'source', 'page', 'chunk', 'char_count']) + df.to_csv(filename, index=False, encoding='utf-8') + print(f" [OK] Exported {len(df)} chunks") + return df + + @staticmethod + def to_json(filename: str = "rag_results.json", results: List[Dict] = None): + """Export results to JSON""" + if results: + with open(filename, 'w', encoding='utf-8') as f: + json.dump(results, f, ensure_ascii=False, indent=2) + print(f" [OK] Exported to {filename}") + + @staticmethod + def to_parquet(filename: str = "rag_data.parquet"): + """Export to Parquet (requires pyarrow)""" + try: + import pyarrow + df = DataExporter.to_csv(filename.replace('.parquet', '.csv')) + df.to_parquet(filename) + print(f" [OK] Exported to {filename}") + except ImportError: + print(" [WARN] pyarrow not installed, skipping parquet export") + +# ============================================================================== +# MAIN EXECUTION +# ============================================================================== + +def load_documents(): + """Load all documents into database""" + print("\n" + "="*60) + print("LOADING DOCUMENTS") + print("="*60) + + for pdf_path, _ in config.pdf_files: + if not os.path.exists(pdf_path): + print(f"[ERROR] File not found: {pdf_path}") + return 0 + + db.init_database() + chunk_id = 0 + + for pdf_path, source_name in config.pdf_files: + chunks = pdf_processor.process_document(pdf_path, source_name, chunk_id) + if chunks: + texts = [c['chunk'] for c in chunks] + print(f" Generating {len(texts)} embeddings...") + embeddings = embedder.generate(texts) + for chunk, emb in zip(chunks, embeddings): + chunk['embedding'] = emb + db.insert_batch(chunks) + chunk_id += len(chunks) + + print(f"\nTOTAL CHUNKS: {chunk_id}") + return chunk_id + +def print_banner(): + """Print system banner""" + print(""" + ====================================================================== + RAG SYSTEM v2.0 + + Retrieval-Augmented Generation with ClickHouse + Ollama + + Features: + - Vector search in ClickHouse + - Reranking for better accuracy + - Caching for repeated queries + - Export to CSV/JSON/Parquet + + Expected: 15-25 seconds/query | Accuracy: 3.5-4.0/5 + ====================================================================== + """) + +def main(): + """Main execution function""" + print_banner() + + # Check dependencies + if not check_dependencies(): + print("\n[ERROR] Please install missing dependencies and try again.") + return + + # Validate config + if not config.validate(): + return + + # Check Ollama + try: + ollama.list() + print("[OK] Ollama is running\n") + except: + print("[ERROR] Ollama is not running!") + print(" Run: ollama serve") + print(" Then: ollama pull llama3.2:3b") + print(" And: ollama pull nomic-embed-text") + return + + # Load documents + total_chunks = load_documents() + if total_chunks == 0: + print("[ERROR] No documents loaded!") + return + + # Quick test + print("\n" + "="*60) + print("QUICK TEST") + print("="*60) + + test_q = "Which requests and responses I need to implement for Sale Form integration?" + print(f"\nQuestion: {test_q[:80]}...") + + start = time.time() + result = rag.ask(test_q) + elapsed = time.time() - start + + print(f"\nANSWER:\n{result['answer'][:500]}") + print(f"\nSOURCES: {result['sources']}") + print(f"TIME: {elapsed:.1f} seconds") + + if elapsed < 25: + print(f" [OK] Speed target achieved! ({elapsed:.1f}s)") + else: + print(f" [WARN] Still slow ({elapsed:.1f}s). Try: config.num_ctx = 2048") + + # Run benchmark if questions exist + if os.path.exists(config.questions_csv): + df = pd.read_csv(config.questions_csv) + questions = df.iloc[:, 0].dropna().tolist() + print(f"\nRunning benchmark on {len(questions)} questions...") + + results_df = Benchmark.run(questions) + + avg_score = results_df['score'].mean() + avg_time = results_df['time'].mean() + + print("\n" + "="*60) + print("BENCHMARK RESULTS") + print("="*60) + print(f"Average Score: {avg_score:.1f} / 5") + print(f"Average Time: {avg_time:.1f} seconds") + + # Save results + results_df.to_csv('benchmark_results.csv', index=False) + DataExporter.to_json('benchmark_results.json', results_df.to_dict('records')) + print(f"\nResults saved to: benchmark_results.csv") + + # Ask about export + print("\n" + "="*60) + print("DATA EXPORT") + print("="*60) + print("Export options:") + print(" 1. Export chunks to CSV") + print(" 2. Export chunks to Parquet") + print(" 3. Skip export") + + # Auto-export for non-interactive mode + DataExporter.to_csv() + + print("\n" + "="*60) + print("[OK] RAG SYSTEM READY!") + print("="*60) + print("\nUsage examples:") + print(" result = rag.ask('your question')") + print(" print(result['answer'])") + print(" print(result['sources'])") + print("\nCached questions respond in < 1 second") + +# ============================================================================== +# COMMAND LINE INTERFACE +# ============================================================================== + +def cli(): + """Command line interface""" + import argparse + + parser = argparse.ArgumentParser(description='RAG System with ClickHouse + Ollama') + parser.add_argument('--query', '-q', type=str, help='Question to ask') + parser.add_argument('--benchmark', '-b', action='store_true', help='Run benchmark') + parser.add_argument('--export', '-e', action='store_true', help='Export data') + parser.add_argument('--version', '-v', action='version', version=f'RAG System v{__version__}') + + args = parser.parse_args() + + if args.query: + # Single query mode + print_banner() + if load_documents() > 0: + result = rag.ask(args.query) + print(f"\nAnswer: {result['answer']}") + print(f"Sources: {result['sources']}") + print(f"Time: {result['time_total']}s") + elif args.benchmark: + # Benchmark mode + main() + elif args.export: + # Export mode + if load_documents() > 0: + DataExporter.to_csv() + DataExporter.to_parquet() + else: + # Interactive mode + main() + +# ============================================================================== +# ENTRY POINT +# ============================================================================== + +if __name__ == "__main__": + # Check if running in Jupyter + if 'get_ipython' in globals(): + main() + else: + cli() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..af8a152448fd9c367c1d993c3afc3a5cae64f61d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +clickhouse-connect==0.7.0 +ollama==0.3.0 +pypdf==4.0.0 +pandas==2.2.0 +numpy==1.24.3 +python-dotenv==1.0.0 \ No newline at end of file diff --git a/router/init.py b/router/init.py new file mode 100644 index 0000000000000000000000000000000000000000..d86798e54b662fde7d94d6cdb5b661a98b0da5a1 --- /dev/null +++ b/router/init.py @@ -0,0 +1 @@ +from router.smart_router import SmartPromptRouter, select_prompt \ No newline at end of file diff --git a/router/smart_router.py b/router/smart_router.py new file mode 100644 index 0000000000000000000000000000000000000000..c260e8da2b4581ccf3373d28ed72614c67ea545d --- /dev/null +++ b/router/smart_router.py @@ -0,0 +1,154 @@ +import re +import pandas as pd +from pathlib import Path +from typing import List, Dict, Tuple +from config import config + + +def load_prompt(filename: str) -> str: + prompt_path = Path(__file__).parent.parent / "prompts" / filename + if prompt_path.exists(): + with open(prompt_path, 'r', encoding='utf-8') as f: + return f.read() + return "" + + +class FewShotLoader: + @staticmethod + def load_examples() -> List[Dict]: + examples = [] + folder_path = Path(config.few_shot_folder) + + if not folder_path.exists(): + return [] + + for file_path in folder_path.rglob("*"): + if file_path.suffix == '.csv': + try: + df = pd.read_csv(file_path) + if 'question' in df.columns and 'answer' in df.columns: + for _, row in df.iterrows(): + examples.append({ + 'question': str(row['question']), + 'answer': str(row['answer']), + 'source': str(row.get('source', 'example')) + }) + except: + pass + elif file_path.suffix == '.txt': + try: + with open(file_path, 'r', encoding='utf-8') as f: + lines = [l.strip() for l in f if l.strip()] + for i in range(0, len(lines), 2): + if i+1 < len(lines): + examples.append({ + 'question': lines[i], + 'answer': lines[i+1], + 'source': 'txt' + }) + except: + pass + + return examples + + @staticmethod + def format_examples(examples: List[Dict], max_examples: int = 3) -> str: + if not examples: + return "" + + examples = examples[:max_examples] + formatted = "\n\n## EXAMPLES OF GOOD ANSWERS:\n\n" + + for i, ex in enumerate(examples, 1): + formatted += f"**Example {i}:**\n" + formatted += f"Question: {ex['question']}\n" + formatted += f"Answer: {ex['answer']}\n" + if ex.get('source'): + formatted += f"Source: {ex['source']}\n" + formatted += "\n" + + return formatted + + +RAG_API_PROMPT = load_prompt("rag_api_en.txt") +RAG_API_PARAMETER_PROMPT = load_prompt("rag_api_parameter_en.txt") +RAG_API_PARAMETERS_LIST_PROMPT = load_prompt("rag_api_parameters_list_en.txt") + +DEFAULT_PROMPT = """You are a technical documentation expert. Answer based ONLY on the provided context. + +FORMAT: +ANSWER: [clear, specific answer] +SOURCE: [document name, page X] + +If not found: "NOT FOUND" +""" + +FEW_SHOT_EXAMPLES = FewShotLoader.load_examples() + + +def get_relevant_examples(question: str, max_examples: int = 3) -> List[Dict]: + if not FEW_SHOT_EXAMPLES: + return [] + + q_lower = question.lower() + q_words = set(re.findall(r'\b\w{4,}\b', q_lower)) + + scored_examples = [] + for ex in FEW_SHOT_EXAMPLES: + ex_lower = ex['question'].lower() + ex_words = set(re.findall(r'\b\w{4,}\b', ex_lower)) + + if q_words and ex_words: + score = len(q_words & ex_words) / len(q_words) + else: + score = 0 + + scored_examples.append((score, ex)) + + scored_examples.sort(key=lambda x: x[0], reverse=True) + return [ex for score, ex in scored_examples[:max_examples] if score > 0] + + +def enhance_prompt_with_examples(base_prompt: str, question: str) -> str: + if not FEW_SHOT_EXAMPLES: + return base_prompt + + relevant_examples = get_relevant_examples(question, max_examples=3) + if not relevant_examples: + return base_prompt + + examples_text = FewShotLoader.format_examples(relevant_examples) + return base_prompt + examples_text + + +def select_prompt(question: str) -> Tuple[str, int, float]: + q_lower = question.lower() + + if any(kw in q_lower for kw in ['list of parameters', 'all parameters', 'parameter list']): + base_prompt = RAG_API_PARAMETERS_LIST_PROMPT if RAG_API_PARAMETERS_LIST_PROMPT else DEFAULT_PROMPT + num_predict, temperature = 1200, 0.05 + elif any(kw in q_lower for kw in ['parameter', 'param', 'field', 'difference']): + base_prompt = RAG_API_PARAMETER_PROMPT if RAG_API_PARAMETER_PROMPT else DEFAULT_PROMPT + num_predict, temperature = 800, 0.05 + else: + base_prompt = RAG_API_PROMPT if RAG_API_PROMPT else DEFAULT_PROMPT + num_predict, temperature = 1000, 0.1 + + enhanced_prompt = enhance_prompt_with_examples(base_prompt, question) + return enhanced_prompt, num_predict, temperature + + +class SmartPromptRouter: + @staticmethod + def select(question: str) -> Tuple[str, int, float]: + return select_prompt(question) + + @staticmethod + def get_examples_count() -> int: + return len(FEW_SHOT_EXAMPLES) + + @staticmethod + def reload_examples(): + global FEW_SHOT_EXAMPLES + FEW_SHOT_EXAMPLES = FewShotLoader.load_examples() + return len(FEW_SHOT_EXAMPLES) \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000000000000000000000000000000000000..8563003f1b80a905c8278a02ef086695650910a5 --- /dev/null +++ b/run.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from config import config +from core.database import db +from core.embeddings import embedder +from core.document_processor import doc_processor +from rag_engine.engine import rag +from evaluator.folder_scanner import FolderScanner +from evaluator.qa_loader import QALoader +from evaluator.results import ResultsAnalyzer +from router.smart_router import SmartPromptRouter +from datetime import datetime + +def load_documents(): + print("\n" + "="*60) + print("LOADING KNOWLEDGE BASE") + print("="*60) + + if not config.doc_files: + print(f"[WARN] No files found in: {config.docs_folder}") + return 0 + + # Создаем папку для ошибок загрузки + error_folder = "./data/errors" + os.makedirs(error_folder, exist_ok=True) + + sources = {} + for _, source in config.doc_files: + src = source.split('/')[0] if '/' in source else 'root' + sources[src] = sources.get(src, 0) + 1 + + print(f"\nSources:") + for src, count in sorted(sources.items()): + print(f" - {src}: {count} files") + + db.init_database() + chunk_id = 0 + failed_files = [] + + for file_path, source_name in config.doc_files: + try: + chunks = doc_processor.process_document(file_path, source_name, chunk_id) + if chunks: + texts = [c['chunk'] for c in chunks] + print(f" Generating {len(texts)} embeddings...") + embeddings = embedder.generate(texts) + for chunk, emb in zip(chunks, embeddings): + chunk['embedding'] = emb + db.insert_batch(chunks) + chunk_id += len(chunks) + except Exception as e: + print(f" [ERROR] FAILED: {source_name} - {str(e)[:50]}") + failed_files.append({ + 'file': file_path, + 'source': source_name, + 'error': str(e) + }) + # Продолжаем со следующим файлом + continue + + # Сохраняем список упавших файлов + if failed_files: + import json + error_file = os.path.join(error_folder, f"failed_files_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") + with open(error_file, 'w', encoding='utf-8') as f: + json.dump(failed_files, f, ensure_ascii=False, indent=2) + print(f"\n[WARN] Failed files saved to: {error_file}") + print(f" Total failed: {len(failed_files)} / {len(config.doc_files)}") + + print(f"\n[OK] TOTAL CHUNKS: {chunk_id}") + return chunk_id + + +def find_all_qa_pairs(): + print("\n" + "="*60) + print("SEARCHING FOR QUESTIONS & ANSWERS") + print("="*60) + + scanner = FolderScanner(config.docs_folder) + folders = scanner.scan() + + if not folders: + print("[WARN] No questions/answers files found") + print(" Expected: questions.txt + answers.txt in any subfolder") + return [] + + all_pairs = [] + for folder in folders: + print(f"\nFolder: {folder['folder_name']}") + print(f" Questions: {folder['questions_file']}") + print(f" Answers: {folder['answers_file']}") + + qa_pairs = QALoader.load_qa_pairs( + folder['questions_file'], + folder['answers_file'] + ) + print(f" Found pairs: {len(qa_pairs)}") + all_pairs.extend(qa_pairs) + + print(f"\nTOTAL QA PAIRS FOUND: {len(all_pairs)}") + return all_pairs + + +def run_evaluation(qa_pairs=None): + print("\n" + "="*60) + print("QA EVALUATION") + print("="*60) + + if qa_pairs is None: + qa_pairs = find_all_qa_pairs() + + if not qa_pairs: + print("[ERROR] No QA pairs found for evaluation") + return None + + # Создаем папку для ошибок + error_folder = "./data/errors" + os.makedirs(error_folder, exist_ok=True) + + examples_count = SmartPromptRouter.get_examples_count() + if examples_count > 0: + print(f"\nFew-shot examples: {examples_count}") + else: + print(f"\nNo few-shot examples. Add with: python train_on_answers.py --question ...") + + all_results = [] + failed_questions = [] # Список упавших вопросов + + for i, (question, expected_answer) in enumerate(qa_pairs): + print(f" [{i+1}/{len(qa_pairs)}] Processing...", end='\r') + + try: + result = rag.ask(question) + + words_q = set(question.lower().split()) + words_a = set(result['answer'].lower().split()) + similarity = len(words_q & words_a) / max(len(words_q), 1) if words_q else 0 + + all_results.append({ + 'question': question[:200], + 'expected_answer': expected_answer[:200], + 'generated_answer': result['answer'][:300], + 'similarity_score': round(similarity, 3), + 'time_seconds': result['time_total'], + 'sources': str(result['sources']), + 'status': 'success' + }) + + except Exception as e: + # Записываем упавший вопрос в исключение + error_msg = str(e) + print(f"\n [WARN] ERROR on question {i+1}: {question[:50]}...") + print(f" Error: {error_msg[:100]}") + + # Сохраняем в файл ошибок + failed_questions.append({ + 'index': i+1, + 'question': question, + 'expected_answer': expected_answer, + 'error': error_msg, + 'timestamp': datetime.now().isoformat() + }) + + # Добавляем в результаты как ошибку + all_results.append({ + 'question': question[:200], + 'expected_answer': expected_answer[:200], + 'generated_answer': f"ERROR: {error_msg[:100]}", + 'similarity_score': 0, + 'time_seconds': 0, + 'sources': '', + 'status': 'error' + }) + + # Продолжаем со следующим вопросом + continue + + print() + + # Сохраняем упавшие вопросы в отдельный файл + if failed_questions: + import json + error_file = os.path.join(error_folder, f"failed_questions_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") + with open(error_file, 'w', encoding='utf-8') as f: + json.dump(failed_questions, f, ensure_ascii=False, indent=2) + print(f"\n[WARN] Failed questions saved to: {error_file}") + print(f" Total failed: {len(failed_questions)} / {len(qa_pairs)}") + + if all_results: + df = ResultsAnalyzer.save(all_results) + avg_score = df[df['status'] == 'success']['similarity_score'].mean() if len(df[df['status'] == 'success']) > 0 else 0 + print(f"\n[STATS] Average score (success only): {avg_score:.3f}") + print(f" Success rate: {len([r for r in all_results if r['status'] == 'success'])}/{len(all_results)}") + print(f"\n[INFO] Train on good answers: python train_on_answers.py --file data/results/evaluation_results.csv") + + return all_results + + +def print_banner(): + print("") + print("="*60) + print("RAG SYSTEM - KNOWLEDGE BASE") + print("="*60) + print("") + print("Commands:") + print(" python run.py - Load docs + interactive") + print(" python run.py --evaluate - Evaluate on Q&A pairs") + print(" python run.py --query \"...\" - Ask a single question") + print(" python run.py --load-only - Only load documents") + print("") + + +def main(): + parser = argparse.ArgumentParser(description='RAG System') + parser.add_argument('--query', '-q', type=str, help='Ask a question') + parser.add_argument('--evaluate', '-e', action='store_true', help='Run QA evaluation') + parser.add_argument('--load-only', action='store_true', help='Only load documents') + + args = parser.parse_args() + + print_banner() + + try: + import ollama + ollama.list() + print("Ollama is running\n") + except: + print("[ERROR] Ollama is not running! Run: ollama serve") + return + + if args.load_only: + load_documents() + return + + if args.evaluate: + load_documents() + run_evaluation() + return + + if args.query: + load_documents() + result = rag.ask(args.query) + print(f"\nQuestion: {args.query}") + print(f"\nANSWER:\n{result['answer']}") + print(f"\nSOURCES: {result['sources']}") + print(f"TIME: {result['time_total']}s") + return + +def load_documents(): + print("\n" + "="*60) + print("LOADING KNOWLEDGE BASE") + print("="*60) + + # Проверяем, есть ли уже данные в базе + existing_chunks = db.get_chunk_count() + + if existing_chunks > 0: + print(f"[INFO] Database already has {existing_chunks} chunks") + print("[INFO] Skipping document loading...") + print("[INFO] Use --force-reload to reload all documents") + return existing_chunks + + if not config.doc_files: + print(f"[WARN] No files found in: {config.docs_folder}") + return 0 + + # Создаём таблицу (без перезаписи) + db.init_database(force_recreate=False) + + # ... остальной код загрузки ... + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/train_on_answers.py b/train_on_answers.py new file mode 100644 index 0000000000000000000000000000000000000000..06b6c5d20c6d900c89a4075ecae94039a7dfbb4d --- /dev/null +++ b/train_on_answers.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +import pandas as pd +from pathlib import Path + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from config import config + + +def add_to_few_shot(question: str, answer: str, source: str = "training"): + examples_file = Path(config.few_shot_folder) / "training_examples.csv" + examples_file.parent.mkdir(parents=True, exist_ok=True) + + if examples_file.exists(): + df = pd.read_csv(examples_file) + else: + df = pd.DataFrame(columns=['question', 'answer', 'source']) + + if not df[df['question'] == question].empty: + print(f" Example already exists: {question[:50]}...") + return False + + new_row = pd.DataFrame([{ + 'question': question, + 'answer': answer, + 'source': source + }]) + + df = pd.concat([df, new_row], ignore_index=True) + df.to_csv(examples_file, index=False, encoding='utf-8') + print(f" Added to few-shot: {question[:50]}...") + return True + + +def train_from_evaluation_results(results_file: str, min_score: float = 0.7): + if not os.path.exists(results_file): + print(f" File not found: {results_file}") + return 0 + + df = pd.read_csv(results_file) + good_answers = df[df['similarity_score'] >= min_score] + + print(f" Found {len(good_answers)} good answers (score >= {min_score})") + + added = 0 + for _, row in good_answers.iterrows(): + if add_to_few_shot(row['question'], row['generated_answer'], source="evaluation"): + added += 1 + + print(f" Added {added} new examples to few-shot") + return added + + +def train_from_txt(questions_txt: str, answers_txt: str, source: str = "txt"): + if not os.path.exists(questions_txt) or not os.path.exists(answers_txt): + print(f" Files not found") + return 0 + + with open(questions_txt, 'r', encoding='utf-8') as f: + questions = [line.strip() for line in f if line.strip()] + + with open(answers_txt, 'r', encoding='utf-8') as f: + answers = [line.strip() for line in f if line.strip()] + + min_len = min(len(questions), len(answers)) + + added = 0 + for i in range(min_len): + if add_to_few_shot(questions[i], answers[i], source=source): + added += 1 + + print(f" Added {added} examples from TXT") + return added + + +def list_examples(): + examples_file = Path(config.few_shot_folder) / "training_examples.csv" + + if not examples_file.exists(): + print(" No examples found") + return + + df = pd.read_csv(examples_file) + print(f"\n FEW-SHOT EXAMPLES ({len(df)} total):") + print("="*60) + for i, row in df.iterrows(): + print(f"\n{i+1}. Q: {row['question'][:80]}...") + print(f" A: {row['answer'][:80]}...") + + +def clear_examples(): + examples_file = Path(config.few_shot_folder) / "training_examples.csv" + if examples_file.exists(): + examples_file.unlink() + print(" All examples cleared") + else: + print(" No examples to clear") + + +def main(): + parser = argparse.ArgumentParser(description='Train RAG on answers') + parser.add_argument('--file', '-f', type=str, help='CSV file with evaluation results') + parser.add_argument('--questions-txt', type=str, help='TXT file with questions') + parser.add_argument('--answers-txt', type=str, help='TXT file with answers') + parser.add_argument('--question', '-q', type=str, help='Single question to add') + parser.add_argument('--answer', '-a', type=str, help='Answer for the question') + parser.add_argument('--source', '-s', type=str, default='manual', help='Source of the example') + parser.add_argument('--min-score', type=float, default=0.7, help='Minimum score to include') + parser.add_argument('--list', '-l', action='store_true', help='List all examples') + parser.add_argument('--clear', action='store_true', help='Clear all examples') + + args = parser.parse_args() + + if args.list: + list_examples() + elif args.clear: + clear_examples() + elif args.file: + train_from_evaluation_results(args.file, args.min_score) + elif args.questions_txt and args.answers_txt: + train_from_txt(args.questions_txt, args.answers_txt, args.source) + elif args.question and args.answer: + add_to_few_shot(args.question, args.answer, args.source) + else: + print(""" + + TRAIN ON ANSWERS - FEW-SHOT LEARNING + + + Usage: + # Train from evaluation results + python train_on_answers.py --file data/results/evaluation_results.csv + + # Train from TXT files + python train_on_answers.py --questions-txt questions.txt --answers-txt answers.txt + + # Add single example + python train_on_answers.py --question "What is X?" --answer "X is Y" + + # List all examples + python train_on_answers.py --list + + # Clear all examples + python train_on_answers.py --clear + """) + + +if __name__ == "__main__": + main() \ No newline at end of file