Skip to content

Commit

Permalink
Added Optuna Search and Configuration changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Harikaraja committed Sep 23, 2024
1 parent de20550 commit 7c24956
Showing 1 changed file with 62 additions and 72 deletions.
134 changes: 62 additions & 72 deletions seed_embeddings/OpenKE/generate_embedding_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.tune.tune_config import TuneConfig
from ray.train import RunConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler

from ray.tune.search.optuna import OptunaSearch
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


Expand All @@ -33,7 +33,6 @@ def test_files(index_dir):
train = os.path.join(index_dir, "train2id.txt")

print(entities, relations, train)

if not os.path.exists(entities):
raise Exception("entity2id.txt not found")
if not os.path.exists(relations):
Expand Down Expand Up @@ -65,7 +64,6 @@ def train(arg_conf):
neg_ent=arg_conf["neg_ent"],
neg_rel=arg_conf["neg_rel"],
)

# dataloader for test (link prediction)
if arg_conf["link_pred"]:
test_dataloader = TestDataLoader(arg_conf["index_dir"], "link")
Expand All @@ -79,24 +77,21 @@ def train(arg_conf):
p_norm=1,
norm_flag=True,
)

# define the loss function
model = NegativeSampling(
model=transe,
loss=MarginLoss(margin=arg_conf["margin"]),
batch_size=train_dataloader.get_batch_size(),
)

# train the model
trainer = Trainer(
model=model,
data_loader=train_dataloader,
train_times=arg_conf["epoch"],
alpha=arg_conf["alpha"],
index_dir=arg_conf["index_dir"],
use_gpu=False,
use_gpu=arg_conf["use_gpu"],
)

trainer.run(
link_prediction=arg_conf["link_pred"],
test_dataloader=test_dataloader,
Expand Down Expand Up @@ -137,7 +132,6 @@ def findRep(src, dest, index_dir, src_type="json"):
if __name__ == "__main__":

ray.init()

parser = argparse.ArgumentParser()
parser.add_argument(
"--index_dir",
Expand Down Expand Up @@ -192,7 +186,14 @@ def findRep(src, dest, index_dir, src_type="json"):
type=float,
default=1.0,
)

parser.add_argument(
"--use_gpu",
dest="use_gpu",
help="To use GPU for computation",
required=False,
type=bool,
default=False,
)
arg_conf = parser.parse_args()

search_space = {
Expand All @@ -205,9 +206,11 @@ def findRep(src, dest, index_dir, src_type="json"):
"neg_ent": tune.randint(1, 30),
"neg_rel": tune.randint(1, 30),
"bern": tune.randint(0, 2),
"opt_method": tune.choice(["SGD", "Adagrad", "Adam", "Adadelta"]),
"opt_method": tune.choice(["SGD", "Adam"]),
#"opt_method": tune.choice(["SGD", "Adagrad", "Adam", "Adadelta"]),
"is_analogy": arg_conf.is_analogy,
"link_pred": arg_conf.link_pred,
"use_gpu": arg_conf.use_gpu,
}

try:
Expand All @@ -221,83 +224,58 @@ def findRep(src, dest, index_dir, src_type="json"):
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=min(arg_conf.epoch, 4000),
reduction_factor=2,
grace_period=15,
reduction_factor=3,
metric="AnalogiesScore",
mode="max",
)
tuner = tune.Tuner(
train,
param_space=search_space,
tune_config=TuneConfig(
max_concurrent_trials=4,
scheduler=scheduler,
num_samples=1,
),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
# *Best* checkpoints are determined by these params:
# checkpoint_score_attribute="AnalogiesScore",
# checkpoint_score_order="max",
)
),
)
elif arg_conf.link_pred:
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=min(arg_conf.epoch, 4000),
reduction_factor=2,
grace_period= 15,
reduction_factor=3,
metric="hit1",
mode="max",
)
tuner = tune.Tuner(
train,
param_space=search_space,
tune_config=TuneConfig(
max_concurrent_trials=4,
scheduler=scheduler,
num_samples=1,
),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
# *Best* checkpoints are determined by these params:
# checkpoint_score_attribute="hit1",
# checkpoint_score_order="max",
)
),
)
else:
scheduler = ASHAScheduler(
time_attr="training_iteration",
max_t=arg_conf.epoch,
grace_period=min(arg_conf.epoch, 4000),
reduction_factor=2,
grace_period= 15,
reduction_factor=3,
metric="loss",
mode="min",
)
tuner = tune.Tuner(
train,
param_space=search_space,
tune_config=TuneConfig(
max_concurrent_trials=4,
scheduler=scheduler,
num_samples=1,
),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=2,
# *Best* checkpoints are determined by these params:
# checkpoint_score_attribute="loss",
# checkpoint_score_order="min",
)
),
if arg_conf.use_gpu:
train_with_resources = tune.with_resources(
train, resources={"cpu": 0, "gpu": 0.0625}
)

else:
train_with_resources = tune.with_resources(
train, resources={"cpu": 10, "gpu": 0}
)

tuner = tune.Tuner(
train_with_resources,
param_space=search_space,
tune_config=TuneConfig(
search_alg=OptunaSearch(metric="loss",mode="min"),
max_concurrent_trials=16,
scheduler=scheduler,
num_samples=512,
),
run_config=RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=1,
# *Best* checkpoints are determined by these params:
checkpoint_score_attribute="loss",
checkpoint_score_order="min",
)
),
)
results = tuner.fit()

# Write the best result to a file, best_result.txt
best_result = None
if arg_conf.is_analogy:
Expand All @@ -315,12 +293,24 @@ def findRep(src, dest, index_dir, src_type="json"):
best_result = results.get_best_result(metric="AnalogiesScore", mode="max")

elif arg_conf.link_pred:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(results.get_best_result(metric="hit1", mode="max"))
)

print(
"Best Config Based on Hit1 : ",
results.get_best_result(metric="hit1", mode="max"),
)
best_result = results.get_best_result(metric="hit1", mode="max")
else:

with open(os.path.join(search_space["index_dir"], "best_result.txt"), "a") as f:
f.write(
"\n" + str(results.get_best_result(metric="loss", mode="min"))
)

print(
"Best Config Based on Loss : ",
results.get_best_result(metric="loss", mode="min"),
Expand All @@ -347,11 +337,10 @@ def findRep(src, dest, index_dir, src_type="json"):
margin,
),
)

best_checkpoint_path = best_result.checkpoint.path

print("best_checkpoint_path is: ",best_checkpoint_path)
file_name = os.listdir(best_checkpoint_path)[0]

print("file_name is: ",file_name)
if file_name.endswith(".ckpt"):
# Construct full file path
source_file = os.path.join(best_checkpoint_path, file_name)
Expand All @@ -368,6 +357,7 @@ def findRep(src, dest, index_dir, src_type="json"):
margin,
),
)
print("embeddings_path: ",embeddings_path)
findRep(outfile, embeddings_path, index_dir, src_type="ckpt")
else:
print("No .ckpt file found in the source directory.")
Expand All @@ -376,4 +366,4 @@ def findRep(src, dest, index_dir, src_type="json"):
print(result)
del results

print("Training finished...")
print("Training finished...")

0 comments on commit 7c24956

Please sign in to comment.