TextGrocery 中文 API
TextGrocery 是一个基于 LibLinear 和 结巴分词 的短文本分类工具,特点是高效易用,同时支持中文和英文语料。
安装
通过 GitHub (最新版本)
git clone https://github.com/2shou/TextGrocery.git cd TextGrocery make
通过 pip (更稳定)
pip install tgrocery
性能
- 训练集:来自 32 个类别的 4.8 万条中文新闻标题
- 测试集:来自 32 个类别的 1.6 万条中文新闻标题
- 与 scikit-learn 的 svm 和朴素贝叶斯算法做横向对比
| 分类器 | 准确率(%) | 计算时间(秒) |
|---|---|---|
| scikit-learn(朴素贝叶斯) | 76.8% | 134 |
| scikit-learn(svm) | 76.9% | 121 |
| TextGrocery | 79.6% | 49 |
API 文档
Grocery
class tgrocery.Grocery(name, custom_tokenize=None)
- 确定你的分类项目名
- custom_tokenize 会覆盖默认的分词单元(结巴分词),要求 custom_tokenize 的类型必须是函数
def Grocery.train(train_src, delimiter='\t')
获取训练样本,生成分类模型
train_src 可以是嵌套列表或文件路径
- 嵌套列表:实体是两个字符串构成的 tuple,第一个字符串是类别标签,第二个字符串是语料文本
- 文件路径:一行为一个训练样本,类别标签在前、语料文本在后,默认分隔符是
\t
delimiter 是解析训练样本时所用的分隔符,仅在 train_src 为文件路径时生效
def Grocery.get_load_status()
返回目前模型是否在已训练或已加载的状态
def Grocery.predict(single_text)
- 对单一文本预测其类别(预测前会检测模型是否已训练或已加载)
- 返回一个
GroceryPredictResult 对象
def Grocery.save()
保存模型到本地
- 默认文件夹名是 Grocery 的 name 属性
- 如果本地存在同名文件夹,将被覆盖
def Grocery.load()
从本地加载模型
- 默认文件夹名是 Grocery 的 name 属性
- 分词单元的信息不会被自动加载,如果自定义了分词单元,需要在创建 Grocery 的过程中再次指定
def Grocery.test(test_src, delimiter='\t')
测试模型在测试样本中取得的准确率
test_src 可以是嵌套列表或文件路径
- 嵌套列表:实体是两个字符串构成的 tuple,第一个字符串是类别标签,第二个字符串是语料文本
- 文件路径:一行为一个测试样本,类别标签在前、语料文本在后,默认分隔符是
\\t
delimiter 是解析测试样本时所用的分隔符,仅在 test_src 为文件路径时生效
返回一个 GroceryTestResult 对象
GroceryPredictResult
对新语料预测后的结果
GroceryPredictResult.predicted_y
预测的类别标签
GroceryPredictResult.dec_values
- 对所有类别的决策变量(一个浮点数,可正可负,越大表示归属于该类别的可能性越大)
- dict,key 是类别标签,value 是决策变量
GroceryTestResult
对测试样本测试后的结果
GroceryTestResult.accuracy_overall
不分类别的总体准确率,浮点数,0 到 1 之间
GroceryTestResult.accuracy_labels
- 区分类别的准确率
- dict,key 是类别标签,value 是准确率
GroceryTestResult.recall_labels
- 区分类别的召回率
- dict,key 是类别标签,value 是召回率
def GroceryTestResult.show_result()
- 打印各类别的准确率和召回率表格,方便比较
快速开始
>>> from tgrocery import Grocery
# 新开张一个杂货铺(别忘了取名)
>>> grocery = Grocery('sample')
# 训练文本可以用列表传入
>>> train_src = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('sports', '图文:法网孟菲尔斯苦战进 16 强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与')
]
>>> grocery.train(train_src)
# 也可以用文件传入(默认以 tab 为分隔符,也支持自定义)
>>> grocery.train('train_ch.txt')
# 保存模型
>>> grocery.save()
# 加载模型(名字和保存的一样)
>>> new_grocery = Grocery('sample')
>>> new_grocery.load()
# 预测
>>> new_grocery.predict('考生必读:新托福写作考试评分标准')
education
# 测试
>>> test_src = [
('education', '福建春季公务员考试报名 18 日截止 2 月 6 日考试'),
('sports', '意甲首轮补赛交战记录:米兰客场 8 战不败国米 10 年连胜'),
]
>>> new_grocery.test(test_src)
# 输出测试的准确率
0.5
# 同样可支持文件传入
>>> new_grocery.test('test_ch.txt')
# 自定义分词模块(必须是一个函数)
>>> custom_grocery = Grocery('custom', custom_tokenize=list)
上一篇: 梯度下降优化算法综述
下一篇: 通过…删除数据来提高模型性能?
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!

发布评论