Skip to content

Commit

Permalink
Merge pull request #361 from aws-samples/xuhan-aics
Browse files Browse the repository at this point in the history
refactor: refactor profile using chatbot table
  • Loading branch information
IcyKallen committed Aug 23, 2024
2 parents d6969ba + abe5b5f commit 645f22e
Show file tree
Hide file tree
Showing 13 changed files with 33 additions and 87 deletions.
2 changes: 0 additions & 2 deletions source/infrastructure/lib/api/api-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ export class ApiConstruct extends Construct {
const domainEndpoint = props.knowledgeBaseStackOutputs.aosDomainEndpoint;
const sessionsTableName = props.chatStackOutputs.sessionsTableName;
const messagesTableName = props.chatStackOutputs.messagesTableName;
const profileTableName = props.chatStackOutputs.profileTableName;
const resBucketName = props.sharedConstructOutputs.resultBucket.bucketName;
const executionTableName = props.knowledgeBaseStackOutputs.executionTableName;
const etlObjTableName = props.knowledgeBaseStackOutputs.etlObjTableName;
Expand Down Expand Up @@ -433,7 +432,6 @@ export class ApiConstruct extends Construct {
environment: {
INDEX_TABLE_NAME: props.sharedConstructOutputs.indexTable.tableName,
CHATBOT_TABLE_NAME: props.sharedConstructOutputs.chatbotTable.tableName,
PROFILE_TABLE_NAME: profileTableName,
MODEL_TABLE_NAME: props.sharedConstructOutputs.modelTable.tableName,
EMBEDDING_ENDPOINT: props.modelConstructOutputs.defaultEmbeddingModelName,
},
Expand Down
4 changes: 0 additions & 4 deletions source/infrastructure/lib/chat/chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ export interface ChatStackOutputs {
messagesTableName: string;
promptTableName: string;
intentionTableName: string;
profileTableName: string;
sqsStatement: iam.PolicyStatement;
messageQueue: Queue;
dlq: Queue;
Expand All @@ -57,7 +56,6 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
public messagesTableName: string;
public promptTableName: string;
public intentionTableName: string;
public profileTableName: string;
public sqsStatement: iam.PolicyStatement;
public messageQueue: Queue;
public dlq: Queue;
Expand Down Expand Up @@ -87,7 +85,6 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
this.messagesTableName = chatTablesConstruct.messagesTableName;
this.promptTableName = chatTablesConstruct.promptTableName;
this.intentionTableName = chatTablesConstruct.intentionTableName;
this.profileTableName = chatTablesConstruct.profileTableName
this.chatbotTableName = props.sharedConstructOutputs.chatbotTable.tableName;
this.indexTableName = props.sharedConstructOutputs.indexTable.tableName;
this.modelTableName = props.sharedConstructOutputs.modelTable.tableName;
Expand Down Expand Up @@ -124,7 +121,6 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
RERANK_ENDPOINT: props.modelConstructOutputs.defaultEmbeddingModelName,
EMBEDDING_ENDPOINT: props.modelConstructOutputs.defaultEmbeddingModelName,
CHATBOT_TABLE_NAME: props.sharedConstructOutputs.chatbotTable.tableName,
PROFILE_TABLE_NAME: chatTablesConstruct.profileTableName,
SESSIONS_TABLE_NAME: chatTablesConstruct.sessionsTableName,
MESSAGES_TABLE_NAME: chatTablesConstruct.messagesTableName,
PROMPT_TABLE_NAME: chatTablesConstruct.promptTableName,
Expand Down
7 changes: 0 additions & 7 deletions source/infrastructure/lib/chat/chat-tables.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ export class ChatTablesConstruct extends Construct {
public messagesTableName: string;
public promptTableName: string;
public intentionTableName: string;
public profileTableName: string;

public readonly byUserIdIndex: string = "byUserId";
public readonly bySessionIdIndex: string = "bySessionId";
Expand Down Expand Up @@ -62,10 +61,6 @@ export class ChatTablesConstruct extends Construct {
name: "intentionId",
type: dynamodb.AttributeType.STRING,
}
const profileIdAttr = {
name: "profileId",
type: dynamodb.AttributeType.STRING,
}

const sessionsTable = new DynamoDBTable(this, "Session", sessionIdAttr, userIdAttr).table;
sessionsTable.addGlobalSecondaryIndex({
Expand All @@ -83,12 +78,10 @@ export class ChatTablesConstruct extends Construct {

const promptTable = new DynamoDBTable(this, "Prompt", groupNameAttr2, sortKeyAttr).table;
const intentionTable = new DynamoDBTable(this, "Intention", groupNameAttr, intentionIdAttr).table;
const profileTable = new DynamoDBTable(this, "Profile", groupNameAttr, profileIdAttr).table;

this.sessionsTableName = sessionsTable.tableName;
this.messagesTableName = messagesTable.tableName;
this.promptTableName = promptTable.tableName;
this.intentionTableName = intentionTable.tableName;
this.profileTableName = profileTable.tableName;
}
}
7 changes: 2 additions & 5 deletions source/lambda/etl/create_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
initiate_chatbot,
initiate_index,
initiate_model,
initiate_profile,
is_chatbot_existed,
)

Expand All @@ -21,7 +20,6 @@
index_table = dynamodb.Table(os.environ.get("INDEX_TABLE_NAME"))
chatbot_table = dynamodb.Table(os.environ.get("CHATBOT_TABLE_NAME"))
model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME"))
profile_table = dynamodb.Table(os.environ.get("PROFILE_TABLE_NAME"))


def lambda_handler(event, context):
Expand Down Expand Up @@ -80,9 +78,8 @@ def lambda_handler(event, context):
create_time,
DESCRIPTION,
)
initiate_chatbot(chatbot_table, group_name, chatbot_id, create_time)
initiate_profile(
profile_table,
initiate_chatbot(
chatbot_table,
group_name,
chatbot_id,
index_id,
Expand Down
29 changes: 6 additions & 23 deletions source/lambda/etl/utils/ddb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,35 +67,18 @@ def initiate_index(
)


def initiate_chatbot(chatbot_table, group_name, chatbot_id, create_time=None):
def initiate_chatbot(
chatbot_table, group_name, chatbot_id, index_id, index_type, tag, create_time=None
):
if not create_time:
create_time = str(datetime.now(timezone.utc))
create_item_if_not_exist(
is_existed, item = create_item_if_not_exist(
chatbot_table,
{"groupName": group_name, "chatbotId": chatbot_id},
{
"groupName": group_name,
"chatbotId": chatbot_id,
"languages": ["zh"],
"profileIds": [chatbot_id],
"createTime": create_time,
"updateTime": create_time,
"status": Status.ACTIVE.value,
},
)


def initiate_profile(
profile_table, group_name, profile_id, index_id, index_type, tag, create_time=None
):
if not create_time:
create_time = str(datetime.now(timezone.utc))
is_existed, item = create_item_if_not_exist(
profile_table,
{"groupName": group_name, "profileId": profile_id},
{
"groupName": group_name,
"profileId": profile_id,
"indexIds": {index_type: {"count": 1, "value": {tag: index_id}}},
"createTime": create_time,
"updateTime": create_time,
Expand All @@ -118,11 +101,11 @@ def initiate_profile(
item["indexIds"][index_type]["count"] = len(
item["indexIds"][index_type]["value"]
)
profile_table.put_item(Item=item)
chatbot_table.put_item(Item=item)
else:
# Add a new index type
item["indexIds"][index_type] = {"count": 1, "value": {tag: index_id}}
profile_table.put_item(Item=item)
chatbot_table.put_item(Item=item)


def is_chatbot_existed(ddb_table, group_name: str, chatbot_id: str):
Expand Down
16 changes: 4 additions & 12 deletions source/lambda/online/common_logic/common_utils/chatbot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,24 @@


class ChatbotManager:
def __init__(self, chatbot_table, index_table, model_table, profile_table):
def __init__(self, chatbot_table, index_table, model_table):
self.chatbot_table = chatbot_table
self.index_table = index_table
self.model_table = model_table
self.profile_table = profile_table

@classmethod
def from_environ(cls):
chatbot_table_name = os.environ.get("CHATBOT_TABLE_NAME", "")
model_table_name = os.environ.get("MODEL_TABLE_NAME", "")
index_table_name = os.environ.get("INDEX_TABLE_NAME", "")
profile_table_name = os.environ.get("PROFILE_TABLE_NAME", "")
dynamodb = boto3.resource("dynamodb")
chatbot_table = dynamodb.Table(chatbot_table_name)
model_table = dynamodb.Table(model_table_name)
index_table = dynamodb.Table(index_table_name)
profile_table = dynamodb.Table(profile_table_name)
chatbot_manager = cls(chatbot_table, index_table, model_table, profile_table)
chatbot_manager = cls(chatbot_table, index_table, model_table)
return chatbot_manager

def get_chatbot(self, group_name: str, chatbot_id: str, user_profile: str):
def get_chatbot(self, group_name: str, chatbot_id: str):
"""Get chatbot from chatbot id and add index, model, etc. data
Args:
Expand All @@ -43,13 +40,8 @@ def get_chatbot(self, group_name: str, chatbot_id: str, user_profile: str):
Key={"groupName": group_name, "chatbotId": chatbot_id}
)
chatbot_content = chatbot_response.get("Item")
user_profile_response = self.profile_table.get_item(
Key={"groupName": group_name, "profileId": user_profile}
)
user_profile_content = user_profile_response.get("Item")
if not chatbot_content or not user_profile_content:
if not chatbot_content:
return Chatbot.from_dynamodb_item({})
chatbot_content["indexIds"] = user_profile_content["indexIds"]
for index_type, index_item in chatbot_content.get("indexIds").items():
for tag, index_id in index_item.get("value").items():
index_content = self.index_table.get_item(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class ChatbotConfig(AllowBaseModel):
user_id: str = "default_user_id"
group_name: str = "Admin"
chatbot_id: str = "admin"
user_profile: str = "admin"
chatbot_mode: ChatbotMode = ChatbotMode.chat
use_history: bool = True
enable_trace: bool = True
Expand Down Expand Up @@ -164,9 +163,9 @@ def get_index_info(index_infos: dict, index_type: str, index_name: str):
raise KeyError(f"key: {index_type}->{index_name} not exits")

@classmethod
def get_index_infos_from_ddb(cls, group_name, chatbot_id, user_profile):
def get_index_infos_from_ddb(cls, group_name, chatbot_id):
chatbot_manager = ChatbotManager.from_environ()
chatbot = chatbot_manager.get_chatbot(group_name, chatbot_id, user_profile)
chatbot = chatbot_manager.get_chatbot(group_name, chatbot_id)
_infos = chatbot.index_ids or {}
infos = {}
for index_type, index_info in _infos.items():
Expand All @@ -188,9 +187,10 @@ def update_retrievers(
default_index_names: dict[str, list],
default_retriever_config: dict[str, dict],
):
index_infos = self.get_index_infos_from_ddb(
self.group_name, self.chatbot_id, self.user_profile
)
index_infos = self.get_index_infos_from_ddb(self.group_name, self.chatbot_id)
print(f"index_infos: {index_infos}")
print(f"default_index_names: {default_index_names}")
print(f"default_retriever_config: {default_retriever_config}")
for task_name, index_names in default_index_names.items():
assert task_name in ("qq_match", "intention", "private_knowledge")
if task_name == "qq_match":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@
chatbot_table_name = os.environ.get("CHATBOT_TABLE", "")
model_table_name = os.environ.get("MODEL_TABLE", "")
index_table_name = os.environ.get("INDEX_TABLE", "")
profile_table_name = os.environ.get("PROFILE_TABLE", "")
dynamodb = boto3.resource("dynamodb")
chatbot_table = dynamodb.Table(chatbot_table_name)
model_table = dynamodb.Table(model_table_name)
index_table = dynamodb.Table(index_table_name)
profile_table = dynamodb.Table(profile_table_name)
chatbot_manager = ChatbotManager(chatbot_table, index_table, model_table, profile_table)
chatbot_manager = ChatbotManager(chatbot_table, index_table, model_table)

region = boto3.Session().region_name

Expand Down
13 changes: 5 additions & 8 deletions source/lambda/online/lambda_main/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ def aics_restapi_event_handler(event_body: dict, context: dict, entry_executor):
user_id = "default_user_id"
group_name = "Admin"
chatbot_id = event_body.get("user_profile", {}).get("channel", "admin")
agent = event_body.get("user_profile", {}).get("agent")
if agent == 1:
user_profile = "admin"
else:
user_profile = "host"
# agent = event_body.get("user_profile", {}).get("agent")
# if agent == 1:
# user_profile = "admin"
# else:
# user_profile = "host"

ddb_history_obj = DynamoDBChatMessageHistory(
sessions_table_name=sessions_table_name,
Expand Down Expand Up @@ -175,7 +175,6 @@ def aics_restapi_event_handler(event_body: dict, context: dict, entry_executor):
standard_event_body["chatbot_config"]["user_id"] = user_id
standard_event_body["chatbot_config"]["group_name"] = group_name
standard_event_body["chatbot_config"]["chatbot_id"] = chatbot_id
standard_event_body["chatbot_config"]["user_profile"] = user_profile
standard_event_body["message_id"] = str(uuid.uuid4())
standard_event_body["custom_message_id"] = ""
standard_event_body["ws_connection_id"] = ""
Expand Down Expand Up @@ -298,7 +297,6 @@ def lambda_handler(event_body: dict, context: dict):
# TODO Need to modify key
group_name = event_body.get("chatbot_config", {}).get("group_name", "Admin")
chatbot_id = event_body.get("chatbot_config", {}).get("chatbot_id", "admin")
user_profile = event_body.get("chatbot_config", {}).get("user_profile", "admin")

ddb_history_obj = DynamoDBChatMessageHistory(
sessions_table_name=sessions_table_name,
Expand All @@ -320,7 +318,6 @@ def lambda_handler(event_body: dict, context: dict):
event_body["chatbot_config"]["user_id"] = user_id
event_body["chatbot_config"]["group_name"] = group_name
event_body["chatbot_config"]["chatbot_id"] = chatbot_id
event_body["chatbot_config"]["user_profile"] = user_profile
# TODO: chatbot id add to event body

# logger.info(f"event_body:\n{json.dumps(event_body,ensure_ascii=False,indent=2,cls=JSONEncoder)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,8 @@ def register_rag_tool(
def register_rag_tool_from_config(event_body: dict):
group_name = event_body.get("chatbot_config").get("group_name", "Admin")
chatbot_id = event_body.get("chatbot_config").get("chatbot_id", "admin")
user_profile = event_body.get("chatbot_config").get("user_profile", "admin")
chatbot_manager = ChatbotManager.from_environ()
chatbot = chatbot_manager.get_chatbot(group_name, chatbot_id, user_profile)
chatbot = chatbot_manager.get_chatbot(group_name, chatbot_id)
logger.info(chatbot)
for index_type, item_dict in chatbot.index_ids.items():
if index_type != IndexType.INTENTION:
Expand Down
5 changes: 1 addition & 4 deletions source/lambda/online/lambda_main/main_utils/parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,9 @@ def from_chatbot_config(cls, chatbot_config: dict):

group_name = chatbot_config["group_name"]
chatbot_id = chatbot_config["chatbot_id"]
user_profile = chatbot_config["user_profile"]

# init chatbot config
chatbot_config_obj = ChatbotConfig(
group_name=group_name, chatbot_id=chatbot_id, user_profile=user_profile
)
chatbot_config_obj = ChatbotConfig(group_name=group_name, chatbot_id=chatbot_id)
# init default llm
chatbot_config_obj.update_llm_config(default_llm_config)

Expand Down
12 changes: 6 additions & 6 deletions source/portal/src/pages/chatbot/ChatBot.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { useAuth } from 'react-oidc-context';
import {
LLM_BOT_COMMON_MODEL_LIST,
LLM_BOT_RETAIL_MODEL_LIST,
LLM_BOT_USER_PROFILE_LIST,
LLM_BOT_CHATBOT_LIST,
SCENARIO_LIST,
RETAIL_GOODS_LIST,
} from 'src/utils/const';
Expand Down Expand Up @@ -77,8 +77,8 @@ const ChatBot: React.FC<ChatBotProps> = (props: ChatBotProps) => {
// const [chatModeOption, setChatModeOption] = useState<SelectProps.Option>(
// LLM_BOT_CHAT_MODE_LIST[0],
// );
const [userProfileOption, setUserProfileOption] = useState<SelectProps.Option>(
LLM_BOT_USER_PROFILE_LIST[0],
const [chatbotOption, setChatbotOption] = useState<SelectProps.Option>(
LLM_BOT_CHATBOT_LIST[0],
);
const [useChatHistory, setUseChatHistory] = useState(true);
const [enableTrace, setEnableTrace] = useState(true);
Expand Down Expand Up @@ -433,10 +433,10 @@ const ChatBot: React.FC<ChatBotProps> = (props: ChatBotProps) => {
<div className="flex-v gap-10">
<div className="flex gap-5 send-message">
<Select
options={LLM_BOT_USER_PROFILE_LIST}
selectedOption={userProfileOption}
options={LLM_BOT_CHATBOT_LIST}
selectedOption={chatbotOption}
onChange={({ detail }) => {
setUserProfileOption(detail.selectedOption);
setChatbotOption(detail.selectedOption);
}}
/>
<div className="flex-1 pr">
Expand Down
6 changes: 1 addition & 5 deletions source/portal/src/utils/const.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,11 @@ export const LLM_BOT_CHAT_MODE_LIST: SelectProps.Option[] = [
},
];

export const LLM_BOT_USER_PROFILE_LIST: SelectProps.Option[] = [
export const LLM_BOT_CHATBOT_LIST: SelectProps.Option[] = [
{
label: 'admin',
value: 'admin',
},
{
label: 'host',
value: 'host',
}
];

export const SCENARIO_LIST: SelectProps.Option[] = [
Expand Down

0 comments on commit 645f22e

Please sign in to comment.