Skip to content

Commit

Permalink
refactor: change classify_img->predict_img
Browse files Browse the repository at this point in the history
  • Loading branch information
martibosch committed Mar 28, 2024
1 parent 4f50349 commit 73ab0c4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
16 changes: 8 additions & 8 deletions detectree/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def __init__(

self.pixel_features_builder_kwargs = pixel_features_builder_kwargs

def _classify_img(self, img_filepath, clf, *, output_filepath=None):
def _predict_img(self, img_filepath, clf, *, output_filepath=None):
# ACHTUNG: Note that we do not use keyword-only arguments in this method because
# `output_filepath` works as the only "optional" argument
src = rio.open(img_filepath)
Expand Down Expand Up @@ -403,7 +403,7 @@ def _classify_img(self, img_filepath, clf, *, output_filepath=None):
src.close()
return y_pred

def _classify_imgs(self, img_filepaths, clf, output_dir):
def _predict_imgs(self, img_filepaths, clf, output_dir):
pred_imgs_lazy = []
pred_img_filepaths = []
for img_filepath in img_filepaths:
Expand All @@ -412,7 +412,7 @@ def _classify_imgs(self, img_filepaths, clf, output_dir):
# output_dir, f"{filename}-pred{ext}")
pred_img_filepath = path.join(output_dir, path.basename(img_filepath))
pred_imgs_lazy.append(
dask.delayed(self._classify_img)(
dask.delayed(self._predict_img)(
img_filepath, clf, output_filepath=pred_img_filepath
)
)
Expand All @@ -423,7 +423,7 @@ def _classify_imgs(self, img_filepaths, clf, output_dir):

return pred_img_filepaths

def classify_img(self, img_filepath, *, img_cluster=None, output_filepath=None):
def predict_img(self, img_filepath, *, img_cluster=None, output_filepath=None):
"""
Use a trained classifier to predict tree pixels in an image.
Expand Down Expand Up @@ -460,11 +460,11 @@ def classify_img(self, img_filepath, *, img_cluster=None, output_filepath=None):
# return self._classify_img(
# img_filepath, clf, output_filepath=output_filepath
# )
return self._classify_img(
return self._predict_img(
img_filepath, self.clf, output_filepath=output_filepath
)

def classify_imgs(self, split_df, output_dir):
def predict_imgs(self, split_df, output_dir):
"""
Use trained classifier(s) to predict tree pixels in multiple images.
Expand All @@ -484,7 +484,7 @@ def classify_imgs(self, split_df, output_dir):
File paths of the dumped tiles.
"""
if hasattr(self, "clf"):
return self._classify_imgs(
return self._predict_imgs(
split_df[~split_df["train"]]["img_filepath"], self.clf, output_dir
)
else:
Expand All @@ -498,7 +498,7 @@ def classify_imgs(self, split_df, output_dir):
f"Classifier for cluster {img_cluster} not found in"
" `self.clf_dict`."
)
pred_imgs[img_cluster] = self._classify_imgs(
pred_imgs[img_cluster] = self._predict_imgs(
utils.get_img_filepaths(split_df, img_cluster, False),
clf,
output_dir,
Expand Down
8 changes: 4 additions & 4 deletions detectree/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def train_classifiers(
@click.option("--refine-int-rescale", type=int)
@click.option("--pixel-features-builder-kwargs", cls=_OptionEatAll)
@click.option("--output-filepath", type=click.Path())
def classify_img(
def predict_img(
ctx,
img_filepath,
clf_filepath,
Expand Down Expand Up @@ -354,7 +354,7 @@ def classify_img(
filename, ext = path.splitext(path.basename(img_filepath))
output_filepath = f"{filename}-pred{ext}"

c.classify_img(
c.predict_img(
img_filepath,
output_filepath=output_filepath,
)
Expand All @@ -373,7 +373,7 @@ def classify_img(
@click.option("--refine-int-rescale", type=int)
@click.option("--pixel-features-builder-kwargs", cls=_OptionEatAll)
@click.option("--output-dir", type=click.Path(exists=True))
def classify_imgs(
def predict_imgs(
ctx,
split_filepath,
clf_filepath,
Expand Down Expand Up @@ -434,7 +434,7 @@ def classify_imgs(
if output_dir is None:
output_dir = ""

pred_imgs = c.classify_imgs(
pred_imgs = c.predict_imgs(
split_df,
output_dir,
)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_detectree.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,11 +585,11 @@ def test_classifier(self):
]:
img_filepath = self.split_i_df.iloc[0]["img_filepath"]
# test that `classify_img` returns a ndarray
self.assertIsInstance(c.classify_img(img_filepath), np.ndarray)
self.assertIsInstance(c.predict_img(img_filepath), np.ndarray)
# test that `classify_img` with `output_filepath` returns a ndarray and
# dumps it
output_filepath = path.join(self.tmp_output_dir, "foo.tif")
y_pred = c.classify_img(img_filepath, output_filepath=output_filepath)
y_pred = c.predict_img(img_filepath, output_filepath=output_filepath)
self.assertIsInstance(y_pred, np.ndarray)
self.assertTrue(os.path.exists(output_filepath))
# remove it so that the output dir is clean in the tests below
Expand All @@ -599,7 +599,7 @@ def test_classifier(self):
# dumped. This works regardless of whether a "img_cluster" column is present
# in the split data frame - since it is ignored for "cluster-I"
for split_df in [self.split_i_df, self.split_ii_df]:
pred_imgs = c.classify_imgs(split_df, self.tmp_output_dir)
pred_imgs = c.predict_imgs(split_df, self.tmp_output_dir)
self.assertIsInstance(pred_imgs, list)
self._test_imgs_exist_and_rm(pred_imgs)

Expand All @@ -610,22 +610,22 @@ def test_classifier(self):
]:
img_filepath = self.split_i_df.iloc[0]["img_filepath"]
# test that `classify_img` returns a ndarray
self.assertIsInstance(c.classify_img(img_filepath), np.ndarray)
self.assertIsInstance(c.predict_img(img_filepath), np.ndarray)

# "cluster-II"
c = dtr.Classifier(clf_dict=self.clf_dict)
# `classify_imgs` should raise a `KeyError` if `split_df` doesn't have a
# "img_cluster" column
self.assertRaises(
KeyError, c.classify_imgs, self.split_i_df, self.tmp_output_dir
KeyError, c.predict_imgs, self.split_i_df, self.tmp_output_dir
)
# otherwise it should return a list and dump the images (regardless of the
# `refine` value
for c in [
dtr.Classifier(clf_dict=self.clf_dict, refine=refine)
for refine in [True, False]
]:
pred_imgs = c.classify_imgs(self.split_ii_df, self.tmp_output_dir)
pred_imgs = c.predict_imgs(self.split_ii_df, self.tmp_output_dir)
self.assertIsInstance(pred_imgs, dict)
for img_cluster in pred_imgs:
self._test_imgs_exist_and_rm(pred_imgs[img_cluster])
Expand Down Expand Up @@ -787,9 +787,9 @@ def test_train_classifiers(self):
)
self.assertEqual(result.exit_code, 0)

def test_classify_img(self):
def test_predict_img(self):
base_args = [
"classify-img",
"predict-img",
glob.glob(path.join(self.img_dir, "*.tif"))[0],
"--output-filepath",
path.join(self.tmp_dir, "foo.tif"),
Expand All @@ -798,8 +798,8 @@ def test_classify_img(self):
result = self.runner.invoke(main.cli, base_args + extra_args)
self.assertEqual(result.exit_code, 0)

def test_classify_imgs(self):
base_args = ["classify-imgs", self.split_ii_filepath]
def test_predict_imgs(self):
base_args = ["predict-imgs", self.split_ii_filepath]
final_args = [
"--output-dir",
self.tmp_dir,
Expand Down

0 comments on commit 73ab0c4

Please sign in to comment.