Package ann4j
Class Trainer
java.lang.Object
ann4j.Trainer
-
Field Summary
Modifier and TypeFieldDescription(package private) double
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionvoid
forwardPropagatewithExclusionInputLayerOnKSamples
(int noOfSamples) Returns the LayerManager object that is used to manage the layers in this map.This function returns the model evaluatorvoid
This function prints the confusion matrix of the modelvoid
relevancePropagate
(int layerNumber, int neuronNumber) This function propagates relevance from the output layer to the input layervoid
test
(int noOfSamples) Used to test the model.void
train()
The function takes in an input layer and an expected output layer, and then it uses the input layer to predict the expected output layervoid
train
(int noOfSamples, int epochs) The function trains the neural network by reading the training data from the mnist database and updating the weights and biases of the neural network
-
Field Details
-
myLayerManager
-
expectedLayer
-
trainingFileReader
-
testingFileReader
-
inputLayer
-
label
double label
-
-
Constructor Details
-
Trainer
public Trainer()
-
-
Method Details
-
getModelEvaluator
This function returns the model evaluator- Returns:
- The ModelEvaluator object.
-
getLayerManager
Returns the LayerManager object that is used to manage the layers in this map.- Returns:
- The LayerManager object.
-
train
public void train(int noOfSamples, int epochs) The function trains the neural network by reading the training data from the mnist database and updating the weights and biases of the neural network- Parameters:
noOfSamples
- The number of samples to train on.epochs
- Number of times the training data is to be trained.
-
test
public void test(int noOfSamples) Used to test the model.- Parameters:
noOfSamples
- The number of samples to be tested.
-
printConfusionMatrix
public void printConfusionMatrix()This function prints the confusion matrix of the model -
forwardPropagatewithExclusionInputLayerOnKSamples
public void forwardPropagatewithExclusionInputLayerOnKSamples(int noOfSamples) -
relevancePropagate
public void relevancePropagate(int layerNumber, int neuronNumber) This function propagates relevance from the output layer to the input layer- Parameters:
layerNumber
- The layer number of the neuron you want to propagate relevance from.neuronNumber
- The neuron number in the layer.
-
train
public void train()The function takes in an input layer and an expected output layer, and then it uses the input layer to predict the expected output layer
-