//********** THIS IS THE MAIN PROGRAM **************************
//==============================================================
public static void main(String[] args)
{
//initiate the weights
initWeights();
//load in the data
initData();
//train the network
for(int j = 0;j <= numEpochs;j++)
{
for(int i = 0;i
//select a pattern at random
patNum = (int)((Math.random()*numPatterns)-0.001);
//calculate the current network output
//and error for this pattern
calcNet();
//change network weights
WeightChangesHO();
WeightChangesIH();
}
//display the overall network error
//after each epoch
calcOverallError();
System.out.println("epoch = " + j + " RMS Error = " + RMSerror);
}
//training has finished
//display the results
displayResults();
}
//============================================================
//********** END OF THE MAIN PROGRAM **************************
//=============================================================
//************************************
public static void calcNet()
{
//calculate the outputs of the hidden neurons
//the hidden neurons are tanh
for(int i = 0;i
hiddenVal[i] = 0.0;
for(int j = 0;j
hiddenVal[i] = tanh(hiddenVal[i]);
}
//calculate the output of the network
//the output neuron is linear
outPred = 0.0;
for(int i = 0;i
//calculate the error
errThisPat = outPred - trainOutput[patNum];
}
//************************************
public static void WeightChangesHO()
//adjust the weights hidden-output
{
for(int k = 0;k
double weightChange = LR_HO * errThisPat * hiddenVal[k];
weightsHO[k] = weightsHO[k] - weightChange;
//regularisation on the output weights
if (weightsHO[k] < -5)
weightsHO[k] = -5;
else if (weightsHO[k] > 5)
weightsHO[k] = 5;
}
}
//************************************
public static void WeightChangesIH()
//adjust the weights input-hidden
{
for(int i = 0;i
for(int k = 0;k
double x = 1 - (hiddenVal[i] * hiddenVal[i]);
x = x * weightsHO[i] * errThisPat * LR_IH;
x = x * trainInputs[patNum][k];
double weightChange = x;
weightsIH[k][i] = weightsIH[k][i] - weightChange;
}
}
}
//************************************
public static void initWeights()
{
for(int j = 0;j
weightsHO[j] = (Math.random() - 0.5)/2;
for(int i = 0;i
}
}
//************************************
public static void initData()
{
System.out.println("initialising data");
// the data here is the XOR data
// it has been rescaled to the range
// [-1][1]
// an extra input valued 1 is also added
// to act as the bias
trainInputs[0][0] = 1;
trainInputs[0][1] = -1;
trainInputs[0][2] = 1;//bias
trainOutput[0] = 1;
trainInputs[1][0] = -1;
trainInputs[1][1] = 1;
trainInputs[1][2] = 1;//bias
trainOutput[1] = 1;
trainInputs[2][0] = 1;
trainInputs[2][1] = 1;
trainInputs[2][2] = 1;//bias
trainOutput[2] = -1;
trainInputs[3][0] = -1;
trainInputs[3][1] = -1;
trainInputs[3][2] = 1;//bias
trainOutput[3] = -1;
}
//************************************
public static double tanh(double x)
{
if (x > 20)
return 1;
else if (x < -20)
return -1;
else
{
double a = Math.exp(x);
double b = Math.exp(-x);
return (a-b)/(a+b);
}
}
//************************************
public static void displayResults()
{
for(int i = 0;i
patNum = i;
calcNet();
System.out.println("pat = " + (patNum+1) + " actual = " + trainOutput[patNum] + " neural model = " + outPred);
}
}
//************************************
public static void calcOverallError()
{
RMSerror = 0.0;
for(int i = 0;i
patNum = i;
calcNet();
RMSerror = RMSerror + (errThisPat * errThisPat);
}
RMSerror = RMSerror/numPatterns;
RMSerror = java.lang.Math.sqrt(RMSerror);
}
}
Tidak ada komentar:
Posting Komentar