Source code for dbt_llm_tools.vector_store

import os
import json

import chromadb
from chromadb.utils import embedding_functions

from dbt_llm_tools.dbt_model import DbtModel
from dbt_llm_tools.types import ParsedSearchResult


[docs] class VectorStore: """ A class representing a vector store for dbt models. Methods: get_client: Returns the client object for the vector store. upsert_models: Upsert the models into the vector store. reset_collection: Clear the collection of all documents. """ def __init__( self, openai_api_key: str, embedding_model_name: str = "text-embedding-3-large", vector_db_path: str = ".local_storage/chroma.db", test_mode: bool = False, ) -> None: """ Initializes a vector store for dbt models. Args: openai_api_key (str): Your OpenAI API key. embedding_model_name (str, optional): The name of the OpenAI embedding model to be used. db_persist_path (str, optional): The path to the persistent database file. Defaults to "./chroma.db". test_mode (bool, optional): Whether the vector store is being used in test mode. Defaults to False. """ if not isinstance(vector_db_path, str) or vector_db_path == "": raise Exception("Please provide a valid path for the persistent database.") os.makedirs(vector_db_path, exist_ok=True) self.__client = chromadb.PersistentClient(vector_db_path) self.__collection_name = "model_documentation" self.__openai_api_key = openai_api_key self.__embedding_fn = self.__get_embedding_fn( embedding_model_name, test_mode=test_mode ) self.__collection = self.__create_collection() def __get_embedding_fn( self, embedding_model_name: str, test_mode: bool = False ) -> embedding_functions.OpenAIEmbeddingFunction: """ Get the embedding function for the vector store. Args: embedding_model_name (str): The name of the OpenAI embedding model to be used. test_mode (bool, optional): Whether the vector store is being used in test mode. Defaults to False. Returns: embedding_functions.OpenAIEmbeddingFunction: The embedding function for the vector store. """ if test_mode: return embedding_functions.DefaultEmbeddingFunction() return embedding_functions.OpenAIEmbeddingFunction( api_key=self.__openai_api_key, model_name=embedding_model_name ) def __create_collection(self, distance_fn: str = "l2") -> chromadb.Collection: """ Create a new collection in the vector store. Args: distance_fn (str, optional): The distance function to be used for nearest neighbour search. Defaults to "l2". Returns: chromadb.Collection: The newly created collection. """ return self.__client.get_or_create_collection( name=self.__collection_name, metadata={"hnsw:space": distance_fn}, embedding_function=self.__embedding_fn, )
[docs] def set_embedding_fn(self, embedding_model_name: str) -> None: """ Set the embedding function for the vector store. Args: embedding_model_name (str): The name of the OpenAI embedding model to be used. """ self.__embedding_fn = self.__get_embedding_fn(embedding_model_name)
[docs] def get_client(self) -> chromadb.PersistentClient: """ Returns the client object for the vector store. Returns: chromadb.PersistentClient: The client object for the vector store. """ return self.__client
[docs] def upsert_models( self, models: list[DbtModel], ) -> None: """ Upsert the models into the vector store. Args: models (list[DbtModel]): A list of dbt model objects to be upserted into the vector store. Returns: None """ documents = [] metadatas = [] ids = [] for model in models: if not isinstance(model, DbtModel): raise Exception("Please provide a list of valid dbt model objects.") model_text = model.as_prompt_text() documents.append(model_text) metadatas.append( { "tags": json.dumps(model.tags), } ) ids.append(model.name) return self.__collection.upsert( documents=documents, metadatas=metadatas, ids=ids )
[docs] def get_models(self, model_ids: list[str] = None) -> list[DbtModel]: """ Get the models from the vector store. Args: model_ids (list[str], optional): A list of model ids to be retrieved from the vector store. Returns: list[DbtModel]: A list of dbt model objects retrieved from the vector store. """ models = [] raw_models = self.__collection.get(ids=model_ids) for i in range(len(raw_models["ids"])): models.append( { "id": raw_models["ids"][i], "document": raw_models["documents"][i], } ) return models
[docs] def query_collection( self, query: str, n_results: int = 3 ) -> list[ParsedSearchResult]: """ Query the collection for the k nearest neighbours to the query. Args: query (str): The query to be used for nearest neighbour search. n_results (int, optional): The number of nearest neighbours to be returned. Defaults to 3. Returns: list[ParsedSearchResult]: A list of parsed search results. """ closest_models = [] if not isinstance(query, str) or query == "": raise Exception("Please provide a valid query.") search_results = self.__collection.query( query_texts=[query], n_results=n_results, include=["documents", "distances", "metadatas"], ) for i in range(len(search_results["ids"][0])): closest_models.append( { "id": search_results["ids"][0][i], "metadata": search_results["metadatas"][0][i], "document": search_results["documents"][0][i], "distance": search_results["distances"][0][i], } ) return closest_models
[docs] def reset_collection(self) -> None: """ Clear the collection of all documents. Returns: None """ self.__client.delete_collection(self.__collection_name) self.__collection = self.__create_collection()