Skip to content

Aurora11111/crnn-train-pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

58 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CRNN TRAIN

This software implements the Convolutional Recurrent Neural Network (CRNN) in pytorch. Origin software could be found in crnn

Envrionment

python 3.6 pytorch 4.0 opencv2.4 + pytorch + lmdb +wrap_ctc

ATTENTION!

getLmdb.py must run in python2.x

Issue when install warp_ctc_pytorch

  • [ 11%] Building NVCC (Device) object CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o sh: cicc: command not found CMake Error at warpctc_generated_reduce.cu.o.cmake:279 (message): Error generating file /home/rice/warp-ctc/build/CMakeFiles/warpctc.dir/src/./warpctc_generated_reduce.cu.o make[2]: *** [CMakeFiles/warpctc.dir/build.make:256: CMakeFiles/warpctc.dir/src/warpctc_generated_reduce.cu.o] Error 1 make[1]: *** [CMakeFiles/Makefile2:104: CMakeFiles/warpctc.dir/all] Error 2 make: *** [Makefile:130: all] Error 2 you should reinstall your cuda, and make sure it install completely
  • THCudaMallco error https://github.com/baidu-research/warp-ctc/pull/71/files
  • Xtra-Computing/thundersvm#54 (comment)
  • my_error_image ln -s /opt/cuda/include/* /home/rice/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/THC/

Train a new model

Construct dataset following origin guide. For training with variable length, please sort the image according to the text length. reference:https://github.com/Aurora11111/TextRecognitionDataGenerator

  1. 数据预处理

运行/contrib/crnn/tool/getLmdb.py

# 生成的lmdb输出路径
outputPath = '/run/media/rice/DATA/chinese1/lmdb'
# 图片及对应的label
 imgdata = open("/run/media/rice/DATA/chinese1/labels.txt")
  1. 训练模型

运行/contrib/crnn/crnn_main.py

python crnn_main.py [--param val]
--trainroot        训练集路径
--valroot          验证集路径
--workers          CPU工作核数, default=4
--batchSize        设置batchSize大小, default=64
--imgH             图片高度, default=32
--imgW             图片宽度,default =280(所用训练图片均为280*32)
--nh               LSTM隐藏层数, default=256
--niter            训练回合数, default=25
--lr               学习率, default=0.00005
--cuda             使用GPU, action='store_true'
--ngpu             使用GPU的个数, default=1
--crnn             选择预训练模型
--alphabet         设置分类  
--experiment        模型保存目录
--displayInterval   设置多少次迭代显示一次, default=1000
--n_test_disp        每次验证显示的个数, default=10
--valInterval        设置多少次迭代验证一次, default=1000
--saveInterval       设置多少次迭代保存一次模型, default=1000
--adam               使用adma优化器, default='True'
--adadelta           使用adadelta优化器, action='store_true'
--keep_ratio         设置图片保持横纵比缩放, action='store_true'
--random_sample      是否使用随机采样器对数据集进行采样, action='store_true'

    示例:python /contrib/crnn/crnn_main.py --tainroot [训练集路径] --valroot [验证集路径] --nh 128 --cuda --crnn [预训练模型路径]

修改/contrib/crnn/keys.pyalphabet = '012346789'增加或者减少类别

  1. 注意事项 训练和预测采用的类别数和LSTM隐藏层数需保持一致

Train a new model( new nclass is dfferent from old nclass)

when you nclass is diferent from old ones, you can use this to finetune: python finetune.py

Run demo

A demo program can be found in src/demo.py. Before running the demo, download a pretrained model from Baidu Netdisk or Dropbox. This pretrained model is converted from auther offered one by tool. Put the downloaded model file crnn.pth into directory data/. Then launch the demo by:

python demo.py

The demo reads an example image and recognizes its text content.

Example image: my_example_image

Expected output: loading pretrained model from ./data/crnn.pth a-----v--a-i-l-a-bb-l-ee-- => available

About

crnn trian pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.6%
  • Lua 2.4%