Improve crossover and neat parameters

This commit is contained in:
Felix Bargfeldt 2019-05-16 21:36:45 +02:00
parent d9c23e6674
commit aa33733315
No known key found for this signature in database
GPG key ID: 99184F5FDC589A67
4 changed files with 14 additions and 14 deletions

View file

@ -43,13 +43,13 @@ class Genome:
trait: ConnectionGene = random.choice([gene_map1, gene_map2])[key].copy()
if (not gene_map1[key].enabled) or (not gene_map2[key].enabled) and random.random() < 0.75:
trait.enabled = False
elif parent1.fitness == parent2.fitness:
trait: ConnectionGene = random.choice([gene_map1, gene_map2]).get(key, None)
elif key in gene_map1:
trait: ConnectionGene = gene_map1[key]
else:
trait: ConnectionGene = gene_map1.get(key, None)
# continue
trait: ConnectionGene = gene_map2[key]
if trait is not None:
child.connection_gene_list.append(trait.copy())
child.connection_gene_list.append(trait.copy())
return child
@ -137,11 +137,10 @@ class Genome:
def mutate_weight(self):
for connection in self.connection_gene_list:
if random.random() < WEIGHT_CHANCE:
if random.random() < PERTURB_CHANCE:
connection.weight += (2 * random.random() - 1) * STEPS
else:
connection.weight = 4 * random.random() - 2
if random.random() < PERTURB_CHANCE:
connection.weight += random.gauss(0, 1)
else:
connection.weight = 4 * random.random() - 2
def mutate_add_connection(self):
self.generate_network()

View file

@ -1,4 +1,4 @@
COMPATIBILITY_THRESHOLD: float = 1
COMPATIBILITY_THRESHOLD: float = 3
EXCESS_COEFFICIENT: float = 1
DISJOINT_COEFFICIENT: float = 1
WEIGHT_COEFFICIENT: float = 0.4

View file

@ -23,7 +23,7 @@ class Species:
def remove_weak_genomes(self):
self.genomes.sort(key=lambda genome: genome.fitness, reverse=True)
survive_count: int = math.ceil(len(self.genomes) / 2)
survive_count: int = math.ceil(len(self.genomes) / 10)
self.genomes: List[Genome] = self.genomes[:survive_count]
def get_top_genome(self) -> Genome:

5
xor.py
View file

@ -7,11 +7,12 @@ from pool import Pool
evaluator: Evaluator = Evaluator([([i, j], [i ^ j]) for i in range(2) for j in range(2)])
pool: Pool = Pool(300, 2, 1)
generation: int = 1
top_genome: Genome = None
while True:
pool.evaluate_fitness(evaluator.evaluate)
top_genome: Genome = pool.get_top_genome()
print(f"Generation {generation} | Top Fitness: {top_genome.fitness} | Species: {len(pool.species)}")
if top_genome.fitness > 3.99:
print(f"Generation {generation} | Top Fitness: {top_genome.fitness} | Species: {len(pool.species)} | Total Genomes: {sum(len(s.genomes) for s in pool.species)}")
if top_genome.fitness > 3.9:
break
pool.breed_new_generation()
generation += 1