Log in Help
Print
Homereleasesgate-5.1-beta2-build3402-ALLpluginsLearningsrcgatelearninglearnersweka 〉 WekaLearner.java
 
/*
 *  WekaLearner.java
 * 
 *  Yaoyong Li 22/03/2007
 *
 *  $Id: WekaLearner.java, v 1.0 2007-03-22 12:58:16 +0000 yaoyong $
 */
package gate.learning.learners.weka;

import gate.learning.LabelsOfFV;
import gate.learning.LabelsOfFeatureVectorDoc;
import gate.learning.LogService;

import java.io.Serializable;
import weka.core.Instances;
/**
 * The abstract class for Weka learner.
 * It implements the training and applying methods for
 * learners.  
 */
public abstract class WekaLearner implements Serializable {
  /** The name of learner. */
  String learnerName = null;
  /** The weka classifier. */
  public weka.classifiers.Classifier wekaCl;
  /** The options for one Weka learner. */
  String options;
  /** The abstract method of getting parameters from options. */
  public abstract void getParametersFromOptionsLine(String options);
  /** Training by calling th buildClassifier method of the learner. */
  public void training(Instances instancesData) {
    try {
      if(LogService.minVerbosityLevel>0)
        System.out.println("Learning start:");
      wekaCl.buildClassifier(instancesData);
    } catch(Exception e) {
      e.printStackTrace();
    }
  }
  /** Applying the model to the data. The output of the
   * application could be a distribution among all labels
   * or the maximal output for one class. 
   */
  public void applying(Instances instancesData,
    LabelsOfFeatureVectorDoc[] labelsFVDoc, boolean distributionOutput) {
    int numInst = 0;
    //Not count the label for null class
    int numClasses = instancesData.numClasses() - 1; 
    //Get the map from the output to values(true label of the problem)
    // From the attribute index to true labels
    int[] trueLabels = new int[numClasses + 1];
    for(int i = 0; i <= numClasses; ++i) {
      trueLabels[i] = Integer.parseInt(instancesData.classAttribute().value(i));
    }
    if(LogService.minVerbosityLevel>0)
      System.out.println("Application starts...");
    try {
      if(distributionOutput) {
        double[] distr;
        for(int iDoc = 0; iDoc < labelsFVDoc.length; ++iDoc) {
          int num = labelsFVDoc[iDoc].multiLabels.length;
          for(int i = 0; i < num; ++i) {
            distr = wekaCl.distributionForInstance(instancesData
              .instance(numInst++));
            labelsFVDoc[iDoc].multiLabels[i] = new LabelsOfFV(numClasses);
            labelsFVDoc[iDoc].multiLabels[i].probs = new float[numClasses];
            double sum = 0.0;
            for(int j = 0; j < distr.length; ++j)
              sum += distr[j] * distr[j];
            sum = Math.sqrt(sum);
            if(sum < 0.00000000001) sum = 1.0;
            for(int j = 0; j < distr.length; ++j)
              distr[j] /= sum;
            for(int j = 0; j <= numClasses; ++j)
              if(trueLabels[j] != -1) {
                labelsFVDoc[iDoc].multiLabels[i].probs[trueLabels[j] - 1] = (float)distr[j]; // as                                                                                   // class
              }
          }
        }
      } else {
        double outputV;
        for(int iDoc = 0; iDoc < labelsFVDoc.length; ++iDoc) {
          int num = labelsFVDoc[iDoc].multiLabels.length;
          for(int i = 0; i < num; ++i) {
            outputV = wekaCl
              .classifyInstance(instancesData.instance(numInst++));
            labelsFVDoc[iDoc].multiLabels[i] = new LabelsOfFV(numClasses);
            labelsFVDoc[iDoc].multiLabels[i].probs = new float[numClasses];
            if(trueLabels[(int)outputV] != -1)
              labelsFVDoc[iDoc].multiLabels[i].probs[trueLabels[(int)outputV] - 1] = 1.0f;
          }
        }
      }
    } catch(Exception e) {
      e.printStackTrace();
    }
  }

  public String getLearnerName() {
    return learnerName;
  }
}