Skip to content

Commit

Permalink
improving evolve (#11348)
Browse files Browse the repository at this point in the history
* improving evole in train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix gen_ranges value in mutation part.

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* fix invalid syntax in line 532

remove on tab from "else" 

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* fix range index

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

fix population size
add crossover min and max rate

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update comments

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* save population for last generation

The latest version incorporates a significant update whereby all hyper parameters are now stored in the population section of "evolve_population.yaml," located in "yolov5\data\hyps," following the transition to the new generation. This development allows for the continuation of a previously abandoned evolution process by utilizing the former population. Additionally, a new argument, "--evolve_population," has been introduced to enable the relocation of the manual "evolve_population.yaml" to any project directory to load for the aforementioned purpose. This enhancement offers greater flexibility and convenience to the users, making it easier for them to resume their evolutionary process.

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove try - except

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Add resume resume_evolve arg for **resume evolve from last generation**.
Population will load from data/hyp by default and load all yaml file form them.


Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update README.zh-CN.md

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

* Update train.py

update pop_size

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>

---------

Signed-off-by: Shayan Mousavinia <45814390+ShAmoNiA@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ShAmoNiA and pre-commit-ci[bot] committed Jan 3, 2024
1 parent b61143c commit 66edf38
Showing 1 changed file with 176 additions and 70 deletions.
246 changes: 176 additions & 70 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ def parse_opt(known=False):
parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
parser.add_argument('--noplots', action='store_true', help='save no plot files')
parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
parser.add_argument('--evolve_population',
type=str,
default=ROOT / 'data/hyps',
help='location for loading population')
parser.add_argument('--resume_evolve', type=str, default=None, help='resume evolve from last generation')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
Expand Down Expand Up @@ -555,37 +560,48 @@ def main(opt, callbacks=Callbacks()):

# Evolve hyperparameters (optional)
else:
# Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
# Hyperparameter evolution metadata (including this hyperparameter True-False, lower_limit, upper_limit)
meta = {
'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
'box': (1, 0.02, 0.2), # box loss gain
'cls': (1, 0.2, 4.0), # cls loss gain
'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (0, 0.1, 0.7), # IoU training threshold
'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
'scale': (1, 0.0, 0.9), # image scale (+/- gain)
'shear': (1, 0.0, 10.0), # image shear (+/- deg)
'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
'mixup': (1, 0.0, 1.0), # image mixup (probability)
'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
'lr0': (False, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
'lrf': (False, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (False, 0.6, 0.98), # SGD momentum/Adam beta1
'weight_decay': (False, 0.0, 0.001), # optimizer weight decay
'warmup_epochs': (False, 0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (False, 0.0, 0.95), # warmup initial momentum
'warmup_bias_lr': (False, 0.0, 0.2), # warmup initial bias lr
'box': (False, 0.02, 0.2), # box loss gain
'cls': (False, 0.2, 4.0), # cls loss gain
'cls_pw': (False, 0.5, 2.0), # cls BCELoss positive_weight
'obj': (False, 0.2, 4.0), # obj loss gain (scale with pixels)
'obj_pw': (False, 0.5, 2.0), # obj BCELoss positive_weight
'iou_t': (False, 0.1, 0.7), # IoU training threshold
'anchor_t': (False, 2.0, 8.0), # anchor-multiple threshold
'anchors': (False, 2.0, 10.0), # anchors per output grid (0 to ignore)
'fl_gamma': (False, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
'hsv_h': (True, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (True, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (True, 0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (True, 0.0, 45.0), # image rotation (+/- deg)
'translate': (True, 0.0, 0.9), # image translation (+/- fraction)
'scale': (True, 0.0, 0.9), # image scale (+/- gain)
'shear': (True, 0.0, 10.0), # image shear (+/- deg)
'perspective': (True, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (True, 0.0, 1.0), # image flip up-down (probability)
'fliplr': (True, 0.0, 1.0), # image flip left-right (probability)
'mosaic': (True, 0.0, 1.0), # image mixup (probability)
'mixup': (True, 0.0, 1.0), # image mixup (probability)
'copy_paste': (True, 0.0, 1.0)} # segment copy-paste (probability)

# GA configs
pop_size = 50
mutation_rate_min = 0.01
mutation_rate_max = 0.5
crossover_rate_min = 0.5
crossover_rate_max = 1
min_elite_size = 2
max_elite_size = 5
tournament_size_min = 2
tournament_size_max = 10

with open(opt.hyp, errors='ignore') as f:
hyp = yaml.safe_load(f) # load hyps dict
Expand All @@ -604,53 +620,143 @@ def main(opt, callbacks=Callbacks()):
f'gs://{opt.bucket}/evolve.csv',
str(evolve_csv), ])

for _ in range(opt.evolve): # generations to evolve
if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
# Select parent(s)
parent = 'single' # parent selection method: 'single' or 'weighted'
x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
n = min(5, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness(x))][:n] # top n mutations
w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
if parent == 'single' or len(x) == 1:
# x = x[random.randint(0, n - 1)] # random selection
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
elif parent == 'weighted':
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination

# Mutate
mp, s = 0.8, 0.2 # mutation probability, sigma
npr = np.random
npr.seed(int(time.time()))
g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
ng = len(meta)
v = np.ones(ng)
while all(v == 1): # mutate until a change occurs (prevent duplicates)
v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
hyp[k] = float(x[i + 7] * v[i]) # mutate

# Constrain to limits
for k, v in meta.items():
hyp[k] = max(hyp[k], v[1]) # lower limit
hyp[k] = min(hyp[k], v[2]) # upper limit
hyp[k] = round(hyp[k], 5) # significant digits

# Train mutation
results = train(hyp.copy(), opt, device, callbacks)
callbacks = Callbacks()
# Write mutation results
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
'val/obj_loss', 'val/cls_loss')
print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)

# Delete the items in meta dictionary whose first value is False
del_ = []
for item in meta.keys():
if meta[item][0] is False:
del_.append(item)
hyp_GA = hyp.copy() # Make a copy of hyp dictionary
for item in del_:
del meta[item] # Remove the item from meta dictionary
del hyp_GA[item] # Remove the item from hyp_GA dictionary

# Set lower_limit and upper_limit arrays to hold the search space boundaries
lower_limit = np.array([meta[k][1] for k in hyp_GA.keys()])
upper_limit = np.array([meta[k][2] for k in hyp_GA.keys()])

# Create gene_ranges list to hold the range of values for each gene in the population
gene_ranges = []
for i in range(len(upper_limit)):
gene_ranges.append((lower_limit[i], upper_limit[i]))

# Initialize the population with initial_values or random values
initial_values = []

# If resuming evolution from a previous checkpoint
if opt.resume_evolve is not None:
assert os.path.isfile(ROOT / opt.resume_evolve), 'evolve population path is wrong!'
with open(ROOT / opt.resume_evolve, errors='ignore') as f:
evolve_population = yaml.safe_load(f)
for value in evolve_population.values():
value = np.array([value[k] for k in hyp_GA.keys()])
initial_values.append(list(value))

# If not resuming from a previous checkpoint, generate initial values from .yaml files in opt.evolve_population
else:
yaml_files = [f for f in os.listdir(opt.evolve_population) if f.endswith('.yaml')]
for file_name in yaml_files:
with open(os.path.join(opt.evolve_population, file_name)) as yaml_file:
value = yaml.safe_load(yaml_file)
value = np.array([value[k] for k in hyp_GA.keys()])
initial_values.append(list(value))

# Generate random values within the search space for the rest of the population
if (initial_values is None):
population = [generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size)]
else:
if (pop_size > 1):
population = [
generate_individual(gene_ranges, len(hyp_GA)) for i in range(pop_size - len(initial_values))]
for initial_value in initial_values:
population = [initial_value] + population

# Run the genetic algorithm for a fixed number of generations
list_keys = list(hyp_GA.keys())
for generation in range(opt.evolve):
if (generation >= 1):
save_dict = {}
for i in range(len(population)):
little_dict = {}
for j in range(len(population[i])):
little_dict[list_keys[j]] = float(population[i][j])
save_dict['gen' + str(generation) + 'number' + str(i)] = little_dict

with open(save_dir / 'evolve_population.yaml', 'w') as outfile:
yaml.dump(save_dict, outfile, default_flow_style=False)

# Adaptive elite size
elite_size = min_elite_size + int((max_elite_size - min_elite_size) * (generation / opt.evolve))
# Evaluate the fitness of each individual in the population
fitness_scores = []
for individual in population:
for key, value in zip(hyp_GA.keys(), individual):
hyp_GA[key] = value
hyp.update(hyp_GA)
results = train(hyp.copy(), opt, device, callbacks)
callbacks = Callbacks()
# Write mutation results
keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
'val/box_loss', 'val/obj_loss', 'val/cls_loss')
print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
fitness_scores.append(results[2])

# Select the fittest individuals for reproduction using adaptive tournament selection
selected_indices = []
for i in range(pop_size - elite_size):
# Adaptive tournament size
tournament_size = max(max(2, tournament_size_min),
int(min(tournament_size_max, pop_size) - (generation / (opt.evolve / 10))))
# Perform tournament selection to choose the best individual
tournament_indices = random.sample(range(pop_size), tournament_size)
tournament_fitness = [fitness_scores[j] for j in tournament_indices]
winner_index = tournament_indices[tournament_fitness.index(max(tournament_fitness))]
selected_indices.append(winner_index)

# Add the elite individuals to the selected indices
elite_indices = [i for i in range(pop_size) if fitness_scores[i] in sorted(fitness_scores)[-elite_size:]]
selected_indices.extend(elite_indices)
# Create the next generation through crossover and mutation
next_generation = []
for i in range(pop_size):
parent1_index = selected_indices[random.randint(0, pop_size - 1)]
parent2_index = selected_indices[random.randint(0, pop_size - 1)]
# Adaptive crossover rate
crossover_rate = max(crossover_rate_min,
min(crossover_rate_max, crossover_rate_max - (generation / opt.evolve)))
if random.uniform(0, 1) < crossover_rate:
crossover_point = random.randint(1, len(hyp_GA) - 1)
child = population[parent1_index][:crossover_point] + population[parent2_index][crossover_point:]
else:
child = population[parent1_index]
# Adaptive mutation rate
mutation_rate = max(mutation_rate_min,
min(mutation_rate_max, mutation_rate_max - (generation / opt.evolve)))
for j in range(len(hyp_GA)):
if random.uniform(0, 1) < mutation_rate:
child[j] += random.uniform(-0.1, 0.1)
child[j] = min(max(child[j], gene_ranges[j][0]), gene_ranges[j][1])
next_generation.append(child)
# Replace the old population with the new generation
population = next_generation
# Print the best solution found
best_index = fitness_scores.index(max(fitness_scores))
best_individual = population[best_index]
print('Best solution found:', best_individual)
# Plot results
plot_evolve(evolve_csv)
LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
f"Results saved to {colorstr('bold', save_dir)}\n"
f'Usage example: $ python train.py --hyp {evolve_yaml}')


def generate_individual(input_ranges, individual_length):
individual = []
for i in range(individual_length):
lower_bound, upper_bound = input_ranges[i]
individual.append(random.uniform(lower_bound, upper_bound))
return individual


def run(**kwargs):
# Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
opt = parse_opt(True)
Expand Down

0 comments on commit 66edf38

Please sign in to comment.