Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEEP-180] Add mmsegmentation integration #2875

Merged
merged 15 commits into from
Jun 21, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Deep Lake comes with built-in dataloaders for Pytorch and TensorFlow. Train your
</details>
<details>
<summary><b>Integrations with Powerful Tools</b></summary>
Deep Lake has integrations with <a href="https://github.com/hwchase17/langchain">Langchain</a> and <a href="https://github.com/jerryjliu/llama_index">LLamaIndex</a> as a vector store for LLM apps, <a href="https://wandb.ai/">Weights & Biases</a> for data lineage during model training, and <a href="https://github.com/open-mmlab/mmdetection">MMDetection</a> for training object detection models.
Deep Lake has integrations with <a href="https://github.com/hwchase17/langchain">Langchain</a> and <a href="https://github.com/jerryjliu/llama_index">LLamaIndex</a> as a vector store for LLM apps, <a href="https://wandb.ai/">Weights & Biases</a> for data lineage during model training, <a href="https://github.com/open-mmlab/mmdetection">MMDetection</a> for training object detection models, and <a href="https://github.com/open-mmlab/mmsegmentation">MMSegmentation</a> for training semantic segmentation models.
</details>
<details>
<summary><b>100+ most-popular image, video, and audio datasets available in seconds</b></summary>
Expand Down
3 changes: 0 additions & 3 deletions deeplake/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def request(
headers (dict, optional): Dictionary of HTTP Headers to send with the request.
timeout (float,optional): How many seconds to wait for the server to send data before giving up.

Raises:
InvalidPasswordException: `password` cannot be `None` inside `json`.

Returns:
requests.Response: The response received from the server.
"""
Expand Down
Empty file.
154 changes: 154 additions & 0 deletions deeplake/integrations/mm/mm_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import torch
import mmcv # type: ignore
import deeplake as dp
from deeplake.util.warnings import always_warn
from deeplake.util.exceptions import EmptyTokenException
from deeplake.client.config import DEEPLAKE_AUTH_TOKEN

Check warning on line 7 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L1-L7

Added lines #L1 - L7 were not covered by tests


def ddp_setup(rank: int, world_size: int, port: int):

Check warning on line 10 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L10

Added line #L10 was not covered by tests
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
port: Port number
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
torch.distributed.init_process_group(

Check warning on line 19 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L17-L19

Added lines #L17 - L19 were not covered by tests
backend="nccl", rank=rank, world_size=world_size
)


def force_cudnn_initialization(device_id):
dev = torch.device(f"cuda:{device_id}")
torch.nn.functional.conv2d(

Check warning on line 26 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L24-L26

Added lines #L24 - L26 were not covered by tests
torch.zeros(32, 32, 32, 32, device=dev), torch.zeros(32, 32, 32, 32, device=dev)
)


def load_ds_from_cfg(cfg: mmcv.utils.config.ConfigDict):
creds = cfg.get("deeplake_credentials", {})
token = creds.get("token", None)
token = token or os.environ.get(DEEPLAKE_AUTH_TOKEN)
if token is None:
raise EmptyTokenException()

Check warning on line 36 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L31-L36

Added lines #L31 - L36 were not covered by tests

ds_path = cfg.deeplake_path
ds = dp.load(ds_path, token=token, read_only=True)
deeplake_commit = cfg.get("deeplake_commit")
deeplake_view_id = cfg.get("deeplake_view_id")
deeplake_query = cfg.get("deeplake_query")

Check warning on line 42 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L38-L42

Added lines #L38 - L42 were not covered by tests

if deeplake_view_id and deeplake_query:
raise Exception(

Check warning on line 45 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L44-L45

Added lines #L44 - L45 were not covered by tests
"A query and view_id were specified simultaneously for a dataset in the config. Please specify either the deeplake_query or the deeplake_view_id."
)

if deeplake_commit:
ds.checkout(deeplake_commit)

Check warning on line 50 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L49-L50

Added lines #L49 - L50 were not covered by tests

if deeplake_view_id:
ds = ds.load_view(id=deeplake_view_id)

Check warning on line 53 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L52-L53

Added lines #L52 - L53 were not covered by tests

if deeplake_query:
ds = ds.query(deeplake_query)

Check warning on line 56 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L55-L56

Added lines #L55 - L56 were not covered by tests

return ds

Check warning on line 58 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L58

Added line #L58 was not covered by tests


def get_collect_keys(cfg):
pipeline = cfg.train_pipeline
for transform in pipeline:
if transform["type"] == "Collect":
return transform["keys"]
raise ValueError("collection keys were not specified")

Check warning on line 66 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L61-L66

Added lines #L61 - L66 were not covered by tests


def check_persistent_workers(train_persistent_workers, val_persistent_workers):
if train_persistent_workers != val_persistent_workers:
if train_persistent_workers:
always_warn(

Check warning on line 72 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L69-L72

Added lines #L69 - L72 were not covered by tests
"persistent workers for training and evaluation should be identical, "
"otherwise, this could lead to performance issues. "
"Either both of then should be `True` or both of them should `False`. "
"If you want to use persistent workers set True for validation"
)
else:
always_warn(

Check warning on line 79 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L79

Added line #L79 was not covered by tests
"persistent workers for training and evaluation should be identical, "
"otherwise, this could lead to performance issues. "
"Either both of then should be `True` or both of them should `False`. "
"If you want to use persistent workers set True for training"
)


def find_tensor_with_htype(ds: dp.Dataset, htype: str, mm_class=None):
tensors = [k for k, v in ds.tensors.items() if v.meta.htype == htype]
if mm_class is not None:
always_warn(

Check warning on line 90 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L87-L90

Added lines #L87 - L90 were not covered by tests
f"No deeplake tensor name specified for '{mm_class} in config. Fetching it using htype '{htype}'."
)
if not tensors:
always_warn(f"No tensor found with htype='{htype}'")
return None
t = tensors[0]
if len(tensors) > 1:
always_warn(f"Multiple tensors with htype='{htype}' found. choosing '{t}'.")
return t

Check warning on line 99 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L93-L99

Added lines #L93 - L99 were not covered by tests


def check_unsupported_functionalities(cfg):
check_unused_dataset_fields(cfg)
check_unsupported_train_pipeline_fields(cfg, mode="train")
check_unsupported_train_pipeline_fields(cfg, mode="val")
check_dataset_augmentation_formats(cfg)

Check warning on line 106 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L102-L106

Added lines #L102 - L106 were not covered by tests


def check_unused_dataset_fields(cfg):
if cfg.get("dataset_type"):
always_warn(

Check warning on line 111 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L109-L111

Added lines #L109 - L111 were not covered by tests
"The deeplake mmdet integration does not use dataset_type to work with the data and compute metrics. All deeplake datasets are in the same deeplake format. To specify a metrics format, you should deeplake_metrics_format "
)

if cfg.get("data_root"):
always_warn(

Check warning on line 116 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L115-L116

Added lines #L115 - L116 were not covered by tests
"The deeplake mmdet integration does not use data_root, this input will be ignored"
)


def check_unsupported_train_pipeline_fields(cfg, mode="train"):
transforms = cfg.data[mode].pipeline

Check warning on line 122 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L121-L122

Added lines #L121 - L122 were not covered by tests

for transform in transforms:
transform_type = transform.get("type")

Check warning on line 125 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L124-L125

Added lines #L124 - L125 were not covered by tests

if transform_type == "LoadImageFromFile":
always_warn(

Check warning on line 128 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L127-L128

Added lines #L127 - L128 were not covered by tests
"LoadImageFromFile is going to be skipped because deeplake mmdet integration does not use it"
)

if transform_type == "LoadAnnotations":
always_warn(

Check warning on line 133 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L132-L133

Added lines #L132 - L133 were not covered by tests
"LoadAnnotations is going to be skipped because deeplake mmdet integration does not use it"
)

if transform_type == "Corrupt":
raise Exception("Corrupt augmentation is not supported yet.")

Check warning on line 138 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L137-L138

Added lines #L137 - L138 were not covered by tests

elif transform_type == "CopyPaste": # TO DO: @adolkhan resolve this
raise Exception("CopyPaste augmentation is not supported yet")

Check warning on line 141 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L140-L141

Added lines #L140 - L141 were not covered by tests

elif transform_type == "CutOut": # TO DO: @adolkhan resolve this
raise Exception("CutOut augmentation is not supported yet")

Check warning on line 144 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L143-L144

Added lines #L143 - L144 were not covered by tests

elif transform_type == "Mosaic": # TO DO: @adolkhan resolve this
raise Exception("Mosaic augmentation is not supported yet")

Check warning on line 147 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L146-L147

Added lines #L146 - L147 were not covered by tests


def check_dataset_augmentation_formats(cfg):
if cfg.get("train_dataset"):
always_warn(

Check warning on line 152 in deeplake/integrations/mm/mm_common.py

View check run for this annotation

Codecov / codecov/patch

deeplake/integrations/mm/mm_common.py#L150-L152

Added lines #L150 - L152 were not covered by tests
"train_dataset is going to be unused. Datset types like: ConcatDataset, RepeatDataset, ClassBalancedDataset, MultiImageMixDataset are not supported."
)
Loading
Loading