/*
* PostProcessing.java
*
* Yaoyong Li 22/03/2007
*
* $Id: PostProcessing.java, v 1.0 2007-03-22 12:58:16 +0000 yaoyong $
*/
package gate.learning.learners;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;
import gate.learning.ChunkLengthStats;
import gate.learning.LabelsOfFV;
import gate.learning.LabelsOfFeatureVectorDoc;
import gate.learning.LogService;
/**
* Post-processing the resutls from the classifiers for annotate
* text. Some specific post-processing procedure is used
* for improving the performance of chunk learning.
*/
public class PostProcessing {
/** Threshold of the probability of the start and end tokens for chunk. */
double boundaryProb = 0.42;
/** Threshold of the probability of the chunk (multipliaction of the
* probabilities for the start and end tokens.
*/
double entityProb = 0.2;
/** Threshold of the probability for text classification. */
double thresholdC = 0.5;
/** Constructor with the three probability thresholds. */
public PostProcessing(float boundaryP, float entityP,
float thresholdClassificaton) {
this.boundaryProb = boundaryP;
this.entityProb = entityP;
this.thresholdC = thresholdClassificaton;
}
/** Trivial constructor. */
public PostProcessing() {
}
/** Post-processing the classification results fro chunk learning,
* by keeping the consistency of start and end tokens of the same label,
* using the length information of chunk from the training data,
* selecting the chunk with the maximal probability from the overlapped
* chunks.
*/
public void postProcessingChunk(short stage, LabelsOfFV[] multiLabels,
int numClasses, HashSet chunks, HashMap chunkLenHash) {
int num = multiLabels.length;
HashMap<ChunkOrEntity,Integer>tempChunks = new HashMap<ChunkOrEntity,Integer>();
for(int j = 0; j < numClasses; j += 2) { // for start and end token
ChunkLengthStats chunkLen;
//String labelS = new Integer(j / 2 + 1).toString();
int labelS = j / 2 + 1;
if(chunkLenHash.get(labelS) != null)
chunkLen = (ChunkLengthStats)chunkLenHash.get(labelS);
else chunkLen = new ChunkLengthStats();
for(int i = 0; i < num; ++i) {
if(multiLabels[i].probs[j] > boundaryProb) {
for(int i1 = i; i1 < num; ++i1)
// Use the boundary probability and the length of chunk statistics
if(multiLabels[i1].probs[j + 1] > boundaryProb
&& i1 - i + 1 < ChunkLengthStats.maxLen
&& chunkLen.lenStats[i1 - i + 1] > 0) {
//if(multiLabels[i1].probs[j+1]>boundaryProb) {
float entityP = multiLabels[i].probs[j]*multiLabels[i1].probs[j + 1];
if(entityP>entityProb) {
ChunkOrEntity chunk = new ChunkOrEntity(i, i1);
chunk.prob = entityP;
chunk.name = j / 2 + 1;
tempChunks.put(chunk, 1);
//break;
}
}
}
}// End of loop for each instance (i)
}
// Solve the overlap case so that every entity has just one label
if(LogService.minVerbosityLevel>1)
System.out.println("*** numberinTempChunks=" + tempChunks.size());
HashMap<String,ChunkOrEntity>mapChunks = new HashMap<String,ChunkOrEntity>();
for(Object obj : tempChunks.keySet()) {
ChunkOrEntity entity = (ChunkOrEntity)obj;
mapChunks
.put(entity.start + "_" + entity.end + "_" + entity.name, entity);
}
List<String>indexes = new ArrayList<String>(mapChunks.keySet());
// LongCompactor c = new LongCompactor();
Collections.sort(indexes);
for(int i1 = 0; i1 < indexes.size(); ++i1) {
// for(Object ob1:tempChunks.keySet() ) {
Object ob1 = mapChunks.get(indexes.get(i1));
if(tempChunks.get(ob1).toString().equals("1")) {
ChunkOrEntity chunk1 = (ChunkOrEntity)ob1;
for(int j1 = i1 + 1; j1 < indexes.size(); ++j1) {
Object ob2 = mapChunks.get(indexes.get(j1));
// for(Object ob2:tempChunks.keySet()) {
if(tempChunks.get(ob2).toString().equals("1")) {
ChunkOrEntity chunk2 = (ChunkOrEntity)ob2;
if(chunk2.start != chunk1.start || chunk2.end != chunk1.end
|| chunk2.name != chunk1.name) {
// if the two entities overlap
if((chunk1.start >= chunk2.start && chunk1.start <= chunk2.end)
|| (chunk1.end <= chunk2.end && chunk1.end >= chunk2.start)) {
if(chunk1.prob > chunk2.prob) {
tempChunks.put(chunk2, 0);
} else if(chunk1.prob < chunk2.prob) {
tempChunks.put(chunk1, 0);
break; // break the inner loop (ob2)
} else if(chunk1.end - chunk1.start>chunk2.end - chunk2.start) {
tempChunks.put(chunk1, 0);
break;
}
else {
tempChunks.put(chunk2, 0);
}
}
}
}
}// end of the inner loop
}
}// end of the outer loop
for(Object ob1 : tempChunks.keySet()) {
if(tempChunks.get((ChunkOrEntity)ob1).intValue() == 1) chunks.add(ob1);
}
}
/** Post-processing the results for the classification problem
* for the one vs all others method: select the label with the
* maximal probability which is bigger than the pre-defined threshold,
* for one instance.
*/
public void postProcessingClassification(short stage,
LabelsOfFV[] multiLabels, int[] selectedLabels, float[] valueLabels) {
int num = multiLabels.length;
float maxValue;
int maxLabel;
for(int i = 0; i < num; ++i) { // for each instance
maxValue = (float)thresholdC;
maxLabel = -1;
for(int j = 0; j < multiLabels[i].num; ++j) { // for each class
if(multiLabels[i].probs[j] > maxValue) {
maxValue = multiLabels[i].probs[j];
maxLabel = j;
}
selectedLabels[i] = maxLabel;
valueLabels[i] = maxValue;
}
}
}
/** Post-processing for the one vs another method,
* for the training data with null label.
*/
public void voteForOneVSAnotherNull(DataForLearning dataFVinDoc,
int numClasses) {
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i) {
LabelsOfFeatureVectorDoc labelFVsDoc = dataFVinDoc.labelsFVDoc[i];
for(int j = 0; j < labelFVsDoc.multiLabels.length; ++j) {
LabelsOfFV multiLabel0 = dataFVinDoc.labelsFVDoc[i].multiLabels[j];
int[] voteResults = new int[numClasses];
int voteNull;
int kk = 0;
// for the null
voteNull = 0;
for(int j1 = 0; j1 < numClasses; ++j1) {
if(multiLabel0.probs[kk] > thresholdC)
voteResults[j1]++;
else voteNull++;
++kk;
}
// for other label
for(int i1 = 0; i1 < numClasses; ++i1)
for(int j1 = i1 + 1; j1 < numClasses; ++j1) {
if(multiLabel0.probs[kk] > thresholdC)
voteResults[i1]++;
else voteResults[j1]++;
++kk;
}
// Convert the vote results into label
int maxVote = voteNull;
kk = -1;
for(int i1 = 0; i1 < numClasses; ++i1)
if(maxVote < voteResults[i1]) {
maxVote = voteResults[i1];
kk = i1;
}
LabelsOfFV multiLabel = new LabelsOfFV(numClasses);
multiLabel.probs = new float[multiLabel.num];
if(kk >= 0) multiLabel.probs[kk] = (float)1.0;
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = multiLabel;
}
}
}
/** Post-processing for the one vs another method,
* for the training data without the null label.
*/
public void voteForOneVSAnother(DataForLearning dataFVinDoc, int numClasses) {
for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i) {
LabelsOfFeatureVectorDoc labelFVsDoc = dataFVinDoc.labelsFVDoc[i];
for(int j = 0; j < labelFVsDoc.multiLabels.length; ++j) {
LabelsOfFV multiLabel0 = dataFVinDoc.labelsFVDoc[i].multiLabels[j];
int[] voteResults = new int[numClasses];
int kk = 0;
// for other label
for(int i1 = 0; i1 < numClasses; ++i1)
for(int j1 = i1 + 1; j1 < numClasses; ++j1) {
if(multiLabel0.probs[kk] > thresholdC)
voteResults[i1]++;
else voteResults[j1]++;
++kk;
}
// Convert the vote results into label
int maxVote = -1;
kk = -1;
for(int i1 = 0; i1 < numClasses; ++i1)
if(maxVote < voteResults[i1]) {
maxVote = voteResults[i1];
kk = i1;
}
LabelsOfFV multiLabel = new LabelsOfFV(numClasses);
multiLabel.probs = new float[multiLabel.num];
if(kk >= 0) multiLabel.probs[kk] = (float)1.0;
dataFVinDoc.labelsFVDoc[i].multiLabels[j] = multiLabel;
}
}
}
}