/*
* MultiClassLearning.java
*
* Yaoyong Li 22/03/2007
*
* $Id: MultiClassLearning.java, v 1.0 2007-03-22 12:58:16 +0000 yaoyong $
*/
package gate.learning.learners;
import gate.learning.ConstantParameters;
import gate.learning.LabelsOfFV;
import gate.learning.LogService;
import gate.learning.SparseFeatureVector;
import gate.learning.DocFeatureVectors.LongCompactor;
import gate.util.GateException;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
/**
* Learning and application by converting the multi-class problem into several
* binary class problems.
*/
public class MultiClassLearning {
/** The training data -- FVs in doc format. */
public DataForLearning dataFVinDoc;
/** Use the data file to save the memory or not. */
//boolean isUsingDataFile = false;
/** Labels for training instances */
ThreadLocal<short[]> labelsTraining;
/** Feature vectors for training instances. */
ThreadLocal<SparseFeatureVector[]> fvsTraining;
/** Labels for training instances - non thread. */
short[] labelsTrainingNT;
/** Feature vectors for training instances - non thread. */
SparseFeatureVector[] fvsTrainingNT;
/** Number of classes for learning */
public int numClasses;
/** The name of class and the number of instances in training document */
public HashMap class2NumberInstances;
/**
* Use the one against all others, or use the one against another. 1 for one
* against all others, 2 for one against another.
*/
short multi2BinaryMode = 1;
/**
* The number of instances in the training data without label (or with label
* null).
*/
public int numNull = 0;
/**
* The executor used to run the (possibly concurrent) binary learning and
* classification tasks.
*/
ExecutorService executor = new InThreadExecutorService();
/** Constructor */
public MultiClassLearning() {
//isUsingDataFile = false;
}
/** Constructor with conversion mode. */
public MultiClassLearning(short mode) {
//isUsingDataFile = false;
multi2BinaryMode = mode;
}
public void setExecutor(ExecutorService executor) {
this.executor = executor;
}
/** Get the training data -- feature vectors and labels. */
public void getDataFromFile(int numDocs, File trainingDataFile, boolean isUsingFile, File tempFVDataFile) {
dataFVinDoc = new DataForLearning(numDocs);
//Open the temp file for writing the fv data
dataFVinDoc.readingFVsFromFile(trainingDataFile, isUsingFile, tempFVDataFile);
// First, get the unique labels from the trainign data
class2NumberInstances = new HashMap();
numNull = obtainUniqueLabels(dataFVinDoc, class2NumberInstances);
numClasses = class2NumberInstances.size();
return;
}
/** Reset the labels for learning for training data filtering. */
public int resetClassInData() {
// Reset the data's class as 1 or -1
int numNeg = 0;
for(int i = 0; i < dataFVinDoc.labelsFVDoc.length; ++i) {
// LabelsOfFeatureVectorDoc labelsDoc = dataFVinDoc.labelsFVDoc[i];
for(int j = 0; j < dataFVinDoc.labelsFVDoc[i].multiLabels.length; ++j) {
if(dataFVinDoc.labelsFVDoc[i].multiLabels[j].num > 0) { // if it has
// label
LabelsOfFV simpLabels = new LabelsOfFV(1);
simpLabels.labels = new int[1];
simpLabels.labels[0] = 1;
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = simpLabels;
} else ++numNeg;
}
}
// Reset the label collection
class2NumberInstances = new HashMap();
numNull = obtainUniqueLabels(dataFVinDoc, class2NumberInstances);
numClasses = class2NumberInstances.size();
return numNeg;
}
/** Learn the models and write them into a set of files */
public void training(final SupervisedLearner learner, File modelFile) {
final int totalNumFeatures = dataFVinDoc.getTotalNumFeatures();
Set classesName = class2NumberInstances.keySet();
final ArrayList array1 = new ArrayList(classesName);
LongCompactor c = new LongCompactor();
Collections.sort(array1, c);
if(LogService.minVerbosityLevel > 1)
System.out.println("total Number of classes for learning is "
+ array1.size());
LogService.logMessage("total Number of classes for learning is "
+ array1.size(), 1);
// Open the mode file for writing the model into it
try {
if(modelFile.exists() && !modelFile.isDirectory()) {
if(!modelFile.delete()) { throw new IOException(
"Existing single-file model " + modelFile + " could not be deleted."); }
}
if(!modelFile.exists()) {
if(!modelFile.mkdirs()) { throw new IOException(
"Couldn't create directory for model files"); }
}
// create a temporary directory for the new learned models
File tmpDirFile = new File(modelFile, "tmp");
if(tmpDirFile.exists()) {
deleteRecursively(tmpDirFile);
}
if(!tmpDirFile.mkdir()) { throw new IOException(
"Couldn't create temporary directory for training"); }
File metaDataFile = new File(tmpDirFile,
ConstantParameters.FILENAMEOFModelMetaData);
BufferedWriter metaDataBuff = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(metaDataFile), "UTF-8"));
// convert the multi-class to binary class -- labels conversion
// can't share these arrays between concurrent threads, so must use
// ThreadLocal
labelsTraining = new ThreadLocal<short[]>() {
protected short[] initialValue() {
return new short[dataFVinDoc.numTraining];
}
};
fvsTraining = new ThreadLocal<SparseFeatureVector[]>() {
protected SparseFeatureVector[] initialValue() {
return new SparseFeatureVector[dataFVinDoc.numTraining];
}
};
List<Callable<String>> tasks = new ArrayList<Callable<String>>();
int classIndex = 1;
switch(multi2BinaryMode){
case 1: // if using the one vs all others appoach
// Write some meta information into the model as a header
LogService.logMessage(
"One against others for multi to binary class conversion.", 1);
writeTrainingMetaData(metaDataBuff, numClasses, numNull, dataFVinDoc
.getNumTrainingDocs(), dataFVinDoc.getTotalNumFeatures(), modelFile
.getAbsolutePath(), learner);
metaDataBuff.close();
for(int iCounter = 0; iCounter < array1.size(); ++iCounter) {
final int i = iCounter;
final File thisClassModelFile = new File(tmpDirFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
tasks.add(new Callable<String>() {
public String call() throws Exception {
short[] myLabelsTraining = labelsTraining.get();
SparseFeatureVector[] myFvsTraining = fvsTraining.get();
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
Multi2BinaryClass.oneVsOthers(dataFVinDoc, array1.get(i)
.toString(), myLabelsTraining, myFvsTraining);
int numTraining = myLabelsTraining.length;
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1)
if(myLabelsTraining[i1] > 0) ++numP;
modelBuff.append("Class=" + array1.get(i).toString()
+ " numTraining=" + numTraining + " numPos=" + numP + "\n");
long time1 = new Date().getTime();
learner.training(modelBuff, myFvsTraining, totalNumFeatures,
myLabelsTraining, numTraining);
long time2 = new Date().getTime();
time2 -= time1;
modelBuff.close();
LogService.logMessage("Training time for class "
+ array1.get(i).toString() + ": " + time2 + "ms", 1);
return null;
}
});
}
break;
case 2: // if using the one vs another appoach
// new numClasses
int numClasses0;
if(numNull > 0)
numClasses0 = (numClasses + 1) * numClasses / 2;
else numClasses0 = (numClasses - 1) * numClasses / 2;
LogService.logMessage(
"One against another for multi to binary class conversion.\n"
+ "So actually we have " + numClasses0 + " binary classes.", 1);
writeTrainingMetaData(metaDataBuff, numClasses0, numNull, dataFVinDoc
.getNumTrainingDocs(), dataFVinDoc.getTotalNumFeatures(), modelFile
.getAbsolutePath(), learner);
metaDataBuff.close();
// first for null vs label
if(numNull > 0) {
for(int jCounter = 0; jCounter < array1.size(); ++jCounter) {
final int j = jCounter;
final File thisClassModelFile = new File(tmpDirFile, String
.format(ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
tasks.add(new Callable<String>() {
public String call() throws Exception {
short[] myLabelsTraining = labelsTraining.get();
SparseFeatureVector[] myFvsTraining = fvsTraining.get();
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
int numTraining;
numTraining = Multi2BinaryClass.oneVsNull(dataFVinDoc, array1
.get(j).toString(), myLabelsTraining, myFvsTraining);
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1) {
if(myLabelsTraining[i1] > 0) ++numP;
}
modelBuff.append("Class1=_NULL" + " Class2="
+ array1.get(j).toString() + " numTraining=" + numTraining
+ " numPos=" + numP + "\n");
long time1 = new Date().getTime();
learner.training(modelBuff, myFvsTraining, totalNumFeatures,
myLabelsTraining, numTraining);
long time2 = new Date().getTime();
time2 -= time1;
modelBuff.close();
LogService.logMessage("Training time for class null against "
+ array1.get(j).toString() + ": " + time2 + "ms", 1);
return null;
}
});
}
}
// then for one vs. another
for(int iCounter = 0; iCounter < array1.size(); ++iCounter) {
final int i = iCounter;
for(int jCounter = i + 1; jCounter < array1.size(); ++jCounter) {
final int j = jCounter;
final File thisClassModelFile = new File(tmpDirFile, String
.format(ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
tasks.add(new Callable<String>() {
public String call() throws Exception {
short[] myLabelsTraining = labelsTraining.get();
SparseFeatureVector[] myFvsTraining = fvsTraining.get();
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
int numTraining;
numTraining = Multi2BinaryClass.oneVsAnother(dataFVinDoc,
array1.get(i).toString(), array1.get(j).toString(),
myLabelsTraining, myFvsTraining);
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1) {
if(myLabelsTraining[i1] > 0) ++numP;
}
modelBuff.append("Class1=" + array1.get(i).toString()
+ " Class2=" + array1.get(j).toString() + " numTraining="
+ numTraining + " numPos=" + numP + "\n");
long time1 = new Date().getTime();
learner.training(modelBuff, myFvsTraining, totalNumFeatures,
myLabelsTraining, numTraining);
long time2 = new Date().getTime();
time2 -= time1;
modelBuff.close();
LogService.logMessage("Training time for class "
+ array1.get(i).toString() + " against "
+ array1.get(j).toString() + ": " + time2 + "ms", 1);
return null;
}
});
}
}
break;
default:
System.out.println("Incorrect multi2BinaryMode value="
+ multi2BinaryMode);
LogService.logMessage("Incorrect multi2BinaryMode value="
+ multi2BinaryMode, 0);
}
// actually run the tasks, print any exception traces that result
LogService.logMessage("Running tasks using executor " + executor, 1);
List<Future<String>> futures = executor.invokeAll(tasks);
boolean success = true;
for(Future<String> f : futures) {
try {
String message = f.get();
if(message != null) {
LogService.logMessage(message, 1);
}
} catch(java.util.concurrent.ExecutionException e) {
success = false;
e.printStackTrace();
}
}
if(success) {
// replace the old model with the new one
moveAllFiles(tmpDirFile, modelFile);
// delete any classNNN.model files beyond the last one we have learned
// on
// this run
for(File orphanedModelFile = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex))); orphanedModelFile.exists(); classIndex++) {
orphanedModelFile.delete();
}
} else {
LogService.logMessage(
"Error during training, old model not overwritten", 1);
}
// delete the temporary directory
deleteRecursively(tmpDirFile);
} catch(IOException e) {
e.printStackTrace();
} catch(InterruptedException e) {
e.printStackTrace();
}
}
/**
* Delete a file or directory. If the argument is a directory, delete its
* contents first, then remove the directory itself.
*/
private void deleteRecursively(File fileOrDir) throws IOException {
if(fileOrDir.isDirectory()) {
for(File f : fileOrDir.listFiles()) {
deleteRecursively(f);
}
}
if(!fileOrDir.delete()) { throw new IOException("Couldn't delete "
+ (fileOrDir.isDirectory() ? "directory " : "file ") + fileOrDir); }
}
/**
* Move all the files from one directory into another.
*
* @param src
* the directory whose contents are to be moved
* @param dest
* the directory into which the files should go
*/
private void moveAllFiles(File src, File dest) throws IOException {
for(String fileName : src.list()) {
File srcFile = new File(src, fileName);
File targetFile = new File(dest, fileName);
if(targetFile.exists() && !targetFile.delete()) { throw new IOException(
"Couldn't delete file " + targetFile); }
if(!srcFile.renameTo(targetFile)) { throw new IOException(
"Couldn't move " + srcFile + " to directory " + dest); }
}
}
/** Apply the model to the data. */
public void apply(final SupervisedLearner learner, File modelFile) {
// Open the mode file and read the model
try {
if(modelFile.exists() && !modelFile.isDirectory()) {
// see whether we're trying to apply an old-style model
// stored all in one file
BufferedReader buff = new BufferedReader(new InputStreamReader(
new FileInputStream(modelFile), "UTF-8"));
String firstLine = buff.readLine();
buff.close();
if(firstLine != null && firstLine.endsWith("#numTrainingDocs")) {
// this is an old-style model, so try and transparently upgrade it to
// the new format
upgradeSingleFileModelToDirectory(modelFile);
} else {
throw new IOException("Unrecognised model file format for file "
+ modelFile);
}
}
if(!modelFile.exists()) { throw new IllegalStateException(
"Model directory " + modelFile + " does not exist"); }
File metaDataFile = new File(modelFile,
ConstantParameters.FILENAMEOFModelMetaData);
BufferedReader metaDataBuff = new BufferedReader(new InputStreamReader(
new FileInputStream(metaDataFile), "UTF-8"));
// Read the training meta information from the model file's header
// include the total number of features and number of tags (numClasses)
int totalNumFeatures;
String learnerNameFromModel = learner.getLearnerName();
// note that reading the training meta data also read the number of class
// in the model, e.g. changing the numClasses.
totalNumFeatures = ReadTrainingMetaData(metaDataBuff,
learnerNameFromModel);
if(LogService.minVerbosityLevel > 1)
System.out.println(" *** numClasses=" + numClasses + " totalfeatures="
+ totalNumFeatures);
metaDataBuff.close();
// compare with the meta data of test data
if(totalNumFeatures < dataFVinDoc.getTotalNumFeatures())
totalNumFeatures = dataFVinDoc.getTotalNumFeatures();
final int finalTotalNumFeatures = totalNumFeatures;
// Apply the model to test feature vectors
long time1 = new Date().getTime();
int classIndex = 1;
List<Callable<Boolean>> tasks = new ArrayList<Callable<Boolean>>();
List<Future<Boolean>> futures = null;
switch(multi2BinaryMode){
case 1:
LogService.logMessage(
"One against others for multi to binary class conversion.\n"
+ "Number of classes in model: " + numClasses, 1);
// Use the tau modification in all cases
learner.isUseTauALLCases = true;
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i)
for(int j = 0; j < dataFVinDoc.trainingFVinDoc[i].getNumInstances(); ++j) {
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = new LabelsOfFV(
numClasses);
dataFVinDoc.labelsFVDoc[i].multiLabels[j].probs = new float[numClasses];
}
// for each class
for(int iClassCounter = 0; iClassCounter < numClasses; ++iClassCounter) {
final int iClass = iClassCounter;
final File thisClassModelFile = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
tasks.add(new Callable<Boolean>() {
public Boolean call() throws Exception {
BufferedReader modelBuff = new BufferedReader(
new InputStreamReader(
new FileInputStream(thisClassModelFile), "UTF-8"));
learner.applying(modelBuff, dataFVinDoc, finalTotalNumFeatures,
iClass);
modelBuff.close();
return Boolean.TRUE;
}
});
}
// actually run the tasks, print any exception traces that result
futures = executor.invokeAll(tasks);
for(Future<Boolean> f : futures) {
try {
f.get();
} catch(java.util.concurrent.ExecutionException e) {
e.printStackTrace();
}
}
if(LogService.minVerbosityLevel > 1)
System.out.println("**** One against all others, numNull="
+ numNull);
break;
case 2:
LogService.logMessage(
"One against another for multi to binary class conversion.", 1);
// not use the tau modification in all cases
learner.isUseTauALLCases = false;
// set the multi class number and allocate the memory
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i)
for(int j = 0; j < dataFVinDoc.trainingFVinDoc[i].getNumInstances(); ++j) {
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = new LabelsOfFV(
numClasses);
dataFVinDoc.labelsFVDoc[i].multiLabels[j].probs = new float[numClasses];
}
// for each class
for(int iClassCounter = 0; iClassCounter < numClasses; ++iClassCounter) {
final int iClass = iClassCounter;
final File thisClassModelFile = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
tasks.add(new Callable<Boolean>() {
public Boolean call() throws Exception {
BufferedReader modelBuff = new BufferedReader(
new InputStreamReader(
new FileInputStream(thisClassModelFile), "UTF-8"));
learner.applying(modelBuff, dataFVinDoc, finalTotalNumFeatures,
iClass);
modelBuff.close();
return Boolean.TRUE;
}
});
}
// actually run the tasks, print any exception traces that result
futures = executor.invokeAll(tasks);
for(Future<Boolean> f : futures) {
try {
f.get();
} catch(java.util.concurrent.ExecutionException e) {
e.printStackTrace();
}
}
PostProcessing postProc = new PostProcessing();
// Get the number of classes of the problem, since the numClasses
// refers to the number
// of classes in the one against another method.
int numClassesL = numClasses * 2;
numClassesL = rootQuaEqn(numClassesL);
if(numNull == 0) numClassesL += 1;
LogService.logMessage("Number of classes in training data: "
+ numClassesL + "\nActuall number of binary classes in model: "
+ numClasses, 1);
if(LogService.minVerbosityLevel > 1)
System.out.println("**** One against another, numNull=" + numNull);
if(numNull > 0)
postProc.voteForOneVSAnotherNull(dataFVinDoc, numClassesL);
else postProc.voteForOneVSAnother(dataFVinDoc, numClassesL);
// Set the number of classes with the correct value.
numClasses = numClassesL;
break;
default:
System.out.println("Incorrect multi2BinaryMode value="
+ multi2BinaryMode);
LogService.logMessage("Incorrect multi2BinaryMode value="
+ multi2BinaryMode, 1);
}
long time2 = new Date().getTime();
time2 -= time1;
LogService.logMessage("Application time for class: " + time2 + "ms", 1);
} catch(IOException e) {
e.printStackTrace();
} catch(InterruptedException e) {
e.printStackTrace();
}
}
/**
* Upgrade an old-style single file model to directory format, with the meta
* data and the individual models in separate files.
*/
public void upgradeSingleFileModelToDirectory(File modelFile)
throws IOException {
// copy the old model file to a backup
byte[] buf = new byte[8192];
int bytesRead = 0;
File backupModelFile = new File(modelFile.getPath() + ".bak");
FileInputStream oldModelIn = new FileInputStream(modelFile);
FileOutputStream backupModelOut = new FileOutputStream(backupModelFile);
while((bytesRead = oldModelIn.read(buf)) >= 0) {
backupModelOut.write(buf, 0, bytesRead);
}
backupModelOut.close();
oldModelIn.close();
buf = null;
// delete the old model file and create a directory in its place
modelFile.delete();
modelFile.mkdir();
// open the backup model file and copy its sections into new files
BufferedReader oldModelsBuff = new BufferedReader(new InputStreamReader(
new FileInputStream(backupModelFile), "UTF-8"));
// first 8 lines are the meta data
File metaDataFile = new File(modelFile,
ConstantParameters.FILENAMEOFModelMetaData);
BufferedWriter metaDataBuff = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(metaDataFile), "UTF-8"));
for(int i = 0; i < 8; i++) {
metaDataBuff.write(oldModelsBuff.readLine());
metaDataBuff.write('\n');
}
metaDataBuff.close();
int classIndex = 1;
BufferedWriter modelWriter = null;
String line = null;
while((line = oldModelsBuff.readLine()) != null) {
if(line.startsWith("Class=") && line.contains("numTraining=")
&& line.contains("numPos=")) {
// found the start of a new model, so close the previous file and start
// the next one
if(modelWriter != null) {
modelWriter.close();
}
File nextModel = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
modelWriter = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(nextModel), "UTF-8"));
}
modelWriter.write(line);
modelWriter.write('\n');
}
if(modelWriter != null) {
modelWriter.close();
}
}
/**
* Get the number of classes in the problem from the number of classes in the
* one vs. another method, by solving a quadratic equation.
*/
private int rootQuaEqn(int numClassesL) {
// The positive root of quadratic equation x^2+x-numClassesL=0.
return (int)((-1 + Math.sqrt(1.0 + numClassesL * 4)) / 2.0);
}
/** Writting the meta information about the learning into the model file. */
public void writeTrainingMetaData(BufferedWriter modelsBuff, int numClasses,
int numNull, int numTrainingDocs, long totalFeatures, String modelFile,
SupervisedLearner learner) throws IOException {
modelsBuff.append(numTrainingDocs + " #numTrainingDocs\n");
modelsBuff.append(numClasses + " #numClasses\n");
modelsBuff.append(numNull + " #numNullLabelInstances\n");
long actualNum = totalFeatures - 5; // because added 5 in DataForLearning
// class
modelsBuff.append(actualNum + " #totalFeatures\n");
modelsBuff.append(modelFile + " #modelFile\n");
modelsBuff.append(learner.getLearnerName() + " #learnerName\n");
modelsBuff.append(learner.getLearnerExecutable() + " #learnerExecutable\n");
modelsBuff.append(learner.getLearnerParams() + " #learnerParams\n");
return;
}
/** Read the meta data from the header of the file. */
public int ReadTrainingMetaData(BufferedReader modelsBuff,
String learnerNameFromModel) throws IOException {
int totalFeatures;
String line;
modelsBuff.readLine(); // read the traing documents
line = modelsBuff.readLine(); // read the number of classes
numClasses = new Integer(line.substring(0, line.indexOf(" "))).intValue();
line = modelsBuff.readLine(); // read the number of classes
numNull = new Integer(line.substring(0, line.indexOf(" "))).intValue();
line = modelsBuff.readLine(); // read the total number of features
totalFeatures = new Integer(line.substring(0, line.indexOf(" ")))
.intValue();
totalFeatures += 5;
modelsBuff.readLine(); // read the model file name
line = modelsBuff.readLine(); // read the learner's name
learnerNameFromModel = line.substring(0, line.indexOf(" "));
modelsBuff.readLine(); // read the learnerExecutable string
modelsBuff.readLine(); // read the learnerParams string
return totalFeatures;
}
/** Obtain the unqilabels from the training data. */
int obtainUniqueLabels(DataForLearning dataFVinDoc,
HashMap class2NumberInstances) {
int numN = 0;
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i)
for(int j = 0; j < dataFVinDoc.labelsFVDoc[i].multiLabels.length; ++j) {
// int label = dataFVinDoc.labelsFVDoc[i].labels[j];
LabelsOfFV multiLabel = dataFVinDoc.labelsFVDoc[i].multiLabels[j];
if(multiLabel.num == 0) ++numN;
for(int j1 = 0; j1 < multiLabel.num; ++j1) {
if(Integer.valueOf(multiLabel.labels[j1]) > 0) {
if(class2NumberInstances.containsKey(multiLabel.labels[j1]))
class2NumberInstances.put(multiLabel.labels[j1],
(new Integer(class2NumberInstances.get(multiLabel.labels[j1])
.toString())) + 1);
else class2NumberInstances.put(multiLabel.labels[j1], "1");
}
}
}
return numN;
}
/** Learn the models and write them into a file -- not use thread*/
public void trainingNoThread(SupervisedLearner learner, File modelFile, boolean isUsingTempDataFile, File tempFVDataFile) {
final int totalNumFeatures = dataFVinDoc.getTotalNumFeatures();
Set classesName = class2NumberInstances.keySet();
final ArrayList array1 = new ArrayList(classesName);
LongCompactor c = new LongCompactor();
Collections.sort(array1, c);
if(LogService.minVerbosityLevel>1)
System.out.println("total Number of classes for learning is "
+ array1.size());
LogService.logMessage("total Number of classes for learning is "
+ array1.size(), 1);
//Open the mode file for writing the model into it
try {
if(modelFile.exists() && !modelFile.isDirectory()) {
if(!modelFile.delete()) { throw new IOException(
"Existing single-file model " + modelFile + " could not be deleted."); }
}
if(!modelFile.exists()) {
if(!modelFile.mkdirs()) { throw new IOException(
"Couldn't create directory for model files"); }
}
// create a temporary directory for the new learned models
File tmpDirFile = new File(modelFile, "tmp");
if(tmpDirFile.exists()) {
deleteRecursively(tmpDirFile);
}
if(!tmpDirFile.mkdir()) { throw new IOException(
"Couldn't create temporary directory for training"); }
File metaDataFile = new File(tmpDirFile,
ConstantParameters.FILENAMEOFModelMetaData);
BufferedWriter metaDataBuff = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream(metaDataFile), "UTF-8"));
int classIndex = 1; //class index for model's file name
// Open the mode file for writing the model into it
// convert the multi-class to binary class -- labels conversion
labelsTrainingNT = new short[dataFVinDoc.numTraining];
fvsTrainingNT = new SparseFeatureVector[dataFVinDoc.numTraining];
switch(multi2BinaryMode){
case 1: // if using the one vs all others appoach
// Write some meta information into the model as a header
LogService.logMessage(
"One against others for multi to binary class conversion.", 1);
writeTrainingMetaData(metaDataBuff, numClasses, numNull, dataFVinDoc
.getNumTrainingDocs(), dataFVinDoc.getTotalNumFeatures(), modelFile
.getAbsolutePath(), learner);
metaDataBuff.close();
for(int iCounter = 0; iCounter < array1.size(); ++iCounter) {
final int i = iCounter;
final File thisClassModelFile = new File(tmpDirFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
Multi2BinaryClass.oneVsOthers(dataFVinDoc,
array1.get(i).toString(), labelsTrainingNT, fvsTrainingNT);
int numTraining = labelsTrainingNT.length;
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1)
if(labelsTrainingNT[i1] > 0) ++numP;
modelBuff.append("Class=" + array1.get(i).toString()
+ " numTraining=" + numTraining + " numPos=" + numP + "\n");
long time1 = new Date().getTime();
if(isUsingTempDataFile) {//using the data file
BufferedReader fvTempRd = new BufferedReader(new InputStreamReader(
new FileInputStream(tempFVDataFile), "UTF-8"));
learner.trainingWithDataFile(modelBuff, fvTempRd, totalNumFeatures,
labelsTrainingNT, numTraining);
fvTempRd.close();
} else
learner.training(modelBuff, fvsTrainingNT, totalNumFeatures,
labelsTrainingNT, numTraining);
modelBuff.close();
long time2 = new Date().getTime();
time2 -= time1;
LogService.logMessage("Training time for class "
+ array1.get(i).toString() + ": " + time2 + "ms", 1);
}
break;
case 2: // if using the one vs another appoach
// new numClasses
int numClasses0;
if(numNull > 0)
numClasses0 = (numClasses + 1) * numClasses / 2;
else numClasses0 = (numClasses - 1) * numClasses / 2;
LogService.logMessage(
"One against another for multi to binary class conversion.\n"
+ "So actually we have " + numClasses0 + " binary classes.", 1);
writeTrainingMetaData(metaDataBuff, numClasses0, numNull, dataFVinDoc
.getNumTrainingDocs(), dataFVinDoc.getTotalNumFeatures(), modelFile
.getAbsolutePath(), learner);
metaDataBuff.close();
// first for null vs label
if(numNull > 0) {
for(int jCounter = 0; jCounter < array1.size(); ++jCounter) {
final int j = jCounter;
final File thisClassModelFile = new File(tmpDirFile, String
.format(ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
int numTraining;
numTraining = Multi2BinaryClass.oneVsNull(dataFVinDoc, array1
.get(j).toString(), labelsTrainingNT, fvsTrainingNT);
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1) {
if(labelsTrainingNT[i1] > 0) ++numP;
}
modelBuff.append("Class1=_NULL" + " Class2="
+ array1.get(j).toString() + " numTraining=" + numTraining
+ " numPos=" + numP + "\n");
long time1 = new Date().getTime();
if(isUsingTempDataFile) {//using the data file
BufferedReader fvTempRd = new BufferedReader(new InputStreamReader(
new FileInputStream(tempFVDataFile), "UTF-8"));
learner.trainingWithDataFile(modelBuff, fvTempRd, totalNumFeatures,
labelsTrainingNT, numTraining);
fvTempRd.close();
} else
learner.training(modelBuff, fvsTrainingNT, totalNumFeatures,
labelsTrainingNT, numTraining);
modelBuff.close();
long time2 = new Date().getTime();
time2 -= time1;
LogService.logMessage("Training time for class null against "
+ array1.get(j).toString() + ": " + time2 + "ms", 1);
}
}
// then for one vs. another
for(int iCounter = 0; iCounter < array1.size(); ++iCounter) {
final int i = iCounter;
for(int jCounter = i + 1; jCounter < array1.size(); ++jCounter) {
final int j = jCounter;
final File thisClassModelFile = new File(tmpDirFile, String
.format(ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
BufferedWriter modelBuff = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
thisClassModelFile), "UTF-8"));
int numTraining;
numTraining = Multi2BinaryClass.oneVsAnother(dataFVinDoc, array1
.get(i).toString(), array1.get(j).toString(), labelsTrainingNT,
fvsTrainingNT);
int numP = 0;
for(int i1 = 0; i1 < numTraining; ++i1) {
if(labelsTrainingNT[i1] > 0) ++numP;
}
modelBuff.append("Class1=_NULL" + " Class2="
+ array1.get(j).toString() + " numTraining=" + numTraining
+ " numPos=" + numP + "\n");
long time1 = new Date().getTime();
if(isUsingTempDataFile) { //using the data file
BufferedReader fvTempRd = new BufferedReader(new InputStreamReader(
new FileInputStream(tempFVDataFile), "UTF-8"));
learner.trainingWithDataFile(modelBuff, fvTempRd, totalNumFeatures,
labelsTrainingNT, numTraining);
fvTempRd.close();
} else
learner.training(modelBuff, fvsTrainingNT, totalNumFeatures,
labelsTrainingNT, numTraining);
modelBuff.close();
long time2 = new Date().getTime();
time2 -= time1;
LogService.logMessage("Training time for class "
+ array1.get(i).toString() + " against "
+ array1.get(j).toString() + ": " + time2 + "ms", 1);
// }
}
}
break;
default:
System.out.println("Incorrect multi2BinaryMode value="
+ multi2BinaryMode);
LogService.logMessage("Incorrect multi2BinaryMode value="
+ multi2BinaryMode, 0);
}
// replace the old model with the new one
moveAllFiles(tmpDirFile, modelFile);
deleteRecursively(tmpDirFile);
//delete the temp data file
tempFVDataFile.delete();
} catch(IOException e) {
e.printStackTrace();
}
}
/** Apply the model to the data - not use thread. */
public void applyNoThread(SupervisedLearner learner, File modelFile) {
// Open the mode file and read the model
try {
if(modelFile.exists() && !modelFile.isDirectory()) {
// see whether we're trying to apply an old-style model
// stored all in one file
BufferedReader buff = new BufferedReader(new InputStreamReader(
new FileInputStream(modelFile), "UTF-8"));
String firstLine = buff.readLine();
buff.close();
if(firstLine != null && firstLine.endsWith("#numTrainingDocs")) {
// this is an old-style model, so try and transparently upgrade it to
// the new format
upgradeSingleFileModelToDirectory(modelFile);
} else {
throw new IOException("Unrecognised model file format for file "
+ modelFile);
}
}
if(!modelFile.exists()) { throw new IllegalStateException(
"Model directory " + modelFile + " does not exist"); }
File metaDataFile = new File(modelFile,
ConstantParameters.FILENAMEOFModelMetaData);
BufferedReader metaDataBuff = new BufferedReader(new InputStreamReader(
new FileInputStream(metaDataFile), "UTF-8"));
// Read the training meta information from the model file's header
// include the total number of features and number of tags (numClasses)
int totalNumFeatures;
String learnerNameFromModel = learner.getLearnerName();
// note that reading the training meta data also read the number of class
// in the model, e.g. changing the numClasses.
totalNumFeatures = ReadTrainingMetaData(metaDataBuff,
learnerNameFromModel);
if(LogService.minVerbosityLevel > 1)
System.out.println(" *** numClasses=" + numClasses + " totalfeatures="
+ totalNumFeatures);
metaDataBuff.close();
// compare with the meta data of test data
if(totalNumFeatures < dataFVinDoc.getTotalNumFeatures())
totalNumFeatures = dataFVinDoc.getTotalNumFeatures();
final int finalTotalNumFeatures = totalNumFeatures;
// Apply the model to test feature vectors
long time1 = new Date().getTime();
int classIndex = 1;
switch(multi2BinaryMode){
case 1:
LogService.logMessage(
"One against others for multi to binary class conversion.\n"
+ "Number of classes in model: " + numClasses, 1);
// Use the tau modification in all cases
learner.isUseTauALLCases = true;
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i)
for(int j = 0; j < dataFVinDoc.trainingFVinDoc[i].getNumInstances(); ++j) {
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = new LabelsOfFV(
numClasses);
dataFVinDoc.labelsFVDoc[i].multiLabels[j].probs = new float[numClasses];
}
// for each class
for(int iClassCounter = 0; iClassCounter < numClasses; ++iClassCounter) {
final int iClass = iClassCounter;
final File thisClassModelFile = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
BufferedReader modelBuff = new BufferedReader(
new InputStreamReader(
new FileInputStream(thisClassModelFile), "UTF-8"));
learner.applying(modelBuff, dataFVinDoc, finalTotalNumFeatures,
iClass);
modelBuff.close();
}
if(LogService.minVerbosityLevel > 1)
System.out.println("**** One against all others, numNull=" + numNull);
break;
case 2:
LogService.logMessage(
"One against another for multi to binary class conversion.", 1);
// not use the tau modification in all cases
learner.isUseTauALLCases = false;
//set the multi class number and allocate the memory
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i)
for(int j = 0; j < dataFVinDoc.trainingFVinDoc[i].getNumInstances(); ++j) {
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = new LabelsOfFV(
numClasses);
dataFVinDoc.labelsFVDoc[i].multiLabels[j].probs = new float[numClasses];
}
// for each class
for(int iClassCounter = 0; iClassCounter < numClasses; ++iClassCounter) {
final int iClass = iClassCounter;
final File thisClassModelFile = new File(modelFile, String.format(
ConstantParameters.FILENAMEOFPerClassModel, Integer
.valueOf(classIndex++)));
BufferedReader modelBuff = new BufferedReader(
new InputStreamReader(
new FileInputStream(thisClassModelFile), "UTF-8"));
learner.applying(modelBuff, dataFVinDoc, finalTotalNumFeatures,
iClass);
modelBuff.close();
}
PostProcessing postProc = new PostProcessing();
// Get the number of classes of the problem, since the numClasses
// refers to the number
// of classes in the one against another method.
int numClassesL = numClasses * 2;
numClassesL = rootQuaEqn(numClassesL);
if(numNull == 0) numClassesL += 1;
LogService.logMessage("Number of classes in training data: "
+ numClassesL + "\nActuall number of binary classes in model: "
+ numClasses, 1);
if(LogService.minVerbosityLevel > 1)
System.out.println("**** One against another, numNull=" + numNull);
if(numNull > 0)
postProc.voteForOneVSAnotherNull(dataFVinDoc, numClassesL);
else postProc.voteForOneVSAnother(dataFVinDoc, numClassesL);
// Set the number of classes with the correct value.
numClasses = numClassesL;
break;
default:
System.out.println("Incorrect multi2BinaryMode value="
+ multi2BinaryMode);
LogService.logMessage("Incorrect multi2BinaryMode value="
+ multi2BinaryMode, 1);
}
long time2 = new Date().getTime();
time2 -= time1;
LogService.logMessage("Application time for class: " + time2 + "ms", 1);
} catch(IOException e) {
e.printStackTrace();
}
}
/**
* Obtain the learner from the learner's name speficied by the learning
* configuration file.
*
* @throws GateException
*/
public static SupervisedLearner obtainLearnerFromName(String learnerName,
String commandLine, String dataFilesName) throws GateException {
SupervisedLearner learner = null;
if(learnerName.equalsIgnoreCase("SVMLibSvmJava")) {
learner = new SvmLibSVM();
learner.setLearnerName(learnerName);
learner.setCommandLine(commandLine + " " + dataFilesName);
learner.getParametersFromCommmand();
} else if(learnerName.equalsIgnoreCase("SVMExec")) {
learner = new SvmForExec();
learner.setLearnerName(learnerName);
learner.setCommandLine(commandLine);
learner.getParametersFromCommmand();
} else if(learnerName.equalsIgnoreCase("PAUM")) {
learner = new Paum();
learner.setCommandLine(commandLine);
learner.getParametersFromCommmand();
} else if(learnerName.equalsIgnoreCase("PAUMExec")) {
learner = new PaumForExec();
learner.setLearnerName(learnerName);
learner.setCommandLine(commandLine);
learner.getParametersFromCommmand();
} else {
throw new GateException("The learner's name \"" + learnerName
+ "\" is not defined!");
}
return learner;
}
}