Skip to content

Commit

Permalink
added stop_event shecker
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoLisin committed Nov 3, 2023
1 parent 87551b4 commit 961c517
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion supervisely/train/src/sly_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def train(api: sly.Api, task_id, context, state, app_logger):
# start train script
api.app.set_field(task_id, "state.activeNames", ["labels", "train", "pred", "metrics"]) # "logs",
get_progress_cb("YOLOv5: Scanning data ", 1)(1)
train_yolov5.main()
train_yolov5.main(stop_event_check=g.my_app.app_is_stoped)

# upload artifacts directory to Team Files
upload_artifacts(g.local_artifacts_dir, g.remote_artifacts_dir)
Expand Down
13 changes: 8 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy
from pathlib import Path
from threading import Thread
from typing import Callable

import numpy as np
import torch.distributed as dist
Expand Down Expand Up @@ -66,7 +67,7 @@
from supervisely import logger


def train(hyp, opt, device, tb_writer=None):
def train(hyp, opt, device, stop_event_check: Callable[[], bool], tb_writer=None):
train_batches_uploaded = False

logger.info("hyperparameters", extra=hyp)
Expand Down Expand Up @@ -508,7 +509,9 @@ def train(hyp, opt, device, tb_writer=None):
if plots and ni == 10 and opt.sly:
train_batches_uploaded = True
upload_train_data_vis()

# add check is stopped
if stop_event_check() is True:
return
# end batch ------------------------------------------------------------------------------------------------
# end epoch ----------------------------------------------------------------------------------------------------

Expand Down Expand Up @@ -706,7 +709,7 @@ def train(hyp, opt, device, tb_writer=None):
return results


def main():
def main(stop_event_check: Callable[[], bool]):
parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str, default="yolov5s.pt", help="initial weights path")
parser.add_argument("--cfg", type=str, default="", help="model.yaml path")
Expand Down Expand Up @@ -854,7 +857,7 @@ def main():
# prefix = colorstr('tensorboard: ')
# logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer)
train(hyp, opt, device, stop_event_check, tb_writer)

# Evolve hyperparameters (optional)
else:
Expand Down Expand Up @@ -934,7 +937,7 @@ def main():
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device)
results = train(hyp.copy(), opt, device, stop_event_check)

# Write mutation results
print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
Expand Down

0 comments on commit 961c517

Please sign in to comment.