数据挖掘与知识发现实验报告
----不平衡数据分类
一:问题描述
不平衡分类问题,是指训练样本数量在类间分布不平衡的模式分类问题.具体地说就是某些类的样本数量远远少于其他类.
二:思路
不平衡数据分类之所以难,用传统的分类方法处理效果差,主要就是两种数据数量的不平衡,差距太大,并且用评测指标衡量分类器时,分错一个正例(数量小的类)和分错一个反例的代价不相等。基于以上原因,自己在处理不平衡数据时第一反应就是消除“不平衡”。 参考老师的代码,使用38个分类器(RandomForest()使用2次), 弄一个训练集和测试集来评估基分类器的好坏,评估每个基分类器的准确率 ,根据准确率给基分类器排序,再对结果投票。当有分类器判断概率为1.0时,剔除这个分类器,最后由剩余sum个分类器投票决定最后的结果。
三:实验结果
UCI数据测试结果:
Cmc.arff : precision=0.4271 recall=0.7477
haberman_new.arff : precision=0.5 recall=0.6296
ionosphere_new.arff : precision=1.0 recall=0.9047
Pima.arff: precision=0.7095 recall=0.8022
letter-recognition.arff : precision=0.8470 recall=0.9759 结果对比:http://datamining./~zq/libid/UCI.htm
通过对比发现:ionosphere_new.arff 这组数据结果precision=1.0有点小偏差。
生物信息学数据测试结果:
feature_microRNA.arff: sn=0.9378 sp=0.8734
feature_SNP.arff: sn=0.8102 sp=0.7438
Cdbox.all.arff: sn=0.9673 sp=0.9411
Hacabox.all.arff sn=0.9384 sp=0.9476
结果对比:http://datamining./~zq/libid/experiments.htm 通过对比发现:这四组数据测试结果还行。
四:实验总结
本实验只是简单的参考老师代码中定义38种基分类器,用这38种分类器分别训练实例集,当有分类器判断概率为1.0时,剔除这个分类器,最后由剩余sum个分类器投票决定最后的结果。尽量避免某些若分类器对结果的影响,突出强分类器的地位。
2302010220不可不戒 2012.12.18
LezgClassifier:
import java.util.Random;
import weka.classifiers.*;
import weka.classifiers.bayes.*;
import weka.classifiers.trees.*;
import weka.classifiers.meta.ThresholdSelector;
import weka.classifiers.meta.Bagging;
import weka.classifiers.meta.MultiClassClassifier;
import weka.classifiers.meta.RandomSubSpace;
import weka.classifiers.lazy.IBk;
import weka.classifiers.meta.MultiBoostAB;
import weka.classifiers.meta.OrdinalClassClassifier;
import weka.classifiers.meta.RandomCommittee;
import weka.classifiers.meta.Dagging;
import weka.classifiers.lazy.IB1;
import weka.classifiers.meta.LogitBoost;
import weka.classifiers.meta.ClassificationViaRegression;
import weka.classifiers.meta.FilteredClassifier;
import weka.classifiers.meta.ClassificationViaClustering;
import weka.classifiers.meta.AttributeSelectedClassifier;
import weka.classifiers.meta.END;
import weka.classifiers.meta.AdaBoostM1;
import weka.classifiers.meta.Decorate;
import weka.core.Instance;
import weka.core.Instances;
public class LezgClassifier {
int[] sort;
int cf_num,sum=38;
int classifier_num = 38;
Classifier[] cf = new Classifier[classifier_num];
public void classifierArray() {
cf[0] = new J48(); // 0 0.24 0.27 0.2----0.23 0.29 cf[1] = new NaiveBayes(); // 0.07 0.48 0.08 0.59
cf[2] = new AdaBoostM1(); // 0.23 0.27 0.43 0.2
cf[3] = new ADTree(); // 0.0 0.2 0.37 0.23
cf[4] = new AttributeSelectedClassifier(); // 0.15 0.32 0.16---0.26
// 0.29
cf[5] = new Bagging(); // 0.3 0.272 0.21 0.24---0.20 0.28 0.24
cf[6] = new IBk(); // 0.22 0.25 cf[7] = new BayesNet(); // 0.07 0.512 0.59 0.22 cf[8] = new BFTree(); // 0.3 0.344 0.11 0.31---0.20 0.29
cf[9] = new ClassificationViaClustering(); // 效果不好 0.15 0.79 0.10 0.82
cf[10] = new ClassificationViaRegression(); // 0.07 0.416 0.27 0.26 cf[11] = new ComplementNaiveBayes(); // 0.07 0.536 0.14 0.34 cf[12] = new Dagging(); // 效果不好 0.07 0.44 0.27 0.30
cf[13] = new DecisionStump(); // 0.07 0.392 0.35 0.23
cf[14] = new Decorate(); // 0.15 0.296 0.11 0.23---0.15 0.27 cf[15] = new DMNBtext(); // 0 0.968 0 1
cf[16] = new END(); // 0.69 0.576 0.24 0.27
cf[17] = new REPTree(); // 0.15 0.336 0.14 0.44
cf[18] = new FilteredClassifier(); // 0.3 0.232 0.27 0.30---0.32 0.25
cf[19] = new FT(); // 0.38 0.32 0.30 0.25
cf[20] = new J48graft(); // 0.07 0.384 0.19 0.48
cf[21] = new RandomForest(); // 0.15 0.752 0.41 0.18
cf[22] = new LMT(); // 0.38 0.216 0.54 0.13
cf[23] = new LogitBoost(); // 0.15 0.344 0.22 0.31---0.24 0.28 cf[24] = new MultiBoostAB(); // 0.15 0.216 0.22 0.27-----0.31 0.27 cf[25] = new MultiClassClassifier(); // 0.30 0.368 0.27 0.30 cf[26] = new NaiveBayesMultinomial(); // 0.15 0.48 0.11 0.52
cf[27] = new NaiveBayesMultinomialUpdateable(); // 0.615 0.128 0.38 0.14
cf[28] = new NaiveBayesUpdateable(); // 0.07 0.576 0.22 0.63 cf[29] = new NBTree(); // 0.538 0.392 0.22 0.33
cf[30] = new OrdinalClassClassifier(); // 0 0.28 0.19 0.32------0.22 // 0.29
cf[31] = new RandomCommittee(); // 0.38 0.288 0.16 0.25
cf[32] = new RandomForest(); // 0.07 0.248 0.05 0.29------0.13 0.28 cf[33] = new RandomSubSpace(); // 0.23 0.46 0.19 0.28
cf[34] = new RandomTree(); // 0.46 0.168 0.16 0.35
cf[35] = new SimpleCart(); // 0.3 0.312 0.27 0.31-----0.26 0.28 cf[36] = new ThresholdSelector(); // 0.07 0.776 0.27 0.19 cf[37] = new IB1(); // 0.11 0.28
}
public String getRevision() {
return ("");
}
public double test(Classifier c, Instances train, Instances test) {// 分类器对于测试集的正确率
double rate = 1.0;
try {
c.buildClassifier(train);
int right = 0;
for (int i = 0; i < test.numInstances(); i++) {
if (c.classifyInstance(test.instance(i)) == test.instance(i)
.classValue()
&& c.classifyInstance(test.instance(i)) == 1) { right++;
}
}
//System.out.println(right);
rate = (double) right / (double) (test.numInstances());
} catch (Exception ex) {
System.out.println(ex.getMessage());
}
return rate;
}
public void buildClassifier(Instances ins) {
Instances t = new Instances(ins, 0);
Instances f = new Instances(ins, 0);
for (int i = 0; i < ins.numInstances(); i++) {
if (ins.instance(i).classValue() == 0.0) {
t.add(ins.instance(i));
} else {
f.add(ins.instance(i));
}
}
if (t.numInstances() > f.numInstances()) {
Instances tmp = t;
t = f;
f = tmp;
}
t = t.resample(new Random());
f = f.resample(new Random());
System.gc();
/* 计算划分的个数 */ int t_num = t.numInstances(); int f_num = f.numInstances(); cf_num = f_num / t_num;
if (cf_num % 2 == 0) { cf_num++; } Instances train_tmp = new Instances(t, 0, (int) (t_num * 0.8)); for (int i = 0; i < (int) (t_num * 0.8); i++) { train_tmp.add(f.instance(i)); } Instances test_tmp = new Instances(t, (int) (t_num * 0.8), (int) (t_num * 0.2)); for (int i = (int) (0.8 * t_num); i < t_num; i++) { test_tmp.add(f.instance(i)); } System.out.println("+++++++++++++++++++++++"); double rate[] = new double[classifier_num]; for (int i = 0; i < classifier_num; i++) { rate[i] = test(cf[i], train_tmp, test_tmp); if(rate[i]==1.0){ rate[i]=0.0;//当分类器判断真确概率为1.0时,剔除 sum--; } System.out.println(rate[i] + " "+i); } System.out.println("+++++++++++++++++++++++"); /* 根据准确率给基分类器排序 */ sort = new int[classifier_num]; for (int i = 0; i < classifier_num; i++) { sort[i] = 0; for (int j = 0; j < classifier_num; j++) { if (rate[j] > rate[sort[i]]) { sort[i] = j; } } rate[sort[i]] = 0; } /*for(int i=0;i<classifier_num;i++) System.out.println("======"+sort[i]);*/ }
//对剩余sum个分类器分类结果进行投票
public double classifyInstance(Instance ins) { double result = 0.0;
} } try { double t = 0.0, f = 0.0; for (int j = 0; j <sum ; j++) { if (cf[sort[j]].classifyInstance(ins) == 0.0) { t++; } else { f++; } } if (t / (t + f) < 0.5) result = 1.0; } catch (Exception ex) { System.out.println(ex.getMessage()); } return result;
Main:
import java.io.*;
import weka.core.*;
public class Main {
/**
* @param args
*
* @throws Exception
*/
public static void main(String[] args) throws Exception {
BufferedReader br = null;
br = new BufferedReader(new FileReader("hacabox.all.arff")); Instances ins = new Instances(br);
br.close();
ins.setClassIndex(ins.numAttributes() - 1);// 最后一个属性是类别,必须得有这句
LezgClassifier c = new LezgClassifier();
c.classifierArray();
c.buildClassifier(ins);
long TN = 0, TP = 0, FN = 0, FP = 0;
for (int j = 0; j < ins.numInstances(); j++) {
if (ins.instance(j).classValue() == 0.0
&& c.classifyInstance(ins.instance(j)) == 0.0) { TP++;
} else if (ins.instance(j).classValue() == 1.0
&& c.classifyInstance(ins.instance(j)) == 1.0) { TN++;
} else if (ins.instance(j).classValue() == 0.0
&& c.classifyInstance(ins.instance(j)) == 1.0) { FN++;
} else if (ins.instance(j).classValue() == 1.0
&& c.classifyInstance(ins.instance(j)) == 0.0) { FP++;
}
}
double sn = (double) TP / (double) (TP + FN);
double pre = (double) TP / (double) (TP + FP);
double sp = (double) TN / (double) (FP + TN);
double acc = (double) (TP + TN) / (double) (TP + TN + FP + FN); double mcc = (double) (TP * TN - FP * FN)
/ Math.sqrt((double) ((TN + FN) * (TN + FP) * (TP + FN) * (TP + FP)));
System.out.println("|||" + FP);
System.out.println("|||" + TP);
System.out.println("sn=" + sn);
System.out.println("sp=" + sp);
System.out.println("acc=" + acc);
System.out.println("mcc=" + mcc);
System.out.println();
System.out.println("precision=" + pre);
System.out.println("recall=" + sn);
}
}