Files
solidstate-tools/mcp/SearchPaperByEmbedding/search.py
koko c0b2ec5983 sofvBV_mcp重构v2
Embedding copy
2025-10-22 23:59:23 +08:00

157 lines
5.6 KiB
Python

import json
import numpy as np
import os
import hashlib
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
class PaperSearcher:
def __init__(self, papers_file, model_type="openai", api_key=None, base_url=None):
with open(papers_file, 'r', encoding='utf-8') as f:
self.papers = json.load(f)
self.model_type = model_type
self.cache_file = self._get_cache_file(papers_file, model_type)
self.embeddings = None
if model_type == "openai":
from openai import OpenAI
self.client = OpenAI(
api_key=api_key or os.getenv('OPENAI_API_KEY'),
base_url=base_url
)
self.model_name = "text-embedding-3-large"
else:
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.model_name = "all-MiniLM-L6-v2"
self._load_cache()
def _get_cache_file(self, papers_file, model_type):
base_name = Path(papers_file).stem
file_hash = hashlib.md5(papers_file.encode()).hexdigest()[:8]
cache_name = f"cache_{base_name}_{file_hash}_{model_type}.npy"
return str(Path(papers_file).parent / cache_name)
def _load_cache(self):
if os.path.exists(self.cache_file):
try:
self.embeddings = np.load(self.cache_file)
if len(self.embeddings) == len(self.papers):
print(f"Loaded cache: {self.embeddings.shape}")
return True
self.embeddings = None
except:
self.embeddings = None
return False
def _save_cache(self):
np.save(self.cache_file, self.embeddings)
print(f"Saved cache: {self.cache_file}")
def _create_text(self, paper):
parts = []
if paper.get('title'):
parts.append(f"Title: {paper['title']}")
if paper.get('abstract'):
parts.append(f"Abstract: {paper['abstract']}")
if paper.get('keywords'):
kw = ', '.join(paper['keywords']) if isinstance(paper['keywords'], list) else paper['keywords']
parts.append(f"Keywords: {kw}")
return ' '.join(parts)
def _embed_openai(self, texts):
if isinstance(texts, str):
texts = [texts]
embeddings = []
batch_size = 100
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = self.client.embeddings.create(input=batch, model=self.model_name)
embeddings.extend([item.embedding for item in response.data])
return np.array(embeddings)
def _embed_local(self, texts):
if isinstance(texts, str):
texts = [texts]
return self.model.encode(texts, show_progress_bar=len(texts) > 100)
def compute_embeddings(self, force=False):
if self.embeddings is not None and not force:
print("Using cached embeddings")
return self.embeddings
print(f"Computing embeddings ({self.model_name})...")
texts = [self._create_text(p) for p in self.papers]
if self.model_type == "openai":
self.embeddings = self._embed_openai(texts)
else:
self.embeddings = self._embed_local(texts)
print(f"Computed: {self.embeddings.shape}")
self._save_cache()
return self.embeddings
def search(self, examples=None, query=None, top_k=100):
if self.embeddings is None:
self.compute_embeddings()
if examples:
texts = []
for ex in examples:
text = f"Title: {ex['title']}"
if ex.get('abstract'):
text += f" Abstract: {ex['abstract']}"
texts.append(text)
if self.model_type == "openai":
embs = self._embed_openai(texts)
else:
embs = self._embed_local(texts)
query_emb = np.mean(embs, axis=0).reshape(1, -1)
elif query:
if self.model_type == "openai":
query_emb = self._embed_openai(query).reshape(1, -1)
else:
query_emb = self._embed_local(query).reshape(1, -1)
else:
raise ValueError("Provide either examples or query")
similarities = cosine_similarity(query_emb, self.embeddings)[0]
top_indices = np.argsort(similarities)[::-1][:top_k]
return [{
'paper': self.papers[idx],
'similarity': float(similarities[idx])
} for idx in top_indices]
def display(self, results, n=10):
print(f"\n{'='*80}")
print(f"Top {len(results)} Results (showing {min(n, len(results))})")
print(f"{'='*80}\n")
for i, result in enumerate(results[:n], 1):
paper = result['paper']
sim = result['similarity']
print(f"{i}. [{sim:.4f}] {paper['title']}")
print(f" #{paper.get('number', 'N/A')} | {paper.get('primary_area', 'N/A')}")
print(f" {paper['forum_url']}\n")
def save(self, results, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
json.dump({
'model': self.model_name,
'total': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
print(f"Saved to {output_file}")