Skip to content

Commit

Permalink
fix postgres issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders committed Sep 18, 2024
1 parent 12e48ec commit 8330e98
Showing 1 changed file with 6 additions and 42 deletions.
48 changes: 6 additions & 42 deletions memgpt/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,14 @@ class PassageModel(Base):
source_id = Column(String)

# vector storage
if config.archival_storage_type == "postgres":
if settings.memgpt_pg_uri_no_default:
from pgvector.sqlalchemy import Vector

embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
else:
elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
embedding = Column(CommonVector)
else:
raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
embedding_config = Column(EmbeddingConfigColumn)
metadata_ = Column(MutableJson)

Expand All @@ -177,40 +179,6 @@ def to_record(self):
)


# def get_db_model(
# config: MemGPTConfig,
# table_name: str,
# table_type: TableType,
# #user_id: str,
# #agent_id: Optional[str] = None,
# dialect="postgresql",
# ):
# # Define a helper function to create or get the model class
# def create_or_get_model(class_name, base_model, table_name):
# if class_name in globals():
# return globals()[class_name]
# Model = type(class_name, (base_model,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
# globals()[class_name] = Model
# return Model
#
# if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# # create schema for archival memory
#
# """Create database model for table_name"""
# #class_name = f"{table_name.capitalize()}Model" + dialect
# return create_or_get_model(class_name, PassageModel, table_name)
#
# elif table_type == TableType.RECALL_MEMORY:
#
# """Create database model for table_name"""
# class_name = f"{table_name.capitalize()}Model" + dialect
# return create_or_get_model(class_name, MessageModel, table_name)
#
# else:
# raise ValueError(f"Table type {table_type} not implemented")
#


class SQLStorageConnector(StorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
Expand Down Expand Up @@ -386,9 +354,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None

super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)

# create table
# self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id)

# construct URI from enviornment variables
if settings.pg_uri:
self.uri = settings.pg_uri
Expand Down Expand Up @@ -419,8 +384,8 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None

self.session_maker = db_context
# self.session_maker = sessionmaker(bind=self.engine)
# with self.session_maker() as session:
# session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
with self.session_maker() as session:
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension

## create table
# Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
Expand Down Expand Up @@ -527,7 +492,6 @@ def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None
self.path = os.path.join(self.path, f"sqlite.db")

# Create the SQLAlchemy engine
# self.db_model = get_db_model(config, self.table_name, table_type, user_id, agent_id, dialect="sqlite")
# self.engine = create_engine(f"sqlite:///{self.path}")
# Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
# self.session_maker = sessionmaker(bind=self.engine)
Expand Down

0 comments on commit 8330e98

Please sign in to comment.