Log in Help
Print
Homereleasesgate-5.1-beta2-build3402-ALLpluginsLearningsrcgatelearninglearners 〉 PostProcessing.java
 
/*
 *  PostProcessing.java
 * 
 *  Yaoyong Li 22/03/2007
 *
 *  $Id: PostProcessing.java, v 1.0 2007-03-22 12:58:16 +0000 yaoyong $
 */
package gate.learning.learners;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.HashMap;
import java.util.List;

import gate.learning.ChunkLengthStats;
import gate.learning.LabelsOfFV;
import gate.learning.LabelsOfFeatureVectorDoc;
import gate.learning.LogService;
/**
 * Post-processing the resutls from the classifiers for annotate
 * text. Some specific post-processing procedure is used
 * for improving the performance of chunk learning.
 */
public class PostProcessing {
  /** Threshold of the probability of the start and end tokens for chunk. */
  double boundaryProb = 0.42;
  /** Threshold of the probability of the chunk (multipliaction of the
   * probabilities for the start and end tokens.
   */
  double entityProb = 0.2;
  /** Threshold of the probability for text classification. */
  double thresholdC = 0.5;
  /** Constructor with the three probability thresholds. */
  public PostProcessing(float boundaryP, float entityP,
    float thresholdClassificaton) {
    this.boundaryProb = boundaryP;
    this.entityProb = entityP;
    this.thresholdC = thresholdClassificaton;
  }
  /** Trivial constructor. */
  public PostProcessing() {
  }
  /** Post-processing the classification results fro chunk learning,
   * by keeping the consistency of start and end tokens of the same label,
   * using the length information of chunk from the training data,
   * selecting the chunk with the maximal probability from the overlapped
   * chunks.  
   */
  public void postProcessingChunk(short stage, LabelsOfFV[] multiLabels,
    int numClasses, HashSet chunks, HashMap chunkLenHash) {
    int num = multiLabels.length;
    HashMap<ChunkOrEntity,Integer>tempChunks = new HashMap<ChunkOrEntity,Integer>();
    for(int j = 0; j < numClasses; j += 2) { // for start and end token
      ChunkLengthStats chunkLen;
      //String labelS = new Integer(j / 2 + 1).toString();
      int labelS = j / 2 + 1;
      if(chunkLenHash.get(labelS) != null)
        chunkLen = (ChunkLengthStats)chunkLenHash.get(labelS);
      else chunkLen = new ChunkLengthStats();
      for(int i = 0; i < num; ++i) {
        if(multiLabels[i].probs[j] > boundaryProb) {
          for(int i1 = i; i1 < num; ++i1)
            // Use the boundary probability and the length of chunk statistics
            if(multiLabels[i1].probs[j + 1] > boundaryProb
              && i1 - i + 1 < ChunkLengthStats.maxLen
              && chunkLen.lenStats[i1 - i + 1] > 0) {
               //if(multiLabels[i1].probs[j+1]>boundaryProb) {
              float entityP = multiLabels[i].probs[j]*multiLabels[i1].probs[j + 1];
              if(entityP>entityProb) {
                ChunkOrEntity chunk = new ChunkOrEntity(i, i1);
                chunk.prob = entityP;
                chunk.name = j / 2 + 1;
                tempChunks.put(chunk, 1);
                //break;
              }
            }
        }
      }// End of loop for each instance (i)
    }
    // Solve the overlap case so that every entity has just one label
    if(LogService.minVerbosityLevel>1)
      System.out.println("*** numberinTempChunks=" + tempChunks.size());
    HashMap<String,ChunkOrEntity>mapChunks = new HashMap<String,ChunkOrEntity>();
    for(Object obj : tempChunks.keySet()) {
      ChunkOrEntity entity = (ChunkOrEntity)obj;
      mapChunks
        .put(entity.start + "_" + entity.end + "_" + entity.name, entity);
    }
    List<String>indexes = new ArrayList<String>(mapChunks.keySet());
    // LongCompactor c = new LongCompactor();
    Collections.sort(indexes);
    for(int i1 = 0; i1 < indexes.size(); ++i1) {
      // for(Object ob1:tempChunks.keySet() ) {
      Object ob1 = mapChunks.get(indexes.get(i1));
      if(tempChunks.get(ob1).toString().equals("1")) {
        ChunkOrEntity chunk1 = (ChunkOrEntity)ob1;
        for(int j1 = i1 + 1; j1 < indexes.size(); ++j1) {
          Object ob2 = mapChunks.get(indexes.get(j1));
          // for(Object ob2:tempChunks.keySet()) {
          if(tempChunks.get(ob2).toString().equals("1")) {
            ChunkOrEntity chunk2 = (ChunkOrEntity)ob2;
            if(chunk2.start != chunk1.start || chunk2.end != chunk1.end
              || chunk2.name != chunk1.name) {
              // if the two entities overlap
              if((chunk1.start >= chunk2.start && chunk1.start <= chunk2.end)
                || (chunk1.end <= chunk2.end && chunk1.end >= chunk2.start)) {
                if(chunk1.prob > chunk2.prob) {
                  tempChunks.put(chunk2, 0);
                } else if(chunk1.prob < chunk2.prob) {
                  tempChunks.put(chunk1, 0);
                  break; // break the inner loop (ob2)
               } else if(chunk1.end - chunk1.start>chunk2.end - chunk2.start) {
                  tempChunks.put(chunk1, 0);
                  break;
                }
                else {
                  tempChunks.put(chunk2, 0);
                }
              }
            }
          }
        }// end of the inner loop
      }
    }// end of the outer loop
    for(Object ob1 : tempChunks.keySet()) {
      if(tempChunks.get((ChunkOrEntity)ob1).intValue() == 1) chunks.add(ob1);
    }
  }
  /** Post-processing the results for the classification problem
   * for the one vs all others method: select the label with the 
   * maximal probability which is bigger than the pre-defined threshold, 
   * for one instance. 
   */
  public void postProcessingClassification(short stage,
    LabelsOfFV[] multiLabels, int[] selectedLabels, float[] valueLabels) {
    int num = multiLabels.length;
    float maxValue;
    int maxLabel;
    for(int i = 0; i < num; ++i) { // for each instance
      maxValue = (float)thresholdC;
      maxLabel = -1;
      for(int j = 0; j < multiLabels[i].num; ++j) { // for each class
        if(multiLabels[i].probs[j] > maxValue) {
          maxValue = multiLabels[i].probs[j];
          maxLabel = j;
        }
        selectedLabels[i] = maxLabel;
        valueLabels[i] = maxValue;
      }
    }
  }
  /** Post-processing for the one vs another method,
   * for the training data with null label. 
   */
  public void voteForOneVSAnotherNull(DataForLearning dataFVinDoc,
    int numClasses) {
    for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i) { 
      LabelsOfFeatureVectorDoc labelFVsDoc = dataFVinDoc.labelsFVDoc[i];
      for(int j = 0; j < labelFVsDoc.multiLabels.length; ++j) { 
        LabelsOfFV multiLabel0 = dataFVinDoc.labelsFVDoc[i].multiLabels[j];
        int[] voteResults = new int[numClasses];
        int voteNull;
        int kk = 0;
        // for the null
        voteNull = 0;
        for(int j1 = 0; j1 < numClasses; ++j1) {
          if(multiLabel0.probs[kk] > thresholdC)
            voteResults[j1]++;
          else voteNull++;
          ++kk;
        }
        // for other label
        for(int i1 = 0; i1 < numClasses; ++i1)
          for(int j1 = i1 + 1; j1 < numClasses; ++j1) {
            if(multiLabel0.probs[kk] > thresholdC)
              voteResults[i1]++;
            else voteResults[j1]++;
            ++kk;
          }
        // Convert the vote results into label
        int maxVote = voteNull;
        kk = -1;
        for(int i1 = 0; i1 < numClasses; ++i1)
          if(maxVote < voteResults[i1]) {
            maxVote = voteResults[i1];
            kk = i1;
          }
        LabelsOfFV multiLabel = new LabelsOfFV(numClasses);
        multiLabel.probs = new float[multiLabel.num];
        if(kk >= 0) multiLabel.probs[kk] = (float)1.0;
        dataFVinDoc.labelsFVDoc[i].multiLabels[j] = multiLabel;
      }
    }
  }

  /** Post-processing for the one vs another method,
   * for the training data without the null label. 
   */
  public void voteForOneVSAnother(DataForLearning dataFVinDoc, int numClasses) {
    for(int i = 0; i < dataFVinDoc.getNumTrainingDocs(); ++i) { 
      LabelsOfFeatureVectorDoc labelFVsDoc = dataFVinDoc.labelsFVDoc[i];
      for(int j = 0; j < labelFVsDoc.multiLabels.length; ++j) { 
        LabelsOfFV multiLabel0 = dataFVinDoc.labelsFVDoc[i].multiLabels[j];
        int[] voteResults = new int[numClasses];
        int kk = 0;
        // for other label
        for(int i1 = 0; i1 < numClasses; ++i1)
          for(int j1 = i1 + 1; j1 < numClasses; ++j1) {
            if(multiLabel0.probs[kk] > thresholdC)
              voteResults[i1]++;
            else voteResults[j1]++;
            ++kk;
          }
        // Convert the vote results into label
        int maxVote = -1;
        kk = -1;
        for(int i1 = 0; i1 < numClasses; ++i1)
          if(maxVote < voteResults[i1]) {
            maxVote = voteResults[i1];
            kk = i1;
          }
        LabelsOfFV multiLabel = new LabelsOfFV(numClasses);
        multiLabel.probs = new float[multiLabel.num];
        if(kk >= 0) multiLabel.probs[kk] = (float)1.0;
        dataFVinDoc.labelsFVDoc[i].multiLabels[j] = multiLabel;
      }
    }
  }
}