Training a Text Classification Model Using SQLFlow
This is a tutorial on how to train a Text Classification Model Using SQLFlow. Note that the steps in this tutorial may be changed during the development of SQLFlow, we only provide a way that simply works for the current version.
To support custom models like CNN text classification, you may check out the current design for ongoing development.
In this tutorial we use two datasets both for english and chinese text classification. The case using chinese dataset is more complicated since Chinese sentences can not be segmented by spaces. You can download the full dataset from:
Steps to Process and Train With IMDB Dataset
- The
imdb
database is already loaded in our Docker image, or you can use this script to download, preprocess and insert data into your own MySQL database. - Use the following statements to train and predict using SQLFlow:
%%sqlflow SELECT content, class FROM imdb.train TO TRAIN DNNClassifier WITH model.n_classes = 2, model.hidden_units = [128, 64] LABEL class INTO sqlflow_models.my_text_model_en;
%%sqlflow SELECT * FROM imdb.test TO PREDICT imdb.predict.class USING sqlflow_models.my_text_model_en;
- Then you can get predict result from table
imdb.predict
.
Train and Predict Using Custom Keras Model
If you want to train you own custom model written by keras you may need to follow the below steps:
- Checkout our “models” repo:
git clone https://github.com/sql-machine-learning/models.git
- Put your custom model under
sqlflow_models/
directory and add importing lines insqlflow_models/__init__.py
, we only support custom model using keras subclass model. - Install models repo on your server you wish to run the training:
python setup.py install
. -
Modify above SQL statement to use custom model by simply change the model name to
sqlflow_models.YourAwesomeModel
like:%%sqlflow SELECT content, class FROM imdb.train limit 100 TO TRAIN sqlflow_models.StackedBiLSTMClassifier WITH model.n_classes = 2, model.stack_units = [64,32], model.hidden_size = 64, train.epoch = 10, train.batch_size = 64 column EMBEDDING(SEQ_CATEGORY_ID(content, 16000), 128, sum) LABEL class INTO sqlflow_models.my_custom_model;
Steps to Run Chinese Text Classification Dataset
- Download the dataset from the above link and unpack
toutiao_cat_data.txt.zip
. - Copy
toutiao_cat_data.txt
to/var/lib/mysql-files/
on the server your MySQL located on, this is because MySQL may prevent importing data from an untrusted location. - Login to MySQL command line like
mysql -uroot -p
and create a database and table to load the dataset, note the table must create withCHARSET=utf8 COLLATE=utf8_unicode_ci
so that the Chinese texts can be correctly shown.%%sqlflow CREATE DATABASE toutiao; CREATE TABLE `train` ( `id` bigint(20) NOT NULL, `class_id` int(3) NOT NULL, `class_name` varchar(100) NOT NULL, `news_title` varchar(255) NOT NULL, `news_keywords` varchar(255) NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci; CREATE TABLE `train_processed` ( `id` bigint(20) NOT NULL, `class_id` int(3) NOT NULL, `class_name` varchar(100) NOT NULL, `news_title` TEXT NOT NULL, `news_keywords` varchar(255) NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci; CREATE TABLE `test_processed` ( `id` bigint(20) NOT NULL, `class_id` int(3) NOT NULL, `class_name` varchar(100) NOT NULL, `news_title` TEXT NOT NULL, `news_keywords` varchar(255) NOT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci; COMMIT;
- In the MySQL shell, type below line to load the dataset into created table:
%%sqlflow LOAD DATA LOCAL INFILE '/var/lib/mysql-files/toutiao_cat_data.txt' INTO TABLE train CHARACTER SET utf8 FIELDS TERMINATED by '_!_' LINES TERMINATED by "\n";
- Run this
python script to generate a vocabulary, and process the raw news title texts to padded word ids. The max length of the segmented sentence is
92
. Note that this python script also change theclass_id
column’s value to0~17
which originally is100~117
since we accept label start from0
. - Split some of the data into a validation table, and remove the validation
data from train data:
%%sqlflow INSERT INTO `test_processed` (`id`, `class_id`, `class_name`, `news_title`, `news_keywords`) SELECT `id`, `class_id`, `class_name`, `news_title`, `news_keywords` FROM `train_processed` ORDER BY RAND() LIMIT 5000; DELETE FROM `train_processed` WHERE id IN ( SELECT id FROM `test_processed` AS p )
- Then use the following statements to train and predict using SQLFlow:
%%sqlflow SELECT news_title, class_id FROM toutiao.train_processed TO TRAIN DNNClassifier WITH model.n_classes = 17, model.hidden_units = [128, 512] LABEL class_id INTO sqlflow_models.my_text_model;
%%sqlflow SELECT * FROM toutiao.test_processed TO PREDICT toutiao.predict.class_id USING sqlflow_models.my_text_model;
- Then you can get predict result from table
toutiao.predict
: