说明:本仓库提供用于训练与推理的脚本,目标是基于 HWDB1.0 数据集训练一个自定义 CNN(使用 PyTorch,默认在 CPU 上运行)。因为 HWDB 真实下载可能需要注册与手动获取,本项目在找不到数据时会回退到 MNIST 做演示。
目录结构(将由脚本创建或期待):
data/hwdb/ : 如果你已准备好 HWDB 的图像数据,可放在这里,要求为 ImageFolder 风格:data/hwdb/train/<label>/*.png 和 data/hwdb/val/<label>/*.png。models/ : 训练检查点保存位置。快速开始
pip install -r requirements.txt
# 若需要 CPU-only 的 PyTorch,可以参考:https://pytorch.org/get-started/locally/
scripts/download_hwdb.py 来下载并解压(需提供 url)。ImageFolder 结构并放到 data/hwdb/。如果你想直接从 NLPR 官方站点下载 HWDB1.0 的 feature 数据(我已测试可用),可以运行下面的脚本或命令:
使用提供的脚本(推荐):
bash scripts/download_hwdb_nlpr.sh
或直接运行等效命令(把工作目录切到仓库根目录):
mkdir -p data/hwdb_raw && cd data/hwdb_raw
wget -c https://nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.0trn.zip -O HWDB1.0trn.zip
wget -c https://nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.0tst.zip -O HWDB1.0tst.zip
unzip -o HWDB1.0trn.zip && unzip -o HWDB1.0tst.zip
下载后会把 feature 数据(.mpf 文件)解压到 data/hwdb_raw/HWDB1.0trn/ 和 data/hwdb_raw/HWDB1.0tst/。
注意:上述 feature 数据为特征文件(.mpf),不是原始字符图像(.gnt)。如果你需要像素级图像用于 CNN 训练,请考虑下载“Character Sample Data”(Gnt1.0TrainPart*.zip / Gnt1.0Test.zip),这些文件体积较大(每个部分 ~1GB)。
train.py 中修改 device 为 cuda):python train.py --data-dir data/hwdb --epochs 10 --batch-size 64
如果 data/hwdb 不存在或为空,脚本会自动回退到 MNIST 小示例来演示训练流程。
python infer_gradio.py --model-path models/checkpoint.pth
文件说明
scripts/download_hwdb.py : 下载并解压 HWDB(若可用)train.py : 训练脚本,包含数据加载、模型定义(使用 src/models.py)、训练/验证循环infer_gradio.py : Gradio 推理界面,加载模型并提供上传/绘制图像的接口src/dataset.py : 数据集帮助函数(支持 ImageFolder 与 MNIST 回退)src/models.py : 自定义 CNN 实现注意
train.py 和 infer_gradio.py 中的 device。