Package ann4j

Class Trainer

java.lang.Object
ann4j.Trainer

public class Trainer extends Object
  • Field Details

  • Constructor Details

    • Trainer

      public Trainer()
  • Method Details

    • getModelEvaluator

      public ModelEvaluator getModelEvaluator()
      This function returns the model evaluator
      Returns:
      The ModelEvaluator object.
    • getLayerManager

      public LayerManager getLayerManager()
      Returns the LayerManager object that is used to manage the layers in this map.
      Returns:
      The LayerManager object.
    • train

      public void train(int noOfSamples, int epochs)
      The function trains the neural network by reading the training data from the mnist database and updating the weights and biases of the neural network
      Parameters:
      noOfSamples - The number of samples to train on.
      epochs - Number of times the training data is to be trained.
    • test

      public void test(int noOfSamples)
      Used to test the model.
      Parameters:
      noOfSamples - The number of samples to be tested.
    • printConfusionMatrix

      public void printConfusionMatrix()
      This function prints the confusion matrix of the model
    • forwardPropagatewithExclusionInputLayerOnKSamples

      public void forwardPropagatewithExclusionInputLayerOnKSamples(int noOfSamples)
    • relevancePropagate

      public void relevancePropagate(int layerNumber, int neuronNumber)
      This function propagates relevance from the output layer to the input layer
      Parameters:
      layerNumber - The layer number of the neuron you want to propagate relevance from.
      neuronNumber - The neuron number in the layer.
    • train

      public void train()
      The function takes in an input layer and an expected output layer, and then it uses the input layer to predict the expected output layer