Skip to content

Commit

Permalink
fix #290
Browse files Browse the repository at this point in the history
  • Loading branch information
lacava committed Aug 17, 2023
1 parent e82b79b commit f0934c5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
6 changes: 4 additions & 2 deletions feat/feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def predict(self,X,Z=None):
else:
return self.cfeat_.predict(X)

def predict_archive(self,X,Z=None):
def predict_archive(self,X,Z=None,front=False):
"""Returns a list of dictionary predictions for all models."""
if not self.is_fitted_:
raise ValueError("Call fit before calling predict.")
Expand All @@ -313,9 +313,11 @@ def predict_archive(self,X,Z=None):
raise NotImplementedError('longitudinal not implemented')
return

archive = self.cfeat_.get_archive(False)
archive = self.cfeat_.get_archive(front)
preds = []
for ind in archive:
if ind['id'] == 9234:
print('individual:',json.dumps(ind,indent=2))
tmp = {}
tmp['id'] = ind['id']
tmp['y_pred'] = self.cfeat_.predict_archive(ind['id'], X)
Expand Down
2 changes: 1 addition & 1 deletion feat/versionstr.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="0.5.2.post115"
__version__="0.5.2.post116"
14 changes: 13 additions & 1 deletion src/feat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,17 +1203,29 @@ VectorXf Feat::predict_archive(int id, MatrixXf& X, LongData& Z)
Data tmp_data(X,empty_y,Z);

/* cout << "individual prediction id " << id << "\n"; */
if (id == best_ind.id)
{
return best_ind.predict_vector(tmp_data);
}
for (int i = 0; i < this->archive.individuals.size(); ++i)
{
Individual& ind = this->archive.individuals.at(i);

if (id == ind.id)
return ind.predict_vector(tmp_data);

}
for (int i = 0; i < this->pop.individuals.size(); ++i)
{
Individual& ind = this->pop.individuals.at(i);

if (id == ind.id)
return ind.predict_vector(tmp_data);

}

THROW_INVALID_ARGUMENT("Could not find id = "
+ to_string(id) + "in archive.");
+ to_string(id) + "in archive or population.");
return VectorXf();
}

Expand Down

0 comments on commit f0934c5

Please sign in to comment.