Skip to content

Commit

Permalink
fixing Win failed import (#1163)
Browse files Browse the repository at this point in the history
* version

* try fix distrib

* update try import
  • Loading branch information
Borda committed Mar 17, 2020
1 parent 49d000c commit e461ec0
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/rebase.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ name: Automatic Rebase
# https://github.com/marketplace/actions/automatic-rebase

on:
issue_comment:
types: [created]
- pull_request

jobs:
rebase:
name: Rebase
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Root package info."""

__version__ = '0.7.1'
__version__ = '0.7.2-dev'
__author__ = 'William Falcon et al.'
__author_email__ = 'waf2107@columbia.edu'
__license__ = 'Apache-2.0'
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True


class ModelHooks(torch.nn.Module):
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from torch import Tensor
from torch.distributed import init_process_group
import torch.distributed as torch_distrib
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.optim.optimizer import Optimizer
Expand All @@ -24,10 +24,10 @@

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True

except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
Expand Down Expand Up @@ -859,7 +859,7 @@ def init_ddp_connection(self):

root_node = self.trainer.resolve_root_node_address(root_node)
os.environ['MASTER_ADDR'] = root_node
init_process_group('nccl', rank=proc_rank, world_size=world_size)
torch_distrib.init_process_group('nccl', rank=proc_rank, world_size=world_size)

def configure_apex(
self,
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

import logging as log
from abc import ABC

try:
from apex import amp

APEX_AVAILABLE = True
except ImportError:
APEX_AVAILABLE = False
import logging as log
else:
APEX_AVAILABLE = True


class TrainerAMPMixin(ABC):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable

import torch.distributed as dist
import torch.distributed as torch_distrib
from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler

Expand Down Expand Up @@ -224,7 +224,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
# get the function we'll use to get data
if self.use_ddp or self.use_ddp2:
# all processes wait until data download has happened
dist.barrier()
torch_distrib.barrier()

# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from torch import optim
import torch.distributed as dist
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -748,7 +748,7 @@ def run_pretrain_routine(self, model: LightningModule):
self.logger.save()

if self.use_ddp or self.use_ddp2:
dist.barrier()
torch_distrib.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
from typing import Union

import torch
import torch.distributed as dist
import torch.distributed as torch_distrib

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
Expand Down Expand Up @@ -177,7 +177,7 @@ def restore_weights(self, model):
# wait for all models to restore weights
if self.use_ddp or self.use_ddp2:
# wait for all processes to catch up
dist.barrier()
torch_distrib.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
Expand Down

0 comments on commit e461ec0

Please sign in to comment.