diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 8b0e3a85d1..728c81ac5d 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -145,12 +145,14 @@ class PassageModel(Base): source_id = Column(String) # vector storage - if config.archival_storage_type == "postgres": + if settings.memgpt_pg_uri_no_default: from pgvector.sqlalchemy import Vector embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) - else: + elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma": embedding = Column(CommonVector) + else: + raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}") embedding_config = Column(EmbeddingConfigColumn) metadata_ = Column(MutableJson) @@ -177,40 +179,6 @@ def to_record(self): ) -# def get_db_model( -# config: MemGPTConfig, -# table_name: str, -# table_type: TableType, -# #user_id: str, -# #agent_id: Optional[str] = None, -# dialect="postgresql", -# ): -# # Define a helper function to create or get the model class -# def create_or_get_model(class_name, base_model, table_name): -# if class_name in globals(): -# return globals()[class_name] -# Model = type(class_name, (base_model,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}}) -# globals()[class_name] = Model -# return Model -# -# if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: -# # create schema for archival memory -# -# """Create database model for table_name""" -# #class_name = f"{table_name.capitalize()}Model" + dialect -# return create_or_get_model(class_name, PassageModel, table_name) -# -# elif table_type == TableType.RECALL_MEMORY: -# -# """Create database model for table_name""" -# class_name = f"{table_name.capitalize()}Model" + dialect -# return create_or_get_model(class_name, MessageModel, table_name) -# -# else: -# raise ValueError(f"Table type {table_type} not implemented") -# - - class SQLStorageConnector(StorageConnector): def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) @@ -386,9 +354,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) - # create table - # self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id) - # construct URI from enviornment variables if settings.pg_uri: self.uri = settings.pg_uri @@ -419,8 +384,8 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None self.session_maker = db_context # self.session_maker = sessionmaker(bind=self.engine) - # with self.session_maker() as session: - # session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension + with self.session_maker() as session: + session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension ## create table # Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist @@ -527,7 +492,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None self.path = os.path.join(self.path, f"sqlite.db") # Create the SQLAlchemy engine - # self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite") # self.engine = create_engine(f"sqlite:///{self.path}") # Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist # self.session_maker = sessionmaker(bind=self.engine)