diff --git a/Makefile b/Makefile index 85e1a3e8c0d..7f0547d3360 100644 --- a/Makefile +++ b/Makefile @@ -118,7 +118,7 @@ LDFLAGS+= -L/usr/local/zed/lib -lsl_core -lsl_input -lsl_zed #-lstdc++ -D_GLIBCXX_USE_CXX11_ABI=0 endif -OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o +OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o ifeq ($(GPU), 1) LDFLAGS+= -lstdc++ OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o diff --git a/README.md b/README.md index 8d87a98eb29..52df11995a8 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,20 @@ # Yolo-v3 and Yolo-v2 for Windows and Linux ### (neural network for object detection) - Tensor Cores can be used on [Linux](https://github.com/AlexeyAB/darknet#how-to-compile-on-linux) and [Windows](https://github.com/AlexeyAB/darknet#how-to-compile-on-windows-using-vcpkg) -Contributors: https://github.com/AlexeyAB/darknet/graphs/contributors More details: http://pjreddie.com/darknet/yolo/ [![CircleCI](https://circleci.com/gh/AlexeyAB/darknet.svg?style=svg)](https://circleci.com/gh/AlexeyAB/darknet) [![TravisCI](https://travis-ci.org/AlexeyAB/darknet.svg?branch=master)](https://travis-ci.org/AlexeyAB/darknet) [![AppveyorCI](https://ci.appveyor.com/api/projects/status/594bwb5uoc1fxwiu/branch/master?svg=true)](https://ci.appveyor.com/project/AlexeyAB/darknet/branch/master) +[![Contributors](https://img.shields.io/github/contributors/AlexeyAB/Darknet.svg)](https://github.com/AlexeyAB/darknet/graphs/contributors) +[![License: Unlicense](https://img.shields.io/badge/license-Unlicense-blue.svg)](https://github.com/AlexeyAB/darknet/blob/master/LICENSE) * [Requirements (and how to install dependecies)](#requirements) * [Pre-trained models](#pre-trained-models) * [Explanations in issues](https://github.com/AlexeyAB/darknet/issues?q=is%3Aopen+is%3Aissue+label%3AExplanations) -* [Yolo v3 in other frameworks (TensorFlow, OpenVINO, OpenCV-dnn, ...)](#yolo-v3-in-other-frameworks) +* [Yolo v3 in other frameworks (TensorRT, TensorFlow, PyTorch, OpenVINO, OpenCV-dnn,...)](#yolo-v3-in-other-frameworks) 0. [Improvements in this repository](#improvements-in-this-repository) 1. [How to use](#how-to-use-on-the-command-line) @@ -44,7 +45,7 @@ More details: http://pjreddie.com/darknet/yolo/ * Windows or Linux * **CMake >= 3.8** for modern CUDA support: https://cmake.org/download/ * **CUDA 10.0**: https://developer.nvidia.com/cuda-toolkit-archive (on Linux do [Post-installation Actions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#post-installation-actions)) -* **OpenCV > 2.4**: use your preferred package manager (brew, apt), build from source using [vcpkg](https://github.com/Microsoft/vcpkg) or download from [OpenCV official site](https://opencv.org/releases.html) (on Windows set system variable `OpenCV_DIR` = `C:\opencv\build` - where are the `include` and `x64` folders [image](https://user-images.githubusercontent.com/4096485/53249516-5130f480-36c9-11e9-8238-a6e82e48c6f2.png)) +* **OpenCV >= 2.4**: use your preferred package manager (brew, apt), build from source using [vcpkg](https://github.com/Microsoft/vcpkg) or download from [OpenCV official site](https://opencv.org/releases.html) (on Windows set system variable `OpenCV_DIR` = `C:\opencv\build` - where are the `include` and `x64` folders [image](https://user-images.githubusercontent.com/4096485/53249516-5130f480-36c9-11e9-8238-a6e82e48c6f2.png)) * **cuDNN >= 7.0 for CUDA 10.0** https://developer.nvidia.com/rdp/cudnn-archive (on **Linux** copy `cudnn.h`,`libcudnn.so`... as desribed here https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installlinux-tar , on **Windows** copy `cudnn.h`,`cudnn64_7.dll`, `cudnn64_7.lib` as desribed here https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installwindows ) * **GPU with CC >= 3.0**: https://en.wikipedia.org/wiki/CUDA#GPUs_supported * on Linux **GCC or Clang**, on Windows **MSVC 2015/2017/2019** https://visualstudio.microsoft.com/thank-you-downloading-visual-studio/?sku=Community @@ -73,9 +74,11 @@ You can get cfg-files by path: `darknet/cfg/` #### Yolo v3 in other frameworks -* Convert `yolov3.weights`/`cfg` model to **TensorFlow**: by using [mystic123](https://github.com/mystic123/tensorflow-yolo-v3) or [jinyu121](https://github.com/jinyu121/DW2TF) projects, and [TensorFlow-lite](https://www.tensorflow.org/lite/guide/get_started#2_convert_the_model_format) -* To use Yolo v3 model in **Intel OpenVINO** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model) -* **OpenCV-dnn** is very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150) +* **TensorFlow:** convert `yolov3.weights`/`cfg` files to `yolov3.ckpt`/`pb/meta`: by using [mystic123](https://github.com/mystic123/tensorflow-yolo-v3) or [jinyu121](https://github.com/jinyu121/DW2TF) projects, and [TensorFlow-lite](https://www.tensorflow.org/lite/guide/get_started#2_convert_the_model_format) +* **Intel OpenVINO 2019 R1:** (Myriad X / USB Neural Compute Stick / Arria FPGA): read this [manual](https://software.intel.com/en-us/articles/OpenVINO-Using-TensorFlow#converting-a-darknet-yolo-model) +* **OpenCV-dnn** is a very fast DNN implementation on CPU (x86/ARM-Android), use `yolov3.weights`/`cfg` with: [C++ example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.cpp#L192-L221), [Python example](https://github.com/opencv/opencv/blob/8c25a8eb7b10fb50cda323ee6bec68aa1a9ce43c/samples/dnn/object_detection.py#L129-L150) +* **PyTorch > ONNX > CoreML > iOS** how to convert cfg/weights-files to pt-file: [ultralytics/yolov3](https://github.com/ultralytics/yolov3#darknet-conversion) and [iOS App](https://itunes.apple.com/app/id1452689527) +* **TensorRT** for YOLOv3 (-70% faster inference): [TensorRT & DeepStream](https://github.com/NVIDIA-AI-IOT/deepstream_reference_apps) ##### Examples of results @@ -164,6 +167,8 @@ Before make, you can set such options in the `Makefile`: [link](https://github.c * `OPENMP=1` to build with OpenMP support to accelerate Yolo by using multi-core CPU * `LIBSO=1` to build a library `darknet.so` and binary runable file `uselib` that uses this library. Or you can try to run so `LD_LIBRARY_PATH=./:$LD_LIBRARY_PATH ./uselib test.mp4` How to use this SO-library from your own code - you can look at C++ example: https://github.com/AlexeyAB/darknet/blob/master/src/yolo_console_dll.cpp or use in such a way: `LD_LIBRARY_PATH=./:$LD_LIBRARY_PATH ./uselib data/coco.names cfg/yolov3.cfg yolov3.weights test.mp4` +* `ZED_CAMERA=1` to build a library with ZED-3D-camera support (should be ZED SDK installed), then run + `LD_LIBRARY_PATH=./:$LD_LIBRARY_PATH ./uselib data/coco.names cfg/yolov3.cfg yolov3.weights zed_camera` To run Darknet on Linux use examples from this article, just use `./darknet` instead of `darknet.exe`, i.e. use this command: `./darknet detector test ./cfg/coco.data ./cfg/yolov3.cfg ./yolov3.weights` @@ -514,9 +519,9 @@ Example of custom object detection: `darknet.exe detector test data/obj.data yol * increase network resolution in your `.cfg`-file (`height=608`, `width=608` or any value multiple of 32) - it will increase precision - * check that each object is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark + * check that each object that you want to detect is mandatory labeled in your dataset - no one object in your data set should not be without label. In the most training issues - there are wrong labels in your dataset (got labels by using some conversion script, marked with a third-party tool, ...). Always check your dataset by using: https://github.com/AlexeyAB/Yolo_mark - * for each object which you want to detect - there must be at least 1 similar object in the Training dataset with about the same: shape, side of object, relative size, angle of rotation, tilt, illumination. So desirable that your training dataset include images with objects at diffrent: scales, rotations, lightings, from different sides, on different backgrounds - you should preferably have 2000 different images for each class or more, and you should train `2000*classes` iterations or more + * for each object which you want to detect - there must be at least 1 similar object in the Training dataset with about the same: shape, side of object, relative size, angle of rotation, tilt, illumination. So desirable that your training dataset include images with objects at diffrent: scales, rotations, lightings, from different sides, on different backgrounds - you should preferably have 2000 different images for each class or more, and you should train `2000*classes` iterations or more * desirable that your training dataset include images with non-labeled objects that you do not want to detect - negative samples without bounded box (empty `.txt` files) - use as many images of negative samples as there are images with objects diff --git a/build/darknet/darknet.vcxproj b/build/darknet/darknet.vcxproj index 787602d56c0..5a78ac602ec 100644 --- a/build/darknet/darknet.vcxproj +++ b/build/darknet/darknet.vcxproj @@ -140,6 +140,8 @@ Default NDEBUG true + + true @@ -183,6 +185,7 @@ + @@ -248,6 +251,7 @@ + diff --git a/build/darknet/darknet_no_gpu.vcxproj b/build/darknet/darknet_no_gpu.vcxproj index ec4cea997ff..5b0fe209248 100644 --- a/build/darknet/darknet_no_gpu.vcxproj +++ b/build/darknet/darknet_no_gpu.vcxproj @@ -189,6 +189,7 @@ + @@ -254,6 +255,7 @@ + diff --git a/build/darknet/x64/partial.cmd b/build/darknet/x64/partial.cmd index f0c2b9e8198..03759e8afeb 100644 --- a/build/darknet/x64/partial.cmd +++ b/build/darknet/x64/partial.cmd @@ -33,6 +33,9 @@ darknet.exe partial cfg/yolov3-spp.cfg yolov3-spp.weights yolov3-spp.conv.85 85 darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.15 15 +darknet.exe partial cfg/yolov3-tiny.cfg yolov3-tiny.weights yolov3-tiny.conv.14 14 + + darknet.exe partial cfg/yolo9000.cfg yolo9000.weights yolo9000.conv.22 22 diff --git a/build/darknet/yolo_cpp_dll.vcxproj b/build/darknet/yolo_cpp_dll.vcxproj index f2c22b03862..b14766998a2 100644 --- a/build/darknet/yolo_cpp_dll.vcxproj +++ b/build/darknet/yolo_cpp_dll.vcxproj @@ -187,6 +187,7 @@ + @@ -254,6 +255,7 @@ + diff --git a/build/darknet/yolo_cpp_dll_no_gpu.vcxproj b/build/darknet/yolo_cpp_dll_no_gpu.vcxproj index 5300e40b771..c5c17e84d66 100644 --- a/build/darknet/yolo_cpp_dll_no_gpu.vcxproj +++ b/build/darknet/yolo_cpp_dll_no_gpu.vcxproj @@ -173,6 +173,7 @@ + @@ -240,6 +241,7 @@ + diff --git a/include/darknet.h b/include/darknet.h index 878d07d809c..dc33db7d4db 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -32,7 +32,6 @@ #endif #endif -#define NFRAMES 3 #define SECRET_NUM -1234 #ifdef GPU @@ -136,6 +135,7 @@ typedef enum { RNN, GRU, LSTM, + CONV_LSTM, CRNN, BATCHNORM, NETWORK, @@ -208,8 +208,10 @@ struct layer { int index; int binary; int xnor; + int peephole; int use_bin_output; int steps; + int state_constrain; int hidden; int truth; float smooth; @@ -354,6 +356,7 @@ struct layer { float *z_cpu; float *r_cpu; float *h_cpu; + float *stored_h_cpu; float * prev_state_cpu; float *temp_cpu; @@ -369,6 +372,7 @@ struct layer { float *g_cpu; float *o_cpu; float *c_cpu; + float *stored_c_cpu; float *dc_cpu; float *binary_input; @@ -407,10 +411,13 @@ struct layer { struct layer *uh; struct layer *uo; struct layer *wo; + struct layer *vo; struct layer *uf; struct layer *wf; + struct layer *vf; struct layer *ui; struct layer *wi; + struct layer *vi; struct layer *ug; struct layer *wg; @@ -424,6 +431,7 @@ struct layer { float *z_gpu; float *r_gpu; float *h_gpu; + float *stored_h_gpu; float *temp_gpu; float *temp2_gpu; @@ -432,12 +440,16 @@ struct layer { float *dh_gpu; float *hh_gpu; float *prev_cell_gpu; + float *prev_state_gpu; + float *last_prev_state_gpu; + float *last_prev_cell_gpu; float *cell_gpu; float *f_gpu; float *i_gpu; float *g_gpu; float *o_gpu; float *c_gpu; + float *stored_c_gpu; float *dc_gpu; // adam @@ -451,7 +463,6 @@ struct layer { float * combine_gpu; float * combine_delta_gpu; - float * prev_state_gpu; float * forgot_state_gpu; float * forgot_delta_gpu; float * state_gpu; @@ -541,6 +552,7 @@ typedef struct network { float learning_rate_min; float learning_rate_max; int batches_per_cycle; + int batches_cycle_mult; float momentum; float decay; float gamma; @@ -549,6 +561,7 @@ typedef struct network { int time_steps; int step; int max_batches; + float *seq_scales; float *scales; int *steps; int num_steps; @@ -571,6 +584,7 @@ typedef struct network { float min_ratio; int center; int flip; // horizontal flip 50% probability augmentaiont for classifier training (default = 1) + int blur; float angle; float aspect; float exposure; @@ -579,6 +593,9 @@ typedef struct network { int random; int track; int augment_speed; + int sequential_subdivisions; + int init_sequential_subdivisions; + int current_subdivision; int try_fix_nan; int gpu_index; @@ -713,6 +730,7 @@ typedef struct load_args { int show_imgs; float jitter; int flip; + int blur; float angle; float aspect; float saturation; @@ -778,7 +796,7 @@ LIB_API float *network_predict_image(network *net, image im); LIB_API float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, float thresh_calc_avg_iou, const float iou_thresh, const int map_points, network *existing_net); LIB_API void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear, int dont_show, int calc_map, int mjpeg_port, int show_imgs); LIB_API void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile); + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box); LIB_API int network_width(network *net); LIB_API int network_height(network *net); LIB_API void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm); diff --git a/include/yolo_v2_class.hpp b/include/yolo_v2_class.hpp index fcaddcf0a51..323f7a1a66f 100644 --- a/include/yolo_v2_class.hpp +++ b/include/yolo_v2_class.hpp @@ -132,33 +132,28 @@ class Detector { else if (img_src.channels() == 1) cv::cvtColor(img_src, img, cv::COLOR_GRAY2BGR); else std::cerr << " Warning: img_src.channels() is not 1, 3 or 4. It is = " << img_src.channels() << std::endl; std::shared_ptr image_ptr(new image_t, [](image_t *img) { free_image(*img); delete img; }); - std::shared_ptr ipl_small = std::make_shared(img); - *image_ptr = ipl_to_image(ipl_small.get()); + *image_ptr = mat_to_image_custom(img); return image_ptr; } private: - static image_t ipl_to_image(IplImage* src) + static image_t mat_to_image_custom(cv::Mat mat) { - unsigned char *data = (unsigned char *)src->imageData; - int h = src->height; - int w = src->width; - int c = src->nChannels; - int step = src->widthStep; - image_t out = make_image_custom(w, h, c); - int count = 0; - - for (int k = 0; k < c; ++k) { - for (int i = 0; i < h; ++i) { - int i_step = i*step; - for (int j = 0; j < w; ++j) { - out.data[count++] = data[i_step + j*c + k] / 255.; + int w = mat.cols; + int h = mat.rows; + int c = mat.channels(); + image_t im = make_image_custom(w, h, c); + unsigned char *data = (unsigned char *)mat.data; + int step = mat.step; + for (int y = 0; y < h; ++y) { + for (int k = 0; k < c; ++k) { + for (int x = 0; x < w; ++x) { + im.data[k*w*h + y*w + x] = data[y*step + x*c + k] / 255.0f; } } } - - return out; + return im; } static image_t make_empty_image(int w, int h, int c) diff --git a/scripts/README.md b/scripts/README.md index 273ed2a7dcd..a38327ecb1b 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -12,6 +12,8 @@ ImageNet (ILSVRC2012): http://www.image-net.org/challenges/LSVRC/2012/nonpub-dow ImageNet (ILSVRC2015): http://image-net.org/small/download.php +ImageNet VID: http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php + Open Images: https://storage.googleapis.com/openimages/web/download.html Cityscapes: https://www.cityscapes-dataset.com/ diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 6bd86dcae37..6fd23c2ed27 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -210,6 +210,30 @@ __global__ void activate_array_logistic_kernel(float *x, int n) } } +__global__ void activate_array_tanh_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = tanh_activate_kernel(x[index]); + } +} + +__global__ void activate_array_hardtan_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = hardtan_activate_kernel(x[index]); + } +} + +__global__ void activate_array_relu_kernel(float *x, int n) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + x[index] = relu_activate_kernel(x[index]); + } +} + __global__ void gradient_array_kernel(float *x, int n, ACTIVATION a, float *delta) { int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; @@ -240,6 +264,14 @@ __global__ void gradient_array_logistic_kernel(float *x, int n, float *delta) } } +__global__ void gradient_array_tanh_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= tanh_gradient_kernel(x[index]); + } +} + __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -248,12 +280,23 @@ __global__ void gradient_array_hardtan_kernel(float *x, int n, float *delta) } } +__global__ void gradient_array_relu_kernel(float *x, int n, float *delta) +{ + int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < n) { + delta[index] *= relu_gradient_kernel(x[index]); + } +} + extern "C" void activate_array_ongpu(float *x, int n, ACTIVATION a) { const int num_blocks = get_number_of_blocks(n, BLOCK); if (a == LINEAR) return; else if(a == LEAKY) activate_array_leaky_kernel << > >(x, n); else if (a == LOGISTIC) activate_array_logistic_kernel << > >(x, n); + else if (a == TANH) activate_array_tanh_kernel << > >(x, n); + else if (a == HARDTAN) activate_array_hardtan_kernel << > >(x, n); + else if (a == RELU) activate_array_relu_kernel << > >(x, n); else if (a == SELU) activate_array_selu_kernel << > >(x, n); else activate_array_kernel<<>>(x, n, a); @@ -266,8 +309,10 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta if (a == LINEAR) return; else if (a == LEAKY) gradient_array_leaky_kernel << > >(x, n, delta); else if (a == LOGISTIC) gradient_array_logistic_kernel << > >(x, n, delta); - else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); + else if (a == TANH) gradient_array_tanh_kernel << > >(x, n, delta); else if (a == HARDTAN) gradient_array_hardtan_kernel << > >(x, n, delta); + else if (a == RELU) gradient_array_relu_kernel << > >(x, n, delta); + else if (a == SELU) gradient_array_selu_kernel << > >(x, n, delta); else gradient_array_kernel << > > (x, n, a, delta); CHECK_CUDA(cudaPeekAtLastError()); diff --git a/src/blas.c b/src/blas.c index c68b64c3d5e..a3ff84b6515 100644 --- a/src/blas.c +++ b/src/blas.c @@ -334,3 +334,22 @@ void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int for } } } + + +void constrain_cpu(int size, float ALPHA, float *X) +{ + int i; + for (i = 0; i < size; ++i) { + X[i] = fminf(ALPHA, fmaxf(-ALPHA, X[i])); + } +} + +void fix_nan_and_inf_cpu(float *input, size_t size) +{ + int i; + for (i = 0; i < size; ++i) { + float val = input[i]; + if (isnan(val) || isinf(val)) + input[i] = 1.0f / i; // pseudo random value + } +} \ No newline at end of file diff --git a/src/blas.h b/src/blas.h index 8e91fff2f76..09b9c9a70e5 100644 --- a/src/blas.h +++ b/src/blas.h @@ -1,9 +1,13 @@ #ifndef BLAS_H #define BLAS_H +#include +#include "darknet.h" + #ifdef GPU #include "dark_cuda.h" #include "tree.h" #endif + #ifdef __cplusplus extern "C" { #endif @@ -46,6 +50,8 @@ void softmax(float *input, int n, float temp, float *output, int stride); void upsample_cpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out); void softmax_cpu(float *input, int n, int batch, int batch_offset, int groups, int group_offset, int stride, float temp, float *output); void softmax_x_ent_cpu(int n, float *pred, float *truth, float *delta, float *error); +void constrain_cpu(int size, float ALPHA, float *X); +void fix_nan_and_inf_cpu(float *input, size_t size); #ifdef GPU @@ -105,6 +111,10 @@ void softmax_tree_gpu(float *input, int spatial, int batch, int stride, float te void fix_nan_and_inf(float *input, size_t size); int is_nan_or_inf(float *input, size_t size); +void add_3_arrays_activate(float *a1, float *a2, float *a3, size_t size, ACTIVATION a, float *dst); +void sum_of_mults(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst); +void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION a, float *dst); + #endif #ifdef __cplusplus } diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu index 4115669316c..2053bde627a 100644 --- a/src/blas_kernels.cu +++ b/src/blas_kernels.cu @@ -983,7 +983,7 @@ __global__ void fix_nan_and_inf_kernel(float *input, size_t size) if (index < size) { float val = input[index]; if (isnan(val) || isinf(val)) - input[index] = index; // pseudo random value + input[index] = 1.0f / index; // pseudo random value } } @@ -1022,3 +1022,66 @@ extern "C" int is_nan_or_inf(float *input, size_t size) CHECK_CUDA(cudaFreeHost(pinned_return)); return ret_val; } + +__global__ void add_3_arrays_activate_kernel(float *a1, float *a2, float *a3, size_t size, ACTIVATION a, float *dst) +{ + const int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) { + float val = 0; + val += a1[index]; + val += a2[index]; + if (a3) val += a3[index]; + if (a == LOGISTIC) val = 1.f / (1.f + expf(-val)); + else if(a == TANH) val = (2 / (1 + expf(-2 * val)) - 1); + dst[index] = val; + } +} + +extern "C" void add_3_arrays_activate(float *a1, float *a2, float *a3, size_t size, ACTIVATION a, float *dst) +{ + const int block_size = BLOCK; + const int num_blocks = get_number_of_blocks(size, block_size); + if (a != LOGISTIC && a != TANH) { + printf(" add_3_arrays_activate() doesn't support activation %d, it supports only LOGISTIC and TANH \n", a); + exit(EXIT_FAILURE); + } + add_3_arrays_activate_kernel << > >(a1, a2, a3, size, a, dst); +} + + +__global__ void sum_of_mults_kernel(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst) +{ + const int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) { + dst[index] = a1[index] * a2[index] + b1[index] * b2[index]; + } +} + +extern "C" void sum_of_mults(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst) +{ + const int block_size = BLOCK; + const int num_blocks = get_number_of_blocks(size, block_size); + sum_of_mults_kernel << > >(a1, a2, b1, b2, size, dst); +} + + +__global__ void activate_and_mult_kernel(float *a1, float *a2, size_t size, ACTIVATION a, float *dst) +{ + const int index = blockIdx.x*blockDim.x + threadIdx.x; + if (index < size) { + float val = a1[index]; + if (a == TANH) val = (2 / (1 + expf(-2 * val)) - 1); + dst[index] = val * a2[index]; + } +} + +extern "C" void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION a, float *dst) +{ + const int block_size = BLOCK; + const int num_blocks = get_number_of_blocks(size, block_size); + if (a != TANH) { + printf(" activat_and_mult() doesn't support activation %d, it supports only TANH \n", a); + exit(EXIT_FAILURE); + } + activate_and_mult_kernel << > >(a1, a2, size, a, dst); +} diff --git a/src/coco.c b/src/coco.c index 931c406c7a8..c1535a35f19 100644 --- a/src/coco.c +++ b/src/coco.c @@ -384,5 +384,5 @@ void run_coco(int argc, char **argv) else if(0==strcmp(argv[2], "valid")) validate_coco(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_coco_recall(cfg, weights); else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, hier_thresh, cam_index, filename, coco_classes, 80, frame_skip, - prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output); + prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output, 0); } diff --git a/src/conv_lstm_layer.c b/src/conv_lstm_layer.c new file mode 100644 index 00000000000..f3041cedc65 --- /dev/null +++ b/src/conv_lstm_layer.c @@ -0,0 +1,1186 @@ +// Page 4: https://arxiv.org/abs/1506.04214v2 +// Page 3: https://arxiv.org/pdf/1705.06368v3.pdf +// https://wikimedia.org/api/rest_v1/media/math/render/svg/1edbece2559479959fe829e9c6657efb380debe7 + +#include "conv_lstm_layer.h" +#include "connected_layer.h" +#include "convolutional_layer.h" +#include "utils.h" +#include "dark_cuda.h" +#include "blas.h" +#include "gemm.h" + +#include +#include +#include +#include + +static void increment_layer(layer *l, int steps) +{ + int num = l->outputs*l->batch*steps; + l->output += num; + l->delta += num; + l->x += num; + l->x_norm += num; + +#ifdef GPU + l->output_gpu += num; + l->delta_gpu += num; + l->x_gpu += num; + l->x_norm_gpu += num; +#endif +} + + +layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor) +{ + fprintf(stderr, "CONV_LSTM Layer: %d x %d x %d image, %d filters\n", h, w, c, output_filters); + /* + batch = batch / steps; + layer l = { (LAYER_TYPE)0 }; + l.batch = batch; + l.type = LSTM; + l.steps = steps; + l.inputs = inputs; + l.out_w = 1; + l.out_h = 1; + l.out_c = outputs; + */ + batch = batch / steps; + layer l = { (LAYER_TYPE)0 }; + l.batch = batch; + l.type = CONV_LSTM; + l.steps = steps; + l.size = size; + l.stride = stride; + l.pad = pad; + l.h = h; + l.w = w; + l.c = c; + l.groups = groups; + l.out_c = output_filters; + l.inputs = h * w * c; + l.xnor = xnor; + l.peephole = peephole; + + // U + l.uf = (layer*)calloc(1, sizeof(layer)); + *(l.uf) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.uf->batch = batch; + if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; + + l.ui = (layer*)calloc(1, sizeof(layer)); + *(l.ui) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.ui->batch = batch; + if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; + + l.ug = (layer*)calloc(1, sizeof(layer)); + *(l.ug) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.ug->batch = batch; + if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; + + l.uo = (layer*)calloc(1, sizeof(layer)); + *(l.uo) = make_convolutional_layer(batch, steps, h, w, c, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.uo->batch = batch; + if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; + + + // W + l.wf = (layer*)calloc(1, sizeof(layer)); + *(l.wf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.wf->batch = batch; + if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; + + l.wi = (layer*)calloc(1, sizeof(layer)); + *(l.wi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.wi->batch = batch; + if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; + + l.wg = (layer*)calloc(1, sizeof(layer)); + *(l.wg) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.wg->batch = batch; + if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; + + l.wo = (layer*)calloc(1, sizeof(layer)); + *(l.wo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.wo->batch = batch; + if (l.workspace_size < l.wo->workspace_size) l.workspace_size = l.wo->workspace_size; + + + // V + l.vf = (layer*)calloc(1, sizeof(layer)); + if (l.peephole) { + *(l.vf) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.vf->batch = batch; + if (l.workspace_size < l.vf->workspace_size) l.workspace_size = l.vf->workspace_size; + } + + l.vi = (layer*)calloc(1, sizeof(layer)); + if (l.peephole) { + *(l.vi) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.vi->batch = batch; + if (l.workspace_size < l.vi->workspace_size) l.workspace_size = l.vi->workspace_size; + } + + l.vo = (layer*)calloc(1, sizeof(layer)); + if (l.peephole) { + *(l.vo) = make_convolutional_layer(batch, steps, h, w, output_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.vo->batch = batch; + if (l.workspace_size < l.vo->workspace_size) l.workspace_size = l.vo->workspace_size; + } + + + l.batch_normalize = batch_normalize; + + l.out_h = l.wo->out_h; + l.out_w = l.wo->out_w; + l.outputs = l.wo->outputs; + int outputs = l.outputs; + l.inputs = w*h*c; + + assert(l.wo->outputs == l.uo->outputs); + + l.output = (float*)calloc(outputs * batch * steps, sizeof(float)); + //l.state = (float*)calloc(outputs * batch, sizeof(float)); + + l.forward = forward_conv_lstm_layer; + l.update = update_conv_lstm_layer; + l.backward = backward_conv_lstm_layer; + + l.prev_state_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.prev_cell_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.cell_cpu = (float*)calloc(batch*outputs*steps, sizeof(float)); + + l.f_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.i_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.g_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.o_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.c_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.stored_c_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.h_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.stored_h_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.temp_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.temp2_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.temp3_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.dc_cpu = (float*)calloc(batch*outputs, sizeof(float)); + l.dh_cpu = (float*)calloc(batch*outputs, sizeof(float)); + +#ifdef GPU + l.forward_gpu = forward_conv_lstm_layer_gpu; + l.backward_gpu = backward_conv_lstm_layer_gpu; + l.update_gpu = update_conv_lstm_layer_gpu; + + //l.state_gpu = cuda_make_array(l.state, batch*l.outputs); + + l.output_gpu = cuda_make_array(0, batch*outputs*steps); + l.delta_gpu = cuda_make_array(0, batch*l.outputs*steps); + + l.prev_state_gpu = cuda_make_array(0, batch*outputs); + l.prev_cell_gpu = cuda_make_array(0, batch*outputs); + l.cell_gpu = cuda_make_array(0, batch*outputs*steps); + + l.f_gpu = cuda_make_array(0, batch*outputs); + l.i_gpu = cuda_make_array(0, batch*outputs); + l.g_gpu = cuda_make_array(0, batch*outputs); + l.o_gpu = cuda_make_array(0, batch*outputs); + l.c_gpu = cuda_make_array(0, batch*outputs); + l.h_gpu = cuda_make_array(0, batch*outputs); + l.stored_c_gpu = cuda_make_array(0, batch*outputs); + l.stored_h_gpu = cuda_make_array(0, batch*outputs); + l.temp_gpu = cuda_make_array(0, batch*outputs); + l.temp2_gpu = cuda_make_array(0, batch*outputs); + l.temp3_gpu = cuda_make_array(0, batch*outputs); + l.dc_gpu = cuda_make_array(0, batch*outputs); + l.dh_gpu = cuda_make_array(0, batch*outputs); + l.last_prev_state_gpu = cuda_make_array(0, l.batch*l.outputs); + l.last_prev_cell_gpu = cuda_make_array(0, l.batch*l.outputs); + +#endif + + l.bflops = l.uf->bflops + l.ui->bflops + l.ug->bflops + l.uo->bflops + + l.wf->bflops + l.wi->bflops + l.wg->bflops + l.wo->bflops + + l.vf->bflops + l.vi->bflops + l.vo->bflops; + + if(l.peephole) l.bflops += 12 * l.outputs*l.batch / 1000000000.; + else l.bflops += 9 * l.outputs*l.batch / 1000000000.; + + return l; +} + +void update_conv_lstm_layer(layer l, int batch, float learning_rate, float momentum, float decay) +{ + if (l.peephole) { + update_convolutional_layer(*(l.vf), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.vi), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.vo), batch, learning_rate, momentum, decay); + } + update_convolutional_layer(*(l.wf), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.wi), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.wg), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.wo), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.uf), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.ui), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.ug), batch, learning_rate, momentum, decay); + update_convolutional_layer(*(l.uo), batch, learning_rate, momentum, decay); +} + +void resize_conv_lstm_layer(layer *l, int w, int h) +{ + if (l->peephole) { + resize_convolutional_layer(l->vf, w, h); + if (l->workspace_size < l->vf->workspace_size) l->workspace_size = l->vf->workspace_size; + + resize_convolutional_layer(l->vi, w, h); + if (l->workspace_size < l->vi->workspace_size) l->workspace_size = l->vi->workspace_size; + + resize_convolutional_layer(l->vo, w, h); + if (l->workspace_size < l->vo->workspace_size) l->workspace_size = l->vo->workspace_size; + } + + resize_convolutional_layer(l->wf, w, h); + if (l->workspace_size < l->wf->workspace_size) l->workspace_size = l->wf->workspace_size; + + resize_convolutional_layer(l->wi, w, h); + if (l->workspace_size < l->wi->workspace_size) l->workspace_size = l->wi->workspace_size; + + resize_convolutional_layer(l->wg, w, h); + if (l->workspace_size < l->wg->workspace_size) l->workspace_size = l->wg->workspace_size; + + resize_convolutional_layer(l->wo, w, h); + if (l->workspace_size < l->wo->workspace_size) l->workspace_size = l->wo->workspace_size; + + + resize_convolutional_layer(l->uf, w, h); + if (l->workspace_size < l->uf->workspace_size) l->workspace_size = l->uf->workspace_size; + + resize_convolutional_layer(l->ui, w, h); + if (l->workspace_size < l->ui->workspace_size) l->workspace_size = l->ui->workspace_size; + + resize_convolutional_layer(l->ug, w, h); + if (l->workspace_size < l->ug->workspace_size) l->workspace_size = l->ug->workspace_size; + + resize_convolutional_layer(l->uo, w, h); + if (l->workspace_size < l->uo->workspace_size) l->workspace_size = l->uo->workspace_size; + + l->w = w; + l->h = h; + l->out_h = l->wo->out_h; + l->out_w = l->wo->out_w; + l->outputs = l->wo->outputs; + int outputs = l->outputs; + l->inputs = w*h*l->c; + int steps = l->steps; + int batch = l->batch; + + assert(l->wo->outputs == l->uo->outputs); + + l->output = (float*)realloc(l->output, outputs * batch * steps * sizeof(float)); + //l->state = (float*)realloc(l->state, outputs * batch * sizeof(float)); + + l->prev_state_cpu = (float*)realloc(l->prev_state_cpu, batch*outputs * sizeof(float)); + l->prev_cell_cpu = (float*)realloc(l->prev_cell_cpu, batch*outputs * sizeof(float)); + l->cell_cpu = (float*)realloc(l->cell_cpu, batch*outputs*steps * sizeof(float)); + + l->f_cpu = (float*)realloc(l->f_cpu, batch*outputs * sizeof(float)); + l->i_cpu = (float*)realloc(l->i_cpu, batch*outputs * sizeof(float)); + l->g_cpu = (float*)realloc(l->g_cpu, batch*outputs * sizeof(float)); + l->o_cpu = (float*)realloc(l->o_cpu, batch*outputs * sizeof(float)); + l->c_cpu = (float*)realloc(l->c_cpu, batch*outputs * sizeof(float)); + l->h_cpu = (float*)realloc(l->h_cpu, batch*outputs * sizeof(float)); + l->temp_cpu = (float*)realloc(l->temp_cpu, batch*outputs * sizeof(float)); + l->temp2_cpu = (float*)realloc(l->temp2_cpu, batch*outputs * sizeof(float)); + l->temp3_cpu = (float*)realloc(l->temp3_cpu, batch*outputs * sizeof(float)); + l->dc_cpu = (float*)realloc(l->dc_cpu, batch*outputs * sizeof(float)); + l->dh_cpu = (float*)realloc(l->dh_cpu, batch*outputs * sizeof(float)); + l->stored_c_cpu = (float*)realloc(l->stored_c_cpu, batch*outputs * sizeof(float)); + l->stored_h_cpu = (float*)realloc(l->stored_h_cpu, batch*outputs * sizeof(float)); + +#ifdef GPU + //if (l->state_gpu) cudaFree(l->state_gpu); + //l->state_gpu = cuda_make_array(l->state, batch*l->outputs); + + if (l->output_gpu) cudaFree(l->output_gpu); + l->output_gpu = cuda_make_array(0, batch*outputs*steps); + + if (l->delta_gpu) cudaFree(l->delta_gpu); + l->delta_gpu = cuda_make_array(0, batch*outputs*steps); + + if (l->prev_state_gpu) cudaFree(l->prev_state_gpu); + l->prev_state_gpu = cuda_make_array(0, batch*outputs); + + if (l->prev_cell_gpu) cudaFree(l->prev_cell_gpu); + l->prev_cell_gpu = cuda_make_array(0, batch*outputs); + + if (l->cell_gpu) cudaFree(l->cell_gpu); + l->cell_gpu = cuda_make_array(0, batch*outputs*steps); + + if (l->f_gpu) cudaFree(l->f_gpu); + l->f_gpu = cuda_make_array(0, batch*outputs); + + if (l->i_gpu) cudaFree(l->i_gpu); + l->i_gpu = cuda_make_array(0, batch*outputs); + + if (l->g_gpu) cudaFree(l->g_gpu); + l->g_gpu = cuda_make_array(0, batch*outputs); + + if (l->o_gpu) cudaFree(l->o_gpu); + l->o_gpu = cuda_make_array(0, batch*outputs); + + if (l->c_gpu) cudaFree(l->c_gpu); + l->c_gpu = cuda_make_array(0, batch*outputs); + + if (l->h_gpu) cudaFree(l->h_gpu); + l->h_gpu = cuda_make_array(0, batch*outputs); + + if (l->temp_gpu) cudaFree(l->temp_gpu); + l->temp_gpu = cuda_make_array(0, batch*outputs); + + if (l->temp2_gpu) cudaFree(l->temp2_gpu); + l->temp2_gpu = cuda_make_array(0, batch*outputs); + + if (l->temp3_gpu) cudaFree(l->temp3_gpu); + l->temp3_gpu = cuda_make_array(0, batch*outputs); + + if (l->dc_gpu) cudaFree(l->dc_gpu); + l->dc_gpu = cuda_make_array(0, batch*outputs); + + if (l->dh_gpu) cudaFree(l->dh_gpu); + l->dh_gpu = cuda_make_array(0, batch*outputs); + + if (l->stored_c_gpu) cudaFree(l->stored_c_gpu); + l->stored_c_gpu = cuda_make_array(0, batch*outputs); + + if (l->stored_h_gpu) cudaFree(l->stored_h_gpu); + l->stored_h_gpu = cuda_make_array(0, batch*outputs); + + if (l->last_prev_state_gpu) cudaFree(l->last_prev_state_gpu); + l->last_prev_state_gpu = cuda_make_array(0, batch*outputs); + + if (l->last_prev_cell_gpu) cudaFree(l->last_prev_cell_gpu); + l->last_prev_cell_gpu = cuda_make_array(0, batch*outputs); +#endif +} + +void free_state_conv_lstm(layer l) +{ + int i; + for (i = 0; i < l.outputs * l.batch; ++i) l.h_cpu[i] = 0; + for (i = 0; i < l.outputs * l.batch; ++i) l.c_cpu[i] = 0; + +#ifdef GPU + cuda_push_array(l.h_gpu, l.h_cpu, l.outputs * l.batch); + cuda_push_array(l.c_gpu, l.c_cpu, l.outputs * l.batch); + + //fill_ongpu(l.outputs * l.batch, 0, l.dc_gpu, 1); // dont use + //fill_ongpu(l.outputs * l.batch, 0, l.dh_gpu, 1); // dont use +#endif // GPU +} + +void randomize_state_conv_lstm(layer l) +{ + int i; + for (i = 0; i < l.outputs * l.batch; ++i) l.h_cpu[i] = rand_uniform(-1, 1); + for (i = 0; i < l.outputs * l.batch; ++i) l.c_cpu[i] = rand_uniform(-1, 1); + +#ifdef GPU + cuda_push_array(l.h_gpu, l.h_cpu, l.outputs * l.batch); + cuda_push_array(l.c_gpu, l.c_cpu, l.outputs * l.batch); +#endif // GPU +} + + +void remember_state_conv_lstm(layer l) +{ + memcpy(l.stored_c_cpu, l.c_cpu, l.outputs * l.batch * sizeof(float)); + memcpy(l.stored_h_cpu, l.h_cpu, l.outputs * l.batch * sizeof(float)); + +#ifdef GPU + copy_ongpu(l.outputs*l.batch, l.c_gpu, 1, l.stored_c_gpu, 1); + copy_ongpu(l.outputs*l.batch, l.h_gpu, 1, l.stored_h_gpu, 1); +#endif // GPU +} + +void restore_state_conv_lstm(layer l) +{ + memcpy(l.c_cpu, l.stored_c_cpu, l.outputs * l.batch * sizeof(float)); + memcpy(l.h_cpu, l.stored_h_cpu, l.outputs * l.batch * sizeof(float)); + +#ifdef GPU + copy_ongpu(l.outputs*l.batch, l.stored_c_gpu, 1, l.c_gpu, 1); + copy_ongpu(l.outputs*l.batch, l.stored_h_gpu, 1, l.h_gpu, 1); +#endif // GPU +} + +void forward_conv_lstm_layer(layer l, network_state state) +{ + network_state s = { 0 }; + s.train = state.train; + s.workspace = state.workspace; + s.net = state.net; + int i; + layer vf = *(l.vf); + layer vi = *(l.vi); + layer vo = *(l.vo); + + layer wf = *(l.wf); + layer wi = *(l.wi); + layer wg = *(l.wg); + layer wo = *(l.wo); + + layer uf = *(l.uf); + layer ui = *(l.ui); + layer ug = *(l.ug); + layer uo = *(l.uo); + + if (state.train) { + if (l.peephole) { + fill_cpu(l.outputs * l.batch * l.steps, 0, vf.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, vi.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, vo.delta, 1); + } + + fill_cpu(l.outputs * l.batch * l.steps, 0, wf.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, wi.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, wg.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, wo.delta, 1); + + fill_cpu(l.outputs * l.batch * l.steps, 0, uf.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, ui.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, ug.delta, 1); + fill_cpu(l.outputs * l.batch * l.steps, 0, uo.delta, 1); + + fill_cpu(l.outputs * l.batch * l.steps, 0, l.delta, 1); + } + + for (i = 0; i < l.steps; ++i) + { + if (l.peephole) { + assert(l.outputs == vf.out_w * vf.out_h * vf.out_c); + s.input = l.c_cpu; + forward_convolutional_layer(vf, s); + forward_convolutional_layer(vi, s); + // vo below + } + + assert(l.outputs == wf.out_w * wf.out_h * wf.out_c); + assert(wf.c == l.out_c && wi.c == l.out_c && wg.c == l.out_c && wo.c == l.out_c); + + s.input = l.h_cpu; + forward_convolutional_layer(wf, s); + forward_convolutional_layer(wi, s); + forward_convolutional_layer(wg, s); + forward_convolutional_layer(wo, s); + + assert(l.inputs == uf.w * uf.h * uf.c); + assert(uf.c == l.c && ui.c == l.c && ug.c == l.c && uo.c == l.c); + + s.input = state.input; + forward_convolutional_layer(uf, s); + forward_convolutional_layer(ui, s); + forward_convolutional_layer(ug, s); + forward_convolutional_layer(uo, s); + + // f = wf + uf + vf + copy_cpu(l.outputs*l.batch, wf.output, 1, l.f_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, uf.output, 1, l.f_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vf.output, 1, l.f_cpu, 1); + + // i = wi + ui + vi + copy_cpu(l.outputs*l.batch, wi.output, 1, l.i_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, ui.output, 1, l.i_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vi.output, 1, l.i_cpu, 1); + + // g = wg + ug + copy_cpu(l.outputs*l.batch, wg.output, 1, l.g_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, ug.output, 1, l.g_cpu, 1); + + activate_array(l.f_cpu, l.outputs*l.batch, LOGISTIC); + activate_array(l.i_cpu, l.outputs*l.batch, LOGISTIC); + activate_array(l.g_cpu, l.outputs*l.batch, TANH); + + // c = f*c + i*g + copy_cpu(l.outputs*l.batch, l.i_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.g_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.f_cpu, 1, l.c_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, l.temp_cpu, 1, l.c_cpu, 1); + + // o = wo + uo + vo(c_new) + if (l.peephole) { + s.input = l.c_cpu; + forward_convolutional_layer(vo, s); + } + copy_cpu(l.outputs*l.batch, wo.output, 1, l.o_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, uo.output, 1, l.o_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vo.output, 1, l.o_cpu, 1); + activate_array(l.o_cpu, l.outputs*l.batch, LOGISTIC); + + // h = o * tanh(c) + copy_cpu(l.outputs*l.batch, l.c_cpu, 1, l.h_cpu, 1); + activate_array(l.h_cpu, l.outputs*l.batch, TANH); + mul_cpu(l.outputs*l.batch, l.o_cpu, 1, l.h_cpu, 1); + + if (l.state_constrain) constrain_cpu(l.outputs*l.batch, l.state_constrain, l.c_cpu); + fix_nan_and_inf_cpu(l.c_cpu, l.outputs*l.batch); + fix_nan_and_inf_cpu(l.h_cpu, l.outputs*l.batch); + + copy_cpu(l.outputs*l.batch, l.c_cpu, 1, l.cell_cpu, 1); + copy_cpu(l.outputs*l.batch, l.h_cpu, 1, l.output, 1); + + state.input += l.inputs*l.batch; + l.output += l.outputs*l.batch; + l.cell_cpu += l.outputs*l.batch; + + if (l.peephole) { + increment_layer(&vf, 1); + increment_layer(&vi, 1); + increment_layer(&vo, 1); + } + + increment_layer(&wf, 1); + increment_layer(&wi, 1); + increment_layer(&wg, 1); + increment_layer(&wo, 1); + + increment_layer(&uf, 1); + increment_layer(&ui, 1); + increment_layer(&ug, 1); + increment_layer(&uo, 1); + } +} + +void backward_conv_lstm_layer(layer l, network_state state) +{ + network_state s = { 0 }; + s.train = state.train; + s.workspace = state.workspace; + int i; + layer vf = *(l.vf); + layer vi = *(l.vi); + layer vo = *(l.vo); + + layer wf = *(l.wf); + layer wi = *(l.wi); + layer wg = *(l.wg); + layer wo = *(l.wo); + + layer uf = *(l.uf); + layer ui = *(l.ui); + layer ug = *(l.ug); + layer uo = *(l.uo); + + if (l.peephole) { + increment_layer(&vf, l.steps - 1); + increment_layer(&vi, l.steps - 1); + increment_layer(&vo, l.steps - 1); + } + + increment_layer(&wf, l.steps - 1); + increment_layer(&wi, l.steps - 1); + increment_layer(&wg, l.steps - 1); + increment_layer(&wo, l.steps - 1); + + increment_layer(&uf, l.steps - 1); + increment_layer(&ui, l.steps - 1); + increment_layer(&ug, l.steps - 1); + increment_layer(&uo, l.steps - 1); + + state.input += l.inputs*l.batch*(l.steps - 1); + if (state.delta) state.delta += l.inputs*l.batch*(l.steps - 1); + + l.output += l.outputs*l.batch*(l.steps - 1); + l.cell_cpu += l.outputs*l.batch*(l.steps - 1); + l.delta += l.outputs*l.batch*(l.steps - 1); + + for (i = l.steps - 1; i >= 0; --i) { + if (i != 0) copy_cpu(l.outputs*l.batch, l.cell_cpu - l.outputs*l.batch, 1, l.prev_cell_cpu, 1); + copy_cpu(l.outputs*l.batch, l.cell_cpu, 1, l.c_cpu, 1); + if (i != 0) copy_cpu(l.outputs*l.batch, l.output - l.outputs*l.batch, 1, l.prev_state_cpu, 1); + copy_cpu(l.outputs*l.batch, l.output, 1, l.h_cpu, 1); + + l.dh_cpu = (i == 0) ? 0 : l.delta - l.outputs*l.batch; + + // f = wf + uf + vf + copy_cpu(l.outputs*l.batch, wf.output, 1, l.f_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, uf.output, 1, l.f_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vf.output, 1, l.f_cpu, 1); + + // i = wi + ui + vi + copy_cpu(l.outputs*l.batch, wi.output, 1, l.i_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, ui.output, 1, l.i_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vi.output, 1, l.i_cpu, 1); + + // g = wg + ug + copy_cpu(l.outputs*l.batch, wg.output, 1, l.g_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, ug.output, 1, l.g_cpu, 1); + + // o = wo + uo + vo + copy_cpu(l.outputs*l.batch, wo.output, 1, l.o_cpu, 1); + axpy_cpu(l.outputs*l.batch, 1, uo.output, 1, l.o_cpu, 1); + if (l.peephole) axpy_cpu(l.outputs*l.batch, 1, vo.output, 1, l.o_cpu, 1); + + activate_array(l.f_cpu, l.outputs*l.batch, LOGISTIC); + activate_array(l.i_cpu, l.outputs*l.batch, LOGISTIC); + activate_array(l.g_cpu, l.outputs*l.batch, TANH); + activate_array(l.o_cpu, l.outputs*l.batch, LOGISTIC); + + copy_cpu(l.outputs*l.batch, l.delta, 1, l.temp3_cpu, 1); + + copy_cpu(l.outputs*l.batch, l.c_cpu, 1, l.temp_cpu, 1); + activate_array(l.temp_cpu, l.outputs*l.batch, TANH); + + copy_cpu(l.outputs*l.batch, l.temp3_cpu, 1, l.temp2_cpu, 1); + mul_cpu(l.outputs*l.batch, l.o_cpu, 1, l.temp2_cpu, 1); + + gradient_array(l.temp_cpu, l.outputs*l.batch, TANH, l.temp2_cpu); + axpy_cpu(l.outputs*l.batch, 1, l.dc_cpu, 1, l.temp2_cpu, 1); + // temp = tanh(c) + // temp2 = delta * o * grad_tanh(tanh(c)) + // temp3 = delta + + copy_cpu(l.outputs*l.batch, l.c_cpu, 1, l.temp_cpu, 1); + activate_array(l.temp_cpu, l.outputs*l.batch, TANH); + mul_cpu(l.outputs*l.batch, l.temp3_cpu, 1, l.temp_cpu, 1); + gradient_array(l.o_cpu, l.outputs*l.batch, LOGISTIC, l.temp_cpu); + // delta for o(w,u,v): temp = delta * tanh(c) * grad_logistic(o) + // delta for c,f,i,g(w,u,v): temp2 = delta * o * grad_tanh(tanh(c)) + delta_c(???) + // delta for output: temp3 = delta + + // o + // delta for O(w,u,v): temp = delta * tanh(c) * grad_logistic(o) + if (l.peephole) { + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, vo.delta, 1); + s.input = l.cell_cpu; + //s.delta = l.dc_cpu; + backward_convolutional_layer(vo, s); + } + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, wo.delta, 1); + s.input = l.prev_state_cpu; + //s.delta = l.dh_cpu; + backward_convolutional_layer(wo, s); + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, uo.delta, 1); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer(uo, s); + + // g + copy_cpu(l.outputs*l.batch, l.temp2_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.i_cpu, 1, l.temp_cpu, 1); + gradient_array(l.g_cpu, l.outputs*l.batch, TANH, l.temp_cpu); + // delta for c,f,i,g(w,u,v): temp2 = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * g * grad_logistic(i) + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, wg.delta, 1); + s.input = l.prev_state_cpu; + //s.delta = l.dh_cpu; + backward_convolutional_layer(wg, s); + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, ug.delta, 1); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer(ug, s); + + // i + copy_cpu(l.outputs*l.batch, l.temp2_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.g_cpu, 1, l.temp_cpu, 1); + gradient_array(l.i_cpu, l.outputs*l.batch, LOGISTIC, l.temp_cpu); + // delta for c,f,i,g(w,u,v): temp2 = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * g * grad_logistic(i) + + if (l.peephole) { + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, vi.delta, 1); + s.input = l.prev_cell_cpu; + //s.delta = l.dc_cpu; + backward_convolutional_layer(vi, s); + } + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, wi.delta, 1); + s.input = l.prev_state_cpu; + //s.delta = l.dh_cpu; + backward_convolutional_layer(wi, s); + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, ui.delta, 1); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer(ui, s); + + // f + copy_cpu(l.outputs*l.batch, l.temp2_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.prev_cell_cpu, 1, l.temp_cpu, 1); + gradient_array(l.f_cpu, l.outputs*l.batch, LOGISTIC, l.temp_cpu); + // delta for c,f,i,g(w,u,v): temp2 = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * c * grad_logistic(f) + + if (l.peephole) { + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, vf.delta, 1); + s.input = l.prev_cell_cpu; + //s.delta = l.dc_cpu; + backward_convolutional_layer(vf, s); + } + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, wf.delta, 1); + s.input = l.prev_state_cpu; + //s.delta = l.dh_cpu; + backward_convolutional_layer(wf, s); + + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, uf.delta, 1); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer(uf, s); + + copy_cpu(l.outputs*l.batch, l.temp2_cpu, 1, l.temp_cpu, 1); + mul_cpu(l.outputs*l.batch, l.f_cpu, 1, l.temp_cpu, 1); + copy_cpu(l.outputs*l.batch, l.temp_cpu, 1, l.dc_cpu, 1); + + state.input -= l.inputs*l.batch; + if (state.delta) state.delta -= l.inputs*l.batch; + l.output -= l.outputs*l.batch; + l.cell_cpu -= l.outputs*l.batch; + l.delta -= l.outputs*l.batch; + + if (l.peephole) { + increment_layer(&vf, -1); + increment_layer(&vi, -1); + increment_layer(&vo, -1); + } + + increment_layer(&wf, -1); + increment_layer(&wi, -1); + increment_layer(&wg, -1); + increment_layer(&wo, -1); + + increment_layer(&uf, -1); + increment_layer(&ui, -1); + increment_layer(&ug, -1); + increment_layer(&uo, -1); + } +} + +#ifdef GPU +void pull_conv_lstm_layer(layer l) +{ + if (l.peephole) { + pull_convolutional_layer(*(l.vf)); + pull_convolutional_layer(*(l.vi)); + pull_convolutional_layer(*(l.vo)); + } + pull_convolutional_layer(*(l.wf)); + pull_convolutional_layer(*(l.wi)); + pull_convolutional_layer(*(l.wg)); + pull_convolutional_layer(*(l.wo)); + pull_convolutional_layer(*(l.uf)); + pull_convolutional_layer(*(l.ui)); + pull_convolutional_layer(*(l.ug)); + pull_convolutional_layer(*(l.uo)); +} + +void push_conv_lstm_layer(layer l) +{ + if (l.peephole) { + push_convolutional_layer(*(l.vf)); + push_convolutional_layer(*(l.vi)); + push_convolutional_layer(*(l.vo)); + } + push_convolutional_layer(*(l.wf)); + push_convolutional_layer(*(l.wi)); + push_convolutional_layer(*(l.wg)); + push_convolutional_layer(*(l.wo)); + push_convolutional_layer(*(l.uf)); + push_convolutional_layer(*(l.ui)); + push_convolutional_layer(*(l.ug)); + push_convolutional_layer(*(l.uo)); +} + +void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay) +{ + if (l.peephole) { + update_convolutional_layer_gpu(*(l.vf), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.vi), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.vo), batch, learning_rate, momentum, decay); + } + update_convolutional_layer_gpu(*(l.wf), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.wi), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.wg), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.wo), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.uf), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.ui), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.ug), batch, learning_rate, momentum, decay); + update_convolutional_layer_gpu(*(l.uo), batch, learning_rate, momentum, decay); +} + +void forward_conv_lstm_layer_gpu(layer l, network_state state) +{ + network_state s = { 0 }; + s.train = state.train; + s.workspace = state.workspace; + s.net = state.net; + if (!state.train) s.index = state.index; // don't use TC for training (especially without cuda_convert_f32_to_f16() ) + int i; + layer vf = *(l.vf); + layer vi = *(l.vi); + layer vo = *(l.vo); + + layer wf = *(l.wf); + layer wi = *(l.wi); + layer wg = *(l.wg); + layer wo = *(l.wo); + + layer uf = *(l.uf); + layer ui = *(l.ui); + layer ug = *(l.ug); + layer uo = *(l.uo); + + if (state.train) { + if (l.peephole) { + fill_ongpu(l.outputs * l.batch * l.steps, 0, vf.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, vi.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, vo.delta_gpu, 1); + } + + fill_ongpu(l.outputs * l.batch * l.steps, 0, wf.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, wi.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, wg.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, wo.delta_gpu, 1); + + fill_ongpu(l.outputs * l.batch * l.steps, 0, uf.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, ui.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, ug.delta_gpu, 1); + fill_ongpu(l.outputs * l.batch * l.steps, 0, uo.delta_gpu, 1); + + fill_ongpu(l.outputs * l.batch * l.steps, 0, l.delta_gpu, 1); + } + + for (i = 0; i < l.steps; ++i) + { + if (l.peephole) { + assert(l.outputs == vf.out_w * vf.out_h * vf.out_c); + s.input = l.c_gpu; + forward_convolutional_layer_gpu(vf, s); + forward_convolutional_layer_gpu(vi, s); + // vo below + } + + assert(l.outputs == wf.out_w * wf.out_h * wf.out_c); + assert(wf.c == l.out_c && wi.c == l.out_c && wg.c == l.out_c && wo.c == l.out_c); + + s.input = l.h_gpu; + forward_convolutional_layer_gpu(wf, s); + forward_convolutional_layer_gpu(wi, s); + forward_convolutional_layer_gpu(wg, s); + forward_convolutional_layer_gpu(wo, s); + + assert(l.inputs == uf.w * uf.h * uf.c); + assert(uf.c == l.c && ui.c == l.c && ug.c == l.c && uo.c == l.c); + + s.input = state.input; + forward_convolutional_layer_gpu(uf, s); + forward_convolutional_layer_gpu(ui, s); + forward_convolutional_layer_gpu(ug, s); + forward_convolutional_layer_gpu(uo, s); + + // f = wf + uf + vf + add_3_arrays_activate(wf.output_gpu, uf.output_gpu, (l.peephole)?vf.output_gpu:NULL, l.outputs*l.batch, LOGISTIC, l.f_gpu); + //copy_ongpu(l.outputs*l.batch, wf.output_gpu, 1, l.f_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, uf.output_gpu, 1, l.f_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vf.output_gpu, 1, l.f_gpu, 1); + //activate_array_ongpu(l.f_gpu, l.outputs*l.batch, LOGISTIC); + + // i = wi + ui + vi + add_3_arrays_activate(wi.output_gpu, ui.output_gpu, (l.peephole) ? vi.output_gpu : NULL, l.outputs*l.batch, LOGISTIC, l.i_gpu); + //copy_ongpu(l.outputs*l.batch, wi.output_gpu, 1, l.i_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, ui.output_gpu, 1, l.i_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vi.output_gpu, 1, l.i_gpu, 1); + //activate_array_ongpu(l.i_gpu, l.outputs*l.batch, LOGISTIC); + + // g = wg + ug + add_3_arrays_activate(wg.output_gpu, ug.output_gpu, NULL, l.outputs*l.batch, TANH, l.g_gpu); + //copy_ongpu(l.outputs*l.batch, wg.output_gpu, 1, l.g_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, ug.output_gpu, 1, l.g_gpu, 1); + //activate_array_ongpu(l.g_gpu, l.outputs*l.batch, TANH); + + // c = f*c + i*g + sum_of_mults(l.f_gpu, l.c_gpu, l.i_gpu, l.g_gpu, l.outputs*l.batch, l.c_gpu); // decreases mAP??? + //copy_ongpu(l.outputs*l.batch, l.i_gpu, 1, l.temp_gpu, 1); + //mul_ongpu(l.outputs*l.batch, l.g_gpu, 1, l.temp_gpu, 1); + //mul_ongpu(l.outputs*l.batch, l.f_gpu, 1, l.c_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, l.temp_gpu, 1, l.c_gpu, 1); + + // o = wo + uo + vo(c_new) + if (l.peephole) { + s.input = l.c_gpu; + forward_convolutional_layer_gpu(vo, s); + } + add_3_arrays_activate(wo.output_gpu, uo.output_gpu, (l.peephole) ? vo.output_gpu : NULL, l.outputs*l.batch, LOGISTIC, l.o_gpu); + //copy_ongpu(l.outputs*l.batch, wo.output_gpu, 1, l.o_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, uo.output_gpu, 1, l.o_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vo.output_gpu, 1, l.o_gpu, 1); + //activate_array_ongpu(l.o_gpu, l.outputs*l.batch, LOGISTIC); + + // h = o * tanh(c) + activate_and_mult(l.c_gpu, l.o_gpu, l.outputs*l.batch, TANH, l.h_gpu); + //simple_copy_ongpu(l.outputs*l.batch, l.c_gpu, l.h_gpu); + //activate_array_ongpu(l.h_gpu, l.outputs*l.batch, TANH); + //mul_ongpu(l.outputs*l.batch, l.o_gpu, 1, l.h_gpu, 1); + + fix_nan_and_inf(l.c_gpu, l.outputs*l.batch); + fix_nan_and_inf(l.h_gpu, l.outputs*l.batch); + if (l.state_constrain) constrain_ongpu(l.outputs*l.batch, l.state_constrain, l.c_gpu, 1); + + if(state.train) simple_copy_ongpu(l.outputs*l.batch, l.c_gpu, l.cell_gpu); + simple_copy_ongpu(l.outputs*l.batch, l.h_gpu, l.output_gpu); // is required for both Detection and Training + + state.input += l.inputs*l.batch; + l.output_gpu += l.outputs*l.batch; + l.cell_gpu += l.outputs*l.batch; + + if (l.peephole) { + increment_layer(&vf, 1); + increment_layer(&vi, 1); + increment_layer(&vo, 1); + } + + increment_layer(&wf, 1); + increment_layer(&wi, 1); + increment_layer(&wg, 1); + increment_layer(&wo, 1); + + increment_layer(&uf, 1); + increment_layer(&ui, 1); + increment_layer(&ug, 1); + increment_layer(&uo, 1); + } +} + +void backward_conv_lstm_layer_gpu(layer l, network_state state) +{ + float *last_output = l.output_gpu + l.outputs*l.batch*(l.steps - 1); + float *last_cell = l.cell_gpu + l.outputs*l.batch*(l.steps - 1); + + network_state s = { 0 }; + s.train = state.train; + s.workspace = state.workspace; + s.net = state.net; + int i; + layer vf = *(l.vf); + layer vi = *(l.vi); + layer vo = *(l.vo); + + layer wf = *(l.wf); + layer wi = *(l.wi); + layer wg = *(l.wg); + layer wo = *(l.wo); + + layer uf = *(l.uf); + layer ui = *(l.ui); + layer ug = *(l.ug); + layer uo = *(l.uo); + + if (l.peephole) { + increment_layer(&vf, l.steps - 1); + increment_layer(&vi, l.steps - 1); + increment_layer(&vo, l.steps - 1); + } + + increment_layer(&wf, l.steps - 1); + increment_layer(&wi, l.steps - 1); + increment_layer(&wg, l.steps - 1); + increment_layer(&wo, l.steps - 1); + + increment_layer(&uf, l.steps - 1); + increment_layer(&ui, l.steps - 1); + increment_layer(&ug, l.steps - 1); + increment_layer(&uo, l.steps - 1); + + state.input += l.inputs*l.batch*(l.steps - 1); + if (state.delta) state.delta += l.inputs*l.batch*(l.steps - 1); + + l.output_gpu += l.outputs*l.batch*(l.steps - 1); + l.cell_gpu += l.outputs*l.batch*(l.steps - 1); + l.delta_gpu += l.outputs*l.batch*(l.steps - 1); + + //fill_ongpu(l.outputs * l.batch, 0, l.dc_gpu, 1); // dont use + const int sequence = get_sequence_value(state.net); + + for (i = l.steps - 1; i >= 0; --i) { + if (i != 0) simple_copy_ongpu(l.outputs*l.batch, l.cell_gpu - l.outputs*l.batch, l.prev_cell_gpu); + //else fill_ongpu(l.outputs * l.batch, 0, l.prev_cell_gpu, 1); // dont use + else if (state.net.current_subdivision % sequence != 0) simple_copy_ongpu(l.outputs*l.batch, l.last_prev_cell_gpu, l.prev_cell_gpu); + + simple_copy_ongpu(l.outputs*l.batch, l.cell_gpu, l.c_gpu); + + if (i != 0) simple_copy_ongpu(l.outputs*l.batch, l.output_gpu - l.outputs*l.batch, l.prev_state_gpu); + //else fill_ongpu(l.outputs * l.batch, 0, l.prev_state_gpu, 1); // dont use + else if(state.net.current_subdivision % sequence != 0) simple_copy_ongpu(l.outputs*l.batch, l.last_prev_state_gpu, l.prev_state_gpu); + + simple_copy_ongpu(l.outputs*l.batch, l.output_gpu, l.h_gpu); + + l.dh_gpu = (i == 0) ? 0 : l.delta_gpu - l.outputs*l.batch; + + // f = wf + uf + vf + add_3_arrays_activate(wf.output_gpu, uf.output_gpu, (l.peephole) ? vf.output_gpu : NULL, l.outputs*l.batch, LOGISTIC, l.f_gpu); + //copy_ongpu(l.outputs*l.batch, wf.output_gpu, 1, l.f_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, uf.output_gpu, 1, l.f_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vf.output_gpu, 1, l.f_gpu, 1); + //activate_array_ongpu(l.f_gpu, l.outputs*l.batch, LOGISTIC); + + // i = wi + ui + vi + add_3_arrays_activate(wi.output_gpu, ui.output_gpu, (l.peephole) ? vi.output_gpu : NULL, l.outputs*l.batch, LOGISTIC, l.i_gpu); + //copy_ongpu(l.outputs*l.batch, wi.output_gpu, 1, l.i_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, ui.output_gpu, 1, l.i_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vi.output_gpu, 1, l.i_gpu, 1); + //activate_array_ongpu(l.i_gpu, l.outputs*l.batch, LOGISTIC); + + // g = wg + ug + add_3_arrays_activate(wg.output_gpu, ug.output_gpu, NULL, l.outputs*l.batch, TANH, l.g_gpu); + //copy_ongpu(l.outputs*l.batch, wg.output_gpu, 1, l.g_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, ug.output_gpu, 1, l.g_gpu, 1); + //activate_array_ongpu(l.g_gpu, l.outputs*l.batch, TANH); + + // o = wo + uo + vo + add_3_arrays_activate(wo.output_gpu, uo.output_gpu, (l.peephole) ? vo.output_gpu : NULL, l.outputs*l.batch, LOGISTIC, l.o_gpu); + //copy_ongpu(l.outputs*l.batch, wo.output_gpu, 1, l.o_gpu, 1); + //axpy_ongpu(l.outputs*l.batch, 1, uo.output_gpu, 1, l.o_gpu, 1); + //if (l.peephole) axpy_ongpu(l.outputs*l.batch, 1, vo.output_gpu, 1, l.o_gpu, 1); + //activate_array_ongpu(l.o_gpu, l.outputs*l.batch, LOGISTIC); + + + simple_copy_ongpu(l.outputs*l.batch, l.delta_gpu, l.temp3_gpu); // temp3 = delta + + simple_copy_ongpu(l.outputs*l.batch, l.c_gpu, l.temp_gpu); + activate_array_ongpu(l.temp_gpu, l.outputs*l.batch, TANH); // temp = tanh(c) + + simple_copy_ongpu(l.outputs*l.batch, l.temp3_gpu, l.temp2_gpu); + mul_ongpu(l.outputs*l.batch, l.o_gpu, 1, l.temp2_gpu, 1); // temp2 = delta * o + + gradient_array_ongpu(l.temp_gpu, l.outputs*l.batch, TANH, l.temp2_gpu); // temp2 = delta * o * grad_tanh(tanh(c)) + //??? + axpy_ongpu(l.outputs*l.batch, 1, l.dc_gpu, 1, l.temp2_gpu, 1); // temp2 = delta * o * grad_tanh(tanh(c)) + delta_c(???) + // temp = tanh(c) + // temp2 = delta * o * grad_tanh(tanh(c)) + delta_c(???) + // temp3 = delta + + simple_copy_ongpu(l.outputs*l.batch, l.c_gpu, l.temp_gpu); + activate_array_ongpu(l.temp_gpu, l.outputs*l.batch, TANH); // temp = tanh(c) + + mul_ongpu(l.outputs*l.batch, l.temp3_gpu, 1, l.temp_gpu, 1); // temp = delta * tanh(c) + gradient_array_ongpu(l.o_gpu, l.outputs*l.batch, LOGISTIC, l.temp_gpu); // temp = delta * tanh(c) * grad_logistic(o) + // delta for o(w,u,v): temp = delta * tanh(c) * grad_logistic(o) + // delta for c,f,i,g(w,u,v): temp2 = delta * o * grad_tanh(tanh(c)) + delta_c(???) + // delta for output: temp3 = delta + + // o + // delta for O(w,u,v): temp = delta * tanh(c) * grad_logistic(o) + if (l.peephole) { + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, vo.delta_gpu); + s.input = l.cell_gpu; + //s.delta = l.dc_gpu; + backward_convolutional_layer_gpu(vo, s); + } + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, wo.delta_gpu); + s.input = l.prev_state_gpu; + //s.delta = l.dh_gpu; + backward_convolutional_layer_gpu(wo, s); + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, uo.delta_gpu); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer_gpu(uo, s); + + // g + simple_copy_ongpu(l.outputs*l.batch, l.temp2_gpu, l.temp_gpu); + mul_ongpu(l.outputs*l.batch, l.i_gpu, 1, l.temp_gpu, 1); + gradient_array_ongpu(l.g_gpu, l.outputs*l.batch, TANH, l.temp_gpu); + // delta for c,f,i,g(w,u,v): temp = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * i * grad_tanh(g) + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, wg.delta_gpu); + s.input = l.prev_state_gpu; + //s.delta = l.dh_gpu; + backward_convolutional_layer_gpu(wg, s); + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, ug.delta_gpu); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer_gpu(ug, s); + + // i + simple_copy_ongpu(l.outputs*l.batch, l.temp2_gpu, l.temp_gpu); + mul_ongpu(l.outputs*l.batch, l.g_gpu, 1, l.temp_gpu, 1); + gradient_array_ongpu(l.i_gpu, l.outputs*l.batch, LOGISTIC, l.temp_gpu); + // delta for c,f,i,g(w,u,v): temp = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * g * grad_logistic(i) + + if (l.peephole) { + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, vi.delta_gpu); + s.input = l.prev_cell_gpu; + //s.delta = l.dc_gpu; + backward_convolutional_layer_gpu(vi, s); + } + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, wi.delta_gpu); + s.input = l.prev_state_gpu; + //s.delta = l.dh_gpu; + backward_convolutional_layer_gpu(wi, s); + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, ui.delta_gpu); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer_gpu(ui, s); + + // f + simple_copy_ongpu(l.outputs*l.batch, l.temp2_gpu, l.temp_gpu); + mul_ongpu(l.outputs*l.batch, l.prev_cell_gpu, 1, l.temp_gpu, 1); + gradient_array_ongpu(l.f_gpu, l.outputs*l.batch, LOGISTIC, l.temp_gpu); + // delta for c,f,i,g(w,u,v): temp = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * c * grad_logistic(f) + + if (l.peephole) { + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, vf.delta_gpu); + s.input = l.prev_cell_gpu; + //s.delta = l.dc_gpu; + backward_convolutional_layer_gpu(vf, s); + } + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, wf.delta_gpu); + s.input = l.prev_state_gpu; + //s.delta = l.dh_gpu; + backward_convolutional_layer_gpu(wf, s); + + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, uf.delta_gpu); + s.input = state.input; + s.delta = state.delta; + backward_convolutional_layer_gpu(uf, s); + + // c + simple_copy_ongpu(l.outputs*l.batch, l.temp2_gpu, l.temp_gpu); + mul_ongpu(l.outputs*l.batch, l.f_gpu, 1, l.temp_gpu, 1); + simple_copy_ongpu(l.outputs*l.batch, l.temp_gpu, l.dc_gpu); + fix_nan_and_inf(l.dc_gpu, l.outputs*l.batch); + // delta for c,f,i,g(w,u,v): delta_c = temp = (delta * o * grad_tanh(tanh(c)) + delta_c(???)) * f // (grad_linear(c)==1) + + state.input -= l.inputs*l.batch; + if (state.delta) state.delta -= l.inputs*l.batch; // new delta: state.delta = prev_layer.delta_gpu; + l.output_gpu -= l.outputs*l.batch; + l.cell_gpu -= l.outputs*l.batch; + l.delta_gpu -= l.outputs*l.batch; + + if (l.peephole) { + increment_layer(&vf, -1); + increment_layer(&vi, -1); + increment_layer(&vo, -1); + } + + increment_layer(&wf, -1); + increment_layer(&wi, -1); + increment_layer(&wg, -1); + increment_layer(&wo, -1); + + increment_layer(&uf, -1); + increment_layer(&ui, -1); + increment_layer(&ug, -1); + increment_layer(&uo, -1); + } + + simple_copy_ongpu(l.outputs*l.batch, last_output, l.last_prev_state_gpu); + simple_copy_ongpu(l.outputs*l.batch, last_cell, l.last_prev_cell_gpu); + + // free state after each 100 iterations + //if (get_current_batch(state.net) % 100) free_state_conv_lstm(l); // dont use +} +#endif diff --git a/src/conv_lstm_layer.h b/src/conv_lstm_layer.h new file mode 100644 index 00000000000..56a57298243 --- /dev/null +++ b/src/conv_lstm_layer.h @@ -0,0 +1,33 @@ +#ifndef CONV_LSTM_LAYER_H +#define CONV_LSTM_LAYER_H + +#include "activations.h" +#include "layer.h" +#include "network.h" +#define USET + +#ifdef __cplusplus +extern "C" { +#endif +layer make_conv_lstm_layer(int batch, int h, int w, int c, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int peephole, int xnor); +void resize_conv_lstm_layer(layer *l, int w, int h); +void free_state_conv_lstm(layer l); +void randomize_state_conv_lstm(layer l); +void remember_state_conv_lstm(layer l); +void restore_state_conv_lstm(layer l); + +void forward_conv_lstm_layer(layer l, network_state state); +void backward_conv_lstm_layer(layer l, network_state state); +void update_conv_lstm_layer(layer l, int batch, float learning_rate, float momentum, float decay); + +#ifdef GPU +void forward_conv_lstm_layer_gpu(layer l, network_state state); +void backward_conv_lstm_layer_gpu(layer l, network_state state); +void update_conv_lstm_layer_gpu(layer l, int batch, float learning_rate, float momentum, float decay); +#endif + +#ifdef __cplusplus +} +#endif + +#endif // CONV_LSTM_LAYER_H diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 892ccc93443..a26b95eaf0d 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -166,20 +166,16 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) { //fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); if(l.binary){ - binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu); + binarize_weights_gpu(l.weights_gpu, l.n, (l.c / l.groups)*l.size*l.size, l.binary_weights_gpu); swap_binary(&l); } if(l.xnor){ if (!l.align_bit_weights_gpu || state.train) { - //binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu); + //binarize_weights_gpu(l.weights_gpu, l.n, (l.c / l.groups)*l.size*l.size, l.binary_weights_gpu); - fast_binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu, l.mean_arr_gpu); + fast_binarize_weights_gpu(l.weights_gpu, l.n, (l.c / l.groups)*l.size*l.size, l.binary_weights_gpu, l.mean_arr_gpu); } - //swap_binary(&l); - //binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu); - //state.input = l.binary_input_gpu; - //cudaDeviceSynchronize(); if (l.align_bit_weights_gpu && !state.train && l.c >= 32) { @@ -187,11 +183,15 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) //cudaError_t status = cudaSuccess; //int input_size = l.c*l.h*l.w*l.batch; - int m = l.n; - int k = l.size*l.size*l.c; + int m = l.n / l.groups; + int k = l.size*l.size*l.c / l.groups; int n = l.out_w*l.out_h; //float * a = l.weights_gpu; + // int i, j; + // for(i = 0; i < l.batch; ++i){ + // for (j = 0; j < l.groups; ++j) { + int ldb_align = l.lda_align; size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; //size_t t_intput_size = new_ldb * n; @@ -484,14 +484,14 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) l.normDstTensorDescF16, output16, // output l.normTensorDesc, - l.scales_gpu, - l.biases_gpu, + l.scales_gpu, // input + l.biases_gpu, // input .01, - l.rolling_mean_gpu, // output (should be FP32) - l.rolling_variance_gpu, // output (should be FP32) + l.rolling_mean_gpu, // input/output (should be FP32) + l.rolling_variance_gpu, // input/output (should be FP32) .00001, - l.mean_gpu, // output (should be FP32) - l.variance_gpu)); // output (should be FP32) + l.mean_gpu, // output (should be FP32) - optional cache to speedup cudnnBatchNormalizationBackward() + l.variance_gpu)); // output (should be FP32) - optional cache to speedup cudnnBatchNormalizationBackward() cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu); //forward_batchnorm_layer_gpu(l, state); @@ -551,22 +551,25 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) #else fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1); - int i; - int m = l.n; - int k = l.size*l.size*l.c; + int i, j; + int m = l.n / l.groups; + int k = l.size*l.size*l.c / l.groups; int n = l.out_w*l.out_h; for(i = 0; i < l.batch; ++i){ - float *im = state.input + i*l.c*l.h*l.w; - float * a = l.weights_gpu; - float * b = state.workspace; - float * c = l.output_gpu; - if (l.size == 1) { - b = im; - } - else { - im2col_ongpu(im, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + for (j = 0; j < l.groups; ++j) { + //float *im = state.input + i*l.c*l.h*l.w; + float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; + float *a = l.weights_gpu + j*l.nweights / l.groups; + float *b = state.workspace; + float *c = l.output_gpu + (i*l.groups + j)*n*m; + if (l.size == 1) { + b = im; + } + else { + im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + } + gemm_ongpu(0, 0, m, n, k, 1., a, k, b, n, 1., c + i*m*n, n); } - gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n); } if (l.batch_normalize) { @@ -655,13 +658,13 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state &one, &one, l.normDstTensorDescF16, - l.x_gpu, // input + l.x_gpu, // input (input in BN-forward-inference) l.normDstTensorDescF16, delta16, // input l.normDstTensorDescF16, - l.x_norm_gpu, // output + l.x_norm_gpu, // output (new delta) l.normTensorDesc, - l.scales_gpu, // output (should be FP32) + l.scales_gpu, // input (should be FP32) l.scale_updates_gpu, // output (should be FP32) l.bias_updates_gpu, // output (should be FP32) .00001, @@ -782,32 +785,38 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state backward_batchnorm_layer_gpu(l, state); } - int m = l.n; - int n = l.size*l.size*l.c; + int m = l.n / l.groups; + int n = l.size*l.size*l.c / l.groups; int k = l.out_w*l.out_h; - int i; + int i, j; for(i = 0; i < l.batch; ++i){ - float * a = l.delta_gpu; - float * b = state.workspace; - float * c = l.weight_updates_gpu; + for (j = 0; j < l.groups; ++j) { + float * a = l.delta_gpu + (i*l.groups + j)*m*k; + float * b = state.workspace; + float * c = l.weight_updates_gpu + j*l.nweights / l.groups; + + float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; - im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); - gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n); + im2col_ongpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + gemm_ongpu(0, 1, m, n, k, 1, a + i*m*k, k, b, k, 1, c, n); - if(state.delta){ - if(l.binary || l.xnor) swap_binary(&l); - float * a = l.weights_gpu; - float * b = l.delta_gpu; - float * c = state.workspace; + if (state.delta) { + if (l.binary || l.xnor) swap_binary(&l); + float * a = l.weights_gpu + j*l.nweights / l.groups; + float * b = l.delta_gpu + (i*l.groups + j)*m*k; + float * c = state.workspace; - gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k); + gemm_ongpu(1, 0, n, k, m, 1, a, n, b + i*k*m, k, 0, c, k); - col2im_ongpu(state.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w); - if(l.binary || l.xnor) { - swap_binary(&l); + float *delta = state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w; + + col2im_ongpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, delta); + if (l.binary || l.xnor) { + swap_binary(&l); + } + if (l.xnor) gradient_array_ongpu(original_input + i*l.c*l.h*l.w, l.c*l.h*l.w, HARDTAN, state.delta + i*l.c*l.h*l.w); } - if(l.xnor) gradient_array_ongpu(original_input + i*l.c*l.h*l.w, l.c*l.h*l.w, HARDTAN, state.delta + i*l.c*l.h*l.w); } } #endif @@ -821,43 +830,43 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state } } -void pull_convolutional_layer(convolutional_layer layer) +void pull_convolutional_layer(convolutional_layer l) { - cuda_pull_array_async(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array_async(layer.biases_gpu, layer.biases, layer.n); - cuda_pull_array_async(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array_async(layer.bias_updates_gpu, layer.bias_updates, layer.n); - if (layer.batch_normalize){ - cuda_pull_array_async(layer.scales_gpu, layer.scales, layer.n); - cuda_pull_array_async(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); - cuda_pull_array_async(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights); + cuda_pull_array_async(l.biases_gpu, l.biases, l.n); + cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights); + cuda_pull_array_async(l.bias_updates_gpu, l.bias_updates, l.n); + if (l.batch_normalize){ + cuda_pull_array_async(l.scales_gpu, l.scales, l.n); + cuda_pull_array_async(l.rolling_mean_gpu, l.rolling_mean, l.n); + cuda_pull_array_async(l.rolling_variance_gpu, l.rolling_variance, l.n); } - if (layer.adam){ - cuda_pull_array_async(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size); - cuda_pull_array_async(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size); + if (l.adam){ + cuda_pull_array_async(l.m_gpu, l.m, l.nweights); + cuda_pull_array_async(l.v_gpu, l.v, l.nweights); } CHECK_CUDA(cudaPeekAtLastError()); cudaStreamSynchronize(get_cuda_stream()); } -void push_convolutional_layer(convolutional_layer layer) +void push_convolutional_layer(convolutional_layer l) { - cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size); + cuda_push_array(l.weights_gpu, l.weights, l.nweights); #ifdef CUDNN_HALF - assert((layer.c*layer.n*layer.size*layer.size) > 0); - cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, layer.weights_gpu16); + assert(l.nweights > 0); + cuda_convert_f32_to_f16(l.weights_gpu, l.nweights, l.weights_gpu16); #endif - cuda_push_array(layer.biases_gpu, layer.biases, layer.n); - cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size); - cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n); - if (layer.batch_normalize){ - cuda_push_array(layer.scales_gpu, layer.scales, layer.n); - cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n); - cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n); + cuda_push_array(l.biases_gpu, l.biases, l.n); + cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights); + cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n); + if (l.batch_normalize){ + cuda_push_array(l.scales_gpu, l.scales, l.n); + cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.n); + cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.n); } - if (layer.adam){ - cuda_push_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size); - cuda_push_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size); + if (l.adam){ + cuda_push_array(l.m_gpu, l.m, l.nweights); + cuda_push_array(l.v_gpu, l.v, l.nweights); } CHECK_CUDA(cudaPeekAtLastError()); } @@ -868,11 +877,10 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init //float momentum = a.momentum; //float decay = a.decay; //int batch = a.batch; - int size = l.size*l.size*l.c*l.n; // old if (l.adam) { //adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.nweights, batch, a.t); - adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, size, batch, l.t); + adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.nweights, batch, l.t); adam_update_gpu(l.biases_gpu, l.bias_updates_gpu, l.bias_m_gpu, l.bias_v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.n, batch, l.t); if (l.scales_gpu) { @@ -883,9 +891,9 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init //axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); //axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); //scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1); - axpy_ongpu(size, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); - axpy_ongpu(size, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); - scal_ongpu(size, momentum, l.weight_updates_gpu, 1); + axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1); + axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1); + scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1); axpy_ongpu(l.n, learning_rate / batch, l.bias_updates_gpu, 1, l.biases_gpu, 1); scal_ongpu(l.n, momentum, l.bias_updates_gpu, 1); diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index 84d36d93ef0..d983ab61725 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -140,7 +140,7 @@ size_t get_workspace_size32(layer l){ if (workspace_size < re_packed_input_size) workspace_size = re_packed_input_size; return workspace_size; } - return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float); + return (size_t)l.out_h*l.out_w*l.size*l.size*(l.c / l.groups)*sizeof(float); } size_t get_workspace_size16(layer l) { @@ -231,9 +231,14 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) // 3. FP32 Master Copy of Weights // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH)); + CHECK_CUDNN(cudnnSetConvolutionGroupCount(l->convDesc, l->groups)); #if((CUDNN_MAJOR*10 + CUDNN_MINOR) >= 72) // cuDNN >= 7.2 CHECK_CUDNN(cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); #endif +#else //if(CUDNN_MAJOR >= 7) + if (l->groups > 1) { + error("CUDNN < 7 doesn't support groups, please upgrade!"); + } #endif // INT8_CONFIG, INT8_EXT_CONFIG, INT8x4_CONFIG and INT8x4_EXT_CONFIG are only supported @@ -243,23 +248,23 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) // backward delta CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w)); - CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c / l->groups, l->size, l->size)); // forward CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w)); - CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c / l->groups, l->size, l->size)); //#ifdef CUDNN_HALF // backward delta CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dsrcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->ddstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); - CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->dweightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c / l->groups, l->size, l->size)); // forward CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->srcTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->c, l->h, l->w)); CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->dstTensorDesc16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); - CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size)); + CHECK_CUDNN(cudnnSetFilter4dDescriptor(l->weightDesc16, CUDNN_DATA_HALF, CUDNN_TENSOR_NCHW, l->n, l->c / l->groups, l->size, l->size)); // batch norm CHECK_CUDNN(cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, CUDNN_DATA_HALF, l->batch, l->out_c, l->out_h, l->out_w)); @@ -326,17 +331,21 @@ void cudnn_convolutional_setup(layer *l, int cudnn_preference) #endif #endif -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index) { int total_batch = batch*steps; int i; convolutional_layer l = { (LAYER_TYPE)0 }; l.type = CONVOLUTIONAL; + if (xnor) groups = 1; // disable groups for XNOR-net + if (groups < 1) groups = 1; + l.index = index; l.h = h; l.w = w; l.c = c; + l.groups = groups; l.n = n; l.binary = binary; l.xnor = xnor; @@ -348,17 +357,17 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.pad = padding; l.batch_normalize = batch_normalize; l.learning_rate_scale = 1; - l.nweights = l.c*l.n*l.size*l.size; + l.nweights = (c / groups) * n * size * size; - l.weights = (float*)calloc(c * n * size * size, sizeof(float)); - l.weight_updates = (float*)calloc(c * n * size * size, sizeof(float)); + l.weights = (float*)calloc(l.nweights, sizeof(float)); + l.weight_updates = (float*)calloc(l.nweights, sizeof(float)); l.biases = (float*)calloc(n, sizeof(float)); l.bias_updates = (float*)calloc(n, sizeof(float)); // float scale = 1./sqrt(size*size*c); - float scale = sqrt(2./(size*size*c)); - for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1); + float scale = sqrt(2./(size*size*c/groups)); + for(i = 0; i < l.nweights; ++i) l.weights[i] = scale*rand_uniform(-1, 1); // rand_normal(); int out_h = convolutional_out_height(l); int out_w = convolutional_out_width(l); l.out_h = out_h; @@ -375,12 +384,12 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.backward = backward_convolutional_layer; l.update = update_convolutional_layer; if(binary){ - l.binary_weights = (float*)calloc(c * n * size * size, sizeof(float)); - l.cweights = (char*)calloc(c * n * size * size, sizeof(char)); + l.binary_weights = (float*)calloc(l.nweights, sizeof(float)); + l.cweights = (char*)calloc(l.nweights, sizeof(char)); l.scales = (float*)calloc(n, sizeof(float)); } if(xnor){ - l.binary_weights = (float*)calloc(c * n * size * size, sizeof(float)); + l.binary_weights = (float*)calloc(l.nweights, sizeof(float)); l.binary_input = (float*)calloc(l.inputs * l.batch, sizeof(float)); int align = 32;// 8; @@ -420,8 +429,8 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, } if(adam){ l.adam = 1; - l.m = (float*)calloc(c * n * size * size, sizeof(float)); - l.v = (float*)calloc(c * n * size * size, sizeof(float)); + l.m = (float*)calloc(l.nweights, sizeof(float)); + l.v = (float*)calloc(l.nweights, sizeof(float)); l.bias_m = (float*)calloc(n, sizeof(float)); l.scale_m = (float*)calloc(n, sizeof(float)); l.bias_v = (float*)calloc(n, sizeof(float)); @@ -435,19 +444,19 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, if(gpu_index >= 0){ if (adam) { - l.m_gpu = cuda_make_array(l.m, c*n*size*size); - l.v_gpu = cuda_make_array(l.v, c*n*size*size); + l.m_gpu = cuda_make_array(l.m, l.nweights); + l.v_gpu = cuda_make_array(l.v, l.nweights); l.bias_m_gpu = cuda_make_array(l.bias_m, n); l.bias_v_gpu = cuda_make_array(l.bias_v, n); l.scale_m_gpu = cuda_make_array(l.scale_m, n); l.scale_v_gpu = cuda_make_array(l.scale_v, n); } - l.weights_gpu = cuda_make_array(l.weights, c*n*size*size); - l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size); + l.weights_gpu = cuda_make_array(l.weights, l.nweights); + l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights); #ifdef CUDNN_HALF - l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weights, c*n*size*size / 2); - l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weight_updates, c*n*size*size / 2); + l.weights_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); + l.weight_updates_gpu16 = cuda_make_array(NULL, l.nweights / 2 + 1); #endif l.biases_gpu = cuda_make_array(l.biases, n); @@ -457,10 +466,10 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.delta_gpu = cuda_make_array(l.delta, total_batch*out_h*out_w*n); if(binary){ - l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size); + l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights); } if(xnor){ - l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size); + l.binary_weights_gpu = cuda_make_array(l.weights, l.nweights); l.mean_arr_gpu = cuda_make_array(0, l.n); l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch); } @@ -490,7 +499,7 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.workspace_size = get_convolutional_workspace_size(l); //fprintf(stderr, "conv %5d %2d x%2d /%2d %4d x%4d x%4d -> %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c); - l.bflops = (2.0 * l.n * l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.; + l.bflops = (2.0 * l.nweights * l.out_h*l.out_w) / 1000000000.; if (l.xnor && l.use_bin_output) fprintf(stderr, "convXB"); else if (l.xnor) fprintf(stderr, "convX "); else fprintf(stderr, "conv "); @@ -504,8 +513,8 @@ void denormalize_convolutional_layer(convolutional_layer l) int i, j; for(i = 0; i < l.n; ++i){ float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001); - for(j = 0; j < l.c*l.size*l.size; ++j){ - l.weights[i*l.c*l.size*l.size + j] *= scale; + for(j = 0; j < l.nweights; ++j){ + l.weights[i*l.nweights + j] *= scale; } l.biases[i] -= l.rolling_mean[i] * scale; l.scales[i] = 1; @@ -516,7 +525,7 @@ void denormalize_convolutional_layer(convolutional_layer l) void test_convolutional_layer() { - convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); + convolutional_layer l = make_convolutional_layer(1, 1, 5, 5, 3, 2, 1, 5, 2, 1, LEAKY, 1, 0, 0, 0, 0, 0); l.batch_normalize = 1; float data[] = {1,1,1,1,1, 1,1,1,1,1, @@ -691,8 +700,8 @@ void bit_to_float(unsigned char *src, float *dst, size_t size, size_t filters, f void binary_align_weights(convolutional_layer *l) { - int m = l->n; - int k = l->size*l->size*l->c; + int m = l->n; // (l->n / l->groups) + int k = l->size*l->size*l->c; // ->size*l->size*(l->c / l->groups) size_t new_lda = k + (l->lda_align - k % l->lda_align); // (k / 8 + 1) * 8; l->new_lda = new_lda; @@ -823,13 +832,13 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) { int out_h = convolutional_out_height(l); int out_w = convolutional_out_width(l); - int i; + int i, j; fill_cpu(l.outputs*l.batch, 0, l.output, 1); if (l.xnor && (!l.align_bit_weights || state.train)) { if (!l.align_bit_weights || state.train) { - binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights); + binarize_weights(l.weights, l.n, l.nweights, l.binary_weights); //printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights); } swap_binary(&l); @@ -837,147 +846,150 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) state.input = l.binary_input; } - int m = l.n; - int k = l.size*l.size*l.c; + int m = l.n / l.groups; + int k = l.size*l.size*l.c / l.groups; int n = out_h*out_w; - float *a = l.weights; - float *b = state.workspace; - float *c = l.output; - static int u = 0; u++; for(i = 0; i < l.batch; ++i){ + for (j = 0; j < l.groups; ++j) { - //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); - //gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n); - if (l.xnor && l.align_bit_weights && !state.train) - { - memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float)); + float *a = l.weights + j*l.nweights / l.groups; + float *b = state.workspace; + float *c = l.output + (i*l.groups + j)*n*m; - if(l.c % 32 == 0) + //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n); + //gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n); + if (l.xnor && l.align_bit_weights && !state.train) { - //printf(" l.index = %d - new XNOR \n", l.index); + memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float)); - int ldb_align = l.lda_align; - size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; - //size_t t_intput_size = new_ldb * l.bit_align;// n; - //size_t t_bit_input_size = t_intput_size / 8;// +1; + if (l.c % 32 == 0) + { + //printf(" l.index = %d - new XNOR \n", l.index); - int re_packed_input_size = l.c * l.w * l.h; - memset(state.workspace, 0, re_packed_input_size * sizeof(float)); + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8; + //size_t t_intput_size = new_ldb * l.bit_align;// n; + //size_t t_bit_input_size = t_intput_size / 8;// +1; - const size_t new_c = l.c / 32; - size_t in_re_packed_input_size = new_c * l.w * l.h + 1; - memset(l.bin_re_packed_input, 0, in_re_packed_input_size * sizeof(uint32_t)); + int re_packed_input_size = l.c * l.w * l.h; + memset(state.workspace, 0, re_packed_input_size * sizeof(float)); - //float *re_packed_input = calloc(l.c * l.w * l.h, sizeof(float)); - //uint32_t *bin_re_packed_input = calloc(new_c * l.w * l.h + 1, sizeof(uint32_t)); + const size_t new_c = l.c / 32; + size_t in_re_packed_input_size = new_c * l.w * l.h + 1; + memset(l.bin_re_packed_input, 0, in_re_packed_input_size * sizeof(uint32_t)); - // float32x4 by channel (as in cuDNN) - repack_input(state.input, state.workspace, l.w, l.h, l.c); + //float *re_packed_input = calloc(l.c * l.w * l.h, sizeof(float)); + //uint32_t *bin_re_packed_input = calloc(new_c * l.w * l.h + 1, sizeof(uint32_t)); - // 32 x floats -> 1 x uint32_t - float_to_bit(state.workspace, (unsigned char *)l.bin_re_packed_input, l.c * l.w * l.h); + // float32x4 by channel (as in cuDNN) + repack_input(state.input, state.workspace, l.w, l.h, l.c); - //free(re_packed_input); + // 32 x floats -> 1 x uint32_t + float_to_bit(state.workspace, (unsigned char *)l.bin_re_packed_input, l.c * l.w * l.h); - // slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) - //convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output, - // l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); + //free(re_packed_input); - // // then exit from if() + // slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN) + //convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output, + // l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr); + // // then exit from if() - im2col_cpu_custom((float *)l.bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); - //im2col_cpu((float *)bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, b); - //free(bin_re_packed_input); + im2col_cpu_custom((float *)l.bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, state.workspace); + //im2col_cpu((float *)bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, b); - int new_k = l.size*l.size*l.c / 32; + //free(bin_re_packed_input); - // good for (l.c == 64) - //gemm_nn_bin_32bit_packed(m, n, new_k, 1, - // l.align_bit_weights, l.new_lda/32, - // b, n, - // c, n, l.mean_arr); + int new_k = l.size*l.size*l.c / 32; -// // then exit from if() + // good for (l.c == 64) + //gemm_nn_bin_32bit_packed(m, n, new_k, 1, + // l.align_bit_weights, l.new_lda/32, + // b, n, + // c, n, l.mean_arr); - transpose_uint32((uint32_t *)state.workspace, (uint32_t*)l.t_bit_input, new_k, n, n, new_ldb); + // // then exit from if() - // the main GEMM function - gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr); + transpose_uint32((uint32_t *)state.workspace, (uint32_t*)l.t_bit_input, new_k, n, n, new_ldb); - // // alternative GEMM - //gemm_nn_bin_transposed_32bit_packed(m, n, new_k, 1, - // l.align_bit_weights, l.new_lda/32, - // t_bit_input, new_ldb / 32, - // c, n, l.mean_arr); + // the main GEMM function + gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr); - //free(t_bit_input); + // // alternative GEMM + //gemm_nn_bin_transposed_32bit_packed(m, n, new_k, 1, + // l.align_bit_weights, l.new_lda/32, + // t_bit_input, new_ldb / 32, + // c, n, l.mean_arr); - } - else - { // else (l.c % 32 != 0) + //free(t_bit_input); - //-------------------------------------------------------- - //printf(" l.index = %d - old XNOR \n", l.index); + } + else + { // else (l.c % 32 != 0) - //im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); - im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align); + //-------------------------------------------------------- + //printf(" l.index = %d - old XNOR \n", l.index); - //size_t output_size = l.outputs; - //float *count_output = calloc(output_size, sizeof(float)); - //size_t bit_output_size = output_size / 8 + 1; - //char *bit_output = calloc(bit_output_size, sizeof(char)); + //im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align); + im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align); - //size_t intput_size = n * k; // (out_h*out_w) X (l.size*l.size*l.c) : after im2col() - //size_t bit_input_size = intput_size / 8 + 1; - //char *bit_input = calloc(bit_input_size, sizeof(char)); + //size_t output_size = l.outputs; + //float *count_output = calloc(output_size, sizeof(float)); + //size_t bit_output_size = output_size / 8 + 1; + //char *bit_output = calloc(bit_output_size, sizeof(char)); - //size_t weights_size = k * m; //l.size*l.size*l.c*l.n; - //size_t bit_weights_size = weights_size / 8 + 1; + //size_t intput_size = n * k; // (out_h*out_w) X (l.size*l.size*l.c) : after im2col() + //size_t bit_input_size = intput_size / 8 + 1; + //char *bit_input = calloc(bit_input_size, sizeof(char)); - //char *bit_weights = calloc(bit_weights_size, sizeof(char)); - //float *mean_arr = calloc(l.n, sizeof(float)); + //size_t weights_size = k * m; //l.size*l.size*l.c*l.n; + //size_t bit_weights_size = weights_size / 8 + 1; - // transpose B from NxK to KxN (x-axis (ldb = l.size*l.size*l.c) - should be multiple of 8 bits) - { - //size_t ldb_align = 256; // 256 bit for AVX2 - int ldb_align = l.lda_align; - size_t new_ldb = k + (ldb_align - k%ldb_align); - size_t t_intput_size = binary_transpose_align_input(k, n, state.workspace, &l.t_bit_input, ldb_align, l.bit_align); + //char *bit_weights = calloc(bit_weights_size, sizeof(char)); + //float *mean_arr = calloc(l.n, sizeof(float)); - // 5x times faster than gemm()-float32 - gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr); + // transpose B from NxK to KxN (x-axis (ldb = l.size*l.size*l.c) - should be multiple of 8 bits) + { + //size_t ldb_align = 256; // 256 bit for AVX2 + int ldb_align = l.lda_align; + size_t new_ldb = k + (ldb_align - k%ldb_align); + size_t t_intput_size = binary_transpose_align_input(k, n, state.workspace, &l.t_bit_input, ldb_align, l.bit_align); - //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr); + // 5x times faster than gemm()-float32 + gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr); - //free(t_input); - //free(t_bit_input); - //} - } + //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr); - } + //free(t_input); + //free(t_bit_input); + //} + } - add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); + } - //activate_array(l.output, m*n*l.batch, l.activation); - activate_array_cpu_custom(l.output, m*n*l.batch, l.activation); - return; + add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); - } - else { - //printf(" l.index = %d - FP32 \n", l.index); - im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b); + //activate_array(l.output, m*n*l.batch, l.activation); + activate_array_cpu_custom(l.output, m*n*l.batch, l.activation); + return; - gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); - // bit-count to float + } + else { + //printf(" l.index = %d - FP32 \n", l.index); + im2col_cpu(state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w, + l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); + + gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n); + // bit-count to float + } + c += n*m; + state.input += l.c*l.h*l.w; } - c += n*m; - state.input += l.c*l.h*l.w; } if(l.batch_normalize){ @@ -986,78 +998,84 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); //activate_array(l.output, m*n*l.batch, l.activation); - activate_array_cpu_custom(l.output, m*n*l.batch, l.activation); + activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation); if(l.binary || l.xnor) swap_binary(&l); } + void backward_convolutional_layer(convolutional_layer l, network_state state) { - int i; - int m = l.n; - int n = l.size*l.size*l.c; - int k = convolutional_out_height(l)* - convolutional_out_width(l); + int i, j; + int m = l.n / l.groups; + int n = l.size*l.size*l.c / l.groups; + int k = l.out_w*l.out_h; - gradient_array(l.output, m*k*l.batch, l.activation, l.delta); - backward_bias(l.bias_updates, l.delta, l.batch, l.n, k); + gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); - if(l.batch_normalize){ + if (l.batch_normalize) { backward_batchnorm_layer(l, state); } + else { + backward_bias(l.bias_updates, l.delta, l.batch, l.n, k); + } - for(i = 0; i < l.batch; ++i){ - float *a = l.delta + i*m*k; - float *b = state.workspace; - float *c = l.weight_updates; + for (i = 0; i < l.batch; ++i) { + for (j = 0; j < l.groups; ++j) { + float *a = l.delta + (i*l.groups + j)*m*k; + float *b = state.workspace; + float *c = l.weight_updates + j*l.nweights / l.groups; - float *im = state.input+i*l.c*l.h*l.w; + float *im = state.input + (i*l.groups + j)*l.c / l.groups*l.h*l.w; - im2col_cpu(im, l.c, l.h, l.w, + im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b); - gemm(0,1,m,n,k,1,a,k,b,k,1,c,n); + gemm(0, 1, m, n, k, 1, a, k, b, k, 1, c, n); - if(state.delta){ - a = l.weights; - b = l.delta + i*m*k; - c = state.workspace; + if (state.delta) { + a = l.weights + j*l.nweights / l.groups; + b = l.delta + (i*l.groups + j)*m*k; + c = state.workspace; - gemm(1,0,n,k,m,1,a,n,b,k,0,c,k); + gemm(1, 0, n, k, m, 1, a, n, b, k, 0, c, k); - col2im_cpu(state.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w); + col2im_cpu(state.workspace, l.c / l.groups, l.h, l.w, l.size, l.stride, + l.pad, state.delta + (i*l.groups + j)*l.c / l.groups*l.h*l.w); + } } } } void update_convolutional_layer(convolutional_layer l, int batch, float learning_rate, float momentum, float decay) { - int size = l.size*l.size*l.c*l.n; - axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1); + //int size = l.size*l.size*l.c*l.n; + axpy_cpu(l.n, learning_rate / batch, l.bias_updates, 1, l.biases, 1); scal_cpu(l.n, momentum, l.bias_updates, 1); - if(l.scales){ - axpy_cpu(l.n, learning_rate/batch, l.scale_updates, 1, l.scales, 1); + if (l.scales) { + axpy_cpu(l.n, learning_rate / batch, l.scale_updates, 1, l.scales, 1); scal_cpu(l.n, momentum, l.scale_updates, 1); } - axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1); - axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1); - scal_cpu(size, momentum, l.weight_updates, 1); + axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1); + axpy_cpu(l.nweights, learning_rate / batch, l.weight_updates, 1, l.weights, 1); + scal_cpu(l.nweights, momentum, l.weight_updates, 1); } + image get_convolutional_weight(convolutional_layer l, int i) { int h = l.size; int w = l.size; - int c = l.c; - return float_to_image(w,h,c,l.weights+i*h*w*c); + int c = l.c / l.groups; + return float_to_image(w, h, c, l.weights + i*h*w*c); } void rgbgr_weights(convolutional_layer l) { int i; - for(i = 0; i < l.n; ++i){ + for (i = 0; i < l.n; ++i) { image im = get_convolutional_weight(l, i); if (im.c == 3) { rgbgr_image(im); @@ -1068,7 +1086,7 @@ void rgbgr_weights(convolutional_layer l) void rescale_weights(convolutional_layer l, float scale, float trans) { int i; - for(i = 0; i < l.n; ++i){ + for (i = 0; i < l.n; ++i) { image im = get_convolutional_weight(l, i); if (im.c == 3) { scale_image(im, scale); @@ -1080,12 +1098,18 @@ void rescale_weights(convolutional_layer l, float scale, float trans) image *get_weights(convolutional_layer l) { - image* weights = (image*)calloc(l.n, sizeof(image)); + image *weights = (image *)calloc(l.n, sizeof(image)); int i; - for(i = 0; i < l.n; ++i){ + for (i = 0; i < l.n; ++i) { weights[i] = copy_image(get_convolutional_weight(l, i)); - //normalize_image(weights[i]); + normalize_image(weights[i]); + /* + char buff[256]; + sprintf(buff, "filter%d", i); + save_image(weights[i], buff); + */ } + //error("hey"); return weights; } @@ -1102,4 +1126,4 @@ image *visualize_convolutional_layer(convolutional_layer l, char *window, image //save_image(dc, buff); free_image(dc); return single_weights; -} +} \ No newline at end of file diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h index da7b8feb383..dc00dabf265 100644 --- a/src/convolutional_layer.h +++ b/src/convolutional_layer.h @@ -30,7 +30,7 @@ void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16); #endif size_t get_convolutional_workspace_size(layer l); -convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); +convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index); void denormalize_convolutional_layer(convolutional_layer l); void resize_convolutional_layer(convolutional_layer *layer, int w, int h); void forward_convolutional_layer(const convolutional_layer layer, network_state state); diff --git a/src/crnn_layer.c b/src/crnn_layer.c index c591276adf9..8534c69a7f1 100644 --- a/src/crnn_layer.c +++ b/src/crnn_layer.c @@ -26,7 +26,7 @@ static void increment_layer(layer *l, int steps) #endif } -layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor) +layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor) { fprintf(stderr, "CRNN Layer: %d x %d x %d image, %d filters\n", h,w,c,output_filters); batch = batch / steps; @@ -40,6 +40,7 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.h = h; l.w = w; l.c = c; + l.groups = groups; l.out_c = output_filters; l.inputs = h * w * c; l.hidden = h * w * hidden_filters; @@ -47,18 +48,18 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.state = (float*)calloc(l.hidden * l.batch * (l.steps + 1), sizeof(float)); - l.input_layer = (layer*)malloc(sizeof(layer)); - *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.input_layer = (layer*)calloc(1, sizeof(layer)); + *(l.input_layer) = make_convolutional_layer(batch, steps, h, w, c, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.input_layer->batch = batch; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; - l.self_layer = (layer*)malloc(sizeof(layer)); - *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.self_layer = (layer*)calloc(1, sizeof(layer)); + *(l.self_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, hidden_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.self_layer->batch = batch; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; - l.output_layer = (layer*)malloc(sizeof(layer)); - *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); + l.output_layer = (layer*)calloc(1, sizeof(layer)); + *(l.output_layer) = make_convolutional_layer(batch, steps, h, w, hidden_filters, output_filters, groups, size, stride, pad, activation, batch_normalize, 0, xnor, 0, 0, 0); l.output_layer->batch = batch; if (l.workspace_size < l.output_layer->workspace_size) l.workspace_size = l.output_layer->workspace_size; @@ -85,6 +86,8 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou l.delta_gpu = l.output_layer->delta_gpu; #endif + l.bflops = l.input_layer->bflops + l.self_layer->bflops + l.output_layer->bflops; + return l; } @@ -128,6 +131,16 @@ void resize_crnn_layer(layer *l, int w, int h) #endif } +void free_state_crnn(layer l) +{ + int i; + for (i = 0; i < l.outputs * l.batch; ++i) l.self_layer->output[i] = rand_uniform(-1, 1); + +#ifdef GPU + cuda_push_array(l.self_layer->output_gpu, l.self_layer->output, l.outputs * l.batch); +#endif // GPU +} + void update_crnn_layer(layer l, int batch, float learning_rate, float momentum, float decay) { update_convolutional_layer(*(l.input_layer), batch, learning_rate, momentum, decay); diff --git a/src/crnn_layer.h b/src/crnn_layer.h index 33487020103..55feb599bb5 100644 --- a/src/crnn_layer.h +++ b/src/crnn_layer.h @@ -9,8 +9,9 @@ #ifdef __cplusplus extern "C" { #endif -layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor); +layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int output_filters, int groups, int steps, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int xnor); void resize_crnn_layer(layer *l, int w, int h); +void free_state_crnn(layer l); void forward_crnn_layer(layer l, network_state state); void backward_crnn_layer(layer l, network_state state); diff --git a/src/darknet.c b/src/darknet.c index 06092d48192..f7ac6593cd5 100644 --- a/src/darknet.c +++ b/src/darknet.c @@ -476,7 +476,7 @@ int main(int argc, char **argv) float thresh = find_float_arg(argc, argv, "-thresh", .24); int ext_output = find_arg(argc, argv, "-ext_output"); char *filename = (argc > 4) ? argv[4]: 0; - test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL); + test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh, 0.5, 0, ext_output, 0, NULL, 0); } else if (0 == strcmp(argv[1], "cifar")){ run_cifar(argc, argv); } else if (0 == strcmp(argv[1], "go")){ diff --git a/src/data.c b/src/data.c index 5de9e92e16a..53959b4c00c 100644 --- a/src/data.c +++ b/src/data.c @@ -231,6 +231,15 @@ void correct_boxes(box_label *boxes, int n, float dx, float dy, float sx, float boxes[i].h = 999999; continue; } + if ((boxes[i].x + boxes[i].w / 2) < 0 || (boxes[i].y + boxes[i].h / 2) < 0 || + (boxes[i].x - boxes[i].w / 2) > 1 || (boxes[i].y - boxes[i].h / 2) > 1) + { + boxes[i].x = 999999; + boxes[i].y = 999999; + boxes[i].w = 999999; + boxes[i].h = 999999; + continue; + } boxes[i].left = boxes[i].left * sx - dx; boxes[i].right = boxes[i].right * sx - dx; boxes[i].top = boxes[i].top * sy - dy; @@ -378,7 +387,7 @@ void fill_truth_detection(const char *path, int num_boxes, float *truth, int cla continue; } if (x == 999999 || y == 999999) { - printf("\n Wrong annotation: x = 0, y = 0 \n"); + printf("\n Wrong annotation: x = 0, y = 0, < 0 or > 1 \n"); sprintf(buff, "echo %s \"Wrong annotation: x = 0 or y = 0\" >> bad_label.list", labelpath); system(buff); ++sub; @@ -769,9 +778,10 @@ static box float_to_box_stride(float *f, int stride) #include "http_stream.h" -data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs) { + const int random_index = random_gen(); c = c ? c : 3; char **random_paths; if (track) random_paths = get_sequential_paths(paths, n, m, mini_batch, augment_speed); @@ -785,7 +795,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo d.X.cols = h*w*c; float r1 = 0, r2 = 0, r3 = 0, r4 = 0; - float dhue = 0, dsat = 0, dexp = 0, flip = 0; + float dhue = 0, dsat = 0, dexp = 0, flip = 0, blur = 0; int augmentation_calculated = 0; d.y = make_matrix(n, 5*boxes); @@ -819,6 +829,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo dexp = rand_scale(exposure); flip = use_flip ? random_gen() % 2 : 0; + blur = rand_int(0, 1) ? (use_blur) : 0; } int pleft = rand_precalc_random(-dw, dw, r1); @@ -835,15 +846,17 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo float dx = ((float)pleft/ow)/sx; float dy = ((float)ptop /oh)/sy; - image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp); - d.X.vals[i] = ai.data; + fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1. / sx, 1. / sy, w, h); - fill_truth_detection(filename, boxes, d.y.vals[i], classes, flip, dx, dy, 1./sx, 1./sy, w, h); + image ai = image_data_augmentation(src, w, h, pleft, ptop, swidth, sheight, flip, jitter, dhue, dsat, dexp, + blur, boxes, d.y.vals[i]); + + d.X.vals[i] = ai.data; if(show_imgs) { char buff[1000]; - sprintf(buff, "aug_%s_%d", basecfg(random_paths[i]), random_gen()); + sprintf(buff, "aug_%d_%d_%s_%d", random_index, i, basecfg(random_paths[i]), random_gen()); int t; for (t = 0; t < boxes; ++t) { box b = float_to_box_stride(d.y.vals[i] + t*(4 + 1), 1); @@ -869,7 +882,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo return d; } #else // OPENCV -data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs) { c = c ? c : 3; @@ -989,7 +1002,7 @@ void *load_thread(void *ptr) } else if (a.type == REGION_DATA){ *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure); } else if (a.type == DETECTION_DATA){ - *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.jitter, + *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.c, a.num_boxes, a.classes, a.flip, a.blur, a.jitter, a.hue, a.saturation, a.exposure, a.mini_batch, a.track, a.augment_speed, a.show_imgs); } else if (a.type == SWAG_DATA){ *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter); diff --git a/src/data.h b/src/data.h index ea2367eec9d..3305db50ca6 100644 --- a/src/data.h +++ b/src/data.h @@ -86,7 +86,7 @@ void print_letters(float *pred, int n); data load_data_captcha(char **paths, int n, int m, int k, int w, int h); data load_data_captcha_encode(char **paths, int n, int m, int w, int h); data load_data_old(char **paths, int n, int m, char **labels, int k, int w, int h); -data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, float jitter, +data load_data_detection(int n, char **paths, int m, int w, int h, int c, int boxes, int classes, int use_flip, int use_blur, float jitter, float hue, float saturation, float exposure, int mini_batch, int track, int augment_speed, int show_imgs); data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int size, float angle, float aspect, float hue, float saturation, float exposure); diff --git a/src/demo.c b/src/demo.c index 8219de53e3c..6c7f5d39848 100644 --- a/src/demo.c +++ b/src/demo.c @@ -37,6 +37,8 @@ static int demo_ext_output = 0; static long long int frame_id = 0; static int demo_json_port = -1; +#define NFRAMES 3 + static float* predictions[NFRAMES]; static int demo_index = 0; static image images[NFRAMES]; @@ -48,7 +50,7 @@ mat_cv* det_img; mat_cv* show_img; static volatile int flag_exit; -static const int letter_box = 0; +static int letter_box = 0; void *fetch_in_thread(void *ptr) { @@ -102,8 +104,9 @@ double get_wall_time() } void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output) + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in) { + letter_box = letter_box_in; in_img = det_img = show_img = NULL; //skip = frame_skip; image **alphabet = load_alphabet(); @@ -319,7 +322,7 @@ void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int } #else void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output) + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in) { fprintf(stderr, "Demo needs OpenCV for webcam images.\n"); } diff --git a/src/demo.h b/src/demo.h index b26b9592f1b..1f749b899e5 100644 --- a/src/demo.h +++ b/src/demo.h @@ -6,7 +6,7 @@ extern "C" { #endif void demo(char *cfgfile, char *weightfile, float thresh, float hier_thresh, int cam_index, const char *filename, char **names, int classes, - int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output); + int frame_skip, char *prefix, char *out_filename, int mjpeg_port, int json_port, int dont_show, int ext_output, int letter_box_in); #ifdef __cplusplus } #endif diff --git a/src/detector.c b/src/detector.c index b1a612cb54b..d7ed7af65ea 100644 --- a/src/detector.c +++ b/src/detector.c @@ -42,10 +42,18 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i cuda_set_device(gpus[0]); printf(" Prepare additional network for mAP calculation...\n"); net_map = parse_network_cfg_custom(cfgfile, 1, 1); + const int net_classes = net_map.layers[net_map.n - 1].classes; int k; // free memory unnecessary arrays - for (k = 0; k < net_map.n; ++k) { - free_layer(net_map.layers[k]); + for (k = 0; k < net_map.n - 1; ++k) free_layer(net_map.layers[k]); + + char *name_list = option_find_str(options, "names", "data/names.list"); + int names_size = 0; + char **names = get_labels_custom(name_list, &names_size); + if (net_classes != names_size) { + printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n", + name_list, names_size, net_classes, cfgfile); + if (net_classes > names_size) getchar(); } } @@ -119,6 +127,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i args.threads = 64; // 16 or 64 args.angle = net.angle; + args.blur = net.blur; args.exposure = net.exposure; args.saturation = net.saturation; args.hue = net.hue; @@ -137,7 +146,8 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i if (net.track) { args.track = net.track; args.augment_speed = net.augment_speed; - args.threads = net.subdivisions * ngpus; // 2 * ngpus; + if (net.sequential_subdivisions) args.threads = net.sequential_subdivisions * ngpus; + else args.threads = net.subdivisions * ngpus; args.mini_batch = net.batch / net.time_steps; printf("\n Tracking! batch = %d, subdiv = %d, time_steps = %d, mini_batch = %d \n", net.batch, net.subdivisions, net.time_steps, args.mini_batch); } @@ -180,6 +190,11 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i time = what_time_is_it_now(); pthread_join(load_thread, 0); train = buffer; + if (net.track) { + net.sequential_subdivisions = get_current_seq_subdivisions(net); + args.threads = net.sequential_subdivisions * ngpus; + printf(" sequential_subdivisions = %d, sequence = %d \n", net.sequential_subdivisions, get_sequence_value(net)); + } load_thread = load_data(args); /* @@ -223,7 +238,7 @@ void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, i calc_map_for_each = fmax(calc_map_for_each, 100); int next_map_calc = iter_map + calc_map_for_each; next_map_calc = fmax(next_map_calc, net.burn_in); - next_map_calc = fmax(next_map_calc, 1000); + next_map_calc = fmax(next_map_calc, 400); if (calc_map) { printf("\n (next mAP calculation at %d iterations) ", next_map_calc); if (mean_average_precision > 0) printf("\n Last accuracy mAP@0.5 = %2.2f %% ", mean_average_precision * 100); @@ -638,7 +653,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa char *valid_images = option_find_str(options, "valid", "data/train.txt"); char *difficult_valid_images = option_find_str(options, "difficult", NULL); char *name_list = option_find_str(options, "names", "data/names.list"); - char **names = get_labels(name_list); + int names_size = 0; + char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list); //char *mapf = option_find_str(options, "map", 0); //int *map = 0; //if (mapf) map = read_map(mapf); @@ -650,6 +666,8 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa char *train_images = option_find_str(options, "train", "data/train.txt"); valid_images = option_find_str(options, "valid", train_images); net = *existing_net; + remember_network_recurrent_state(*existing_net); + free_network_recurrent_state(*existing_net); } else { net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1 @@ -660,6 +678,11 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa fuse_conv_batchnorm(net); calculate_binary_weights(net); } + if (net.layers[net.n - 1].classes != names_size) { + printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n", + name_list, names_size, net.layers[net.n - 1].classes, cfgfile); + getchar(); + } srand(time(0)); printf("\n calculation mAP (mean average precision)...\n"); @@ -1053,6 +1076,9 @@ float validate_detector_map(char *datacfg, char *cfgfile, char *weightfile, floa if (existing_net) { //set_batch_network(&net, initial_batch); + //free_network_recurrent_state(*existing_net); + restore_network_recurrent_state(*existing_net); + //randomize_network_recurrent_state(*existing_net); } else { free_network(net); @@ -1220,7 +1246,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int if (show) { #ifdef OPENCV - //show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height); + show_acnhors(number_of_boxes, num_of_clusters, rel_width_height_array, anchors_data, width, height); #endif // OPENCV } free(rel_width_height_array); @@ -1230,7 +1256,7 @@ void calc_anchors(char *datacfg, int num_of_clusters, int width, int height, int void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, - float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile) + float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box) { list *options = read_data_cfg(datacfg); char *name_list = option_find_str(options, "names", "data/names.list"); @@ -1278,9 +1304,9 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam //image im; //image sized = load_image_resize(input, net.w, net.h, net.c, &im); image im = load_image(input, 0, 0, net.c); - image sized = resize_image(im, net.w, net.h); - int letterbox = 0; - //image sized = letterbox_image(im, net.w, net.h); letterbox = 1; + image sized; + if(letter_box) sized = letterbox_image(im, net.w, net.h); + else sized = resize_image(im, net.w, net.h); layer l = net.layers[net.n - 1]; //box *boxes = calloc(l.w*l.h*l.n, sizeof(box)); @@ -1297,7 +1323,7 @@ void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filenam //printf("%s: Predicted in %f seconds.\n", input, (what_time_is_it_now()-time)); int nboxes = 0; - detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox); + detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letter_box); if (nms) do_nms_sort(dets, nboxes, l.classes, nms); draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output); save_image(im, "predictions"); @@ -1383,6 +1409,7 @@ void run_detector(int argc, char **argv) { int dont_show = find_arg(argc, argv, "-dont_show"); int show = find_arg(argc, argv, "-show"); + int letter_box = find_arg(argc, argv, "-letter_box"); int calc_map = find_arg(argc, argv, "-map"); int map_points = find_int_arg(argc, argv, "-points", 0); check_mistakes = find_arg(argc, argv, "-check_mistakes"); @@ -1441,7 +1468,7 @@ void run_detector(int argc, char **argv) if (strlen(weights) > 0) if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0; char *filename = (argc > 6) ? argv[6] : 0; - if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile); + if (0 == strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels, outfile, letter_box); else if (0 == strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show, calc_map, mjpeg_port, show_imgs); else if (0 == strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile); else if (0 == strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights); @@ -1456,7 +1483,7 @@ void run_detector(int argc, char **argv) if (strlen(filename) > 0) if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0; demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename, - mjpeg_port, json_port, dont_show, ext_output); + mjpeg_port, json_port, dont_show, ext_output, letter_box); free_list_contents_kvp(options); free_list(options); diff --git a/src/image_opencv.cpp b/src/image_opencv.cpp index a227db103b7..d6f726c3c15 100644 --- a/src/image_opencv.cpp +++ b/src/image_opencv.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -86,8 +87,8 @@ extern "C" { // ==================================================================== image mat_to_image(cv::Mat mat); cv::Mat image_to_mat(image img); - image ipl_to_image(mat_cv* src); - mat_cv *image_to_ipl(image img); +// image ipl_to_image(mat_cv* src); +// mat_cv *image_to_ipl(image img); // cv::Mat ipl_to_mat(IplImage *ipl); // IplImage *mat_to_ipl(cv::Mat mat); @@ -112,7 +113,8 @@ mat_cv *load_image_mat_cv(const char *filename, int flag) //if (check_mistakes) getchar(); return NULL; } - cv::cvtColor(mat, mat, cv::COLOR_RGB2BGR); + if (mat.channels() == 3) cv::cvtColor(mat, mat, cv::COLOR_RGB2BGR); + else if (mat.channels() == 4) cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGRA); return (mat_cv *)mat_ptr; } @@ -151,7 +153,7 @@ image load_image_cv(char *filename, int channels) cv::Mat mat = load_image_mat(filename, channels); if (mat.empty()) { - return make_image(10, 10, 3); + return make_image(10, 10, channels); } return mat_to_image(mat); } @@ -429,7 +431,8 @@ void show_image_cv(image p, const char *name) constrain_image(copy); cv::Mat mat = image_to_mat(copy); - cv::cvtColor(mat, mat, cv::COLOR_RGB2BGR); + if (mat.channels() == 3) cv::cvtColor(mat, mat, cv::COLOR_RGB2BGR); + else if (mat.channels() == 4) cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGR); cv::namedWindow(name, cv::WINDOW_NORMAL); cv::imshow(name, mat); free_image(copy); @@ -816,7 +819,7 @@ extern int stbi_write_jpg(char const *filename, int x, int y, int comp, const vo void save_mat_png(cv::Mat img_src, const char *name) { cv::Mat img_rgb; - cv::cvtColor(img_src, img_rgb, cv::COLOR_RGB2BGR); + if (img_src.channels() >= 3) cv::cvtColor(img_src, img_rgb, cv::COLOR_RGB2BGR); stbi_write_png(name, img_rgb.cols, img_rgb.rows, 3, (char *)img_rgb.data, 0); } // ---------------------------------------- @@ -824,7 +827,7 @@ void save_mat_png(cv::Mat img_src, const char *name) void save_mat_jpg(cv::Mat img_src, const char *name) { cv::Mat img_rgb; - cv::cvtColor(img_src, img_rgb, cv::COLOR_RGB2BGR); + if (img_src.channels() >= 3) cv::cvtColor(img_src, img_rgb, cv::COLOR_RGB2BGR); stbi_write_jpg(name, img_rgb.cols, img_rgb.rows, 3, (char *)img_rgb.data, 80); } // ---------------------------------------- @@ -850,131 +853,136 @@ void save_cv_jpg(mat_cv *img_src, const char *name) // ==================================================================== void draw_detections_cv_v3(mat_cv* mat, detection *dets, int num, float thresh, char **names, image **alphabet, int classes, int ext_output) { - cv::Mat *show_img = mat; - int i, j; - if (!show_img) return; - static int frame_id = 0; - frame_id++; - - for (i = 0; i < num; ++i) { - char labelstr[4096] = { 0 }; - int class_id = -1; - for (j = 0; j < classes; ++j) { - int show = strncmp(names[j], "dont_show", 9); - if (dets[i].prob[j] > thresh && show) { - if (class_id < 0) { - strcat(labelstr, names[j]); - class_id = j; - char buff[10]; - sprintf(buff, " (%2.0f%%)", dets[i].prob[j]*100); - strcat(labelstr, buff); - } - else { - strcat(labelstr, ", "); - strcat(labelstr, names[j]); + try { + cv::Mat *show_img = mat; + int i, j; + if (!show_img) return; + static int frame_id = 0; + frame_id++; + + for (i = 0; i < num; ++i) { + char labelstr[4096] = { 0 }; + int class_id = -1; + for (j = 0; j < classes; ++j) { + int show = strncmp(names[j], "dont_show", 9); + if (dets[i].prob[j] > thresh && show) { + if (class_id < 0) { + strcat(labelstr, names[j]); + class_id = j; + char buff[10]; + sprintf(buff, " (%2.0f%%)", dets[i].prob[j] * 100); + strcat(labelstr, buff); + } + else { + strcat(labelstr, ", "); + strcat(labelstr, names[j]); + } + printf("%s: %.0f%% ", names[j], dets[i].prob[j] * 100); } - printf("%s: %.0f%% ", names[j], dets[i].prob[j] * 100); + } + if (class_id >= 0) { + int width = std::max(1.0f, show_img->rows * .002f); + + //if(0){ + //width = pow(prob, 1./2.)*10+1; + //alphabet = 0; + //} + + //printf("%d %s: %.0f%%\n", i, names[class_id], prob*100); + int offset = class_id * 123457 % classes; + float red = get_color(2, offset, classes); + float green = get_color(1, offset, classes); + float blue = get_color(0, offset, classes); + float rgb[3]; + + //width = prob*20+2; + + rgb[0] = red; + rgb[1] = green; + rgb[2] = blue; + box b = dets[i].bbox; + if (std::isnan(b.w) || std::isinf(b.w)) b.w = 0.5; + if (std::isnan(b.h) || std::isinf(b.h)) b.h = 0.5; + if (std::isnan(b.x) || std::isinf(b.x)) b.x = 0.5; + if (std::isnan(b.y) || std::isinf(b.y)) b.y = 0.5; + b.w = (b.w < 1) ? b.w : 1; + b.h = (b.h < 1) ? b.h : 1; + b.x = (b.x < 1) ? b.x : 1; + b.y = (b.y < 1) ? b.y : 1; + //printf("%f %f %f %f\n", b.x, b.y, b.w, b.h); + + int left = (b.x - b.w / 2.)*show_img->cols; + int right = (b.x + b.w / 2.)*show_img->cols; + int top = (b.y - b.h / 2.)*show_img->rows; + int bot = (b.y + b.h / 2.)*show_img->rows; + + if (left < 0) left = 0; + if (right > show_img->cols - 1) right = show_img->cols - 1; + if (top < 0) top = 0; + if (bot > show_img->rows - 1) bot = show_img->rows - 1; + + //int b_x_center = (left + right) / 2; + //int b_y_center = (top + bot) / 2; + //int b_width = right - left; + //int b_height = bot - top; + //sprintf(labelstr, "%d x %d - w: %d, h: %d", b_x_center, b_y_center, b_width, b_height); + + float const font_size = show_img->rows / 1000.F; + cv::Size const text_size = cv::getTextSize(labelstr, cv::FONT_HERSHEY_COMPLEX_SMALL, font_size, 1, 0); + cv::Point pt1, pt2, pt_text, pt_text_bg1, pt_text_bg2; + pt1.x = left; + pt1.y = top; + pt2.x = right; + pt2.y = bot; + pt_text.x = left; + pt_text.y = top - 4;// 12; + pt_text_bg1.x = left; + pt_text_bg1.y = top - (3 + 18 * font_size); + pt_text_bg2.x = right; + if ((right - left) < text_size.width) pt_text_bg2.x = left + text_size.width; + pt_text_bg2.y = top; + cv::Scalar color; + color.val[0] = red * 256; + color.val[1] = green * 256; + color.val[2] = blue * 256; + + // you should create directory: result_img + //static int copied_frame_id = -1; + //static IplImage* copy_img = NULL; + //if (copied_frame_id != frame_id) { + // copied_frame_id = frame_id; + // if(copy_img == NULL) copy_img = cvCreateImage(cvSize(show_img->width, show_img->height), show_img->depth, show_img->nChannels); + // cvCopy(show_img, copy_img, 0); + //} + //static int img_id = 0; + //img_id++; + //char image_name[1024]; + //sprintf(image_name, "result_img/img_%d_%d_%d_%s.jpg", frame_id, img_id, class_id, names[class_id]); + //CvRect rect = cvRect(pt1.x, pt1.y, pt2.x - pt1.x, pt2.y - pt1.y); + //cvSetImageROI(copy_img, rect); + //cvSaveImage(image_name, copy_img, 0); + //cvResetImageROI(copy_img); + + cv::rectangle(*show_img, pt1, pt2, color, width, 8, 0); + if (ext_output) + printf("\t(left_x: %4.0f top_y: %4.0f width: %4.0f height: %4.0f)\n", + (float)left, (float)top, b.w*show_img->cols, b.h*show_img->rows); + else + printf("\n"); + + cv::rectangle(*show_img, pt_text_bg1, pt_text_bg2, color, width, 8, 0); + cv::rectangle(*show_img, pt_text_bg1, pt_text_bg2, color, CV_FILLED, 8, 0); // filled + cv::Scalar black_color = CV_RGB(0, 0, 0); + cv::putText(*show_img, labelstr, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, font_size, black_color, 2 * font_size, CV_AA); + // cv::FONT_HERSHEY_COMPLEX_SMALL, cv::FONT_HERSHEY_SIMPLEX } } - if (class_id >= 0) { - int width = std::max(1.0f, show_img->rows * .002f); - - //if(0){ - //width = pow(prob, 1./2.)*10+1; - //alphabet = 0; - //} - - //printf("%d %s: %.0f%%\n", i, names[class_id], prob*100); - int offset = class_id * 123457 % classes; - float red = get_color(2, offset, classes); - float green = get_color(1, offset, classes); - float blue = get_color(0, offset, classes); - float rgb[3]; - - //width = prob*20+2; - - rgb[0] = red; - rgb[1] = green; - rgb[2] = blue; - box b = dets[i].bbox; - if (std::isnan(b.w) || std::isinf(b.w)) b.w = 0.5; - if (std::isnan(b.h) || std::isinf(b.h)) b.h = 0.5; - if (std::isnan(b.x) || std::isinf(b.x)) b.x = 0.5; - if (std::isnan(b.y) || std::isinf(b.y)) b.y = 0.5; - b.w = (b.w < 1) ? b.w : 1; - b.h = (b.h < 1) ? b.h : 1; - b.x = (b.x < 1) ? b.x : 1; - b.y = (b.y < 1) ? b.y : 1; - //printf("%f %f %f %f\n", b.x, b.y, b.w, b.h); - - int left = (b.x - b.w / 2.)*show_img->cols; - int right = (b.x + b.w / 2.)*show_img->cols; - int top = (b.y - b.h / 2.)*show_img->rows; - int bot = (b.y + b.h / 2.)*show_img->rows; - - if (left < 0) left = 0; - if (right > show_img->cols - 1) right = show_img->cols - 1; - if (top < 0) top = 0; - if (bot > show_img->rows - 1) bot = show_img->rows - 1; - - //int b_x_center = (left + right) / 2; - //int b_y_center = (top + bot) / 2; - //int b_width = right - left; - //int b_height = bot - top; - //sprintf(labelstr, "%d x %d - w: %d, h: %d", b_x_center, b_y_center, b_width, b_height); - - float const font_size = show_img->rows / 1000.F; - cv::Size const text_size = cv::getTextSize(labelstr, cv::FONT_HERSHEY_COMPLEX_SMALL, font_size, 1, 0); - cv::Point pt1, pt2, pt_text, pt_text_bg1, pt_text_bg2; - pt1.x = left; - pt1.y = top; - pt2.x = right; - pt2.y = bot; - pt_text.x = left; - pt_text.y = top - 4;// 12; - pt_text_bg1.x = left; - pt_text_bg1.y = top - (1 + 18 * font_size); - pt_text_bg2.x = right; - if ((right - left) < text_size.width) pt_text_bg2.x = left + text_size.width; - pt_text_bg2.y = top; - cv::Scalar color; - color.val[0] = red * 256; - color.val[1] = green * 256; - color.val[2] = blue * 256; - - // you should create directory: result_img - //static int copied_frame_id = -1; - //static IplImage* copy_img = NULL; - //if (copied_frame_id != frame_id) { - // copied_frame_id = frame_id; - // if(copy_img == NULL) copy_img = cvCreateImage(cvSize(show_img->width, show_img->height), show_img->depth, show_img->nChannels); - // cvCopy(show_img, copy_img, 0); - //} - //static int img_id = 0; - //img_id++; - //char image_name[1024]; - //sprintf(image_name, "result_img/img_%d_%d_%d_%s.jpg", frame_id, img_id, class_id, names[class_id]); - //CvRect rect = cvRect(pt1.x, pt1.y, pt2.x - pt1.x, pt2.y - pt1.y); - //cvSetImageROI(copy_img, rect); - //cvSaveImage(image_name, copy_img, 0); - //cvResetImageROI(copy_img); - - cv::rectangle(*show_img, pt1, pt2, color, width, 8, 0); - if (ext_output) - printf("\t(left_x: %4.0f top_y: %4.0f width: %4.0f height: %4.0f)\n", - (float)left, (float)top, b.w*show_img->cols, b.h*show_img->rows); - else - printf("\n"); - - cv::rectangle(*show_img, pt_text_bg1, pt_text_bg2, color, width, 8, 0); - cv::rectangle(*show_img, pt_text_bg1, pt_text_bg2, color, CV_FILLED, 8, 0); // filled - cv::Scalar black_color = CV_RGB(0,0,0); - cv::putText(*show_img, labelstr, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, font_size, black_color, 2*font_size, CV_AA); - // cv::FONT_HERSHEY_COMPLEX_SMALL, cv::FONT_HERSHEY_SIMPLEX + if (ext_output) { + fflush(stdout); } } - if (ext_output) { - fflush(stdout); + catch (...) { + cerr << "OpenCV exception: draw_detections_cv_v3() \n"; } } // ---------------------------------------- @@ -984,53 +992,58 @@ void draw_detections_cv_v3(mat_cv* mat, detection *dets, int num, float thresh, // ==================================================================== mat_cv* draw_train_chart(float max_img_loss, int max_batches, int number_of_lines, int img_size, int dont_show) { - int img_offset = 50; + int img_offset = 60; int draw_size = img_size - img_offset; cv::Mat *img_ptr = new cv::Mat(img_size, img_size, CV_8UC3, CV_RGB(255, 255, 255)); cv::Mat &img = *img_ptr; cv::Point pt1, pt2, pt_text; - char char_buff[100]; - int i; - // vertical lines - pt1.x = img_offset; pt2.x = img_size, pt_text.x = 10; - for (i = 1; i <= number_of_lines; ++i) { - pt1.y = pt2.y = (float)i * draw_size / number_of_lines; - cv::line(img, pt1, pt2, CV_RGB(224, 224, 224), 1, 8, 0); - if (i % 10 == 0) { - sprintf(char_buff, "%2.1f", max_img_loss*(number_of_lines - i) / number_of_lines); - pt_text.y = pt1.y + 5; - - cv::putText(img, char_buff, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - cv::line(img, pt1, pt2, CV_RGB(128, 128, 128), 1, 8, 0); + try { + char char_buff[100]; + int i; + // vertical lines + pt1.x = img_offset; pt2.x = img_size, pt_text.x = 30; + for (i = 1; i <= number_of_lines; ++i) { + pt1.y = pt2.y = (float)i * draw_size / number_of_lines; + cv::line(img, pt1, pt2, CV_RGB(224, 224, 224), 1, 8, 0); + if (i % 10 == 0) { + sprintf(char_buff, "%2.1f", max_img_loss*(number_of_lines - i) / number_of_lines); + pt_text.y = pt1.y + 3; + + cv::putText(img, char_buff, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + cv::line(img, pt1, pt2, CV_RGB(128, 128, 128), 1, 8, 0); + } } - } - // horizontal lines - pt1.y = draw_size; pt2.y = 0, pt_text.y = draw_size + 15; - for (i = 0; i <= number_of_lines; ++i) { - pt1.x = pt2.x = img_offset + (float)i * draw_size / number_of_lines; - cv::line(img, pt1, pt2, CV_RGB(224, 224, 224), 1, 8, 0); - if (i % 10 == 0) { - sprintf(char_buff, "%d", max_batches * i / number_of_lines); - pt_text.x = pt1.x - 20; - cv::putText(img, char_buff, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - cv::line(img, pt1, pt2, CV_RGB(128, 128, 128), 1, 8, 0); + // horizontal lines + pt1.y = draw_size; pt2.y = 0, pt_text.y = draw_size + 15; + for (i = 0; i <= number_of_lines; ++i) { + pt1.x = pt2.x = img_offset + (float)i * draw_size / number_of_lines; + cv::line(img, pt1, pt2, CV_RGB(224, 224, 224), 1, 8, 0); + if (i % 10 == 0) { + sprintf(char_buff, "%d", max_batches * i / number_of_lines); + pt_text.x = pt1.x - 20; + cv::putText(img, char_buff, pt_text, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + cv::line(img, pt1, pt2, CV_RGB(128, 128, 128), 1, 8, 0); + } } - } - cv::putText(img, "Loss", cv::Point(0, 35), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - cv::putText(img, "Iteration number", cv::Point(draw_size / 2, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - char max_batches_buff[100]; - sprintf(max_batches_buff, "in cfg max_batches=%d", max_batches); - cv::putText(img, max_batches_buff, cv::Point(draw_size - 195, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - cv::putText(img, "Press 's' to save : chart.png", cv::Point(5, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); - if (!dont_show) { - printf(" If error occurs - run training with flag: -dont_show \n"); - cv::namedWindow("average loss", cv::WINDOW_NORMAL); - cv::moveWindow("average loss", 0, 0); - cv::resizeWindow("average loss", img_size, img_size); - cv::imshow("average loss", img); - cv::waitKey(20); + cv::putText(img, "Loss", cv::Point(10, 55), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 255), 1, CV_AA); + cv::putText(img, "Iteration number", cv::Point(draw_size / 2, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + char max_batches_buff[100]; + sprintf(max_batches_buff, "in cfg max_batches=%d", max_batches); + cv::putText(img, max_batches_buff, cv::Point(draw_size - 195, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + cv::putText(img, "Press 's' to save : chart.png", cv::Point(5, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + if (!dont_show) { + printf(" If error occurs - run training with flag: -dont_show \n"); + cv::namedWindow("average loss", cv::WINDOW_NORMAL); + cv::moveWindow("average loss", 0, 0); + cv::resizeWindow("average loss", img_size, img_size); + cv::imshow("average loss", img); + cv::waitKey(20); + } + } + catch (...) { + cerr << "OpenCV exception: draw_train_chart() \n"; } return (mat_cv*)img_ptr; } @@ -1039,60 +1052,72 @@ mat_cv* draw_train_chart(float max_img_loss, int max_batches, int number_of_line void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_img_loss, int current_batch, int max_batches, float precision, int draw_precision, char *accuracy_name, int dont_show, int mjpeg_port) { - cv::Mat &img = *(cv::Mat*)img_src; - int img_offset = 50; - int draw_size = img_size - img_offset; - char char_buff[100]; - cv::Point pt1, pt2; - pt1.x = img_offset + draw_size * (float)current_batch / max_batches; - pt1.y = draw_size * (1 - avg_loss / max_img_loss); - if (pt1.y < 0) pt1.y = 1; - cv::circle(img, pt1, 1, CV_RGB(0, 0, 255), CV_FILLED, 8, 0); - - // precision - if (draw_precision) { - static float old_precision = 0; - static int iteration_old = 0; - static int text_iteration_old = 0; - if (iteration_old == 0) - cv::putText(img, accuracy_name, cv::Point(0, 12), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 0, 0), 1, CV_AA); - - cv::line(img, - cv::Point(img_offset + draw_size * (float)iteration_old / max_batches, draw_size * (1 - old_precision)), - cv::Point(img_offset + draw_size * (float)current_batch / max_batches, draw_size * (1 - precision)), - CV_RGB(255, 0, 0), 1, 8, 0); - - if (((int)(old_precision * 10) != (int)(precision * 10)) || (current_batch - text_iteration_old) >= max_batches / 10) { - text_iteration_old = current_batch; - sprintf(char_buff, "%2.0f%% ", precision * 100); - cv::putText(img, char_buff, cv::Point(pt1.x - 30, draw_size * (1 - precision) + 15), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 5, CV_AA); - - cv::putText(img, char_buff, cv::Point(pt1.x - 30, draw_size * (1 - precision) + 15), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(200, 0, 0), 1, CV_AA); + try { + cv::Mat &img = *(cv::Mat*)img_src; + int img_offset = 60; + int draw_size = img_size - img_offset; + char char_buff[100]; + cv::Point pt1, pt2; + pt1.x = img_offset + draw_size * (float)current_batch / max_batches; + pt1.y = draw_size * (1 - avg_loss / max_img_loss); + if (pt1.y < 0) pt1.y = 1; + cv::circle(img, pt1, 1, CV_RGB(0, 0, 255), CV_FILLED, 8, 0); + + // precision + if (draw_precision) { + static float old_precision = 0; + static float max_precision = 0; + static int iteration_old = 0; + static int text_iteration_old = 0; + if (iteration_old == 0) + cv::putText(img, accuracy_name, cv::Point(10, 12), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 0, 0), 1, CV_AA); + + cv::line(img, + cv::Point(img_offset + draw_size * (float)iteration_old / max_batches, draw_size * (1 - old_precision)), + cv::Point(img_offset + draw_size * (float)current_batch / max_batches, draw_size * (1 - precision)), + CV_RGB(255, 0, 0), 1, 8, 0); + + sprintf(char_buff, "%2.1f%% ", precision * 100); + cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 5, CV_AA); + cv::putText(img, char_buff, cv::Point(10, 28), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(200, 0, 0), 1, CV_AA); + + if ((std::fabs(old_precision - precision) > 0.1) || (max_precision < precision) || (current_batch - text_iteration_old) >= max_batches / 10) { + text_iteration_old = current_batch; + max_precision = std::max(max_precision, precision); + sprintf(char_buff, "%2.0f%% ", precision * 100); + cv::putText(img, char_buff, cv::Point(pt1.x - 30, draw_size * (1 - precision) + 15), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 5, CV_AA); + cv::putText(img, char_buff, cv::Point(pt1.x - 30, draw_size * (1 - precision) + 15), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(200, 0, 0), 1, CV_AA); + } + old_precision = precision; + iteration_old = current_batch; } - old_precision = precision; - iteration_old = current_batch; - } - sprintf(char_buff, "current avg loss = %2.4f iteration = %d", avg_loss, current_batch); - pt1.x = 55, pt1.y = 10; - pt2.x = pt1.x + 460, pt2.y = pt1.y + 20; - cv::rectangle(img, pt1, pt2, CV_RGB(255, 255, 255), CV_FILLED, 8, 0); - pt1.y += 15; - cv::putText(img, char_buff, pt1, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 0), 1, CV_AA); + sprintf(char_buff, "current avg loss = %2.4f iteration = %d", avg_loss, current_batch); + pt1.x = 15, pt1.y = draw_size + 18; + pt2.x = pt1.x + 460, pt2.y = pt1.y + 20; + cv::rectangle(img, pt1, pt2, CV_RGB(255, 255, 255), CV_FILLED, 8, 0); + pt1.y += 15; + cv::putText(img, char_buff, pt1, cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(0, 0, 100), 1, CV_AA); + + int k = 0; + if (!dont_show) { + cv::imshow("average loss", img); + k = cv::waitKey(20); + } + static int old_batch = 0; + if (k == 's' || current_batch == (max_batches - 1) || (current_batch / 100 > old_batch / 100)) { + old_batch = current_batch; + save_mat_png(img, "chart.png"); + cv::putText(img, "- Saved", cv::Point(260, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 0, 0), 1, CV_AA); + } + else + cv::putText(img, "- Saved", cv::Point(260, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 1, CV_AA); - int k = 0; - if (!dont_show) { - cv::imshow("average loss", img); - k = cv::waitKey(20); + if (mjpeg_port > 0) send_mjpeg((mat_cv *)&img, mjpeg_port, 500000, 100); } - if (k == 's' || current_batch == (max_batches - 1) || current_batch % 100 == 0) { - save_mat_png(img, "chart.png"); - cv::putText(img, "- Saved", cv::Point(260, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 0, 0), 1, CV_AA); + catch (...) { + cerr << "OpenCV exception: draw_train_loss() \n"; } - else - cv::putText(img, "- Saved", cv::Point(260, img_size - 10), cv::FONT_HERSHEY_COMPLEX_SMALL, 0.7, CV_RGB(255, 255, 255), 1, CV_AA); - - if (mjpeg_port > 0) send_mjpeg((mat_cv *)&img, mjpeg_port, 500000, 100); } // ---------------------------------------- @@ -1100,9 +1125,20 @@ void draw_train_loss(mat_cv* img_src, int img_size, float avg_loss, float max_im // ==================================================================== // Data augmentation // ==================================================================== +static box float_to_box_stride(float *f, int stride) +{ + box b = { 0 }; + b.x = f[0]; + b.y = f[1 * stride]; + b.w = f[2 * stride]; + b.h = f[3 * stride]; + return b; +} + image image_data_augmentation(mat_cv* mat, int w, int h, int pleft, int ptop, int swidth, int sheight, int flip, - float jitter, float dhue, float dsat, float dexp) + float jitter, float dhue, float dsat, float dexp, + int blur, int num_boxes, float *truth) { image out; try { @@ -1114,43 +1150,52 @@ image image_data_augmentation(mat_cv* mat, int w, int h, cv::Rect new_src_rect = src_rect & img_rect; cv::Rect dst_rect(cv::Point2i(std::max(0, -pleft), std::max(0, -ptop)), new_src_rect.size()); + cv::Mat sized; - cv::Mat cropped(cv::Size(src_rect.width, src_rect.height), img.type()); - cropped.setTo(cv::Scalar::all(0)); + if (src_rect.x == 0 && src_rect.y == 0 && src_rect.size() == img.size()) { + cv::resize(img, sized, cv::Size(w, h), 0, 0, cv::INTER_LINEAR); + } + else { + cv::Mat cropped(src_rect.size(), img.type()); + //cropped.setTo(cv::Scalar::all(0)); + cropped.setTo(cv::mean(img)); - img(new_src_rect).copyTo(cropped(dst_rect)); + img(new_src_rect).copyTo(cropped(dst_rect)); - // resize - cv::Mat sized; - cv::resize(cropped, sized, cv::Size(w, h), 0, 0, cv::INTER_LINEAR); + // resize + cv::resize(cropped, sized, cv::Size(w, h), 0, 0, cv::INTER_LINEAR); + } // flip if (flip) { + cv::Mat cropped; cv::flip(sized, cropped, 1); // 0 - x-axis, 1 - y-axis, -1 - both axes (x & y) sized = cropped.clone(); } // HSV augmentation // cv::COLOR_BGR2HSV, cv::COLOR_RGB2HSV, cv::COLOR_HSV2BGR, cv::COLOR_HSV2RGB - if(img.channels() >= 3) - { - cv::Mat hsv_src; - cvtColor(sized, hsv_src, cv::COLOR_RGB2HSV); // RGB to HSV + if (dsat != 1 || dexp != 1 || dhue != 0) { + if (img.channels() >= 3) + { + cv::Mat hsv_src; + cvtColor(sized, hsv_src, cv::COLOR_RGB2HSV); // RGB to HSV - std::vector hsv; - cv::split(hsv_src, hsv); + std::vector hsv; + cv::split(hsv_src, hsv); - hsv[1] *= dsat; - hsv[2] *= dexp; - hsv[0] += 179 * dhue; + hsv[1] *= dsat; + hsv[2] *= dexp; + hsv[0] += 179 * dhue; - cv::merge(hsv, hsv_src); + cv::merge(hsv, hsv_src); - cvtColor(hsv_src, sized, cv::COLOR_HSV2RGB); // HSV to RGB (the same as previous) - } - else - { - sized *= dexp; + cvtColor(hsv_src, sized, cv::COLOR_HSV2RGB); // HSV to RGB (the same as previous) + } + else + { + sized *= dexp; + } } //std::stringstream window_name; @@ -1158,6 +1203,31 @@ image image_data_augmentation(mat_cv* mat, int w, int h, //cv::imshow(window_name.str(), sized); //cv::waitKey(0); + if (blur) { + cv::Mat dst(sized.size(), sized.type()); + if(blur == 1) cv::GaussianBlur(sized, dst, cv::Size(31, 31), 0); + else cv::GaussianBlur(sized, dst, cv::Size((blur / 2) * 2 + 1, (blur / 2) * 2 + 1), 0); + cv::Rect img_rect(0, 0, sized.cols, sized.rows); + //std::cout << " blur num_boxes = " << num_boxes << std::endl; + + if (blur == 1) { + int t; + for (t = 0; t < num_boxes; ++t) { + box b = float_to_box_stride(truth + t*(4 + 1), 1); + if (!b.x) break; + int left = (b.x - b.w / 2.)*sized.cols; + int width = b.w*sized.cols; + int top = (b.y - b.h / 2.)*sized.rows; + int height = b.h*sized.rows; + cv::Rect roi(left, top, width, height); + roi = roi & img_rect; + + sized(roi).copyTo(dst(roi)); + } + } + dst.copyTo(sized); + } + // Mat -> image out = mat_to_image(sized); } diff --git a/src/image_opencv.h b/src/image_opencv.h index 13e67210cd8..21ee6939fae 100644 --- a/src/image_opencv.h +++ b/src/image_opencv.h @@ -95,7 +95,8 @@ void draw_train_loss(mat_cv* img, int img_size, float avg_loss, float max_img_lo // Data augmentation image image_data_augmentation(mat_cv* mat, int w, int h, int pleft, int ptop, int swidth, int sheight, int flip, - float jitter, float dhue, float dsat, float dexp); + float jitter, float dhue, float dsat, float dexp, + int blur, int num_boxes, float *truth); // Show Anchors void show_acnhors(int number_of_boxes, int num_of_clusters, float *rel_width_height_array, model anchors_data, int width, int height); diff --git a/src/layer.c b/src/layer.c index a60978486a4..6409bd892e7 100644 --- a/src/layer.c +++ b/src/layer.c @@ -2,22 +2,40 @@ #include "dark_cuda.h" #include +void free_sublayer(layer *l) +{ + if (l) { + free_layer(*l); + free(l); + } +} + void free_layer(layer l) { - // free layers: input_layer, self_layer, output_layer, ... - if (l.type == CRNN) { - if (l.input_layer) { - free_layer(*l.input_layer); - free(l.input_layer); - } - if (l.self_layer) { - free_layer(*l.self_layer); - free(l.self_layer); + if (l.type == CONV_LSTM) { + if (l.peephole) { + free_sublayer(l.vf); + free_sublayer(l.vi); + free_sublayer(l.vo); } - if (l.output_layer) { - free_layer(*l.output_layer); - free(l.output_layer); + else { + free(l.vf); + free(l.vi); + free(l.vo); } + free_sublayer(l.wf); + free_sublayer(l.wi); + free_sublayer(l.wg); + free_sublayer(l.wo); + free_sublayer(l.uf); + free_sublayer(l.ui); + free_sublayer(l.ug); + free_sublayer(l.uo); + } + if (l.type == CRNN) { + free_sublayer(l.input_layer); + free_sublayer(l.self_layer); + free_sublayer(l.output_layer); l.output = NULL; l.delta = NULL; #ifdef GPU @@ -83,21 +101,36 @@ void free_layer(layer l) if (l.v) free(l.v); if (l.z_cpu) free(l.z_cpu); if (l.r_cpu) free(l.r_cpu); - if (l.h_cpu) free(l.h_cpu); if (l.binary_input) free(l.binary_input); if (l.bin_re_packed_input) free(l.bin_re_packed_input); if (l.t_bit_input) free(l.t_bit_input); if (l.loss) free(l.loss); + // CONV-LSTM + if (l.f_cpu) free(l.f_cpu); + if (l.i_cpu) free(l.i_cpu); + if (l.g_cpu) free(l.g_cpu); + if (l.o_cpu) free(l.o_cpu); + if (l.c_cpu) free(l.c_cpu); + if (l.h_cpu) free(l.h_cpu); + if (l.temp_cpu) free(l.temp_cpu); + if (l.temp2_cpu) free(l.temp2_cpu); + if (l.temp3_cpu) free(l.temp3_cpu); + if (l.dc_cpu) free(l.dc_cpu); + if (l.dh_cpu) free(l.dh_cpu); + if (l.prev_state_cpu) free(l.prev_state_cpu); + if (l.prev_cell_cpu) free(l.prev_cell_cpu); + if (l.stored_c_cpu) free(l.stored_c_cpu); + if (l.stored_h_cpu) free(l.stored_h_cpu); + if (l.cell_cpu) free(l.cell_cpu); + #ifdef GPU if (l.indexes_gpu) cuda_free((float *)l.indexes_gpu); if (l.z_gpu) cuda_free(l.z_gpu); if (l.r_gpu) cuda_free(l.r_gpu); - if (l.h_gpu) cuda_free(l.h_gpu); if (l.m_gpu) cuda_free(l.m_gpu); if (l.v_gpu) cuda_free(l.v_gpu); - if (l.prev_state_gpu) cuda_free(l.prev_state_gpu); if (l.forgot_state_gpu) cuda_free(l.forgot_state_gpu); if (l.forgot_delta_gpu) cuda_free(l.forgot_delta_gpu); if (l.state_gpu) cuda_free(l.state_gpu); @@ -137,5 +170,25 @@ void free_layer(layer l) if (l.rand_gpu) cuda_free(l.rand_gpu); if (l.squared_gpu) cuda_free(l.squared_gpu); if (l.norms_gpu) cuda_free(l.norms_gpu); + + // CONV-LSTM + if (l.f_gpu) cuda_free(l.f_gpu); + if (l.i_gpu) cuda_free(l.i_gpu); + if (l.g_gpu) cuda_free(l.g_gpu); + if (l.o_gpu) cuda_free(l.o_gpu); + if (l.c_gpu) cuda_free(l.c_gpu); + if (l.h_gpu) cuda_free(l.h_gpu); + if (l.temp_gpu) cuda_free(l.temp_gpu); + if (l.temp2_gpu) cuda_free(l.temp2_gpu); + if (l.temp3_gpu) cuda_free(l.temp3_gpu); + if (l.dc_gpu) cuda_free(l.dc_gpu); + if (l.dh_gpu) cuda_free(l.dh_gpu); + if (l.prev_state_gpu) cuda_free(l.prev_state_gpu); + if (l.prev_cell_gpu) cuda_free(l.prev_cell_gpu); + if (l.stored_c_gpu) cuda_free(l.stored_c_gpu); + if (l.stored_h_gpu) cuda_free(l.stored_h_gpu); + if (l.last_prev_state_gpu) cuda_free(l.last_prev_state_gpu); + if (l.last_prev_cell_gpu) cuda_free(l.last_prev_cell_gpu); + if (l.cell_gpu) cuda_free(l.cell_gpu); #endif } diff --git a/src/lstm_layer.c b/src/lstm_layer.c index bf1e303b4c5..94664ce3aa1 100644 --- a/src/lstm_layer.c +++ b/src/lstm_layer.c @@ -39,49 +39,49 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n l.out_h = 1; l.out_c = outputs; - l.uf = (layer*)malloc(sizeof(layer)); + l.uf = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.uf) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.uf->batch = batch; if (l.workspace_size < l.uf->workspace_size) l.workspace_size = l.uf->workspace_size; - l.ui = (layer*)malloc(sizeof(layer)); + l.ui = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.ui) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.ui->batch = batch; if (l.workspace_size < l.ui->workspace_size) l.workspace_size = l.ui->workspace_size; - l.ug = (layer*)malloc(sizeof(layer)); + l.ug = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.ug) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.ug->batch = batch; if (l.workspace_size < l.ug->workspace_size) l.workspace_size = l.ug->workspace_size; - l.uo = (layer*)malloc(sizeof(layer)); + l.uo = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.uo) = make_connected_layer(batch, steps, inputs, outputs, LINEAR, batch_normalize); l.uo->batch = batch; if (l.workspace_size < l.uo->workspace_size) l.workspace_size = l.uo->workspace_size; - l.wf = (layer*)malloc(sizeof(layer)); + l.wf = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wf) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wf->batch = batch; if (l.workspace_size < l.wf->workspace_size) l.workspace_size = l.wf->workspace_size; - l.wi = (layer*)malloc(sizeof(layer)); + l.wi = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wi) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wi->batch = batch; if (l.workspace_size < l.wi->workspace_size) l.workspace_size = l.wi->workspace_size; - l.wg = (layer*)malloc(sizeof(layer)); + l.wg = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wg) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wg->batch = batch; if (l.workspace_size < l.wg->workspace_size) l.workspace_size = l.wg->workspace_size; - l.wo = (layer*)malloc(sizeof(layer)); + l.wo = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.wo) = make_connected_layer(batch, steps, outputs, outputs, LINEAR, batch_normalize); l.wo->batch = batch; @@ -95,6 +95,7 @@ layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_n l.forward = forward_lstm_layer; l.update = update_lstm_layer; + l.backward = backward_lstm_layer; l.prev_state_cpu = (float*)calloc(batch*outputs, sizeof(float)); l.prev_cell_cpu = (float*)calloc(batch*outputs, sizeof(float)); diff --git a/src/lstm_layer.h b/src/lstm_layer.h index f60a0cabafe..dc8eb37a38d 100644 --- a/src/lstm_layer.h +++ b/src/lstm_layer.h @@ -12,6 +12,7 @@ extern "C" { layer make_lstm_layer(int batch, int inputs, int outputs, int steps, int batch_normalize); void forward_lstm_layer(layer l, network_state state); +void backward_lstm_layer(layer l, network_state state); void update_lstm_layer(layer l, int batch, float learning_rate, float momentum, float decay); #ifdef GPU diff --git a/src/network.c b/src/network.c index 248b93ff0c9..86c248c4a1f 100644 --- a/src/network.c +++ b/src/network.c @@ -15,6 +15,7 @@ #include "gru_layer.h" #include "rnn_layer.h" #include "crnn_layer.h" +#include "conv_lstm_layer.h" #include "local_layer.h" #include "convolutional_layer.h" #include "activation_layer.h" @@ -90,6 +91,32 @@ void reset_rnn(network *net) reset_network_state(net, 0); } +float get_current_seq_subdivisions(network net) +{ + int sequence_subdivisions = net.init_sequential_subdivisions; + + if (net.num_steps > 0) + { + int batch_num = get_current_batch(net); + int i; + for (i = 0; i < net.num_steps; ++i) { + if (net.steps[i] > batch_num) break; + sequence_subdivisions *= net.seq_scales[i]; + } + } + if (sequence_subdivisions < 1) sequence_subdivisions = 1; + if (sequence_subdivisions > net.subdivisions) sequence_subdivisions = net.subdivisions; + return sequence_subdivisions; +} + +int get_sequence_value(network net) +{ + int sequence = 1; + if (net.sequential_subdivisions != 0) sequence = net.subdivisions / net.sequential_subdivisions; + if (sequence < 1) sequence = 1; + return sequence; +} + float get_current_rate(network net) { int batch_num = get_current_batch(net); @@ -120,11 +147,20 @@ float get_current_rate(network net) case SIG: return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step)))); case SGDR: + { + int last_iteration_start = 0; + int cycle_size = net.batches_per_cycle; + while ((last_iteration_start + cycle_size) < batch_num) + { + last_iteration_start += cycle_size; + cycle_size *= net.batches_cycle_mult; + } rate = net.learning_rate_min + - 0.5*(net.learning_rate-net.learning_rate_min) - * (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle)); + 0.5*(net.learning_rate - net.learning_rate_min) + * (1. + cos((float)(batch_num - last_iteration_start)*3.14159265 / cycle_size)); return rate; + } default: fprintf(stderr, "Policy is weird!\n"); return net.learning_rate; @@ -315,6 +351,7 @@ float train_network_sgd(network net, data d, int n) float sum = 0; for(i = 0; i < n; ++i){ get_random_batch(d, batch, X, y); + net.current_subdivision = i; float err = train_network_datum(net, X, y); sum += err; } @@ -340,6 +377,7 @@ float train_network_waitkey(network net, data d, int wait_key) float sum = 0; for(i = 0; i < n; ++i){ get_next_batch(d, batch, i*batch, X, y); + net.current_subdivision = i; float err = train_network_datum(net, X, y); sum += err; if(wait_key) wait_key_cv(5); @@ -925,6 +963,7 @@ void free_network(network net) } free(net.layers); + free(net.seq_scales); free(net.scales); free(net.steps); free(net.seen); @@ -967,14 +1006,16 @@ void fuse_conv_batchnorm(network net) int f; for (f = 0; f < l->n; ++f) { - l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f]) + .000001f); + //l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f]) + .000001f); + l->biases[f] = l->biases[f] - (double)l->scales[f] * l->rolling_mean[f] / (sqrt((double)l->rolling_variance[f] + .000001)); - const size_t filter_size = l->size*l->size*l->c; + const size_t filter_size = l->size*l->size*l->c / l->groups; int i; for (i = 0; i < filter_size; ++i) { int w_index = f*filter_size + i; - l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f]) + .000001f); + //l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f]) + .000001f); + l->weights[w_index] = (double)l->weights[w_index] * l->scales[f] / (sqrt((double)l->rolling_variance[f] + .000001)); } } @@ -1111,3 +1152,40 @@ network combine_train_valid_networks(network net_train, network net_map) } return net_combined; } + +void free_network_recurrent_state(network net) +{ + int k; + for (k = 0; k < net.n; ++k) { + if (net.layers[k].type == CONV_LSTM) free_state_conv_lstm(net.layers[k]); + if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]); + } +} + +void randomize_network_recurrent_state(network net) +{ + int k; + for (k = 0; k < net.n; ++k) { + if (net.layers[k].type == CONV_LSTM) randomize_state_conv_lstm(net.layers[k]); + if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]); + } +} + + +void remember_network_recurrent_state(network net) +{ + int k; + for (k = 0; k < net.n; ++k) { + if (net.layers[k].type == CONV_LSTM) remember_state_conv_lstm(net.layers[k]); + //if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]); + } +} + +void restore_network_recurrent_state(network net) +{ + int k; + for (k = 0; k < net.n; ++k) { + if (net.layers[k].type == CONV_LSTM) restore_state_conv_lstm(net.layers[k]); + if (net.layers[k].type == CRNN) free_state_crnn(net.layers[k]); + } +} \ No newline at end of file diff --git a/src/network.h b/src/network.h index c4138c74c45..186fe3d8bbb 100644 --- a/src/network.h +++ b/src/network.h @@ -104,6 +104,8 @@ void backward_network_gpu(network net, network_state state); void update_network_gpu(network net); #endif +float get_current_seq_subdivisions(network net); +int get_sequence_value(network net); float get_current_rate(network net); int get_current_batch(network net); void free_network(network net); @@ -163,6 +165,10 @@ int get_network_background(network net); //LIB_API void calculate_binary_weights(network net); network combine_train_valid_networks(network net_train, network net_map); void copy_weights_net(network net_train, network *net_map); +void free_network_recurrent_state(network net); +void randomize_network_recurrent_state(network net); +void remember_network_recurrent_state(network net); +void restore_network_recurrent_state(network net); #ifdef __cplusplus } diff --git a/src/network_kernels.cu b/src/network_kernels.cu index 9162c8b2b36..40f71eb0ba5 100644 --- a/src/network_kernels.cu +++ b/src/network_kernels.cu @@ -126,7 +126,7 @@ void update_network_gpu(network net) { cuda_set_device(net.gpu_index); int i; - int update_batch = net.batch*net.subdivisions; + int update_batch = net.batch*net.subdivisions * get_sequence_value(net); float rate = get_current_rate(net); for(i = 0; i < net.n; ++i){ layer l = net.layers[i]; @@ -171,6 +171,22 @@ void forward_backward_network_gpu(network net, float *x, float *y) cuda_convert_f32_to_f16(l.self_layer->weights_gpu, l.self_layer->nweights, l.self_layer->weights_gpu16); cuda_convert_f32_to_f16(l.output_layer->weights_gpu, l.output_layer->nweights, l.output_layer->weights_gpu16); } + else if (l.type == CONV_LSTM && l.wf->weights_gpu && l.wf->weights_gpu16) { + assert((l.wf->c * l.wf->n * l.wf->size * l.wf->size) > 0); + if (l.peephole) { + cuda_convert_f32_to_f16(l.vf->weights_gpu, l.vf->nweights, l.vf->weights_gpu16); + cuda_convert_f32_to_f16(l.vi->weights_gpu, l.vi->nweights, l.vi->weights_gpu16); + cuda_convert_f32_to_f16(l.vo->weights_gpu, l.vo->nweights, l.vo->weights_gpu16); + } + cuda_convert_f32_to_f16(l.wf->weights_gpu, l.wf->nweights, l.wf->weights_gpu16); + cuda_convert_f32_to_f16(l.wi->weights_gpu, l.wi->nweights, l.wi->weights_gpu16); + cuda_convert_f32_to_f16(l.wg->weights_gpu, l.wg->nweights, l.wg->weights_gpu16); + cuda_convert_f32_to_f16(l.wo->weights_gpu, l.wo->nweights, l.wo->weights_gpu16); + cuda_convert_f32_to_f16(l.uf->weights_gpu, l.uf->nweights, l.uf->weights_gpu16); + cuda_convert_f32_to_f16(l.ui->weights_gpu, l.ui->nweights, l.ui->weights_gpu16); + cuda_convert_f32_to_f16(l.ug->weights_gpu, l.ug->nweights, l.ug->weights_gpu16); + cuda_convert_f32_to_f16(l.uo->weights_gpu, l.uo->nweights, l.uo->weights_gpu16); + } } } #endif @@ -184,7 +200,9 @@ float train_network_datum_gpu(network net, float *x, float *y) *net.seen += net.batch; forward_backward_network_gpu(net, x, y); float error = get_network_cost(net); - if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net); + //if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net); + const int sequence = get_sequence_value(net); + if (((*net.seen) / net.batch) % (net.subdivisions*sequence) == 0) update_network_gpu(net); return error; } @@ -219,7 +237,7 @@ void pull_updates(layer l) { if(l.type == CONVOLUTIONAL){ cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n); - cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c); + cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.nweights); if(l.scale_updates) cuda_pull_array(l.scale_updates_gpu, l.scale_updates, l.n); } else if(l.type == CONNECTED){ cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs); @@ -231,7 +249,7 @@ void push_updates(layer l) { if(l.type == CONVOLUTIONAL){ cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n); - cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c); + cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights); if(l.scale_updates) cuda_push_array(l.scale_updates_gpu, l.scale_updates, l.n); } else if(l.type == CONNECTED){ cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs); @@ -253,7 +271,7 @@ void merge_weights(layer l, layer base) { if (l.type == CONVOLUTIONAL) { axpy_cpu(l.n, 1, l.biases, 1, base.biases, 1); - axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weights, 1, base.weights, 1); + axpy_cpu(l.nweights, 1, l.weights, 1, base.weights, 1); if (l.scales) { axpy_cpu(l.n, 1, l.scales, 1, base.scales, 1); } @@ -267,7 +285,7 @@ void scale_weights(layer l, float s) { if (l.type == CONVOLUTIONAL) { scal_cpu(l.n, s, l.biases, 1); - scal_cpu(l.n*l.size*l.size*l.c, s, l.weights, 1); + scal_cpu(l.nweights, s, l.weights, 1); if (l.scales) { scal_cpu(l.n, s, l.scales, 1); } @@ -282,7 +300,7 @@ void pull_weights(layer l) { if(l.type == CONVOLUTIONAL){ cuda_pull_array(l.biases_gpu, l.biases, l.n); - cuda_pull_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c); + cuda_pull_array(l.weights_gpu, l.weights, l.nweights); if(l.scales) cuda_pull_array(l.scales_gpu, l.scales, l.n); } else if(l.type == CONNECTED){ cuda_pull_array(l.biases_gpu, l.biases, l.outputs); @@ -294,7 +312,7 @@ void push_weights(layer l) { if(l.type == CONVOLUTIONAL){ cuda_push_array(l.biases_gpu, l.biases, l.n); - cuda_push_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c); + cuda_push_array(l.weights_gpu, l.weights, l.nweights); if(l.scales) cuda_push_array(l.scales_gpu, l.scales, l.n); } else if(l.type == CONNECTED){ cuda_push_array(l.biases_gpu, l.biases, l.outputs); @@ -306,7 +324,7 @@ void distribute_weights(layer l, layer base) { if(l.type == CONVOLUTIONAL){ cuda_push_array(l.biases_gpu, base.biases, l.n); - cuda_push_array(l.weights_gpu, base.weights, l.n*l.size*l.size*l.c); + cuda_push_array(l.weights_gpu, base.weights, l.nweights); if(base.scales) cuda_push_array(l.scales_gpu, base.scales, l.n); } else if(l.type == CONNECTED){ cuda_push_array(l.biases_gpu, base.biases, l.outputs); @@ -319,7 +337,7 @@ void merge_updates(layer l, layer base) { if (l.type == CONVOLUTIONAL) { axpy_cpu(l.n, 1, l.bias_updates, 1, base.bias_updates, 1); - axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weight_updates, 1, base.weight_updates, 1); + axpy_cpu(l.nweights, 1, l.weight_updates, 1, base.weight_updates, 1); if (l.scale_updates) { axpy_cpu(l.n, 1, l.scale_updates, 1, base.scale_updates, 1); } @@ -333,7 +351,7 @@ void distribute_updates(layer l, layer base) { if(l.type == CONVOLUTIONAL){ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.n); - cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.n*l.size*l.size*l.c); + cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.nweights); if(base.scale_updates) cuda_push_array(l.scale_updates_gpu, base.scale_updates, l.n); } else if(l.type == CONNECTED){ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.outputs); diff --git a/src/parser.c b/src/parser.c index 782b5820c0f..56a0bcb78f9 100644 --- a/src/parser.c +++ b/src/parser.c @@ -20,6 +20,7 @@ #include "list.h" #include "local_layer.h" #include "lstm_layer.h" +#include "conv_lstm_layer.h" #include "maxpool_layer.h" #include "normalization_layer.h" #include "option_list.h" @@ -61,13 +62,14 @@ LAYER_TYPE string_to_layer_type(char * type) if (strcmp(type, "[crnn]")==0) return CRNN; if (strcmp(type, "[gru]")==0) return GRU; if (strcmp(type, "[lstm]")==0) return LSTM; + if (strcmp(type, "[conv_lstm]") == 0) return CONV_LSTM; if (strcmp(type, "[rnn]")==0) return RNN; if (strcmp(type, "[conn]")==0 || strcmp(type, "[connected]")==0) return CONNECTED; if (strcmp(type, "[max]")==0 || strcmp(type, "[maxpool]")==0) return MAXPOOL; - if (strcmp(type, "[reorg]")==0) return REORG; - if (strcmp(type, "[reorg_old]") == 0) return REORG_OLD; + if (strcmp(type, "[reorg3d]")==0) return REORG; + if (strcmp(type, "[reorg]") == 0) return REORG_OLD; if (strcmp(type, "[avg]")==0 || strcmp(type, "[avgpool]")==0) return AVGPOOL; if (strcmp(type, "[dropout]")==0) return DROPOUT; @@ -148,6 +150,7 @@ local_layer parse_local(list *options, size_params params) convolutional_layer parse_convolutional(list *options, size_params params) { int n = option_find_int(options, "filters",1); + int groups = option_find_int_quiet(options, "groups", 1); int size = option_find_int(options, "size",1); int stride = option_find_int(options, "stride",1); int pad = option_find_int_quiet(options, "pad",0); @@ -168,7 +171,7 @@ convolutional_layer parse_convolutional(list *options, size_params params) int xnor = option_find_int_quiet(options, "xnor", 0); int use_bin_output = option_find_int_quiet(options, "bin_output", 0); - convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); + convolutional_layer layer = make_convolutional_layer(batch,1,h,w,c,n,groups,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam, use_bin_output, params.index); layer.flipped = option_find_int_quiet(options, "flipped", 0); layer.dot = option_find_float_quiet(options, "dot", 0); @@ -191,12 +194,13 @@ layer parse_crnn(list *options, size_params params) int output_filters = option_find_int(options, "output",1); int hidden_filters = option_find_int(options, "hidden",1); + int groups = option_find_int_quiet(options, "groups", 1); char *activation_s = option_find_str(options, "activation", "logistic"); ACTIVATION activation = get_activation(activation_s); int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); int xnor = option_find_int_quiet(options, "xnor", 0); - layer l = make_crnn_layer(params.batch, params.w, params.h, params.c, hidden_filters, output_filters, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); + layer l = make_crnn_layer(params.batch, params.h, params.w, params.c, hidden_filters, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, xnor); l.shortcut = option_find_int_quiet(options, "shortcut", 0); @@ -239,6 +243,31 @@ layer parse_lstm(list *options, size_params params) return l; } +layer parse_conv_lstm(list *options, size_params params) +{ + // a ConvLSTM with a larger transitional kernel should be able to capture faster motions + int size = option_find_int_quiet(options, "size", 3); + int stride = option_find_int_quiet(options, "stride", 1); + int pad = option_find_int_quiet(options, "pad", 0); + int padding = option_find_int_quiet(options, "padding", 0); + if (pad) padding = size / 2; + + int output_filters = option_find_int(options, "output", 1); + int groups = option_find_int_quiet(options, "groups", 1); + char *activation_s = option_find_str(options, "activation", "LINEAR"); + ACTIVATION activation = get_activation(activation_s); + int batch_normalize = option_find_int_quiet(options, "batch_normalize", 0); + int xnor = option_find_int_quiet(options, "xnor", 0); + int peephole = option_find_int_quiet(options, "peephole", 0); + + layer l = make_conv_lstm_layer(params.batch, params.h, params.w, params.c, output_filters, groups, params.time_steps, size, stride, padding, activation, batch_normalize, peephole, xnor); + + l.state_constrain = option_find_int_quiet(options, "state_constrain", params.time_steps * 32); + l.shortcut = option_find_int_quiet(options, "shortcut", 0); + + return l; +} + connected_layer parse_connected(list *options, size_params params) { int output = option_find_int(options, "output",1); @@ -640,13 +669,16 @@ void parse_net_options(list *options, network *net) net->batch = option_find_int(options, "batch",1); net->learning_rate = option_find_float(options, "learning_rate", .001); net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001); - net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 500); + net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 1000); + net->batches_cycle_mult = option_find_int_quiet(options, "sgdr_mult", 2); net->momentum = option_find_float(options, "momentum", .9); net->decay = option_find_float(options, "decay", .0001); int subdivs = option_find_int(options, "subdivisions",1); net->time_steps = option_find_int_quiet(options, "time_steps",1); net->track = option_find_int_quiet(options, "track", 0); net->augment_speed = option_find_int_quiet(options, "augment_speed", 2); + net->init_sequential_subdivisions = net->sequential_subdivisions = option_find_int_quiet(options, "sequential_subdivisions", subdivs); + if (net->sequential_subdivisions > subdivs) net->init_sequential_subdivisions = net->sequential_subdivisions = subdivs; net->try_fix_nan = option_find_int_quiet(options, "try_fix_nan", 0); net->batch /= subdivs; net->batch *= net->time_steps; @@ -666,6 +698,7 @@ void parse_net_options(list *options, network *net) net->max_crop = option_find_int_quiet(options, "max_crop",net->w*2); net->min_crop = option_find_int_quiet(options, "min_crop",net->w); net->flip = option_find_int_quiet(options, "flip", 1); + net->blur = option_find_int_quiet(options, "blur", 0); net->angle = option_find_float_quiet(options, "angle", 0); net->aspect = option_find_float_quiet(options, "aspect", 1); @@ -691,30 +724,44 @@ void parse_net_options(list *options, network *net) if(net->policy == STEP){ net->step = option_find_int(options, "step", 1); net->scale = option_find_float(options, "scale", 1); - } else if (net->policy == STEPS){ + } else if (net->policy == STEPS || net->policy == SGDR){ char *l = option_find(options, "steps"); char *p = option_find(options, "scales"); - if(!l || !p) error("STEPS policy must have steps and scales in cfg file"); + char *s = option_find(options, "seq_scales"); + if(net->policy == STEPS && (!l || !p)) error("STEPS policy must have steps and scales in cfg file"); - int len = strlen(l); - int n = 1; - int i; - for(i = 0; i < len; ++i){ - if (l[i] == ',') ++n; - } - int* steps = (int*)calloc(n, sizeof(int)); - float* scales = (float*)calloc(n, sizeof(float)); - for(i = 0; i < n; ++i){ - int step = atoi(l); - float scale = atof(p); - l = strchr(l, ',')+1; - p = strchr(p, ',')+1; - steps[i] = step; - scales[i] = scale; + if (l) { + int len = strlen(l); + int n = 1; + int i; + for (i = 0; i < len; ++i) { + if (l[i] == ',') ++n; + } + int* steps = (int*)calloc(n, sizeof(int)); + float* scales = (float*)calloc(n, sizeof(float)); + float* seq_scales = (float*)calloc(n, sizeof(float)); + for (i = 0; i < n; ++i) { + float scale = 1.0; + if (p) { + scale = atof(p); + p = strchr(p, ',') + 1; + } + float sequence_scale = 1.0; + if (s) { + sequence_scale = atof(s); + s = strchr(s, ',') + 1; + } + int step = atoi(l); + l = strchr(l, ',') + 1; + steps[i] = step; + scales[i] = scale; + seq_scales[i] = sequence_scale; + } + net->scales = scales; + net->steps = steps; + net->seq_scales = seq_scales; + net->num_steps = n; } - net->scales = scales; - net->steps = steps; - net->num_steps = n; } else if (net->policy == EXP){ net->gamma = option_find_float(options, "gamma", 1); } else if (net->policy == SIG){ @@ -789,6 +836,8 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps) l = parse_gru(options, params); }else if(lt == LSTM){ l = parse_lstm(options, params); + }else if (lt == CONV_LSTM) { + l = parse_conv_lstm(options, params); }else if(lt == CRNN){ l = parse_crnn(options, params); }else if(lt == CONNECTED){ @@ -950,8 +999,8 @@ void save_convolutional_weights_binary(layer l, FILE *fp) pull_convolutional_layer(l); } #endif - binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights); - int size = l.c*l.size*l.size; + int size = (l.c/l.groups)*l.size*l.size; + binarize_weights(l.weights, l.n, size, l.binary_weights); int i, j, k; fwrite(l.biases, sizeof(float), l.n, fp); if (l.batch_normalize){ @@ -986,7 +1035,7 @@ void save_convolutional_weights(layer l, FILE *fp) pull_convolutional_layer(l); } #endif - int num = l.n*l.c*l.size*l.size; + int num = l.nweights; fwrite(l.biases, sizeof(float), l.n, fp); if (l.batch_normalize){ fwrite(l.scales, sizeof(float), l.n, fp); @@ -1076,6 +1125,20 @@ void save_weights_upto(network net, char *filename, int cutoff) save_connected_weights(*(l.ui), fp); save_connected_weights(*(l.ug), fp); save_connected_weights(*(l.uo), fp); + } if (l.type == CONV_LSTM) { + if (l.peephole) { + save_convolutional_weights(*(l.vf), fp); + save_convolutional_weights(*(l.vi), fp); + save_convolutional_weights(*(l.vo), fp); + } + save_convolutional_weights(*(l.wf), fp); + save_convolutional_weights(*(l.wi), fp); + save_convolutional_weights(*(l.wg), fp); + save_convolutional_weights(*(l.wo), fp); + save_convolutional_weights(*(l.uf), fp); + save_convolutional_weights(*(l.ui), fp); + save_convolutional_weights(*(l.ug), fp); + save_convolutional_weights(*(l.uo), fp); } if(l.type == CRNN){ save_convolutional_weights(*(l.input_layer), fp); save_convolutional_weights(*(l.self_layer), fp); @@ -1156,7 +1219,7 @@ void load_convolutional_weights_binary(layer l, FILE *fp) fread(l.rolling_mean, sizeof(float), l.n, fp); fread(l.rolling_variance, sizeof(float), l.n, fp); } - int size = l.c*l.size*l.size; + int size = (l.c / l.groups)*l.size*l.size; int i, j, k; for(i = 0; i < l.n; ++i){ float mean = 0; @@ -1184,7 +1247,7 @@ void load_convolutional_weights(layer l, FILE *fp) //load_convolutional_weights_binary(l, fp); //return; } - int num = l.n*l.c*l.size*l.size; + int num = l.nweights; fread(l.biases, sizeof(float), l.n, fp); //fread(l.weights, sizeof(float), num, fp); // as in connected layer if (l.batch_normalize && (!l.dontloadscales)){ @@ -1214,9 +1277,9 @@ void load_convolutional_weights(layer l, FILE *fp) //} //if(l.c == 3) scal_cpu(num, 1./256, l.weights, 1); if (l.flipped) { - transpose_matrix(l.weights, l.c*l.size*l.size, l.n); + transpose_matrix(l.weights, (l.c/l.groups)*l.size*l.size, l.n); } - //if (l.binary) binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.weights); + //if (l.binary) binarize_weights(l.weights, l.n, (l.c/l.groups)*l.size*l.size, l.weights); #ifdef GPU if(gpu_index >= 0){ push_convolutional_layer(l); @@ -1298,6 +1361,21 @@ void load_weights_upto(network *net, char *filename, int cutoff) load_connected_weights(*(l.ug), fp, transpose); load_connected_weights(*(l.uo), fp, transpose); } + if (l.type == CONV_LSTM) { + if (l.peephole) { + load_convolutional_weights(*(l.vf), fp); + load_convolutional_weights(*(l.vi), fp); + load_convolutional_weights(*(l.vo), fp); + } + load_convolutional_weights(*(l.wf), fp); + load_convolutional_weights(*(l.wi), fp); + load_convolutional_weights(*(l.wg), fp); + load_convolutional_weights(*(l.wo), fp); + load_convolutional_weights(*(l.uf), fp); + load_convolutional_weights(*(l.ui), fp); + load_convolutional_weights(*(l.ug), fp); + load_convolutional_weights(*(l.uo), fp); + } if(l.type == LOCAL){ int locations = l.out_w*l.out_h; int size = l.size*l.size*l.c*l.n*locations; diff --git a/src/rnn_layer.c b/src/rnn_layer.c index 28163d754de..4b5b9c2c64c 100644 --- a/src/rnn_layer.c +++ b/src/rnn_layer.c @@ -42,19 +42,19 @@ layer make_rnn_layer(int batch, int inputs, int hidden, int outputs, int steps, l.state = (float*)calloc(batch * hidden * (steps + 1), sizeof(float)); - l.input_layer = (layer*)malloc(sizeof(layer)); + l.input_layer = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.input_layer) = make_connected_layer(batch, steps, inputs, hidden, activation, batch_normalize); l.input_layer->batch = batch; if (l.workspace_size < l.input_layer->workspace_size) l.workspace_size = l.input_layer->workspace_size; - l.self_layer = (layer*)malloc(sizeof(layer)); + l.self_layer = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.self_layer) = make_connected_layer(batch, steps, hidden, hidden, (log==2)?LOGGY:(log==1?LOGISTIC:activation), batch_normalize); l.self_layer->batch = batch; if (l.workspace_size < l.self_layer->workspace_size) l.workspace_size = l.self_layer->workspace_size; - l.output_layer = (layer*)malloc(sizeof(layer)); + l.output_layer = (layer*)calloc(1, sizeof(layer)); fprintf(stderr, "\t\t"); *(l.output_layer) = make_connected_layer(batch, steps, hidden, outputs, activation, batch_normalize); l.output_layer->batch = batch; diff --git a/src/yolo.c b/src/yolo.c index 07a2092c0b1..711470eade2 100644 --- a/src/yolo.c +++ b/src/yolo.c @@ -351,5 +351,5 @@ void run_yolo(int argc, char **argv) else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights); else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights); else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, hier_thresh, cam_index, filename, voc_names, 20, frame_skip, - prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output); + prefix, out_filename, mjpeg_port, json_port, dont_show, ext_output, 0); } diff --git a/src/yolo_console_dll.cpp b/src/yolo_console_dll.cpp index 4f2b2b7432e..60da53f6516 100644 --- a/src/yolo_console_dll.cpp +++ b/src/yolo_console_dll.cpp @@ -129,7 +129,9 @@ cv::Mat slMat2cvMat(sl::Mat &input) { cv::Mat zed_capture_rgb(sl::Camera &zed) { sl::Mat left; zed.retrieveImage(left); - return slMat2cvMat(left).clone(); + cv::Mat left_rgb; + cv::cvtColor(slMat2cvMat(left), left_rgb, CV_RGBA2RGB); + return left_rgb; } cv::Mat zed_capture_3d(sl::Camera &zed) { diff --git a/src/yolo_layer.c b/src/yolo_layer.c index d303b5aa20b..ae48ef7ae07 100644 --- a/src/yolo_layer.c +++ b/src/yolo_layer.c @@ -240,6 +240,7 @@ void forward_yolo_layer(const layer l, network_state state) int class_id = state.truth[t*(4 + 1) + b*l.truths + 4]; if (class_id >= l.classes) { printf(" Warning: in txt-labels class_id=%d >= classes=%d in cfg-file. In txt-labels class_id should be [from 0 to %d] \n", class_id, l.classes, l.classes - 1); + printf(" truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f, class_id = %d \n", truth.x, truth.y, truth.w, truth.h, class_id); getchar(); continue; // if label contains class_id more than number of classes in the cfg-file } @@ -271,6 +272,9 @@ void forward_yolo_layer(const layer l, network_state state) } for(t = 0; t < l.max_boxes; ++t){ box truth = float_to_box_stride(state.truth + t*(4 + 1) + b*l.truths, 1); + if (truth.x < 0 || truth.y < 0 || truth.x > 1 || truth.y > 1 || truth.w < 0 || truth.h < 0) { + printf(" Wrong label: truth.x = %f, truth.y = %f, truth.w = %f, truth.h = %f \n", truth.x, truth.y, truth.w, truth.h); + } int class_id = state.truth[t*(4 + 1) + b*l.truths + 4]; if (class_id >= l.classes) continue; // if label contains class_id more than number of classes in the cfg-file diff --git a/src/yolo_v2_class.cpp b/src/yolo_v2_class.cpp index 6eaa38b7291..cffbd625718 100644 --- a/src/yolo_v2_class.cpp +++ b/src/yolo_v2_class.cpp @@ -22,6 +22,7 @@ extern "C" { #include #include +#define NFRAMES 3 //static Detector* detector = NULL; static std::unique_ptr detector;