From e7107e9179d6287b1af4311fc017a74493d2dabc Mon Sep 17 00:00:00 2001 From: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> Date: Wed, 3 Jan 2024 08:12:46 +0100 Subject: [PATCH] improving evolve (#11348) * improving evole in train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix gen_ranges value in mutation part. Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * fix invalid syntax in line 532 remove on tab from "else" Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * fix range index Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py fix population size add crossover min and max rate Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update comments Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * save population for last generation The latest version incorporates a significant update whereby all hyper parameters are now stored in the population section of "evolve_population.yaml," located in "yolov5\data\hyps," following the transition to the new generation. This development allows for the continuation of a previously abandoned evolution process by utilizing the former population. Additionally, a new argument, "--evolve_population," has been introduced to enable the relocation of the manual "evolve_population.yaml" to any project directory to load for the aforementioned purpose. This enhancement offers greater flexibility and convenience to the users, making it easier for them to resume their evolutionary process. Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove try - except Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Add resume resume_evolve arg for **resume evolve from last generation**. Population will load from data/hyp by default and load all yaml file form them. Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update README.zh-CN.md Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> * Update train.py update pop_size Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> --------- Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- train.py | 246 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 176 insertions(+), 70 deletions(-) diff --git a/train.py b/train.py index b4c14c76a3cc..752d7b450557 100644 --- a/train.py +++ b/train.py @@ -468,6 +468,11 @@ def parse_opt(known=False): parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor') parser.add_argument('--noplots', action='store_true', help='save no plot files') parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations') + parser.add_argument('--evolve_population', + type=str, + default=ROOT / 'data/hyps', + help='location for loading population') + parser.add_argument('--resume_evolve', type=str, default=None, help='resume evolve from last generation') parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk') parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') @@ -555,37 +560,48 @@ def main(opt, callbacks=Callbacks()): # Evolve hyperparameters (optional) else: - # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit) + # Hyperparameter evolution metadata (including this hyperparameter True-False, lower_limit, upper_limit) meta = { - 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) - 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1 - 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay - 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok) - 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum - 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr - 'box': (1, 0.02, 0.2), # box loss gain - 'cls': (1, 0.2, 4.0), # cls loss gain - 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight - 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels) - 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight - 'iou_t': (0, 0.1, 0.7), # IoU training threshold - 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold - 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore) - 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) - 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction) - 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) - 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction) - 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg) - 'translate': (1, 0.0, 0.9), # image translation (+/- fraction) - 'scale': (1, 0.0, 0.9), # image scale (+/- gain) - 'shear': (1, 0.0, 10.0), # image shear (+/- deg) - 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - 'flipud': (1, 0.0, 1.0), # image flip up-down (probability) - 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability) - 'mosaic': (1, 0.0, 1.0), # image mixup (probability) - 'mixup': (1, 0.0, 1.0), # image mixup (probability) - 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability) + 'lr0': (False, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3) + 'lrf': (False, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + 'momentum': (False, 0.6, 0.98), # SGD momentum/Adam beta1 + 'weight_decay': (False, 0.0, 0.001), # optimizer weight decay + 'warmup_epochs': (False, 0.0, 5.0), # warmup epochs (fractions ok) + 'warmup_momentum': (False, 0.0, 0.95), # warmup initial momentum + 'warmup_bias_lr': (False, 0.0, 0.2), # warmup initial bias lr + 'box': (False, 0.02, 0.2), # box loss gain + 'cls': (False, 0.2, 4.0), # cls loss gain + 'cls_pw': (False, 0.5, 2.0), # cls BCELoss positive_weight + 'obj': (False, 0.2, 4.0), # obj loss gain (scale with pixels) + 'obj_pw': (False, 0.5, 2.0), # obj BCELoss positive_weight + 'iou_t': (False, 0.1, 0.7), # IoU training threshold + 'anchor_t': (False, 2.0, 8.0), # anchor-multiple threshold + 'anchors': (False, 2.0, 10.0), # anchors per output grid (0 to ignore) + 'fl_gamma': (False, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) + 'hsv_h': (True, 0.0, 0.1), # image HSV-Hue augmentation (fraction) + 'hsv_s': (True, 0.0, 0.9), # image HSV-Saturation augmentation (fraction) + 'hsv_v': (True, 0.0, 0.9), # image HSV-Value augmentation (fraction) + 'degrees': (True, 0.0, 45.0), # image rotation (+/- deg) + 'translate': (True, 0.0, 0.9), # image translation (+/- fraction) + 'scale': (True, 0.0, 0.9), # image scale (+/- gain) + 'shear': (True, 0.0, 10.0), # image shear (+/- deg) + 'perspective': (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + 'flipud': (True, 0.0, 1.0), # image flip up-down (probability) + 'fliplr': (True, 0.0, 1.0), # image flip left-right (probability) + 'mosaic': (True, 0.0, 1.0), # image mixup (probability) + 'mixup': (True, 0.0, 1.0), # image mixup (probability) + 'copy_paste': (True, 0.0, 1.0)} # segment copy-paste (probability) + + # GA configs + pop_size = 50 + mutation_rate_min = 0.01 + mutation_rate_max = 0.5 + crossover_rate_min = 0.5 + crossover_rate_max = 1 + min_elite_size = 2 + max_elite_size = 5 + tournament_size_min = 2 + tournament_size_max = 10 with open(opt.hyp, errors='ignore') as f: hyp = yaml.safe_load(f) # load hyps dict @@ -604,46 +620,128 @@ def main(opt, callbacks=Callbacks()): f'gs://{opt.bucket}/evolve.csv', str(evolve_csv), ]) - for _ in range(opt.evolve): # generations to evolve - if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate - # Select parent(s) - parent = 'single' # parent selection method: 'single' or 'weighted' - x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1) - n = min(5, len(x)) # number of previous results to consider - x = x[np.argsort(-fitness(x))][:n] # top n mutations - w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0) - if parent == 'single' or len(x) == 1: - # x = x[random.randint(0, n - 1)] # random selection - x = x[random.choices(range(n), weights=w)[0]] # weighted selection - elif parent == 'weighted': - x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination - - # Mutate - mp, s = 0.8, 0.2 # mutation probability, sigma - npr = np.random - npr.seed(int(time.time())) - g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1 - ng = len(meta) - v = np.ones(ng) - while all(v == 1): # mutate until a change occurs (prevent duplicates) - v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0) - for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) - hyp[k] = float(x[i + 7] * v[i]) # mutate - - # Constrain to limits - for k, v in meta.items(): - hyp[k] = max(hyp[k], v[1]) # lower limit - hyp[k] = min(hyp[k], v[2]) # upper limit - hyp[k] = round(hyp[k], 5) # significant digits - - # Train mutation - results = train(hyp.copy(), opt, device, callbacks) - callbacks = Callbacks() - # Write mutation results - keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss', - 'val/obj_loss', 'val/cls_loss') - print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket) - + # Delete the items in meta dictionary whose first value is False + del_ = [] + for item in meta.keys(): + if meta[item][0] is False: + del_.append(item) + hyp_GA = hyp.copy() # Make a copy of hyp dictionary + for item in del_: + del meta[item] # Remove the item from meta dictionary + del hyp_GA[item] # Remove the item from hyp_GA dictionary + + # Set lower_limit and upper_limit arrays to hold the search space boundaries + lower_limit = np.array([meta[k][1] for k in hyp_GA.keys()]) + upper_limit = np.array([meta[k][2] for k in hyp_GA.keys()]) + + # Create gene_ranges list to hold the range of values for each gene in the population + gene_ranges = [] + for i in range(len(upper_limit)): + gene_ranges.append((lower_limit[i], upper_limit[i])) + + # Initialize the population with initial_values or random values + initial_values = [] + + # If resuming evolution from a previous checkpoint + if opt.resume_evolve is not None: + assert os.path.isfile(ROOT / opt.resume_evolve), 'evolve population path is wrong!' + with open(ROOT / opt.resume_evolve, errors='ignore') as f: + evolve_population = yaml.safe_load(f) + for value in evolve_population.values(): + value = np.array([value[k] for k in hyp_GA.keys()]) + initial_values.append(list(value)) + + # If not resuming from a previous checkpoint, generate initial values from .yaml files in opt.evolve_population + else: + yaml_files = [f for f in os.listdir(opt.evolve_population) if f.endswith('.yaml')] + for file_name in yaml_files: + with open(os.path.join(opt.evolve_population, file_name)) as yaml_file: + value = yaml.safe_load(yaml_file) + value = np.array([value[k] for k in hyp_GA.keys()]) + initial_values.append(list(value)) + + # Generate random values within the search space for the rest of the population + if (initial_values is None): + population = [generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size)] + else: + if (pop_size > 1): + population = [ + generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size - len(initial_values))] + for initial_value in initial_values: + population = [initial_value] + population + + # Run the genetic algorithm for a fixed number of generations + list_keys = list(hyp_GA.keys()) + for generation in range(opt.evolve): + if (generation >= 1): + save_dict = {} + for i in range(len(population)): + little_dict = {} + for j in range(len(population[i])): + little_dict[list_keys[j]] = float(population[i][j]) + save_dict['gen' + str(generation) + 'number' + str(i)] = little_dict + + with open(save_dir / 'evolve_population.yaml', 'w') as outfile: + yaml.dump(save_dict, outfile, default_flow_style=False) + + # Adaptive elite size + elite_size = min_elite_size + int((max_elite_size - min_elite_size) * (generation / opt.evolve)) + # Evaluate the fitness of each individual in the population + fitness_scores = [] + for individual in population: + for key, value in zip(hyp_GA.keys(), individual): + hyp_GA[key] = value + hyp.update(hyp_GA) + results = train(hyp.copy(), opt, device, callbacks) + callbacks = Callbacks() + # Write mutation results + keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', + 'val/box_loss', 'val/obj_loss', 'val/cls_loss') + print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket) + fitness_scores.append(results[2]) + + # Select the fittest individuals for reproduction using adaptive tournament selection + selected_indices = [] + for i in range(pop_size - elite_size): + # Adaptive tournament size + tournament_size = max(max(2, tournament_size_min), + int(min(tournament_size_max, pop_size) - (generation / (opt.evolve / 10)))) + # Perform tournament selection to choose the best individual + tournament_indices = random.sample(range(pop_size), tournament_size) + tournament_fitness = [fitness_scores[j] for j in tournament_indices] + winner_index = tournament_indices[tournament_fitness.index(max(tournament_fitness))] + selected_indices.append(winner_index) + + # Add the elite individuals to the selected indices + elite_indices = [i for i in range(pop_size) if fitness_scores[i] in sorted(fitness_scores)[-elite_size:]] + selected_indices.extend(elite_indices) + # Create the next generation through crossover and mutation + next_generation = [] + for i in range(pop_size): + parent1_index = selected_indices[random.randint(0, pop_size - 1)] + parent2_index = selected_indices[random.randint(0, pop_size - 1)] + # Adaptive crossover rate + crossover_rate = max(crossover_rate_min, + min(crossover_rate_max, crossover_rate_max - (generation / opt.evolve))) + if random.uniform(0, 1) < crossover_rate: + crossover_point = random.randint(1, len(hyp_GA) - 1) + child = population[parent1_index][:crossover_point] + population[parent2_index][crossover_point:] + else: + child = population[parent1_index] + # Adaptive mutation rate + mutation_rate = max(mutation_rate_min, + min(mutation_rate_max, mutation_rate_max - (generation / opt.evolve))) + for j in range(len(hyp_GA)): + if random.uniform(0, 1) < mutation_rate: + child[j] += random.uniform(-0.1, 0.1) + child[j] = min(max(child[j], gene_ranges[j][0]), gene_ranges[j][1]) + next_generation.append(child) + # Replace the old population with the new generation + population = next_generation + # Print the best solution found + best_index = fitness_scores.index(max(fitness_scores)) + best_individual = population[best_index] + print('Best solution found:', best_individual) # Plot results plot_evolve(evolve_csv) LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n' @@ -651,6 +749,14 @@ def main(opt, callbacks=Callbacks()): f'Usage example: $ python train.py --hyp {evolve_yaml}') +def generate_individual(input_ranges, individual_length): + individual = [] + for i in range(individual_length): + lower_bound, upper_bound = input_ranges[i] + individual.append(random.uniform(lower_bound, upper_bound)) + return individual + + def run(**kwargs): # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt') opt = parse_opt(True)