From 7b78d8fba0a779c561124dd3f79fd67b341700c0 Mon Sep 17 00:00:00 2001 From: Cyprien Noel Date: Mon, 8 Aug 2016 17:31:47 -0700 Subject: [PATCH] Python multi-GPU --- include/caffe/blob.hpp | 1 + include/caffe/data_transformer.hpp | 2 +- include/caffe/layers/python_layer.hpp | 8 +- python/caffe/__init__.py | 5 +- python/caffe/_caffe.cpp | 215 +++++++++++++++++- python/caffe/pycaffe.py | 3 +- python/caffe/train.py | 278 +++++++++++++++++++++++ python/multi_gpu.py | 11 + src/caffe/blob.cpp | 6 + src/caffe/data_transformer.cpp | 5 +- src/caffe/layer_factory.cpp | 7 + src/caffe/layers/base_data_layer.cpp | 4 +- src/caffe/test/test_data_transformer.cpp | 14 +- 13 files changed, 530 insertions(+), 29 deletions(-) create mode 100644 python/caffe/train.py create mode 100644 python/multi_gpu.py diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index af360ac24bd..2f59471c29e 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -220,6 +220,7 @@ class Blob { void set_cpu_data(Dtype* data); const int* gpu_shape() const; const Dtype* gpu_data() const; + void set_gpu_data(Dtype* data); const Dtype* cpu_diff() const; const Dtype* gpu_diff() const; Dtype* mutable_cpu_data(); diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp index 97b4ee6a8c4..87c2c17eb9e 100644 --- a/include/caffe/data_transformer.hpp +++ b/include/caffe/data_transformer.hpp @@ -23,7 +23,7 @@ class DataTransformer { * @brief Initialize the Random number generations if needed by the * transformation. */ - void InitRand(); + void InitRand(unsigned int seed); /** * @brief Applies the transformation defined in the data layer's diff --git a/include/caffe/layers/python_layer.hpp b/include/caffe/layers/python_layer.hpp index 66dbbdf13b8..ad6ae903ef6 100644 --- a/include/caffe/layers/python_layer.hpp +++ b/include/caffe/layers/python_layer.hpp @@ -20,10 +20,10 @@ class PythonLayer : public Layer { const vector*>& top) { // Disallow PythonLayer in MultiGPU training stage, due to GIL issues // Details: https://github.com/BVLC/caffe/issues/2936 - if (this->phase_ == TRAIN && Caffe::solver_count() > 1 - && !ShareInParallel()) { - LOG(FATAL) << "PythonLayer is not implemented in Multi-GPU training"; - } +// if (this->phase_ == TRAIN && Caffe::solver_count() > 1 +// && !ShareInParallel()) { +// LOG(FATAL) << "PythonLayer is not implemented in Multi-GPU training"; +// } self_.attr("param_str") = bp::str( this->layer_param_.python_param().param_str()); self_.attr("phase") = static_cast(this->phase_); diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py index 35868a403a3..47d65c9c634 100644 --- a/python/caffe/__init__.py +++ b/python/caffe/__init__.py @@ -1,8 +1,9 @@ -from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver -from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed +from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver, DataTransformer, Blob, NCCL, Timer +from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed, get_random, solver_count, set_solver_count, solver_rank, set_solver_rank, Layer, get_solver, layer_type_list from ._caffe import __version__ from .proto.caffe_pb2 import TRAIN, TEST from .classifier import Classifier from .detector import Detector from . import io from .net_spec import layers, params, NetSpec, to_proto +from .train import train \ No newline at end of file diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index bdee75acd6c..dc28ee99934 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -53,6 +53,16 @@ void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); } void set_random_seed(unsigned int seed) { Caffe::set_random_seed(seed); } +void InitLog(int level) { + FLAGS_logtostderr = 1; + FLAGS_minloglevel = level; + ::google::InitGoogleLogging(""); + ::google::InstallFailureSignalHandler(); +} +void Log(const string& s) { + LOG(INFO) << s; +} + // For convenience, check that input files can be opened, and raise an // exception that boost will send to Python if not (caffe could still crash // later if the input files are disturbed before they are actually used, but @@ -254,12 +264,12 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) { } template -class PythonCallback: public Solver::Callback { +class SolverCallback: public Solver::Callback { protected: bp::object on_start_, on_gradients_ready_; public: - PythonCallback(bp::object on_start, bp::object on_gradients_ready) + SolverCallback(bp::object on_start, bp::object on_gradients_ready) : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { } virtual void on_gradients_ready() { on_gradients_ready_(); @@ -271,7 +281,121 @@ class PythonCallback: public Solver::Callback { template void Solver_add_callback(Solver * solver, bp::object on_start, bp::object on_gradients_ready) { - solver->add_callback(new PythonCallback(on_start, on_gradients_ready)); + solver->add_callback(new SolverCallback(on_start, on_gradients_ready)); +} +// Seems boost cannot call the base method directly +void Solver_add_nccl(SGDSolver* solver, NCCL* nccl) { + solver->add_callback(nccl); +} +template +class NetCallback: public Net::Callback { + public: + explicit NetCallback(bp::object run) : run_(run) {} + + protected: + virtual void run(int layer) { + run_(layer); + } + bp::object run_; +}; +void Net_before_forward(Net* net, bp::object run) { + net->add_before_forward(new NetCallback(run)); +} +void Net_after_forward(Net* net, bp::object run) { + net->add_after_forward(new NetCallback(run)); +} +void Net_before_backward(Net* net, bp::object run) { + net->add_before_backward(new NetCallback(run)); +} +void Net_after_backward(Net* net, bp::object run) { + net->add_after_backward(new NetCallback(run)); +} +void Net_add_nccl(Net* net, NCCL* nccl) { + net->add_after_backward(nccl); +} + +// Transformer constructor for passing phase as int +shared_ptr > Transformer_Init( + const TransformationParameter& param, int phase) { + shared_ptr > t( + new DataTransformer(param, + static_cast(phase))); + return t; +} + +void Transform(DataTransformer* trans, + const string& str, + Blob* data, + Blob* label, + int index) { + Datum datum; + datum.ParseFromString(str); + vector shape(data->shape()); + shape[0] = 1; + Blob tmp(shape); + tmp.set_cpu_data(data->mutable_cpu_data() + data->offset(index)); + trans->Transform(datum, &tmp); + label->mutable_cpu_data()[label->offset(index)] = datum.label(); +} + +template +struct proto_pickle : bp::pickle_suite { + static bp::tuple getstate(const T& proto) { + return bp::make_tuple(proto.SerializeAsString()); + } + + static void setstate(T& proto, // NOLINT(runtime/references) + bp::tuple state) { + string s = bp::extract(state[0])(); + proto.ParseFromString(s); + } +}; + +struct blob_pickle : bp::pickle_suite { + // TODO also transfer cpu side through regular IPC + static bp::tuple getstate(const Blob& blob) { + string s1(sizeof(int) * blob.shape().size(), 0); + memcpy(&s1[0], &blob.shape()[0], s1.size()); // NOLINT(caffe/alt_fn) + + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, blob.gpu_data())); + CUDA_CHECK(cudaSetDevice(attributes.device)); + + cudaIpcMemHandle_t handle; + CUDA_CHECK(cudaIpcGetMemHandle(&handle, + reinterpret_cast(const_cast(blob.gpu_data())))); + string s2(CUDA_IPC_HANDLE_SIZE, 0); + memcpy(&s2[0], &handle, CUDA_IPC_HANDLE_SIZE); // NOLINT(caffe/alt_fn) + + return bp::make_tuple(s1, s2); + } + + static void setstate(Blob& blob, // NOLINT(runtime/references) + bp::tuple state) { + string s1 = bp::extract(state[0])(); + string s2 = bp::extract(state[1])(); + + vector shape(s1.size() / sizeof(int)); + memcpy(&shape[0], &s1[0], s1.size()); // NOLINT(caffe/alt_fn) + blob.Reshape(shape); + + cudaIpcMemHandle_t handle; + memcpy(&handle, &s2[0], CUDA_IPC_HANDLE_SIZE); // NOLINT(caffe/alt_fn) + Dtype* data; + CUDA_CHECK(cudaIpcOpenMemHandle(reinterpret_cast(&data), handle, + cudaIpcMemLazyEnablePeerAccess)); + blob.set_gpu_data(data); + } +}; + +int phase_as_int(LayerParameter* param) { + return static_cast(param->phase()); +} +void prefetch_to_gpu(Blob* blob) { + blob->gpu_data(); +} +void set_gpu_data(Blob* blob, Blob* source) { + blob->set_gpu_data(source->mutable_gpu_data()); } BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1); @@ -283,10 +407,17 @@ BOOST_PYTHON_MODULE(_caffe) { bp::scope().attr("__version__") = AS_STRING(CAFFE_VERSION); // Caffe utility functions + bp::def("init_log", &InitLog); + bp::def("log", &Log); bp::def("set_mode_cpu", &set_mode_cpu); bp::def("set_mode_gpu", &set_mode_gpu); bp::def("set_random_seed", &set_random_seed); + bp::def("get_random", &caffe_rng_rand); bp::def("set_device", &Caffe::SetDevice); + bp::def("solver_count", &Caffe::solver_count); + bp::def("set_solver_count", &Caffe::set_solver_count); + bp::def("solver_rank", &Caffe::solver_rank); + bp::def("set_solver_rank", &Caffe::set_solver_rank); bp::def("layer_type_list", &LayerRegistry::LayerTypeList); @@ -317,6 +448,7 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_internal_reference<>())) .add_property("layers", bp::make_function(&Net::layers, bp::return_internal_reference<>())) + .def("layer", bp::make_function(&Net::layer_by_name)) .add_property("_blob_names", bp::make_function(&Net::blob_names, bp::return_value_policy())) .add_property("_layer_names", bp::make_function(&Net::layer_names, @@ -330,11 +462,16 @@ BOOST_PYTHON_MODULE(_caffe) { bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) .def("save", &Net_Save) .def("save_hdf5", &Net_SaveHDF5) - .def("load_hdf5", &Net_LoadHDF5); + .def("load_hdf5", &Net_LoadHDF5) + .def("before_forward", &Net_before_forward) + .def("after_forward", &Net_after_forward) + .def("before_backward", &Net_before_backward) + .def("after_backward", &Net_after_backward) + .def("after_backward", &Net_add_nccl); BP_REGISTER_SHARED_PTR_TO_PYTHON(Net); bp::class_, shared_ptr >, boost::noncopyable>( - "Blob", bp::no_init) + "Blob", bp::init<>()) .add_property("shape", bp::make_function( static_cast& (Blob::*)() const>( @@ -350,7 +487,10 @@ BOOST_PYTHON_MODULE(_caffe) { .add_property("data", bp::make_function(&Blob::mutable_cpu_data, NdarrayCallPolicies())) .add_property("diff", bp::make_function(&Blob::mutable_cpu_diff, - NdarrayCallPolicies())); + NdarrayCallPolicies())) + .def_pickle(blob_pickle()) + .def("prefetch_to_gpu", &prefetch_to_gpu) + .def("set_gpu_data", &set_gpu_data); BP_REGISTER_SHARED_PTR_TO_PYTHON(Blob); bp::class_, shared_ptr >, @@ -359,10 +499,43 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_internal_reference<>())) .def("setup", &Layer::LayerSetUp) .def("reshape", &Layer::Reshape) - .add_property("type", bp::make_function(&Layer::type)); + .add_property("type", bp::make_function(&Layer::type)) + .add_property("layer_param", bp::make_function(&Layer::layer_param, + bp::return_value_policy())); BP_REGISTER_SHARED_PTR_TO_PYTHON(Layer); - bp::class_("LayerParameter", bp::no_init); + bp::class_("SolverParameter", bp::init<>()) + .add_property("max_iter", &SolverParameter::max_iter) + .add_property("display", &SolverParameter::display) + .def_pickle(proto_pickle()); + bp::class_("TransformationParameter", bp::init<>()) + .add_property("crop_size", &TransformationParameter::crop_size); + bp::class_("DataParameter", bp::init<>()) + .add_property("batch_size", &DataParameter::batch_size) + .add_property("source", bp::make_function(&DataParameter::source, + bp::return_value_policy())) + .add_property("backend", &DataParameter::backend) + .def_pickle(proto_pickle()); + bp::class_("MemoryDataParameter", bp::init<>()) + .add_property("batch_size", &MemoryDataParameter::batch_size) + .add_property("channels", &MemoryDataParameter::channels) + .add_property("height", &MemoryDataParameter::height) + .add_property("width", &MemoryDataParameter::width) + .def_pickle(proto_pickle()); + bp::class_("LayerParameter", bp::init<>()) + .add_property("name", bp::make_function(&LayerParameter::name, + bp::return_value_policy())) + .add_property("phase", &phase_as_int) + .add_property("top_size", &LayerParameter::top_size) + .add_property("transform_param", + bp::make_function(&LayerParameter::transform_param, + bp::return_value_policy())) + .add_property("data_param", bp::make_function(&LayerParameter::data_param, + bp::return_value_policy())) + .add_property("memory_data_param", + bp::make_function(&LayerParameter::memory_data_param, + bp::return_value_policy())) + .def_pickle(proto_pickle()); bp::class_, shared_ptr >, boost::noncopyable>( "Solver", bp::no_init) @@ -371,11 +544,14 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_internal_reference<>())) .add_property("iter", &Solver::iter) .def("add_callback", &Solver_add_callback) + .def("add_callback", &Solver_add_nccl) .def("solve", static_cast::*)(const char*)>( &Solver::Solve), SolveOverloads()) .def("step", &Solver::Step) .def("restore", &Solver::Restore) - .def("snapshot", &Solver::Snapshot); + .def("snapshot", &Solver::Snapshot) + .add_property("param", bp::make_function(&Solver::param, + bp::return_value_policy())); BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver); bp::class_, bp::bases >, @@ -419,6 +595,27 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_ >("BoolVec") .def(bp::vector_indexing_suite >()); + bp::class_, shared_ptr >, + boost::noncopyable>("DataTransformer", bp::no_init) + .def("__init__", bp::make_constructor(&Transformer_Init)) + .def("init_rand", &DataTransformer::InitRand) + .def("transform", &Transform); + BP_REGISTER_SHARED_PTR_TO_PYTHON(DataTransformer); + + bp::class_, shared_ptr >, + boost::noncopyable>("NCCL", + bp::init >, const string&>()) + .def("new_uid", &NCCL::new_uid).staticmethod("new_uid") + .def("bcast", &NCCL::bcast); + BP_REGISTER_SHARED_PTR_TO_PYTHON(NCCL); + + bp::class_, boost::noncopyable>( + "Timer", bp::init<>()) + .def("start", &Timer::Start) + .def("stop", &Timer::Stop) + .add_property("ms", &Timer::MilliSeconds); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Timer); + // boost python expects a void (missing) return value, while import_array // returns NULL for python3. import_array1() forces a void return value. import_array1(); diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5bae18d9a4d..4c0c0a7765b 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -11,7 +11,8 @@ import numpy as np from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \ - RMSPropSolver, AdaDeltaSolver, AdamSolver + RMSPropSolver, AdaDeltaSolver, AdamSolver, DataTransformer, \ + Blob, NCCL, Timer import caffe.io import six diff --git a/python/caffe/train.py b/python/caffe/train.py new file mode 100644 index 00000000000..04383de198e --- /dev/null +++ b/python/caffe/train.py @@ -0,0 +1,278 @@ +""" +Trains a model using one or more GPUs. Existing solver and model params can +be used, with the only modification of changing the data layer type from 'Data' +to 'Python', and adding a python_params section like this: + + python_param { + module: "caffe.train" + layer: "DataLayer" + } + +Other sections like transform_param and data_param will be interpreted by the +python layer and can be used as is. LMDB is currently the only supported backend. +""" +from Queue import Empty +from multiprocessing import Process, Queue + +import lmdb, time +import caffe + + +def train( + solver, # solver proto definition + snapshot, # network snapshot to restore + gpus, # set of GPUs to run on + layer_wise_reduce=False, # overlaps training and transfers + data_layer_count=1, # For nets with multiple TODO get from net proto + prefetchers_per_gpu=2, # 0 to disable prefetching + prefetch_queue_size=4, # more can improve perf, but needs GPUs memory + timing=False, # show timing info for compute and communications +): + _train(solver=solver, snapshot=snapshot, gpus=gpus, + layer_wise_reduce=layer_wise_reduce, + data_layer_count=data_layer_count, + prefetchers_per_gpu=prefetchers_per_gpu, + prefetch_queue_size=prefetch_queue_size, + timing=timing) + + +def _train(**kwargs): + # NCCL uses a uid to identify the session + uid = caffe.NCCL.new_uid() + + procs = [] + for rank in range(len(kwargs['gpus'])): + queues = [] + for _ in range(kwargs['data_layer_count']): + for _ in range(kwargs['prefetchers_per_gpu'] * 2): # train + test + queues.append(QueuePair()) + + p = Process(target=solve, args=(uid, rank, queues, kwargs)) + p.daemon = True + p.start() + procs.append(p) + + for i in range(len(queues)): + p = Process(target=fill, args=(queues[i], rank, i, kwargs)) + p.daemon = True + p.start() + + for p in procs: + p.join() + + +class DataLayer(caffe.Layer): + """ + Python version of Caffe's data layer. It reads transform_param to apply the + same transforms as the original. If prefetching if enabled, loading and transform + are done in separate processes, and retrieved through zero-copy CUDA IPC. + """ + + def init(self, queues, kwargs): + self.queues = queues + self.queue = 0 + self.index = -1 + self.kwargs = kwargs + if kwargs['prefetchers_per_gpu'] == 0: + self.source = DataSource(caffe.solver_rank(), 0, + self.layer_param, + caffe.get_random(), kwargs) + for queue in queues: + assert len(queue.items) == 0 + for i in range(kwargs['prefetch_queue_size']): + queue.free.put(i) + blobs = [] + for _ in range(self.layer_param.top_size): + blobs.append(caffe.Blob()) + self.reshape(None, blobs) + # Make sure buffer is created before queueing + for blob in blobs: + blob.prefetch_to_gpu() + queue.items.append(blobs) + # Arguments for prefetch process + queue.init.put((queue.items, self.layer_param, caffe.get_random())) + + def setup(self, bottom, top): + pass + + def reshape(self, bottom, top): + batch = self.layer_param.data_param.batch_size + # If the data layer does not have a transform, you need to specify + # its shape manually, e.g. for mnist: + # top[0].reshape(batch, 1, 28, 28) + top[0].reshape(batch, + 3, + self.layer_param.transform_param.crop_size, + self.layer_param.transform_param.crop_size) + top[1].reshape(batch, 1) + + def forward(self, bottom, top): + if self.kwargs['prefetchers_per_gpu'] == 0: + self.source.batch(top) + else: + if self.index != -1: + self.queues[self.queue].free.put(self.index) + # Round robin for deterministic runs + self.queue += 1 + if self.queue == len(self.queues): + self.queue = 0 + + qp = self.queues[self.queue] + try: + self.index = qp.full.get(block=False) + except Empty: + caffe.log('Waiting on data') + self.index = qp.full.get() + data = qp.items[self.index] + for i in range(len(data)): + top[i].set_gpu_data(data[i]) + + def backward(self, top, propagate_down, bottom): + pass + + +class DataSource: + def __init__(self, rank, index, param, seed, kwargs): + self.batch_size = param.data_param.batch_size + self.db = self.open_db(param.data_param.source, rank, index, + param.phase, kwargs) + self.tr = caffe.DataTransformer(param.transform_param, param.phase) + self.tr.init_rand(seed) + + def open_db(self, path, rank, index, phase, kwargs): + """ + Reads items from lmdb, skipping keys that will be read by + other processes and threads. + """ + caffe.log('lmdb open %s' % path) + env = lmdb.open(path, map_size=1024 ^ 4, readonly=True, create=False) + txn = env.begin() + + per_gpu = max(1, kwargs['prefetchers_per_gpu']) + segment = len(kwargs['gpus']) * per_gpu + offset = 0 + while True: + for key, value in txn.cursor(): # TODO also test in parallel + if offset % segment == rank * per_gpu + index or phase == caffe.TEST: + yield value + offset += 1 + + def batch(self, blobs): + for i in range(self.batch_size): + self.tr.transform(self.db.next(), blobs[0], blobs[1], i) + + +class QueuePair: + """ + Exchange items between processes. Items are sent once initially + through the init queue, then only indexes are exchanged. + """ + + def __init__(self): + self.free = Queue() + self.full = Queue() + + # Initial arguments and items cached on each side + self.init = Queue() + self.items = [] + + +def solve(uid, rank, queues, kwargs): + gpus = kwargs['gpus'] + + # glog levels: INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 + caffe.init_log(0 if rank == 0 else 1) + caffe.set_mode_gpu() + caffe.set_device(gpus[rank]) + caffe.set_solver_count(len(gpus)) + caffe.set_solver_rank(rank) + + solver = caffe.SGDSolver(kwargs['solver']) + if rank == 0 and kwargs['snapshot'] is not None: + solver.restore(kwargs['snapshot']) + + index = 0 + batch = 0 + prefetchers_per_gpu = kwargs['prefetchers_per_gpu'] + for layer in solver.net.layers: + if isinstance(layer, DataLayer): + layer.init(queues[index: index + prefetchers_per_gpu], kwargs) + index += prefetchers_per_gpu + batch = layer.layer_param.data_param.batch_size + for layer in solver.test_nets[0].layers: + if isinstance(layer, DataLayer): + layer.init(queues[index: index + prefetchers_per_gpu], kwargs) + index += prefetchers_per_gpu + assert index == len(queues) + + nccl = caffe.NCCL(solver, uid) + nccl.bcast() + display = solver.param.display + + if kwargs['timing']: + fprop = [] + bprop = [] + total = caffe.Timer() + allrd = caffe.Timer() + for _ in range(len(solver.net.layers)): + fprop.append(caffe.Timer()) + bprop.append(caffe.Timer()) + + def show_time(): + if solver.iter % display == 0: + s = '\n' + for i in range(len(solver.net.layers)): + s += 'forw %3d %8s ' % (i, solver.net.layers[i].layer_param.name) + s += ': %.2f\n' % fprop[i].ms + for i in range(len(solver.net.layers) - 1, -1, -1): + s += 'back %3d %8s ' % (i, solver.net.layers[i].layer_param.name) + s += ': %.2f\n' % bprop[i].ms + s += 'solver total: %.2f\n' % total.ms + s += 'allreduce: %.2f\n' % allrd.ms + caffe.log(s) + + solver.net.before_forward(lambda layer: fprop[layer].start()) + solver.net.after_forward(lambda layer: fprop[layer].stop()) + solver.net.before_backward(lambda layer: bprop[layer].start()) + solver.net.after_backward(lambda layer: bprop[layer].stop()) + solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start())) + solver.add_callback(nccl) + solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time())) + else: + solver.add_callback(nccl) + + class Rate: + def __init__(self): + self.start = time.time() + self.count = 0 + + def run(self): + if solver.iter % display == 0: + nbr = batch * solver.iter * len(gpus) + now = time.time() + caffe.log('%d examples/s' % ((nbr - self.count) / (now - self.start))) + self.start = now + self.count = nbr + + rate = Rate() + solver.add_callback(lambda: '', lambda: rate.run()) + + if kwargs['layer_wise_reduce']: + solver.net.after_backward(nccl) + solver.step(solver.param.max_iter) + + +def fill(qp, rank, index, kwargs): + caffe.init_log(0 if rank == 0 else 1) + caffe.set_device(kwargs['gpus'][rank]) + args = qp.init.get() + assert len(qp.items) == 0 + qp.items = args[0] + assert len(qp.items) == kwargs['prefetch_queue_size'] + source = DataSource(rank, index, args[1], args[2], kwargs) + while True: + index = qp.free.get() + source.batch(qp.items[index]) + for blob in qp.items[index]: + blob.prefetch_to_gpu() + qp.full.put(index) diff --git a/python/multi_gpu.py b/python/multi_gpu.py new file mode 100644 index 00000000000..23bb9872fe2 --- /dev/null +++ b/python/multi_gpu.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +import caffe + +# Example multi-GPU training +caffe.train( + solver='models/bvlc_reference_caffenet/solver.prototxt', + snapshot=None, + gpus=range(8), + layer_wise_reduce=True, + # timing=True, +) diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 4a34e4c5856..863d940c190 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -98,6 +98,12 @@ const Dtype* Blob::gpu_data() const { return (const Dtype*)data_->gpu_data(); } +template +void Blob::set_gpu_data(Dtype* data) { + CHECK(data); + data_->set_gpu_data(data); +} + template const Dtype* Blob::cpu_diff() const { CHECK(diff_); diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 7189d67e289..21216aa9c9d 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -520,12 +520,11 @@ vector DataTransformer::InferBlobShape( #endif // USE_OPENCV template -void DataTransformer::InitRand() { +void DataTransformer::InitRand(unsigned int seed) { const bool needs_rand = param_.mirror() || (phase_ == TRAIN && param_.crop_size()); if (needs_rand) { - const unsigned int rng_seed = caffe_rng_rand(); - rng_.reset(new Caffe::RNG(rng_seed)); + rng_.reset(new Caffe::RNG(seed)); } else { rng_.reset(); } diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index e967bd6181c..f14253a510e 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -67,6 +67,7 @@ shared_ptr > GetConvolutionLayer( #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -104,6 +105,7 @@ shared_ptr > GetPoolingLayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -141,6 +143,7 @@ shared_ptr > GetLRNLayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -164,6 +167,7 @@ shared_ptr > GetReLULayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -187,6 +191,7 @@ shared_ptr > GetSigmoidLayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -210,6 +215,7 @@ shared_ptr > GetSoftmaxLayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } @@ -233,6 +239,7 @@ shared_ptr > GetTanHLayer(const LayerParameter& param) { #endif } else { LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + throw; // Avoids missing return warning } } diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 989319f1a07..6906e64bb0d 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -27,7 +27,7 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, } data_transformer_.reset( new DataTransformer(transform_param_, this->phase_)); - data_transformer_->InitRand(); + data_transformer_->InitRand(caffe_rng_rand()); // The subclasses should setup the size of bottom and top DataLayerSetUp(bottom, top); } @@ -67,7 +67,7 @@ void BasePrefetchingDataLayer::LayerSetUp( } #endif DLOG(INFO) << "Initializing prefetch"; - this->data_transformer_->InitRand(); + this->data_transformer_->InitRand(caffe_rng_rand()); StartInternalThread(); DLOG(INFO) << "Prefetch initialized."; } diff --git a/src/caffe/test/test_data_transformer.cpp b/src/caffe/test/test_data_transformer.cpp index 31bf1c1fb14..9f07936eb0b 100644 --- a/src/caffe/test/test_data_transformer.cpp +++ b/src/caffe/test/test_data_transformer.cpp @@ -42,7 +42,7 @@ class DataTransformTest : public ::testing::Test { DataTransformer transformer(transform_param, phase); const int crop_size = transform_param.crop_size(); Caffe::set_random_seed(seed_); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); Blob blob(1, datum.channels(), datum.height(), datum.width()); if (transform_param.crop_size() > 0) { blob.Reshape(1, datum.channels(), crop_size, crop_size); @@ -87,7 +87,7 @@ TYPED_TEST(DataTransformTest, TestEmptyTransform) { FillDatum(label, channels, height, width, unique_pixels, &datum); Blob blob(1, channels, height, width); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); transformer.Transform(datum, &blob); EXPECT_EQ(blob.num(), 1); EXPECT_EQ(blob.channels(), datum.channels()); @@ -110,7 +110,7 @@ TYPED_TEST(DataTransformTest, TestEmptyTransformUniquePixels) { FillDatum(label, channels, height, width, unique_pixels, &datum); Blob blob(1, 3, 4, 5); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); transformer.Transform(datum, &blob); EXPECT_EQ(blob.num(), 1); EXPECT_EQ(blob.channels(), datum.channels()); @@ -134,7 +134,7 @@ TYPED_TEST(DataTransformTest, TestCropSize) { Datum datum; FillDatum(label, channels, height, width, unique_pixels, &datum); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); Blob blob(1, channels, crop_size, crop_size); for (int iter = 0; iter < this->num_iter_; ++iter) { transformer.Transform(datum, &blob); @@ -272,7 +272,7 @@ TYPED_TEST(DataTransformTest, TestMeanValue) { FillDatum(label, channels, height, width, unique_pixels, &datum); Blob blob(1, channels, height, width); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); transformer.Transform(datum, &blob); for (int j = 0; j < blob.count(); ++j) { EXPECT_EQ(blob.cpu_data()[j], label - mean_value); @@ -294,7 +294,7 @@ TYPED_TEST(DataTransformTest, TestMeanValues) { FillDatum(label, channels, height, width, unique_pixels, &datum); Blob blob(1, channels, height, width); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); transformer.Transform(datum, &blob); for (int c = 0; c < channels; ++c) { for (int j = 0; j < height * width; ++j) { @@ -333,7 +333,7 @@ TYPED_TEST(DataTransformTest, TestMeanFile) { FillDatum(label, channels, height, width, unique_pixels, &datum); Blob blob(1, channels, height, width); DataTransformer transformer(transform_param, TEST); - transformer.InitRand(); + transformer.InitRand(caffe_rng_rand()); transformer.Transform(datum, &blob); for (int j = 0; j < blob.count(); ++j) { EXPECT_EQ(blob.cpu_data()[j], 0);