Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jgrss/v170 #49

Merged
merged 312 commits into from
Mar 9, 2023
Merged

jgrss/v170 #49

merged 312 commits into from
Mar 9, 2023

Conversation

jgrss
Copy link
Owner

@jgrss jgrss commented Mar 3, 2023

This PR introduces changes toward v170.

# `num_classes` includes background
'count': 3 + num_classes - 1,
'dtype': 'uint16',
'blockxsize': 64 if 64 < src.gw.ncols else src.gw.ncols,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion:

min(64, src.gw.ncols)

'sharing': False,
'compress': compression
}
profile['tiled'] = True if max(profile['blockxsize'], profile['blockysize']) >= 16 else False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The True if ... else False part isn't strictly needed. if clarity is the goal, probably better to wrap this expression in a function:

def is_tiled(blockxsize, blockysize, tile_limit=16):
    return max(blockxsize, blockysize) >= tile_limit

)
rheight = pad_slice2d[0].stop - pad_slice2d[0].start
rwidth = pad_slice2d[1].stop - pad_slice2d[1].start
def reshaper(x: torch.Tensor, channel_dims: int) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be good to introduce an autoformatter like black

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

train_data = joblib.load(train_path)
if train_data.train_id == train_id:
batch_stored = True
aug_method = AugmenterMapping[aug.replace('-', '_')].value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the augmenter names should be consistent, just use _ or - everywhere, or better yet use enums.

# Clip the edges to the current grid
try:
grid_edges = gpd.clip(df_edges, row.geometry)
except:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to explicitly catch Topology errors, else you may emit misleading warnings when you run into other errors.

window_pad
) for window, window_pad in window_chunk
)
pbar_total.update(len(window_chunk))


def create_dataset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This thing is getting very long. Probably a good idea to find logical chunks to wrap in functions. See here for some guidelines on how to tell when functions are getting too long: https://stackoverflow.com/questions/475675/when-is-a-function-too-long.

qt = QuadTree(df_unique_locations, force_square=False)
qt.split_recursive(max_samples=1)
n_val = int(val_frac * len(df_unique_locations.index))
df_val_sample = qt.sample(n=n_val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do? Something regarding the spatial distribution of the validation set, but it's not totally clear to me.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the spatially-balanced splitting method (see https://github.com/jgrss/geosample). I've added comments on each step to help clarify this.

@@ -134,21 +248,61 @@ def tanimoto(y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
class TanimotoDistLoss(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note as above, this probably needs tests.

Comment on lines 321 to 326
# train_ds, val_ds = dataset.split_train_val_by_partition(
# spatial_partitions=spatial_partitions,
# partition_column=partition_column,
# val_frac=val_frac,
# partition_name=partition_name
# )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this commented-out block?

Comment on lines 17 to 21
# assert dims in (2, 3)
# if dims == 2:
# ones = torch.ones((1, channels, 1, 1))
# else:
# ones = torch.ones((1, channels, 1, 1, 1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

import enum


class ModelTypes(enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum class might be useful here, or elsewhere in the codebase: https://docs.python.org/3/library/enum.html#enum.StrEnum

Works nicely when you want to map enum values to strings of their names.

@jgrss jgrss mentioned this pull request Mar 8, 2023
Comment on lines +30 to +33

class SetActivation(torch.nn.Module):
def __init__(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice small class! Now you can just run the followings without run torch.nn.{activation_type}

SetActivation( activation_type, channels=out_channels, dims=2)

Comment on lines +460 to +468

def var(self, unbiased=True):
mean = self.mean()[:, None]
return self.integrate(
lambda x: (x - mean).pow(2)
) / (self.count - (1 if unbiased else 0))

def std(self, unbiased=True):
return self.var(unbiased=unbiased).sqrt()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is interesting! This is similar to torch.var

Comment on lines 824 to 826
if len(ts_list) <= 1:
pbar.update(1)
pbar.set_description('TS too short')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so hypothetically, if I only have 20210101.tif and 20220101.tif in my features, this function will continue?

Comment on lines +1071 to +1073
def generate_model_graph(args):
from cultionet.models.convstar import StarRNN
from cultionet.models.nunet import ResUNet3Psi
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a reason we import inside a function for this one?
Is it because the imports are only relevant for this function + take a long time to import?

Copy link
Owner Author

@jgrss jgrss Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because the imports are only relevant for this function

This -- this function only serves the purpose of creating .onnx files for viewing graphs. It's called in isolation and there's no need for the imports elsewhere.

Comment on lines 1375 to 1376
with open(
project_path / f"{args.process}_command_{now.strftime('%Y%m%d-%H%M')}.json", mode='w'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

out_3_1=out_3_1,
out_2_2=out_2_2,
out_1_3=out_1_3
)

return out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file does seem verbose, with a lot of repeated blocks. Attempting to reduce repetition could be the subject of a future PR.

@@ -0,0 +1,798 @@
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of formatting stuff in this file. Linters will likely get it all.

jgrss and others added 8 commits March 9, 2023 11:01
* add flake8 to precommit

* add black and flake8 to pyproject.toml

* change flake8 repo

* add install test extras

* simplify checks

* black formatting

* created CONTRIBUTING file

* format

* format

* format

* sync names

* format

* format

* format

* remove unused function

* format

* moved line

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* format

* use StrEnum

* remove StrEnum

* add version comment

* format

* format

* fix: jgrss/refine (#58)

* format

* test

* add missing reshape

* remove edge temperature

* removed edge refine layer

* format

* format

* remove sigmoid

* remove temperature override

* increase lr

* fixed arg name

* add bash scripts

* update docstring

* fix: jgrss/refine (#59)

* format

* fix arg

* use all data for refinement

* add random sampler for refinement

* format

* format

* remove old arg

* format
@jgrss jgrss merged commit 4135817 into main Mar 9, 2023
@jgrss jgrss deleted the jgrss/topo_v2_time_rnn_nobal_batch_act_dgm_test branch March 9, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants