logo
0
0
Login

HWDB1.0 PyTorch 示例项目(CPU 优先)

说明:本仓库提供用于训练与推理的脚本,目标是基于 HWDB1.0 数据集训练一个自定义 CNN(使用 PyTorch,默认在 CPU 上运行)。因为 HWDB 真实下载可能需要注册与手动获取,本项目在找不到数据时会回退到 MNIST 做演示。

目录结构(将由脚本创建或期待):

  • data/hwdb/ : 如果你已准备好 HWDB 的图像数据,可放在这里,要求为 ImageFolder 风格:data/hwdb/train/<label>/*.pngdata/hwdb/val/<label>/*.png
  • models/ : 训练检查点保存位置。

快速开始

  1. 安装依赖(示例,按需选择 CPU 或 CUDA 版本的 torch):
pip install -r requirements.txt # 若需要 CPU-only 的 PyTorch,可以参考:https://pytorch.org/get-started/locally/
  1. 下载/准备数据:
  • 若你有 HWDB 的下载 URL,可以使用 scripts/download_hwdb.py 来下载并解压(需提供 url)。
  • 或手动把数据解压后整理成 ImageFolder 结构并放到 data/hwdb/

如果你想直接从 NLPR 官方站点下载 HWDB1.0 的 feature 数据(我已测试可用),可以运行下面的脚本或命令:

使用提供的脚本(推荐):

bash scripts/download_hwdb_nlpr.sh

或直接运行等效命令(把工作目录切到仓库根目录):

mkdir -p data/hwdb_raw && cd data/hwdb_raw wget -c https://nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.0trn.zip -O HWDB1.0trn.zip wget -c https://nlpr.ia.ac.cn/databases/download/feature_data/HWDB1.0tst.zip -O HWDB1.0tst.zip unzip -o HWDB1.0trn.zip && unzip -o HWDB1.0tst.zip

下载后会把 feature 数据(.mpf 文件)解压到 data/hwdb_raw/HWDB1.0trn/data/hwdb_raw/HWDB1.0tst/

注意:上述 feature 数据为特征文件(.mpf),不是原始字符图像(.gnt)。如果你需要像素级图像用于 CNN 训练,请考虑下载“Character Sample Data”(Gnt1.0TrainPart*.zip / Gnt1.0Test.zip),这些文件体积较大(每个部分 ~1GB)。

  1. 训练模型(默认使用 CPU,如果你想使用 GPU,请在 train.py 中修改 device 为 cuda):
python train.py --data-dir data/hwdb --epochs 10 --batch-size 64

如果 data/hwdb 不存在或为空,脚本会自动回退到 MNIST 小示例来演示训练流程。

  1. 运行 Gradio 推理界面(会在本机打开一个 web GUI):
python infer_gradio.py --model-path models/checkpoint.pth

文件说明

  • scripts/download_hwdb.py : 下载并解压 HWDB(若可用)
  • train.py : 训练脚本,包含数据加载、模型定义(使用 src/models.py)、训练/验证循环
  • infer_gradio.py : Gradio 推理界面,加载模型并提供上传/绘制图像的接口
  • src/dataset.py : 数据集帮助函数(支持 ImageFolder 与 MNIST 回退)
  • src/models.py : 自定义 CNN 实现

注意

  • 本项目默认用 CPU 以保证兼容性;如果你确实有 NVIDIA GPU 并安装了 GPU 版 PyTorch,可修改 train.pyinfer_gradio.py 中的 device
  • 若需要我帮你把脚本调整为严格使用 GPU,请告诉我你希望我如何检测/选择设备。

About

Data from https://nlpr.ia.ac.cn/databases/handwriting/Download.html

Language
Python94.2%
Shell5.8%