301 lines
9.3 KiB
Python
301 lines
9.3 KiB
Python
"""
|
|
Embeddings Service - Code indexing with vector embeddings.
|
|
"""
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
import httpx
|
|
import numpy as np
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CodeChunk:
|
|
"""A chunk of indexed code."""
|
|
file_path: str
|
|
content: str
|
|
start_line: int
|
|
end_line: int
|
|
chunk_type: str # program, section, paragraph, copybook
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
class EmbeddingsService:
|
|
"""
|
|
Service for generating and managing code embeddings.
|
|
|
|
Supports:
|
|
- Local MiniLM-L6-v2 (development)
|
|
- Azure OpenAI embeddings (production)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
provider: str = "local",
|
|
azure_endpoint: Optional[str] = None,
|
|
azure_key: Optional[str] = None,
|
|
azure_model: str = "text-embedding-3-large",
|
|
qdrant_url: str = "http://localhost:6333",
|
|
):
|
|
self.provider = provider
|
|
self.azure_endpoint = azure_endpoint
|
|
self.azure_key = azure_key
|
|
self.azure_model = azure_model
|
|
self.qdrant_url = qdrant_url
|
|
self._local_model = None
|
|
|
|
async def embed_text(self, text: str) -> List[float]:
|
|
"""Generate embedding for a text."""
|
|
if self.provider == "azure":
|
|
return await self._embed_azure(text)
|
|
else:
|
|
return self._embed_local(text)
|
|
|
|
async def _embed_azure(self, text: str) -> List[float]:
|
|
"""Generate embedding using Azure OpenAI."""
|
|
url = f"{self.azure_endpoint}/openai/deployments/{self.azure_model}/embeddings?api-version=2024-02-01"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
url,
|
|
headers={
|
|
"api-key": self.azure_key,
|
|
"Content-Type": "application/json",
|
|
},
|
|
json={"input": text},
|
|
timeout=60.0,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["data"][0]["embedding"]
|
|
|
|
def _embed_local(self, text: str) -> List[float]:
|
|
"""Generate embedding using local MiniLM model."""
|
|
if self._local_model is None:
|
|
from sentence_transformers import SentenceTransformer
|
|
self._local_model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
|
embedding = self._local_model.encode(text)
|
|
return embedding.tolist()
|
|
|
|
def parse_cobol_program(self, content: str, file_path: str) -> List[CodeChunk]:
|
|
"""
|
|
Parse a COBOL program into indexable chunks.
|
|
|
|
Extracts:
|
|
- PROGRAM-ID
|
|
- COPY statements
|
|
- CALL statements
|
|
- SECTIONs and PARAGRAPHs
|
|
- FILE-CONTROL
|
|
- Working Storage variables
|
|
"""
|
|
chunks = []
|
|
lines = content.split("\n")
|
|
|
|
# Extract PROGRAM-ID
|
|
program_id = None
|
|
for i, line in enumerate(lines):
|
|
match = re.search(r"PROGRAM-ID\.\s+(\S+)", line, re.IGNORECASE)
|
|
if match:
|
|
program_id = match.group(1).rstrip(".")
|
|
break
|
|
|
|
# Extract COPY statements
|
|
copies = []
|
|
for i, line in enumerate(lines):
|
|
match = re.search(r"COPY\s+(\S+)", line, re.IGNORECASE)
|
|
if match:
|
|
copies.append(match.group(1).rstrip("."))
|
|
|
|
# Extract CALL statements
|
|
calls = []
|
|
for i, line in enumerate(lines):
|
|
match = re.search(r"CALL\s+['\"](\S+)['\"]", line, re.IGNORECASE)
|
|
if match:
|
|
calls.append(match.group(1))
|
|
|
|
# Extract SECTIONs
|
|
current_section = None
|
|
section_start = 0
|
|
section_content = []
|
|
|
|
for i, line in enumerate(lines):
|
|
# Check for SECTION definition
|
|
match = re.search(r"^\s{7}(\S+)\s+SECTION", line)
|
|
if match:
|
|
# Save previous section
|
|
if current_section:
|
|
chunks.append(CodeChunk(
|
|
file_path=file_path,
|
|
content="\n".join(section_content),
|
|
start_line=section_start,
|
|
end_line=i - 1,
|
|
chunk_type="section",
|
|
metadata={
|
|
"program_id": program_id,
|
|
"section_name": current_section,
|
|
"copies": copies,
|
|
"calls": calls,
|
|
},
|
|
))
|
|
current_section = match.group(1)
|
|
section_start = i
|
|
section_content = [line]
|
|
elif current_section:
|
|
section_content.append(line)
|
|
|
|
# Save last section
|
|
if current_section:
|
|
chunks.append(CodeChunk(
|
|
file_path=file_path,
|
|
content="\n".join(section_content),
|
|
start_line=section_start,
|
|
end_line=len(lines) - 1,
|
|
chunk_type="section",
|
|
metadata={
|
|
"program_id": program_id,
|
|
"section_name": current_section,
|
|
"copies": copies,
|
|
"calls": calls,
|
|
},
|
|
))
|
|
|
|
# If no sections found, chunk the whole program
|
|
if not chunks:
|
|
chunks.append(CodeChunk(
|
|
file_path=file_path,
|
|
content=content,
|
|
start_line=1,
|
|
end_line=len(lines),
|
|
chunk_type="program",
|
|
metadata={
|
|
"program_id": program_id,
|
|
"copies": copies,
|
|
"calls": calls,
|
|
},
|
|
))
|
|
|
|
return chunks
|
|
|
|
async def index_chunks(
|
|
self,
|
|
chunks: List[CodeChunk],
|
|
collection: str,
|
|
product: str,
|
|
client: str,
|
|
) -> int:
|
|
"""Index code chunks into Qdrant."""
|
|
indexed = 0
|
|
|
|
for chunk in chunks:
|
|
# Generate embedding
|
|
text_to_embed = f"""
|
|
File: {chunk.file_path}
|
|
Type: {chunk.chunk_type}
|
|
{chunk.metadata.get('section_name', '')}
|
|
{chunk.content[:1000]}
|
|
"""
|
|
embedding = await self.embed_text(text_to_embed)
|
|
|
|
# Store in Qdrant
|
|
await self._store_vector(
|
|
collection=collection,
|
|
vector=embedding,
|
|
payload={
|
|
"file_path": chunk.file_path,
|
|
"content": chunk.content,
|
|
"start_line": chunk.start_line,
|
|
"end_line": chunk.end_line,
|
|
"chunk_type": chunk.chunk_type,
|
|
"product": product,
|
|
"client": client,
|
|
**chunk.metadata,
|
|
},
|
|
)
|
|
indexed += 1
|
|
|
|
return indexed
|
|
|
|
async def search_similar(
|
|
self,
|
|
query: str,
|
|
collection: str,
|
|
limit: int = 10,
|
|
filters: Optional[Dict[str, Any]] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search for similar code chunks."""
|
|
embedding = await self.embed_text(query)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
body = {
|
|
"vector": embedding,
|
|
"limit": limit,
|
|
"with_payload": True,
|
|
}
|
|
if filters:
|
|
body["filter"] = filters
|
|
|
|
response = await client.post(
|
|
f"{self.qdrant_url}/collections/{collection}/points/search",
|
|
json=body,
|
|
timeout=30.0,
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
results = response.json().get("result", [])
|
|
return [
|
|
{
|
|
"score": r["score"],
|
|
**r["payload"],
|
|
}
|
|
for r in results
|
|
]
|
|
return []
|
|
|
|
async def _store_vector(
|
|
self,
|
|
collection: str,
|
|
vector: List[float],
|
|
payload: Dict[str, Any],
|
|
) -> bool:
|
|
"""Store a vector in Qdrant."""
|
|
import uuid
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.put(
|
|
f"{self.qdrant_url}/collections/{collection}/points",
|
|
json={
|
|
"points": [
|
|
{
|
|
"id": str(uuid.uuid4()),
|
|
"vector": vector,
|
|
"payload": payload,
|
|
}
|
|
]
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
return response.status_code == 200
|
|
|
|
async def create_collection(
|
|
self,
|
|
name: str,
|
|
vector_size: int = 384, # MiniLM default
|
|
) -> bool:
|
|
"""Create a Qdrant collection."""
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.put(
|
|
f"{self.qdrant_url}/collections/{name}",
|
|
json={
|
|
"vectors": {
|
|
"size": vector_size,
|
|
"distance": "Cosine",
|
|
}
|
|
},
|
|
timeout=30.0,
|
|
)
|
|
return response.status_code in [200, 201]
|